From 52900c8c2455ac1b9200b91f062302375fc14bff Mon Sep 17 00:00:00 2001 From: Juraj Hilje Date: Thu, 28 May 2026 10:07:52 +0200 Subject: [PATCH 1/8] feat(token): add TLS auth for token service --- .env.sample | 13 +++ .gitignore | 3 + compose.yml | 5 ++ services/generator/client/token.go | 41 ++++++++- services/generator/config/config.go | 16 +++- services/generator/main.go | 13 ++- services/preauth/client/token.go | 41 ++++++++- services/preauth/config/config.go | 16 +++- services/preauth/main.go | 11 ++- services/token/client/signer_aws.go | 17 ++-- services/token/client/signer_fortanix.go | 17 ++-- services/token/config/config.go | 10 +++ services/token/main.go | 9 +- services/token/service/service.go | 103 +++++++++++++++++------ services/token/service/service_test.go | 9 +- 15 files changed, 247 insertions(+), 77 deletions(-) diff --git a/.env.sample b/.env.sample index 4cf9b74..19a5b88 100644 --- a/.env.sample +++ b/.env.sample @@ -2,6 +2,7 @@ TOKEN_HOST=token TOKEN_PORT=50051 TOKEN_MOCK=true +TOKEN_DEBUG=false AWS_TOKEN_KEY_ID= AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= @@ -10,6 +11,18 @@ FORTANIX_ENDPOINT= FORTANIX_API_KEY= FORTANIX_KEY_ID= +# token TLS (mTLS) — server +# Generate certs once with: scripts/gen-certs.sh (see README for instructions) +TOKEN_TLS_ENABLED=false +TOKEN_TLS_CERT_FILE=/certs/server.crt +TOKEN_TLS_KEY_FILE=/certs/server.key +TOKEN_TLS_CA_FILE=/certs/ca.crt + +# token TLS — clients (generator + preauth) +TOKEN_TLS_CLIENT_CA_FILE=/certs/ca.crt +TOKEN_TLS_CLIENT_CERT_FILE=/certs/client.crt +TOKEN_TLS_CLIENT_KEY_FILE=/certs/client.key + # generator: SERVER_DB_HOST=server-db SERVER_DB_PORT=3306 diff --git a/.gitignore b/.gitignore index dfb5825..91bc717 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,6 @@ go.work.sum # env file .env + +# TLS certs +certs/ diff --git a/compose.yml b/compose.yml index c0eede6..01ee853 100644 --- a/compose.yml +++ b/compose.yml @@ -10,6 +10,8 @@ services: restart: unless-stopped networks: server: + volumes: + - ./certs:/certs:ro generator: build: @@ -25,6 +27,7 @@ services: server: volumes: - data:/app/data + - ./certs:/certs:ro distributor: build: @@ -52,6 +55,8 @@ services: - ${PREAUTH_GET_PORT}:${PREAUTH_GET_PORT} networks: server: + volumes: + - ./certs:/certs:ro verifier: build: diff --git a/services/generator/client/token.go b/services/generator/client/token.go index 0cb99ee..7d2476d 100644 --- a/services/generator/client/token.go +++ b/services/generator/client/token.go @@ -2,9 +2,14 @@ package client import ( "context" + "crypto/tls" + "crypto/x509" + "errors" "fmt" + "os" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "ivpn.net/auth/services/generator/config" proto "ivpn.net/auth/services/proto" @@ -42,10 +47,44 @@ func (c *TokenClient) GenerateToken(input string) (string, error) { func connect(cfg config.TokenServerConfig) (*grpc.ClientConn, error) { address := cfg.Host + ":" + cfg.Port - conn, err := grpc.NewClient(address, grpc.WithTransportCredentials(insecure.NewCredentials())) + + var creds grpc.DialOption + if cfg.TLSEnabled { + tlsCfg, err := buildClientTLS(cfg) + if err != nil { + return nil, fmt.Errorf("tls config: %w", err) + } + creds = grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)) + } else { + creds = grpc.WithTransportCredentials(insecure.NewCredentials()) + } + + conn, err := grpc.NewClient(address, creds) if err != nil { return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err) } return conn, nil } + +func buildClientTLS(cfg config.TokenServerConfig) (*tls.Config, error) { + caPEM, err := os.ReadFile(cfg.TLSCACertFile) + if err != nil { + return nil, fmt.Errorf("read CA cert: %w", err) + } + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(caPEM) { + return nil, errors.New("failed to parse CA certificate") + } + + cert, err := tls.LoadX509KeyPair(cfg.TLSCertFile, cfg.TLSKeyFile) + if err != nil { + return nil, fmt.Errorf("load client cert/key: %w", err) + } + + return &tls.Config{ + RootCAs: caPool, + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS13, + }, nil +} diff --git a/services/generator/config/config.go b/services/generator/config/config.go index dbb748e..d61c3bf 100644 --- a/services/generator/config/config.go +++ b/services/generator/config/config.go @@ -14,8 +14,12 @@ type DBConfig struct { } type TokenServerConfig struct { - Host string - Port string + Host string + Port string + TLSEnabled bool + TLSCACertFile string + TLSCertFile string + TLSKeyFile string } type ServiceConfig struct { @@ -38,8 +42,12 @@ func New() (Config, error) { return Config{ TokenServer: TokenServerConfig{ - Host: os.Getenv("TOKEN_HOST"), - Port: os.Getenv("TOKEN_PORT"), + Host: os.Getenv("TOKEN_HOST"), + Port: os.Getenv("TOKEN_PORT"), + TLSEnabled: os.Getenv("TOKEN_TLS_ENABLED") == "true", + TLSCACertFile: os.Getenv("TOKEN_TLS_CLIENT_CA_FILE"), + TLSCertFile: os.Getenv("TOKEN_TLS_CLIENT_CERT_FILE"), + TLSKeyFile: os.Getenv("TOKEN_TLS_CLIENT_KEY_FILE"), }, DB: DBConfig{ Host: os.Getenv("SERVER_DB_HOST"), diff --git a/services/generator/main.go b/services/generator/main.go index cf4811b..f1c80b6 100644 --- a/services/generator/main.go +++ b/services/generator/main.go @@ -13,17 +13,17 @@ import ( func main() { cfg, err := config.New() if err != nil { - log.Println(err) + log.Fatal(err) } db, err := repository.NewDB(cfg) if err != nil { - log.Println(err) + log.Fatal(err) } tokenClient, err := client.New(cfg.TokenServer) if err != nil { - log.Println(err) + log.Fatal(err) } service := service.New(cfg, db, tokenClient) @@ -34,7 +34,7 @@ func main() { switch args[1] { case "sync": if err := service.Generate(); err != nil { - log.Println(err) + log.Fatal(err) } return @@ -47,8 +47,7 @@ func main() { } } - err = service.Start() - if err != nil { - log.Println(err) + if err = service.Start(); err != nil { + log.Fatal(err) } } diff --git a/services/preauth/client/token.go b/services/preauth/client/token.go index 4f2029b..32bb91c 100644 --- a/services/preauth/client/token.go +++ b/services/preauth/client/token.go @@ -2,9 +2,14 @@ package client import ( "context" + "crypto/tls" + "crypto/x509" + "errors" "fmt" + "os" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "ivpn.net/auth/services/preauth/config" @@ -43,10 +48,44 @@ func (c *TokenClient) GenerateToken(input string) (string, error) { func connect(cfg config.TokenServerConfig) (*grpc.ClientConn, error) { address := cfg.Host + ":" + cfg.Port - conn, err := grpc.NewClient(address, grpc.WithTransportCredentials(insecure.NewCredentials())) + + var creds grpc.DialOption + if cfg.TLSEnabled { + tlsCfg, err := buildClientTLS(cfg) + if err != nil { + return nil, fmt.Errorf("tls config: %w", err) + } + creds = grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)) + } else { + creds = grpc.WithTransportCredentials(insecure.NewCredentials()) + } + + conn, err := grpc.NewClient(address, creds) if err != nil { return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err) } return conn, nil } + +func buildClientTLS(cfg config.TokenServerConfig) (*tls.Config, error) { + caPEM, err := os.ReadFile(cfg.TLSCACertFile) + if err != nil { + return nil, fmt.Errorf("read CA cert: %w", err) + } + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(caPEM) { + return nil, errors.New("failed to parse CA certificate") + } + + cert, err := tls.LoadX509KeyPair(cfg.TLSCertFile, cfg.TLSKeyFile) + if err != nil { + return nil, fmt.Errorf("load client cert/key: %w", err) + } + + return &tls.Config{ + RootCAs: caPool, + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS13, + }, nil +} diff --git a/services/preauth/config/config.go b/services/preauth/config/config.go index d97fcbc..b00007d 100644 --- a/services/preauth/config/config.go +++ b/services/preauth/config/config.go @@ -35,8 +35,12 @@ type RedisConfig struct { } type TokenServerConfig struct { - Host string - Port string + Host string + Port string + TLSEnabled bool + TLSCACertFile string + TLSCertFile string + TLSKeyFile string } type Config struct { @@ -87,8 +91,12 @@ func New() (Config, error) { TLSInsecureSkipVerify: os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY") == "true", }, TokenServer: TokenServerConfig{ - Host: os.Getenv("TOKEN_HOST"), - Port: os.Getenv("TOKEN_PORT"), + Host: os.Getenv("TOKEN_HOST"), + Port: os.Getenv("TOKEN_PORT"), + TLSEnabled: os.Getenv("TOKEN_TLS_ENABLED") == "true", + TLSCACertFile: os.Getenv("TOKEN_TLS_CLIENT_CA_FILE"), + TLSCertFile: os.Getenv("TOKEN_TLS_CLIENT_CERT_FILE"), + TLSKeyFile: os.Getenv("TOKEN_TLS_CLIENT_KEY_FILE"), }, }, nil } diff --git a/services/preauth/main.go b/services/preauth/main.go index f6d8db4..a5057e0 100644 --- a/services/preauth/main.go +++ b/services/preauth/main.go @@ -13,23 +13,22 @@ import ( func main() { cfg, err := config.New() if err != nil { - log.Println(err) + log.Fatal(err) } redis, err := repository.New(cfg.Redis) if err != nil { - log.Println(err) + log.Fatal(err) } tokenClient, err := client.New(cfg.TokenServer) if err != nil { - log.Println(err) + log.Fatal(err) } service := service.New(cfg, redis, tokenClient) - err = api.Start(cfg.API, service) - if err != nil { - log.Println(err) + if err = api.Start(cfg.API, service); err != nil { + log.Fatal(err) } } diff --git a/services/token/client/signer_aws.go b/services/token/client/signer_aws.go index ec24c9e..6fd0a81 100644 --- a/services/token/client/signer_aws.go +++ b/services/token/client/signer_aws.go @@ -14,9 +14,7 @@ import ( "ivpn.net/auth/services/token/model" ) -var ( - ErrEmptyInput = "input string cannot be empty" -) +const ErrEmptyInput = "input string cannot be empty" type SignerAWS struct { Cfg *config.Config @@ -46,13 +44,11 @@ func NewSignerAWS(cfg config.Config) (*SignerAWS, error) { }, nil } -func (s *SignerAWS) Generate(input string) (*model.HSMToken, error) { +func (s *SignerAWS) Generate(ctx context.Context, input string) (*model.HSMToken, error) { if input == "" { return nil, fmt.Errorf("%s", ErrEmptyInput) } - // start := time.Now() - digest := sha512.Sum512([]byte(input)) if s.Cfg.Mock { @@ -67,15 +63,16 @@ func (s *SignerAWS) Generate(input string) (*model.HSMToken, error) { MacAlgorithm: types.MacAlgorithmSpecHmacSha256, } - signOut, err := s.Client.GenerateMac(context.Background(), generateInput) + signOut, err := s.Client.GenerateMac(ctx, generateInput) if err != nil { return nil, fmt.Errorf("failed to sign input: %w", err) } - // elapsed := time.Since(start) - // log.Printf("Token() completed in %s", elapsed) - return &model.HSMToken{ Token: base64.StdEncoding.EncodeToString(signOut.Mac), }, nil } + +func (s *SignerAWS) Authenticate() error { + return nil +} diff --git a/services/token/client/signer_fortanix.go b/services/token/client/signer_fortanix.go index 99da715..c0610c5 100644 --- a/services/token/client/signer_fortanix.go +++ b/services/token/client/signer_fortanix.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "fmt" "net/http" + "time" "github.com/fortanix/sdkms-client-go/sdkms" "ivpn.net/auth/services/token/config" @@ -18,9 +19,10 @@ type SignerFortanix struct { } func NewSignerFortanix(cfg config.Config) (*SignerFortanix, error) { + httpClient := &http.Client{Timeout: 30 * time.Second} client := sdkms.Client{ Endpoint: cfg.FortanixEndpoint, - HTTPClient: http.DefaultClient, + HTTPClient: httpClient, } _, err := client.AuthenticateWithAPIKey(context.Background(), cfg.FortanixApiKey) @@ -34,13 +36,11 @@ func NewSignerFortanix(cfg config.Config) (*SignerFortanix, error) { }, nil } -func (s *SignerFortanix) Generate(input string) (*model.HSMToken, error) { +func (s *SignerFortanix) Generate(ctx context.Context, input string) (*model.HSMToken, error) { if input == "" { return nil, fmt.Errorf("%s", ErrEmptyInput) } - // start := time.Now() - digest := sha512.Sum512([]byte(input)) data := sdkms.Blob(digest[:]) keyId := s.Cfg.FortanixKeyId @@ -58,14 +58,11 @@ func (s *SignerFortanix) Generate(input string) (*model.HSMToken, error) { Key: sdkms.SobjectByID(keyId), } - res, err := s.Client.Mac(context.Background(), req) + res, err := s.Client.Mac(ctx, req) if err != nil { return nil, err } - // elapsed := time.Since(start) - // log.Printf("Token() completed in %s", elapsed) - return &model.HSMToken{ Token: base64.StdEncoding.EncodeToString(res.Mac), }, nil @@ -76,7 +73,7 @@ func (s *SignerFortanix) Authenticate() error { return err } -func (s *SignerFortanix) Verify(data [64]byte, signature string) (bool, error) { +func (s *SignerFortanix) Verify(ctx context.Context, data [64]byte, signature string) (bool, error) { sigData, err := base64.StdEncoding.DecodeString(signature) if err != nil { return false, err @@ -92,7 +89,7 @@ func (s *SignerFortanix) Verify(data [64]byte, signature string) (bool, error) { Key: sdkms.SobjectByID(keyId), } - res, err := s.Client.MacVerify(context.Background(), req) + res, err := s.Client.MacVerify(ctx, req) if err != nil { return false, err } diff --git a/services/token/config/config.go b/services/token/config/config.go index f1fda51..93ecf5c 100644 --- a/services/token/config/config.go +++ b/services/token/config/config.go @@ -13,6 +13,11 @@ type Config struct { FortanixEndpoint string FortanixApiKey string FortanixKeyId string + TLSEnabled bool + TLSCertFile string + TLSKeyFile string + TLSCAFile string + Debug bool } func New() (Config, error) { @@ -27,5 +32,10 @@ func New() (Config, error) { FortanixEndpoint: os.Getenv("FORTANIX_ENDPOINT"), FortanixApiKey: os.Getenv("FORTANIX_API_KEY"), FortanixKeyId: os.Getenv("FORTANIX_KEY_ID"), + TLSEnabled: os.Getenv("TOKEN_TLS_ENABLED") == "true", + TLSCertFile: os.Getenv("TOKEN_TLS_CERT_FILE"), + TLSKeyFile: os.Getenv("TOKEN_TLS_KEY_FILE"), + TLSCAFile: os.Getenv("TOKEN_TLS_CA_FILE"), + Debug: os.Getenv("TOKEN_DEBUG") == "true", }, nil } diff --git a/services/token/main.go b/services/token/main.go index baa08c8..cc00106 100644 --- a/services/token/main.go +++ b/services/token/main.go @@ -11,17 +11,16 @@ import ( func main() { cfg, err := config.New() if err != nil { - log.Println(err) + log.Fatal(err) } signer, err := client.NewSignerFortanix(cfg) if err != nil { - log.Println(err) + log.Fatal(err) } server := service.New(signer, cfg) - err = server.Start() - if err != nil { - log.Println(err) + if err = server.Start(); err != nil { + log.Fatal(err) } } diff --git a/services/token/service/service.go b/services/token/service/service.go index 16fac7b..620b7ed 100644 --- a/services/token/service/service.go +++ b/services/token/service/service.go @@ -2,19 +2,31 @@ package service import ( "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" "log" "net" + "os" "strings" + "time" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/reflection" proto "ivpn.net/auth/services/proto" "ivpn.net/auth/services/token/config" "ivpn.net/auth/services/token/model" ) +const maxInputBytes = 4096 + +// ErrAuthRequired is returned when the HSM session needs re-authentication. +var ErrAuthRequired = errors.New("hsm auth required") + type Signer interface { - Generate(input string) (*model.HSMToken, error) + Generate(ctx context.Context, input string) (*model.HSMToken, error) Authenticate() error } @@ -34,53 +46,94 @@ func New(signer Signer, cfg config.Config) *Server { func (s *Server) Start() error { log.Printf("Starting token service on %s:%s", s.Cfg.Host, s.Cfg.Port) - lis, err := net.Listen("tcp", ":"+s.Cfg.Port) + lis, err := net.Listen("tcp", s.Cfg.Host+":"+s.Cfg.Port) if err != nil { - log.Println(err) - return err + return fmt.Errorf("listen: %w", err) } - srv := grpc.NewServer() + var opts []grpc.ServerOption + if s.Cfg.TLSEnabled { + tlsCfg, err := buildServerTLS(s.Cfg) + if err != nil { + return fmt.Errorf("tls config: %w", err) + } + opts = append(opts, grpc.Creds(credentials.NewTLS(tlsCfg))) + log.Println("Token service TLS (mTLS) enabled") + } else { + log.Println("WARNING: Token service is running without TLS — do not use in production") + } + + srv := grpc.NewServer(opts...) proto.RegisterTokenServer(srv, s) - reflection.Register(srv) - err = srv.Serve(lis) - if err != nil { - log.Println(err) - return err + if s.Cfg.Debug { + reflection.Register(srv) + log.Println("WARNING: gRPC reflection enabled — disable in production (TOKEN_DEBUG=false)") } + if err = srv.Serve(lis); err != nil { + return fmt.Errorf("serve: %w", err) + } return nil } +func buildServerTLS(cfg *config.Config) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(cfg.TLSCertFile, cfg.TLSKeyFile) + if err != nil { + return nil, fmt.Errorf("load server cert/key: %w", err) + } + + caPEM, err := os.ReadFile(cfg.TLSCAFile) + if err != nil { + return nil, fmt.Errorf("read CA cert: %w", err) + } + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(caPEM) { + return nil, errors.New("failed to parse CA certificate") + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientCAs: caPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS13, + }, nil +} + func (s *Server) Generate(ctx context.Context, req *proto.Request) (*proto.Response, error) { - token, err := s.generateToken(req.Input) + if len(req.Input) > maxInputBytes { + return nil, fmt.Errorf("input exceeds maximum allowed size of %d bytes", maxInputBytes) + } + + reqCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + token, err := s.generateToken(reqCtx, req.Input) if err != nil { - if strings.Contains(err.Error(), "Status: 401") || strings.Contains(err.Error(), "Status: 403") { + if isAuthError(err) { log.Println("Re-authenticating Signer session...") - err = s.Signer.Authenticate() - if err == nil { - token, err = s.generateToken(req.Input) + if authErr := s.Signer.Authenticate(); authErr == nil { + token, err = s.generateToken(reqCtx, req.Input) if err != nil { log.Println(err) return nil, err } - - return &proto.Response{ - Token: token.Token, - }, nil + return &proto.Response{Token: token.Token}, nil } } - log.Println(err) return nil, err } - return &proto.Response{ - Token: token.Token, - }, nil + return &proto.Response{Token: token.Token}, nil +} + +// isAuthError detects HSM session expiry responses. +func isAuthError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "Status: 401") || strings.Contains(msg, "Status: 403") } -func (s *Server) generateToken(input string) (*model.HSMToken, error) { - return s.Signer.Generate(input) +func (s *Server) generateToken(ctx context.Context, input string) (*model.HSMToken, error) { + return s.Signer.Generate(ctx, input) } diff --git a/services/token/service/service_test.go b/services/token/service/service_test.go index d3f93a9..d26b4cb 100644 --- a/services/token/service/service_test.go +++ b/services/token/service/service_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "errors" "testing" @@ -16,7 +17,7 @@ type MockHSMClient struct { } // Token implements the HSMClient interface for the mock -func (m *MockHSMClient) Generate(input string) (*model.HSMToken, error) { +func (m *MockHSMClient) Generate(ctx context.Context, input string) (*model.HSMToken, error) { // Store the parameters for verification m.input = input return m.mockToken, m.mockError @@ -46,7 +47,7 @@ func TestGenerateToken_Success(t *testing.T) { inputStr := "test-input" // Act - token, err := svc.generateToken(inputStr) + token, err := svc.generateToken(context.Background(), inputStr) // Assert if err != nil { @@ -79,7 +80,7 @@ func TestGenerateToken_Error(t *testing.T) { inputStr := "test-input" // Act - token, err := svc.generateToken(inputStr) + token, err := svc.generateToken(context.Background(), inputStr) // Assert if err != expectedError { @@ -123,7 +124,7 @@ func TestGenerateToken_DifferentParameters(t *testing.T) { svc := New(mockHSM, cfg) // Act - _, err := svc.generateToken(tc.input) + _, err := svc.generateToken(context.Background(), tc.input) // Assert if err != nil { From d7e602c13fe2bee4738150ffc5ab6d9128997e46 Mon Sep 17 00:00:00 2001 From: Juraj Hilje Date: Fri, 29 May 2026 10:57:32 +0200 Subject: [PATCH 2/8] feat(token): update config.go --- services/token/config/config.go | 38 ++++++++++++++++++++++++++++++- services/token/main.go | 3 +++ services/token/service/service.go | 12 ++++++---- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/services/token/config/config.go b/services/token/config/config.go index 93ecf5c..248f753 100644 --- a/services/token/config/config.go +++ b/services/token/config/config.go @@ -1,6 +1,9 @@ package config -import "os" +import ( + "errors" + "os" +) type Config struct { Host string @@ -39,3 +42,36 @@ func New() (Config, error) { Debug: os.Getenv("TOKEN_DEBUG") == "true", }, nil } + +// Validate returns an error if required credentials are missing for the active signer mode. +func (c *Config) Validate() error { + if c.Port == "" { + return errors.New("TOKEN_PORT is required") + } + if c.Mock { + return nil + } + // Fortanix is the active signer (NewSignerFortanix is called from main.go). + // Validate the fields it needs at runtime so the process fails fast at startup. + if c.FortanixEndpoint == "" { + return errors.New("FORTANIX_ENDPOINT is required when TOKEN_MOCK=false") + } + if c.FortanixApiKey == "" { + return errors.New("FORTANIX_API_KEY is required when TOKEN_MOCK=false") + } + if c.FortanixKeyId == "" { + return errors.New("FORTANIX_KEY_ID is required when TOKEN_MOCK=false") + } + if c.TLSEnabled { + if c.TLSCertFile == "" { + return errors.New("TOKEN_TLS_CERT_FILE is required when TOKEN_TLS_ENABLED=true") + } + if c.TLSKeyFile == "" { + return errors.New("TOKEN_TLS_KEY_FILE is required when TOKEN_TLS_ENABLED=true") + } + if c.TLSCAFile == "" { + return errors.New("TOKEN_TLS_CA_FILE is required when TOKEN_TLS_ENABLED=true") + } + } + return nil +} diff --git a/services/token/main.go b/services/token/main.go index cc00106..e25c1da 100644 --- a/services/token/main.go +++ b/services/token/main.go @@ -13,6 +13,9 @@ func main() { if err != nil { log.Fatal(err) } + if err = cfg.Validate(); err != nil { + log.Fatal(err) + } signer, err := client.NewSignerFortanix(cfg) if err != nil { diff --git a/services/token/service/service.go b/services/token/service/service.go index 620b7ed..1aa6b75 100644 --- a/services/token/service/service.go +++ b/services/token/service/service.go @@ -8,10 +8,11 @@ import ( "fmt" "log" "net" + "net/http" "os" - "strings" "time" + "github.com/fortanix/sdkms-client-go/sdkms" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/reflection" @@ -128,10 +129,13 @@ func (s *Server) Generate(ctx context.Context, req *proto.Request) (*proto.Respo return &proto.Response{Token: token.Token}, nil } -// isAuthError detects HSM session expiry responses. +// isAuthError detects HSM session expiry by inspecting the SDK's typed BackendError. func isAuthError(err error) bool { - msg := err.Error() - return strings.Contains(msg, "Status: 401") || strings.Contains(msg, "Status: 403") + var be *sdkms.BackendError + if errors.As(err, &be) { + return be.StatusCode == http.StatusUnauthorized || be.StatusCode == http.StatusForbidden + } + return false } func (s *Server) generateToken(ctx context.Context, input string) (*model.HSMToken, error) { From 3639fcb05d90bd87c2ce27510849658251cae212 Mon Sep 17 00:00:00 2001 From: Juraj Hilje Date: Fri, 29 May 2026 11:14:20 +0200 Subject: [PATCH 3/8] feat(distributor): update auth.go --- services/distributor/middleware/auth/auth.go | 4 ++++ services/preauth/middleware/auth/auth.go | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/services/distributor/middleware/auth/auth.go b/services/distributor/middleware/auth/auth.go index d5f3edf..dd90323 100644 --- a/services/distributor/middleware/auth/auth.go +++ b/services/distributor/middleware/auth/auth.go @@ -10,6 +10,10 @@ import ( func NewIPFilter(allowedIPs []string) fiber.Handler { return func(c *fiber.Ctx) error { + if slices.Contains(allowedIPs, "*") { + return c.Next() + } + clientIP := c.IP() if slices.Contains(allowedIPs, clientIP) { return c.Next() diff --git a/services/preauth/middleware/auth/auth.go b/services/preauth/middleware/auth/auth.go index d5aabaa..a95c621 100644 --- a/services/preauth/middleware/auth/auth.go +++ b/services/preauth/middleware/auth/auth.go @@ -10,6 +10,10 @@ import ( func NewIPFilter(allowedIPs []string) fiber.Handler { return func(c *fiber.Ctx) error { + if slices.Contains(allowedIPs, "*") { + return c.Next() + } + clientIP := c.IP() if slices.Contains(allowedIPs, clientIP) || c.IP() == "" { return c.Next() From 9fb7bb7f2820c43582200669be1f479e138d4ec3 Mon Sep 17 00:00:00 2001 From: Juraj Hilje Date: Sun, 31 May 2026 08:31:00 +0200 Subject: [PATCH 4/8] feat(generator): update service.go --- services/generator/client/token.go | 5 ++++ services/generator/config/config.go | 34 ++++++++++++++++++++++ services/generator/main.go | 4 +++ services/generator/repository/account.go | 25 ++++++++++++---- services/generator/repository/db.go | 18 ++++++++++-- services/generator/service/service.go | 36 +++++++++++++++++++----- 6 files changed, 106 insertions(+), 16 deletions(-) diff --git a/services/generator/client/token.go b/services/generator/client/token.go index 7d2476d..4a98919 100644 --- a/services/generator/client/token.go +++ b/services/generator/client/token.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "errors" "fmt" + "log" "os" "google.golang.org/grpc" @@ -56,6 +57,10 @@ func connect(cfg config.TokenServerConfig) (*grpc.ClientConn, error) { } creds = grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)) } else { + if os.Getenv("GENERATOR_ALLOW_INSECURE") != "true" { + return nil, errors.New("TLS is disabled but GENERATOR_ALLOW_INSECURE is not set to 'true'; refusing insecure connection") + } + log.Println("WARNING: gRPC connection to token server is unencrypted (TLS disabled)") creds = grpc.WithTransportCredentials(insecure.NewCredentials()) } diff --git a/services/generator/config/config.go b/services/generator/config/config.go index d61c3bf..3245587 100644 --- a/services/generator/config/config.go +++ b/services/generator/config/config.go @@ -1,6 +1,7 @@ package config import ( + "errors" "os" "strconv" ) @@ -63,3 +64,36 @@ func New() (Config, error) { }, }, nil } + +// Validate checks that all required configuration values are present. +func (c Config) Validate() error { + required := map[string]string{ + "TOKEN_HOST": c.TokenServer.Host, + "TOKEN_PORT": c.TokenServer.Port, + "SERVER_DB_HOST": c.DB.Host, + "SERVER_DB_PORT": c.DB.Port, + "SERVER_DB_NAME": c.DB.Name, + "SERVER_DB_USER": c.DB.User, + "SERVER_DB_PASSWORD": c.DB.Password, + } + for name, val := range required { + if val == "" { + return errors.New("required env var not set: " + name) + } + } + if c.Service.TPS <= 0 { + return errors.New("GENERATOR_TPS must be a positive integer") + } + if c.TokenServer.TLSEnabled { + if c.TokenServer.TLSCACertFile == "" { + return errors.New("required env var not set: TOKEN_TLS_CLIENT_CA_FILE") + } + if c.TokenServer.TLSCertFile == "" { + return errors.New("required env var not set: TOKEN_TLS_CLIENT_CERT_FILE") + } + if c.TokenServer.TLSKeyFile == "" { + return errors.New("required env var not set: TOKEN_TLS_CLIENT_KEY_FILE") + } + } + return nil +} diff --git a/services/generator/main.go b/services/generator/main.go index f1c80b6..21f0428 100644 --- a/services/generator/main.go +++ b/services/generator/main.go @@ -16,6 +16,10 @@ func main() { log.Fatal(err) } + if err := cfg.Validate(); err != nil { + log.Fatal("configuration error: ", err) + } + db, err := repository.NewDB(cfg) if err != nil { log.Fatal(err) diff --git a/services/generator/repository/account.go b/services/generator/repository/account.go index c5981ac..1f3efe6 100644 --- a/services/generator/repository/account.go +++ b/services/generator/repository/account.go @@ -34,8 +34,12 @@ func (d *Database) GetAccounts() ([]*model.Account, error) { func (d *Database) GetAccountsMock(count int) ([]*model.Account, error) { accounts := make([]*model.Account, count) for i := range count { + id, err := randomId() + if err != nil { + return nil, fmt.Errorf("GetAccountsMock: %w", err) + } accounts[i] = &model.Account{ - ID: randomId(), + ID: id, CreatedAt: time.Now(), IsActive: true, ActiveUntil: time.Now().AddDate(0, i%12+1, 0), // Active for x months @@ -69,7 +73,7 @@ func (d *Database) PostAccount(account *model.Account) error { return d.Client.Create(account).Error } -func randomId() string { +func randomId() (string, error) { // Generate a random ID, e.g., i-1234-ABCD-XYQZ const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" var id strings.Builder @@ -78,23 +82,32 @@ func randomId() string { max := big.NewInt(int64(len(charset))) for i := range 4 { - n, _ := rand.Int(rand.Reader, max) + n, err := rand.Int(rand.Reader, max) + if err != nil { + return "", fmt.Errorf("randomId: crypto/rand failed: %w", err) + } id.WriteByte(charset[n.Int64()]) if i == 3 { id.WriteByte('-') } } for i := range 4 { - n, _ := rand.Int(rand.Reader, max) + n, err := rand.Int(rand.Reader, max) + if err != nil { + return "", fmt.Errorf("randomId: crypto/rand failed: %w", err) + } id.WriteByte(charset[n.Int64()]) if i == 3 { id.WriteByte('-') } } for range 4 { - n, _ := rand.Int(rand.Reader, max) + n, err := rand.Int(rand.Reader, max) + if err != nil { + return "", fmt.Errorf("randomId: crypto/rand failed: %w", err) + } id.WriteByte(charset[n.Int64()]) } - return id.String() + return id.String(), nil } diff --git a/services/generator/repository/db.go b/services/generator/repository/db.go index f64c7d6..648741d 100644 --- a/services/generator/repository/db.go +++ b/services/generator/repository/db.go @@ -3,6 +3,7 @@ package repository import ( "log" + mysqldrv "github.com/go-sql-driver/mysql" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -44,13 +45,24 @@ func (d *Database) Close() error { } func connect(cfg config.DBConfig) (*gorm.DB, error) { - config := &gorm.Config{ + gormCfg := &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), } - dsn := cfg.User + ":" + cfg.Password + "@tcp(" + cfg.Host + ":" + cfg.Port + ")/" + cfg.Name + "?charset=utf8mb4&parseTime=True&loc=Local" + dsnCfg := mysqldrv.Config{ + User: cfg.User, + Passwd: cfg.Password, + Net: "tcp", + Addr: cfg.Host + ":" + cfg.Port, + DBName: cfg.Name, + Params: map[string]string{"charset": "utf8mb4"}, + ParseTime: true, + Loc: nil, // use UTC + AllowNativePasswords: true, + } + dsn := dsnCfg.FormatDSN() - db, err := gorm.Open(mysql.Open(dsn), config) + db, err := gorm.Open(mysql.Open(dsn), gormCfg) if err != nil { return nil, err } diff --git a/services/generator/service/service.go b/services/generator/service/service.go index b77e1c6..11e9474 100644 --- a/services/generator/service/service.go +++ b/services/generator/service/service.go @@ -2,16 +2,18 @@ package service import ( "context" + "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "log" - "math/rand" + "math/big" "os" "path/filepath" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -149,6 +151,7 @@ func (s *Service) GenerateSubscriptions() ([]model.Subscription, error) { jobs := make(chan *model.Account, len(accounts)) results := make(chan model.Subscription, len(accounts)) var wg sync.WaitGroup + var failedCount atomic.Int64 // Start workers for w := range workerCount { @@ -159,12 +162,15 @@ func (s *Service) GenerateSubscriptions() ([]model.Subscription, error) { // Wait for limiter before each Sign call if err := limiter.Wait(ctx); err != nil { log.Printf("[worker %d] limiter error: %v", workerID, err) + failedCount.Add(1) continue } // Generate token for account ID token, err := s.Token.GenerateToken(account.ID) if err != nil { + log.Printf("[worker %d] failed to generate token for account %s: %v", workerID, account.ID, err) + failedCount.Add(1) continue } @@ -213,11 +219,14 @@ func (s *Service) GenerateSubscriptions() ([]model.Subscription, error) { signedSubs = append(signedSubs, sub) } - // Randomize order - r := rand.New(rand.NewSource(time.Now().UnixNano())) - r.Shuffle(len(signedSubs), func(i, j int) { - signedSubs[i], signedSubs[j] = signedSubs[j], signedSubs[i] - }) + if n := failedCount.Load(); n > 0 { + return nil, fmt.Errorf("%d account(s) failed token generation; manifest not saved", n) + } + + // Randomize order using crypto/rand to prevent correlation across manifest versions + if err := cryptoShuffle(signedSubs); err != nil { + return nil, fmt.Errorf("shuffle failed: %w", err) + } log.Printf("signed %d subscriptions in %s with %d workers (limit: %d TPS)\n", len(signedSubs), time.Since(start), workerCount, s.Cfg.Service.TPS) @@ -277,8 +286,21 @@ func (s *Service) SignManifest(m *model.Manifest) error { m.Signature = signature - log.Printf("manifest signed: %s", m.Signature) + log.Println("manifest signed successfully") + + return nil +} +// cryptoShuffle performs a Fisher-Yates shuffle using crypto/rand. +func cryptoShuffle(subs []model.Subscription) error { + for i := len(subs) - 1; i > 0; i-- { + n, err := rand.Int(rand.Reader, big.NewInt(int64(i+1))) + if err != nil { + return err + } + j := n.Int64() + subs[i], subs[j] = subs[j], subs[i] + } return nil } From 4e1d5f6a4487dc0dfbf878ba3942e93e024005c0 Mon Sep 17 00:00:00 2001 From: Juraj Hilje Date: Sun, 31 May 2026 08:46:06 +0200 Subject: [PATCH 5/8] feat(preauth): update service.go --- services/preauth/client/http/http.go | 15 +++-- services/preauth/client/token.go | 5 ++ services/preauth/config/config.go | 82 +++++++++++++++++------- services/preauth/main.go | 4 ++ services/preauth/middleware/auth/auth.go | 5 +- services/preauth/repository/redis.go | 33 ++++++++-- services/preauth/service/service.go | 16 ++++- 7 files changed, 124 insertions(+), 36 deletions(-) diff --git a/services/preauth/client/http/http.go b/services/preauth/client/http/http.go index 8427641..d0f8924 100644 --- a/services/preauth/client/http/http.go +++ b/services/preauth/client/http/http.go @@ -1,7 +1,9 @@ package http import ( + "encoding/json" "errors" + "fmt" "log" "net/http" @@ -21,15 +23,20 @@ func New(cfg config.APIConfig) *Http { } func (h Http) PostSession(session model.Session, url string, psk string) error { + body, err := json.Marshal(session) + if err != nil { + return fmt.Errorf("failed to marshal session: %w", err) + } + req := fiber.Post(url) req.Set("Content-Type", "application/json") req.Set("Accept", "application/json") req.Set("Authorization", "Bearer "+psk) - req.Body([]byte(`{"id": "` + session.ID + `", "token": "` + session.Token + `", "preauth_id": "` + session.PreAuthID + `"}`)) + req.Body(body) - status, res, err := req.Bytes() - if err != nil { - log.Printf("Error calling session webhook: %v", err) + status, res, errs := req.Bytes() + if len(errs) > 0 { + log.Printf("Error calling session webhook: %v", errs) return errors.New("error calling session webhook") } diff --git a/services/preauth/client/token.go b/services/preauth/client/token.go index 32bb91c..36ea95a 100644 --- a/services/preauth/client/token.go +++ b/services/preauth/client/token.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "errors" "fmt" + "log" "os" "google.golang.org/grpc" @@ -57,6 +58,10 @@ func connect(cfg config.TokenServerConfig) (*grpc.ClientConn, error) { } creds = grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)) } else { + if os.Getenv("PREAUTH_ALLOW_INSECURE") != "true" { + return nil, errors.New("TLS is disabled but PREAUTH_ALLOW_INSECURE is not set to 'true'; refusing insecure connection") + } + log.Println("WARNING: gRPC connection to token server is unencrypted (TLS disabled)") creds = grpc.WithTransportCredentials(insecure.NewCredentials()) } diff --git a/services/preauth/config/config.go b/services/preauth/config/config.go index b00007d..40fbb01 100644 --- a/services/preauth/config/config.go +++ b/services/preauth/config/config.go @@ -1,6 +1,7 @@ package config import ( + "errors" "os" "strings" "time" @@ -20,18 +21,17 @@ type APIConfig struct { } type RedisConfig struct { - Addr string - Addrs []string - MasterName string - Username string - Password string - FailoverUsername string - FailoverPassword string - TLSEnabled bool - CertFile string - KeyFile string - CACertFile string - TLSInsecureSkipVerify bool // Optional: Only for testing, use false in production + Addr string + Addrs []string + MasterName string + Username string + Password string + FailoverUsername string + FailoverPassword string + TLSEnabled bool + CertFile string + KeyFile string + CACertFile string } type TokenServerConfig struct { @@ -77,18 +77,17 @@ func New() (Config, error) { ApiAllowIPs: apiAllowIPs, }, Redis: RedisConfig{ - Addr: os.Getenv("REDIS_ADDR"), - Addrs: redisAddrs, - MasterName: os.Getenv("REDIS_MASTER_NAME"), - Username: os.Getenv("REDIS_USERNAME"), - Password: os.Getenv("REDIS_PASSWORD"), - FailoverUsername: os.Getenv("REDIS_FAILOVER_USERNAME"), - FailoverPassword: os.Getenv("REDIS_FAILOVER_PASSWORD"), - TLSEnabled: os.Getenv("REDIS_TLS_ENABLED") == "true", - CertFile: os.Getenv("REDIS_CERT_FILE"), - KeyFile: os.Getenv("REDIS_KEY_FILE"), - CACertFile: os.Getenv("REDIS_CA_CERT_FILE"), - TLSInsecureSkipVerify: os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY") == "true", + Addr: os.Getenv("REDIS_ADDR"), + Addrs: redisAddrs, + MasterName: os.Getenv("REDIS_MASTER_NAME"), + Username: os.Getenv("REDIS_USERNAME"), + Password: os.Getenv("REDIS_PASSWORD"), + FailoverUsername: os.Getenv("REDIS_FAILOVER_USERNAME"), + FailoverPassword: os.Getenv("REDIS_FAILOVER_PASSWORD"), + TLSEnabled: os.Getenv("REDIS_TLS_ENABLED") == "true", + CertFile: os.Getenv("REDIS_CERT_FILE"), + KeyFile: os.Getenv("REDIS_KEY_FILE"), + CACertFile: os.Getenv("REDIS_CA_CERT_FILE"), }, TokenServer: TokenServerConfig{ Host: os.Getenv("TOKEN_HOST"), @@ -100,3 +99,38 @@ func New() (Config, error) { }, }, nil } + +// Validate checks that all required configuration values are present. +func (c Config) Validate() error { + required := map[string]string{ + "PREAUTH_ADD_PORT": c.API.AddPort, + "PREAUTH_ADD_PSK": c.API.AddPSK, + "PREAUTH_GET_PORT": c.API.GetPort, + "PREAUTH_GET_PSK": c.API.GetPSK, + "TOKEN_HOST": c.TokenServer.Host, + "TOKEN_PORT": c.TokenServer.Port, + } + for name, val := range required { + if val == "" { + return errors.New("required env var not set: " + name) + } + } + if c.Redis.Addr == "" && (len(c.Redis.Addrs) == 0 || c.Redis.Addrs[0] == "") { + return errors.New("required env var not set: REDIS_ADDR or REDIS_ADDRESSES") + } + if c.API.PreauthTTL <= 0 { + return errors.New("PREAUTH_TTL must be a positive duration") + } + if c.TokenServer.TLSEnabled { + if c.TokenServer.TLSCACertFile == "" { + return errors.New("required env var not set: TOKEN_TLS_CLIENT_CA_FILE") + } + if c.TokenServer.TLSCertFile == "" { + return errors.New("required env var not set: TOKEN_TLS_CLIENT_CERT_FILE") + } + if c.TokenServer.TLSKeyFile == "" { + return errors.New("required env var not set: TOKEN_TLS_CLIENT_KEY_FILE") + } + } + return nil +} diff --git a/services/preauth/main.go b/services/preauth/main.go index a5057e0..1ab4a49 100644 --- a/services/preauth/main.go +++ b/services/preauth/main.go @@ -16,6 +16,10 @@ func main() { log.Fatal(err) } + if err := cfg.Validate(); err != nil { + log.Fatal("configuration error: ", err) + } + redis, err := repository.New(cfg.Redis) if err != nil { log.Fatal(err) diff --git a/services/preauth/middleware/auth/auth.go b/services/preauth/middleware/auth/auth.go index a95c621..98d2c43 100644 --- a/services/preauth/middleware/auth/auth.go +++ b/services/preauth/middleware/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "crypto/subtle" "slices" "strings" @@ -15,7 +16,7 @@ func NewIPFilter(allowedIPs []string) fiber.Handler { } clientIP := c.IP() - if slices.Contains(allowedIPs, clientIP) || c.IP() == "" { + if clientIP != "" && slices.Contains(allowedIPs, clientIP) { return c.Next() } @@ -26,7 +27,7 @@ func NewIPFilter(allowedIPs []string) fiber.Handler { func NewPSK(psk string) fiber.Handler { return func(c *fiber.Ctx) error { - if GetToken(c) == psk { + if subtle.ConstantTimeCompare([]byte(GetToken(c)), []byte(psk)) == 1 { return c.Next() } diff --git a/services/preauth/repository/redis.go b/services/preauth/repository/redis.go index 2d905fd..da4dc5b 100644 --- a/services/preauth/repository/redis.go +++ b/services/preauth/repository/redis.go @@ -43,7 +43,33 @@ func New(cfg config.RedisConfig) (*Redis, error) { func newClient(cfg config.RedisConfig) (*redis.Client, error) { log.Println("creating Redis client") options := &redis.Options{ - Addr: cfg.Addr, + Addr: cfg.Addr, + Username: cfg.Username, + Password: cfg.Password, + } + + if cfg.TLSEnabled { + log.Println("using TLS to connect to Redis") + + cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate: %v", err) + } + + caCert, err := os.ReadFile(cfg.CACertFile) + if err != nil { + return nil, fmt.Errorf("failed to load CA certificate: %v", err) + } + + caCertPool := x509.NewCertPool() + if ok := caCertPool.AppendCertsFromPEM(caCert); !ok { + return nil, fmt.Errorf("failed to append CA certificate") + } + + options.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + } } return redis.NewClient(options), nil @@ -80,9 +106,8 @@ func newFailoverClient(cfg config.RedisConfig) (*redis.Client, error) { } options.TLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: caCertPool, - InsecureSkipVerify: cfg.TLSInsecureSkipVerify, // Only for testing, use false in production + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, } } diff --git a/services/preauth/service/service.go b/services/preauth/service/service.go index bb9f865..d956ab5 100644 --- a/services/preauth/service/service.go +++ b/services/preauth/service/service.go @@ -5,6 +5,8 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" + "fmt" "log" "time" @@ -94,6 +96,7 @@ func (s *Service) AddPreAuth(ctx context.Context, accountId string, isActive boo // Post session to webhooks services := make([]model.SessionService, len(s.Cfg.API.SessionURLs)) + var webhookErrs []error for i, url := range s.Cfg.API.SessionURLs { session := model.Session{ ID: uuid.New().String(), @@ -102,9 +105,10 @@ func (s *Service) AddPreAuth(ctx context.Context, accountId string, isActive boo } psk := s.Cfg.API.SessionPSKs[i] - err = s.Http.PostSession(session, url, psk) - if err != nil { + if err = s.Http.PostSession(session, url, psk); err != nil { log.Println("failed to post session to ", url, ", error:", err) + webhookErrs = append(webhookErrs, fmt.Errorf("%s: %w", url, err)) + continue } services[i] = model.SessionService{ @@ -113,5 +117,13 @@ func (s *Service) AddPreAuth(ctx context.Context, accountId string, isActive boo } } + if len(webhookErrs) > 0 { + // Roll back the PreAuth stored in Redis to avoid orphaned entries + if delErr := s.Cache.Del(ctx, "preauth_"+pa.ID); delErr != nil { + log.Println("failed to rollback pre-auth from cache:", delErr) + } + return nil, errors.Join(webhookErrs...) + } + return services, nil } From f4540c9bbb8607d6d6cd44d9985e6acfb227be9c Mon Sep 17 00:00:00 2001 From: Juraj Hilje Date: Sun, 31 May 2026 09:02:40 +0200 Subject: [PATCH 6/8] feat(distributor): update service.go --- services/distributor/api/routes.go | 9 --------- services/distributor/config/config.go | 12 ++++++++++++ services/distributor/main.go | 8 ++++++-- services/distributor/middleware/auth/auth.go | 3 ++- services/distributor/service/service.go | 10 +++++----- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/services/distributor/api/routes.go b/services/distributor/api/routes.go index 041c9a0..4f64738 100644 --- a/services/distributor/api/routes.go +++ b/services/distributor/api/routes.go @@ -1,7 +1,6 @@ package api import ( - "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/healthcheck" "github.com/gofiber/fiber/v2/middleware/helmet" "ivpn.net/auth/services/distributor/config" @@ -10,14 +9,6 @@ import ( ) func (h *Handler) SetupRoutes(cfg config.APIConfig) { - h.Server.Get("/debug-ip", func(c *fiber.Ctx) error { - return c.JSON(fiber.Map{ - "ip": c.IP(), - "x-forwarded-for": c.Get("X-Forwarded-For"), - "x-real-ip": c.Get("X-Real-IP"), - }) - }) - h.Server.Use(helmet.New()) h.Server.Use(healthcheck.New()) h.Server.Use(auth.NewIPFilter(cfg.ApiAllowIPs)) diff --git a/services/distributor/config/config.go b/services/distributor/config/config.go index a21f7c7..3bf64b2 100644 --- a/services/distributor/config/config.go +++ b/services/distributor/config/config.go @@ -1,6 +1,7 @@ package config import ( + "errors" "os" "strings" ) @@ -29,3 +30,14 @@ func New() (Config, error) { }, }, nil } + +// Validate checks that all required configuration values are present. +func (c Config) Validate() error { + if c.API.Port == "" { + return errors.New("required env var not set: DISTRIBUTOR_PORT") + } + if c.API.PSK == "" { + return errors.New("required env var not set: DISTRIBUTOR_PSK") + } + return nil +} diff --git a/services/distributor/main.go b/services/distributor/main.go index b04e16e..e8accd9 100644 --- a/services/distributor/main.go +++ b/services/distributor/main.go @@ -11,13 +11,17 @@ import ( func main() { cfg, err := config.New() if err != nil { - log.Println(err) + log.Fatal(err) + } + + if err := cfg.Validate(); err != nil { + log.Fatal("configuration error: ", err) } service := service.New(cfg) err = api.Start(cfg.API, service) if err != nil { - log.Println(err) + log.Fatal(err) } } diff --git a/services/distributor/middleware/auth/auth.go b/services/distributor/middleware/auth/auth.go index dd90323..e78df84 100644 --- a/services/distributor/middleware/auth/auth.go +++ b/services/distributor/middleware/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "crypto/subtle" "slices" "strings" @@ -26,7 +27,7 @@ func NewIPFilter(allowedIPs []string) fiber.Handler { func NewPSK(psk string) fiber.Handler { return func(c *fiber.Ctx) error { - if GetToken(c) == psk { + if subtle.ConstantTimeCompare([]byte(GetToken(c)), []byte(psk)) == 1 { return c.Next() } diff --git a/services/distributor/service/service.go b/services/distributor/service/service.go index 3451397..ec09156 100644 --- a/services/distributor/service/service.go +++ b/services/distributor/service/service.go @@ -2,6 +2,7 @@ package service import ( "encoding/json" + "fmt" "io" "log" "os" @@ -31,21 +32,20 @@ func (s *Service) GetManifest() (model.Manifest, error) { // Open the JSON file file, err := os.Open(path) if err != nil { - log.Println("failed to open file:", err) + return model.Manifest{}, fmt.Errorf("failed to open manifest: %w", err) } defer file.Close() // Read file contents bytes, err := io.ReadAll(file) if err != nil { - log.Println("failed to read file:", err) + return model.Manifest{}, fmt.Errorf("failed to read manifest: %w", err) } // Unmarshal JSON into Manifest struct var manifest model.Manifest - err = json.Unmarshal(bytes, &manifest) - if err != nil { - log.Println("failed to unmarshal JSON:", err) + if err = json.Unmarshal(bytes, &manifest); err != nil { + return model.Manifest{}, fmt.Errorf("failed to unmarshal manifest: %w", err) } return manifest, nil From 2cf5259f3e2953a8236fc551e4f1841b9f4633a4 Mon Sep 17 00:00:00 2001 From: Juraj Hilje Date: Sun, 31 May 2026 09:16:45 +0200 Subject: [PATCH 7/8] feat(verifier): update service.go --- services/verifier/client/verifier_aws.go | 11 ++++- services/verifier/client/verifier_fortanix.go | 13 +++++- services/verifier/config/config.go | 23 +++++++++++ services/verifier/main.go | 8 +++- services/verifier/repository/db.go | 18 ++++++-- services/verifier/repository/postgres.go | 41 ++++++++----------- services/verifier/repository/subscription.go | 39 ++++++++---------- services/verifier/service/service.go | 8 ++-- 8 files changed, 101 insertions(+), 60 deletions(-) diff --git a/services/verifier/client/verifier_aws.go b/services/verifier/client/verifier_aws.go index ffe58a7..c49179b 100644 --- a/services/verifier/client/verifier_aws.go +++ b/services/verifier/client/verifier_aws.go @@ -73,7 +73,10 @@ func (s *VerifierAWS) Verify(signature string, data []byte) error { MacAlgorithm: types.MacAlgorithmSpecHmacSha256, } - verifyOut, _ := s.Client.VerifyMac(context.Background(), verifyInput) + verifyOut, err := s.Client.VerifyMac(context.Background(), verifyInput) + if err != nil { + return fmt.Errorf("KMS VerifyMac failed: %w", err) + } if verifyOut == nil { return fmt.Errorf("error verifying manifest signature: verifyOut is nil") } @@ -85,3 +88,9 @@ func (s *VerifierAWS) Verify(signature string, data []byte) error { return nil } + +// IsAuthError always returns false for AWS KMS — authentication is handled via IAM credentials, +// not per-request session tokens. +func (s *VerifierAWS) IsAuthError(_ error) bool { + return false +} diff --git a/services/verifier/client/verifier_fortanix.go b/services/verifier/client/verifier_fortanix.go index 50300d6..b4b222b 100644 --- a/services/verifier/client/verifier_fortanix.go +++ b/services/verifier/client/verifier_fortanix.go @@ -5,9 +5,11 @@ import ( "crypto/sha256" "crypto/sha512" "encoding/base64" + "errors" "fmt" "log" "net/http" + "time" "github.com/fortanix/sdkms-client-go/sdkms" "ivpn.net/auth/services/verifier/config" @@ -21,7 +23,7 @@ type VerifierFortanix struct { func NewVerifierFortanix(cfg config.Config) (*VerifierFortanix, error) { client := sdkms.Client{ Endpoint: cfg.Service.FortanixEndpoint, - HTTPClient: http.DefaultClient, + HTTPClient: &http.Client{Timeout: 30 * time.Second}, } _, err := client.AuthenticateWithAPIKey(context.Background(), cfg.Service.FortanixApiKey) @@ -86,3 +88,12 @@ func (s *VerifierFortanix) Authenticate() error { _, err := s.Client.AuthenticateWithAPIKey(context.Background(), s.Cfg.Service.FortanixApiKey) return err } + +// IsAuthError returns true when err is a Fortanix BackendError with HTTP status 401 or 403. +func (s *VerifierFortanix) IsAuthError(err error) bool { + var be *sdkms.BackendError + if errors.As(err, &be) { + return be.StatusCode == 401 || be.StatusCode == 403 + } + return false +} diff --git a/services/verifier/config/config.go b/services/verifier/config/config.go index 0bd65a0..7da7911 100644 --- a/services/verifier/config/config.go +++ b/services/verifier/config/config.go @@ -1,6 +1,7 @@ package config import ( + "errors" "os" ) @@ -103,3 +104,25 @@ func New() (Config, error) { }, }, nil } + +// Validate checks that all required configuration values are present. +func (c Config) Validate() error { + if c.API.ManifestURL == "" { + return errors.New("required env var not set: MANIFEST_URL") + } + if c.API.ManifestPSK == "" { + return errors.New("required env var not set: MANIFEST_PSK") + } + if !c.Service.Mock { + if c.Service.FortanixEndpoint == "" { + return errors.New("required env var not set: FORTANIX_ENDPOINT") + } + if c.Service.FortanixApiKey == "" { + return errors.New("required env var not set: FORTANIX_API_KEY") + } + if c.Service.FortanixKeyId == "" { + return errors.New("required env var not set: FORTANIX_KEY_ID") + } + } + return nil +} diff --git a/services/verifier/main.go b/services/verifier/main.go index 5bad1b3..c6cc9d4 100644 --- a/services/verifier/main.go +++ b/services/verifier/main.go @@ -16,6 +16,10 @@ func main() { log.Fatal(err) } + if err := cfg.Validate(); err != nil { + log.Fatal("configuration error: ", err) + } + var stores []service.Store if cfg.DB.Host != "" { @@ -62,7 +66,7 @@ func main() { switch args[1] { case "sync": if err := svc.SyncManifest(); err != nil { - log.Println(err) + log.Fatal(err) } return @@ -76,6 +80,6 @@ func main() { } if err := svc.Start(); err != nil { - log.Println(err) + log.Fatal(err) } } diff --git a/services/verifier/repository/db.go b/services/verifier/repository/db.go index 6ac3711..0527194 100644 --- a/services/verifier/repository/db.go +++ b/services/verifier/repository/db.go @@ -3,6 +3,7 @@ package repository import ( "log" + mysqldrv "github.com/go-sql-driver/mysql" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -44,13 +45,24 @@ func (d *Database) Close() error { } func connect(cfg config.DBConfig) (*gorm.DB, error) { - config := &gorm.Config{ + gormCfg := &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), } - dsn := cfg.User + ":" + cfg.Password + "@tcp(" + cfg.Host + ":" + cfg.Port + ")/" + cfg.Name + "?charset=utf8mb4&parseTime=True&loc=Local" + dsnCfg := mysqldrv.Config{ + User: cfg.User, + Passwd: cfg.Password, + Net: "tcp", + Addr: cfg.Host + ":" + cfg.Port, + DBName: cfg.Name, + Params: map[string]string{"charset": "utf8mb4"}, + ParseTime: true, + Loc: nil, + AllowNativePasswords: true, + } + dsn := dsnCfg.FormatDSN() - db, err := gorm.Open(mysql.Open(dsn), config) + db, err := gorm.Open(mysql.Open(dsn), gormCfg) if err != nil { return nil, err } diff --git a/services/verifier/repository/postgres.go b/services/verifier/repository/postgres.go index 3e73142..da84af8 100644 --- a/services/verifier/repository/postgres.go +++ b/services/verifier/repository/postgres.go @@ -3,7 +3,6 @@ package repository import ( "fmt" "log" - "strings" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -47,7 +46,7 @@ func (d *PostgresDB) Close() error { func connectPostgres(cfg config.PGDBConfig) (*gorm.DB, error) { sslMode := cfg.SSLMode if sslMode == "" { - sslMode = "disable" + sslMode = "require" } dsn := fmt.Sprintf( @@ -86,27 +85,19 @@ func (d *PostgresDB) UpdateSubscriptions(subs []model.Subscription) error { return nil } - var ids []string - var isActiveCases, activeUntilCases, tierCases strings.Builder - - for _, sub := range subs { - id := sub.ID - ids = append(ids, fmt.Sprintf("'%s'", id)) - - isActiveCases.WriteString(fmt.Sprintf("WHEN '%s' THEN %t ", id, sub.IsActive)) - activeUntilCases.WriteString(fmt.Sprintf("WHEN '%s' THEN '%s'::timestamp ", id, sub.ActiveUntil.Format("2006-01-02 15:04:05"))) - tierCases.WriteString(fmt.Sprintf("WHEN '%s' THEN '%s' ", id, sub.Tier)) - } - - sql := fmt.Sprintf(` - UPDATE %s - SET - updated_at = NOW(), - is_active = CASE id %s END, - active_until = CASE id %s END, - tier = CASE id %s END - WHERE id IN (%s); - `, d.TableName, isActiveCases.String(), activeUntilCases.String(), tierCases.String(), strings.Join(ids, ",")) - - return d.Client.Exec(sql).Error + return d.Client.Transaction(func(tx *gorm.DB) error { + for _, sub := range subs { + result := tx.Table(d.TableName). + Where("id = ?", sub.ID). + Updates(map[string]interface{}{ + "is_active": sub.IsActive, + "active_until": sub.ActiveUntil, + "tier": sub.Tier, + }) + if result.Error != nil { + return result.Error + } + } + return nil + }) } diff --git a/services/verifier/repository/subscription.go b/services/verifier/repository/subscription.go index 8c05a72..5eeb667 100644 --- a/services/verifier/repository/subscription.go +++ b/services/verifier/repository/subscription.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "gorm.io/gorm" "ivpn.net/auth/services/verifier/model" ) @@ -18,29 +19,21 @@ func (d *Database) UpdateSubscriptions(subs []model.Subscription) error { return nil } - var ids []string - var isActiveCases, activeUntilCases, tierCases strings.Builder - - for _, sub := range subs { - id := sub.ID // assuming this is a string (e.g., UUID) - ids = append(ids, fmt.Sprintf("'%s'", id)) // quote the string for SQL - - isActiveCases.WriteString(fmt.Sprintf("WHEN '%s' THEN %t ", id, sub.IsActive)) - activeUntilCases.WriteString(fmt.Sprintf("WHEN '%s' THEN '%s' ", id, sub.ActiveUntil.Format("2006-01-02 15:04:05"))) - tierCases.WriteString(fmt.Sprintf("WHEN '%s' THEN '%s' ", id, sub.Tier)) - } - - sql := fmt.Sprintf(` - UPDATE %s - SET - updated_at = NOW(), - is_active = CASE id %s END, - active_until = CASE id %s END, - tier = CASE id %s END - WHERE id IN (%s); - `, d.TableName, isActiveCases.String(), activeUntilCases.String(), tierCases.String(), strings.Join(ids, ",")) - - return d.Client.Exec(sql).Error + return d.Client.Transaction(func(tx *gorm.DB) error { + for _, sub := range subs { + result := tx.Table(d.TableName). + Where("id = ?", sub.ID). + Updates(map[string]interface{}{ + "is_active": sub.IsActive, + "active_until": sub.ActiveUntil, + "tier": sub.Tier, + }) + if result.Error != nil { + return result.Error + } + } + return nil + }) } func joinInt64s(ids []int64) string { diff --git a/services/verifier/service/service.go b/services/verifier/service/service.go index 98bda5c..05be16e 100644 --- a/services/verifier/service/service.go +++ b/services/verifier/service/service.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "log" - "strings" "time" "github.com/jasonlvhit/gocron" @@ -21,6 +20,7 @@ type Store interface { type Verifier interface { Verify(signature string, data []byte) error Authenticate() error + IsAuthError(err error) bool } type Service struct { @@ -105,16 +105,14 @@ func (s *Service) VerifyManifest(m model.Manifest) error { err = s.Verifier.Verify(signature, data) if err != nil { - if strings.Contains(err.Error(), "Status: 401") || strings.Contains(err.Error(), "Status: 403") { + if s.Verifier.IsAuthError(err) { log.Println("re-authenticating verifier session...") - err = s.Verifier.Authenticate() - if err == nil { + if authErr := s.Verifier.Authenticate(); authErr == nil { err = s.Verifier.Verify(signature, data) if err != nil { log.Println(err) return err } - return nil } } From 23cc41aa3e02cf599ff11f00e9eeb71d4ac366d9 Mon Sep 17 00:00:00 2001 From: Juraj Hilje Date: Wed, 3 Jun 2026 10:45:16 +0200 Subject: [PATCH 8/8] tests: update Taskfile.yml --- Taskfile.yml | 4 ++-- services/verifier/repository/postgres.go | 2 +- services/verifier/repository/subscription.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Taskfile.yml b/Taskfile.yml index 6b29e4c..9c6f3d0 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -12,14 +12,14 @@ tasks: - | find . -name 'go.mod' -exec dirname {} \; | while read module_dir; do echo "Running go modernize in $module_dir" - (cd "$module_dir" && go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -test ./...) + (cd "$module_dir" && go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@v0.21.1 -test ./...) done modernize_fix: cmds: - | find . -name 'go.mod' -exec dirname {} \; | while read module_dir; do echo "Running go modernize_fix in $module_dir" - (cd "$module_dir" && go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./...) + (cd "$module_dir" && go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@v0.21.1 -fix -test ./...) done test: cmds: diff --git a/services/verifier/repository/postgres.go b/services/verifier/repository/postgres.go index da84af8..ffc5cbe 100644 --- a/services/verifier/repository/postgres.go +++ b/services/verifier/repository/postgres.go @@ -89,7 +89,7 @@ func (d *PostgresDB) UpdateSubscriptions(subs []model.Subscription) error { for _, sub := range subs { result := tx.Table(d.TableName). Where("id = ?", sub.ID). - Updates(map[string]interface{}{ + Updates(map[string]any{ "is_active": sub.IsActive, "active_until": sub.ActiveUntil, "tier": sub.Tier, diff --git a/services/verifier/repository/subscription.go b/services/verifier/repository/subscription.go index 5eeb667..8d73d5a 100644 --- a/services/verifier/repository/subscription.go +++ b/services/verifier/repository/subscription.go @@ -23,7 +23,7 @@ func (d *Database) UpdateSubscriptions(subs []model.Subscription) error { for _, sub := range subs { result := tx.Table(d.TableName). Where("id = ?", sub.ID). - Updates(map[string]interface{}{ + Updates(map[string]any{ "is_active": sub.IsActive, "active_until": sub.ActiveUntil, "tier": sub.Tier,