diff --git a/packages/pam/local/access.go b/packages/pam/local/access.go index 53c0dcbe..10be3ef0 100644 --- a/packages/pam/local/access.go +++ b/packages/pam/local/access.go @@ -114,11 +114,11 @@ var databaseConfigs = map[string]DatabaseDisplayConfig{ TypeLabel: "PostgreSQL", DefaultPort: 5432, ConnectionString: func(username, database string, port int) string { - return fmt.Sprintf("postgres://%s@localhost:%d/%s", username, port, database) + return fmt.Sprintf("postgres://%s@127.0.0.1:%d/%s", username, port, database) }, UsageExamples: func(username, database string, port int) []string { return []string{ - fmt.Sprintf("psql -h localhost -p %d -U %s -d %s", port, username, database), + fmt.Sprintf("psql -h 127.0.0.1 -p %d -U %s -d %s", port, username, database), } }, }, @@ -126,7 +126,7 @@ var databaseConfigs = map[string]DatabaseDisplayConfig{ TypeLabel: "MySQL", DefaultPort: 3306, ConnectionString: func(username, database string, port int) string { - return fmt.Sprintf("mysql://%s@localhost:%d/%s", username, port, database) + return fmt.Sprintf("mysql://%s@127.0.0.1:%d/%s", username, port, database) }, UsageExamples: func(username, database string, port int) []string { return []string{ @@ -138,11 +138,11 @@ var databaseConfigs = map[string]DatabaseDisplayConfig{ TypeLabel: "SQL Server", DefaultPort: 1433, ConnectionString: func(username, database string, port int) string { - return fmt.Sprintf("sqlserver://%s@localhost:%d?database=%s", username, port, database) + return fmt.Sprintf("sqlserver://%s@127.0.0.1:%d?database=%s", username, port, database) }, UsageExamples: func(username, database string, port int) []string { return []string{ - fmt.Sprintf("sqlcmd -S localhost,%d -U %s -d %s", port, username, database), + fmt.Sprintf("sqlcmd -S 127.0.0.1,%d -U %s -d %s", port, username, database), } }, }, @@ -150,11 +150,11 @@ var databaseConfigs = map[string]DatabaseDisplayConfig{ TypeLabel: "MongoDB", DefaultPort: 27017, ConnectionString: func(username, database string, port int) string { - return fmt.Sprintf("mongodb://localhost:%d/%s", port, database) + return fmt.Sprintf("mongodb://127.0.0.1:%d/%s", port, database) }, UsageExamples: func(username, database string, port int) []string { return []string{ - fmt.Sprintf("mongosh --host localhost --port %d %s", port, database), + fmt.Sprintf("mongosh --host 127.0.0.1 --port %d %s", port, database), } }, }, @@ -162,11 +162,11 @@ var databaseConfigs = map[string]DatabaseDisplayConfig{ TypeLabel: "Oracle", DefaultPort: 1521, ConnectionString: func(username, database string, port int) string { - return fmt.Sprintf("%s@localhost:%d/%s", username, port, database) + return fmt.Sprintf("%s@127.0.0.1:%d/%s", username, port, database) }, UsageExamples: func(username, database string, port int) []string { return []string{ - fmt.Sprintf("sqlplus %s@localhost:%d/%s", username, port, database), + fmt.Sprintf("sqlplus %s@127.0.0.1:%d/%s", username, port, database), } }, }, @@ -367,7 +367,7 @@ func printDatabaseSessionInfo(config DatabaseDisplayConfig, folder, account stri fmt.Printf(" Connection Details \n") fmt.Printf("----------------------------------------------------------------------\n") fmt.Printf("\n") - fmt.Printf(" Host: localhost\n") + fmt.Printf(" Host: 127.0.0.1\n") fmt.Printf(" Port: %d\n", port) if username != "" { fmt.Printf(" Username: %s\n", username) @@ -382,7 +382,7 @@ func printDatabaseSessionInfo(config DatabaseDisplayConfig, folder, account stri fmt.Printf("----------------------------------------------------------------------\n") fmt.Printf("\n") fmt.Printf(" Use your preferred database client (CLI, GUI, or IDE) to connect\n") - fmt.Printf(" to localhost:%d. No password is needed.\n", port) + fmt.Printf(" to 127.0.0.1:%d. No password is needed.\n", port) fmt.Printf("\n") if config.UsageExamples != nil { examples := config.UsageExamples(username, database, port) diff --git a/packages/pam/local/database-proxy.go b/packages/pam/local/database-proxy.go index 872a13ee..27209336 100644 --- a/packages/pam/local/database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -28,9 +28,9 @@ const ( func (p *DatabaseProxyServer) Start(port int) error { var err error if port == 0 { - p.server, err = net.Listen("tcp", ":0") + p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only } else { - p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port)) + p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) } if err != nil { diff --git a/packages/pam/local/kubernetes-proxy.go b/packages/pam/local/kubernetes-proxy.go index 6bb87a98..1d94c594 100644 --- a/packages/pam/local/kubernetes-proxy.go +++ b/packages/pam/local/kubernetes-proxy.go @@ -31,7 +31,7 @@ func (p *KubernetesProxyServer) SetupKubeconfig(clusterName string) error { } config.Clusters[clusterName] = &k8sapi.Cluster{ - Server: fmt.Sprintf("http://localhost:%d", p.port), + Server: fmt.Sprintf("http://127.0.0.1:%d", p.port), } config.AuthInfos[clusterName] = &k8sapi.AuthInfo{} config.Contexts[clusterName] = &k8sapi.Context{ @@ -53,9 +53,9 @@ func (p *KubernetesProxyServer) SetupKubeconfig(clusterName string) error { func (p *KubernetesProxyServer) Start(port int) error { var err error if port == 0 { - p.server, err = net.Listen("tcp", ":0") + p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only } else { - p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port)) + p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) } if err != nil { diff --git a/packages/pam/local/proxy_loopback_test.go b/packages/pam/local/proxy_loopback_test.go new file mode 100644 index 00000000..92e0c473 --- /dev/null +++ b/packages/pam/local/proxy_loopback_test.go @@ -0,0 +1,41 @@ +package pam + +import ( + "net" + "testing" +) + +// TestLocalProxiesBindLoopback guards that the local PAM proxies bind to a +// loopback address rather than all interfaces. Start() only creates the +// listener (the accept loop lives in Run), so it can be exercised in isolation +// without a gateway or an active session. +func TestLocalProxiesBindLoopback(t *testing.T) { + cases := []struct { + name string + start func() (net.Listener, error) + }{ + {"database", func() (net.Listener, error) { p := &DatabaseProxyServer{}; err := p.Start(0); return p.server, err }}, + {"redis", func() (net.Listener, error) { p := &RedisProxyServer{}; err := p.Start(0); return p.server, err }}, + {"kubernetes", func() (net.Listener, error) { p := &KubernetesProxyServer{}; err := p.Start(0); return p.server, err }}, + {"ssh", func() (net.Listener, error) { p := &SSHProxyServer{}; err := p.Start(0); return p.server, err }}, + {"rdp", func() (net.Listener, error) { p := &RDPProxyServer{}; err := p.Start(0); return p.server, err }}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ln, err := tc.start() + if err != nil { + t.Fatalf("Start: %v", err) + } + defer func() { _ = ln.Close() }() + + addr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("unexpected listener address type %T", ln.Addr()) + } + if !addr.IP.IsLoopback() { + t.Fatalf("%s proxy bound to %s; must bind a loopback address, not all interfaces", tc.name, addr.IP) + } + }) + } +} diff --git a/packages/pam/local/redis-proxy.go b/packages/pam/local/redis-proxy.go index 901f1659..5e26defa 100644 --- a/packages/pam/local/redis-proxy.go +++ b/packages/pam/local/redis-proxy.go @@ -107,9 +107,9 @@ func StartRedisLocalProxy(accessToken string, accessParams PAMAccessParams, proj util.PrintfStderr("\n") util.PrintfStderr("You can now connect to your Redis instance using:\n") if username != "" { - util.PrintfStderr("redis://%s@localhost:%d", username, proxy.port) + util.PrintfStderr("redis://%s@127.0.0.1:%d", username, proxy.port) } else { - util.PrintfStderr("redis://localhost:%d", proxy.port) + util.PrintfStderr("redis://127.0.0.1:%d", proxy.port) } util.PrintfStderr("\n**********************************************************************\n") util.PrintfStderr("\n") @@ -129,9 +129,9 @@ func StartRedisLocalProxy(accessToken string, accessParams PAMAccessParams, proj func (p *RedisProxyServer) Start(port int) error { var err error if port == 0 { - p.server, err = net.Listen("tcp", ":0") + p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only } else { - p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port)) + p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) } if err != nil {