Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions internal/core/internal_network.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ func isInternalAddr(addr netip.Addr) bool {
return false
}

// Unmap IPv4-in-IPv6 (e.g. ::ffff:10.0.0.1) so all subsequent checks
// work uniformly. Without this, Is4() returns false for mapped addresses
// and the CGNAT prefix check would be skipped.
addr = addr.Unmap()

if addr.IsPrivate() ||
addr.IsLoopback() ||
addr.IsLinkLocalUnicast() ||
Expand All @@ -102,56 +107,67 @@ func isInternalAddr(addr netip.Addr) bool {
return false
}

func (s *Server) validateToolEndpoint(ctx context.Context, endpoint *url.URL) error {
// Skip validation if internal network access check is disabled
// validateToolEndpoint checks whether the endpoint is allowed by the
// internal-network policy. It returns the first resolved IP address (if DNS
// resolution was performed) so the caller can pin the connection and prevent
// DNS-rebinding attacks. When the host is already an IP literal or validation
// is disabled, resolvedIP will be empty.
func (s *Server) validateToolEndpoint(ctx context.Context, endpoint *url.URL) (resolvedIP string, err error) {
if !s.internalNetEnabled {
return nil
return "", nil
}

if endpoint == nil {
return fmt.Errorf("tool endpoint is empty")
return "", fmt.Errorf("tool endpoint is empty")
}

host := endpoint.Hostname()
if host == "" {
return fmt.Errorf("tool endpoint host is empty")
return "", fmt.Errorf("tool endpoint host is empty")
}

if s.internalNetACL.allowsHost(host) {
return nil
return "", nil
}

if addr, err := netip.ParseAddr(host); err == nil {
if isInternalAddr(addr) && !s.internalNetACL.allowsAddr(addr) {
return fmt.Errorf("internal network access is disabled for tool endpoints")
return "", fmt.Errorf("internal network access is disabled for tool endpoints")
}
return nil
// Host is an IP literal; no DNS involved, no rebinding risk.
return "", nil
}

lookupCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
addrs, err := net.DefaultResolver.LookupIPAddr(lookupCtx, host)
if err != nil {
return fmt.Errorf("failed to resolve tool endpoint host for internal access check: %w", err)
return "", fmt.Errorf("failed to resolve tool endpoint host for internal access check: %w", err)
}

internalFound := false
var firstAddr string
for _, addr := range addrs {
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
continue
}
ip = ip.Unmap()
if firstAddr == "" {
firstAddr = ip.String()
}
if isInternalAddr(ip) {
internalFound = true
if s.internalNetACL.allowsAddr(ip) {
return nil
return firstAddr, nil
}
}
}

if internalFound {
return fmt.Errorf("internal network access is disabled for tool endpoints")
return "", fmt.Errorf("internal network access is disabled for tool endpoints")
}

return nil
// Return the resolved IP so the transport can pin the connection.
return firstAddr, nil
}
129 changes: 124 additions & 5 deletions internal/core/internal_network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package core

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"testing"

Expand All @@ -12,7 +16,8 @@ func TestValidateToolEndpoint_InternalBlocked(t *testing.T) {
s := &Server{internalNetEnabled: true}
u, err := url.Parse("http://127.0.0.1:8080")
assert.NoError(t, err)
assert.Error(t, s.validateToolEndpoint(context.Background(), u))
_, verr := s.validateToolEndpoint(context.Background(), u)
assert.Error(t, verr)
}

func TestValidateToolEndpoint_InternalAllowlistedCIDR(t *testing.T) {
Expand All @@ -22,7 +27,8 @@ func TestValidateToolEndpoint_InternalAllowlistedCIDR(t *testing.T) {
s := &Server{internalNetEnabled: true, internalNetACL: allowlist}
u, err := url.Parse("http://127.0.0.1:8080")
assert.NoError(t, err)
assert.NoError(t, s.validateToolEndpoint(context.Background(), u))
_, verr := s.validateToolEndpoint(context.Background(), u)
assert.NoError(t, verr)
}

func TestValidateToolEndpoint_InternalAllowlistedHost(t *testing.T) {
Expand All @@ -32,20 +38,133 @@ func TestValidateToolEndpoint_InternalAllowlistedHost(t *testing.T) {
s := &Server{internalNetEnabled: true, internalNetACL: allowlist}
u, err := url.Parse("http://internal.local/health")
assert.NoError(t, err)
assert.NoError(t, s.validateToolEndpoint(context.Background(), u))
_, verr := s.validateToolEndpoint(context.Background(), u)
assert.NoError(t, verr)
}

func TestValidateToolEndpoint_PublicIPAllowed(t *testing.T) {
s := &Server{internalNetEnabled: true}
u, err := url.Parse("http://8.8.8.8")
assert.NoError(t, err)
assert.NoError(t, s.validateToolEndpoint(context.Background(), u))
_, verr := s.validateToolEndpoint(context.Background(), u)
assert.NoError(t, verr)
}

func TestValidateToolEndpoint_Disabled(t *testing.T) {
s := &Server{internalNetEnabled: false}
u, err := url.Parse("http://127.0.0.1:8080")
assert.NoError(t, err)
// When disabled, internal addresses should be allowed
assert.NoError(t, s.validateToolEndpoint(context.Background(), u))
_, verr := s.validateToolEndpoint(context.Background(), u)
assert.NoError(t, verr)
}

func TestIsInternalAddr_IPv4MappedIPv6(t *testing.T) {
// ::ffff:127.0.0.1 should be detected as loopback
mapped := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 127, 0, 0, 1})
assert.True(t, isInternalAddr(mapped), "IPv4-mapped loopback should be internal")

// ::ffff:10.0.0.1 should be detected as private
mappedPrivate := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 10, 0, 0, 1})
assert.True(t, isInternalAddr(mappedPrivate), "IPv4-mapped RFC1918 should be internal")

// ::ffff:100.64.0.1 (CGNAT) should be detected as internal
mappedCGNAT := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 100, 64, 0, 1})
assert.True(t, isInternalAddr(mappedCGNAT), "IPv4-mapped CGNAT should be internal")

// ::ffff:8.8.8.8 should NOT be internal
mappedPublic := netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 8, 8, 8, 8})
assert.False(t, isInternalAddr(mappedPublic), "IPv4-mapped public IP should not be internal")
}

func TestCreateHTTPClient_RedirectToInternalBlocked(t *testing.T) {
// Simulate a public server that redirects to an internal address
redirectTarget := "http://169.254.169.254/latest/meta-data/"
publicServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, redirectTarget, http.StatusFound)
}))
defer publicServer.Close()

s := &Server{internalNetEnabled: true}
cli, err := s.createHTTPClient(nil, "")
assert.NoError(t, err)

req, _ := http.NewRequest("GET", publicServer.URL, nil)
_, err = cli.Do(req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "redirect")
assert.Contains(t, err.Error(), "blocked")
}

func TestCreateHTTPClient_RedirectToInternalAllowlisted(t *testing.T) {
// Internal target server
internalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "ok")
}))
defer internalServer.Close()

// Public server that redirects to the internal target
publicServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, internalServer.URL+"/data", http.StatusFound)
}))
defer publicServer.Close()

allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8"})
s := &Server{internalNetEnabled: true, internalNetACL: allowlist}
cli, err := s.createHTTPClient(nil, "")
assert.NoError(t, err)

req, _ := http.NewRequest("GET", publicServer.URL, nil)
resp, err := cli.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
resp.Body.Close()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (testing): Strengthen TestCreateHTTPClient_PinnedAddr to actually verify that the pinned address is used

The test comment promises to verify that the client actually connects to pinnedAddr, but the current assertions only check for a 200 response. Because httptest.NewServer already binds to 127.0.0.1, the test would still pass even if the pinned address were ignored.

To better match the stated intent, you could either:

  • Inject a custom net.Dialer (or wrapper around DialContext) that records the dialed addr and assert the host matches pinnedAddr, or
  • Update the test comment to say it only checks that a client created with pinnedAddr works, and add a more targeted unit test around the dial function to validate the pinning behavior.

That way, the test either truly enforces DNS pinning or accurately documents its limited scope.

}

func TestCreateHTTPClient_ChainedRedirectBlocked(t *testing.T) {
// Simulates a multi-hop redirect: public -> public -> internal
internalTarget := "http://10.0.0.1/admin"

hop2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, internalTarget, http.StatusFound)
}))
defer hop2.Close()

hop1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, hop2.URL, http.StatusFound)
}))
defer hop1.Close()

s := &Server{internalNetEnabled: true}
cli, err := s.createHTTPClient(nil, "")
assert.NoError(t, err)

req, _ := http.NewRequest("GET", hop1.URL, nil)
_, err = cli.Do(req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "blocked")
}

func TestCreateHTTPClient_PinnedAddr(t *testing.T) {
// Verify that when a pinnedAddr is provided, the client connects to it
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "pinned")
}))
defer target.Close()

// Extract host:port from the test server to use as pinnedAddr
u, _ := url.Parse(target.URL)

s := &Server{}
cli, err := s.createHTTPClient(nil, u.Hostname())
assert.NoError(t, err)

// Request a URL with the same port but "127.0.0.1" host;
// the pinned addr should override the hostname resolution.
req, _ := http.NewRequest("GET", target.URL, nil)
resp, err := cli.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
resp.Body.Close()
}
50 changes: 41 additions & 9 deletions internal/core/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,16 @@ func fillDefaultArgs(tool *config.ToolConfig, args map[string]any) {
}
}

// createHTTPClient creates an HTTP client with proxy support if configured
func createHTTPClient(tool *config.ToolConfig) (*http.Client, error) {
// createHTTPClient creates an HTTP client with proxy support if configured.
// It installs a CheckRedirect hook that re-validates every redirect target
// against the internal-network allowlist (preventing 302-based SSRF), and
// pins the initial connection to pinnedAddr (when non-empty) to prevent
// DNS-rebinding attacks.
func (s *Server) createHTTPClient(tool *config.ToolConfig, pinnedAddr string) (*http.Client, error) {
var baseTransport *http.Transport

if tool != nil && tool.Proxy != nil {
transport := &http.Transport{}
baseTransport = &http.Transport{}

switch tool.Proxy.Type {
case "http", "https":
Expand All @@ -284,22 +290,47 @@ func createHTTPClient(tool *config.ToolConfig) (*http.Client, error) {
if err != nil {
return nil, fmt.Errorf("invalid %s proxy configuration: %w", tool.Proxy.Type, err)
}
transport.Proxy = http.ProxyURL(proxyURL)
baseTransport.Proxy = http.ProxyURL(proxyURL)

case "socks5":
dialer, err := proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", tool.Proxy.Host, tool.Proxy.Port), nil, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create SOCKS5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
baseTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
}
} else {
baseTransport = http.DefaultTransport.(*http.Transport).Clone()
}

return &http.Client{Transport: otelhttp.NewTransport(transport)}, nil
// Pin DNS: when validateToolEndpoint resolved a hostname to an IP, we
// force the transport to connect to that IP instead of re-resolving.
// This closes the DNS-rebinding TOCTOU window.
if pinnedAddr != "" && baseTransport.DialContext == nil {
stdDialer := &net.Dialer{}
baseTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
Comment on lines +311 to +313
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 issue (security): DNS pinning is skipped for the default transport because DialContext is already non-nil there.

Because http.DefaultTransport.(*http.Transport).Clone() already sets DialContext, the pinning logic will not apply in the typical non-proxy case, so the default path does not actually get the DNS-rebinding protection described. Only the custom &http.Transport{} path does. To make pinning effective and consistent, wrap whatever DialContext is present (capture the original and delegate to it with addr rewritten to use pinnedAddr) instead of only setting DialContext when it is nil. This preserves existing dial behavior (timeouts, keep-alive, etc.) while still closing the TOCTOU window.

_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return stdDialer.DialContext(ctx, network, net.JoinHostPort(pinnedAddr, port))
}
}

return &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)}, nil
return &http.Client{
Transport: otelhttp.NewTransport(baseTransport),
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return fmt.Errorf("stopped after 10 redirects")
}
if _, err := s.validateToolEndpoint(req.Context(), req.URL); err != nil {
return fmt.Errorf("redirect to %s blocked: %w", req.URL.Redacted(), err)
}
return nil
},
}, nil
}

// executeHTTPTool executes a tool with the given arguments
Expand Down Expand Up @@ -355,7 +386,8 @@ func (s *Server) executeHTTPTool(c *gin.Context, conn session.Connection, tool *
return nil, err
}

if err := s.validateToolEndpoint(ctx, req.URL); err != nil {
pinnedAddr, err := s.validateToolEndpoint(ctx, req.URL)
if err != nil {
logger.Warn("blocked tool endpoint",
zap.String("tool", tool.Name),
zap.String("session_id", conn.Meta().ID),
Expand Down Expand Up @@ -403,7 +435,7 @@ func (s *Server) executeHTTPTool(c *gin.Context, conn session.Connection, tool *
processArguments(req, tool, args)

// Execute request
cli, err := createHTTPClient(tool)
cli, err := s.createHTTPClient(tool, pinnedAddr)
if err != nil {
logger.Error("failed to create HTTP client",
zap.String("tool", tool.Name),
Expand Down
10 changes: 6 additions & 4 deletions internal/core/tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,25 @@ func TestFillDefaultArgs(t *testing.T) {
}

func TestCreateHTTPClient(t *testing.T) {
s := &Server{}

// default client when no proxy
cli, err := createHTTPClient(nil)
cli, err := s.createHTTPClient(nil, "")
assert.NoError(t, err)
assert.NotNil(t, cli)

// http proxy
cli2, err := createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "http", Host: "127.0.0.1", Port: 8080}})
cli2, err := s.createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "http", Host: "127.0.0.1", Port: 8080}}, "")
assert.NoError(t, err)
assert.NotNil(t, cli2)

// socks5 proxy
cli3, err := createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "socks5", Host: "127.0.0.1", Port: 1080}})
cli3, err := s.createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "socks5", Host: "127.0.0.1", Port: 1080}}, "")
assert.NoError(t, err)
assert.NotNil(t, cli3)

// invalid proxy
_, err = createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "https", Host: "invalid host with space", Port: 1}})
_, err = s.createHTTPClient(&config.ToolConfig{Proxy: &config.ProxyConfig{Type: "https", Host: "invalid host with space", Port: 1}}, "")
assert.Error(t, err)
}

Expand Down
Loading