diff --git a/.github/workflows/release_build_infisical_cli.yml b/.github/workflows/release_build_infisical_cli.yml index 9c8e3656..3669e210 100644 --- a/.github/workflows/release_build_infisical_cli.yml +++ b/.github/workflows/release_build_infisical_cli.yml @@ -137,6 +137,13 @@ jobs: sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 3B4FE6ACC0B21F32 sudo apt update sudo apt-get install -y libssl1.0-dev + - name: Install glibc cross-compilers for PKCS#11 (HSM) builds + run: | + set -euo pipefail + # PKCS#11 driver loading uses dlopen; the artifact must be dynamically + # linked against glibc. We use the system gcc for amd64 (native) and + # gcc-aarch64-linux-gnu for arm64. + sudo apt-get install -y gcc-aarch64-linux-gnu - name: Install cross-compile toolchains for RDP tier run: | set -euo pipefail @@ -253,10 +260,19 @@ jobs: chmod 600 /tmp/infisical-apk.rsa env: APK_PRIVATE_KEY: ${{ secrets.APK_PRIVATE_KEY }} + # upload_to_s3.sh syncs the whole apk repo down before rebuilding the index, and + # that repo holds the full backfilled history. + - name: Free disk space before publish + if: github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && !inputs.dry_run) + run: | + df -h / + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL || true + df -h / - name: Publish packages to repositories if: github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && !inputs.dry_run) run: bash upload_to_s3.sh env: + AWS_DEFAULT_REGION: us-east-1 INFISICAL_CLI_S3_BUCKET: ${{ secrets.INFISICAL_CLI_S3_BUCKET }} INFISICAL_CLI_REPO_SIGNING_KEY_ID: ${{ secrets.INFISICAL_CLI_REPO_SIGNING_KEY_ID }} AWS_ACCESS_KEY_ID: ${{ secrets.INFISICAL_CLI_REPO_AWS_ACCESS_KEY_ID }} @@ -266,6 +282,7 @@ jobs: if: github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && !inputs.dry_run) run: aws cloudfront create-invalidation --distribution-id $CLOUDFRONT_DISTRIBUTION_ID --paths '/rpm/*' '/deb/dists/stable/*' '/apk/stable/main/*' env: + AWS_DEFAULT_REGION: us-east-1 AWS_ACCESS_KEY_ID: ${{ secrets.INFISICAL_CLI_REPO_AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.INFISICAL_CLI_REPO_AWS_SECRET_ACCESS_KEY }} CLOUDFRONT_DISTRIBUTION_ID: ${{ secrets.INFISICAL_CLI_REPO_CLOUDFRONT_DISTRIBUTION_ID }} diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 1b0a5362..399bc264 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -117,6 +117,42 @@ builds: goarm: - "7" + # PKCS#11-enabled HSM companion. Loads the vendor's PKCS#11 driver via dlopen + # at runtime, so it MUST be dynamically linked (no -extldflags "-static") and + # built with a glibc toolchain. Shipped as a separate artifact and fetched by + # the launcher in packages/gateway-v2/pkcs11_launcher.go. + - id: linux-amd64-pkcs11 + binary: infisical-pkcs11 + ldflags: + - -X github.com/Infisical/infisical-merge/packages/util.CLI_VERSION={{ .Version }} + - -X github.com/Infisical/infisical-merge/packages/telemetry.POSTHOG_API_KEY_FOR_CLI={{ .Env.POSTHOG_API_KEY_FOR_CLI }} + flags: + - -trimpath + - -tags=pkcs11 + env: + - CGO_ENABLED=1 + - CC=x86_64-linux-gnu-gcc + goos: + - linux + goarch: + - amd64 + + - id: linux-arm64-pkcs11 + binary: infisical-pkcs11 + ldflags: + - -X github.com/Infisical/infisical-merge/packages/util.CLI_VERSION={{ .Version }} + - -X github.com/Infisical/infisical-merge/packages/telemetry.POSTHOG_API_KEY_FOR_CLI={{ .Env.POSTHOG_API_KEY_FOR_CLI }} + flags: + - -trimpath + - -tags=pkcs11 + env: + - CGO_ENABLED=1 + - CC=aarch64-linux-gnu-gcc + goos: + - linux + goarch: + - arm64 + # BSDs and windows/arm64 stay on CGO=0 stub; see build-rdp-bridge.yml. - id: all-other-builds env: @@ -151,7 +187,18 @@ builds: goarch: arm archives: - - format_overrides: + - id: default + builds_info: + group: default + builds: + - linux-amd64-rdp + - linux-arm64-rdp + - linux-386-rdp + - linux-armv6-rdp + - linux-armv7-rdp + - windows-amd64-rdp + - all-other-builds + format_overrides: - goos: windows format: zip files: @@ -160,6 +207,15 @@ archives: - manpages/* - completions/* + - id: pkcs11 + builds: + - linux-amd64-pkcs11 + - linux-arm64-pkcs11 + name_template: "infisical-pkcs11_{{ .Version }}_{{ .Os }}_{{ .Arch }}" + files: + - README* + - LICENSE* + release: mode: append diff --git a/README.md b/README.md index 8e03f0e1..ef33d992 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,11 @@ The official Infisical CLI: Inject secrets into applications and manage your Infisical infrastructure.

+> [!IMPORTANT] +> **The Infisical CLI Linux package repository is moving off Cloudsmith.** To keep up with download volume, we're migrating the Linux package repository to our own host at `artifacts-cli.infisical.com`. Cloudsmith downloads will stop being served on **September 16th, 2026**, after which installs and updates from the old URL will fail. +> +> Every release, including all older versions, is already available on the new host. If you're on an existing setup, you don't need to change anything else, just repoint your machine to the new artifact URL by following the [migration steps](https://infisical.com/docs/cli/cloudsmith-migration). + ## Introduction The **[Infisical CLI](https://infisical.com/docs/cli/overview)** is a powerful command-line tool for secret management that allows you to: diff --git a/go.mod b/go.mod index eb6d44de..dd08a0ed 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/jackc/pgx/v5 v5.9.2 github.com/jcmturner/gokrb5/v8 v8.4.4 github.com/mattn/go-isatty v0.0.20 + github.com/miekg/pkcs11 v1.1.1 github.com/muesli/ansi v0.0.0-20221106050444-61f0cd9a192a github.com/muesli/mango-cobra v1.2.0 github.com/muesli/reflow v0.3.0 diff --git a/go.sum b/go.sum index 63fd6ff6..15176da8 100644 --- a/go.sum +++ b/go.sum @@ -444,6 +444,8 @@ github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRC github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/miekg/pkcs11 v1.1.1 h1:Ugu9pdy6vAYku5DEpVWVFPYnzV+bxB+iRdbuFSu7TvU= +github.com/miekg/pkcs11 v1.1.1/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= diff --git a/packages/api/api.go b/packages/api/api.go index 25c72961..325c1c2e 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -812,10 +812,11 @@ func CallGatewayHeartBeatV1(httpClient *resty.Client) error { return nil } -func CallGatewayHeartBeatV2(httpClient *resty.Client) error { +func CallGatewayHeartBeatV2(httpClient *resty.Client, request GatewayHeartbeatRequest) error { response, err := httpClient. R(). SetHeader("User-Agent", USER_AGENT). + SetBody(request). Post(fmt.Sprintf("%v/v2/gateways/heartbeat", config.INFISICAL_URL)) if err != nil { diff --git a/packages/api/model.go b/packages/api/model.go index a17441f5..b96ba3fd 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -1015,6 +1015,10 @@ type RelayHeartbeatRequest struct { Name string `json:"name"` } +type GatewayHeartbeatRequest struct { + Capabilities map[string]any `json:"capabilities,omitempty"` +} + type RelayLoginRequest struct { Method string `json:"method"` Token string `json:"token,omitempty"` diff --git a/packages/cmd/export.go b/packages/cmd/export.go index a27c5a0c..a75707e6 100644 --- a/packages/cmd/export.go +++ b/packages/cmd/export.go @@ -24,6 +24,7 @@ const ( FormatCSV string = "csv" FormatYaml string = "yaml" FormatDotEnvExport string = "dotenv-export" + FormatDotEnvEval string = "dotenv-eval" ) // exportCmd represents the export command @@ -237,6 +238,8 @@ func getDefaultFilename(format string) string { return "secrets.yaml" case FormatDotEnvExport: return ".env" + case FormatDotEnvEval: + return ".env" case FormatDotenv: return ".env" default: @@ -255,6 +258,8 @@ func getDefaultExtension(format string) string { return ".yaml" case FormatDotEnvExport: return ".env" + case FormatDotEnvEval: + return ".env" case FormatDotenv: return ".env" default: @@ -266,7 +271,7 @@ func init() { RootCmd.AddCommand(exportCmd) exportCmd.Flags().StringP("env", "e", "dev", "Set the environment (dev, prod, etc.) from which your secrets should be pulled from") exportCmd.Flags().Bool("expand", true, "Parse shell parameter expansions in your secrets") - exportCmd.Flags().StringP("format", "f", "dotenv", "Set the format of the output file (dotenv, json, csv)") + exportCmd.Flags().StringP("format", "f", "dotenv", "Set the format of the output file (dotenv, dotenv-export, dotenv-eval, json, csv, yaml)") exportCmd.Flags().Bool("secret-overriding", true, "Prioritizes personal secrets, if any, with the same name over shared secrets") exportCmd.Flags().Bool("include-imports", true, "Imported linked secrets") exportCmd.Flags().String("token", "", "Fetch secrets using service token or machine identity access token") @@ -284,6 +289,8 @@ func formatEnvs(envs []models.SingleEnvironmentVariable, format string) (string, return formatAsDotEnv(envs), nil case FormatDotEnvExport: return formatAsDotEnvExport(envs), nil + case FormatDotEnvEval: + return formatAsDotEnvEval(envs), nil case FormatJson: return formatAsJson(envs), nil case FormatCSV: @@ -291,7 +298,7 @@ func formatEnvs(envs []models.SingleEnvironmentVariable, format string) (string, case FormatYaml: return formatAsYaml(envs) default: - return "", fmt.Errorf("invalid format type: %s. Available format types are [%s]", format, []string{FormatDotenv, FormatJson, FormatCSV, FormatYaml, FormatDotEnvExport}) + return "", fmt.Errorf("invalid format type: %s. Available format types are [%s]", format, []string{FormatDotenv, FormatJson, FormatCSV, FormatYaml, FormatDotEnvExport, FormatDotEnvEval}) } } @@ -325,6 +332,26 @@ func formatAsDotEnvExport(envs []models.SingleEnvironmentVariable) string { return dotenv } +// Format environment variables for shell eval/source. Values are wrapped in +// single quotes with POSIX escaping so the output is safe to evaluate via +// `eval "$(infisical export --format=dotenv-eval)"` regardless of value +// contents (newlines, single quotes, $, ", \, etc.). +func formatAsDotEnvEval(envs []models.SingleEnvironmentVariable) string { + var dotenv string + for _, env := range envs { + dotenv += fmt.Sprintf("export %s=%s\n", env.Key, posixShellQuote(env.Value)) + } + return dotenv +} + +// posixShellQuote wraps a value in single quotes and escapes any embedded +// single quotes using the standard `'\”` sequence. Single-quoted POSIX +// strings preserve every other character verbatim (including newlines, +// backslashes, $, and "), so this is sufficient for eval/source. +func posixShellQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'" +} + func formatAsYaml(envs []models.SingleEnvironmentVariable) (string, error) { m := make(map[string]string) for _, env := range envs { diff --git a/packages/cmd/export_test.go b/packages/cmd/export_test.go index 1be0a7ed..0e3921ec 100644 --- a/packages/cmd/export_test.go +++ b/packages/cmd/export_test.go @@ -77,3 +77,79 @@ func TestFormatAsYaml(t *testing.T) { }) } } + +func TestFormatAsDotEnvEval(t *testing.T) { + tests := []struct { + name string + input []models.SingleEnvironmentVariable + expected string + }{ + { + name: "Empty input", + input: []models.SingleEnvironmentVariable{}, + expected: "", + }, + { + name: "Simple value", + input: []models.SingleEnvironmentVariable{ + {Key: "KEY1", Value: "simple"}, + }, + expected: "export KEY1='simple'\n", + }, + { + name: "Value containing single quote", + input: []models.SingleEnvironmentVariable{ + {Key: "KEY1", Value: "it's a value"}, + }, + expected: "export KEY1='it'\\''s a value'\n", + }, + { + name: "Multiline value is preserved verbatim", + input: []models.SingleEnvironmentVariable{ + {Key: "KEY1", Value: "line1\nline2"}, + }, + expected: "export KEY1='line1\nline2'\n", + }, + { + name: "Multiline value with skipMultilineEncoding set still emits real newlines", + input: []models.SingleEnvironmentVariable{ + {Key: "KEY1", Value: "line1\nline2", SkipMultilineEncoding: true}, + }, + expected: "export KEY1='line1\nline2'\n", + }, + { + name: "Shell metacharacters are preserved literally inside single quotes", + input: []models.SingleEnvironmentVariable{ + {Key: "KEY1", Value: `$(rm -rf /) "quotes" \backslash`}, + }, + expected: "export KEY1='$(rm -rf /) \"quotes\" \\backslash'\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, formatAsDotEnvEval(tt.input)) + }) + } +} + +func TestPosixShellQuote(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {input: "", expected: "''"}, + {input: "plain", expected: "'plain'"}, + {input: "it's", expected: `'it'\''s'`}, + {input: "'leading", expected: `''\''leading'`}, + {input: "trailing'", expected: `'trailing'\'''`}, + {input: "a'b'c", expected: `'a'\''b'\''c'`}, + {input: "with\nnewline", expected: "'with\nnewline'"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.expected, posixShellQuote(tt.input)) + }) + } +} diff --git a/packages/cmd/gateway.go b/packages/cmd/gateway.go index 4cef610b..d06a2ed8 100644 --- a/packages/cmd/gateway.go +++ b/packages/cmd/gateway.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "os/signal" + "path/filepath" "runtime" "sync/atomic" "syscall" @@ -209,6 +210,23 @@ var gatewayStartCmd = &cobra.Command{ Example: "infisical gateway start my-gateway --token=", DisableFlagsInUseLine: true, Args: cobra.MaximumNArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + pkcs11ModulePath, _ := util.GetCmdFlagOrEnv(cmd, "pkcs11-module", []string{gatewayv2.INFISICAL_PKCS11_MODULE_ENV_NAME}) + if pkcs11ModulePath == "" { + return nil + } + if !filepath.IsAbs(pkcs11ModulePath) { + return fmt.Errorf("--pkcs11-module must be an absolute path (got %q)", pkcs11ModulePath) + } + info, err := os.Stat(pkcs11ModulePath) + if err != nil { + return fmt.Errorf("PKCS#11 driver not found at %q: %w", pkcs11ModulePath, err) + } + if info.IsDir() { + return fmt.Errorf("--pkcs11-module path is a directory, expected a driver file: %q", pkcs11ModulePath) + } + return gatewayv2.MaybeExecPkcs11Launcher(pkcs11ModulePath, os.Args) + }, Run: func(cmd *cobra.Command, args []string) { enrollMethod, _ := cmd.Flags().GetString("enroll-method") // Fall back to env var for systemd-managed runs where flags aren't set. @@ -401,11 +419,14 @@ var gatewayStartCmd = &cobra.Command{ } } + pkcs11ModulePath, _ := util.GetCmdFlagOrEnv(cmd, "pkcs11-module", []string{gatewayv2.INFISICAL_PKCS11_MODULE_ENV_NAME}) + gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ - Name: gatewayName, - RelayName: relayName, - ReconnectDelay: 10 * time.Second, - UseV3Connect: runningWithStoredToken, + Name: gatewayName, + RelayName: relayName, + ReconnectDelay: 10 * time.Second, + UseV3Connect: runningWithStoredToken, + Pkcs11ModulePath: pkcs11ModulePath, }) if err != nil { @@ -591,6 +612,11 @@ var gatewaySystemdInstallCmd = &cobra.Command{ enrollMethod, _ := cmd.Flags().GetString("enroll-method") + pkcs11ModulePath, _ := cmd.Flags().GetString("pkcs11-module") + if pkcs11ModulePath != "" && !filepath.IsAbs(pkcs11ModulePath) { + util.HandleError(fmt.Errorf("--pkcs11-module must be an absolute path (got %q)", pkcs11ModulePath)) + } + var installedServiceName string if enrollMethod == gatewayv2.EnrollMethodToken { @@ -616,7 +642,7 @@ var gatewaySystemdInstallCmd = &cobra.Command{ } // Install systemd service using the long-lived access token - svcName, installErr := gatewayv2.InstallEnrolledGatewaySystemdService(enrollResp.AccessToken, domain, gatewayName, relayName, serviceLogFile) + svcName, installErr := gatewayv2.InstallEnrolledGatewaySystemdService(enrollResp.AccessToken, domain, gatewayName, relayName, serviceLogFile, pkcs11ModulePath) if installErr != nil { util.HandleError(installErr, "Unable to install systemd service") } @@ -632,7 +658,7 @@ var gatewaySystemdInstallCmd = &cobra.Command{ relayName, _ := util.GetRelayName(cmd, false, "") - svcName, installErr := gatewayv2.InstallAwsAuthGatewaySystemdService(gatewayID, domain, gatewayName, relayName, serviceLogFile) + svcName, installErr := gatewayv2.InstallAwsAuthGatewaySystemdService(gatewayID, domain, gatewayName, relayName, serviceLogFile, pkcs11ModulePath) if installErr != nil { util.HandleError(installErr, "Unable to install systemd service") } @@ -653,7 +679,7 @@ var gatewaySystemdInstallCmd = &cobra.Command{ util.HandleError(relayErr, "unable to get relay name") } - svcName, installErr := gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, relayName, serviceLogFile) + svcName, installErr := gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, relayName, serviceLogFile, pkcs11ModulePath) if installErr != nil { util.HandleError(installErr, "Unable to install systemd service") } @@ -759,6 +785,7 @@ func init() { gatewayStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") gatewayStartCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") gatewayStartCmd.Flags().String("pam-session-recording-path", "", "directory path for PAM session recordings (defaults to /var/lib/infisical/session_recordings)") + gatewayStartCmd.Flags().String("pkcs11-module", "", "absolute path to a PKCS#11 driver (e.g. /opt/fortanix/pkcs11/fortanix_pkcs11.so). When set, the gateway loads the driver and serves HSM operations through it.") // Legacy install command flags (v1) gatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") @@ -774,6 +801,7 @@ func init() { gatewaySystemdInstallCmd.Flags().String("relay", "", "The name of the relay (deprecated, use --target-relay-name)") // Deprecated, use --target-relay-name instead gatewaySystemdInstallCmd.Flags().String("target-relay-name", "", "The name of the relay") gatewaySystemdInstallCmd.Flags().String("log-file", "", "The file to write the service logs to. Example: /var/log/infisical/gateway.log. If not provided, logs will not be written to a file.") + gatewaySystemdInstallCmd.Flags().String("pkcs11-module", "", "absolute path to a PKCS#11 driver (e.g. /opt/fortanix/pkcs11/fortanix_pkcs11.so). When set, the systemd service starts the gateway with the PKCS#11 driver loaded for HSM operations.") // Gateway relay command flags gatewayRelayCmd.Flags().String("config", "", "Relay config yaml file path") diff --git a/packages/cmd/root.go b/packages/cmd/root.go index de1b0b82..e690ed6c 100644 --- a/packages/cmd/root.go +++ b/packages/cmd/root.go @@ -132,9 +132,9 @@ func init() { config.INFISICAL_URL = util.AppendAPIEndpoint(resolveDomain(cmd, config.INFISICAL_URL)) - // util.DisplayAptInstallationChangeBannerWithWriter(silent, cmd.ErrOrStderr()) if !util.IsRunningInDocker() && !silent && !isStructuredOutputRequested(cmd) { util.CheckForUpdateWithWriter(cmd.ErrOrStderr()) + util.DisplayPackageRepoMigrationNoticeWithWriter(silent, cmd.ErrOrStderr()) } loggedInDetails, err := util.GetCurrentLoggedInUserDetails(false) diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go index ce92df1b..be470823 100644 --- a/packages/gateway-v2/constants.go +++ b/packages/gateway-v2/constants.go @@ -6,6 +6,7 @@ const ( KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" INFISICAL_PAM_SESSION_RECORDING_PATH_ENV_NAME = "INFISICAL_PAM_SESSION_RECORDING_PATH" + INFISICAL_PKCS11_MODULE_ENV_NAME = "INFISICAL_PKCS11_MODULE" RELAY_NAME_ENV_NAME = "INFISICAL_RELAY_NAME" RELAY_HOST_ENV_NAME = "INFISICAL_RELAY_HOST" diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index de87d37b..adf66f2b 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -40,6 +40,7 @@ const ( ForwardModePAMCapabilities ForwardMode = "PAM_CAPABILITIES" ForwardModePing ForwardMode = "PING" ForwardModeHealth ForwardMode = "HEALTH" + ForwardModePkcs11 ForwardMode = "PKCS11" ) type ActorType string @@ -82,12 +83,13 @@ type ActorDetails struct { } type GatewayConfig struct { - Name string - RelayName string - IdentityToken string - SSHPort int - ReconnectDelay time.Duration - UseV3Connect bool // Use V3 /connect endpoint instead of V2 /gateways for cert refresh + Name string + RelayName string + IdentityToken string + SSHPort int + ReconnectDelay time.Duration + UseV3Connect bool // Use V3 /connect endpoint instead of V2 /gateways for cert refresh + Pkcs11ModulePath string } type pamSessionEntry struct { @@ -132,6 +134,7 @@ type Gateway struct { // MongoDB proxy registry: one topology per session, shared across connections mongoProxies map[string]*mongoProxyEntry mongoProxiesMu sync.Mutex + pkcs11Module Pkcs11Module } // mongoProxyEntry holds a session-level MongoDB proxy with a ready signal. @@ -160,6 +163,12 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { pamCredentialsManager := session.NewCredentialsManager(httpClient) + pkcs11Module, err := setupPkcs11ModuleForConfig(config.Pkcs11ModulePath) + if err != nil { + cancel() + return nil, fmt.Errorf("failed to load PKCS#11 module: %w", err) + } + return &Gateway{ httpClient: httpClient, config: config, @@ -169,6 +178,7 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { pamSessionUploader: session.NewSessionUploader(httpClient, pamCredentialsManager), pamSessions: make(map[string][]*pamSessionEntry), mongoProxies: make(map[string]*mongoProxyEntry), + pkcs11Module: pkcs11Module, }, nil } @@ -366,7 +376,12 @@ func (g *Gateway) reapIdleSessions() { func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { sendHeartbeat := func() error { - if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { + capabilities := map[string]any{} + if g.pkcs11Module != nil { + capabilities[CapabilityPkcs11] = true + } + req := api.GatewayHeartbeatRequest{Capabilities: capabilities} + if err := api.CallGatewayHeartBeatV2(g.httpClient, req); err != nil { log.Warn().Msgf("Heartbeat failed: %v", err) select { case errCh <- err: @@ -502,6 +517,13 @@ func (g *Gateway) Stop() { if g.pamCredentialsManager != nil { g.pamCredentialsManager.Shutdown() } + + if g.pkcs11Module != nil { + if err := g.pkcs11Module.Finalize(); err != nil { + log.Warn().Err(err).Msg("PKCS#11 module Finalize returned an error") + } + g.pkcs11Module = nil + } } func (g *Gateway) startHeartbeatOnce(ctx context.Context, errCh chan error) { @@ -707,7 +729,7 @@ func (g *Gateway) setupTLSConfig() error { ClientCAs: clientCAPool, ClientAuth: tls.RequireAndVerifyClientCert, MinVersion: tls.VersionTLS12, - NextProtos: []string{"infisical-http-proxy", "infisical-tcp-proxy", "infisical-health", "infisical-ping", "infisical-pam-proxy", "infisical-pam-rdp-browser", "infisical-pam-session-cancellation", "infisical-pam-capabilities"}, + NextProtos: nextProtosForGateway(g.pkcs11Module != nil), } return nil @@ -921,6 +943,14 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Info().Msg("Health handler completed") } return + } else if forwardConfig.Mode == ForwardModePkcs11 { + log.Info().Msg("Starting PKCS#11 handler") + if err := servePkcs11OverTLS(g.ctx, tlsConn, reader, g.pkcs11Module); err != nil { + log.Error().Err(err).Msg("PKCS#11 handler ended with error") + } else { + log.Info().Msg("PKCS#11 handler completed") + } + return } } @@ -975,6 +1005,10 @@ func (g *Gateway) parseForwardConfigFromALPN(tlsConn *tls.Conn, reader *bufio.Re config.Mode = ForwardModeHealth return config, nil + case "infisical-pkcs11": + config.Mode = ForwardModePkcs11 + return config, nil + default: return nil, fmt.Errorf("unsupported ALPN protocol: %s", negotiatedProtocol) } @@ -1137,3 +1171,20 @@ func (g *Gateway) renewCertificates() error { return nil } + +func nextProtosForGateway(pkcs11Loaded bool) []string { + base := []string{ + "infisical-http-proxy", + "infisical-tcp-proxy", + "infisical-health", + "infisical-ping", + "infisical-pam-proxy", + "infisical-pam-rdp-browser", + "infisical-pam-session-cancellation", + "infisical-pam-capabilities", + } + if pkcs11Loaded { + base = append(base, "infisical-pkcs11") + } + return base +} diff --git a/packages/gateway-v2/pkcs11.go b/packages/gateway-v2/pkcs11.go new file mode 100644 index 00000000..59077a01 --- /dev/null +++ b/packages/gateway-v2/pkcs11.go @@ -0,0 +1,53 @@ +package gatewayv2 + +type Pkcs11Module interface { + Test(slotLabel string, pin []byte) (SlotInfo, error) + + GenerateKeyPair(slotLabel string, pin []byte, keyLabel string, keyAlgorithm string) ([]byte, error) + + GetPublicKey(slotLabel string, pin []byte, keyLabel string) ([]byte, error) + + Sign(slotLabel string, pin []byte, keyLabel string, mechanism string, data []byte, isDigest bool) ([]byte, error) + + Finalize() error +} + +type SlotInfo struct { + Manufacturer string `json:"manufacturer"` + Model string `json:"model"` + Firmware string `json:"firmware"` +} + +type Pkcs11ErrorCode string + +const ( + Pkcs11ErrPinIncorrect Pkcs11ErrorCode = "pin_incorrect" + Pkcs11ErrPinLocked Pkcs11ErrorCode = "pin_locked" + Pkcs11ErrSlotNotFound Pkcs11ErrorCode = "slot_not_found" + Pkcs11ErrKeyNotFound Pkcs11ErrorCode = "key_not_found" + Pkcs11ErrMechanismInvalid Pkcs11ErrorCode = "mechanism_invalid" + Pkcs11ErrDriverUnavailable Pkcs11ErrorCode = "driver_unavailable" + Pkcs11ErrLoginFailed Pkcs11ErrorCode = "login_failed" + Pkcs11ErrNotSupported Pkcs11ErrorCode = "pkcs11_not_supported" + Pkcs11ErrBadRequest Pkcs11ErrorCode = "bad_request" + Pkcs11ErrInternal Pkcs11ErrorCode = "internal" +) + +type Pkcs11Error struct { + Code Pkcs11ErrorCode + Message string +} + +func (e *Pkcs11Error) Error() string { + return string(e.Code) + ": " + e.Message +} + +// Supported keyAlgorithm values. +const ( + KeyAlgorithmRSA2048 = "RSA_2048" + KeyAlgorithmRSA4096 = "RSA_4096" + KeyAlgorithmECCP256 = "ECC_P256" + KeyAlgorithmECCP384 = "ECC_P384" +) + +const CapabilityPkcs11 = "pkcs11" diff --git a/packages/gateway-v2/pkcs11_enabled.go b/packages/gateway-v2/pkcs11_enabled.go new file mode 100644 index 00000000..5808a3ea --- /dev/null +++ b/packages/gateway-v2/pkcs11_enabled.go @@ -0,0 +1,452 @@ +//go:build pkcs11 + +package gatewayv2 + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/asn1" + "encoding/binary" + "fmt" + "math/big" + "strings" + "sync" + + "github.com/miekg/pkcs11" + "github.com/rs/zerolog/log" +) + +var ( + ecParamsP256 = []byte{0x06, 0x08, 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07} + ecParamsP384 = []byte{0x06, 0x05, 0x2B, 0x81, 0x04, 0x00, 0x22} +) + +type pkcs11ModuleImpl struct { + mu sync.Mutex + ctx *pkcs11.Ctx +} + +func LoadPkcs11Module(path string) (Pkcs11Module, error) { + if strings.TrimSpace(path) == "" { + return nil, &Pkcs11Error{ + Code: Pkcs11ErrDriverUnavailable, + Message: "Empty --pkcs11-module path", + } + } + ctx := pkcs11.New(path) + if ctx == nil { + return nil, &Pkcs11Error{ + Code: Pkcs11ErrDriverUnavailable, + Message: fmt.Sprintf("Failed to dlopen PKCS#11 driver at %q", path), + } + } + if err := ctx.Initialize(); err != nil { + if e, ok := err.(pkcs11.Error); !ok || e != pkcs11.CKR_CRYPTOKI_ALREADY_INITIALIZED { + ctx.Destroy() + return nil, &Pkcs11Error{ + Code: Pkcs11ErrDriverUnavailable, + Message: fmt.Sprintf("PKCS#11 C_Initialize failed: %v", err), + } + } + } + if _, err := ctx.GetSlotList(true); err != nil { + _ = ctx.Finalize() + ctx.Destroy() + return nil, &Pkcs11Error{ + Code: Pkcs11ErrDriverUnavailable, + Message: fmt.Sprintf("PKCS#11 C_GetSlotList failed: %v", err), + } + } + return &pkcs11ModuleImpl{ctx: ctx}, nil +} + +func (m *pkcs11ModuleImpl) Finalize() error { + m.mu.Lock() + defer m.mu.Unlock() + if m.ctx == nil { + return nil + } + err := m.ctx.Finalize() + m.ctx.Destroy() + m.ctx = nil + return err +} + +type sessionFn func(slot uint, sh pkcs11.SessionHandle) error + +func (m *pkcs11ModuleImpl) withSession(slotLabel string, pin []byte, fn sessionFn) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.ctx == nil { + return &Pkcs11Error{Code: Pkcs11ErrDriverUnavailable, Message: "Module is not loaded"} + } + slots, err := m.ctx.GetSlotList(true) + if err != nil { + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "GetSlotList failed"} + } + slot, ok := findSlotByLabel(m.ctx, slots, slotLabel) + if !ok { + return &Pkcs11Error{Code: Pkcs11ErrSlotNotFound, Message: fmt.Sprintf("Slot %q not found on this HSM", slotLabel)} + } + session, err := m.ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + if err != nil { + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "OpenSession failed"} + } + defer func() { + if closeErr := m.ctx.CloseSession(session); closeErr != nil { + log.Warn().Err(closeErr).Msg("pkcs11: CloseSession failed") + } + }() + + loggedIn := false + loginErr := m.ctx.Login(session, pkcs11.CKU_USER, string(pin)) + if loginErr != nil { + if e, ok := loginErr.(pkcs11.Error); !ok || e != pkcs11.CKR_USER_ALREADY_LOGGED_IN { + return mapPkcs11LoginError(loginErr) + } + } else { + loggedIn = true + } + if loggedIn { + defer func() { + if logoutErr := m.ctx.Logout(session); logoutErr != nil { + log.Warn().Err(logoutErr).Msg("pkcs11: Logout failed") + } + }() + } + + return fn(slot, session) +} + +func findSlotByLabel(ctx *pkcs11.Ctx, slots []uint, label string) (uint, bool) { + for _, slot := range slots { + ti, err := ctx.GetTokenInfo(slot) + if err != nil { + continue + } + if strings.TrimRight(ti.Label, " \x00") == label { + return slot, true + } + } + return 0, false +} + +func mapPkcs11LoginError(err error) error { + if e, ok := err.(pkcs11.Error); ok { + switch e { + case pkcs11.CKR_PIN_INCORRECT: + return &Pkcs11Error{Code: Pkcs11ErrPinIncorrect, Message: "The HSM rejected the PIN"} + case pkcs11.CKR_PIN_LOCKED: + return &Pkcs11Error{Code: Pkcs11ErrPinLocked, Message: "The HSM has locked the slot"} + case pkcs11.CKR_TOKEN_NOT_PRESENT, pkcs11.CKR_DEVICE_REMOVED, pkcs11.CKR_DEVICE_ERROR: + return &Pkcs11Error{Code: Pkcs11ErrDriverUnavailable, Message: "Driver unavailable"} + } + } + return &Pkcs11Error{Code: Pkcs11ErrLoginFailed, Message: "The HSM rejected the login"} +} + +func (m *pkcs11ModuleImpl) Test(slotLabel string, pin []byte) (SlotInfo, error) { + var info SlotInfo + err := m.withSession(slotLabel, pin, func(slot uint, _ pkcs11.SessionHandle) error { + ti, err := m.ctx.GetTokenInfo(slot) + if err != nil { + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "GetTokenInfo failed"} + } + info = SlotInfo{ + Manufacturer: strings.TrimRight(ti.ManufacturerID, " \x00"), + Model: strings.TrimRight(ti.Model, " \x00"), + Firmware: fmt.Sprintf("%d.%d", ti.FirmwareVersion.Major, ti.FirmwareVersion.Minor), + } + return nil + }) + return info, err +} + +func (m *pkcs11ModuleImpl) GenerateKeyPair(slotLabel string, pin []byte, keyLabel, keyAlgorithm string) ([]byte, error) { + var spkiDer []byte + err := m.withSession(slotLabel, pin, func(_ uint, session pkcs11.SessionHandle) error { + mech, pubTpl, privTpl, err := generateKeyPairTemplates(keyLabel, keyAlgorithm) + if err != nil { + return err + } + pubHandle, _, err := m.ctx.GenerateKeyPair(session, mech, pubTpl, privTpl) + if err != nil { + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "GenerateKeyPair failed"} + } + der, err := buildSpkiFromHandle(m.ctx, session, pubHandle, keyAlgorithm) + if err != nil { + return err + } + spkiDer = der + return nil + }) + return spkiDer, err +} + +func generateKeyPairTemplates(keyLabel, keyAlgorithm string) ([]*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute, error) { + commonPriv := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(keyLabel)), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_PRIVATE, true), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), + pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), + } + commonPub := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(keyLabel)), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), + } + switch keyAlgorithm { + case KeyAlgorithmRSA2048, KeyAlgorithmRSA4096: + modulusBits := 2048 + if keyAlgorithm == KeyAlgorithmRSA4096 { + modulusBits = 4096 + } + pubTpl := append(commonPub, + pkcs11.NewAttribute(pkcs11.CKA_MODULUS_BITS, modulusBits), + pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, []byte{0x01, 0x00, 0x01}), + ) + return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_KEY_PAIR_GEN, nil)}, pubTpl, commonPriv, nil + case KeyAlgorithmECCP256: + pubTpl := append(commonPub, pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, ecParamsP256)) + return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_EC_KEY_PAIR_GEN, nil)}, pubTpl, commonPriv, nil + case KeyAlgorithmECCP384: + pubTpl := append(commonPub, pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, ecParamsP384)) + return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_EC_KEY_PAIR_GEN, nil)}, pubTpl, commonPriv, nil + default: + return nil, nil, nil, &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: fmt.Sprintf("Unsupported keyAlgorithm %q", keyAlgorithm)} + } +} + +func buildSpkiFromHandle(ctx *pkcs11.Ctx, session pkcs11.SessionHandle, pubHandle pkcs11.ObjectHandle, keyAlgorithm string) ([]byte, error) { + switch keyAlgorithm { + case KeyAlgorithmRSA2048, KeyAlgorithmRSA4096: + attrs, err := ctx.GetAttributeValue(session, pubHandle, []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_MODULUS, nil), + pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, nil), + }) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "RSA GetAttributeValue failed"} + } + var modulus, exp []byte + for _, a := range attrs { + switch a.Type { + case pkcs11.CKA_MODULUS: + modulus = a.Value + case pkcs11.CKA_PUBLIC_EXPONENT: + exp = a.Value + } + } + pub := &rsa.PublicKey{ + N: new(big.Int).SetBytes(modulus), + E: int(new(big.Int).SetBytes(exp).Int64()), + } + der, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "MarshalPKIXPublicKey failed"} + } + return der, nil + + case KeyAlgorithmECCP256, KeyAlgorithmECCP384: + attrs, err := ctx.GetAttributeValue(session, pubHandle, []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, nil), + }) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "EC GetAttributeValue failed"} + } + if len(attrs) == 0 { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "CKA_EC_POINT missing from response"} + } + // CKA_EC_POINT is DER OCTET STRING wrapping the raw point. + var raw []byte + if _, err := asn1.Unmarshal(attrs[0].Value, &raw); err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Unmarshal CKA_EC_POINT failed"} + } + var curve elliptic.Curve + if keyAlgorithm == KeyAlgorithmECCP256 { + curve = elliptic.P256() + } else { + curve = elliptic.P384() + } + // Parse the uncompressed point format (RFC 5480 Section 2.2): 0x04 || X || Y + // with each coordinate padded to (BitSize + 7) / 8 bytes. Stdlib's + // elliptic.Unmarshal is deprecated and there is no ECDSA-specific replacement, + // so do the parse inline. + byteLen := (curve.Params().BitSize + 7) / 8 + if len(raw) != 1+2*byteLen || raw[0] != 0x04 { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Failed to unmarshal EC point"} + } + x := new(big.Int).SetBytes(raw[1 : 1+byteLen]) + y := new(big.Int).SetBytes(raw[1+byteLen:]) + pub := &ecdsa.PublicKey{Curve: curve, X: x, Y: y} + der, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "MarshalPKIXPublicKey failed"} + } + return der, nil + } + return nil, &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: "Unsupported keyAlgorithm for SPKI build"} +} + +func (m *pkcs11ModuleImpl) GetPublicKey(slotLabel string, pin []byte, keyLabel string) ([]byte, error) { + var spkiDer []byte + err := m.withSession(slotLabel, pin, func(_ uint, session pkcs11.SessionHandle) error { + handle, found, err := findObject(m.ctx, session, keyLabel, pkcs11.CKO_PUBLIC_KEY) + if err != nil { + return err + } + if !found { + return &Pkcs11Error{Code: Pkcs11ErrKeyNotFound, Message: fmt.Sprintf("Public key with label %q not found", keyLabel)} + } + alg, err := detectKeyAlgorithm(m.ctx, session, handle) + if err != nil { + return err + } + der, err := buildSpkiFromHandle(m.ctx, session, handle, alg) + if err != nil { + return err + } + spkiDer = der + return nil + }) + return spkiDer, err +} + +func detectKeyAlgorithm(ctx *pkcs11.Ctx, session pkcs11.SessionHandle, handle pkcs11.ObjectHandle) (string, error) { + attrs, err := ctx.GetAttributeValue(session, handle, []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), + }) + if err != nil || len(attrs) == 0 { + return "", &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Failed to read CKA_KEY_TYPE"} + } + raw := make([]byte, 8) + copy(raw, attrs[0].Value) + keyType := uint(binary.LittleEndian.Uint64(raw)) + switch keyType { + case pkcs11.CKK_RSA: + modAttrs, err := ctx.GetAttributeValue(session, handle, []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_MODULUS, nil)}) + if err != nil || len(modAttrs) == 0 { + return "", &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Failed to read CKA_MODULUS"} + } + switch len(modAttrs[0].Value) { + case 256: + return KeyAlgorithmRSA2048, nil + case 512: + return KeyAlgorithmRSA4096, nil + } + return "", &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: fmt.Sprintf("Unsupported RSA modulus length: %d bits", len(modAttrs[0].Value)*8)} + case pkcs11.CKK_EC: + paramsAttrs, err := ctx.GetAttributeValue(session, handle, []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, nil)}) + if err != nil || len(paramsAttrs) == 0 { + return "", &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Failed to read CKA_EC_PARAMS"} + } + if bytes.Equal(paramsAttrs[0].Value, ecParamsP256) { + return KeyAlgorithmECCP256, nil + } + if bytes.Equal(paramsAttrs[0].Value, ecParamsP384) { + return KeyAlgorithmECCP384, nil + } + return "", &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: "Unsupported EC curve"} + } + return "", &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: fmt.Sprintf("Unsupported PKCS#11 key type: %d", keyType)} +} + +func findObject(ctx *pkcs11.Ctx, session pkcs11.SessionHandle, label string, class uint) (pkcs11.ObjectHandle, bool, error) { + tpl := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(label)), + pkcs11.NewAttribute(pkcs11.CKA_CLASS, class), + } + if err := ctx.FindObjectsInit(session, tpl); err != nil { + return 0, false, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "FindObjectsInit failed"} + } + defer func() { + if finalErr := ctx.FindObjectsFinal(session); finalErr != nil { + log.Warn().Err(finalErr).Msg("pkcs11: FindObjectsFinal failed") + } + }() + objs, _, err := ctx.FindObjects(session, 2) + if err != nil { + return 0, false, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "FindObjects failed"} + } + if len(objs) == 0 { + return 0, false, nil + } + if len(objs) > 1 { + return 0, false, &Pkcs11Error{ + Code: Pkcs11ErrBadRequest, + Message: fmt.Sprintf("Multiple objects on the HSM share label %q. Resolve the duplicate before proceeding.", label), + } + } + return objs[0], true, nil +} + +func (m *pkcs11ModuleImpl) Sign(slotLabel string, pin []byte, keyLabel, mechanism string, data []byte, isDigest bool) ([]byte, error) { + log.Debug().Str("keyLabel", keyLabel).Str("mech", mechanism).Int("dataLen", len(data)).Msg("pkcs11.Sign: enter") + var sig []byte + err := m.withSession(slotLabel, pin, func(_ uint, session pkcs11.SessionHandle) error { + mechCode, params, err := resolveMechanism(mechanism, isDigest) + if err != nil { + return err + } + handle, found, err := findObject(m.ctx, session, keyLabel, pkcs11.CKO_PRIVATE_KEY) + if err != nil { + return err + } + if !found { + return &Pkcs11Error{Code: Pkcs11ErrKeyNotFound, Message: fmt.Sprintf("Private key with label %q not found", keyLabel)} + } + if err := m.ctx.SignInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechCode, params)}, handle); err != nil { + return mapPkcs11SignError(err) + } + out, err := m.ctx.Sign(session, data) + if err != nil { + return mapPkcs11SignError(err) + } + sig = out + return nil + }) + log.Debug().Bool("ok", err == nil).Int("sigLen", len(sig)).Msg("pkcs11.Sign: done") + return sig, err +} + +func resolveMechanism(name string, isDigest bool) (uint, []byte, error) { + switch name { + case "CKM_SHA256_RSA_PKCS": + return pkcs11.CKM_SHA256_RSA_PKCS, nil, nil + case "CKM_SHA384_RSA_PKCS": + return pkcs11.CKM_SHA384_RSA_PKCS, nil, nil + case "CKM_SHA512_RSA_PKCS": + return pkcs11.CKM_SHA512_RSA_PKCS, nil, nil + case "CKM_ECDSA_SHA256": + if isDigest { + return pkcs11.CKM_ECDSA, nil, nil + } + return pkcs11.CKM_ECDSA_SHA256, nil, nil + case "CKM_ECDSA_SHA384": + if isDigest { + return pkcs11.CKM_ECDSA, nil, nil + } + return pkcs11.CKM_ECDSA_SHA384, nil, nil + case "CKM_ECDSA": + return pkcs11.CKM_ECDSA, nil, nil + } + return 0, nil, &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: fmt.Sprintf("Unsupported mechanism %q", name)} +} + +func mapPkcs11SignError(err error) error { + if e, ok := err.(pkcs11.Error); ok { + switch e { + case pkcs11.CKR_KEY_HANDLE_INVALID, pkcs11.CKR_OBJECT_HANDLE_INVALID: + return &Pkcs11Error{Code: Pkcs11ErrKeyNotFound, Message: "The HSM rejected the key handle"} + case pkcs11.CKR_MECHANISM_INVALID, pkcs11.CKR_KEY_TYPE_INCONSISTENT: + return &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: "The HSM does not support the requested signing algorithm"} + case pkcs11.CKR_TOKEN_NOT_PRESENT, pkcs11.CKR_DEVICE_REMOVED, pkcs11.CKR_DEVICE_ERROR: + return &Pkcs11Error{Code: Pkcs11ErrDriverUnavailable, Message: "Driver unavailable"} + } + } + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Sign operation failed"} +} diff --git a/packages/gateway-v2/pkcs11_handler.go b/packages/gateway-v2/pkcs11_handler.go new file mode 100644 index 00000000..6a0eb5ec --- /dev/null +++ b/packages/gateway-v2/pkcs11_handler.go @@ -0,0 +1,353 @@ +package gatewayv2 + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog/log" +) + +type pkcs11RequestEnvelope struct { + SlotLabel string `json:"slotLabel"` + PIN []byte `json:"-"` + Params json.RawMessage `json:"params"` +} + +func (e *pkcs11RequestEnvelope) UnmarshalJSON(data []byte) error { + var raw struct { + SlotLabel string `json:"slotLabel"` + PIN string `json:"pin"` + Params json.RawMessage `json:"params"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + e.SlotLabel = raw.SlotLabel + e.PIN = []byte(raw.PIN) + e.Params = raw.Params + return nil +} + +type pkcs11Response struct { + Result json.RawMessage `json:"result"` +} + +type pkcs11ErrorResponse struct { + Error pkcs11ErrorBody `json:"error"` +} + +type pkcs11ErrorBody struct { + Code Pkcs11ErrorCode `json:"code"` + Message string `json:"message"` +} + +const pkcs11RequestDeadline = 30 * time.Second + +func servePkcs11OverTLS(ctx context.Context, conn *tls.Conn, reader *bufio.Reader, module Pkcs11Module) error { + _ = conn.SetDeadline(time.Now().Add(pkcs11RequestDeadline)) + + if module == nil { + writeErrorResponse(conn, http.StatusServiceUnavailable, Pkcs11ErrNotSupported, "PKCS#11 module not loaded") + return errors.New("PKCS#11 module is nil") + } + + reqCh := make(chan *http.Request, 1) + errCh := make(chan error, 1) + go func() { + req, err := http.ReadRequest(reader) + if err != nil { + errCh <- err + return + } + reqCh <- req + }() + + var req *http.Request + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return fmt.Errorf("failed to read HTTP request: %w", err) + case req = <-reqCh: + } + + log.Debug().Str("path", req.URL.Path).Int64("contentLength", req.ContentLength).Msg("pkcs11: request received") + + rw := newBufferedResponseWriter() + servePkcs11Mux(module).ServeHTTP(rw, req) + if err := rw.writeTo(conn); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + log.Debug().Int("status", rw.status).Msg("pkcs11: response written") + return nil +} + +type bufferedResponseWriter struct { + header http.Header + body bytes.Buffer + status int + wroteStart bool +} + +func newBufferedResponseWriter() *bufferedResponseWriter { + return &bufferedResponseWriter{header: http.Header{}, status: http.StatusOK} +} +func (b *bufferedResponseWriter) Header() http.Header { return b.header } +func (b *bufferedResponseWriter) WriteHeader(s int) { + if b.wroteStart { + return + } + b.status = s + b.wroteStart = true +} +func (b *bufferedResponseWriter) Write(p []byte) (int, error) { + if !b.wroteStart { + b.WriteHeader(http.StatusOK) + } + return b.body.Write(p) +} +func (b *bufferedResponseWriter) writeTo(conn *tls.Conn) error { + body := b.body.Bytes() + if b.header.Get("Content-Length") == "" { + b.header.Set("Content-Length", strconv.Itoa(len(body))) + } + if b.header.Get("Connection") == "" { + b.header.Set("Connection", "close") + } + var sb strings.Builder + fmt.Fprintf(&sb, "HTTP/1.1 %d %s\r\n", b.status, http.StatusText(b.status)) + for k, vs := range b.header { + for _, v := range vs { + sb.WriteString(k) + sb.WriteString(": ") + sb.WriteString(v) + sb.WriteString("\r\n") + } + } + sb.WriteString("\r\n") + if _, err := conn.Write([]byte(sb.String())); err != nil { + return err + } + if _, err := conn.Write(body); err != nil { + return err + } + return nil +} + +func servePkcs11Mux(module Pkcs11Module) *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/v1/test", wrapPkcs11(module, handleTest)) + mux.HandleFunc("/v1/generate-key-pair", wrapPkcs11(module, handleGenerateKeyPair)) + mux.HandleFunc("/v1/sign", wrapPkcs11(module, handleSign)) + mux.HandleFunc("/v1/get-public-key", wrapPkcs11(module, handleGetPublicKey)) + return mux +} + +type pkcs11Handler func(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) + +const maxPkcs11RequestBodyBytes = 256 * 1024 + +func zeroBytes(b []byte) { + for i := range b { + b[i] = 0 + } +} + +func safeMessageForCode(code Pkcs11ErrorCode) string { + switch code { + case Pkcs11ErrPinIncorrect: + return "The HSM rejected the PIN" + case Pkcs11ErrPinLocked: + return "The HSM has locked the slot" + case Pkcs11ErrLoginFailed: + return "The HSM rejected the login" + case Pkcs11ErrSlotNotFound: + return "Slot not found on this HSM" + case Pkcs11ErrKeyNotFound: + return "Key not found on this HSM" + case Pkcs11ErrMechanismInvalid: + return "Mechanism not supported by this HSM" + case Pkcs11ErrDriverUnavailable: + return "Driver unavailable" + case Pkcs11ErrNotSupported: + return "Operation not supported" + case Pkcs11ErrBadRequest: + return "Invalid request" + } + return "Operation failed" +} + +func wrapPkcs11(module Pkcs11Module, fn pkcs11Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + defer func() { + if r.Body != nil { + _ = r.Body.Close() + } + }() + log.Debug().Str("path", r.URL.Path).Str("method", r.Method).Int64("contentLength", r.ContentLength).Msg("pkcs11: handler received request") + if r.Method != http.MethodPost { + writeErrorResponse(w, http.StatusMethodNotAllowed, Pkcs11ErrBadRequest, "Only POST is supported") + return + } + if r.ContentLength > maxPkcs11RequestBodyBytes { + log.Error().Int64("contentLength", r.ContentLength).Msg("pkcs11: request body too large") + writeErrorResponse(w, http.StatusRequestEntityTooLarge, Pkcs11ErrBadRequest, "Request body too large") + return + } + r.Body = http.MaxBytesReader(w, r.Body, maxPkcs11RequestBodyBytes) + var env pkcs11RequestEnvelope + if err := json.NewDecoder(r.Body).Decode(&env); err != nil { + log.Warn().Err(err).Msg("pkcs11: body decode failed") + writeErrorResponse(w, http.StatusBadRequest, Pkcs11ErrBadRequest, "Malformed request body") + return + } + defer zeroBytes(env.PIN) + log.Debug().Bool("hasPin", len(env.PIN) > 0).Msg("pkcs11: body decoded, dispatching to op handler") + result, err := fn(module, &env) + log.Debug().Bool("ok", err == nil).Msg("pkcs11: op handler returned") + if err != nil { + var p11Err *Pkcs11Error + if errors.As(err, &p11Err) { + log.Error().Str("code", string(p11Err.Code)).Str("errorMessage", p11Err.Message).Msg("pkcs11: op handler returned typed error") + writeErrorResponse(w, statusForCode(p11Err.Code), p11Err.Code, safeMessageForCode(p11Err.Code)) + return + } + log.Error().Err(err).Msg("pkcs11: op handler returned untyped error") + writeErrorResponse(w, http.StatusInternalServerError, Pkcs11ErrInternal, safeMessageForCode(Pkcs11ErrInternal)) + return + } + raw, err := json.Marshal(result) + if err != nil { + log.Error().Err(err).Msg("pkcs11: failed to marshal result") + writeErrorResponse(w, http.StatusInternalServerError, Pkcs11ErrInternal, "Failed to marshal result") + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(pkcs11Response{Result: raw}); err != nil { + log.Warn().Err(err).Msg("pkcs11: failed to encode response") + } + } +} + +type generateKeyPairParams struct { + KeyLabel string `json:"keyLabel"` + KeyAlgorithm string `json:"keyAlgorithm"` +} + +type signParams struct { + KeyLabel string `json:"keyLabel"` + Mechanism string `json:"mechanism"` + Data string `json:"data"` + IsDigest bool `json:"isDigest"` +} + +type singleKeyLabelParams struct { + KeyLabel string `json:"keyLabel"` +} + +const ( + maxKeyLabelLen = 256 + maxSignDataBytes = 64 * 1024 +) + +func handleTest(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) { + info, err := module.Test(env.SlotLabel, env.PIN) + if err != nil { + return nil, err + } + return map[string]any{"slotInfo": info}, nil +} + +func handleGenerateKeyPair(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) { + var p generateKeyPairParams + if err := json.Unmarshal(env.Params, &p); err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "Malformed params for generate-key-pair"} + } + if len(p.KeyLabel) > maxKeyLabelLen { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "keyLabel too long"} + } + spki, err := module.GenerateKeyPair(env.SlotLabel, env.PIN, p.KeyLabel, p.KeyAlgorithm) + if err != nil { + return nil, err + } + return map[string]any{"publicKey": base64.StdEncoding.EncodeToString(spki)}, nil +} + +func handleSign(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) { + var p signParams + if err := json.Unmarshal(env.Params, &p); err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "Malformed params for sign"} + } + if len(p.KeyLabel) > maxKeyLabelLen { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "keyLabel too long"} + } + data, err := base64.StdEncoding.DecodeString(p.Data) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "data is not valid base64"} + } + if len(data) == 0 { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "data is empty"} + } + if len(data) > maxSignDataBytes { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "Data too large for signing"} + } + sig, err := module.Sign(env.SlotLabel, env.PIN, p.KeyLabel, p.Mechanism, data, p.IsDigest) + if err != nil { + return nil, err + } + return map[string]any{"signature": base64.StdEncoding.EncodeToString(sig)}, nil +} + +func handleGetPublicKey(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) { + var p singleKeyLabelParams + if err := json.Unmarshal(env.Params, &p); err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "Malformed params for get-public-key"} + } + if len(p.KeyLabel) > maxKeyLabelLen { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "keyLabel too long"} + } + spki, err := module.GetPublicKey(env.SlotLabel, env.PIN, p.KeyLabel) + if err != nil { + return nil, err + } + return map[string]any{"publicKey": base64.StdEncoding.EncodeToString(spki)}, nil +} + +func statusForCode(code Pkcs11ErrorCode) int { + switch code { + case Pkcs11ErrPinIncorrect, Pkcs11ErrPinLocked, Pkcs11ErrLoginFailed, Pkcs11ErrSlotNotFound, Pkcs11ErrKeyNotFound, Pkcs11ErrMechanismInvalid, Pkcs11ErrBadRequest: + return http.StatusBadRequest + case Pkcs11ErrDriverUnavailable, Pkcs11ErrInternal: + return http.StatusBadGateway + case Pkcs11ErrNotSupported: + return http.StatusServiceUnavailable + } + return http.StatusBadGateway +} + +func writeErrorResponse(w any, status int, code Pkcs11ErrorCode, message string) { + body, _ := json.Marshal(pkcs11ErrorResponse{Error: pkcs11ErrorBody{Code: code, Message: message}}) + switch sink := w.(type) { + case http.ResponseWriter: + sink.Header().Set("Content-Type", "application/json") + sink.WriteHeader(status) + _, _ = sink.Write(body) + case *tls.Conn: + resp := fmt.Sprintf("HTTP/1.1 %d %s\r\nContent-Type: application/json\r\nContent-Length: %d\r\nConnection: close\r\n\r\n%s", + status, http.StatusText(status), len(body), body) + _, _ = sink.Write([]byte(resp)) + default: + log.Warn().Msg("writeErrorResponse called with unsupported sink type") + } +} diff --git a/packages/gateway-v2/pkcs11_launcher.go b/packages/gateway-v2/pkcs11_launcher.go new file mode 100644 index 00000000..32c21d5f --- /dev/null +++ b/packages/gateway-v2/pkcs11_launcher.go @@ -0,0 +1,173 @@ +//go:build !pkcs11 + +package gatewayv2 + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "syscall" + "time" + + "github.com/Infisical/infisical-merge/packages/util" + "github.com/rs/zerolog/log" +) + +const ( + releaseURLBase = "https://github.com/Infisical/cli/releases/download" + envReleaseURLBaseOverride = "INFISICAL_PKCS11_RELEASE_URL_BASE" +) + +var pkcs11HTTPClient = &http.Client{Timeout: 90 * time.Second} + +func pkcs11TarballName(version, goos, goarch string) string { + return fmt.Sprintf("infisical-pkcs11_%s_%s_%s.tar.gz", version, goos, goarch) +} + +func MaybeExecPkcs11Launcher(pkcs11ModulePath string, originalArgs []string) error { + if strings.TrimSpace(pkcs11ModulePath) == "" { + return nil + } + if runtime.GOOS != "linux" { + return fmt.Errorf("--pkcs11-module is only supported on Linux (detected %s)", runtime.GOOS) + } + if util.IsDevelopmentMode() { + return fmt.Errorf("--pkcs11-module auto-download is not available in development builds") + } + binPath, err := ensurePkcs11Binary(util.CLI_VERSION, runtime.GOOS, runtime.GOARCH) + if err != nil { + return fmt.Errorf("failed to provision infisical-pkcs11: %w", err) + } + newArgv := append([]string{binPath}, originalArgs[1:]...) + return syscall.Exec(binPath, newArgv, os.Environ()) +} + +func ensurePkcs11Binary(version, goos, goarch string) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + binDir := filepath.Join(home, ".infisical", "bin") + binPath := filepath.Join(binDir, "infisical-pkcs11") + verPath := binPath + ".version" + + if cached, err := os.ReadFile(verPath); err == nil && strings.TrimSpace(string(cached)) == version { + if _, err := os.Stat(binPath); err == nil { + return binPath, nil + } + } + + if err := os.MkdirAll(binDir, 0o755); err != nil { + return "", err + } + + base := releaseURLBase + if override := os.Getenv(envReleaseURLBaseOverride); override != "" { + base = strings.TrimRight(override, "/") + } + tarName := pkcs11TarballName(version, goos, goarch) + sumsURL := fmt.Sprintf("%s/v%s/checksums.txt", base, version) + tarURL := fmt.Sprintf("%s/v%s/%s", base, version, tarName) + + log.Info().Str("version", version).Msg("installing infisical-pkcs11 (one-time setup)") + + expectedSum, err := fetchChecksum(sumsURL, tarName) + if err != nil { + return "", err + } + tarBytes, actualSum, err := downloadAndHash(tarURL) + if err != nil { + return "", err + } + if !strings.EqualFold(actualSum, expectedSum) { + return "", fmt.Errorf("checksum mismatch for %s", tarName) + } + if err := extractPkcs11FromTarball(tarBytes, binPath); err != nil { + return "", err + } + if err := os.WriteFile(verPath, []byte(version), 0o644); err != nil { + return "", err + } + log.Info().Str("path", binPath).Msg("infisical-pkcs11 installed") + return binPath, nil +} + +func downloadAndHash(url string) ([]byte, string, error) { + resp, err := pkcs11HTTPClient.Get(url) + if err != nil { + return nil, "", err + } + defer resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + return nil, "", fmt.Errorf("download %s: HTTP %d", url, resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, "", err + } + sum := sha256.Sum256(body) + return body, hex.EncodeToString(sum[:]), nil +} + +func fetchChecksum(url, filename string) (string, error) { + resp, err := pkcs11HTTPClient.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("fetch %s: HTTP %d", url, resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + for _, line := range strings.Split(string(body), "\n") { + fields := strings.Fields(strings.TrimSpace(line)) + if len(fields) >= 2 && fields[1] == filename { + return strings.ToLower(fields[0]), nil + } + } + return "", fmt.Errorf("checksum for %s not found in %s", filename, url) +} + +func extractPkcs11FromTarball(tarGz []byte, outPath string) error { + gz, err := gzip.NewReader(bytes.NewReader(tarGz)) + if err != nil { + return err + } + defer gz.Close() //nolint:errcheck + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return err + } + if hdr.Typeflag != tar.TypeReg || filepath.Base(hdr.Name) != "infisical-pkcs11" { + continue + } + out, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + return err + } + if _, err := io.Copy(out, tr); err != nil { + out.Close() //nolint:errcheck + return err + } + return out.Close() + } + return fmt.Errorf("infisical-pkcs11 binary not found in tarball") +} diff --git a/packages/gateway-v2/pkcs11_launcher_pkcs11.go b/packages/gateway-v2/pkcs11_launcher_pkcs11.go new file mode 100644 index 00000000..2314fa68 --- /dev/null +++ b/packages/gateway-v2/pkcs11_launcher_pkcs11.go @@ -0,0 +1,7 @@ +//go:build pkcs11 + +package gatewayv2 + +func MaybeExecPkcs11Launcher(_ string, _ []string) error { + return nil +} diff --git a/packages/gateway-v2/pkcs11_setup.go b/packages/gateway-v2/pkcs11_setup.go new file mode 100644 index 00000000..798dfade --- /dev/null +++ b/packages/gateway-v2/pkcs11_setup.go @@ -0,0 +1,7 @@ +//go:build !pkcs11 + +package gatewayv2 + +func setupPkcs11ModuleForConfig(_ string) (Pkcs11Module, error) { + return nil, nil +} diff --git a/packages/gateway-v2/pkcs11_setup_pkcs11.go b/packages/gateway-v2/pkcs11_setup_pkcs11.go new file mode 100644 index 00000000..59ac43fd --- /dev/null +++ b/packages/gateway-v2/pkcs11_setup_pkcs11.go @@ -0,0 +1,17 @@ +//go:build pkcs11 + +package gatewayv2 + +import "github.com/rs/zerolog/log" + +func setupPkcs11ModuleForConfig(path string) (Pkcs11Module, error) { + if path == "" { + return nil, nil + } + mod, err := LoadPkcs11Module(path) + if err != nil { + return nil, err + } + log.Info().Str("path", path).Msg("PKCS#11 module loaded") + return mod, nil +} diff --git a/packages/gateway-v2/systemd.go b/packages/gateway-v2/systemd.go index 08128bf6..e6166c35 100644 --- a/packages/gateway-v2/systemd.go +++ b/packages/gateway-v2/systemd.go @@ -76,7 +76,7 @@ func resolveInstallPaths(name string) (installResult, error) { }, nil } -func InstallGatewaySystemdService(token string, domain string, name string, relayName string, serviceLogFile string) (string, error) { +func InstallGatewaySystemdService(token string, domain string, name string, relayName string, serviceLogFile string, pkcs11ModulePath string) (string, error) { if runtime.GOOS != "linux" { log.Info().Msg("Skipping systemd service installation - not on Linux") return "", nil @@ -107,6 +107,9 @@ func InstallGatewaySystemdService(token string, domain string, name string, rela if relayName != "" { configContent += fmt.Sprintf("%s=%s\n", RELAY_NAME_ENV_NAME, relayName) } + if pkcs11ModulePath != "" { + configContent += fmt.Sprintf("%s=%s\n", INFISICAL_PKCS11_MODULE_ENV_NAME, pkcs11ModulePath) + } if err := os.WriteFile(paths.configPath, []byte(configContent), 0600); err != nil { return "", fmt.Errorf("failed to write environment file: %v", err) @@ -139,7 +142,7 @@ func InstallGatewaySystemdService(token string, domain string, name string, rela // InstallEnrolledGatewaySystemdService installs the systemd service for a gateway that was // enrolled via the enrollment token flow. It writes the long-lived gateway access token // (not a machine identity token) into the environment file. -func InstallEnrolledGatewaySystemdService(accessToken string, domain string, name string, relayName string, serviceLogFile string) (string, error) { +func InstallEnrolledGatewaySystemdService(accessToken string, domain string, name string, relayName string, serviceLogFile string, pkcs11ModulePath string) (string, error) { if runtime.GOOS != "linux" { log.Info().Msg("Skipping systemd service installation - not on Linux") return "", nil @@ -169,6 +172,9 @@ func InstallEnrolledGatewaySystemdService(accessToken string, domain string, nam if relayName != "" { configContent += fmt.Sprintf("%s=%s\n", RELAY_NAME_ENV_NAME, relayName) } + if pkcs11ModulePath != "" { + configContent += fmt.Sprintf("%s=%s\n", INFISICAL_PKCS11_MODULE_ENV_NAME, pkcs11ModulePath) + } if err := os.WriteFile(paths.configPath, []byte(configContent), 0600); err != nil { return "", fmt.Errorf("failed to write environment file: %v", err) @@ -203,7 +209,7 @@ func InstallEnrolledGatewaySystemdService(accessToken string, domain string, nam // fresh STS-signed login on each service start using whatever AWS credentials it can resolve // (instance role, env vars, shared profile). We just persist the gateway id, domain, and name // so `gateway start` can re-authenticate. -func InstallAwsAuthGatewaySystemdService(gatewayID string, domain string, name string, relayName string, serviceLogFile string) (string, error) { +func InstallAwsAuthGatewaySystemdService(gatewayID string, domain string, name string, relayName string, serviceLogFile string, pkcs11ModulePath string) (string, error) { if runtime.GOOS != "linux" { log.Info().Msg("Skipping systemd service installation - not on Linux") return "", nil @@ -234,6 +240,9 @@ func InstallAwsAuthGatewaySystemdService(gatewayID string, domain string, name s if relayName != "" { configContent += fmt.Sprintf("%s=%s\n", RELAY_NAME_ENV_NAME, relayName) } + if pkcs11ModulePath != "" { + configContent += fmt.Sprintf("%s=%s\n", INFISICAL_PKCS11_MODULE_ENV_NAME, pkcs11ModulePath) + } if err := os.WriteFile(paths.configPath, []byte(configContent), 0600); err != nil { return "", fmt.Errorf("failed to write environment file: %v", err) diff --git a/packages/pam/local/access.go b/packages/pam/local/access.go index 10be3ef0..41284164 100644 --- a/packages/pam/local/access.go +++ b/packages/pam/local/access.go @@ -92,7 +92,7 @@ func StartPAMAccess(accessToken, path, reason, durationStr string, port int) { case AccountTypeAwsIam: util.PrintErrorMessageAndExit("AWS IAM access not yet supported in the new PAM model") case AccountTypeWindows: - util.PrintErrorMessageAndExit("Windows/RDP access not yet supported in the new PAM model") + startRDPProxy(httpClient, &pamResponse, displayPath, durationStr, port) case AccountTypeActiveDirectory: util.PrintErrorMessageAndExit("Active Directory access not yet supported in the new PAM model") default: @@ -249,6 +249,99 @@ func startDatabaseProxy(httpClient *resty.Client, response *api.PAMAccessRespons proxy.Run() } +func startRDPProxy(httpClient *resty.Client, response *api.PAMAccessResponse, path, durationStr string, port int) { + duration, err := time.ParseDuration(durationStr) + if err != nil { + util.HandleError(err, "Failed to parse duration") + return + } + + username, ok := response.Metadata["username"] + if !ok { + util.HandleError(fmt.Errorf("PAM response metadata is missing 'username'"), "Failed to start RDP proxy") + return + } + + ctx, cancel := context.WithCancel(context.Background()) + + proxy := &RDPProxyServer{ + BaseProxyServer: BaseProxyServer{ + httpClient: httpClient, + relayHost: response.RelayHost, + relayClientCert: response.RelayClientCertificate, + relayClientKey: response.RelayClientPrivateKey, + relayServerCertChain: response.RelayServerCertificateChain, + gatewayClientCert: response.GatewayClientCertificate, + gatewayClientKey: response.GatewayClientPrivateKey, + gatewayServerCertChain: response.GatewayServerCertificateChain, + sessionExpiry: time.Now().Add(duration), + sessionId: response.SessionId, + resourceType: response.AccountType, + ctx: ctx, + cancel: cancel, + shutdownCh: make(chan struct{}), + }, + } + + if err := proxy.ValidateResourceTypeSupported(); err != nil { + util.HandleError(err, "Gateway version outdated") + return + } + + if err := proxy.Start(port); err != nil { + util.HandleError(err, "Failed to start RDP proxy server") + return + } + + rdpFilePath, err := writeRDPFile(proxy.port, response.SessionId, username) + if err != nil { + log.Warn().Err(err).Msg("Failed to write .rdp file; proxy still running") + } else { + proxy.rdpFilePath = rdpFilePath + } + + folder, account := parsePath(path) + + log.Info().Msgf("RDP proxy server listening on port %d", proxy.port) + util.PrintfStderr("\n") + util.PrintfStderr("**********************************************************************\n") + util.PrintfStderr(" RDP Proxy Session Started! \n") + util.PrintfStderr("**********************************************************************\n") + util.PrintfStderr("\n") + if folder != "" { + util.PrintfStderr(" Folder: %s\n", folder) + } + util.PrintfStderr(" Account: %s\n", account) + util.PrintfStderr(" Duration: %s\n", duration.String()) + util.PrintfStderr("\n") + util.PrintfStderr("----------------------------------------------------------------------\n") + util.PrintfStderr(" Connection Details \n") + util.PrintfStderr("----------------------------------------------------------------------\n") + util.PrintfStderr("\n") + util.PrintfStderr(" Host: 127.0.0.1\n") + util.PrintfStderr(" Port: %d\n", proxy.port) + util.PrintfStderr(" Username: %s\n", username) + util.PrintfStderr(" Password: (leave blank)\n") + if proxy.rdpFilePath != "" { + util.PrintfStderr("\n") + util.PrintfStderr(" .rdp file: %s\n", proxy.rdpFilePath) + } + util.PrintfStderr("\n") + util.PrintfStderr(" Press Ctrl+C to terminate the session.\n") + util.PrintfStderr("**********************************************************************\n") + util.PrintfStderr("\n") + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-sigChan + log.Info().Msgf("Received signal %v, initiating graceful shutdown...", sig) + proxy.gracefulShutdown() + }() + + proxy.Run() +} + func startSSHAccess(httpClient *resty.Client, response *api.PAMAccessResponse, path, durationStr string, port int) { duration, err := time.ParseDuration(durationStr) if err != nil { diff --git a/packages/pam/local/base-proxy.go b/packages/pam/local/base-proxy.go index 0cc52108..394ec20b 100644 --- a/packages/pam/local/base-proxy.go +++ b/packages/pam/local/base-proxy.go @@ -297,9 +297,12 @@ func (b *BaseProxyServer) WaitForDisconnect(gatewayErrCh, clientErrCh <-chan err case <-gatewayErrCh: b.HandleGatewayDisconnect() case <-clientErrCh: - // Normal client disconnect, proxy stays running case <-connCtx.Done(): - log.Info().Msg("Connection cancelled by context") + select { + case <-gatewayErrCh: + b.HandleGatewayDisconnect() + default: + } } } diff --git a/packages/pam/local/rdp-proxy.go b/packages/pam/local/rdp-proxy.go index de021915..c8067a2c 100644 --- a/packages/pam/local/rdp-proxy.go +++ b/packages/pam/local/rdp-proxy.go @@ -266,7 +266,7 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) { connCtx, connCancel := context.WithCancel(p.ctx) defer connCancel() - done := make(chan struct{}, 2) + gatewayErrCh, clientErrCh := p.NewDisconnectChannels() go func() { defer connCancel() @@ -278,7 +278,7 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) { log.Debug().Err(err).Msg("Gateway to client copy ended") } } - done <- struct{}{} + gatewayErrCh <- err }() go func() { @@ -291,14 +291,10 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) { log.Debug().Err(err).Msg("Client to gateway copy ended") } } - done <- struct{}{} + clientErrCh <- err }() - select { - case <-done: - case <-connCtx.Done(): - log.Info().Msg("Connection cancelled by context") - } + p.WaitForDisconnect(gatewayErrCh, clientErrCh, connCtx) log.Info().Msgf("RDP connection closed for client: %s", clientConn.RemoteAddr().String()) } diff --git a/packages/util/check-for-update.go b/packages/util/check-for-update.go index d59b9528..dc90713b 100644 --- a/packages/util/check-for-update.go +++ b/packages/util/check-for-update.go @@ -209,27 +209,87 @@ func writeUpdateCheckCache(cache *UpdateCheckCache) error { return nil } -func DisplayAptInstallationChangeBanner(isSilent bool) { - DisplayAptInstallationChangeBannerWithWriter(isSilent, os.Stderr) +const migrationNoticeCacheTTL = 24 * time.Hour + +const migrationGuideURL = "https://infisical.com/docs/cli/cloudsmith-migration" + +const migrationCloudsmithSunset = "September 16, 2026" + +type migrationNoticeCache struct { + LastShownTime time.Time `json:"lastShownTime"` } -func DisplayAptInstallationChangeBannerWithWriter(isSilent bool, w io.Writer) { - if isSilent { +func getMigrationNoticeCachePath() (string, error) { + homeDir, err := GetHomeDir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, CONFIG_FOLDER_NAME, MIGRATION_NOTICE_CACHE_FILE_NAME), nil +} + +// migrationNoticeRecentlyShown returns true if the notice was shown within the TTL. +func migrationNoticeRecentlyShown() bool { + path, err := getMigrationNoticeCachePath() + if err != nil { + return false + } + data, err := os.ReadFile(path) + if err != nil { + return false + } + var cache migrationNoticeCache + if err := json.Unmarshal(data, &cache); err != nil { + return false + } + return time.Since(cache.LastShownTime) < migrationNoticeCacheTTL +} + +// recordMigrationNoticeShown stamps the cache so the notice is throttled. +func recordMigrationNoticeShown() { + path, err := getMigrationNoticeCachePath() + if err != nil { + return + } + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { return } + data, err := json.Marshal(migrationNoticeCache{LastShownTime: time.Now()}) + if err != nil { + return + } + // Best-effort write; a failure just means the notice may show again next run. + _ = os.WriteFile(path, data, 0600) +} - if runtime.GOOS == "linux" { - _, err := exec.LookPath("apt-get") - isApt := err == nil - if isApt { - yellow := color.New(color.FgYellow).SprintFunc() - msg := fmt.Sprintf("%s", - yellow("Update Required: Your current package installation script is outdated and will no longer receive updates.\nPlease update to the new installation script which can be found here https://infisical.com/docs/cli/overview#installation debian section\n"), - ) +func DisplayPackageRepoMigrationNotice(isSilent bool) { + DisplayPackageRepoMigrationNoticeWithWriter(isSilent, os.Stderr) +} - fmt.Fprintln(w, msg) - } +// DisplayPackageRepoMigrationNoticeWithWriter prints a one-time-per-day notice +// that the Linux package repository has moved off Cloudsmith. +// Stays quiet in --silent mode, and can be disabled with INFISICAL_DISABLE_MIGRATION_NOTICE. +func DisplayPackageRepoMigrationNoticeWithWriter(isSilent bool, w io.Writer) { + if isSilent { + return + } + if os.Getenv("INFISICAL_DISABLE_MIGRATION_NOTICE") != "" { + return } + if migrationNoticeRecentlyShown() { + return + } + + yellow := color.New(color.FgYellow).SprintFunc() + bold := color.New(color.FgYellow, color.Bold).SprintFunc() + fmt.Fprintln(w, bold("Important: the Infisical CLI Linux package repository is moving off Cloudsmith.")) + fmt.Fprintln(w, yellow( + "What's happening: Cloudsmith stops serving on "+migrationCloudsmithSunset+". After that, installing or\n"+ + "updating the CLI on Linux from the old Cloudsmith URL (apt, yum/dnf, apk) will fail.\n"+ + "What to do: repoint your machine to the new host (artifacts-cli.infisical.com).\n"+ + "Migration steps: "+migrationGuideURL+"\n", + )) + + recordMigrationNoticeShown() } func getLatestTag(repoOwner string, repoName string) (string, time.Time, bool, error) { diff --git a/packages/util/constants.go b/packages/util/constants.go index 22bdf606..50300e85 100644 --- a/packages/util/constants.go +++ b/packages/util/constants.go @@ -69,7 +69,8 @@ const ( KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" - UPDATE_CHECK_CACHE_FILE_NAME = "update-check.json" + UPDATE_CHECK_CACHE_FILE_NAME = "update-check.json" + MIGRATION_NOTICE_CACHE_FILE_NAME = "migration-notice.json" ) var ( diff --git a/upload_to_s3.sh b/upload_to_s3.sh index 0766c1fe..8e1653db 100755 --- a/upload_to_s3.sh +++ b/upload_to_s3.sh @@ -65,7 +65,20 @@ if ls *.apk 1> /dev/null 2>&1; then # Sync existing packages from S3 (to preserve old versions) echo "Syncing existing APK packages from S3..." aws s3 sync "s3://$INFISICAL_CLI_S3_BUCKET/apk/" apk-staging/ --exclude "*/APKINDEX.tar.gz" - + + # Integrity gate: the APKINDEX is rebuilt from staging, so a partial sync would + # silently drop versions. Staging = synced repo + new apks, so it must have at + # least as many .apk as S3 per arch; if fewer, the sync was incomplete, so abort. + for arch in x86_64 aarch64; do + s3_count=$(aws s3 ls "s3://$INFISICAL_CLI_S3_BUCKET/apk/stable/main/$arch/" 2>/dev/null | grep -c '\.apk$' || true) + local_count=$(ls apk-staging/stable/main/$arch/*.apk 2>/dev/null | wc -l | tr -d ' ') + if [ "$local_count" -lt "$s3_count" ]; then + echo "Error: APK sync incomplete for $arch (S3 has $s3_count, staged $local_count). Aborting to avoid publishing a stale APKINDEX." + exit 1 + fi + echo "APK integrity OK for $arch: staged $local_count >= S3 $s3_count" + done + # Validate APK private key exists if [ ! -f "$APK_PRIVATE_KEY_PATH" ]; then echo "Error: APK private key not found at $APK_PRIVATE_KEY_PATH"