diff --git a/url/url.go b/url/url.go index cbd2bdc..85cb2f8 100644 --- a/url/url.go +++ b/url/url.go @@ -2,6 +2,7 @@ package urlutil import ( "bytes" + "net" "net/url" "strings" @@ -146,11 +147,7 @@ func (u *URL) UpdatePort(newport string) { if newport == "" { return } - if u.Port() != "" { - u.Host = strings.Replace(u.Host, u.Port(), newport, 1) - return - } - u.Host += ":" + newport + u.Host = net.JoinHostPort(u.Hostname(), newport) } // TrimPort if any diff --git a/url/url_test.go b/url/url_test.go index a60a053..0861361 100644 --- a/url/url_test.go +++ b/url/url_test.go @@ -82,6 +82,27 @@ func TestPortUpdate(t *testing.T) { require.Equalf(t, urlx.String(), expected, "expected %v but got %v", expected, urlx.String()) } +func TestPortUpdateWhenPortMatchesHost(t *testing.T) { + // UpdatePort uses strings.Replace on the whole Host, so the port substring + // can accidentally match a piece of the IP/hostname instead of the actual port. + testcases := []struct { + inputURL string + newport string + expected string + }{ + // last octet "80" matches before the real ":80" port + {"http://37.228.93.80:80/", "443", "http://37.228.93.80:443/"}, + {"http://37.228.93.443:80/", "8080", "http://37.228.93.443:8080/"}, + {"http://[2001:db8::80]:80/", "8080", "http://[2001:db8::80]:8080/"}, + } + for _, tc := range testcases { + urlx, err := Parse(tc.inputURL) + require.Nil(t, err) + urlx.UpdatePort(tc.newport) + require.Equalf(t, tc.expected, urlx.String(), "expected %v but got %v", tc.expected, urlx.String()) + } +} + func TestUpdateRelPath(t *testing.T) { // updates existing relative path with new one exURL := "https://scanme.sh/somepath/abc?key=true"