From 5cb9976fc4cf64d591472dee31d4c9672c3615d9 Mon Sep 17 00:00:00 2001 From: JonathanLab Date: Mon, 30 Mar 2026 14:32:26 +0200 Subject: [PATCH] feat: pass auth tokens by reference, not value --- apps/code/src/main/di/container.ts | 2 + apps/code/src/main/di/tokens.ts | 1 + apps/code/src/main/menu.ts | 4 +- .../main/services/agent/auth-adapter.test.ts | 182 ++++++++++++++ .../src/main/services/agent/auth-adapter.ts | 211 ++++++++++++++++ apps/code/src/main/services/agent/schemas.ts | 9 - .../src/main/services/agent/service.test.ts | 92 ++----- apps/code/src/main/services/agent/service.ts | 233 ++---------------- .../src/main/services/auth-proxy/service.ts | 45 ++-- apps/code/src/main/services/auth/service.ts | 61 +++++ .../src/main/services/cloud-task/schemas.ts | 4 - .../src/main/services/cloud-task/service.ts | 111 +++------ apps/code/src/main/services/git/schemas.ts | 8 - apps/code/src/main/services/git/service.ts | 5 - .../src/main/services/llm-gateway/schemas.ts | 8 - .../src/main/services/llm-gateway/service.ts | 31 ++- apps/code/src/main/services/ui/service.ts | 14 +- apps/code/src/main/trpc/routers/agent.ts | 9 +- apps/code/src/main/trpc/routers/cloud-task.ts | 5 - apps/code/src/main/trpc/routers/git.ts | 10 +- .../code/src/main/trpc/routers/llm-gateway.ts | 2 +- apps/code/src/main/trpc/routers/oauth.ts | 31 +-- .../components/GlobalEventHandlers.tsx | 18 +- .../components/ScopeReauthPrompt.test.tsx | 6 - .../features/auth/stores/authStore.test.ts | 35 ++- .../features/auth/stores/authStore.ts | 48 ++-- .../hooks/useGitInteraction.ts | 62 ----- .../features/sessions/service/service.test.ts | 4 - .../features/sessions/service/service.ts | 18 +- .../sidebar/components/ProjectSwitcher.tsx | 6 +- apps/code/src/renderer/utils/generateTitle.ts | 8 +- packages/agent/src/agent.ts | 8 +- packages/agent/src/posthog-api.test.ts | 48 ++++ packages/agent/src/posthog-api.ts | 74 ++++-- packages/agent/src/types.ts | 3 +- 35 files changed, 778 insertions(+), 638 deletions(-) create mode 100644 apps/code/src/main/services/agent/auth-adapter.test.ts create mode 100644 apps/code/src/main/services/agent/auth-adapter.ts create mode 100644 packages/agent/src/posthog-api.test.ts diff --git a/apps/code/src/main/di/container.ts b/apps/code/src/main/di/container.ts index 87eb47635..0b08821a7 100644 --- a/apps/code/src/main/di/container.ts +++ b/apps/code/src/main/di/container.ts @@ -8,6 +8,7 @@ import { SuspensionRepositoryImpl } from "../db/repositories/suspension-reposito import { WorkspaceRepository } from "../db/repositories/workspace-repository"; import { WorktreeRepository } from "../db/repositories/worktree-repository"; import { DatabaseService } from "../db/service"; +import { AgentAuthAdapter } from "../services/agent/auth-adapter"; import { AgentService } from "../services/agent/service"; import { AppLifecycleService } from "../services/app-lifecycle/service"; import { ArchiveService } from "../services/archive/service"; @@ -57,6 +58,7 @@ container.bind(MAIN_TOKENS.WorkspaceRepository).to(WorkspaceRepository); container.bind(MAIN_TOKENS.WorktreeRepository).to(WorktreeRepository); container.bind(MAIN_TOKENS.ArchiveRepository).to(ArchiveRepository); container.bind(MAIN_TOKENS.SuspensionRepository).to(SuspensionRepositoryImpl); +container.bind(MAIN_TOKENS.AgentAuthAdapter).to(AgentAuthAdapter); container.bind(MAIN_TOKENS.AgentService).to(AgentService); container.bind(MAIN_TOKENS.AuthService).to(AuthService); container.bind(MAIN_TOKENS.AuthProxyService).to(AuthProxyService); diff --git a/apps/code/src/main/di/tokens.ts b/apps/code/src/main/di/tokens.ts index ad8ee191b..8080584b4 100644 --- a/apps/code/src/main/di/tokens.ts +++ b/apps/code/src/main/di/tokens.ts @@ -18,6 +18,7 @@ export const MAIN_TOKENS = Object.freeze({ SuspensionRepository: Symbol.for("Main.SuspensionRepository"), // Services + AgentAuthAdapter: Symbol.for("Main.AgentAuthAdapter"), AgentService: Symbol.for("Main.AgentService"), AuthService: Symbol.for("Main.AuthService"), AuthProxyService: Symbol.for("Main.AuthProxyService"), diff --git a/apps/code/src/main/menu.ts b/apps/code/src/main/menu.ts index 5fabab42e..48e475207 100644 --- a/apps/code/src/main/menu.ts +++ b/apps/code/src/main/menu.ts @@ -127,7 +127,9 @@ function buildFileMenu(): MenuItemConstructorOptions { { label: "Invalidate OAuth token", click: () => { - container.get(MAIN_TOKENS.UIService).invalidateToken(); + void container + .get(MAIN_TOKENS.UIService) + .invalidateToken(); }, }, { diff --git a/apps/code/src/main/services/agent/auth-adapter.test.ts b/apps/code/src/main/services/agent/auth-adapter.test.ts new file mode 100644 index 000000000..acb58fd45 --- /dev/null +++ b/apps/code/src/main/services/agent/auth-adapter.test.ts @@ -0,0 +1,182 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +const mockFetch = vi.hoisted(() => vi.fn()); + +vi.mock("../../utils/logger.js", () => ({ + logger: { + scope: () => ({ + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + debug: vi.fn(), + }), + }, +})); + +vi.mock("@posthog/agent/posthog-api", () => ({ + getLlmGatewayUrl: vi.fn(() => "https://gateway.example.com"), +})); + +vi.stubGlobal("fetch", mockFetch); + +import { AgentAuthAdapter } from "./auth-adapter"; + +const baseCredentials = { + apiHost: "https://app.posthog.com", + projectId: 1, +}; + +function createDependencies() { + return { + authService: { + getValidAccessToken: vi.fn().mockResolvedValue({ + accessToken: "test-access-token", + apiHost: "https://app.posthog.com", + }), + refreshAccessToken: vi.fn().mockResolvedValue({ + accessToken: "fresh-access-token", + apiHost: "https://app.posthog.com", + }), + authenticatedFetch: vi + .fn() + .mockImplementation( + async ( + fetchImpl: typeof fetch, + input: string | Request, + init?: RequestInit, + ) => fetchImpl(input, init), + ), + }, + authProxy: { + start: vi.fn().mockResolvedValue("http://127.0.0.1:9999"), + }, + }; +} + +describe("AgentAuthAdapter", () => { + let adapter: AgentAuthAdapter; + let deps: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + mockFetch.mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ results: [] }), + }); + + deps = createDependencies(); + adapter = new AgentAuthAdapter( + deps.authService as never, + deps.authProxy as never, + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("builds the default PostHog MCP server", async () => { + const servers = await adapter.buildMcpServers(baseCredentials); + + expect(servers).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + name: "posthog", + type: "http", + url: "https://mcp.posthog.com/mcp", + headers: expect.arrayContaining([ + { + name: "Authorization", + value: "Bearer test-access-token", + }, + ]), + }), + ]), + ); + }); + + it("includes enabled user-installed MCP servers from backend", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + results: [ + { + id: "inst-1", + url: "https://custom-mcp.example.com", + proxy_url: "https://proxy.posthog.com/inst-1/", + name: "custom-server", + display_name: "Custom Server", + auth_type: "none", + is_enabled: true, + pending_oauth: false, + needs_reauth: false, + }, + ], + }), + }); + + const servers = await adapter.buildMcpServers(baseCredentials); + + expect(servers).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + name: "custom-server", + url: "https://custom-mcp.example.com", + headers: [], + }), + ]), + ); + }); + + it("routes authenticated installed MCP servers through the proxy URL", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + results: [ + { + id: "inst-2", + url: "https://remote-mcp.example.com", + proxy_url: "https://proxy.posthog.com/inst-2/", + name: "secure-server", + display_name: "Secure Server", + auth_type: "oauth", + is_enabled: true, + pending_oauth: false, + needs_reauth: false, + }, + ], + }), + }); + + const servers = await adapter.buildMcpServers(baseCredentials); + + expect(servers).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + name: "secure-server", + url: "https://proxy.posthog.com/inst-2/", + headers: [ + { name: "Authorization", value: "Bearer test-access-token" }, + ], + }), + ]), + ); + }); + + it("configures environment using the gateway proxy and current token", async () => { + await adapter.configureProcessEnv({ + credentials: baseCredentials, + mockNodeDir: "/mock/node", + proxyUrl: "http://127.0.0.1:9999", + claudeCliPath: "/mock/claude-cli.js", + }); + + expect(process.env.POSTHOG_API_KEY).toBe("test-access-token"); + expect(process.env.POSTHOG_AUTH_HEADER).toBe("Bearer test-access-token"); + expect(process.env.LLM_GATEWAY_URL).toBe("http://127.0.0.1:9999"); + expect(process.env.CLAUDE_CODE_EXECUTABLE).toBe("/mock/claude-cli.js"); + expect(process.env.POSTHOG_PROJECT_ID).toBe("1"); + }); +}); diff --git a/apps/code/src/main/services/agent/auth-adapter.ts b/apps/code/src/main/services/agent/auth-adapter.ts new file mode 100644 index 000000000..ad635474a --- /dev/null +++ b/apps/code/src/main/services/agent/auth-adapter.ts @@ -0,0 +1,211 @@ +import { delimiter } from "node:path"; +import { getLlmGatewayUrl } from "@posthog/agent/posthog-api"; +import { inject, injectable } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { logger } from "../../utils/logger"; +import type { AuthService } from "../auth/service"; +import type { AuthProxyService } from "../auth-proxy/service"; +import type { Credentials } from "./schemas"; + +const log = logger.scope("agent-auth-adapter"); + +export interface AcpMcpServer { + name: string; + type: "http"; + url: string; + headers: Array<{ name: string; value: string }>; +} + +export interface AgentPosthogConfig { + apiUrl: string; + getApiKey: () => Promise; + refreshApiKey: () => Promise; + projectId: number; +} + +interface ConfigureProcessEnvInput { + credentials: Credentials; + mockNodeDir: string; + proxyUrl: string; + claudeCliPath: string; +} + +@injectable() +export class AgentAuthAdapter { + constructor( + @inject(MAIN_TOKENS.AuthService) + private readonly authService: AuthService, + @inject(MAIN_TOKENS.AuthProxyService) + private readonly authProxy: AuthProxyService, + ) {} + + createPosthogConfig(credentials: Credentials): AgentPosthogConfig { + return { + apiUrl: credentials.apiHost, + getApiKey: () => this.getValidToken(), + refreshApiKey: () => this.refreshToken(), + projectId: credentials.projectId, + }; + } + + async buildMcpServers(credentials: Credentials): Promise { + const servers: AcpMcpServer[] = []; + const mcpUrl = this.getPostHogMcpUrl(credentials.apiHost); + const token = await this.getValidToken(); + + servers.push({ + name: "posthog", + type: "http", + url: mcpUrl, + headers: [ + { name: "Authorization", value: `Bearer ${token}` }, + { + name: "x-posthog-project-id", + value: String(credentials.projectId), + }, + { name: "x-posthog-mcp-version", value: "2" }, + ], + }); + + const installations = await this.fetchMcpInstallations(credentials); + + for (const installation of installations) { + if (installation.url === mcpUrl) continue; + + if (installation.auth_type === "none") { + servers.push({ + name: + installation.name || installation.display_name || installation.url, + type: "http", + url: installation.url, + headers: [], + }); + continue; + } + + servers.push({ + name: + installation.name || installation.display_name || installation.url, + type: "http", + url: installation.proxy_url, + headers: [{ name: "Authorization", value: `Bearer ${token}` }], + }); + } + + return servers; + } + + async ensureGatewayProxy(apiHost: string): Promise { + return this.authProxy.start(getLlmGatewayUrl(apiHost)); + } + + async configureProcessEnv({ + credentials, + mockNodeDir, + proxyUrl, + claudeCliPath, + }: ConfigureProcessEnvInput): Promise { + await this.getValidToken(); + + const currentPath = process.env.PATH || ""; + if (!currentPath.split(delimiter).includes(mockNodeDir)) { + process.env.PATH = `${mockNodeDir}${delimiter}${currentPath}`; + } + + process.env.LLM_GATEWAY_URL = proxyUrl; + process.env.CLAUDE_CODE_EXECUTABLE = claudeCliPath; + process.env.POSTHOG_API_URL = credentials.apiHost; + process.env.POSTHOG_PROJECT_ID = String(credentials.projectId); + } + + private syncTokenEnvironment(token: string): void { + process.env.POSTHOG_API_KEY = token; + process.env.POSTHOG_AUTH_HEADER = `Bearer ${token}`; + } + + private async getValidToken(): Promise { + const { accessToken } = await this.authService.getValidAccessToken(); + this.syncTokenEnvironment(accessToken); + return accessToken; + } + + private async refreshToken(): Promise { + const { accessToken } = await this.authService.refreshAccessToken(); + this.syncTokenEnvironment(accessToken); + return accessToken; + } + + private getPostHogMcpUrl(apiHost: string): string { + const overrideUrl = process.env.POSTHOG_MCP_URL; + if (overrideUrl) { + return overrideUrl; + } + if (apiHost.includes("localhost") || apiHost.includes("127.0.0.1")) { + return "http://localhost:8787/mcp"; + } + return "https://mcp.posthog.com/mcp"; + } + + private getPostHogApiBaseUrl(apiHost: string): string { + const host = process.env.POSTHOG_PROXY_BASE_URL || apiHost; + return host.endsWith("/") ? host.slice(0, -1) : host; + } + + private async fetchMcpInstallations(credentials: Credentials): Promise< + Array<{ + id: string; + url: string; + proxy_url: string; + name: string; + display_name: string; + auth_type: string; + }> + > { + const baseUrl = this.getPostHogApiBaseUrl(credentials.apiHost); + const url = `${baseUrl}/api/environments/${credentials.projectId}/mcp_server_installations/`; + + try { + const response = await this.authService.authenticatedFetch(fetch, url, { + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + log.warn("Failed to fetch MCP installations", { + status: response.status, + }); + return []; + } + + const data = (await response.json()) as { + results?: Array<{ + id: string; + url: string; + proxy_url?: string; + name: string; + display_name: string; + auth_type: string; + is_enabled?: boolean; + pending_oauth: boolean; + needs_reauth: boolean; + }>; + }; + const installations = data.results ?? []; + + return installations + .filter( + (i) => !i.pending_oauth && !i.needs_reauth && i.is_enabled !== false, + ) + .map((i) => ({ + ...i, + proxy_url: + i.proxy_url ?? + `${baseUrl}/api/environments/${credentials.projectId}/mcp_server_installations/${i.id}/proxy/`, + })); + } catch (err) { + log.warn("Error fetching MCP installations", { error: err }); + return []; + } + } +} diff --git a/apps/code/src/main/services/agent/schemas.ts b/apps/code/src/main/services/agent/schemas.ts index 071436736..dafcaf169 100644 --- a/apps/code/src/main/services/agent/schemas.ts +++ b/apps/code/src/main/services/agent/schemas.ts @@ -10,7 +10,6 @@ export type { EffortLevel } from "@shared/types"; // Session credentials schema export const credentialsSchema = z.object({ - apiKey: z.string(), apiHost: z.string(), projectId: z.number(), }); @@ -41,7 +40,6 @@ export const startSessionInput = z.object({ taskId: z.string(), taskRunId: z.string(), repoPath: z.string(), - apiKey: z.string(), apiHost: z.string(), projectId: z.number(), permissionMode: z.string().optional(), @@ -174,7 +172,6 @@ export const reconnectSessionInput = z.object({ taskId: z.string(), taskRunId: z.string(), repoPath: z.string(), - apiKey: z.string(), apiHost: z.string(), projectId: z.number(), logUrl: z.string().optional(), @@ -189,11 +186,6 @@ export const reconnectSessionInput = z.object({ export type ReconnectSessionInput = z.infer; -// Token update input - updates the global token for all agent operations -export const tokenUpdateInput = z.object({ - token: z.string(), -}); - // Set config option input (for Codex reasoning level, etc.) export const setConfigOptionInput = z.object({ sessionId: z.string(), @@ -297,7 +289,6 @@ export const listSessionsOutput = z.array(sessionInfoSchema); export const getGatewayModelsInput = z.object({ apiHost: z.string(), - apiKey: z.string(), }); export const getGatewayModelsOutput = z.array(modelOptionSchema); diff --git a/apps/code/src/main/services/agent/service.test.ts b/apps/code/src/main/services/agent/service.test.ts index 8dd751391..23db2eba2 100644 --- a/apps/code/src/main/services/agent/service.test.ts +++ b/apps/code/src/main/services/agent/service.test.ts @@ -47,8 +47,6 @@ const mockAgentConstructor = vi.hoisted(() => }), ); -const mockFetch = vi.hoisted(() => vi.fn()); - // --- Module mocks --- const mockPowerMonitor = vi.hoisted(() => ({ @@ -127,8 +125,6 @@ vi.mock("node:fs", async (importOriginal) => { }; }); -vi.stubGlobal("fetch", mockFetch); - // --- Import after mocks --- import { AgentService } from "./service"; @@ -154,12 +150,31 @@ function createMockDependencies() { posthogPluginService: { getPluginPath: vi.fn(() => "/mock/plugin"), }, - authProxy: { - start: vi.fn().mockResolvedValue("http://127.0.0.1:9999"), - stop: vi.fn().mockResolvedValue(undefined), - updateToken: vi.fn(), - getProxyUrl: vi.fn(() => "http://127.0.0.1:9999"), - isRunning: vi.fn(() => false), + agentAuthAdapter: { + ensureGatewayProxy: vi.fn().mockResolvedValue("http://127.0.0.1:9999"), + configureProcessEnv: vi.fn().mockResolvedValue(undefined), + createPosthogConfig: vi.fn((credentials) => ({ + apiUrl: credentials.apiHost, + getApiKey: vi.fn().mockResolvedValue("test-access-token"), + refreshApiKey: vi.fn().mockResolvedValue("fresh-access-token"), + projectId: credentials.projectId, + })), + buildMcpServers: vi.fn().mockResolvedValue([ + { + name: "posthog", + type: "http", + url: "https://mcp.posthog.com/mcp", + headers: [], + }, + ]), + }, + mcpAppsService: { + setServerConfigs: vi.fn(), + handleDiscovery: vi.fn().mockResolvedValue(undefined), + cleanup: vi.fn().mockResolvedValue(undefined), + notifyToolInput: vi.fn(), + notifyToolResult: vi.fn(), + notifyToolCancelled: vi.fn(), }, }; } @@ -168,7 +183,6 @@ const baseSessionParams = { taskId: "task-1", taskRunId: "run-1", repoPath: "/mock/repo", - apiKey: "test-api-key", apiHost: "https://app.posthog.com", projectId: 1, }; @@ -179,27 +193,14 @@ describe("AgentService", () => { beforeEach(() => { vi.clearAllMocks(); - // MCP installations endpoint returns empty - mockFetch.mockResolvedValue({ - ok: true, - json: () => Promise.resolve({ results: [] }), - }); - const deps = createMockDependencies(); service = new AgentService( deps.processTracking as never, deps.sleepService as never, deps.fsService as never, deps.posthogPluginService as never, - deps.authProxy as never, - { - setServerConfigs: vi.fn(), - handleDiscovery: vi.fn().mockResolvedValue(undefined), - cleanup: vi.fn().mockResolvedValue(undefined), - notifyToolInput: vi.fn(), - notifyToolResult: vi.fn(), - notifyToolCancelled: vi.fn(), - } as never, + deps.agentAuthAdapter as never, + deps.mcpAppsService as never, ); }); @@ -263,45 +264,6 @@ describe("AgentService", () => { const codexMcp = mockNewSession.mock.calls[1][0].mcpServers; expect(codexMcp).toEqual(claudeMcp); }); - - it("includes user-installed MCP servers from backend", async () => { - mockFetch.mockResolvedValue({ - ok: true, - json: () => - Promise.resolve({ - results: [ - { - id: "inst-1", - url: "https://custom-mcp.example.com", - proxy_url: "https://proxy.posthog.com/inst-1/", - name: "custom-server", - display_name: "Custom Server", - auth_type: "none", - is_enabled: true, - pending_oauth: false, - needs_reauth: false, - }, - ], - }), - }); - - await service.startSession({ - ...baseSessionParams, - adapter: "codex", - }); - - const mcpServers = mockNewSession.mock.calls[0][0].mcpServers; - expect(mcpServers).toHaveLength(2); - expect(mcpServers).toEqual( - expect.arrayContaining([ - expect.objectContaining({ name: "posthog" }), - expect.objectContaining({ - name: "custom-server", - url: "https://custom-mcp.example.com", - }), - ]), - ); - }); }); describe("idle timeout", () => { diff --git a/apps/code/src/main/services/agent/service.ts b/apps/code/src/main/services/agent/service.ts index 933a815df..9d64c2ac9 100644 --- a/apps/code/src/main/services/agent/service.ts +++ b/apps/code/src/main/services/agent/service.ts @@ -1,6 +1,6 @@ import fs, { mkdirSync, symlinkSync } from "node:fs"; import { tmpdir } from "node:os"; -import { delimiter, isAbsolute, join, relative, resolve, sep } from "node:path"; +import { isAbsolute, join, relative, resolve, sep } from "node:path"; import { type Client, ClientSideConnection, @@ -30,12 +30,12 @@ import { MAIN_TOKENS } from "../../di/tokens"; import { isDevBuild } from "../../utils/env"; import { logger } from "../../utils/logger"; import { TypedEventEmitter } from "../../utils/typed-event-emitter"; -import type { AuthProxyService } from "../auth-proxy/service"; import type { FsService } from "../fs/service"; import type { McpAppsService } from "../mcp-apps/service"; import type { PosthogPluginService } from "../posthog-plugin/service"; import type { ProcessTrackingService } from "../process-tracking/service"; import type { SleepService } from "../sleep/service"; +import type { AgentAuthAdapter } from "./auth-adapter"; import { discoverExternalPlugins } from "./discover-plugins"; import { AgentServiceEvent, @@ -172,13 +172,6 @@ const onAgentLog: OnLogCallback = (level, scope, message, data) => { } }; -interface AcpMcpServer { - name: string; - type: "http"; - url: string; - headers: Array<{ name: string; value: string }>; -} - interface SessionConfig { taskId: string; taskRunId: string; @@ -251,7 +244,6 @@ export class AgentService extends TypedEventEmitter { private static readonly IDLE_TIMEOUT_MS = 15 * 60 * 1000; private sessions = new Map(); - private currentToken: string | null = null; private pendingPermissions = new Map(); private mockNodeReady = false; private idleTimeouts = new Map< @@ -262,7 +254,7 @@ export class AgentService extends TypedEventEmitter { private sleepService: SleepService; private fsService: FsService; private posthogPluginService: PosthogPluginService; - private authProxy: AuthProxyService; + private agentAuthAdapter: AgentAuthAdapter; private mcpAppsService: McpAppsService; constructor( @@ -274,8 +266,8 @@ export class AgentService extends TypedEventEmitter { fsService: FsService, @inject(MAIN_TOKENS.PosthogPluginService) posthogPluginService: PosthogPluginService, - @inject(MAIN_TOKENS.AuthProxyService) - authProxy: AuthProxyService, + @inject(MAIN_TOKENS.AgentAuthAdapter) + agentAuthAdapter: AgentAuthAdapter, @inject(MAIN_TOKENS.McpAppsService) mcpAppsService: McpAppsService, ) { @@ -284,31 +276,12 @@ export class AgentService extends TypedEventEmitter { this.sleepService = sleepService; this.fsService = fsService; this.posthogPluginService = posthogPluginService; - this.authProxy = authProxy; + this.agentAuthAdapter = agentAuthAdapter; this.mcpAppsService = mcpAppsService; powerMonitor.on("resume", () => this.checkIdleDeadlines()); } - public updateToken(newToken: string): void { - this.currentToken = newToken; - - if (this.authProxy.isRunning()) { - this.authProxy.updateToken(newToken); - } - - process.env.ANTHROPIC_API_KEY = newToken; - process.env.ANTHROPIC_AUTH_TOKEN = newToken; - process.env.OPENAI_API_KEY = newToken; - process.env.POSTHOG_API_KEY = newToken; - process.env.POSTHOG_AUTH_HEADER = `Bearer ${newToken}`; - - log.info("Token updated (proxy + env vars)", { - sessionCount: this.sessions.size, - proxyRunning: this.authProxy.isRunning(), - }); - } - /** * Respond to a pending permission request from the UI. * This resolves the promise that the agent is waiting on. @@ -436,123 +409,6 @@ export class AgentService extends TypedEventEmitter { } } - private getToken(fallback: string): string { - return this.currentToken || fallback; - } - - private async buildMcpServers( - credentials: Credentials, - ): Promise { - const servers: AcpMcpServer[] = []; - - const mcpUrl = this.getPostHogMcpUrl(credentials.apiHost); - const token = this.getToken(credentials.apiKey); - - servers.push({ - name: "posthog", - type: "http", - url: mcpUrl, - headers: [ - { name: "Authorization", value: `Bearer ${token}` }, - { - name: "x-posthog-project-id", - value: String(credentials.projectId), - }, - { name: "x-posthog-mcp-version", value: "2" }, - ], - }); - - // Fetch user-installed MCP servers from the PostHog backend - const installations = await this.fetchMcpInstallations(credentials); - - for (const installation of installations) { - // Skip the PostHog MCP server since it's already included above - if (installation.url === mcpUrl) continue; - - if (installation.auth_type === "none") { - servers.push({ - name: - installation.name || installation.display_name || installation.url, - type: "http", - url: installation.url, - headers: [], - }); - } else { - // Authenticated servers go through the PostHog proxy so credentials - // never leave the backend - servers.push({ - name: - installation.name || installation.display_name || installation.url, - type: "http", - url: installation.proxy_url, - headers: [{ name: "Authorization", value: `Bearer ${token}` }], - }); - } - } - - return servers; - } - - private async fetchMcpInstallations(credentials: Credentials): Promise< - Array<{ - id: string; - url: string; - proxy_url: string; - name: string; - display_name: string; - auth_type: string; - }> - > { - const token = this.getToken(credentials.apiKey); - const baseUrl = this.getPostHogApiBaseUrl(credentials.apiHost); - const url = `${baseUrl}/api/environments/${credentials.projectId}/mcp_server_installations/`; - - try { - const response = await fetch(url, { - headers: { - Authorization: `Bearer ${token}`, - "Content-Type": "application/json", - }, - }); - - if (!response.ok) { - log.warn("Failed to fetch MCP installations", { - status: response.status, - }); - return []; - } - - const data = (await response.json()) as { - results?: Array<{ - id: string; - url: string; - proxy_url?: string; - name: string; - display_name: string; - auth_type: string; - is_enabled?: boolean; - pending_oauth: boolean; - needs_reauth: boolean; - }>; - }; - const installations = data.results ?? []; - - return installations - .filter( - (i) => !i.pending_oauth && !i.needs_reauth && i.is_enabled !== false, - ) - .map((i) => ({ - ...i, - proxy_url: - i.proxy_url ?? - `${baseUrl}/api/environments/${credentials.projectId}/mcp_server_installations/${i.id}/proxy/`, - })); - } catch (err) { - log.warn("Error fetching MCP installations", { error: err }); - return []; - } - } - private buildSystemPrompt( credentials: Credentials, customInstructions?: string, @@ -568,22 +424,6 @@ export class AgentService extends TypedEventEmitter { return { append: prompt }; } - private getPostHogMcpUrl(apiHost: string): string { - const overrideUrl = process.env.POSTHOG_MCP_URL; - if (overrideUrl) { - return overrideUrl; - } - if (apiHost.includes("localhost") || apiHost.includes("127.0.0.1")) { - return "http://localhost:8787/mcp"; - } - return "https://mcp.posthog.com/mcp"; - } - - private getPostHogApiBaseUrl(apiHost: string): string { - const host = process.env.POSTHOG_PROXY_BASE_URL || apiHost; - return host.endsWith("/") ? host.slice(0, -1) : host; - } - async startSession(params: StartSessionInput): Promise { this.validateSessionParams(params); const config = this.toSessionConfig(params); @@ -648,16 +488,21 @@ export class AgentService extends TypedEventEmitter { const channel = `agent-event:${taskRunId}`; const mockNodeDir = this.setupMockNodeEnvironment(); - const proxyUrl = await this.ensureAuthProxy(credentials); - this.setupEnvironment(credentials, mockNodeDir, proxyUrl); + const proxyUrl = await this.agentAuthAdapter.ensureGatewayProxy( + credentials.apiHost, + ); + await this.agentAuthAdapter.configureProcessEnv({ + credentials, + mockNodeDir, + proxyUrl, + claudeCliPath: getClaudeCliPath(), + }); const isPreview = taskId === "__preview__"; const agent = new Agent({ posthog: { - apiUrl: credentials.apiHost, - getApiKey: () => this.getToken(credentials.apiKey), - projectId: credentials.projectId, + ...this.agentAuthAdapter.createPosthogConfig(credentials), userAgent: `posthog/desktop.hog.dev; version: ${app.getVersion()}`, }, skipLogPersistence: isPreview, @@ -716,7 +561,8 @@ export class AgentService extends TypedEventEmitter { }, }); - const mcpServers = await this.buildMcpServers(credentials); + const mcpServers = + await this.agentAuthAdapter.buildMcpServers(credentials); // Store server configs for lazy MCP connections — actual connections // are created on-demand when UI resources are first requested. @@ -1162,42 +1008,6 @@ For git operations while detached: log.info("All agent sessions cleaned up"); } - private async ensureAuthProxy(credentials: Credentials): Promise { - const token = this.getToken(credentials.apiKey); - const llmGatewayUrl = getLlmGatewayUrl(credentials.apiHost); - return this.authProxy.start(llmGatewayUrl, token); - } - - private setupEnvironment( - credentials: Credentials, - mockNodeDir: string, - proxyUrl: string, - ): void { - const token = this.getToken(credentials.apiKey); - const currentPath = process.env.PATH || ""; - if (!currentPath.split(delimiter).includes(mockNodeDir)) { - process.env.PATH = `${mockNodeDir}${delimiter}${currentPath}`; - } - process.env.POSTHOG_AUTH_HEADER = `Bearer ${token}`; - process.env.ANTHROPIC_API_KEY = token; - process.env.ANTHROPIC_AUTH_TOKEN = token; - - process.env.ANTHROPIC_BASE_URL = proxyUrl; - - const openaiBaseUrl = proxyUrl.endsWith("/v1") - ? proxyUrl - : `${proxyUrl}/v1`; - process.env.OPENAI_BASE_URL = openaiBaseUrl; - process.env.OPENAI_API_KEY = token; - process.env.LLM_GATEWAY_URL = proxyUrl; - - process.env.CLAUDE_CODE_EXECUTABLE = getClaudeCliPath(); - - process.env.POSTHOG_API_KEY = token; - process.env.POSTHOG_API_URL = credentials.apiHost; - process.env.POSTHOG_PROJECT_ID = String(credentials.projectId); - } - private setupMockNodeEnvironment(): string { const mockNodeDir = getMockNodeDir(); if (!this.mockNodeReady) { @@ -1494,8 +1304,8 @@ For git operations while detached: if (!params.taskId || !params.repoPath) { throw new Error("taskId and repoPath are required"); } - if (!params.apiKey || !params.apiHost) { - throw new Error("PostHog API credentials are required"); + if (!params.apiHost) { + throw new Error("PostHog API host is required"); } } @@ -1537,7 +1347,6 @@ For git operations while detached: taskRunId: params.taskRunId, repoPath: params.repoPath, credentials: { - apiKey: params.apiKey, apiHost: params.apiHost, projectId: params.projectId, }, @@ -1672,7 +1481,7 @@ For git operations while detached: } } - async getGatewayModels(apiHost: string, _apiKey: string) { + async getGatewayModels(apiHost: string) { const gatewayUrl = getLlmGatewayUrl(apiHost); const models = await fetchGatewayModels({ gatewayUrl }); diff --git a/apps/code/src/main/services/auth-proxy/service.ts b/apps/code/src/main/services/auth-proxy/service.ts index e40c06584..384ec1cc9 100644 --- a/apps/code/src/main/services/auth-proxy/service.ts +++ b/apps/code/src/main/services/auth-proxy/service.ts @@ -1,25 +1,29 @@ import http from "node:http"; -import { injectable } from "inversify"; +import { inject, injectable } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; import { logger } from "../../utils/logger"; +import type { AuthService } from "../auth/service"; const log = logger.scope("auth-proxy"); @injectable() export class AuthProxyService { private server: http.Server | null = null; - private currentToken: string | null = null; private gatewayUrl: string | null = null; private port: number | null = null; - async start(gatewayUrl: string, initialToken: string): Promise { + constructor( + @inject(MAIN_TOKENS.AuthService) + private readonly authService: AuthService, + ) {} + + async start(gatewayUrl: string): Promise { if (this.server) { - this.currentToken = initialToken; this.gatewayUrl = gatewayUrl; return this.getProxyUrl(); } this.gatewayUrl = gatewayUrl; - this.currentToken = initialToken; this.server = http.createServer((req, res) => { this.handleRequest(req, res); @@ -44,10 +48,6 @@ export class AuthProxyService { }); } - updateToken(token: string): void { - this.currentToken = token; - } - getProxyUrl(): string { if (!this.port) { throw new Error("Auth proxy not started"); @@ -76,7 +76,7 @@ export class AuthProxyService { req: http.IncomingMessage, res: http.ServerResponse, ): void { - if (!this.gatewayUrl || !this.currentToken) { + if (!this.gatewayUrl) { res.writeHead(503); res.end("Proxy not configured"); return; @@ -124,15 +124,26 @@ export class AuthProxyService { target: targetUrl.toString(), }); + const strippedAuthHeaders = new Set([ + "authorization", + "x-api-key", + "api-key", + "anthropic-auth-token", + "proxy-authorization", + ]); const headers: Record = {}; for (const [key, value] of Object.entries(req.headers)) { - if (key === "host" || key === "connection") continue; + if ( + key === "host" || + key === "connection" || + strippedAuthHeaders.has(key) + ) { + continue; + } if (typeof value === "string") { headers[key] = value; } } - headers.authorization = `Bearer ${this.currentToken}`; - const fetchOptions: RequestInit = { method: req.method ?? "GET", headers, @@ -156,7 +167,11 @@ export class AuthProxyService { res: http.ServerResponse, ): Promise { try { - const response = await fetch(url, options); + const response = await this.authService.authenticatedFetch( + fetch, + url, + options, + ); log.debug("Proxy response", { url, @@ -169,7 +184,7 @@ export class AuthProxyService { "content-encoding", "content-length", ]); - response.headers.forEach((value, key) => { + response.headers.forEach((value: string, key: string) => { if (stripHeaders.has(key)) return; responseHeaders[key] = value; }); diff --git a/apps/code/src/main/services/auth/service.ts b/apps/code/src/main/services/auth/service.ts index 51faa8318..6472f5f1f 100644 --- a/apps/code/src/main/services/auth/service.ts +++ b/apps/code/src/main/services/auth/service.ts @@ -24,6 +24,10 @@ import { const log = logger.scope("auth-service"); const TOKEN_EXPIRY_SKEW_MS = 60_000; +type FetchLike = ( + input: string | Request, + init?: RequestInit, +) => Promise; interface InMemorySession { accessToken: string; @@ -124,6 +128,48 @@ export class AuthService extends TypedEventEmitter { }; } + async invalidateAccessTokenForTest(): Promise { + await this.initialize(); + + if (!this.session) { + return; + } + + this.session = { + ...this.session, + accessToken: `${this.session.accessToken}_invalid`, + // Keep the token apparently fresh so the next authenticated request + // exercises the 401 -> refresh retry path instead of preemptive refresh. + accessTokenExpiresAt: Date.now() + 5 * 60 * 1000, + }; + } + + async authenticatedFetch( + fetchImpl: FetchLike, + input: string | Request, + init: RequestInit = {}, + ): Promise { + const initialAuth = await this.getValidAccessToken(); + let response = await this.executeAuthenticatedFetch( + fetchImpl, + input, + init, + initialAuth.accessToken, + ); + + if (response.status === 401 || response.status === 403) { + const refreshedAuth = await this.refreshAccessToken(); + response = await this.executeAuthenticatedFetch( + fetchImpl, + input, + init, + refreshedAuth.accessToken, + ); + } + + return response; + } + async redeemInviteCode(code: string): Promise { const { accessToken, apiHost } = await this.getValidAccessToken(); const response = await fetch(`${apiHost}/api/code/invites/redeem/`, { @@ -179,6 +225,21 @@ export class AuthService extends TypedEventEmitter { return this.getState(); } + private executeAuthenticatedFetch( + fetchImpl: FetchLike, + input: string | Request, + init: RequestInit, + accessToken: string, + ): Promise { + const headers = new Headers(init.headers); + headers.set("authorization", `Bearer ${accessToken}`); + + return fetchImpl(input, { + ...init, + headers, + }); + } + private async doInitialize(): Promise { const stored = this.authSessionRepository.getCurrent(); diff --git a/apps/code/src/main/services/cloud-task/schemas.ts b/apps/code/src/main/services/cloud-task/schemas.ts index d845ab71f..b7f1a811a 100644 --- a/apps/code/src/main/services/cloud-task/schemas.ts +++ b/apps/code/src/main/services/cloud-task/schemas.ts @@ -36,10 +36,6 @@ export const unwatchInput = z.object({ runId: z.string(), }); -export const updateTokenInput = z.object({ - token: z.string(), -}); - export const onUpdateInput = z.object({ taskId: z.string(), runId: z.string(), diff --git a/apps/code/src/main/services/cloud-task/service.ts b/apps/code/src/main/services/cloud-task/service.ts index 3a0efbc6b..a01b24792 100644 --- a/apps/code/src/main/services/cloud-task/service.ts +++ b/apps/code/src/main/services/cloud-task/service.ts @@ -1,8 +1,10 @@ import type { StoredLogEntry } from "@shared/types/session-events"; import { net } from "electron"; -import { injectable, preDestroy } from "inversify"; +import { inject, injectable, preDestroy } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; import { logger } from "../../utils/logger"; import { TypedEventEmitter } from "../../utils/typed-event-emitter"; +import type { AuthService } from "../auth/service"; import { CloudTaskEvent, type CloudTaskEvents, @@ -47,11 +49,6 @@ interface WatcherState { viewing: boolean; } -interface PendingWatchState { - input: WatchInput; - subscriberCount: number; -} - function watcherKey(taskId: string, runId: string): string { return `${taskId}:${runId}`; } @@ -59,8 +56,13 @@ function watcherKey(taskId: string, runId: string): string { @injectable() export class CloudTaskService extends TypedEventEmitter { private watchers = new Map(); - private pendingWatches = new Map(); - private apiKey: string | null = null; + + constructor( + @inject(MAIN_TOKENS.AuthService) + private readonly authService: AuthService, + ) { + super(); + } watch(input: WatchInput): void { const key = watcherKey(input.taskId, input.runId); @@ -79,18 +81,6 @@ export class CloudTaskService extends TypedEventEmitter { return; } - // If no token yet, queue (deduplicated by key) - if (!this.apiKey) { - const pending = this.pendingWatches.get(key); - if (pending) { - pending.subscriberCount++; - } else { - this.pendingWatches.set(key, { input, subscriberCount: 1 }); - } - log.info("Cloud task watch queued (no token yet)", { key }); - return; - } - this.startWatcher(input, 1); } @@ -98,13 +88,6 @@ export class CloudTaskService extends TypedEventEmitter { const key = watcherKey(taskId, runId); const watcher = this.watchers.get(key); if (!watcher) { - const pending = this.pendingWatches.get(key); - if (!pending) return; - - pending.subscriberCount--; - if (pending.subscriberCount <= 0) { - this.pendingWatches.delete(key); - } return; } @@ -119,22 +102,6 @@ export class CloudTaskService extends TypedEventEmitter { } } - updateToken(token: string): void { - this.apiKey = token; - - // Drain pending watches - if (this.pendingWatches.size > 0) { - const pending = [...this.pendingWatches.values()]; - this.pendingWatches.clear(); - for (const queued of pending) { - this.startWatcher(queued.input, queued.subscriberCount); - } - log.info("Drained pending cloud task watches", { - count: pending.length, - }); - } - } - setViewing(taskId: string, runId: string, viewing: boolean): void { const key = watcherKey(taskId, runId); const watcher = this.watchers.get(key); @@ -157,10 +124,6 @@ export class CloudTaskService extends TypedEventEmitter { } async sendCommand(input: SendCommandInput): Promise { - if (!this.apiKey) { - return { success: false, error: "No API token available" }; - } - const url = `${input.apiHost}/api/projects/${input.teamId}/tasks/${input.taskId}/runs/${input.runId}/command/`; const body = { jsonrpc: "2.0", @@ -170,14 +133,17 @@ export class CloudTaskService extends TypedEventEmitter { }; try { - const response = await net.fetch(url, { - method: "POST", - headers: { - Authorization: `Bearer ${this.apiKey}`, - "Content-Type": "application/json", + const response = await this.authService.authenticatedFetch( + net.fetch, + url, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(body), }, - body: JSON.stringify(body), - }); + ); if (!response.ok) { const errorText = await response.text().catch(() => ""); @@ -244,7 +210,6 @@ export class CloudTaskService extends TypedEventEmitter { for (const key of [...this.watchers.keys()]) { this.stopWatcher(key); } - this.pendingWatches.clear(); } // --- Private --- @@ -307,7 +272,7 @@ export class CloudTaskService extends TypedEventEmitter { private async poll(key: string, isSnapshot: boolean): Promise { const watcher = this.watchers.get(key); - if (!watcher || !this.apiKey) return; + if (!watcher) return; try { // Only fetch logs when the user is viewing the run @@ -457,20 +422,23 @@ export class CloudTaskService extends TypedEventEmitter { } try { - const response = await net.fetch(url.toString(), { - method: "GET", - headers: { Authorization: `Bearer ${this.apiKey}` }, - }); + const authedResponse = await this.authService.authenticatedFetch( + net.fetch, + url.toString(), + { + method: "GET", + }, + ); - if (!response.ok) { + if (!authedResponse.ok) { log.warn("Cloud task log fetch failed", { - status: response.status, + status: authedResponse.status, taskId: watcher.taskId, }); return { newEntries: [] }; } - const raw = await response.text(); + const raw = await authedResponse.text(); const entries = JSON.parse(raw) as StoredLogEntry[]; if (entries.length === 0) { @@ -549,20 +517,23 @@ export class CloudTaskService extends TypedEventEmitter { const url = `${watcher.apiHost}/api/projects/${watcher.teamId}/tasks/${watcher.taskId}/runs/${watcher.runId}/`; try { - const response = await net.fetch(url, { - method: "GET", - headers: { Authorization: `Bearer ${this.apiKey}` }, - }); + const authedResponse = await this.authService.authenticatedFetch( + net.fetch, + url, + { + method: "GET", + }, + ); - if (!response.ok) { + if (!authedResponse.ok) { log.warn("Cloud task status fetch failed", { - status: response.status, + status: authedResponse.status, taskId: watcher.taskId, }); return null; } - return (await response.json()) as TaskRunResponse; + return (await authedResponse.json()) as TaskRunResponse; } catch (error) { log.warn("Cloud task status fetch error", { taskId: watcher.taskId, diff --git a/apps/code/src/main/services/git/schemas.ts b/apps/code/src/main/services/git/schemas.ts index 310856ada..232f51039 100644 --- a/apps/code/src/main/services/git/schemas.ts +++ b/apps/code/src/main/services/git/schemas.ts @@ -302,10 +302,6 @@ export const getBranchChangedFilesOutput = z.array(changedFileSchema); export const generateCommitMessageInput = z.object({ directoryPath: z.string(), - credentials: z.object({ - apiKey: z.string(), - apiHost: z.string(), - }), }); export const generateCommitMessageOutput = z.object({ @@ -314,10 +310,6 @@ export const generateCommitMessageOutput = z.object({ export const generatePrTitleAndBodyInput = z.object({ directoryPath: z.string(), - credentials: z.object({ - apiKey: z.string(), - apiHost: z.string(), - }), }); export const generatePrTitleAndBodyOutput = z.object({ diff --git a/apps/code/src/main/services/git/service.ts b/apps/code/src/main/services/git/service.ts index 1d361373e..495c1c90c 100644 --- a/apps/code/src/main/services/git/service.ts +++ b/apps/code/src/main/services/git/service.ts @@ -31,7 +31,6 @@ import { inject, injectable } from "inversify"; import { MAIN_TOKENS } from "../../di/tokens"; import { logger } from "../../utils/logger"; import { TypedEventEmitter } from "../../utils/typed-event-emitter"; -import type { LlmCredentials } from "../llm-gateway/schemas"; import type { LlmGatewayService } from "../llm-gateway/service"; import type { ChangedFile, @@ -831,7 +830,6 @@ export class GitService extends TypedEventEmitter { public async generateCommitMessage( directoryPath: string, - credentials: LlmCredentials, ): Promise<{ message: string }> { const [stagedDiff, unstagedDiff, conventions, changedFiles] = await Promise.all([ @@ -890,7 +888,6 @@ ${truncatedDiff}`; }); const response = await this.llmGateway.prompt( - credentials, [{ role: "user", content: userMessage }], { system }, ); @@ -900,7 +897,6 @@ ${truncatedDiff}`; public async generatePrTitleAndBody( directoryPath: string, - credentials: LlmCredentials, ): Promise<{ title: string; body: string }> { await this.fetchIfStale(directoryPath); @@ -982,7 +978,6 @@ ${filesSummary || "(no file changes detected)"}`; }); const response = await this.llmGateway.prompt( - credentials, [{ role: "user", content: userMessage }], { system, maxTokens: 2000 }, ); diff --git a/apps/code/src/main/services/llm-gateway/schemas.ts b/apps/code/src/main/services/llm-gateway/schemas.ts index 15417278f..7b8c1ae3e 100644 --- a/apps/code/src/main/services/llm-gateway/schemas.ts +++ b/apps/code/src/main/services/llm-gateway/schemas.ts @@ -1,12 +1,5 @@ import { z } from "zod"; -export const llmCredentialsSchema = z.object({ - apiKey: z.string(), - apiHost: z.string(), -}); - -export type LlmCredentials = z.infer; - export const llmMessageSchema = z.object({ role: z.enum(["user", "assistant"]), content: z.string(), @@ -15,7 +8,6 @@ export const llmMessageSchema = z.object({ export type LlmMessage = z.infer; export const promptInput = z.object({ - credentials: llmCredentialsSchema, system: z.string().optional(), messages: z.array(llmMessageSchema), maxTokens: z.number().optional(), diff --git a/apps/code/src/main/services/llm-gateway/service.ts b/apps/code/src/main/services/llm-gateway/service.ts index 8e333c137..0fc92bb94 100644 --- a/apps/code/src/main/services/llm-gateway/service.ts +++ b/apps/code/src/main/services/llm-gateway/service.ts @@ -1,12 +1,13 @@ import { getLlmGatewayUrl } from "@posthog/agent/posthog-api"; import { net } from "electron"; -import { injectable } from "inversify"; +import { inject, injectable } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; import { logger } from "../../utils/logger"; +import type { AuthService } from "../auth/service"; import type { AnthropicErrorResponse, AnthropicMessagesRequest, AnthropicMessagesResponse, - LlmCredentials, LlmMessage, PromptOutput, } from "./schemas"; @@ -27,8 +28,12 @@ export class LlmGatewayError extends Error { @injectable() export class LlmGatewayService { + constructor( + @inject(MAIN_TOKENS.AuthService) + private readonly authService: AuthService, + ) {} + async prompt( - credentials: LlmCredentials, messages: LlmMessage[], options: { system?: string; @@ -38,7 +43,8 @@ export class LlmGatewayService { ): Promise { const { system, maxTokens, model = "claude-haiku-4-5" } = options; - const gatewayUrl = getLlmGatewayUrl(credentials.apiHost); + const auth = await this.authService.getValidAccessToken(); + const gatewayUrl = getLlmGatewayUrl(auth.apiHost); const messagesUrl = `${gatewayUrl}/v1/messages`; const requestBody: AnthropicMessagesRequest = { @@ -61,14 +67,17 @@ export class LlmGatewayService { messageCount: messages.length, }); - const response = await net.fetch(messagesUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${credentials.apiKey}`, + const response = await this.authService.authenticatedFetch( + net.fetch, + messagesUrl, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(requestBody), }, - body: JSON.stringify(requestBody), - }); + ); if (!response.ok) { const errorBody = await response.text(); diff --git a/apps/code/src/main/services/ui/service.ts b/apps/code/src/main/services/ui/service.ts index 4a2632d9c..f991d4ea8 100644 --- a/apps/code/src/main/services/ui/service.ts +++ b/apps/code/src/main/services/ui/service.ts @@ -1,9 +1,18 @@ -import { injectable } from "inversify"; +import { inject, injectable } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; import { TypedEventEmitter } from "../../utils/typed-event-emitter"; +import type { AuthService } from "../auth/service"; import { UIServiceEvent, type UIServiceEvents } from "./schemas"; @injectable() export class UIService extends TypedEventEmitter { + constructor( + @inject(MAIN_TOKENS.AuthService) + private readonly authService: AuthService, + ) { + super(); + } + openSettings(): void { this.emit(UIServiceEvent.OpenSettings, true); } @@ -20,7 +29,8 @@ export class UIService extends TypedEventEmitter { this.emit(UIServiceEvent.ClearStorage, true); } - invalidateToken(): void { + async invalidateToken(): Promise { + await this.authService.invalidateAccessTokenForTest(); this.emit(UIServiceEvent.InvalidateToken, true); } } diff --git a/apps/code/src/main/trpc/routers/agent.ts b/apps/code/src/main/trpc/routers/agent.ts index 75a4a7960..f6907f61a 100644 --- a/apps/code/src/main/trpc/routers/agent.ts +++ b/apps/code/src/main/trpc/routers/agent.ts @@ -20,7 +20,6 @@ import { setConfigOptionInput, startSessionInput, subscribeSessionInput, - tokenUpdateInput, } from "../../services/agent/schemas"; import type { AgentService } from "../../services/agent/service"; import type { ProcessTrackingService } from "../../services/process-tracking/service"; @@ -61,10 +60,6 @@ export const agentRouter = router({ .output(sessionResponseSchema.nullable()) .mutation(({ input }) => getService().reconnectSession(input)), - updateToken: publicProcedure.input(tokenUpdateInput).mutation(({ input }) => { - getService().updateToken(input.token); - }), - setConfigOption: publicProcedure .input(setConfigOptionInput) .mutation(({ input }) => @@ -197,7 +192,5 @@ export const agentRouter = router({ getGatewayModels: publicProcedure .input(getGatewayModelsInput) .output(getGatewayModelsOutput) - .query(({ input }) => - getService().getGatewayModels(input.apiHost, input.apiKey), - ), + .query(({ input }) => getService().getGatewayModels(input.apiHost)), }); diff --git a/apps/code/src/main/trpc/routers/cloud-task.ts b/apps/code/src/main/trpc/routers/cloud-task.ts index 8a7ab595e..7d82366ab 100644 --- a/apps/code/src/main/trpc/routers/cloud-task.ts +++ b/apps/code/src/main/trpc/routers/cloud-task.ts @@ -7,7 +7,6 @@ import { sendCommandOutput, setViewingInput, unwatchInput, - updateTokenInput, watchInput, } from "../../services/cloud-task/schemas"; import type { CloudTaskService } from "../../services/cloud-task/service"; @@ -25,10 +24,6 @@ export const cloudTaskRouter = router({ .input(unwatchInput) .mutation(({ input }) => getService().unwatch(input.taskId, input.runId)), - updateToken: publicProcedure - .input(updateTokenInput) - .mutation(({ input }) => getService().updateToken(input.token)), - setViewing: publicProcedure .input(setViewingInput) .mutation(({ input }) => diff --git a/apps/code/src/main/trpc/routers/git.ts b/apps/code/src/main/trpc/routers/git.ts index 384bc1c96..0629c60e9 100644 --- a/apps/code/src/main/trpc/routers/git.ts +++ b/apps/code/src/main/trpc/routers/git.ts @@ -275,20 +275,14 @@ export const gitRouter = router({ .input(generateCommitMessageInput) .output(generateCommitMessageOutput) .mutation(({ input }) => - getService().generateCommitMessage( - input.directoryPath, - input.credentials, - ), + getService().generateCommitMessage(input.directoryPath), ), generatePrTitleAndBody: publicProcedure .input(generatePrTitleAndBodyInput) .output(generatePrTitleAndBodyOutput) .mutation(({ input }) => - getService().generatePrTitleAndBody( - input.directoryPath, - input.credentials, - ), + getService().generatePrTitleAndBody(input.directoryPath), ), searchGithubIssues: publicProcedure diff --git a/apps/code/src/main/trpc/routers/llm-gateway.ts b/apps/code/src/main/trpc/routers/llm-gateway.ts index d115d355a..83c59ecac 100644 --- a/apps/code/src/main/trpc/routers/llm-gateway.ts +++ b/apps/code/src/main/trpc/routers/llm-gateway.ts @@ -12,7 +12,7 @@ export const llmGatewayRouter = router({ .input(promptInput) .output(promptOutput) .mutation(({ input }) => - getService().prompt(input.credentials, input.messages, { + getService().prompt(input.messages, { system: input.system, maxTokens: input.maxTokens, model: input.model, diff --git a/apps/code/src/main/trpc/routers/oauth.ts b/apps/code/src/main/trpc/routers/oauth.ts index f3f6f8d7d..e62a1a05f 100644 --- a/apps/code/src/main/trpc/routers/oauth.ts +++ b/apps/code/src/main/trpc/routers/oauth.ts @@ -1,42 +1,13 @@ import { container } from "../../di/container"; import { MAIN_TOKENS } from "../../di/tokens"; -import { - cancelFlowOutput, - openExternalUrlInput, - refreshTokenInput, - refreshTokenOutput, - startFlowInput, - startFlowOutput, - startSignupFlowInput, -} from "../../services/oauth/schemas"; +import { cancelFlowOutput } from "../../services/oauth/schemas"; import type { OAuthService } from "../../services/oauth/service"; import { publicProcedure, router } from "../trpc"; const getService = () => container.get(MAIN_TOKENS.OAuthService); export const oauthRouter = router({ - startFlow: publicProcedure - .input(startFlowInput) - .output(startFlowOutput) - .mutation(({ input }) => getService().startFlow(input.region)), - - startSignupFlow: publicProcedure - .input(startSignupFlowInput) - .output(startFlowOutput) - .mutation(({ input }) => getService().startSignupFlow(input.region)), - - refreshToken: publicProcedure - .input(refreshTokenInput) - .output(refreshTokenOutput) - .mutation(({ input }) => - getService().refreshToken(input.refreshToken, input.region), - ), - cancelFlow: publicProcedure .output(cancelFlowOutput) .mutation(() => getService().cancelFlow()), - - openExternalUrl: publicProcedure - .input(openExternalUrlInput) - .mutation(({ input }) => getService().openExternalUrl(input.url)), }); diff --git a/apps/code/src/renderer/components/GlobalEventHandlers.tsx b/apps/code/src/renderer/components/GlobalEventHandlers.tsx index e3fe442e2..de2aa27c1 100644 --- a/apps/code/src/renderer/components/GlobalEventHandlers.tsx +++ b/apps/code/src/renderer/components/GlobalEventHandlers.tsx @@ -1,4 +1,3 @@ -import { useAuthStore } from "@features/auth/stores/authStore"; import { useFolders } from "@features/folders/hooks/useFolders"; import { usePanelLayoutStore } from "@features/panels/store/panelLayoutStore"; import { useRightSidebarStore } from "@features/right-sidebar"; @@ -11,7 +10,6 @@ import { useFocusWorkspace } from "@features/workspace/hooks/useFocusWorkspace"; import { useWorkspaces } from "@features/workspace/hooks/useWorkspace"; import { SHORTCUTS } from "@renderer/constants/keyboard-shortcuts"; import { useTRPC } from "@renderer/trpc"; -import { trpcClient } from "@renderer/trpc/client"; import type { Task } from "@shared/types"; import { useCommandMenuStore } from "@stores/commandMenuStore"; import { useNavigationStore } from "@stores/navigationStore"; @@ -147,21 +145,7 @@ export function GlobalEventHandlers({ const handleInvalidateToken = useCallback((data?: unknown) => { if (!data) return; const log = logger.scope("global-event-handlers"); - const state = useAuthStore.getState(); - const currentToken = state.oauthAccessToken; - if (!currentToken) { - log.warn("No access token to invalidate"); - return; - } - const invalidToken = `${currentToken}_invalid`; - useAuthStore.setState({ oauthAccessToken: invalidToken }); - trpcClient.agent.updateToken - .mutate({ token: invalidToken }) - .catch((err) => log.warn("Failed to update agent token", err)); - trpcClient.cloudTask.updateToken - .mutate({ token: invalidToken }) - .catch((err) => log.warn("Failed to update cloud task token", err)); - log.info("OAuth access token invalidated for testing"); + log.info("Main access token invalidated for testing"); }, []); const globalOptions = { diff --git a/apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx b/apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx index 8dfe8a4b1..1253f3476 100644 --- a/apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx +++ b/apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx @@ -49,12 +49,6 @@ vi.mock("@renderer/trpc/client", () => ({ }), }, }, - agent: { - updateToken: { mutate: vi.fn().mockResolvedValue(undefined) }, - }, - cloudTask: { - updateToken: { mutate: vi.fn().mockResolvedValue(undefined) }, - }, analytics: { setUserId: { mutate: vi.fn().mockResolvedValue(undefined) }, resetUser: { mutate: vi.fn().mockResolvedValue(undefined) }, diff --git a/apps/code/src/renderer/features/auth/stores/authStore.test.ts b/apps/code/src/renderer/features/auth/stores/authStore.test.ts index 0c1fd3fb6..cd5ce4e05 100644 --- a/apps/code/src/renderer/features/auth/stores/authStore.test.ts +++ b/apps/code/src/renderer/features/auth/stores/authStore.test.ts @@ -24,12 +24,6 @@ vi.mock("@renderer/trpc/client", () => ({ redeemInviteCode: mockRedeemInviteCode, logout: mockLogout, }, - agent: { - updateToken: { mutate: vi.fn().mockResolvedValue(undefined) }, - }, - cloudTask: { - updateToken: { mutate: vi.fn().mockResolvedValue(undefined) }, - }, analytics: { setUserId: { mutate: vi.fn().mockResolvedValue(undefined) }, resetUser: { mutate: vi.fn().mockResolvedValue(undefined) }, @@ -77,6 +71,8 @@ vi.mock("@stores/navigationStore", () => ({ }, })); +import { resetUser } from "@utils/analytics"; +import { queryClient } from "@utils/queryClient"; import { resetAuthStoreModuleStateForTest, useAuthStore } from "./authStore"; const authenticatedState = { @@ -120,7 +116,6 @@ describe("authStore", () => { mockOnStateChangedSubscribe.mockReturnValue({ unsubscribe: vi.fn() }); useAuthStore.setState({ - oauthAccessToken: null, cloudRegion: null, staleCloudRegion: null, isAuthenticated: false, @@ -145,7 +140,6 @@ describe("authStore", () => { expect(result).toBe(true); expect(useAuthStore.getState().isAuthenticated).toBe(true); expect(useAuthStore.getState().projectId).toBe(1); - expect(useAuthStore.getState().oauthAccessToken).toBe("test-access-token"); }); it("logs in through the main auth service", async () => { @@ -166,6 +160,29 @@ describe("authStore", () => { await useAuthStore.getState().checkCodeAccess(); expect(mockGetCurrentUser).toHaveBeenCalledTimes(1); - expect(mockGetValidAccessToken.query).toHaveBeenCalledTimes(1); + }); + + it("clears user identity and cached current user on implicit auth loss", async () => { + mockGetState.query + .mockResolvedValueOnce(authenticatedState) + .mockResolvedValueOnce({ + status: "anonymous", + bootstrapComplete: true, + cloudRegion: null, + projectId: null, + availableProjectIds: [], + availableOrgIds: [], + hasCodeAccess: null, + needsScopeReauth: false, + }); + + await useAuthStore.getState().initializeOAuth(); + await useAuthStore.getState().checkCodeAccess(); + + expect(resetUser).toHaveBeenCalledTimes(1); + expect(queryClient.removeQueries).toHaveBeenCalledWith({ + queryKey: ["currentUser"], + exact: true, + }); }); }); diff --git a/apps/code/src/renderer/features/auth/stores/authStore.ts b/apps/code/src/renderer/features/auth/stores/authStore.ts index 167f02262..76f36966b 100644 --- a/apps/code/src/renderer/features/auth/stores/authStore.ts +++ b/apps/code/src/renderer/features/auth/stores/authStore.ts @@ -32,10 +32,8 @@ export function resetAuthStoreModuleStateForTest(): void { } interface AuthStoreState { - oauthAccessToken: string | null; cloudRegion: CloudRegion | null; staleCloudRegion: CloudRegion | null; - isAuthenticated: boolean; client: PostHogAPIClient | null; projectId: number | null; @@ -44,16 +42,13 @@ interface AuthStoreState { needsProjectSelection: boolean; needsScopeReauth: boolean; hasCodeAccess: boolean | null; - hasCompletedOnboarding: boolean; selectedPlan: "free" | "pro" | null; selectedOrgId: string | null; - checkCodeAccess: () => Promise; redeemInviteCode: (code: string) => Promise; loginWithOAuth: (region: CloudRegion) => Promise; signupWithOAuth: (region: CloudRegion) => Promise; - refreshAccessToken: () => Promise; initializeOAuth: () => Promise; selectProject: (projectId: number) => Promise; completeOnboarding: () => void; @@ -64,25 +59,14 @@ interface AuthStoreState { async function getValidAccessToken(): Promise { const { accessToken } = await trpcClient.auth.getValidAccessToken.query(); - useAuthStore.setState({ oauthAccessToken: accessToken }); return accessToken; } async function refreshAccessToken(): Promise { const { accessToken } = await trpcClient.auth.refreshAccessToken.mutate(); - useAuthStore.setState({ oauthAccessToken: accessToken }); return accessToken; } -function updateServiceTokens(token: string): void { - trpcClient.agent.updateToken - .mutate({ token }) - .catch((err) => log.warn("Failed to update agent token", err)); - trpcClient.cloudTask.updateToken - .mutate({ token }) - .catch((err) => log.warn("Failed to update cloud task token", err)); -} - function createClient( cloudRegion: CloudRegion, projectId: number | null, @@ -99,7 +83,22 @@ function createClient( return client; } +function clearAuthenticatedRendererState(options?: { + clearAllQueries?: boolean; +}): void { + resetUser(); + trpcClient.analytics.resetUser.mutate(); + + if (options?.clearAllQueries) { + queryClient.clear(); + return; + } + + queryClient.removeQueries({ queryKey: ["currentUser"], exact: true }); +} + async function syncAuthState(): Promise { + const previousState = useAuthStore.getState(); const authState = await trpcClient.auth.getState.query(); const isAuthenticated = authState.status === "authenticated"; @@ -136,6 +135,9 @@ async function syncAuthState(): Promise { const client = useAuthStore.getState().client; if (!isAuthenticated || !authState.cloudRegion || !client) { + if (previousState.isAuthenticated || lastCompletedAuthSyncKey !== null) { + clearAuthenticatedRendererState(); + } inFlightAuthSync = null; inFlightAuthSyncKey = null; lastCompletedAuthSyncKey = null; @@ -163,9 +165,6 @@ async function syncAuthState(): Promise { const user = await client.getCurrentUser(); queryClient.setQueryData(["currentUser"], user); - const token = await getValidAccessToken(); - updateServiceTokens(token); - const distinctId = user.distinct_id || user.email; identifyUser(distinctId, { email: user.email, @@ -214,7 +213,6 @@ function ensureAuthSubscription(): void { } export const useAuthStore = create((set, get) => ({ - oauthAccessToken: null, cloudRegion: null, staleCloudRegion: null, @@ -258,11 +256,6 @@ export const useAuthStore = create((set, get) => ({ }); }, - refreshAccessToken: async () => { - const token = await refreshAccessToken(); - updateServiceTokens(token); - }, - initializeOAuth: async () => { if (initializePromise) { return initializePromise; @@ -300,16 +293,13 @@ export const useAuthStore = create((set, get) => ({ logout: async () => { track(ANALYTICS_EVENTS.USER_LOGGED_OUT); - resetUser(); sessionResetCallback?.(); - queryClient.clear(); + clearAuthenticatedRendererState({ clearAllQueries: true }); await trpcClient.auth.logout.mutate(); - trpcClient.analytics.resetUser.mutate(); useNavigationStore.getState().navigateToTaskInput(); set((state) => ({ ...state, - oauthAccessToken: null, cloudRegion: null, staleCloudRegion: state.cloudRegion ?? null, isAuthenticated: false, diff --git a/apps/code/src/renderer/features/git-interaction/hooks/useGitInteraction.ts b/apps/code/src/renderer/features/git-interaction/hooks/useGitInteraction.ts index 783ad0954..0490c9609 100644 --- a/apps/code/src/renderer/features/git-interaction/hooks/useGitInteraction.ts +++ b/apps/code/src/renderer/features/git-interaction/hooks/useGitInteraction.ts @@ -1,4 +1,3 @@ -import { useAuthStore } from "@features/auth/stores/authStore"; import { useGitQueries } from "@features/git-interaction/hooks/useGitQueries"; import { computeGitInteractionState } from "@features/git-interaction/state/gitInteractionLogic"; import { @@ -132,21 +131,10 @@ export function useGitInteraction( modal.openPr("", ""); if (!repoPath) return; - const authState = useAuthStore.getState(); - const apiKey = authState.oauthAccessToken; - const cloudRegion = authState.cloudRegion; - if (!apiKey || !cloudRegion) return; - - const apiHost = - cloudRegion === "eu" - ? "https://eu.posthog.com" - : "https://us.posthog.com"; - modal.setIsGeneratingPr(true); try { const result = await trpcClient.git.generatePrTitleAndBody.mutate({ directoryPath: repoPath, - credentials: { apiKey, apiHost }, }); if (result.title || result.body) { modal.setPrTitle(result.title); @@ -201,27 +189,9 @@ export function useGitInteraction( let message = store.commitMessage.trim(); if (!message) { - const authState = useAuthStore.getState(); - const apiKey = authState.oauthAccessToken; - const cloudRegion = authState.cloudRegion; - - if (!apiKey || !cloudRegion) { - modal.setCommitError( - "Authentication required to generate commit message.", - ); - modal.setIsSubmitting(false); - return; - } - - const apiHost = - cloudRegion === "eu" - ? "https://eu.posthog.com" - : "https://us.posthog.com"; - try { const generated = await trpcClient.git.generateCommitMessage.mutate({ directoryPath: repoPath, - credentials: { apiKey, apiHost }, }); if (!generated.message) { @@ -389,29 +359,12 @@ export function useGitInteraction( const generateCommitMessage = async () => { if (!repoPath) return; - const authState = useAuthStore.getState(); - const apiKey = authState.oauthAccessToken; - const cloudRegion = authState.cloudRegion; - - if (!apiKey || !cloudRegion) { - modal.setCommitError( - "Authentication required to generate commit message.", - ); - return; - } - - const apiHost = - cloudRegion === "eu" - ? "https://eu.posthog.com" - : "https://us.posthog.com"; - modal.setIsGeneratingCommitMessage(true); modal.setCommitError(null); try { const result = await trpcClient.git.generateCommitMessage.mutate({ directoryPath: repoPath, - credentials: { apiKey, apiHost }, }); if (result.message) { @@ -436,27 +389,12 @@ export function useGitInteraction( const generatePrTitleAndBody = async () => { if (!repoPath) return; - const authState = useAuthStore.getState(); - const apiKey = authState.oauthAccessToken; - const cloudRegion = authState.cloudRegion; - - if (!apiKey || !cloudRegion) { - modal.setPrError("Authentication required to generate PR description."); - return; - } - - const apiHost = - cloudRegion === "eu" - ? "https://eu.posthog.com" - : "https://us.posthog.com"; - modal.setIsGeneratingPr(true); modal.setPrError(null); try { const result = await trpcClient.git.generatePrTitleAndBody.mutate({ directoryPath: repoPath, - credentials: { apiKey, apiHost }, }); if (result.title || result.body) { diff --git a/apps/code/src/renderer/features/sessions/service/service.test.ts b/apps/code/src/renderer/features/sessions/service/service.test.ts index 881e5e3f6..dbe19f2db 100644 --- a/apps/code/src/renderer/features/sessions/service/service.test.ts +++ b/apps/code/src/renderer/features/sessions/service/service.test.ts @@ -64,7 +64,6 @@ vi.mock("@features/sessions/stores/sessionStore", () => ({ const mockAuthStore = vi.hoisted(() => ({ useAuthStore: { getState: vi.fn(() => ({ - oauthAccessToken: "test-token", cloudRegion: "us", projectId: 123, client: { @@ -282,7 +281,6 @@ describe("SessionService", () => { // Track how many times createTaskRun is called const createTaskRunMock = vi.fn().mockResolvedValue({ id: "run-123" }); mockAuthStore.useAuthStore.getState.mockReturnValue({ - oauthAccessToken: "test-token", cloudRegion: "us", projectId: 123, client: { @@ -336,7 +334,6 @@ describe("SessionService", () => { const service = getSessionService(); mockAuthStore.useAuthStore.getState.mockReturnValue({ - oauthAccessToken: null, cloudRegion: null, projectId: null, client: null, @@ -418,7 +415,6 @@ describe("SessionService", () => { // Setup: create a task run to trigger subscription creation const createTaskRunMock = vi.fn().mockResolvedValue({ id: "run-456" }); mockAuthStore.useAuthStore.getState.mockReturnValue({ - oauthAccessToken: "test-token", cloudRegion: "us", projectId: 123, client: { diff --git a/apps/code/src/renderer/features/sessions/service/service.ts b/apps/code/src/renderer/features/sessions/service/service.ts index 0c2b9351b..d6c2a20a4 100644 --- a/apps/code/src/renderer/features/sessions/service/service.ts +++ b/apps/code/src/renderer/features/sessions/service/service.ts @@ -63,7 +63,6 @@ const log = logger.scope("session-service"); export const PREVIEW_TASK_ID = "__preview__"; interface AuthCredentials { - apiKey: string; apiHost: string; projectId: number; client: ReturnType["client"]; @@ -390,7 +389,6 @@ export class SessionService { taskId, taskRunId, repoPath, - apiKey: auth.apiKey, apiHost: auth.apiHost, projectId: auth.projectId, logUrl, @@ -543,7 +541,6 @@ export class SessionService { taskId, taskRunId: taskRun.id, repoPath, - apiKey: auth.apiKey, apiHost: auth.apiHost, projectId: auth.projectId, permissionMode: executionMode, @@ -688,7 +685,6 @@ export class SessionService { taskId: PREVIEW_TASK_ID, taskRunId, repoPath: "__preview__", - apiKey: auth.apiKey, apiHost: auth.apiHost, projectId: auth.projectId, adapter: params.adapter, @@ -1946,18 +1942,13 @@ export class SessionService { }); } - // Get auth for initial token + host info + // Get auth for host info const auth = useAuthStore.getState(); - if (!auth.oauthAccessToken || !auth.projectId || !auth.cloudRegion) { + if (!auth.projectId || !auth.cloudRegion) { log.warn("No auth for cloud task watcher", { taskId }); return () => {}; } - // Ensure main-process service has current token - trpcClient.cloudTask.updateToken - .mutate({ token: auth.oauthAccessToken }) - .catch(() => {}); - // Start main-process watcher trpcClient.cloudTask.watch .mutate({ @@ -2174,15 +2165,14 @@ export class SessionService { private getAuthCredentials(): AuthCredentials | null { const authState = useAuthStore.getState(); - const apiKey = authState.oauthAccessToken; const apiHost = authState.cloudRegion ? getCloudUrlFromRegion(authState.cloudRegion) : null; const projectId = authState.projectId; const client = authState.client; - if (!apiKey || !apiHost || !projectId) return null; - return { apiKey, apiHost, projectId, client }; + if (!apiHost || !projectId || !client) return null; + return { apiHost, projectId, client }; } private parseLogContent(content: string): { diff --git a/apps/code/src/renderer/features/sidebar/components/ProjectSwitcher.tsx b/apps/code/src/renderer/features/sidebar/components/ProjectSwitcher.tsx index 3a04d8330..1a79e44f1 100644 --- a/apps/code/src/renderer/features/sidebar/components/ProjectSwitcher.tsx +++ b/apps/code/src/renderer/features/sidebar/components/ProjectSwitcher.tsx @@ -74,7 +74,7 @@ export function ProjectSwitcher() { const handleCreateProject = async () => { if (cloudRegion) { const cloudUrl = getCloudUrlFromRegion(cloudRegion); - await trpcClient.oauth.openExternalUrl.mutate({ + await trpcClient.os.openExternal.mutate({ url: `${cloudUrl}/organization/create-project`, }); } @@ -99,12 +99,12 @@ export function ProjectSwitcher() { }; const handleOpenExternal = async (url: string) => { - await trpcClient.oauth.openExternalUrl.mutate({ url }); + await trpcClient.os.openExternal.mutate({ url }); setPopoverOpen(false); }; const handleDiscord = async () => { - await trpcClient.oauth.openExternalUrl.mutate({ + await trpcClient.os.openExternal.mutate({ url: "https://discord.gg/c3qYyJXSWp", }); setPopoverOpen(false); diff --git a/apps/code/src/renderer/utils/generateTitle.ts b/apps/code/src/renderer/utils/generateTitle.ts index 7ec56d031..dd3f33dcf 100644 --- a/apps/code/src/renderer/utils/generateTitle.ts +++ b/apps/code/src/renderer/utils/generateTitle.ts @@ -1,6 +1,5 @@ import { useAuthStore } from "@features/auth/stores/authStore"; import { trpcClient } from "@renderer/trpc"; -import { getCloudUrlFromRegion } from "@shared/constants/oauth"; import { logger } from "@utils/logger"; const log = logger.scope("title-generator"); @@ -43,14 +42,9 @@ Never wrap the title in quotes.`; export async function generateTitle(content: string): Promise { try { const authState = useAuthStore.getState(); - const apiKey = authState.oauthAccessToken; - const cloudRegion = authState.cloudRegion; - if (!apiKey || !cloudRegion) return null; - - const apiHost = getCloudUrlFromRegion(cloudRegion); + if (!authState.isAuthenticated) return null; const result = await trpcClient.llmGateway.prompt.mutate({ - credentials: { apiKey, apiHost }, system: SYSTEM_PROMPT, messages: [ { diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index d2adfd62c..26da11a88 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -45,17 +45,17 @@ export class Agent { } } - private _configureLlmGateway(overrideUrl?: string): { + private async _configureLlmGateway(overrideUrl?: string): Promise<{ gatewayUrl: string; apiKey: string; - } | null { + } | null> { if (!this.posthogAPI) { return null; } try { const gatewayUrl = overrideUrl ?? this.posthogAPI.getLlmGatewayUrl(); - const apiKey = this.posthogAPI.getApiKey(); + const apiKey = await this.posthogAPI.getApiKey(); process.env.OPENAI_BASE_URL = `${gatewayUrl}/v1`; process.env.OPENAI_API_KEY = apiKey; @@ -74,7 +74,7 @@ export class Agent { taskRunId: string, options: TaskExecutionOptions = {}, ): Promise { - const gatewayConfig = this._configureLlmGateway(options.gatewayUrl); + const gatewayConfig = await this._configureLlmGateway(options.gatewayUrl); this.logger.info("Configured LLM gateway", { adapter: options.adapter, }); diff --git a/packages/agent/src/posthog-api.test.ts b/packages/agent/src/posthog-api.test.ts new file mode 100644 index 000000000..4ab3f3bd2 --- /dev/null +++ b/packages/agent/src/posthog-api.test.ts @@ -0,0 +1,48 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { PostHogAPIClient } from "./posthog-api"; + +const mockFetch = vi.fn(); + +vi.stubGlobal("fetch", mockFetch); + +describe("PostHogAPIClient", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("refreshes once when fetching task run logs gets an auth failure", async () => { + const getApiKey = vi.fn().mockResolvedValue("stale-token"); + const refreshApiKey = vi.fn().mockResolvedValue("fresh-token"); + const client = new PostHogAPIClient({ + apiUrl: "https://app.posthog.com", + getApiKey, + refreshApiKey, + projectId: 1, + }); + + mockFetch + .mockResolvedValueOnce({ + ok: false, + status: 401, + statusText: "Unauthorized", + }) + .mockResolvedValueOnce({ + ok: true, + text: vi + .fn() + .mockResolvedValue( + `${JSON.stringify({ type: "notification", notification: { method: "foo" } })}\n`, + ), + }); + + const logs = await client.fetchTaskRunLogs({ + id: "run-1", + task: "task-1", + } as never); + + expect(logs).toHaveLength(1); + expect(getApiKey).toHaveBeenCalledTimes(1); + expect(refreshApiKey).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledTimes(2); + }); +}); diff --git a/packages/agent/src/posthog-api.ts b/packages/agent/src/posthog-api.ts index 75196b20c..c9d4b3e4f 100644 --- a/packages/agent/src/posthog-api.ts +++ b/packages/agent/src/posthog-api.ts @@ -47,27 +47,63 @@ export class PostHogAPIClient { return host; } - private get headers(): Record { - return { - Authorization: `Bearer ${this.config.getApiKey()}`, - "Content-Type": "application/json", - "User-Agent": this.config.userAgent ?? DEFAULT_USER_AGENT, - }; + private isAuthFailure(status: number): boolean { + return status === 401 || status === 403; } - private async apiRequest( + private async resolveApiKey(forceRefresh = false): Promise { + if (forceRefresh && this.config.refreshApiKey) { + return this.config.refreshApiKey(); + } + + return this.config.getApiKey(); + } + + private async buildHeaders( + options: RequestInit, + forceRefresh = false, + ): Promise { + const headers = new Headers(options.headers); + headers.set( + "Authorization", + `Bearer ${await this.resolveApiKey(forceRefresh)}`, + ); + headers.set("Content-Type", "application/json"); + headers.set("User-Agent", this.config.userAgent ?? DEFAULT_USER_AGENT); + return headers; + } + + private async performRequest( endpoint: string, - options: RequestInit = {}, - ): Promise { + options: RequestInit, + forceRefresh = false, + ): Promise { const url = `${this.baseUrl}${endpoint}`; - const response = await fetch(url, { + return fetch(url, { ...options, - headers: { - ...this.headers, - ...options.headers, - }, + headers: await this.buildHeaders(options, forceRefresh), }); + } + + private async performRequestWithRetry( + endpoint: string, + options: RequestInit = {}, + ): Promise { + let response = await this.performRequest(endpoint, options); + + if (!response.ok && this.isAuthFailure(response.status)) { + response = await this.performRequest(endpoint, options, true); + } + + return response; + } + + private async apiRequest( + endpoint: string, + options: RequestInit = {}, + ): Promise { + const response = await this.performRequestWithRetry(endpoint, options); if (!response.ok) { let errorMessage: string; @@ -87,8 +123,8 @@ export class PostHogAPIClient { return this.config.projectId; } - getApiKey(): string { - return this.config.getApiKey(); + async getApiKey(forceRefresh = false): Promise { + return this.resolveApiKey(forceRefresh); } getLlmGatewayUrl(): string { @@ -228,12 +264,10 @@ export class PostHogAPIClient { */ async fetchTaskRunLogs(taskRun: TaskRun): Promise { const teamId = this.getTeamId(); + const endpoint = `/api/projects/${teamId}/tasks/${taskRun.task}/runs/${taskRun.id}/logs`; try { - const response = await fetch( - `${this.baseUrl}/api/projects/${teamId}/tasks/${taskRun.task}/runs/${taskRun.id}/logs`, - { headers: this.headers }, - ); + const response = await this.performRequestWithRetry(endpoint); if (!response.ok) { if (response.status === 404) { diff --git a/packages/agent/src/types.ts b/packages/agent/src/types.ts index 8c137f63f..b463189e7 100644 --- a/packages/agent/src/types.ts +++ b/packages/agent/src/types.ts @@ -126,7 +126,8 @@ export type OnLogCallback = ( export interface PostHogAPIConfig { apiUrl: string; - getApiKey: () => string; + getApiKey: () => string | Promise; + refreshApiKey?: () => string | Promise; projectId: number; userAgent?: string; }