diff --git a/internal/core/internal_network.go b/internal/core/internal_network.go index 0a8d0c55..ee4cc538 100644 --- a/internal/core/internal_network.go +++ b/internal/core/internal_network.go @@ -86,6 +86,11 @@ func isInternalAddr(addr netip.Addr) bool { return false } + // Unmap IPv4-in-IPv6 (e.g. ::ffff:10.0.0.1) so all subsequent checks + // work uniformly. Without this, Is4() returns false for mapped addresses + // and the CGNAT prefix check would be skipped. + addr = addr.Unmap() + if addr.IsPrivate() || addr.IsLoopback() || addr.IsLinkLocalUnicast() || @@ -102,56 +107,67 @@ func isInternalAddr(addr netip.Addr) bool { return false } -func (s *Server) validateToolEndpoint(ctx context.Context, endpoint *url.URL) error { - // Skip validation if internal network access check is disabled +// validateToolEndpoint checks whether the endpoint is allowed by the +// internal-network policy. It returns the first resolved IP address (if DNS +// resolution was performed) so the caller can pin the connection and prevent +// DNS-rebinding attacks. When the host is already an IP literal or validation +// is disabled, resolvedIP will be empty. +func (s *Server) validateToolEndpoint(ctx context.Context, endpoint *url.URL) (resolvedIP string, err error) { if !s.internalNetEnabled { - return nil + return "", nil } if endpoint == nil { - return fmt.Errorf("tool endpoint is empty") + return "", fmt.Errorf("tool endpoint is empty") } host := endpoint.Hostname() if host == "" { - return fmt.Errorf("tool endpoint host is empty") + return "", fmt.Errorf("tool endpoint host is empty") } if s.internalNetACL.allowsHost(host) { - return nil + return "", nil } if addr, err := netip.ParseAddr(host); err == nil { if isInternalAddr(addr) && !s.internalNetACL.allowsAddr(addr) { - return fmt.Errorf("internal network access is disabled for tool endpoints") + return "", fmt.Errorf("internal network access is disabled for tool endpoints") } - return nil + // Host is an IP literal; no DNS involved, no rebinding risk. + return "", nil } lookupCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() addrs, err := net.DefaultResolver.LookupIPAddr(lookupCtx, host) if err != nil { - return fmt.Errorf("failed to resolve tool endpoint host for internal access check: %w", err) + return "", fmt.Errorf("failed to resolve tool endpoint host for internal access check: %w", err) } internalFound := false + var firstAddr string for _, addr := range addrs { ip, ok := netip.AddrFromSlice(addr.IP) if !ok { continue } + ip = ip.Unmap() + if firstAddr == "" { + firstAddr = ip.String() + } if isInternalAddr(ip) { internalFound = true if s.internalNetACL.allowsAddr(ip) { - return nil + return firstAddr, nil } } } if internalFound { - return fmt.Errorf("internal network access is disabled for tool endpoints") + return "", fmt.Errorf("internal network access is disabled for tool endpoints") } - return nil + // Return the resolved IP so the transport can pin the connection. + return firstAddr, nil } diff --git a/internal/core/internal_network_test.go b/internal/core/internal_network_test.go index 40e84e2c..38811fa8 100644 --- a/internal/core/internal_network_test.go +++ b/internal/core/internal_network_test.go @@ -2,6 +2,10 @@ package core import ( "context" + "fmt" + "net/http" + "net/http/httptest" + "net/netip" "net/url" "testing" @@ -12,7 +16,8 @@ func TestValidateToolEndpoint_InternalBlocked(t *testing.T) { s := &Server{internalNetEnabled: true} u, err := url.Parse("http://127.0.0.1:8080") assert.NoError(t, err) - assert.Error(t, s.validateToolEndpoint(context.Background(), u)) + _, verr := s.validateToolEndpoint(context.Background(), u) + assert.Error(t, verr) } func TestValidateToolEndpoint_InternalAllowlistedCIDR(t *testing.T) { @@ -22,7 +27,8 @@ func TestValidateToolEndpoint_InternalAllowlistedCIDR(t *testing.T) { s := &Server{internalNetEnabled: true, internalNetACL: allowlist} u, err := url.Parse("http://127.0.0.1:8080") assert.NoError(t, err) - assert.NoError(t, s.validateToolEndpoint(context.Background(), u)) + _, verr := s.validateToolEndpoint(context.Background(), u) + assert.NoError(t, verr) } func TestValidateToolEndpoint_InternalAllowlistedHost(t *testing.T) { @@ -32,14 +38,16 @@ func TestValidateToolEndpoint_InternalAllowlistedHost(t *testing.T) { s := &Server{internalNetEnabled: true, internalNetACL: allowlist} u, err := url.Parse("http://internal.local/health") assert.NoError(t, err) - assert.NoError(t, s.validateToolEndpoint(context.Background(), u)) + _, verr := s.validateToolEndpoint(context.Background(), u) + assert.NoError(t, verr) } func TestValidateToolEndpoint_PublicIPAllowed(t *testing.T) { s := &Server{internalNetEnabled: true} u, err := url.Parse("http://8.8.8.8") assert.NoError(t, err) - assert.NoError(t, s.validateToolEndpoint(context.Background(), u)) + _, verr := s.validateToolEndpoint(context.Background(), u) + assert.NoError(t, verr) } func TestValidateToolEndpoint_Disabled(t *testing.T) { @@ -47,5 +55,116 @@ func TestValidateToolEndpoint_Disabled(t *testing.T) { u, err := url.Parse("http://127.0.0.1:8080") assert.NoError(t, err) // When disabled, internal addresses should be allowed - assert.NoError(t, s.validateToolEndpoint(context.Background(), u)) + _, verr := s.validateToolEndpoint(context.Background(), u) + assert.NoError(t, verr) +} + +func TestIsInternalAddr_IPv4MappedIPv6(t *testing.T) { + // ::ffff:127.0.0.1 should be detected as loopback + mapped := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 127, 0, 0, 1}) + assert.True(t, isInternalAddr(mapped), "IPv4-mapped loopback should be internal") + + // ::ffff:10.0.0.1 should be detected as private + mappedPrivate := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 10, 0, 0, 1}) + assert.True(t, isInternalAddr(mappedPrivate), "IPv4-mapped RFC1918 should be internal") + + // ::ffff:100.64.0.1 (CGNAT) should be detected as internal + mappedCGNAT := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 100, 64, 0, 1}) + assert.True(t, isInternalAddr(mappedCGNAT), "IPv4-mapped CGNAT should be internal") + + // ::ffff:8.8.8.8 should NOT be internal + mappedPublic := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 8, 8, 8, 8}) + assert.False(t, isInternalAddr(mappedPublic), "IPv4-mapped public IP should not be internal") +} + +func TestCreateHTTPClient_RedirectToInternalBlocked(t *testing.T) { + // Simulate a public server that redirects to an internal address + redirectTarget := "http://169.254.169.254/latest/meta-data/" + publicServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, redirectTarget, http.StatusFound) + })) + defer publicServer.Close() + + s := &Server{internalNetEnabled: true} + cli, err := s.createHTTPClient(nil, "") + assert.NoError(t, err) + + req, _ := http.NewRequest("GET", publicServer.URL, nil) + _, err = cli.Do(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "redirect") + assert.Contains(t, err.Error(), "blocked") +} + +func TestCreateHTTPClient_RedirectToInternalAllowlisted(t *testing.T) { + // Internal target server + internalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "ok") + })) + defer internalServer.Close() + + // Public server that redirects to the internal target + publicServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, internalServer.URL+"/data", http.StatusFound) + })) + defer publicServer.Close() + + allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8"}) + s := &Server{internalNetEnabled: true, internalNetACL: allowlist} + cli, err := s.createHTTPClient(nil, "") + assert.NoError(t, err) + + req, _ := http.NewRequest("GET", publicServer.URL, nil) + resp, err := cli.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestCreateHTTPClient_ChainedRedirectBlocked(t *testing.T) { + // Simulates a multi-hop redirect: public -> public -> internal + internalTarget := "http://10.0.0.1/admin" + + hop2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, internalTarget, http.StatusFound) + })) + defer hop2.Close() + + hop1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, hop2.URL, http.StatusFound) + })) + defer hop1.Close() + + s := &Server{internalNetEnabled: true} + cli, err := s.createHTTPClient(nil, "") + assert.NoError(t, err) + + req, _ := http.NewRequest("GET", hop1.URL, nil) + _, err = cli.Do(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "blocked") +} + +func TestCreateHTTPClient_PinnedAddr(t *testing.T) { + // Verify that when a pinnedAddr is provided, the client connects to it + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "pinned") + })) + defer target.Close() + + // Extract host:port from the test server to use as pinnedAddr + u, _ := url.Parse(target.URL) + + s := &Server{} + cli, err := s.createHTTPClient(nil, u.Hostname()) + assert.NoError(t, err) + + // Request a URL with the same port but "127.0.0.1" host; + // the pinned addr should override the hostname resolution. + req, _ := http.NewRequest("GET", target.URL, nil) + resp, err := cli.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() } diff --git a/internal/core/tool.go b/internal/core/tool.go index 018a422a..8a4bb80b 100644 --- a/internal/core/tool.go +++ b/internal/core/tool.go @@ -272,10 +272,16 @@ func fillDefaultArgs(tool *config.ToolConfig, args map[string]any) { } } -// createHTTPClient creates an HTTP client with proxy support if configured -func createHTTPClient(tool *config.ToolConfig) (*http.Client, error) { +// createHTTPClient creates an HTTP client with proxy support if configured. +// It installs a CheckRedirect hook that re-validates every redirect target +// against the internal-network allowlist (preventing 302-based SSRF), and +// pins the initial connection to pinnedAddr (when non-empty) to prevent +// DNS-rebinding attacks. +func (s *Server) createHTTPClient(tool *config.ToolConfig, pinnedAddr string) (*http.Client, error) { + var baseTransport *http.Transport + if tool != nil && tool.Proxy != nil { - transport := &http.Transport{} + baseTransport = &http.Transport{} switch tool.Proxy.Type { case "http", "https": @@ -284,22 +290,47 @@ func createHTTPClient(tool *config.ToolConfig) (*http.Client, error) { if err != nil { return nil, fmt.Errorf("invalid %s proxy configuration: %w", tool.Proxy.Type, err) } - transport.Proxy = http.ProxyURL(proxyURL) + baseTransport.Proxy = http.ProxyURL(proxyURL) case "socks5": dialer, err := proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", tool.Proxy.Host, tool.Proxy.Port), nil, proxy.Direct) if err != nil { return nil, fmt.Errorf("failed to create SOCKS5 dialer: %w", err) } - transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + baseTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) } } + } else { + baseTransport = http.DefaultTransport.(*http.Transport).Clone() + } - return &http.Client{Transport: otelhttp.NewTransport(transport)}, nil + // Pin DNS: when validateToolEndpoint resolved a hostname to an IP, we + // force the transport to connect to that IP instead of re-resolving. + // This closes the DNS-rebinding TOCTOU window. + if pinnedAddr != "" && baseTransport.DialContext == nil { + stdDialer := &net.Dialer{} + baseTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + return stdDialer.DialContext(ctx, network, net.JoinHostPort(pinnedAddr, port)) + } } - return &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)}, nil + return &http.Client{ + Transport: otelhttp.NewTransport(baseTransport), + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + if _, err := s.validateToolEndpoint(req.Context(), req.URL); err != nil { + return fmt.Errorf("redirect to %s blocked: %w", req.URL.Redacted(), err) + } + return nil + }, + }, nil } // executeHTTPTool executes a tool with the given arguments @@ -355,7 +386,8 @@ func (s *Server) executeHTTPTool(c *gin.Context, conn session.Connection, tool * return nil, err } - if err := s.validateToolEndpoint(ctx, req.URL); err != nil { + pinnedAddr, err := s.validateToolEndpoint(ctx, req.URL) + if err != nil { logger.Warn("blocked tool endpoint", zap.String("tool", tool.Name), zap.String("session_id", conn.Meta().ID), @@ -403,7 +435,7 @@ func (s *Server) executeHTTPTool(c *gin.Context, conn session.Connection, tool * processArguments(req, tool, args) // Execute request - cli, err := createHTTPClient(tool) + cli, err := s.createHTTPClient(tool, pinnedAddr) if err != nil { logger.Error("failed to create HTTP client", zap.String("tool", tool.Name), diff --git a/internal/core/tool_test.go b/internal/core/tool_test.go index bb197241..fccb412e 100644 --- a/internal/core/tool_test.go +++ b/internal/core/tool_test.go @@ -81,23 +81,25 @@ func TestFillDefaultArgs(t *testing.T) { } func TestCreateHTTPClient(t *testing.T) { + s := &Server{} + // default client when no proxy - cli, err := createHTTPClient(nil) + cli, err := s.createHTTPClient(nil, "") assert.NoError(t, err) assert.NotNil(t, cli) // http proxy - cli2, err := createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "http", Host: "127.0.0.1", Port: 8080}}) + cli2, err := s.createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "http", Host: "127.0.0.1", Port: 8080}}, "") assert.NoError(t, err) assert.NotNil(t, cli2) // socks5 proxy - cli3, err := createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "socks5", Host: "127.0.0.1", Port: 1080}}) + cli3, err := s.createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "socks5", Host: "127.0.0.1", Port: 1080}}, "") assert.NoError(t, err) assert.NotNil(t, cli3) // invalid proxy - _, err = createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "https", Host: "invalid host with space", Port: 1}}) + _, err = s.createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "https", Host: "invalid host with space", Port: 1}}, "") assert.Error(t, err) }