diff --git a/packages/pam/local/database-proxy.go b/packages/pam/local/database-proxy.go index c418795c..659c40a0 100644 --- a/packages/pam/local/database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -119,19 +119,19 @@ func StartDatabaseLocalProxy(accessToken string, accessParams PAMAccessParams, p switch pamResponse.ResourceType { case session.ResourceTypePostgres: - util.PrintfStderr("postgres://%s@localhost:%d/%s", username, proxy.port, database) + util.PrintfStderr("postgres://%s@127.0.0.1:%d/%s", username, proxy.port, database) case session.ResourceTypeMysql: - util.PrintfStderr("mysql://%s@localhost:%d/%s", username, proxy.port, database) + util.PrintfStderr("mysql://%s@127.0.0.1:%d/%s", username, proxy.port, database) case session.ResourceTypeMssql: - util.PrintfStderr("sqlserver://%s@localhost:%d?database=%s&encrypt=false&trustServerCertificate=true", username, proxy.port, database) + util.PrintfStderr("sqlserver://%s@127.0.0.1:%d?database=%s&encrypt=false&trustServerCertificate=true", username, proxy.port, database) case session.ResourceTypeMongodb: - util.PrintfStderr("mongodb://localhost:%d/%s?serverSelectionTimeoutMS=15000", proxy.port, database) + util.PrintfStderr("mongodb://127.0.0.1:%d/%s?serverSelectionTimeoutMS=15000", proxy.port, database) case session.ResourceTypeOracledb: - util.PrintfStderr("%s/%s@localhost:%d/%s", username, oracle.ProxyPasswordPlaceholder, proxy.port, database) - util.PrintfStderr("\njdbc:oracle:thin:@localhost:%d/%s (user: %s, password: %s)", proxy.port, database, username, oracle.ProxyPasswordPlaceholder) + util.PrintfStderr("%s/%s@127.0.0.1:%d/%s", username, oracle.ProxyPasswordPlaceholder, proxy.port, database) + util.PrintfStderr("\njdbc:oracle:thin:@127.0.0.1:%d/%s (user: %s, password: %s)", proxy.port, database, username, oracle.ProxyPasswordPlaceholder) util.PrintfStderr("\n\nNote: the password shown is a protocol placeholder required by Oracle, not a secret.") default: - util.PrintfStderr("localhost:%d", proxy.port) + util.PrintfStderr("127.0.0.1:%d", proxy.port) } util.PrintfStderr("\n**********************************************************************\n") util.PrintfStderr("\n") @@ -151,9 +151,9 @@ func StartDatabaseLocalProxy(accessToken string, accessParams PAMAccessParams, p func (p *DatabaseProxyServer) Start(port int) error { var err error if port == 0 { - p.server, err = net.Listen("tcp", ":0") + p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only } else { - p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port)) + p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) } if err != nil { diff --git a/packages/pam/local/kubernetes-proxy.go b/packages/pam/local/kubernetes-proxy.go index 0e94fe72..2da147ce 100644 --- a/packages/pam/local/kubernetes-proxy.go +++ b/packages/pam/local/kubernetes-proxy.go @@ -113,7 +113,7 @@ func StartKubernetesLocalProxy(accessToken string, accessParams PAMAccessParams, clusterName := fmt.Sprintf("infisical-k8s-pam/%s/%s", accessParams.ResourceName, accessParams.AccountName) config.Clusters[clusterName] = &k8sapi.Cluster{ - Server: fmt.Sprintf("http://localhost:%d", proxy.port), + Server: fmt.Sprintf("http://127.0.0.1:%d", proxy.port), } config.AuthInfos[clusterName] = &k8sapi.AuthInfo{} config.Contexts[clusterName] = &k8sapi.Context{ @@ -158,9 +158,9 @@ func StartKubernetesLocalProxy(accessToken string, accessParams PAMAccessParams, func (p *KubernetesProxyServer) Start(port int) error { var err error if port == 0 { - p.server, err = net.Listen("tcp", ":0") + p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only } else { - p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port)) + p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) } if err != nil { diff --git a/packages/pam/local/proxy_loopback_test.go b/packages/pam/local/proxy_loopback_test.go new file mode 100644 index 00000000..0e9a5b85 --- /dev/null +++ b/packages/pam/local/proxy_loopback_test.go @@ -0,0 +1,40 @@ +package pam + +import ( + "net" + "testing" +) + +// TestLocalProxiesBindLoopback guards that the local PAM proxies bind to a +// loopback address rather than all interfaces. Start() only creates the +// listener (the accept loop lives in Run), so it can be exercised in isolation +// without a gateway or an active session. +func TestLocalProxiesBindLoopback(t *testing.T) { + cases := []struct { + name string + start func() (net.Listener, error) + }{ + {"database", func() (net.Listener, error) { p := &DatabaseProxyServer{}; err := p.Start(0); return p.server, err }}, + {"redis", func() (net.Listener, error) { p := &RedisProxyServer{}; err := p.Start(0); return p.server, err }}, + {"kubernetes", func() (net.Listener, error) { p := &KubernetesProxyServer{}; err := p.Start(0); return p.server, err }}, + {"rdp", func() (net.Listener, error) { p := &RDPProxyServer{}; err := p.Start(0); return p.server, err }}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ln, err := tc.start() + if err != nil { + t.Fatalf("Start: %v", err) + } + defer func() { _ = ln.Close() }() + + addr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("unexpected listener address type %T", ln.Addr()) + } + if !addr.IP.IsLoopback() { + t.Fatalf("%s proxy bound to %s; must bind a loopback address, not all interfaces", tc.name, addr.IP) + } + }) + } +} diff --git a/packages/pam/local/redis-proxy.go b/packages/pam/local/redis-proxy.go index 901f1659..5e26defa 100644 --- a/packages/pam/local/redis-proxy.go +++ b/packages/pam/local/redis-proxy.go @@ -107,9 +107,9 @@ func StartRedisLocalProxy(accessToken string, accessParams PAMAccessParams, proj util.PrintfStderr("\n") util.PrintfStderr("You can now connect to your Redis instance using:\n") if username != "" { - util.PrintfStderr("redis://%s@localhost:%d", username, proxy.port) + util.PrintfStderr("redis://%s@127.0.0.1:%d", username, proxy.port) } else { - util.PrintfStderr("redis://localhost:%d", proxy.port) + util.PrintfStderr("redis://127.0.0.1:%d", proxy.port) } util.PrintfStderr("\n**********************************************************************\n") util.PrintfStderr("\n") @@ -129,9 +129,9 @@ func StartRedisLocalProxy(accessToken string, accessParams PAMAccessParams, proj func (p *RedisProxyServer) Start(port int) error { var err error if port == 0 { - p.server, err = net.Listen("tcp", ":0") + p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only } else { - p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port)) + p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) } if err != nil {