Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions packages/pam/local/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ func StartPAMAccess(accessToken, path, reason, durationStr string, port int) {
case AccountTypePostgres, AccountTypeMySQL, AccountTypeMsSQL, AccountTypeMongoDB, AccountTypeOracleDB:
startDatabaseProxy(httpClient, &pamResponse, displayPath, durationStr, port)

// Non-database types - not yet implemented
case AccountTypeSSH:
util.PrintErrorMessageAndExit("SSH access not yet supported in the new PAM model")
startSSHAccess(httpClient, &pamResponse, displayPath, durationStr, port)
Comment thread
bernie-g marked this conversation as resolved.
case AccountTypeRedis:
util.PrintErrorMessageAndExit("Redis access not yet supported in the new PAM model")
case AccountTypeKubernetes:
Expand Down Expand Up @@ -250,6 +249,107 @@ func startDatabaseProxy(httpClient *resty.Client, response *api.PAMAccessRespons
proxy.Run()
}

func startSSHAccess(httpClient *resty.Client, response *api.PAMAccessResponse, path, durationStr string, port int) {
duration, err := time.ParseDuration(durationStr)
if err != nil {
util.HandleError(err, "Failed to parse duration")
return
}

username, ok := response.Metadata["username"]
if !ok {
util.HandleError(fmt.Errorf("PAM response metadata is missing 'username'"), "Failed to start SSH session")
return
}

ctx, cancel := context.WithCancel(context.Background())

proxy := &SSHProxyServer{
BaseProxyServer: BaseProxyServer{
httpClient: httpClient,
relayHost: response.RelayHost,
relayClientCert: response.RelayClientCertificate,
relayClientKey: response.RelayClientPrivateKey,
relayServerCertChain: response.RelayServerCertificateChain,
gatewayClientCert: response.GatewayClientCertificate,
gatewayClientKey: response.GatewayClientPrivateKey,
gatewayServerCertChain: response.GatewayServerCertificateChain,
sessionExpiry: time.Now().Add(duration),
sessionId: response.SessionId,
resourceType: response.AccountType,
ctx: ctx,
cancel: cancel,
shutdownCh: make(chan struct{}),
},
}
Comment thread
bernie-g marked this conversation as resolved.

if err := proxy.ValidateResourceTypeSupported(); err != nil {
util.HandleError(err, "Gateway version outdated")
return
}

err = proxy.Start(port)
if err != nil {
util.HandleError(err, "Failed to start SSH proxy server")
return
}

folder, account := parsePath(path)

log.Info().Msgf("SSH proxy server listening on port %d", proxy.port)
Comment thread
bernie-g marked this conversation as resolved.
printSSHSessionInfo(folder, account, duration, username, proxy.port)

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

go func() {
sig := <-sigChan
log.Info().Msgf("Received signal %v, initiating graceful shutdown...", sig)
proxy.gracefulShutdown()
}()

proxy.Run()
}

func printSSHSessionInfo(folder, account string, duration time.Duration, username string, port int) {
fmt.Printf("\n")
fmt.Printf("**********************************************************************\n")
fmt.Printf(" SSH Proxy Session Started! \n")
fmt.Printf("**********************************************************************\n")
fmt.Printf("\n")
if folder != "" {
fmt.Printf(" Folder: %s\n", folder)
}
fmt.Printf(" Account: %s\n", account)
fmt.Printf(" Duration: %s\n", duration.String())
fmt.Printf("\n")
fmt.Printf("----------------------------------------------------------------------\n")
fmt.Printf(" Connection Details \n")
fmt.Printf("----------------------------------------------------------------------\n")
fmt.Printf("\n")
fmt.Printf(" Host: 127.0.0.1\n")
fmt.Printf(" Port: %d\n", port)
if username != "" {
fmt.Printf(" Username: %s\n", username)
}
fmt.Printf("\n")
fmt.Printf("----------------------------------------------------------------------\n")
fmt.Printf(" How to Connect \n")
fmt.Printf("----------------------------------------------------------------------\n")
fmt.Printf("\n")
fmt.Printf(" Use your preferred SSH client to connect to 127.0.0.1:%d.\n", port)
fmt.Printf(" Credentials are handled automatically by the gateway.\n")
fmt.Printf("\n")
fmt.Printf(" Examples:\n")
util.PrintfStderr(" $ ssh -p %d -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null %s@127.0.0.1\n", port, username)
util.PrintfStderr(" $ scp -P %d -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null <local-file> %s@127.0.0.1:<remote-path>\n", port, username)
Comment thread
bernie-g marked this conversation as resolved.
fmt.Printf("\n")
fmt.Printf(" Press Ctrl+C to stop the proxy.\n")
fmt.Printf("\n")
fmt.Printf("**********************************************************************\n")
fmt.Printf("\n")
}

// printDatabaseSessionInfo prints the connection info banner for database sessions
func printDatabaseSessionInfo(config DatabaseDisplayConfig, folder, account string, duration time.Duration, username, database string, port int) {
fmt.Printf("\n")
Expand Down
220 changes: 1 addition & 219 deletions packages/pam/local/ssh-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,155 +6,15 @@ import (
"io"
"net"
"os"
"os/exec"
"os/signal"
"strconv"
"strings"
"syscall"
"time"

"github.com/Infisical/infisical-merge/packages/pam/session"
"github.com/Infisical/infisical-merge/packages/util"
"github.com/go-resty/resty/v2"
"github.com/rs/zerolog/log"
)

type SSHProxyServer struct {
BaseProxyServer // Embed common functionality
server net.Listener
port int
sshProcess *exec.Cmd
options SSHAccessOptions
sshExitCode int // Exit code from SSH process (for exec mode)
}

// SSHAccessOptions configures SSH access behavior
type SSHAccessOptions struct {
ExecCommand string // If set, run this command instead of interactive shell
ProxyOnly bool // If true, start proxy without launching SSH client
}

func StartSSHLocalProxy(accessToken string, accessParams PAMAccessParams, projectID string, durationStr string, options SSHAccessOptions) {
httpClient := resty.New()
httpClient.SetAuthToken(accessToken)
httpClient.SetHeader("User-Agent", "infisical-cli")

pamRequest := accessParams.ToAPIRequest(projectID, durationStr)

interactive := options.ExecCommand == ""
pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, interactive)
if err != nil {
if HandleApprovalWorkflow(httpClient, err, projectID, accessParams, durationStr) {
return
}
util.HandleError(err, "Failed to access PAM account")
return
}

// Verify this is an SSH resource
if pamResponse.ResourceType != session.ResourceTypeSSH {
util.HandleError(fmt.Errorf("account is not an SSH resource, got: %s", pamResponse.ResourceType), "Invalid resource type")
return
}

duration, err := time.ParseDuration(durationStr)
if err != nil {
util.HandleError(err, "Failed to parse duration")
return
}

ctx, cancel := context.WithCancel(context.Background())

proxy := &SSHProxyServer{
BaseProxyServer: BaseProxyServer{
httpClient: httpClient,
relayHost: pamResponse.RelayHost,
relayClientCert: pamResponse.RelayClientCertificate,
relayClientKey: pamResponse.RelayClientPrivateKey,
relayServerCertChain: pamResponse.RelayServerCertificateChain,
gatewayClientCert: pamResponse.GatewayClientCertificate,
gatewayClientKey: pamResponse.GatewayClientPrivateKey,
gatewayServerCertChain: pamResponse.GatewayServerCertificateChain,
sessionExpiry: time.Now().Add(duration),
sessionId: pamResponse.SessionId,
resourceType: pamResponse.ResourceType,
ctx: ctx,
cancel: cancel,
shutdownCh: make(chan struct{}),
},
options: options,
}

if err := proxy.ValidateResourceTypeSupported(); err != nil {
util.HandleError(err, "Gateway version outdated")
return
}

// Start the local TCP proxy on a random port
err = proxy.Start(0) // 0 = random port
if err != nil {
util.HandleError(err, "Failed to start SSH proxy server")
return
}

// Extract metadata
username, ok := pamResponse.Metadata["username"]
if !ok {
util.HandleError(fmt.Errorf("PAM response metadata is missing 'username'"), "Failed to start proxy server")
return
}

log.Debug().
Str("sessionID", pamResponse.SessionId).
Str("username", username).
Int("port", proxy.port).
Msg("SSH proxy ready")

// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

go func() {
sig := <-sigChan
log.Debug().Msgf("Received signal %v, initiating graceful shutdown...", sig)
proxy.gracefulShutdown()
}()

// Start the proxy server in a goroutine
go proxy.Run()

// Give the proxy a moment to start accepting connections
time.Sleep(500 * time.Millisecond)

if options.ProxyOnly {
// Proxy-only mode: print connection info and wait
fmt.Printf("SSH proxy listening on 127.0.0.1:%d\n", proxy.port)
fmt.Printf("Username: %s\n", username)
fmt.Printf("Session expires: %s\n", proxy.sessionExpiry.Format(time.RFC3339))
fmt.Println("")
fmt.Println("Use this proxy with SSH, SCP, SFTP, or rsync:")
fmt.Printf(" ssh -p %d -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null %s@127.0.0.1\n", proxy.port, username)
fmt.Printf(" scp -P %d -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null <local-file> %s@127.0.0.1:<remote-path>\n", proxy.port, username)
fmt.Println("")
fmt.Println("Press Ctrl+C to stop the proxy.")

// Wait for context cancellation (Ctrl+C triggers gracefulShutdown which cancels context)
<-proxy.ctx.Done()
} else {
// Launch SSH client connected to the local proxy (transparent to user)
err = proxy.launchSSHClient(username)
if err != nil {
log.Error().Err(err).Msg("Failed to launch SSH client")
proxy.gracefulShutdown()
return
}

// Wait for SSH process to complete
proxy.waitForSSHCompletion()

// SSH client exited, shutdown gracefully
proxy.gracefulShutdown()
}
}

func (p *SSHProxyServer) Start(port int) error {
Expand All @@ -177,81 +37,10 @@ func (p *SSHProxyServer) Start(port int) error {
return nil
}

func (p *SSHProxyServer) launchSSHClient(username string) error {
// Build SSH command: ssh -p <local-port> <username>@localhost [command]
sshArgs := []string{
"-p", strconv.Itoa(p.port),
"-o", "StrictHostKeyChecking=no", // Skip host key verification (we're connecting to localhost)
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
fmt.Sprintf("%s@127.0.0.1", username),
}

// If exec command is specified, append it (non-interactive mode)
if p.options.ExecCommand != "" {
sshArgs = append(sshArgs, p.options.ExecCommand)
}

p.sshProcess = exec.Command("ssh", sshArgs...)
p.sshProcess.Stdin = os.Stdin
p.sshProcess.Stdout = os.Stdout
p.sshProcess.Stderr = os.Stderr

log.Debug().Msgf("Executing: ssh %s", formatSSHArgs(sshArgs))

err := p.sshProcess.Start()
if err != nil {
return fmt.Errorf("failed to start SSH client: %w", err)
}

log.Debug().Msgf("SSH client started with PID: %d", p.sshProcess.Process.Pid)
return nil
}

// formatSSHArgs formats SSH arguments for logging, quoting args with spaces
func formatSSHArgs(args []string) string {
formatted := make([]string, len(args))
for i, arg := range args {
if strings.ContainsRune(arg, ' ') {
formatted[i] = fmt.Sprintf("%q", arg)
} else {
formatted[i] = arg
}
}
return strings.Join(formatted, " ")
}

func (p *SSHProxyServer) waitForSSHCompletion() {
if p.sshProcess == nil {
return
}

err := p.sshProcess.Wait()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
p.sshExitCode = exitErr.ExitCode()
log.Debug().Msgf("SSH client exited with code: %d", p.sshExitCode)
} else {
log.Error().Err(err).Msg("Error waiting for SSH client")
p.sshExitCode = 1
}
} else {
p.sshExitCode = 0
log.Debug().Msg("SSH client exited successfully")
}
}

func (p *SSHProxyServer) gracefulShutdown() {
p.shutdownOnce.Do(func() {
log.Debug().Msg("Starting graceful shutdown of SSH proxy...")

// Kill SSH process if it's still running
if p.sshProcess != nil && p.sshProcess.Process != nil {
log.Debug().Msg("Terminating SSH client process")
p.sshProcess.Process.Signal(syscall.SIGTERM)
}

// Send session termination notification before cancelling context
p.NotifySessionTermination()

// Signal the accept loop to stop
Expand All @@ -269,14 +58,7 @@ func (p *SSHProxyServer) gracefulShutdown() {
p.WaitForConnectionsWithTimeout(10 * time.Second)

log.Debug().Msg("SSH proxy shutdown complete")

// Only propagate SSH exit code in exec mode (non-interactive)
// For interactive sessions, always exit 0 on clean shutdown
exitCode := 0
if p.options.ExecCommand != "" {
exitCode = p.sshExitCode
}
os.Exit(exitCode)
os.Exit(0)
})
}

Expand Down
Loading