diff --git a/Makefile b/Makefile index 085db3a..979a849 100644 --- a/Makefile +++ b/Makefile @@ -39,10 +39,10 @@ deps: go mod tidy build: deps - go build $(LDFLAGS) -o $(BINARY_NAME) cmd/agent/main.go + go build $(LDFLAGS) -o $(BINARY_NAME) ./cmd/agent run: deps - go run cmd/agent/main.go --config config.example.yml + go run ./cmd/agent --config config.example.yml test: test-unit diff --git a/internal/api/cert_renewal_test.go b/internal/api/cert_renewal_test.go new file mode 100644 index 0000000..1336cdc --- /dev/null +++ b/internal/api/cert_renewal_test.go @@ -0,0 +1,256 @@ +package api + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "math/big" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/flatrun/agent/internal/docker" + "github.com/flatrun/agent/internal/nginx" + "github.com/flatrun/agent/internal/proxy" + "github.com/flatrun/agent/internal/ssl" + "github.com/flatrun/agent/pkg/config" + "github.com/flatrun/agent/pkg/models" + "github.com/gin-gonic/gin" + "gopkg.in/yaml.v3" +) + +type fakeCertbotExecutor struct{} + +func (fakeCertbotExecutor) Execute(_ *config.ServiceExecConfig, _ []string) ([]byte, error) { + return []byte("ok"), nil +} + +func writeSelfSignedCert(t *testing.T, certsDir, domain string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: domain}, + Issuer: pkix.Name{CommonName: "flatrun-test-ca"}, + NotBefore: time.Now().Add(-24 * time.Hour), + NotAfter: time.Now().Add(60 * 24 * time.Hour), + DNSNames: []string{domain}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + + dir := filepath.Join(certsDir, domain) + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + f, err := os.Create(filepath.Join(dir, "cert.pem")) + if err != nil { + t.Fatalf("create file: %v", err) + } + defer f.Close() + if err := pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: der}); err != nil { + t.Fatalf("encode: %v", err) + } +} + +func writeDeploymentWithDomains(t *testing.T, deploymentsPath, name string, domains []models.DomainConfig) { + t.Helper() + + dir := filepath.Join(deploymentsPath, name) + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatalf("mkdir deployment: %v", err) + } + compose := "name: " + name + "\nservices:\n web:\n image: nginx:latest\n" + if err := os.WriteFile(filepath.Join(dir, "docker-compose.yml"), []byte(compose), 0644); err != nil { + t.Fatalf("write compose: %v", err) + } + + meta := &models.ServiceMetadata{ + Name: name, + Type: "web", + Domains: domains, + } + data, err := yaml.Marshal(meta) + if err != nil { + t.Fatalf("marshal metadata: %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "service.yml"), data, 0644); err != nil { + t.Fatalf("write metadata: %v", err) + } +} + +func setupRenewalTestServer(t *testing.T) (*Server, string, string) { + t.Helper() + gin.SetMode(gin.TestMode) + + tmpDir := t.TempDir() + deploymentsPath := filepath.Join(tmpDir, "deployments") + certsPath := filepath.Join(tmpDir, "certs", "live") + if err := os.MkdirAll(deploymentsPath, 0755); err != nil { + t.Fatalf("mkdir deployments: %v", err) + } + if err := os.MkdirAll(certsPath, 0755); err != nil { + t.Fatalf("mkdir certs: %v", err) + } + + cfg := &config.Config{ + DeploymentsPath: deploymentsPath, + Certbot: config.CertbotConfig{CertsPath: certsPath}, + } + + nginxMgr := nginx.NewManager(&cfg.Nginx, deploymentsPath, "") + sslMgr := ssl.NewManager(&cfg.Certbot, deploymentsPath, fakeCertbotExecutor{}) + orch := proxy.NewOrchestratorWithManagers(nginxMgr, sslMgr) + + server := &Server{ + config: cfg, + manager: docker.NewManager(deploymentsPath), + proxyOrchestrator: orch, + } + + return server, deploymentsPath, certsPath +} + +func TestListCertificates_AnnotatesDeploymentID(t *testing.T) { + server, deploymentsPath, certsPath := setupRenewalTestServer(t) + + writeSelfSignedCert(t, certsPath, "app.example.com") + writeSelfSignedCert(t, certsPath, "alias.example.com") + writeSelfSignedCert(t, certsPath, "orphan.example.com") + + writeDeploymentWithDomains(t, deploymentsPath, "my-app", []models.DomainConfig{ + { + Domain: "app.example.com", + Aliases: []string{"alias.example.com"}, + }, + }) + + router := gin.New() + router.GET("/certificates", server.listCertificates) + + req := httptest.NewRequest(http.MethodGet, "/certificates", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (%s)", w.Code, w.Body.String()) + } + + var resp struct { + Certificates []models.Certificate `json:"certificates"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + byDomain := make(map[string]string) + for _, c := range resp.Certificates { + byDomain[c.Domain] = c.DeploymentID + } + + if byDomain["app.example.com"] != "my-app" { + t.Errorf("app.example.com DeploymentID = %q, want my-app", byDomain["app.example.com"]) + } + if byDomain["alias.example.com"] != "my-app" { + t.Errorf("alias.example.com DeploymentID = %q, want my-app (via alias)", byDomain["alias.example.com"]) + } + if byDomain["orphan.example.com"] != "" { + t.Errorf("orphan cert should have empty DeploymentID, got %q", byDomain["orphan.example.com"]) + } +} + +func TestRenewDeploymentCertificates_CollectsAllDomainsAndAliases(t *testing.T) { + server, deploymentsPath, certsPath := setupRenewalTestServer(t) + + writeSelfSignedCert(t, certsPath, "primary.example.com") + writeSelfSignedCert(t, certsPath, "alt.example.com") + writeSelfSignedCert(t, certsPath, "other.example.com") + + writeDeploymentWithDomains(t, deploymentsPath, "multi-app", []models.DomainConfig{ + { + Domain: "primary.example.com", + Aliases: []string{"alt.example.com"}, + }, + {Domain: "other.example.com"}, + }) + + router := gin.New() + router.POST("/deployments/:name/certificates/renew", server.renewDeploymentCertificates) + + req := httptest.NewRequest( + http.MethodPost, + "/deployments/multi-app/certificates/renew", + bytes.NewReader(nil), + ) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (%s)", w.Code, w.Body.String()) + } + + var resp struct { + Deployment string `json:"deployment"` + Result struct { + Success bool `json:"success"` + Results []struct { + Domain string `json:"domain"` + Success bool `json:"success"` + } `json:"results"` + } `json:"result"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if resp.Deployment != "multi-app" { + t.Errorf("deployment = %q, want multi-app", resp.Deployment) + } + if len(resp.Result.Results) != 3 { + t.Errorf("expected 3 domains renewed, got %d: %+v", len(resp.Result.Results), resp.Result.Results) + } + seen := make(map[string]bool) + for _, r := range resp.Result.Results { + seen[r.Domain] = true + } + for _, want := range []string{"primary.example.com", "alt.example.com", "other.example.com"} { + if !seen[want] { + t.Errorf("expected result for %s, got %+v", want, resp.Result.Results) + } + } +} + +func TestRenewDeploymentCertificates_NoDomains(t *testing.T) { + server, deploymentsPath, _ := setupRenewalTestServer(t) + + writeDeploymentWithDomains(t, deploymentsPath, "empty-app", nil) + + router := gin.New() + router.POST("/deployments/:name/certificates/renew", server.renewDeploymentCertificates) + + req := httptest.NewRequest( + http.MethodPost, + "/deployments/empty-app/certificates/renew", + bytes.NewReader(nil), + ) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 for deployment with no domains, got %d (%s)", w.Code, w.Body.String()) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 1f7e2f4..441fa2a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -35,6 +35,7 @@ import ( "github.com/flatrun/agent/internal/proxy" "github.com/flatrun/agent/internal/scheduler" "github.com/flatrun/agent/internal/security" + "github.com/flatrun/agent/internal/ssl" "github.com/flatrun/agent/internal/setup" "github.com/flatrun/agent/internal/system" "github.com/flatrun/agent/internal/traffic" @@ -78,6 +79,7 @@ type Server struct { clusterManager *cluster.Manager setupManager *setup.Manager setupHandlers *setup.Handlers + certRenewer *ssl.Renewer statsMu sync.RWMutex statsCache gin.H @@ -339,7 +341,11 @@ func (s *Server) setupRoutes() { protected.GET("/certificates", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.listCertificates) protected.POST("/certificates", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.requestCertificate) protected.POST("/certificates/renew", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.renewCertificates) + protected.GET("/certificates/:domain", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.getCertificate) + protected.POST("/certificates/:domain/renew", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.renewCertificate) + protected.PATCH("/certificates/:domain/auto-renew", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.setCertificateAutoRenew) protected.DELETE("/certificates/:domain", s.authMiddleware.RequirePermission(auth.PermCertificatesDelete), s.deleteCertificate) + protected.POST("/deployments/:name/certificates/renew", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.renewDeploymentCertificates) // Proxy endpoints protected.GET("/proxy/status/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.getProxyStatus) @@ -625,11 +631,29 @@ func (s *Server) Start() error { Handler: s.router, } + if s.config.Certbot.Enabled && s.config.Certbot.AutoRenewalEnabled { + s.certRenewer = ssl.NewRenewer( + s.proxyOrchestrator.SSLManager(), + s.config.Certbot.RenewalThresholdDays, + s.config.Certbot.RenewalCheckInterval, + func(domain string) { + if err := s.proxyOrchestrator.NginxManager().Reload(); err != nil { + log.Printf("auto-renew: failed to reload nginx after %s: %v", domain, err) + } + }, + ) + s.certRenewer.Start(context.Background()) + log.Printf("auto-renew: enabled (threshold=%d days, interval=%s)", s.config.Certbot.RenewalThresholdDays, s.config.Certbot.RenewalCheckInterval) + } + return s.server.ListenAndServe() } func (s *Server) Stop() error { + if s.certRenewer != nil { + s.certRenewer.Stop() + } if s.clusterManager != nil { s.clusterManager.Stop() } @@ -3327,11 +3351,52 @@ func (s *Server) listCertificates(c *gin.Context) { return } + s.annotateCertificatesWithDeployment(certificates) + c.JSON(http.StatusOK, gin.H{ "certificates": certificates, }) } +func (s *Server) annotateCertificatesWithDeployment(certs []models.Certificate) { + if len(certs) == 0 { + return + } + + deployments, err := s.manager.ListDeployments() + if err != nil { + log.Printf("warning: failed to list deployments for cert annotation: %v", err) + return + } + + domainToDeployment := make(map[string]string) + for _, d := range deployments { + if d.Metadata == nil { + continue + } + for _, dom := range d.Metadata.Domains { + if dom.Domain != "" { + if _, exists := domainToDeployment[dom.Domain]; !exists { + domainToDeployment[dom.Domain] = d.Name + } + } + for _, alias := range dom.Aliases { + if alias != "" { + if _, exists := domainToDeployment[alias]; !exists { + domainToDeployment[alias] = d.Name + } + } + } + } + } + + for i := range certs { + if name, ok := domainToDeployment[certs[i].Domain]; ok { + certs[i].DeploymentID = name + } + } +} + func (s *Server) requestCertificate(c *gin.Context) { var req struct { Domain string `json:"domain" binding:"required"` @@ -3415,6 +3480,104 @@ func (s *Server) renewCertificates(c *gin.Context) { }) } +func (s *Server) getCertificate(c *gin.Context) { + domain := c.Param("domain") + cert, err := s.proxyOrchestrator.SSLManager().GetCertificate(domain) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + annotated := []models.Certificate{*cert} + s.annotateCertificatesWithDeployment(annotated) + c.JSON(http.StatusOK, gin.H{"certificate": annotated[0]}) +} + +func (s *Server) renewCertificate(c *gin.Context) { + domain := c.Param("domain") + + result, err := s.proxyOrchestrator.RenewCertificate(domain) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "Certificate renewed", + "domain": domain, + "result": result, + }) +} + +func (s *Server) setCertificateAutoRenew(c *gin.Context) { + domain := c.Param("domain") + + var req struct { + AutoRenew bool `json:"auto_renew"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := s.proxyOrchestrator.SSLManager().SetAutoRenew(domain, req.AutoRenew); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "domain": domain, + "auto_renew": req.AutoRenew, + }) +} + +func (s *Server) renewDeploymentCertificates(c *gin.Context) { + name := c.Param("name") + + dep, err := s.manager.GetDeployment(name) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + if dep.Metadata == nil { + c.JSON(http.StatusOK, gin.H{ + "message": "no domains configured for deployment", + "result": nil, + }) + return + } + + seen := make(map[string]bool) + var domains []string + for _, d := range dep.Metadata.Domains { + if d.Domain != "" && !seen[d.Domain] { + domains = append(domains, d.Domain) + seen[d.Domain] = true + } + for _, alias := range d.Aliases { + if alias != "" && !seen[alias] { + domains = append(domains, alias) + seen[alias] = true + } + } + } + + if len(domains) == 0 { + c.JSON(http.StatusOK, gin.H{ + "message": "no domains configured for deployment", + "result": nil, + }) + return + } + + result := s.proxyOrchestrator.RenewCertificatesForDomains(domains) + c.JSON(http.StatusOK, gin.H{ + "message": "Deployment certificate renewal completed", + "deployment": name, + "result": result, + }) +} + func (s *Server) deleteCertificate(c *gin.Context) { domain := c.Param("domain") force := c.DefaultQuery("force", "false") == "true" diff --git a/internal/proxy/interfaces.go b/internal/proxy/interfaces.go index ae3d2b3..aa04fb9 100644 --- a/internal/proxy/interfaces.go +++ b/internal/proxy/interfaces.go @@ -22,6 +22,8 @@ type SSLManager interface { RequestCertificate(domain string) (*ssl.CertificateResult, error) GetCertificate(domain string) (*models.Certificate, error) RenewCertificates() (*ssl.RenewalResult, error) + RenewCertificate(domain string) (*ssl.RenewalResult, error) ListCertificates() ([]models.Certificate, error) GetExpiringCertificates(days int) ([]models.Certificate, error) + SetAutoRenew(domain string, enabled bool) error } diff --git a/internal/proxy/orchestrator.go b/internal/proxy/orchestrator.go index 93b3cb9..a765ae4 100644 --- a/internal/proxy/orchestrator.go +++ b/internal/proxy/orchestrator.go @@ -219,6 +219,56 @@ func (o *Orchestrator) RenewCertificates() (*ssl.RenewalResult, error) { return result, nil } +func (o *Orchestrator) RenewCertificate(domain string) (*ssl.RenewalResult, error) { + if err := o.ssl.ValidateDomain(domain); err != nil { + return nil, err + } + + result, err := o.ssl.RenewCertificate(domain) + if err != nil { + return nil, err + } + + if err := o.nginx.Reload(); err != nil { + log.Printf("warning: failed to reload nginx after renewal of %s: %v", domain, err) + } + + return result, nil +} + +func (o *Orchestrator) RenewCertificatesForDomains(domains []string) *ssl.MultiCertificateResult { + result := &ssl.MultiCertificateResult{ + Results: make([]*ssl.CertificateResult, 0, len(domains)), + Success: true, + } + + for _, domain := range domains { + if !o.ssl.CertificateExists(domain) { + continue + } + if _, err := o.ssl.RenewCertificate(domain); err != nil { + result.Results = append(result.Results, &ssl.CertificateResult{ + Domain: domain, + Success: false, + Message: err.Error(), + }) + result.Success = false + continue + } + result.Results = append(result.Results, &ssl.CertificateResult{ + Domain: domain, + Success: true, + Message: "renewed", + }) + } + + if err := o.nginx.Reload(); err != nil { + log.Printf("warning: failed to reload nginx after deployment renewal: %v", err) + } + + return result +} + func (o *Orchestrator) GetDeploymentProxyStatus(deployment *models.Deployment) *ProxyStatus { status := &ProxyStatus{ DeploymentName: deployment.Name, diff --git a/internal/proxy/orchestrator_test.go b/internal/proxy/orchestrator_test.go index 6811af4..3d2aaeb 100644 --- a/internal/proxy/orchestrator_test.go +++ b/internal/proxy/orchestrator_test.go @@ -117,6 +117,10 @@ func (m *mockSSLManager) RenewCertificates() (*ssl.RenewalResult, error) { return &ssl.RenewalResult{Success: true}, nil } +func (m *mockSSLManager) RenewCertificate(domain string) (*ssl.RenewalResult, error) { + return &ssl.RenewalResult{Success: true, RenewedDomains: []string{domain}}, nil +} + func (m *mockSSLManager) ListCertificates() ([]models.Certificate, error) { return nil, nil } @@ -125,6 +129,10 @@ func (m *mockSSLManager) GetExpiringCertificates(days int) ([]models.Certificate return nil, nil } +func (m *mockSSLManager) SetAutoRenew(domain string, enabled bool) error { + return nil +} + type testableOrchestrator struct { nginx NginxManager ssl SSLManager diff --git a/internal/ssl/autorenew.go b/internal/ssl/autorenew.go new file mode 100644 index 0000000..45599d4 --- /dev/null +++ b/internal/ssl/autorenew.go @@ -0,0 +1,105 @@ +package ssl + +import ( + "context" + "log" + "sync" + "time" +) + +// Renewer runs periodic auto-renewal checks for certificates that have +// auto-renew enabled and are within the configured expiry threshold. +type Renewer struct { + manager *Manager + thresholdDays int + interval time.Duration + onRenew func(domain string) + + mu sync.Mutex + cancel context.CancelFunc +} + +func NewRenewer(manager *Manager, thresholdDays int, interval time.Duration, onRenew func(domain string)) *Renewer { + if thresholdDays <= 0 { + thresholdDays = 30 + } + if interval <= 0 { + interval = 12 * time.Hour + } + return &Renewer{ + manager: manager, + thresholdDays: thresholdDays, + interval: interval, + onRenew: onRenew, + } +} + +func (r *Renewer) Start(ctx context.Context) { + r.mu.Lock() + if r.cancel != nil { + r.mu.Unlock() + return + } + ctx, cancel := context.WithCancel(ctx) + r.cancel = cancel + r.mu.Unlock() + + go r.loop(ctx) +} + +func (r *Renewer) Stop() { + r.mu.Lock() + defer r.mu.Unlock() + if r.cancel != nil { + r.cancel() + r.cancel = nil + } +} + +func (r *Renewer) loop(ctx context.Context) { + // Run shortly after start so certs near expiry get picked up promptly. + initial := time.NewTimer(30 * time.Second) + defer initial.Stop() + + select { + case <-ctx.Done(): + return + case <-initial.C: + r.Run() + } + + ticker := time.NewTicker(r.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + r.Run() + } + } +} + +// Run performs one pass of the renewal check. Exported for tests and manual triggers. +func (r *Renewer) Run() { + expiring, err := r.manager.GetExpiringCertificates(r.thresholdDays) + if err != nil { + log.Printf("auto-renew: failed to list expiring certificates: %v", err) + return + } + + for _, cert := range expiring { + if !cert.AutoRenew { + continue + } + if _, err := r.manager.RenewCertificate(cert.Domain); err != nil { + log.Printf("auto-renew: failed to renew %s (days_left=%d): %v", cert.Domain, cert.DaysLeft, err) + continue + } + log.Printf("auto-renew: renewed %s (was %d days from expiry)", cert.Domain, cert.DaysLeft) + if r.onRenew != nil { + r.onRenew(cert.Domain) + } + } +} diff --git a/internal/ssl/manager.go b/internal/ssl/manager.go index 57acf82..930200a 100644 --- a/internal/ssl/manager.go +++ b/internal/ssl/manager.go @@ -176,6 +176,65 @@ func (m *Manager) RenewCertificates() (*RenewalResult, error) { }, nil } +func (m *Manager) RenewCertificate(domain string) (*RenewalResult, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.certificateExistsLocked(domain) { + return nil, fmt.Errorf("certificate for domain %q not found", domain) + } + + output, err := m.executeCertbot([]string{ + "renew", + "--non-interactive", + "--cert-name", domain, + }) + if err != nil { + return nil, fmt.Errorf("renewal failed for %s: %s - %w", domain, string(output), err) + } + + return &RenewalResult{ + Success: true, + Message: string(output), + RenewedDomains: []string{domain}, + }, nil +} + +func (m *Manager) certificateExistsLocked(domain string) bool { + certPath := filepath.Join(m.certsPath, domain, "cert.pem") + _, err := os.Stat(certPath) + return err == nil +} + +func (m *Manager) SetAutoRenew(domain string, enabled bool) error { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.certificateExistsLocked(domain) { + return fmt.Errorf("certificate for domain %q not found", domain) + } + + marker := filepath.Join(m.certsPath, domain, ".flatrun-auto-renew-disabled") + if enabled { + if err := os.Remove(marker); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to enable auto-renew: %w", err) + } + return nil + } + + f, err := os.Create(marker) + if err != nil { + return fmt.Errorf("failed to disable auto-renew: %w", err) + } + return f.Close() +} + +func (m *Manager) isAutoRenewEnabled(domain string) bool { + marker := filepath.Join(m.certsPath, domain, ".flatrun-auto-renew-disabled") + _, err := os.Stat(marker) + return os.IsNotExist(err) +} + func (m *Manager) RevokeCertificate(domain string) error { m.mu.Lock() defer m.mu.Unlock() @@ -302,7 +361,7 @@ func (m *Manager) parseCertificate(certPath, domain string) (*models.Certificate DaysLeft: daysLeft, Status: status, Path: certPath, - AutoRenew: true, + AutoRenew: m.isAutoRenewEnabled(domain), }, nil } diff --git a/internal/ssl/renewal_test.go b/internal/ssl/renewal_test.go new file mode 100644 index 0000000..5743358 --- /dev/null +++ b/internal/ssl/renewal_test.go @@ -0,0 +1,255 @@ +package ssl + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/flatrun/agent/pkg/config" +) + +// writeTestCert generates a self-signed certificate and writes it into the +// layout that ssl.Manager expects: //cert.pem. +func writeTestCert(t *testing.T, certsPath, domain string, notAfter time.Time) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: domain}, + Issuer: pkix.Name{CommonName: "flatrun-test-ca"}, + NotBefore: time.Now().Add(-24 * time.Hour), + NotAfter: notAfter, + DNSNames: []string{domain}, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + + domainDir := filepath.Join(certsPath, domain) + if err := os.MkdirAll(domainDir, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + + certPath := filepath.Join(domainDir, "cert.pem") + f, err := os.Create(certPath) + if err != nil { + t.Fatalf("create cert file: %v", err) + } + defer f.Close() + + if err := pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: der}); err != nil { + t.Fatalf("encode pem: %v", err) + } +} + +func newTestManager(t *testing.T) (*Manager, *mockExecutor, string) { + t.Helper() + tmpDir := t.TempDir() + certsDir := filepath.Join(tmpDir, "live") + if err := os.MkdirAll(certsDir, 0755); err != nil { + t.Fatalf("mkdir certs: %v", err) + } + + mock := &mockExecutor{} + cfg := &config.CertbotConfig{CertsPath: certsDir} + m := NewManager(cfg, tmpDir, mock) + return m, mock, certsDir +} + +func TestRenewCertificate_PassesCertName(t *testing.T) { + m, mock, certsDir := newTestManager(t) + writeTestCert(t, certsDir, "example.com", time.Now().Add(10*24*time.Hour)) + + result, err := m.RenewCertificate("example.com") + if err != nil { + t.Fatalf("RenewCertificate: %v", err) + } + if !result.Success { + t.Error("expected Success=true") + } + if len(result.RenewedDomains) != 1 || result.RenewedDomains[0] != "example.com" { + t.Errorf("RenewedDomains = %v, want [example.com]", result.RenewedDomains) + } + + if len(mock.calls) != 1 { + t.Fatalf("expected 1 executor call, got %d", len(mock.calls)) + } + + args := mock.calls[0].args + var hasRenew, hasNonInteractive, hasCertName bool + for i, arg := range args { + switch arg { + case "renew": + hasRenew = true + case "--non-interactive": + hasNonInteractive = true + case "--cert-name": + if i+1 < len(args) && args[i+1] == "example.com" { + hasCertName = true + } + } + } + if !hasRenew || !hasNonInteractive || !hasCertName { + t.Errorf("unexpected certbot args: %v", args) + } +} + +func TestRenewCertificate_ErrorsWhenMissing(t *testing.T) { + m, mock, _ := newTestManager(t) + + _, err := m.RenewCertificate("missing.example.com") + if err == nil { + t.Fatal("expected error for missing certificate") + } + if len(mock.calls) != 0 { + t.Errorf("expected no executor calls when cert missing, got %d", len(mock.calls)) + } +} + +func TestSetAutoRenew_TogglesMarkerAndReflectsInCertificate(t *testing.T) { + m, _, certsDir := newTestManager(t) + writeTestCert(t, certsDir, "auto.example.com", time.Now().Add(20*24*time.Hour)) + + cert, err := m.GetCertificate("auto.example.com") + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if !cert.AutoRenew { + t.Error("new cert should default to auto_renew=true") + } + + if err := m.SetAutoRenew("auto.example.com", false); err != nil { + t.Fatalf("SetAutoRenew(false): %v", err) + } + + cert, err = m.GetCertificate("auto.example.com") + if err != nil { + t.Fatalf("GetCertificate after disable: %v", err) + } + if cert.AutoRenew { + t.Error("cert should report auto_renew=false after disable") + } + + if err := m.SetAutoRenew("auto.example.com", true); err != nil { + t.Fatalf("SetAutoRenew(true): %v", err) + } + + cert, err = m.GetCertificate("auto.example.com") + if err != nil { + t.Fatalf("GetCertificate after re-enable: %v", err) + } + if !cert.AutoRenew { + t.Error("cert should report auto_renew=true after re-enable") + } +} + +func TestSetAutoRenew_ErrorsWhenCertMissing(t *testing.T) { + m, _, _ := newTestManager(t) + + if err := m.SetAutoRenew("nope.example.com", false); err == nil { + t.Error("expected error for missing certificate") + } +} + +func TestGetExpiringCertificates_FiltersByThreshold(t *testing.T) { + m, _, certsDir := newTestManager(t) + writeTestCert(t, certsDir, "fresh.example.com", time.Now().Add(90*24*time.Hour)) + writeTestCert(t, certsDir, "soon.example.com", time.Now().Add(5*24*time.Hour)) + writeTestCert(t, certsDir, "expired.example.com", time.Now().Add(-24*time.Hour)) + + expiring, err := m.GetExpiringCertificates(30) + if err != nil { + t.Fatalf("GetExpiringCertificates: %v", err) + } + + got := make(map[string]bool) + for _, c := range expiring { + got[c.Domain] = true + } + + if got["fresh.example.com"] { + t.Error("fresh cert should not be reported as expiring") + } + if !got["soon.example.com"] { + t.Error("soon cert should be reported as expiring") + } + if !got["expired.example.com"] { + t.Error("expired cert should be reported as expiring") + } +} + +func TestRenewer_RenewsOnlyExpiringAutoRenewCerts(t *testing.T) { + m, mock, certsDir := newTestManager(t) + + // Within threshold, auto-renew on (default) — should be renewed. + writeTestCert(t, certsDir, "renew.example.com", time.Now().Add(10*24*time.Hour)) + + // Within threshold, auto-renew off — should be skipped. + writeTestCert(t, certsDir, "manual.example.com", time.Now().Add(10*24*time.Hour)) + if err := m.SetAutoRenew("manual.example.com", false); err != nil { + t.Fatalf("SetAutoRenew: %v", err) + } + + // Outside threshold — should be skipped. + writeTestCert(t, certsDir, "fresh.example.com", time.Now().Add(120*24*time.Hour)) + + renewed := make(map[string]bool) + r := NewRenewer(m, 30, time.Hour, func(domain string) { + renewed[domain] = true + }) + r.Run() + + var renewCalls []string + for _, call := range mock.calls { + if len(call.args) > 0 && call.args[0] == "renew" { + for i, a := range call.args { + if a == "--cert-name" && i+1 < len(call.args) { + renewCalls = append(renewCalls, call.args[i+1]) + } + } + } + } + + if len(renewCalls) != 1 || renewCalls[0] != "renew.example.com" { + t.Errorf("expected exactly one renewal for renew.example.com, got %v", renewCalls) + } + if !renewed["renew.example.com"] { + t.Error("onRenew callback was not invoked for renew.example.com") + } + if renewed["manual.example.com"] { + t.Error("onRenew should not fire for auto-renew=false cert") + } + if renewed["fresh.example.com"] { + t.Error("onRenew should not fire for cert outside threshold") + } +} + +func TestRenewer_StartStop(t *testing.T) { + m, _, _ := newTestManager(t) + r := NewRenewer(m, 30, time.Hour, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r.Start(ctx) + // Calling Start twice should be a no-op and not panic. + r.Start(ctx) + r.Stop() + // Stop twice should be safe. + r.Stop() +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 95a3a57..0b10872 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -68,14 +68,17 @@ type NginxConfig struct { } type CertbotConfig struct { - Enabled bool `yaml:"enabled" json:"enabled"` - Image string `yaml:"image" json:"image"` - Email string `yaml:"email" json:"email"` - Staging bool `yaml:"staging" json:"staging"` - CertsPath string `yaml:"certs_path" json:"certs_path"` - WebrootPath string `yaml:"webroot_path" json:"webroot_path"` - ContainerWebrootPath string `yaml:"container_webroot_path" json:"container_webroot_path"` - DNSProvider string `yaml:"dns_provider" json:"dns_provider"` + Enabled bool `yaml:"enabled" json:"enabled"` + Image string `yaml:"image" json:"image"` + Email string `yaml:"email" json:"email"` + Staging bool `yaml:"staging" json:"staging"` + CertsPath string `yaml:"certs_path" json:"certs_path"` + WebrootPath string `yaml:"webroot_path" json:"webroot_path"` + ContainerWebrootPath string `yaml:"container_webroot_path" json:"container_webroot_path"` + DNSProvider string `yaml:"dns_provider" json:"dns_provider"` + AutoRenewalEnabled bool `yaml:"auto_renewal_enabled" json:"auto_renewal_enabled"` + RenewalThresholdDays int `yaml:"renewal_threshold_days" json:"renewal_threshold_days"` + RenewalCheckInterval time.Duration `yaml:"renewal_check_interval" json:"renewal_check_interval"` } type ServiceExecConfig struct { @@ -285,6 +288,12 @@ func setDefaults(cfg *Config) { if cfg.Certbot.Image == "" { cfg.Certbot.Image = "certbot/certbot" } + if cfg.Certbot.RenewalThresholdDays == 0 { + cfg.Certbot.RenewalThresholdDays = 30 + } + if cfg.Certbot.RenewalCheckInterval == 0 { + cfg.Certbot.RenewalCheckInterval = 12 * time.Hour + } // Security defaults if cfg.Security.ScanInterval == 0 { cfg.Security.ScanInterval = 30 * time.Second