diff --git a/backend/gateway/internal/auth/auth.go b/backend/gateway/internal/auth/auth.go index 3b7a206..3ba7c77 100644 --- a/backend/gateway/internal/auth/auth.go +++ b/backend/gateway/internal/auth/auth.go @@ -11,9 +11,14 @@ package auth import ( "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" "errors" "net/http" "strings" + "time" ) // PlayerID is the Supabase user ID extracted from the JWT. @@ -36,6 +41,12 @@ var ErrMissingAuth = errors.New("missing Authorization header") // signature, wrong issuer). var ErrInvalidJWT = errors.New("invalid JWT") +const ( + maxAuthorizationHeaderBytes = 8 * 1024 + maxBearerTokenBytes = 6 * 1024 + maxJWTPartBytes = 3 * 1024 +) + // Verifier validates Supabase JWTs and extracts the player ID. // The concrete impl uses the Supabase project's JWT secret (HS256) - // no network call to Supabase per request, the secret is enough to @@ -44,6 +55,15 @@ type Verifier interface { Verify(ctx context.Context, jwt string) (Identity, error) } +// NewHS256Verifier returns a local Supabase-compatible JWT verifier. +func NewHS256Verifier(secret string) Verifier { + secret = strings.TrimSpace(secret) + if secret == "" { + return nil + } + return hs256Verifier{secret: []byte(secret), now: func() time.Time { return time.Now().UTC() }} +} + // FromRequest extracts the bearer token from an HTTP request. // Returns ErrMissingAuth if the header is absent or malformed. func FromRequest(r *http.Request) (string, error) { @@ -51,6 +71,9 @@ func FromRequest(r *http.Request) (string, error) { if authHdr == "" { return "", ErrMissingAuth } + if len(authHdr) > maxAuthorizationHeaderBytes { + return "", ErrInvalidJWT + } const prefix = "Bearer " if !strings.HasPrefix(authHdr, prefix) { return "", ErrMissingAuth @@ -59,5 +82,87 @@ func FromRequest(r *http.Request) (string, error) { if token == "" { return "", ErrMissingAuth } + if len(token) > maxBearerTokenBytes { + return "", ErrInvalidJWT + } return token, nil } + +type hs256Verifier struct { + secret []byte + now func() time.Time +} + +type jwtHeader struct { + Algorithm string `json:"alg"` + Type string `json:"typ"` +} + +type supabaseClaims struct { + Subject string `json:"sub"` + Role string `json:"role"` + Email string `json:"email"` + IsAnonymous bool `json:"is_anonymous"` + ExpiresAt int64 `json:"exp"` +} + +func (v hs256Verifier) Verify(_ context.Context, jwt string) (Identity, error) { + if len(jwt) > maxBearerTokenBytes { + return Identity{}, ErrInvalidJWT + } + parts := strings.Split(jwt, ".") + if len(parts) != 3 { + return Identity{}, ErrInvalidJWT + } + for _, part := range parts { + if part == "" || len(part) > maxJWTPartBytes { + return Identity{}, ErrInvalidJWT + } + } + + var header jwtHeader + if err := decodeJWTPart(parts[0], &header); err != nil { + return Identity{}, ErrInvalidJWT + } + if header.Algorithm != "HS256" { + return Identity{}, ErrInvalidJWT + } + + mac := hmac.New(sha256.New, v.secret) + mac.Write([]byte(parts[0])) + mac.Write([]byte(".")) + mac.Write([]byte(parts[1])) + expected := mac.Sum(nil) + + signature, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil || !hmac.Equal(signature, expected) { + return Identity{}, ErrInvalidJWT + } + + var claims supabaseClaims + if err := decodeJWTPart(parts[1], &claims); err != nil { + return Identity{}, ErrInvalidJWT + } + if strings.TrimSpace(claims.Subject) == "" { + return Identity{}, ErrInvalidJWT + } + if claims.ExpiresAt > 0 && v.now().Unix() >= claims.ExpiresAt { + return Identity{}, ErrInvalidJWT + } + + return Identity{ + PlayerID: PlayerID(claims.Subject), + Role: claims.Role, + Email: claims.Email, + IsAnonymous: claims.IsAnonymous, + ExpiresAt: claims.ExpiresAt, + }, nil +} + +func decodeJWTPart(part string, target any) error { + payload, err := base64.RawURLEncoding.DecodeString(part) + if err != nil { + return err + } + return json.Unmarshal(payload, target) +} diff --git a/backend/gateway/internal/auth/auth_test.go b/backend/gateway/internal/auth/auth_test.go new file mode 100644 index 0000000..151427a --- /dev/null +++ b/backend/gateway/internal/auth/auth_test.go @@ -0,0 +1,86 @@ +package auth + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestHS256VerifierAcceptsValidSupabaseToken(t *testing.T) { + verifier := hs256Verifier{ + secret: []byte("test-secret"), + now: func() time.Time { return time.Unix(100, 0).UTC() }, + } + token := signTestJWT(t, "test-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","role":"authenticated","email":"p@example.test","exp":200}`) + + identity, err := verifier.Verify(context.Background(), token) + if err != nil { + t.Fatalf("expected token to verify: %v", err) + } + if identity.PlayerID != "player-1" { + t.Fatalf("expected player id from sub, got %q", identity.PlayerID) + } + if identity.Role != "authenticated" { + t.Fatalf("expected role from token, got %q", identity.Role) + } +} + +func TestHS256VerifierRejectsInvalidSignature(t *testing.T) { + verifier := hs256Verifier{ + secret: []byte("test-secret"), + now: func() time.Time { return time.Unix(100, 0).UTC() }, + } + token := signTestJWT(t, "other-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","exp":200}`) + + if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT { + t.Fatalf("expected ErrInvalidJWT, got %v", err) + } +} + +func TestHS256VerifierRejectsExpiredToken(t *testing.T) { + verifier := hs256Verifier{ + secret: []byte("test-secret"), + now: func() time.Time { return time.Unix(300, 0).UTC() }, + } + token := signTestJWT(t, "test-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","exp":200}`) + + if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT { + t.Fatalf("expected ErrInvalidJWT, got %v", err) + } +} + +func TestFromRequestRejectsOversizedBearerToken(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/agent/decide", nil) + req.Header.Set("Authorization", "Bearer "+strings.Repeat("a", maxBearerTokenBytes+1)) + + if _, err := FromRequest(req); err != ErrInvalidJWT { + t.Fatalf("expected ErrInvalidJWT, got %v", err) + } +} + +func TestHS256VerifierRejectsOversizedJWTPart(t *testing.T) { + verifier := hs256Verifier{ + secret: []byte("test-secret"), + now: func() time.Time { return time.Unix(100, 0).UTC() }, + } + token := strings.Repeat("a", maxJWTPartBytes+1) + ".payload.signature" + + if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT { + t.Fatalf("expected ErrInvalidJWT, got %v", err) + } +} + +func signTestJWT(t *testing.T, secret string, header string, claims string) string { + t.Helper() + encodedHeader := base64.RawURLEncoding.EncodeToString([]byte(header)) + encodedClaims := base64.RawURLEncoding.EncodeToString([]byte(claims)) + unsigned := encodedHeader + "." + encodedClaims + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte(unsigned)) + return unsigned + "." + base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} diff --git a/backend/gateway/internal/server/agent_decision_limiter.go b/backend/gateway/internal/server/agent_decision_limiter.go new file mode 100644 index 0000000..13345e6 --- /dev/null +++ b/backend/gateway/internal/server/agent_decision_limiter.go @@ -0,0 +1,168 @@ +package server + +import ( + "strings" + "sync" + "time" + + "github.com/DOS/Second-Spawn/backend/gateway/internal/config" +) + +type agentDecisionLimitResult struct { + Error string `json:"error"` + Reason string `json:"reason"` + PlayerID string `json:"player_id"` + RetryAfterSeconds int64 `json:"retry_after_seconds,omitempty"` + TokenEstimate int `json:"token_estimate,omitempty"` + TokenBudgetPerDay int `json:"token_budget_per_day,omitempty"` + TokenBudgetUsedToday int `json:"token_budget_used_today,omitempty"` + TokenBudgetRemaining int `json:"token_budget_remaining,omitempty"` +} + +type agentDecisionLimiter struct { + mu sync.Mutex + cfg *config.Config + now func() time.Time + lastPruned time.Time + // TODO(#13): Move limiter state to Redis or another shared store before + // running more than one gateway instance. + players map[string]*agentDecisionLimitState +} + +type agentDecisionLimitState struct { + minuteStart time.Time + minuteCount int + day string + tokensUsed int + lastSeen time.Time +} + +const agentDecisionLimitStateTTL = 25 * time.Hour +const agentDecisionLimitPruneScanLimit = 64 + +func newAgentDecisionLimiter(cfg *config.Config, now func() time.Time) *agentDecisionLimiter { + if cfg == nil { + cfg = &config.Config{} + } + if now == nil { + now = func() time.Time { return time.Now().UTC() } + } + return &agentDecisionLimiter{ + cfg: cfg, + now: now, + players: map[string]*agentDecisionLimitState{}, + } +} + +func (l *agentDecisionLimiter) Allow(playerID string, tokenEstimate int) (bool, agentDecisionLimitResult) { + if l == nil || l.cfg == nil || !l.enabled() { + return true, agentDecisionLimitResult{} + } + + playerID = normalizeLimitPlayerID(playerID) + tokenEstimate = max(tokenEstimate, 1) + now := l.now().UTC() + minuteStart := now.Truncate(time.Minute) + day := now.Format("2006-01-02") + + l.mu.Lock() + defer l.mu.Unlock() + + l.pruneExpiredIfDue(now) + state := l.playerState(playerID, minuteStart, day) + state.resetWindows(minuteStart, day) + state.lastSeen = now + if result, blocked := state.rateLimitResult(playerID, l.cfg.LLMRateLimitPerPlayerPerMin, now); blocked { + return false, result + } + if result, blocked := state.tokenBudgetResult(playerID, tokenEstimate, l.cfg.LLMTokenBudgetPerPlayerDay); blocked { + return false, result + } + + state.minuteCount++ + state.tokensUsed += tokenEstimate + return true, agentDecisionLimitResult{} +} + +func (l *agentDecisionLimiter) enabled() bool { + return l.cfg.LLMRateLimitPerPlayerPerMin > 0 || l.cfg.LLMTokenBudgetPerPlayerDay > 0 +} + +func (l *agentDecisionLimiter) playerState(playerID string, minuteStart time.Time, day string) *agentDecisionLimitState { + state := l.players[playerID] + if state != nil { + return state + } + state = &agentDecisionLimitState{minuteStart: minuteStart, day: day} + l.players[playerID] = state + return state +} + +func (l *agentDecisionLimiter) pruneExpiredIfDue(now time.Time) { + if !l.lastPruned.IsZero() && now.Sub(l.lastPruned) < time.Minute { + return + } + l.lastPruned = now + + cutoff := now.Add(-agentDecisionLimitStateTTL) + scanned := 0 + for playerID, state := range l.players { + if !state.lastSeen.IsZero() && state.lastSeen.Before(cutoff) { + delete(l.players, playerID) + } + scanned++ + if scanned >= agentDecisionLimitPruneScanLimit { + return + } + } +} + +func (s *agentDecisionLimitState) resetWindows(minuteStart time.Time, day string) { + if !s.minuteStart.Equal(minuteStart) { + s.minuteStart = minuteStart + s.minuteCount = 0 + } + if s.day != day { + s.day = day + s.tokensUsed = 0 + } +} + +func (s *agentDecisionLimitState) rateLimitResult(playerID string, rateLimit int, now time.Time) (agentDecisionLimitResult, bool) { + if rateLimit <= 0 || s.minuteCount < rateLimit { + return agentDecisionLimitResult{}, false + } + retryAfter := s.minuteStart.Add(time.Minute).Sub(now) + if retryAfter < time.Second { + retryAfter = time.Second + } + return agentDecisionLimitResult{ + Error: "agent decision rate limit exceeded", + Reason: "rate_limit_exceeded", + PlayerID: playerID, + RetryAfterSeconds: int64(retryAfter.Seconds()), + }, true +} + +func (s *agentDecisionLimitState) tokenBudgetResult(playerID string, tokenEstimate int, tokenBudget int) (agentDecisionLimitResult, bool) { + if tokenBudget <= 0 || s.tokensUsed+tokenEstimate <= tokenBudget { + return agentDecisionLimitResult{}, false + } + return agentDecisionLimitResult{ + Error: "agent decision token budget exhausted", + Reason: "token_budget_exhausted", + PlayerID: playerID, + TokenEstimate: tokenEstimate, + TokenBudgetPerDay: tokenBudget, + TokenBudgetUsedToday: s.tokensUsed, + TokenBudgetRemaining: max(tokenBudget-s.tokensUsed, 0), + }, true +} + +func normalizeLimitPlayerID(playerID string) string { + playerID = strings.TrimSpace(playerID) + if playerID == "" { + return "unknown" + } + return playerID +} diff --git a/backend/gateway/internal/server/server.go b/backend/gateway/internal/server/server.go index 1730257..9b27ccb 100644 --- a/backend/gateway/internal/server/server.go +++ b/backend/gateway/internal/server/server.go @@ -4,10 +4,12 @@ import ( "encoding/json" "errors" "net/http" + "strconv" "strings" "time" "github.com/DOS/Second-Spawn/backend/gateway/internal/agent" + "github.com/DOS/Second-Spawn/backend/gateway/internal/auth" "github.com/DOS/Second-Spawn/backend/gateway/internal/character" "github.com/DOS/Second-Spawn/backend/gateway/internal/config" "github.com/DOS/Second-Spawn/backend/gateway/internal/llm" @@ -22,6 +24,9 @@ type Server struct { cfg *config.Config store character.Store decider agent.Decider + limiter *agentDecisionLimiter + auth auth.Verifier + now func() time.Time } func New(cfg *config.Config) *Server { @@ -46,7 +51,15 @@ func NewWithDependencies(cfg *config.Config, store character.Store, decider agen if decider == nil { decider = agent.PrototypeDecider{} } - return &Server{cfg: cfg, store: store, decider: decider} + srv := &Server{ + cfg: cfg, + store: store, + decider: decider, + auth: auth.NewHS256Verifier(cfg.SupabaseJWTSecret), + now: func() time.Time { return time.Now().UTC() }, + } + srv.limiter = newAgentDecisionLimiter(cfg, srv.now) + return srv } // Routes registers all HTTP handlers. Keep this file small - real handler @@ -135,8 +148,18 @@ func (s *Server) handleAgentDecide(w http.ResponseWriter, r *http.Request) { return } + trustedPlayerID, err := s.resolveTrustedPlayerID(r) + if err != nil { + writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()}) + return + } + if strings.TrimSpace(req.Context.Player.PlayerID) == "" { - ctx, err := s.store.GetOrCreateContext(r.Context(), "dev-player") + playerID := "dev-player" + if trustedPlayerID != "" { + playerID = trustedPlayerID + } + ctx, err := s.store.GetOrCreateContext(r.Context(), playerID) if err != nil { writeError(w, err) return @@ -144,7 +167,17 @@ func (s *Server) handleAgentDecide(w http.ResponseWriter, r *http.Request) { req.Context = ctx } req.Allowed = ensureStopAllowed(req.Allowed) - // TODO(#6): enforce per-player decision rate limits and daily token budgets here. + limitPlayerID := req.Context.Player.PlayerID + if trustedPlayerID != "" { + limitPlayerID = trustedPlayerID + } + if allowed, result := s.limiter.Allow(limitPlayerID, estimateAgentDecisionTokens(req)); !allowed { + if result.RetryAfterSeconds > 0 { + w.Header().Set("Retry-After", strconv.FormatInt(result.RetryAfterSeconds, 10)) + } + writeJSON(w, http.StatusTooManyRequests, result) + return + } decision, err := s.decider.Decide(r.Context(), req) if err != nil { writeJSON(w, http.StatusBadGateway, map[string]any{"error": err.Error()}) @@ -158,6 +191,38 @@ func (s *Server) handleAgentDecide(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, decision) } +const agentDecisionOutputTokenReserve = 400 + +func (s *Server) resolveTrustedPlayerID(r *http.Request) (string, error) { + if s.auth == nil { + return "", nil + } + token, err := auth.FromRequest(r) + if err != nil { + return "", err + } + identity, err := s.auth.Verify(r.Context(), token) + if err != nil { + return "", err + } + playerID := strings.TrimSpace(string(identity.PlayerID)) + if playerID == "" { + return "", auth.ErrInvalidJWT + } + return playerID, nil +} + +func estimateAgentDecisionTokens(req agent.DecisionRequest) int { + payload, err := json.Marshal(req) + if err != nil { + return agentDecisionOutputTokenReserve + } + + // Rough JSON estimate: four bytes per token plus the completion reserve + // used by the model-backed decision path. + return max(len(payload)/4+agentDecisionOutputTokenReserve, agentDecisionOutputTokenReserve) +} + type npcChatRequest struct { PlayerID string `json:"player_id"` NPCID string `json:"npc_id"` diff --git a/backend/gateway/internal/server/server_test.go b/backend/gateway/internal/server/server_test.go index 089e8fb..9f98fce 100644 --- a/backend/gateway/internal/server/server_test.go +++ b/backend/gateway/internal/server/server_test.go @@ -6,9 +6,12 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" + "time" "github.com/DOS/Second-Spawn/backend/gateway/internal/agent" + "github.com/DOS/Second-Spawn/backend/gateway/internal/auth" "github.com/DOS/Second-Spawn/backend/gateway/internal/config" ) @@ -208,6 +211,187 @@ func TestAgentDecideRejectsConfiguredDeciderActionOutsideAllowed(t *testing.T) { } } +func TestAgentDecideRateLimitPerPlayer(t *testing.T) { + decider := &staticAgentDecider{ + decision: agent.Decision{ + Action: agent.ActionSay, + Say: "First request is allowed.", + Reason: "say is allowed for this request", + Confidence: 0.8, + Source: agent.DecisionSourceModel, + SourceReason: "validated_model_intent", + }, + } + srv := NewWithDependencies(&config.Config{ + Env: "test", + LLMRateLimitPerPlayerPerMin: 1, + LLMTokenBudgetPerPlayerDay: 0, + }, nil, decider) + + first := httptest.NewRecorder() + srv.Routes().ServeHTTP(first, newAgentDecideRequest("rate-user")) + if first.Code != http.StatusOK { + t.Fatalf("expected first decision 200, got %d: %s", first.Code, first.Body.String()) + } + + second := httptest.NewRecorder() + srv.Routes().ServeHTTP(second, newAgentDecideRequest("rate-user")) + if second.Code != http.StatusTooManyRequests { + t.Fatalf("expected second decision 429, got %d: %s", second.Code, second.Body.String()) + } + if !bytes.Contains(second.Body.Bytes(), []byte(`"reason":"rate_limit_exceeded"`)) { + t.Fatalf("expected rate limit reason, got %s", second.Body.String()) + } + if second.Header().Get("Retry-After") == "" { + t.Fatal("expected Retry-After header on rate-limited response") + } + if decider.calls != 1 { + t.Fatalf("expected decider to be called once, got %d", decider.calls) + } +} + +func TestAgentDecideRateLimitUsesTrustedAuthSubject(t *testing.T) { + decider := &staticAgentDecider{ + decision: agent.Decision{ + Action: agent.ActionSay, + Say: "Authenticated subject owns the limiter key.", + Reason: "say is allowed for this request", + Confidence: 0.8, + Source: agent.DecisionSourceModel, + SourceReason: "validated_model_intent", + }, + } + srv := NewWithDependencies(&config.Config{ + Env: "test", + LLMRateLimitPerPlayerPerMin: 1, + LLMTokenBudgetPerPlayerDay: 0, + }, nil, decider) + srv.auth = staticAuthVerifier{playerID: "auth-user"} + + first := httptest.NewRecorder() + srv.Routes().ServeHTTP(first, withBearer(newAgentDecideRequest("body-profile-1"))) + if first.Code != http.StatusOK { + t.Fatalf("expected first decision 200, got %d: %s", first.Code, first.Body.String()) + } + + second := httptest.NewRecorder() + srv.Routes().ServeHTTP(second, withBearer(newAgentDecideRequest("body-profile-2"))) + if second.Code != http.StatusTooManyRequests { + t.Fatalf("expected second decision to use auth subject limit key, got %d: %s", second.Code, second.Body.String()) + } + if !bytes.Contains(second.Body.Bytes(), []byte(`"player_id":"auth-user"`)) { + t.Fatalf("expected limit response to name trusted auth subject, got %s", second.Body.String()) + } + if decider.calls != 1 { + t.Fatalf("expected decider to be called once, got %d", decider.calls) + } +} + +func TestAgentDecideRequiresAuthWhenVerifierConfigured(t *testing.T) { + srv := NewWithDependencies(&config.Config{ + Env: "test", + LLMRateLimitPerPlayerPerMin: 1, + }, nil, &staticAgentDecider{}) + srv.auth = staticAuthVerifier{playerID: "auth-user"} + + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, newAgentDecideRequest("body-profile-1")) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected missing auth to return 401, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func TestAgentDecideRejectsOversizedBearerBeforeDecider(t *testing.T) { + decider := &staticAgentDecider{} + srv := NewWithDependencies(&config.Config{Env: "test", SupabaseJWTSecret: "test-secret"}, nil, decider) + req := newAgentDecideRequest("body-profile-1") + req.Header.Set("Authorization", "Bearer "+strings.Repeat("a", 7*1024)) + + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected oversized auth to return 401, got %d: %s", rec.Code, rec.Body.String()) + } + if decider.calls != 0 { + t.Fatalf("expected decider not to be called, got %d", decider.calls) + } +} + +func TestAgentDecideTokenBudgetPerPlayer(t *testing.T) { + decider := &staticAgentDecider{ + decision: agent.Decision{ + Action: agent.ActionSay, + Say: "This should not be reached.", + Confidence: 0.8, + }, + } + srv := NewWithDependencies(&config.Config{ + Env: "test", + LLMRateLimitPerPlayerPerMin: 0, + LLMTokenBudgetPerPlayerDay: agentDecisionOutputTokenReserve - 1, + }, nil, decider) + + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, newAgentDecideRequest("budget-user")) + if rec.Code != http.StatusTooManyRequests { + t.Fatalf("expected decision 429, got %d: %s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte(`"reason":"token_budget_exhausted"`)) { + t.Fatalf("expected token budget reason, got %s", rec.Body.String()) + } + if decider.calls != 0 { + t.Fatalf("expected decider not to be called, got %d", decider.calls) + } +} + +func TestAgentDecisionLimiterResetsWindows(t *testing.T) { + now := time.Date(2026, 5, 17, 12, 0, 30, 0, time.UTC) + limiter := newAgentDecisionLimiter(&config.Config{ + LLMRateLimitPerPlayerPerMin: 1, + LLMTokenBudgetPerPlayerDay: 500, + }, func() time.Time { return now }) + + if allowed, result := limiter.Allow("reset-user", 400); !allowed { + t.Fatalf("expected first request allowed, got %+v", result) + } + if allowed, result := limiter.Allow("reset-user", 1); allowed || result.Reason != "rate_limit_exceeded" { + t.Fatalf("expected same-minute rate limit, allowed=%t result=%+v", allowed, result) + } + + now = now.Add(time.Minute) + if allowed, result := limiter.Allow("reset-user", 100); !allowed { + t.Fatalf("expected next-minute request allowed, got %+v", result) + } + + now = now.Add(time.Minute) + if allowed, result := limiter.Allow("reset-user", 1); allowed || result.Reason != "token_budget_exhausted" { + t.Fatalf("expected same-day budget exhaustion, allowed=%t result=%+v", allowed, result) + } + + now = now.Add(24 * time.Hour) + if allowed, result := limiter.Allow("reset-user", 400); !allowed { + t.Fatalf("expected next-day budget reset, got %+v", result) + } +} + +func TestAgentDecisionLimiterPrunesExpiredPlayerState(t *testing.T) { + now := time.Date(2026, 5, 17, 12, 0, 0, 0, time.UTC) + limiter := newAgentDecisionLimiter(&config.Config{ + LLMRateLimitPerPlayerPerMin: 10, + }, func() time.Time { return now }) + + if allowed, result := limiter.Allow("old-user", 1); !allowed { + t.Fatalf("expected old user request allowed, got %+v", result) + } + now = now.Add(agentDecisionLimitStateTTL + time.Minute) + if allowed, result := limiter.Allow("new-user", 1); !allowed { + t.Fatalf("expected new user request allowed, got %+v", result) + } + if _, ok := limiter.players["old-user"]; ok { + t.Fatal("expected stale player limiter state to be pruned") + } +} + func TestNPCChatPrototype(t *testing.T) { srv := New(&config.Config{Env: "test"}) @@ -229,9 +413,11 @@ func TestNPCChatPrototype(t *testing.T) { type staticAgentDecider struct { decision agent.Decision stopWasAllowed bool + calls int } func (d *staticAgentDecider) Decide(_ context.Context, req agent.DecisionRequest) (agent.Decision, error) { + d.calls++ for _, action := range req.Allowed { if action == agent.ActionStop { d.stopWasAllowed = true @@ -240,3 +426,34 @@ func (d *staticAgentDecider) Decide(_ context.Context, req agent.DecisionRequest } return d.decision, nil } + +type staticAuthVerifier struct { + playerID string +} + +func (v staticAuthVerifier) Verify(_ context.Context, _ string) (auth.Identity, error) { + return auth.Identity{PlayerID: auth.PlayerID(v.playerID)}, nil +} + +func newAgentDecideRequest(playerID string) *http.Request { + return httptest.NewRequest(http.MethodPost, "/v1/agent/decide", bytes.NewReader([]byte(`{ + "context": { + "player": { + "player_id": "`+playerID+`", + "display_name": "Budget Test" + } + }, + "world_snapshot": { + "zone_id": "hub", + "position": {"x": 0, "z": 0}, + "safe_radius": 5, + "body_time_seconds": 3600 + }, + "allowed": ["say"] + }`))) +} + +func withBearer(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Bearer test-token") + return req +}