Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions packages/pam/local/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,19 @@ var databaseConfigs = map[string]DatabaseDisplayConfig{
TypeLabel: "PostgreSQL",
DefaultPort: 5432,
ConnectionString: func(username, database string, port int) string {
return fmt.Sprintf("postgres://%s@localhost:%d/%s", username, port, database)
return fmt.Sprintf("postgres://%s@127.0.0.1:%d/%s", username, port, database)
},
UsageExamples: func(username, database string, port int) []string {
return []string{
fmt.Sprintf("psql -h localhost -p %d -U %s -d %s", port, username, database),
fmt.Sprintf("psql -h 127.0.0.1 -p %d -U %s -d %s", port, username, database),
}
},
},
AccountTypeMySQL: {
TypeLabel: "MySQL",
DefaultPort: 3306,
ConnectionString: func(username, database string, port int) string {
return fmt.Sprintf("mysql://%s@localhost:%d/%s", username, port, database)
return fmt.Sprintf("mysql://%s@127.0.0.1:%d/%s", username, port, database)
},
UsageExamples: func(username, database string, port int) []string {
return []string{
Expand All @@ -138,35 +138,35 @@ var databaseConfigs = map[string]DatabaseDisplayConfig{
TypeLabel: "SQL Server",
DefaultPort: 1433,
ConnectionString: func(username, database string, port int) string {
return fmt.Sprintf("sqlserver://%s@localhost:%d?database=%s", username, port, database)
return fmt.Sprintf("sqlserver://%s@127.0.0.1:%d?database=%s", username, port, database)
},
UsageExamples: func(username, database string, port int) []string {
return []string{
fmt.Sprintf("sqlcmd -S localhost,%d -U %s -d %s", port, username, database),
fmt.Sprintf("sqlcmd -S 127.0.0.1,%d -U %s -d %s", port, username, database),
}
},
},
AccountTypeMongoDB: {
TypeLabel: "MongoDB",
DefaultPort: 27017,
ConnectionString: func(username, database string, port int) string {
return fmt.Sprintf("mongodb://localhost:%d/%s", port, database)
return fmt.Sprintf("mongodb://127.0.0.1:%d/%s", port, database)
},
UsageExamples: func(username, database string, port int) []string {
return []string{
fmt.Sprintf("mongosh --host localhost --port %d %s", port, database),
fmt.Sprintf("mongosh --host 127.0.0.1 --port %d %s", port, database),
}
},
},
AccountTypeOracleDB: {
TypeLabel: "Oracle",
DefaultPort: 1521,
ConnectionString: func(username, database string, port int) string {
return fmt.Sprintf("%s@localhost:%d/%s", username, port, database)
return fmt.Sprintf("%s@127.0.0.1:%d/%s", username, port, database)
},
UsageExamples: func(username, database string, port int) []string {
return []string{
fmt.Sprintf("sqlplus %s@localhost:%d/%s", username, port, database),
fmt.Sprintf("sqlplus %s@127.0.0.1:%d/%s", username, port, database),
}
},
},
Expand Down Expand Up @@ -367,7 +367,7 @@ func printDatabaseSessionInfo(config DatabaseDisplayConfig, folder, account stri
fmt.Printf(" Connection Details \n")
fmt.Printf("----------------------------------------------------------------------\n")
fmt.Printf("\n")
fmt.Printf(" Host: localhost\n")
fmt.Printf(" Host: 127.0.0.1\n")
fmt.Printf(" Port: %d\n", port)
if username != "" {
fmt.Printf(" Username: %s\n", username)
Expand All @@ -382,7 +382,7 @@ func printDatabaseSessionInfo(config DatabaseDisplayConfig, folder, account stri
fmt.Printf("----------------------------------------------------------------------\n")
fmt.Printf("\n")
fmt.Printf(" Use your preferred database client (CLI, GUI, or IDE) to connect\n")
fmt.Printf(" to localhost:%d. No password is needed.\n", port)
fmt.Printf(" to 127.0.0.1:%d. No password is needed.\n", port)
fmt.Printf("\n")
if config.UsageExamples != nil {
examples := config.UsageExamples(username, database, port)
Expand Down
4 changes: 2 additions & 2 deletions packages/pam/local/database-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ const (
func (p *DatabaseProxyServer) Start(port int) error {
var err error
if port == 0 {
p.server, err = net.Listen("tcp", ":0")
p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only
} else {
p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port))
p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
}

if err != nil {
Comment thread
x032205 marked this conversation as resolved.
Expand Down
6 changes: 3 additions & 3 deletions packages/pam/local/kubernetes-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (p *KubernetesProxyServer) SetupKubeconfig(clusterName string) error {
}

config.Clusters[clusterName] = &k8sapi.Cluster{
Server: fmt.Sprintf("http://localhost:%d", p.port),
Server: fmt.Sprintf("http://127.0.0.1:%d", p.port),
}
config.AuthInfos[clusterName] = &k8sapi.AuthInfo{}
config.Contexts[clusterName] = &k8sapi.Context{
Expand All @@ -53,9 +53,9 @@ func (p *KubernetesProxyServer) SetupKubeconfig(clusterName string) error {
func (p *KubernetesProxyServer) Start(port int) error {
var err error
if port == 0 {
p.server, err = net.Listen("tcp", ":0")
p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only
} else {
p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port))
p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
}

if err != nil {
Expand Down
41 changes: 41 additions & 0 deletions packages/pam/local/proxy_loopback_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package pam

import (
"net"
"testing"
)

// TestLocalProxiesBindLoopback guards that the local PAM proxies bind to a
// loopback address rather than all interfaces. Start() only creates the
// listener (the accept loop lives in Run), so it can be exercised in isolation
// without a gateway or an active session.
func TestLocalProxiesBindLoopback(t *testing.T) {
cases := []struct {
name string
start func() (net.Listener, error)
}{
{"database", func() (net.Listener, error) { p := &DatabaseProxyServer{}; err := p.Start(0); return p.server, err }},
{"redis", func() (net.Listener, error) { p := &RedisProxyServer{}; err := p.Start(0); return p.server, err }},
{"kubernetes", func() (net.Listener, error) { p := &KubernetesProxyServer{}; err := p.Start(0); return p.server, err }},
{"ssh", func() (net.Listener, error) { p := &SSHProxyServer{}; err := p.Start(0); return p.server, err }},
{"rdp", func() (net.Listener, error) { p := &RDPProxyServer{}; err := p.Start(0); return p.server, err }},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ln, err := tc.start()
if err != nil {
t.Fatalf("Start: %v", err)
}
defer func() { _ = ln.Close() }()

addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatalf("unexpected listener address type %T", ln.Addr())
}
if !addr.IP.IsLoopback() {
t.Fatalf("%s proxy bound to %s; must bind a loopback address, not all interfaces", tc.name, addr.IP)
}
})
}
}
8 changes: 4 additions & 4 deletions packages/pam/local/redis-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ func StartRedisLocalProxy(accessToken string, accessParams PAMAccessParams, proj
util.PrintfStderr("\n")
util.PrintfStderr("You can now connect to your Redis instance using:\n")
if username != "" {
util.PrintfStderr("redis://%s@localhost:%d", username, proxy.port)
util.PrintfStderr("redis://%s@127.0.0.1:%d", username, proxy.port)
} else {
util.PrintfStderr("redis://localhost:%d", proxy.port)
util.PrintfStderr("redis://127.0.0.1:%d", proxy.port)
}
util.PrintfStderr("\n**********************************************************************\n")
util.PrintfStderr("\n")
Expand All @@ -129,9 +129,9 @@ func StartRedisLocalProxy(accessToken string, accessParams PAMAccessParams, proj
func (p *RedisProxyServer) Start(port int) error {
var err error
if port == 0 {
p.server, err = net.Listen("tcp", ":0")
p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to 127.0.0.1 only
} else {
p.server, err = net.Listen("tcp", fmt.Sprintf(":%d", port))
p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
}

if err != nil {
Expand Down
Loading