diff --git a/apps/code/src/main/deep-links.ts b/apps/code/src/main/deep-links.ts index 40559f11cb..d10e82dcda 100644 --- a/apps/code/src/main/deep-links.ts +++ b/apps/code/src/main/deep-links.ts @@ -1,7 +1,7 @@ import { getDeeplinkProtocol } from "@posthog/shared"; import { app } from "electron"; import { container } from "./di/container"; -import { MAIN_TOKENS } from "./di/tokens"; +import { DEEP_LINK_SERVICE } from "./di/tokens"; import type { DeepLinkService } from "./services/deep-link/service"; import { isDevBuild } from "./utils/env"; import { logger } from "./utils/logger"; @@ -12,7 +12,7 @@ const log = logger.scope("deep-links"); let pendingDeepLinkUrl: string | null = null; function getDeepLinkService(): DeepLinkService { - return container.get(MAIN_TOKENS.DeepLinkService); + return container.get(DEEP_LINK_SERVICE); } function findDeepLinkUrlInArgs(args: string[]): string | undefined { diff --git a/apps/code/src/main/di/container.ts b/apps/code/src/main/di/container.ts index cf712efdca..3c3a7b4e07 100644 --- a/apps/code/src/main/di/container.ts +++ b/apps/code/src/main/di/container.ts @@ -257,8 +257,10 @@ import { DATABASE_SERVICE as MAIN_DATABASE_SERVICE, DEEP_LINK_SERVICE as MAIN_DEEP_LINK_SERVICE, DEFAULT_ADDITIONAL_DIRECTORY_REPOSITORY as MAIN_DEFAULT_ADDITIONAL_DIRECTORY_REPOSITORY, + DISCORD_PRESENCE_SERVICE as MAIN_DISCORD_PRESENCE_SERVICE, ENCRYPTION_SERVICE as MAIN_ENCRYPTION_SERVICE, EXTERNAL_APPS_SERVICE as MAIN_EXTERNAL_APPS_SERVICE, + FILE_WATCHER_SERVICE as MAIN_FILE_WATCHER_SERVICE, FS_SERVICE as MAIN_FS_SERVICE, INBOX_LINK_SERVICE as MAIN_INBOX_LINK_SERVICE, LLM_GATEWAY_SERVICE as MAIN_LLM_GATEWAY_SERVICE, @@ -276,9 +278,9 @@ import { SUSPENSION_REPOSITORY as MAIN_SUSPENSION_REPOSITORY, SUSPENSION_SERVICE as MAIN_SUSPENSION_SERVICE, TASK_LINK_SERVICE as MAIN_TASK_LINK_SERVICE, - MAIN_TOKENS, UPDATES_SERVICE as MAIN_UPDATES_SERVICE, WATCHER_REGISTRY_SERVICE as MAIN_WATCHER_REGISTRY_SERVICE, + WORKSPACE_CLIENT as MAIN_WORKSPACE_CLIENT, WORKSPACE_REPOSITORY as MAIN_WORKSPACE_REPOSITORY, WORKSPACE_SERVER_SERVICE as MAIN_WORKSPACE_SERVER_SERVICE, WORKSPACE_SERVICE as MAIN_WORKSPACE_SERVICE, @@ -340,17 +342,17 @@ container .bind(AUTH_TOKEN_OVERRIDE) .toConstantValue(process.env.VITE_POSTHOG_ACCESS_TOKEN_OVERRIDE ?? null); container.bind(MAIN_AUTH_SERVICE).to(AuthService); -container.bind(AUTH_SERVICE).toService(MAIN_TOKENS.AuthService); +container.bind(AUTH_SERVICE).toService(MAIN_AUTH_SERVICE); container.load(authProxyModule); container.bind(AUTH_PROXY_AUTH).toDynamicValue((ctx) => ({ authenticatedFetch: (url: string, init?: RequestInit) => ctx - .get(MAIN_TOKENS.AuthService) + .get(MAIN_AUTH_SERVICE) .authenticatedFetch(fetch, url, init), })); container.load(mcpProxyModule); container.bind(MCP_PROXY_AUTH).toDynamicValue((ctx) => { - const auth = () => ctx.get(MAIN_TOKENS.AuthService); + const auth = () => ctx.get(MAIN_AUTH_SERVICE); return { authenticatedFetch: (url: string, init?: RequestInit) => auth().authenticatedFetch(fetch, url, init), @@ -365,7 +367,7 @@ container.bind(ARCHIVE_SESSION_CANCELLER).toDynamicValue((ctx) => ({ container.bind(ARCHIVE_FILE_WATCHER).toDynamicValue((ctx) => ({ stopWatching: async (worktreePath: string) => { ctx - .get(MAIN_TOKENS.FileWatcherService) + .get(MAIN_FILE_WATCHER_SERVICE) .stopWatching(worktreePath); }, })); @@ -377,7 +379,7 @@ container.bind(SUSPENSION_SESSION_CANCELLER).toDynamicValue((ctx) => ({ container.bind(SUSPENSION_FILE_WATCHER).toDynamicValue((ctx) => ({ stopWatching: async (worktreePath: string) => { ctx - .get(MAIN_TOKENS.FileWatcherService) + .get(MAIN_FILE_WATCHER_SERVICE) .stopWatching(worktreePath); }, })); @@ -387,20 +389,20 @@ container.load(cloudTaskModule); container.bind(CLOUD_TASK_AUTH).toDynamicValue((ctx) => ({ authenticatedFetch: (url: string, init?: RequestInit) => ctx - .get(MAIN_TOKENS.AuthService) + .get(MAIN_AUTH_SERVICE) .authenticatedFetch(fetch, url, init), })); container.bind(MAIN_CLOUD_TASK_SERVICE).toService(CLOUD_TASK_SERVICE); container.load(contextMenuCoreModule); container .bind(CONTEXT_MENU_EXTERNAL_APPS_SERVICE) - .toService(MAIN_TOKENS.ExternalAppsService); + .toService(MAIN_EXTERNAL_APPS_SERVICE); container.bind(MAIN_CONTEXT_MENU_SERVICE).toService(CONTEXT_MENU_CONTROLLER); container.bind(MAIN_DEEP_LINK_SERVICE).to(DeepLinkService); -container.bind(DEEP_LINK_SERVICE).toService(MAIN_TOKENS.DeepLinkService); +container.bind(DEEP_LINK_SERVICE).toService(MAIN_DEEP_LINK_SERVICE); container.load(enrichmentModule); container.bind(ENRICHMENT_AUTH).toDynamicValue((ctx) => { - const auth = () => ctx.get(MAIN_TOKENS.AuthService); + const auth = () => ctx.get(MAIN_AUTH_SERVICE); return { getState: () => { const state = auth().getState(); @@ -423,7 +425,7 @@ container.bind(ENRICHMENT_FILE_READER).toConstantValue({ listFilesContainingText(repoPath, text), }); container.bind(MAIN_PROVISIONING_SERVICE).to(ProvisioningService); -container.bind(PROVISIONING_SERVICE).toService(MAIN_TOKENS.ProvisioningService); +container.bind(PROVISIONING_SERVICE).toService(MAIN_PROVISIONING_SERVICE); const externalAppsPrefsStore = new ExternalAppsStoreImpl<{ externalAppsPrefs: ExternalAppsPreferences; @@ -441,7 +443,7 @@ container.load(externalAppsModule); container.bind(MAIN_EXTERNAL_APPS_SERVICE).toService(EXTERNAL_APPS_SERVICE); container.load(llmGatewayModule); container.bind(LLM_GATEWAY_HOST).toDynamicValue((ctx) => { - const auth = () => ctx.get(MAIN_TOKENS.AuthService); + const auth = () => ctx.get(MAIN_AUTH_SERVICE); return { getValidAccessToken: () => auth().getValidAccessToken(), authenticatedFetch: (url: string, init?: RequestInit) => @@ -460,9 +462,8 @@ container.bind(MAIN_MCP_APPS_SERVICE).toService(MCP_APPS_SERVICE); container.load(foldersModule); container.load(integrationsModule); container.load(gitPrModule); -container.bind(GIT_DIFF_SOURCE).toDynamicValue(() => { - const wsClient = () => - container.get(GIT_WORKSPACE_CLIENT); +container.bind(GIT_DIFF_SOURCE).toDynamicValue((ctx) => { + const wsClient = () => ctx.get(GIT_WORKSPACE_CLIENT); const git = () => wsClient().git; return { getStagedDiff: (directoryPath: string) => @@ -523,7 +524,7 @@ container container.load(handoffModule); container.bind(HANDOFF_HOST).to(HandoffHostService).inSingletonScope(); container.bind(HANDOFF_GIT_GATEWAY).toDynamicValue((ctx): HandoffGitGateway => { - const workspace = ctx.get(MAIN_TOKENS.WorkspaceClient); + const workspace = ctx.get(MAIN_WORKSPACE_CLIENT); return { async getChangedFiles(repoPath) { const files = await workspace.git.getChangedFilesHead.query({ @@ -548,7 +549,7 @@ container.bind(HANDOFF_GIT_GATEWAY).toDynamicValue((ctx): HandoffGitGateway => { }; }); container.bind(HANDOFF_LOG_GATEWAY).toDynamicValue((ctx) => { - const ws = ctx.get(MAIN_TOKENS.WorkspaceClient); + const ws = ctx.get(MAIN_WORKSPACE_CLIENT); return { seedLocalLogs: (taskRunId: string, content: string) => ws.localLogs.seed.mutate({ taskRunId, content }), @@ -584,26 +585,22 @@ container.load(skillsMarketplaceModule); container.load(onboardingImportModule); container.load(additionalDirectoriesModule); container.bind(MAIN_SLEEP_SERVICE).to(SleepService); -container.bind(SLEEP_SERVICE).toService(MAIN_TOKENS.SleepService); +container.bind(SLEEP_SERVICE).toService(MAIN_SLEEP_SERVICE); container.load(shellModule); container.load(uiModule); container.bind(UI_AUTH).toDynamicValue((ctx) => ({ invalidateAccessTokenForTest: () => - ctx - .get(MAIN_TOKENS.AuthService) - .invalidateAccessTokenForTest(), + ctx.get(MAIN_AUTH_SERVICE).invalidateAccessTokenForTest(), })); container.load(updatesCoreModule); -container - .bind(UPDATE_LIFECYCLE_SERVICE) - .toService(MAIN_TOKENS.AppLifecycleService); +container.bind(UPDATE_LIFECYCLE_SERVICE).toService(MAIN_APP_LIFECYCLE_SERVICE); container.bind(MAIN_UPDATES_SERVICE).toService(UPDATES_SERVICE); container.load(usageMonitorModule); container.bind(USAGE_HOST).toDynamicValue((ctx) => { const agent = () => ctx.get(AGENT_SERVICE); return { fetchUsage: () => - ctx.get(MAIN_TOKENS.LlmGatewayService).fetchUsage(), + ctx.get(MAIN_LLM_GATEWAY_SERVICE).fetchUsage(), onLlmActivity: (listener: () => void) => agent().on(AgentServiceEvent.LlmActivity, listener), offLlmActivity: (listener: () => void) => @@ -615,17 +612,15 @@ container.bind(USAGE_HOST).toDynamicValue((ctx) => { }; }); container.bind(MAIN_TASK_LINK_SERVICE).to(TaskLinkService); -container.bind(TASK_LINK_SERVICE).toService(MAIN_TOKENS.TaskLinkService); +container.bind(TASK_LINK_SERVICE).toService(MAIN_TASK_LINK_SERVICE); container.bind(MAIN_INBOX_LINK_SERVICE).to(InboxLinkService); -container.bind(INBOX_LINK_SERVICE).toService(MAIN_TOKENS.InboxLinkService); +container.bind(INBOX_LINK_SERVICE).toService(MAIN_INBOX_LINK_SERVICE); container.bind(MAIN_SCOUT_LINK_SERVICE).to(ScoutLinkService); -container.bind(SCOUT_LINK_SERVICE).toService(MAIN_TOKENS.ScoutLinkService); +container.bind(SCOUT_LINK_SERVICE).toService(MAIN_SCOUT_LINK_SERVICE); container.bind(MAIN_NEW_TASK_LINK_SERVICE).to(NewTaskLinkService); -container.bind(NEW_TASK_LINK_SERVICE).toService(MAIN_TOKENS.NewTaskLinkService); +container.bind(NEW_TASK_LINK_SERVICE).toService(MAIN_NEW_TASK_LINK_SERVICE); container.bind(MAIN_APPROVAL_LINK_SERVICE).to(ApprovalLinkService); -container - .bind(APPROVAL_LINK_SERVICE) - .toService(MAIN_TOKENS.ApprovalLinkService); +container.bind(APPROVAL_LINK_SERVICE).toService(MAIN_APPROVAL_LINK_SERVICE); container.load(watcherRegistryModule); container .bind(MAIN_WATCHER_REGISTRY_SERVICE) @@ -642,9 +637,7 @@ container.bind(WORKSPACE_AGENT).toDynamicValue((ctx): WorkspaceAgent => { container .bind(WORKSPACE_FILE_WATCHER) .toDynamicValue((ctx): WorkspaceFileWatcher => { - const fileWatcher = ctx.get( - MAIN_TOKENS.FileWatcherService, - ); + const fileWatcher = ctx.get(MAIN_FILE_WATCHER_SERVICE); return { stopWatching: async (worktreePath) => { fileWatcher.stopWatching(worktreePath); @@ -666,7 +659,7 @@ container .bind(WORKSPACE_PROVISIONING) .toDynamicValue((ctx): WorkspaceProvisioning => { const provisioning = ctx.get( - MAIN_TOKENS.ProvisioningService, + MAIN_PROVISIONING_SERVICE, ); return { emitOutput: (taskId, data) => provisioning.emitOutput(taskId, data), @@ -685,9 +678,9 @@ container .bind(MAIN_SECURE_STORE_SERVICE) .to(SecureStoreService) .inSingletonScope(); -container.bind(SECURE_STORE_SERVICE).toService(MAIN_TOKENS.SecureStoreService); +container.bind(SECURE_STORE_SERVICE).toService(MAIN_SECURE_STORE_SERVICE); container.bind(LOGS_SERVICE).toDynamicValue((ctx) => { - const ws = ctx.get(MAIN_TOKENS.WorkspaceClient); + const ws = ctx.get(MAIN_WORKSPACE_CLIENT); return { fetchS3Logs: async (logUrl: string) => { try { @@ -706,7 +699,7 @@ container.bind(LOGS_SERVICE).toDynamicValue((ctx) => { }; }); container.bind(MAIN_ENCRYPTION_SERVICE).to(EncryptionService); -container.bind(MAIN_TOKENS.DiscordPresenceService).to(DiscordPresenceService); +container.bind(MAIN_DISCORD_PRESENCE_SERVICE).to(DiscordPresenceService); // Canvas / dashboards (project-bluebird). The host-agnostic dashboard services // live in @posthog/core (bound via canvasCoreModule) and resolve through diff --git a/apps/code/src/main/di/tokens.ts b/apps/code/src/main/di/tokens.ts index d653053eee..fc491b27e3 100644 --- a/apps/code/src/main/di/tokens.ts +++ b/apps/code/src/main/di/tokens.ts @@ -121,50 +121,3 @@ export const WORKSPACE_SERVER_SERVICE = Symbol.for( export const DISCORD_PRESENCE_SERVICE = Symbol.for( "posthog.host.main.discord-presence.service", ); - -export const MAIN_TOKENS = Object.freeze({ - WorkspaceClient: WORKSPACE_CLIENT, - - SettingsStore: SETTINGS_STORE, - SecureStoreService: SECURE_STORE_SERVICE, - SecureStoreBackend: SECURE_STORE_BACKEND, - EncryptionService: ENCRYPTION_SERVICE, - - AuthPreferenceRepository: AUTH_PREFERENCE_REPOSITORY, - DatabaseService: DATABASE_SERVICE, - AuthSessionRepository: AUTH_SESSION_REPOSITORY, - RepositoryRepository: REPOSITORY_REPOSITORY, - WorkspaceRepository: WORKSPACE_REPOSITORY, - WorktreeRepository: WORKTREE_REPOSITORY, - ArchiveRepository: ARCHIVE_REPOSITORY, - SuspensionRepository: SUSPENSION_REPOSITORY, - DefaultAdditionalDirectoryRepository: DEFAULT_ADDITIONAL_DIRECTORY_REPOSITORY, - - AuthService: AUTH_SERVICE, - SuspensionService: SUSPENSION_SERVICE, - AppLifecycleService: APP_LIFECYCLE_SERVICE, - CloudTaskService: CLOUD_TASK_SERVICE, - ContextMenuService: CONTEXT_MENU_SERVICE, - DiscordPresenceService: DISCORD_PRESENCE_SERVICE, - - ExternalAppsService: EXTERNAL_APPS_SERVICE, - LlmGatewayService: LLM_GATEWAY_SERVICE, - McpAppsService: MCP_APPS_SERVICE, - FileWatcherService: FILE_WATCHER_SERVICE, - FsService: FS_SERVICE, - GitService: GIT_SERVICE, - DeepLinkService: DEEP_LINK_SERVICE, - ProcessTrackingService: PROCESS_TRACKING_SERVICE, - SleepService: SLEEP_SERVICE, - PosthogPluginService: POSTHOG_PLUGIN_SERVICE, - UpdatesService: UPDATES_SERVICE, - TaskLinkService: TASK_LINK_SERVICE, - InboxLinkService: INBOX_LINK_SERVICE, - ScoutLinkService: SCOUT_LINK_SERVICE, - NewTaskLinkService: NEW_TASK_LINK_SERVICE, - ApprovalLinkService: APPROVAL_LINK_SERVICE, - WatcherRegistryService: WATCHER_REGISTRY_SERVICE, - ProvisioningService: PROVISIONING_SERVICE, - WorkspaceService: WORKSPACE_SERVICE, - WorkspaceServerService: WORKSPACE_SERVER_SERVICE, -}); diff --git a/apps/code/src/main/index.ts b/apps/code/src/main/index.ts index cb6e30c2fe..74ce1d7553 100644 --- a/apps/code/src/main/index.ts +++ b/apps/code/src/main/index.ts @@ -48,7 +48,25 @@ import type { SuspensionService } from "@posthog/workspace-server/services/suspe import type { WorkspaceService } from "@posthog/workspace-server/services/workspace/workspace"; import { initializeDeepLinks, registerDeepLinkHandlers } from "./deep-links"; import { container } from "./di/container"; -import { MAIN_TOKENS } from "./di/tokens"; +import { + APP_LIFECYCLE_SERVICE, + APPROVAL_LINK_SERVICE, + AUTH_SERVICE, + DATABASE_SERVICE, + DISCORD_PRESENCE_SERVICE, + EXTERNAL_APPS_SERVICE, + FILE_WATCHER_SERVICE, + INBOX_LINK_SERVICE, + FS_SERVICE as MAIN_FS_SERVICE, + NEW_TASK_LINK_SERVICE, + POSTHOG_PLUGIN_SERVICE, + SCOUT_LINK_SERVICE, + TASK_LINK_SERVICE, + UPDATES_SERVICE, + WORKSPACE_CLIENT, + WORKSPACE_SERVER_SERVICE, + WORKSPACE_SERVICE, +} from "./di/tokens"; import { posthogNodeAnalytics } from "./platform-adapters/posthog-analytics"; import { registerMcpSandboxProtocol } from "./protocols/mcp-sandbox"; import type { AppLifecycleService } from "./services/app-lifecycle/service"; @@ -219,29 +237,27 @@ app.on("child-process-gone", (_event, details) => { }); async function initializeServices(): Promise { - container.get(MAIN_TOKENS.DatabaseService); + container.get(DATABASE_SERVICE); container.get(OAUTH_SERVICE); - const authService = container.get(MAIN_TOKENS.AuthService); + const authService = container.get(AUTH_SERVICE); container.get(NOTIFICATION_SERVICE); - container.get(MAIN_TOKENS.UpdatesService); - container.get(MAIN_TOKENS.TaskLinkService); - container.get(MAIN_TOKENS.InboxLinkService); - container.get(MAIN_TOKENS.ScoutLinkService); - container.get(MAIN_TOKENS.NewTaskLinkService); - container.get(MAIN_TOKENS.ApprovalLinkService); + container.get(UPDATES_SERVICE); + container.get(TASK_LINK_SERVICE); + container.get(INBOX_LINK_SERVICE); + container.get(SCOUT_LINK_SERVICE); + container.get(NEW_TASK_LINK_SERVICE); + container.get(APPROVAL_LINK_SERVICE); container.get(GITHUB_INTEGRATION_SERVICE); container.get(SLACK_INTEGRATION_SERVICE); - container.get(MAIN_TOKENS.ExternalAppsService); - container.get(MAIN_TOKENS.PosthogPluginService); + container.get(EXTERNAL_APPS_SERVICE); + container.get(POSTHOG_PLUGIN_SERVICE); // Eagerly start the Discord presence service so it connects when enabled. - container.get(MAIN_TOKENS.DiscordPresenceService); + container.get(DISCORD_PRESENCE_SERVICE); await authService.initialize(); // Initialize workspace branch watcher for live branch rename detection - const workspaceService = container.get( - MAIN_TOKENS.WorkspaceService, - ); + const workspaceService = container.get(WORKSPACE_SERVICE); workspaceService.initBranchWatcher(); const suspensionService = @@ -310,18 +326,16 @@ app.whenReady().then(async () => { createWindow(); const wsServer = container.get( - MAIN_TOKENS.WorkspaceServerService, + WORKSPACE_SERVER_SERVICE, ); const connection = await wsServer.start(); const workspaceClient = createWorkspaceClient(connection); - container.bind(MAIN_TOKENS.WorkspaceClient).toConstantValue(workspaceClient); + container.bind(WORKSPACE_CLIENT).toConstantValue(workspaceClient); container.bind(GIT_WORKSPACE_CLIENT).toConstantValue(workspaceClient); container.bind(CONNECTIVITY_CLIENT).toConstantValue(workspaceClient); container.bind(ENVIRONMENT_CLIENT).toConstantValue(workspaceClient); const fileWatcherBridge = new FileWatcherBridge(workspaceClient); - container - .bind(MAIN_TOKENS.FileWatcherService) - .toConstantValue(fileWatcherBridge); + container.bind(FILE_WATCHER_SERVICE).toConstantValue(fileWatcherBridge); container.bind(FILE_WATCHER_CONTROL).toConstantValue(fileWatcherBridge); container.bind(FOCUS_WORKSPACE_CLIENT).toConstantValue(workspaceClient); container.bind(FOCUS_SESSION_STORE).toConstantValue(focusSessionStore); @@ -358,8 +372,8 @@ app.whenReady().then(async () => { }); }, }; - container.bind(MAIN_TOKENS.FsService).toConstantValue(fsCapability); - container.bind(FS_SERVICE).toService(MAIN_TOKENS.FsService); + container.bind(MAIN_FS_SERVICE).toConstantValue(fsCapability); + container.bind(FS_SERVICE).toService(MAIN_FS_SERVICE); await initializeServices(); initializeDeepLinks(); }); @@ -368,16 +382,22 @@ app.on("window-all-closed", () => { app.quit(); }); +const teardownContainer = async (): Promise => { + try { + await container.unbindAll(); + } catch (error) { + log.warn("Failed to unbind container", error); + } +}; + app.on("before-quit", async (event) => { try { - container - .get(MAIN_TOKENS.WorkspaceServerService) - .stop(); + container.get(WORKSPACE_SERVER_SERVICE).stop(); } catch {} let lifecycleService: AppLifecycleService; try { lifecycleService = container.get( - MAIN_TOKENS.AppLifecycleService, + APP_LIFECYCLE_SERVICE, ); } catch { // Container already torn down (e.g. second quit during shutdown), let Electron quit @@ -397,20 +417,21 @@ app.on("before-quit", async (event) => { event.preventDefault(); - await lifecycleService.gracefulExit(); + await lifecycleService.gracefulExit(teardownContainer); }); const handleShutdownSignal = async (signal: string) => { log.info(`Received ${signal}, starting shutdown`); try { const lifecycleService = container.get( - MAIN_TOKENS.AppLifecycleService, + APP_LIFECYCLE_SERVICE, ); if (lifecycleService.isShuttingDown) { log.warn(`${signal} received during shutdown, forcing exit`); process.exit(1); } await lifecycleService.shutdown(); + await teardownContainer(); } catch (_err) { // Container torn down or shutdown failed } diff --git a/apps/code/src/main/menu.ts b/apps/code/src/main/menu.ts index be80db42b3..29276b4ecd 100644 --- a/apps/code/src/main/menu.ts +++ b/apps/code/src/main/menu.ts @@ -17,7 +17,7 @@ import { shell, } from "electron"; import { container } from "./di/container"; -import { MAIN_TOKENS } from "./di/tokens"; +import { AUTH_SERVICE, UPDATES_SERVICE } from "./di/tokens"; import { isDevBuild } from "./utils/env"; import { getLogFilePath } from "./utils/logger"; @@ -113,7 +113,7 @@ function buildAppMenu(): MenuItemConstructorOptions { label: "Check for Updates...", click: () => { container - .get(MAIN_TOKENS.UpdatesService) + .get(UPDATES_SERVICE) .triggerMenuCheck(); }, }, @@ -222,7 +222,7 @@ function buildFileMenu(): MenuItemConstructorOptions { label: "Force refresh of OAuth token", click: () => { container - .get(MAIN_TOKENS.AuthService) + .get(AUTH_SERVICE) .refreshAccessToken() .then(() => { dialog.showMessageBox({ diff --git a/apps/code/src/main/services/app-lifecycle/service.test.ts b/apps/code/src/main/services/app-lifecycle/service.test.ts index ddfa5dc576..217e8cd992 100644 --- a/apps/code/src/main/services/app-lifecycle/service.test.ts +++ b/apps/code/src/main/services/app-lifecycle/service.test.ts @@ -4,7 +4,6 @@ import { AppLifecycleService } from "./service"; const { mockAppLifecycle, - mockContainer, mockDatabaseService, mockSuspensionService, mockWatcherRegistry, @@ -40,10 +39,6 @@ const { onQuit: vi.fn(() => () => {}), registerDeepLinkScheme: vi.fn(), }, - mockContainer: { - unbindAll: vi.fn(() => Promise.resolve()), - get: vi.fn(() => mockDatabaseService), - }, mockDatabaseService, mockTrackAppEvent: vi.fn(), mockShutdownPostHog: vi.fn(() => Promise.resolve()), @@ -74,10 +69,6 @@ vi.mock("../../platform-adapters/posthog-analytics.js", () => ({ }, })); -vi.mock("../../di/container.js", () => ({ - container: mockContainer, -})); - vi.mock("@posthog/shared/analytics-events", () => ({ ANALYTICS_EVENTS: { APP_QUIT: "app_quit", @@ -137,13 +128,6 @@ describe("AppLifecycleService", () => { }); describe("shutdown", () => { - it("unbinds all container services", async () => { - const promise = service.shutdown(); - await vi.runAllTimersAsync(); - await promise; - expect(mockContainer.unbindAll).toHaveBeenCalled(); - }); - it("tracks app quit event", async () => { const promise = service.shutdown(); await vi.runAllTimersAsync(); @@ -164,9 +148,6 @@ describe("AppLifecycleService", () => { mockDatabaseService.close.mockImplementation(() => { callOrder.push("dbClose"); }); - mockContainer.unbindAll.mockImplementation(async () => { - callOrder.push("unbindAll"); - }); mockTrackAppEvent.mockImplementation(() => { callOrder.push("trackAppEvent"); }); @@ -183,7 +164,6 @@ describe("AppLifecycleService", () => { expect(callOrder).toEqual([ "dbClose", - "unbindAll", "trackAppEvent", "shutdownOtelTransport", "shutdownPostHog", @@ -197,17 +177,6 @@ describe("AppLifecycleService", () => { expect(mockDatabaseService.close).toHaveBeenCalled(); }); - it("continues shutdown if container unbind fails", async () => { - mockContainer.unbindAll.mockRejectedValue(new Error("unbind failed")); - - const promise = service.shutdown(); - await vi.runAllTimersAsync(); - await promise; - - expect(mockTrackAppEvent).toHaveBeenCalled(); - expect(mockShutdownPostHog).toHaveBeenCalled(); - }); - it("continues shutdown if PostHog shutdown fails", async () => { mockShutdownPostHog.mockRejectedValue(new Error("posthog failed")); @@ -225,7 +194,7 @@ describe("AppLifecycleService", () => { }); it("force-exits when shutdown times out", async () => { - mockContainer.unbindAll.mockReturnValue(new Promise(() => {})); + mockShutdownOtelTransport.mockReturnValue(new Promise(() => {})); const promise = service.shutdown(); @@ -243,9 +212,6 @@ describe("AppLifecycleService", () => { mockDatabaseService.close.mockImplementation(() => { callOrder.push("dbClose"); }); - mockContainer.unbindAll.mockImplementation(async () => { - callOrder.push("unbindAll"); - }); mockAppLifecycle.exit.mockImplementation(() => { callOrder.push("exit"); }); @@ -264,5 +230,26 @@ describe("AppLifecycleService", () => { await promise; expect(mockAppLifecycle.exit).toHaveBeenCalledWith(0); }); + + it("runs the beforeExit hook after shutdown and before exit", async () => { + const callOrder: string[] = []; + + mockDatabaseService.close.mockImplementation(() => { + callOrder.push("dbClose"); + }); + mockAppLifecycle.exit.mockImplementation(() => { + callOrder.push("exit"); + }); + const beforeExit = vi.fn(async () => { + callOrder.push("beforeExit"); + }); + + const promise = service.gracefulExit(beforeExit); + await vi.runAllTimersAsync(); + await promise; + + expect(beforeExit).toHaveBeenCalledTimes(1); + expect(callOrder).toEqual(["dbClose", "beforeExit", "exit"]); + }); }); }); diff --git a/apps/code/src/main/services/app-lifecycle/service.ts b/apps/code/src/main/services/app-lifecycle/service.ts index cb111c0c1a..9037a7e966 100644 --- a/apps/code/src/main/services/app-lifecycle/service.ts +++ b/apps/code/src/main/services/app-lifecycle/service.ts @@ -11,8 +11,7 @@ import { SUSPENSION_SERVICE } from "@posthog/workspace-server/services/suspensio import type { SuspensionService } from "@posthog/workspace-server/services/suspension/suspension"; import type { WatcherRegistryService } from "@posthog/workspace-server/services/watcher-registry/watcher-registry"; import { inject, injectable } from "inversify"; -import { container } from "../../di/container"; -import { MAIN_TOKENS } from "../../di/tokens"; +import { WATCHER_REGISTRY_SERVICE } from "../../di/tokens"; import { posthogNodeAnalytics } from "../../platform-adapters/posthog-analytics"; import { withTimeout } from "../../utils/async"; import { logger } from "../../utils/logger"; @@ -34,7 +33,7 @@ export class AppLifecycleService { private readonly db: DatabaseService, @inject(SUSPENSION_SERVICE) private readonly suspensionService: SuspensionService, - @inject(MAIN_TOKENS.WatcherRegistryService) + @inject(WATCHER_REGISTRY_SERVICE) private readonly watcherRegistry: WatcherRegistryService, @inject(PROCESS_TRACKING_SERVICE) private readonly processTracking: ProcessTrackingService, @@ -103,15 +102,20 @@ export class AppLifecycleService { } /** - * Runs a full shutdown then exits the Electron app. + * Runs a full shutdown then exits the Electron app. The optional + * `beforeExit` hook lets the composition root tear down the DI container + * after shutdown completes but before the process exits. */ - async gracefulExit(): Promise { + async gracefulExit(beforeExit?: () => Promise): Promise { await this.shutdown(); + if (beforeExit) { + await beforeExit(); + } this.appLifecycle.exit(0); } /** - * Runs the full shutdown sequence: native resources, container, analytics. + * Runs the full shutdown sequence: native resources, database, analytics. */ private async doShutdown(): Promise { log.info("Shutdown started"); @@ -130,12 +134,6 @@ export class AppLifecycleService { log.warn("Failed to close database during shutdown", error); } - try { - await container.unbindAll(); - } catch (error) { - log.warn("Failed to unbind container", error); - } - posthogNodeAnalytics.track(ANALYTICS_EVENTS.APP_QUIT); try { diff --git a/apps/code/src/main/services/auth/port-adapters.ts b/apps/code/src/main/services/auth/port-adapters.ts index 32cfbb4d7d..9747fb9423 100644 --- a/apps/code/src/main/services/auth/port-adapters.ts +++ b/apps/code/src/main/services/auth/port-adapters.ts @@ -22,7 +22,11 @@ import type { WorkspaceClient } from "@posthog/workspace-client/client"; import type { IAuthPreferenceRepository } from "@posthog/workspace-server/db/repositories/auth-preference-repository"; import type { IAuthSessionRepository } from "@posthog/workspace-server/db/repositories/auth-session-repository"; import { inject, injectable } from "inversify"; -import { MAIN_TOKENS } from "../../di/tokens"; +import { + AUTH_PREFERENCE_REPOSITORY, + AUTH_SESSION_REPOSITORY, + WORKSPACE_CLIENT, +} from "../../di/tokens"; import { decrypt, encrypt } from "../../utils/encryption"; @injectable() @@ -66,7 +70,7 @@ export class OAuthFlowPortAdapter implements IAuthOAuthFlowService { @injectable() export class AuthSessionPortAdapter implements IAuthSessionStore { constructor( - @inject(MAIN_TOKENS.AuthSessionRepository) + @inject(AUTH_SESSION_REPOSITORY) private readonly repository: IAuthSessionRepository, ) {} @@ -95,7 +99,7 @@ export class AuthSessionPortAdapter implements IAuthSessionStore { @injectable() export class AuthPreferencePortAdapter implements IAuthPreferenceStore { constructor( - @inject(MAIN_TOKENS.AuthPreferenceRepository) + @inject(AUTH_PREFERENCE_REPOSITORY) private readonly repository: IAuthPreferenceRepository, ) {} @@ -147,7 +151,7 @@ export class ConnectivityPortAdapter implements IAuthConnectivity { private readonly handlers = new Set<(status: ConnectivityStatus) => void>(); constructor( - @inject(MAIN_TOKENS.WorkspaceClient) + @inject(WORKSPACE_CLIENT) private readonly workspace: WorkspaceClient, ) { this.workspace.connectivity.onStatusChange.subscribe(undefined, { diff --git a/apps/code/src/main/services/secure-store/service.ts b/apps/code/src/main/services/secure-store/service.ts index 8bfec2beea..297b688e2d 100644 --- a/apps/code/src/main/services/secure-store/service.ts +++ b/apps/code/src/main/services/secure-store/service.ts @@ -1,4 +1,4 @@ -import { MAIN_TOKENS } from "@main/di/tokens"; +import { SECURE_STORE_BACKEND } from "@main/di/tokens"; import { decrypt, encrypt } from "@main/utils/encryption"; import { logger } from "@main/utils/logger"; import { inject, injectable } from "inversify"; @@ -28,7 +28,7 @@ export interface SecureStoreBackend { @injectable() export class SecureStoreService { constructor( - @inject(MAIN_TOKENS.SecureStoreBackend) + @inject(SECURE_STORE_BACKEND) private readonly store: SecureStoreBackend, ) {} diff --git a/apps/code/src/main/services/workspace-server/service.test.ts b/apps/code/src/main/services/workspace-server/service.test.ts new file mode 100644 index 0000000000..8f9f69afd4 --- /dev/null +++ b/apps/code/src/main/services/workspace-server/service.test.ts @@ -0,0 +1,157 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +vi.mock("../../utils/logger.js", () => ({ + logger: { + scope: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), + }, +})); + +import type { WorkspaceConnection } from "@posthog/workspace-client/client"; +import { + WorkspaceServerEvent, + WorkspaceServerService, + WorkspaceServerStatus, +} from "./service"; + +const CONNECTION: WorkspaceConnection = { + url: "http://127.0.0.1:9999", + secret: "test-secret", +}; + +type Internals = { + spawnChild: () => Promise; + connection: WorkspaceConnection | null; +}; + +function internals(service: WorkspaceServerService): Internals { + return service as unknown as Internals; +} + +function withHealthySpawn(service: WorkspaceServerService) { + const spawn = vi.fn(async () => { + // Defer like the real spawnChild, which only sets the connection after the + // async health poll, so concurrent start() callers coalesce on pendingStart. + await Promise.resolve(); + internals(service).connection = CONNECTION; + return CONNECTION; + }); + internals(service).spawnChild = spawn; + return spawn; +} + +function withFailingSpawn(service: WorkspaceServerService) { + const spawn = vi.fn(async (): Promise => { + throw new Error("unhealthy"); + }); + internals(service).spawnChild = spawn; + return spawn; +} + +function trackStatuses( + service: WorkspaceServerService, +): WorkspaceServerStatus[] { + const statuses: WorkspaceServerStatus[] = []; + service.on(WorkspaceServerEvent.StatusChanged, (event) => { + statuses.push(event.status); + }); + return statuses; +} + +describe("WorkspaceServerService", () => { + afterEach(() => { + vi.useRealTimers(); + vi.restoreAllMocks(); + }); + + describe("start", () => { + it("transitions idle -> starting -> ready and exposes the connection", async () => { + const service = new WorkspaceServerService(); + withHealthySpawn(service); + const statuses = trackStatuses(service); + + const result = await service.start(); + + expect(result).toEqual(CONNECTION); + expect(service.getConnection()).toEqual(CONNECTION); + expect(service.getStatus()).toBe(WorkspaceServerStatus.Ready); + expect(statuses).toEqual([ + WorkspaceServerStatus.Starting, + WorkspaceServerStatus.Ready, + ]); + expect(service.getStatusSnapshot()).toEqual({ + status: WorkspaceServerStatus.Ready, + attempt: 0, + }); + }); + + it("coalesces concurrent callers and does not respawn once connected", async () => { + const service = new WorkspaceServerService(); + const spawn = withHealthySpawn(service); + + const first = service.start(); + const second = service.start(); + expect(first).toBe(second); + await first; + + await expect(service.start()).resolves.toEqual(CONNECTION); + expect(spawn).toHaveBeenCalledTimes(1); + }); + }); + + describe("supervised restart", () => { + it("backs off, caps the attempts, then settles in failed", async () => { + vi.useFakeTimers(); + const service = new WorkspaceServerService(); + const spawn = withFailingSpawn(service); + const statuses = trackStatuses(service); + + service.start().catch(() => {}); + await vi.runAllTimersAsync(); + + expect(service.getStatus()).toBe(WorkspaceServerStatus.Failed); + // initial attempt + MAX_RESTART_ATTEMPTS (5) supervised retries + expect(spawn).toHaveBeenCalledTimes(6); + expect(statuses[0]).toBe(WorkspaceServerStatus.Starting); + expect(statuses).toContain(WorkspaceServerStatus.Retrying); + expect(statuses[statuses.length - 1]).toBe(WorkspaceServerStatus.Failed); + }); + + it("restart() resets the attempt budget after a failure", async () => { + vi.useFakeTimers(); + const service = new WorkspaceServerService(); + const spawn = withFailingSpawn(service); + + service.start().catch(() => {}); + await vi.runAllTimersAsync(); + expect(service.getStatus()).toBe(WorkspaceServerStatus.Failed); + + spawn.mockImplementation(async () => { + internals(service).connection = CONNECTION; + return CONNECTION; + }); + const result = await service.restart(); + + expect(result).toEqual(CONNECTION); + expect(service.getStatus()).toBe(WorkspaceServerStatus.Ready); + expect(service.getStatusSnapshot().attempt).toBe(0); + }); + }); + + describe("stop", () => { + it("goes idle, clears the connection and suppresses restarts", async () => { + const service = new WorkspaceServerService(); + withHealthySpawn(service); + await service.start(); + + service.stop(); + + expect(service.getStatus()).toBe(WorkspaceServerStatus.Idle); + expect(service.getConnection()).toBeNull(); + }); + }); +}); diff --git a/apps/code/src/main/services/workspace-server/service.ts b/apps/code/src/main/services/workspace-server/service.ts index ecbd7e402a..c1e0cb4a98 100644 --- a/apps/code/src/main/services/workspace-server/service.ts +++ b/apps/code/src/main/services/workspace-server/service.ts @@ -11,10 +11,26 @@ const HEALTH_POLL_INTERVAL_MS = 100; const HEALTH_POLL_TIMEOUT_MS = 5_000; const SHUTDOWN_GRACE_MS = 3_000; +const MAX_RESTART_ATTEMPTS = 5; +const RESTART_BASE_DELAY_MS = 500; +const RESTART_MAX_DELAY_MS = 30_000; + const log = logger.scope("workspace-server"); +export const WorkspaceServerStatus = { + Idle: "idle", + Starting: "starting", + Ready: "ready", + Retrying: "retrying", + Failed: "failed", +} as const; + +export type WorkspaceServerStatus = + (typeof WorkspaceServerStatus)[keyof typeof WorkspaceServerStatus]; + export const WorkspaceServerEvent = { ConnectionLost: "connectionLost", + StatusChanged: "statusChanged", } as const; export interface WorkspaceServerEvents { @@ -22,6 +38,11 @@ export interface WorkspaceServerEvents { code: number | null; signal: NodeJS.Signals | null; }; + [WorkspaceServerEvent.StatusChanged]: { + status: WorkspaceServerStatus; + attempt: number; + error?: string; + }; } @injectable() @@ -30,26 +51,125 @@ export class WorkspaceServerService extends TypedEventEmitter | null = null; + private status: WorkspaceServerStatus = WorkspaceServerStatus.Idle; + private restartAttempts = 0; + private restartTimer: NodeJS.Timeout | null = null; + private stopping = false; getConnection(): WorkspaceConnection | null { return this.connection; } + getStatus(): WorkspaceServerStatus { + return this.status; + } + + getStatusSnapshot(): { status: WorkspaceServerStatus; attempt: number } { + return { status: this.status, attempt: this.restartAttempts }; + } + start(): Promise { if (this.connection) return Promise.resolve(this.connection); if (this.pendingStart) return this.pendingStart; - this.pendingStart = this.spawnChild().finally(() => { - this.pendingStart = null; - }); + this.stopping = false; + this.clearRestartTimer(); + this.pendingStart = this.runStart(); return this.pendingStart; } stop(): void { - if (!this.child) return; + this.stopping = true; + this.clearRestartTimer(); + this.restartAttempts = 0; + this.setStatus(WorkspaceServerStatus.Idle); + const c = this.child; this.child = null; this.connection = null; + if (c) this.killChild(c); + } + + /** + * User-initiated restart, e.g. the renderer "Retry" action from the failed + * state. Resets the attempt budget so the supervisor gets a fresh set of + * retries, unlike the automatic restart path which keeps counting toward the + * cap. + */ + restart(): Promise { + this.stopping = false; + this.clearRestartTimer(); + this.restartAttempts = 0; + return this.start(); + } + + private async runStart(): Promise { + if (this.restartAttempts === 0) { + this.setStatus(WorkspaceServerStatus.Starting); + } + try { + const connection = await this.spawnChild(); + this.restartAttempts = 0; + this.pendingStart = null; + this.setStatus(WorkspaceServerStatus.Ready); + return connection; + } catch (error) { + this.pendingStart = null; + this.scheduleRestart(error); + throw error; + } + } + + /** + * Supervises restarts after an unexpected child exit or a failed start. + * Backs off exponentially and gives up after MAX_RESTART_ATTEMPTS, leaving + * the service in the "failed" state for the renderer to surface. The attempt + * budget resets only once a child becomes healthy again. + */ + private scheduleRestart(error?: unknown): void { + if (this.stopping) return; + if (this.pendingStart || this.restartTimer) return; + + if (this.restartAttempts >= MAX_RESTART_ATTEMPTS) { + this.setStatus(WorkspaceServerStatus.Failed, errorMessage(error)); + return; + } + + this.restartAttempts++; + const delay = Math.min( + RESTART_BASE_DELAY_MS * 2 ** (this.restartAttempts - 1), + RESTART_MAX_DELAY_MS, + ); + this.setStatus(WorkspaceServerStatus.Retrying, errorMessage(error)); + log.info("scheduling workspace-server restart", { + attempt: this.restartAttempts, + delayMs: delay, + }); + this.restartTimer = setTimeout(() => { + this.restartTimer = null; + // A failed restart re-enters scheduleRestart through runStart's catch. + void this.start().catch(() => {}); + }, delay); + this.restartTimer.unref(); + } + + private clearRestartTimer(): void { + if (this.restartTimer) { + clearTimeout(this.restartTimer); + this.restartTimer = null; + } + } + + private setStatus(status: WorkspaceServerStatus, error?: string): void { + this.status = status; + this.emit(WorkspaceServerEvent.StatusChanged, { + status, + attempt: this.restartAttempts, + error, + }); + } + + private killChild(c: ChildProcess): void { try { c.kill("SIGTERM"); } catch {} @@ -81,19 +201,22 @@ export class WorkspaceServerService extends TypedEventEmitter process.stdout.write(chunk)); c.stderr?.on("data", (chunk) => process.stderr.write(chunk)); c.once("exit", (code, signal) => { + if (this.child !== c) return; const wasConnected = this.connection !== null; this.child = null; this.connection = null; log.info("child exited", { code, signal }); - if (wasConnected) { + if (wasConnected && !this.stopping) { this.emit(WorkspaceServerEvent.ConnectionLost, { code, signal }); + this.scheduleRestart(); } }); this.child = c; if (!(await pollHealth(url))) { - this.stop(); + this.child = null; + this.killChild(c); throw new Error( `workspace-server failed to become healthy within ${HEALTH_POLL_TIMEOUT_MS}ms`, ); @@ -104,6 +227,10 @@ export class WorkspaceServerService extends TypedEventEmitter { return new Promise((resolve, reject) => { const s = createServer(); diff --git a/apps/code/src/main/trpc/routers/discord-presence.ts b/apps/code/src/main/trpc/routers/discord-presence.ts index bd120fb98e..2344726841 100644 --- a/apps/code/src/main/trpc/routers/discord-presence.ts +++ b/apps/code/src/main/trpc/routers/discord-presence.ts @@ -1,6 +1,6 @@ import { z } from "zod"; import { container } from "../../di/container"; -import { MAIN_TOKENS } from "../../di/tokens"; +import { DISCORD_PRESENCE_SERVICE } from "../../di/tokens"; import { DiscordPresenceServiceEvent, discordPresenceStateSchema, @@ -10,7 +10,7 @@ import type { DiscordPresenceService } from "../../services/discord-presence/ser import { publicProcedure, router } from "../trpc"; const getService = () => - container.get(MAIN_TOKENS.DiscordPresenceService); + container.get(DISCORD_PRESENCE_SERVICE); export const discordPresenceRouter = router({ getState: publicProcedure diff --git a/apps/code/src/main/trpc/routers/encryption.ts b/apps/code/src/main/trpc/routers/encryption.ts index 9f423e170e..6f28546eda 100644 --- a/apps/code/src/main/trpc/routers/encryption.ts +++ b/apps/code/src/main/trpc/routers/encryption.ts @@ -1,11 +1,10 @@ import { container } from "@main/di/container"; -import { MAIN_TOKENS } from "@main/di/tokens"; +import { ENCRYPTION_SERVICE } from "@main/di/tokens"; import type { EncryptionService } from "@main/services/encryption/service"; import { z } from "zod"; import { publicProcedure, router } from "../trpc"; -const getService = () => - container.get(MAIN_TOKENS.EncryptionService); +const getService = () => container.get(ENCRYPTION_SERVICE); export const encryptionRouter = router({ encrypt: publicProcedure diff --git a/apps/code/src/main/trpc/routers/workspace-server.ts b/apps/code/src/main/trpc/routers/workspace-server.ts index d2e442e8b5..26d214aefa 100644 --- a/apps/code/src/main/trpc/routers/workspace-server.ts +++ b/apps/code/src/main/trpc/routers/workspace-server.ts @@ -1,6 +1,6 @@ import { z } from "zod"; import { container } from "../../di/container"; -import { MAIN_TOKENS } from "../../di/tokens"; +import { WORKSPACE_SERVER_SERVICE } from "../../di/tokens"; import { WorkspaceServerEvent, type WorkspaceServerService, @@ -13,7 +13,7 @@ const connectionSchema = z.object({ }); const getService = () => - container.get(MAIN_TOKENS.WorkspaceServerService); + container.get(WORKSPACE_SERVER_SERVICE); export const workspaceServerRouter = router({ getConnection: publicProcedure.output(connectionSchema).query(async () => { @@ -21,6 +21,10 @@ export const workspaceServerRouter = router({ return service.getConnection() ?? service.start(); }), + restart: publicProcedure.mutation(async () => { + await getService().restart(); + }), + onConnectionLost: publicProcedure.subscription(async function* (opts) { const service = getService(); const iterable = service.toIterable(WorkspaceServerEvent.ConnectionLost, { @@ -30,4 +34,24 @@ export const workspaceServerRouter = router({ yield data; } }), + + onStatusChanged: publicProcedure.subscription(async function* (opts) { + const service = getService(); + const iterable = service.toIterable(WorkspaceServerEvent.StatusChanged, { + signal: opts.signal, + }); + // toIterable attaches its listener on the first pull. Prime it before + // reading the snapshot so a transition in between is buffered, not dropped. + const firstEvent = iterable.next(); + yield service.getStatusSnapshot(); + try { + let result = await firstEvent; + while (!result.done) { + yield result.value; + result = await iterable.next(); + } + } finally { + await iterable.return?.(undefined); + } + }), }); diff --git a/apps/code/src/renderer/components/Providers.tsx b/apps/code/src/renderer/components/Providers.tsx index 2dc1cf9a79..d7c50d58d4 100644 --- a/apps/code/src/renderer/components/Providers.tsx +++ b/apps/code/src/renderer/components/Providers.tsx @@ -9,14 +9,41 @@ import { } from "@renderer/trpc/client"; import { QueryClientProvider, + useMutation, useQuery, useQueryClient, } from "@tanstack/react-query"; import { useSubscription } from "@trpc/tanstack-react-query"; import { queryClient } from "@utils/queryClient"; import type React from "react"; +import { useCallback, useState } from "react"; import { HotkeysProvider } from "react-hotkeys-hook"; +function WorkspaceServerErrorBanner({ + onRetry, + disabled, +}: { + onRetry: () => void; + disabled?: boolean; +}) { + return ( +
+ The workspace server stopped and could not be restarted. + +
+ ); +} + function ConnectedWorkspaceProvider({ children, }: { @@ -24,22 +51,52 @@ function ConnectedWorkspaceProvider({ }) { const trpc = useTRPC(); const rqClient = useQueryClient(); + const [serverStatus, setServerStatus] = useState("ready"); const { data: connection } = useQuery( trpc.workspaceServer.getConnection.queryOptions(undefined, { staleTime: 30_000, }), ); + + const invalidateConnection = useCallback(() => { + rqClient.invalidateQueries({ + queryKey: trpc.workspaceServer.getConnection.queryKey(), + }); + }, [rqClient, trpc]); + + const restartServer = useMutation( + trpc.workspaceServer.restart.mutationOptions(), + ); + useSubscription( trpc.workspaceServer.onConnectionLost.subscriptionOptions(undefined, { - onData: () => { - rqClient.invalidateQueries({ - queryKey: trpc.workspaceServer.getConnection.queryKey(), - }); + onData: invalidateConnection, + }), + ); + + useSubscription( + trpc.workspaceServer.onStatusChanged.subscriptionOptions(undefined, { + onData: (data) => { + setServerStatus(data.status); + if (data.status === "ready") { + invalidateConnection(); + } }, }), ); + return ( + {serverStatus === "failed" ? ( + + restartServer.mutate(undefined, { + onSettled: () => invalidateConnection(), + }) + } + disabled={restartServer.isPending} + /> + ) : null} {children} ); diff --git a/apps/code/src/renderer/di/tokens.ts b/apps/code/src/renderer/di/tokens.ts index dead53490a..256d64cc13 100644 --- a/apps/code/src/renderer/di/tokens.ts +++ b/apps/code/src/renderer/di/tokens.ts @@ -10,11 +10,3 @@ export const TRPC_CLIENT = Symbol.for("posthog.host.renderer.trpc-client"); // Services export const TASK_SERVICE = Symbol.for("posthog.host.renderer.task-service"); - -export const RENDERER_TOKENS = Object.freeze({ - // Infrastructure - TRPCClient: TRPC_CLIENT, - - // Services - TaskService: TASK_SERVICE, -}); diff --git a/biome.jsonc b/biome.jsonc index 2e9ab97442..f2bf55073a 100644 --- a/biome.jsonc +++ b/biome.jsonc @@ -226,7 +226,6 @@ "electron", "node:*", "@posthog/*", - "!@posthog/core", "!@posthog/shared", "!@posthog/shared/*" ], @@ -256,7 +255,6 @@ "electron", "node:*", "@posthog/*", - "!@posthog/core", "!@posthog/api-client", "!@posthog/workspace-server" ], @@ -282,12 +280,7 @@ "options": { "patterns": [ { - "group": [ - "electron", - "node:*", - "@posthog/*", - "!@posthog/core" - ], + "group": ["electron", "node:*", "@posthog/*"], "message": "platform is interface-only." } ] @@ -313,8 +306,14 @@ }, "patterns": [ { - "group": ["@posthog/ui", "@posthog/api-client"], - "message": "workspace-server must not depend on UI or the PostHog API." + "group": [ + "@posthog/ui", + "@posthog/api-client", + "@posthog/core", + "@posthog/workspace-client", + "@posthog/host-router" + ], + "message": "workspace-server must not depend on UI, the PostHog API, core, workspace-client, or host-router." } ] } diff --git a/packages/agent/src/adapters/acp-connection.ts b/packages/agent/src/adapters/acp-connection.ts index c271e03b6c..f86b057914 100644 --- a/packages/agent/src/adapters/acp-connection.ts +++ b/packages/agent/src/adapters/acp-connection.ts @@ -8,6 +8,7 @@ import { type StreamPair, } from "../utils/streams"; import { ClaudeAcpAgent } from "./claude/claude-agent"; +import type { GatewayEnv } from "./claude/session/options"; import { CodexAcpAgent } from "./codex/codex-agent"; import type { CodexProcessOptions } from "./codex/spawn"; @@ -30,6 +31,8 @@ export type AcpConnectionConfig = { posthogApiConfig?: PostHogAPIConfig; /** Defaults to true when posthogApiConfig is set. Set to false to disable enrichment. */ enricherEnabled?: boolean; + /** Explicit gateway config for the Claude adapter — prevents global process.env mutation. */ + claudeGatewayEnv?: GatewayEnv; }; export type AcpConnection = { @@ -114,6 +117,7 @@ function createClaudeConnection(config: AcpConnectionConfig): AcpConnection { ...config.processCallbacks, onStructuredOutput: config.onStructuredOutput, posthogApiConfig: resolveEnricherApiConfig(config), + gatewayEnv: config.claudeGatewayEnv, }); return agent; }, agentStream); diff --git a/packages/agent/src/adapters/base-acp-agent.ts b/packages/agent/src/adapters/base-acp-agent.ts index e5d2e6b415..aaea9fb965 100644 --- a/packages/agent/src/adapters/base-acp-agent.ts +++ b/packages/agent/src/adapters/base-acp-agent.ts @@ -131,11 +131,16 @@ export abstract class BaseAcpAgent implements Agent { throw new Error("Method not implemented."); } - async getModelConfigOptions(currentModelOverride?: string): Promise<{ + async getModelConfigOptions( + currentModelOverride?: string, + gatewayUrl?: string, + ): Promise<{ currentModelId: string; options: SessionConfigSelectOption[]; }> { - this.gatewayModels = await fetchGatewayModels(); + this.gatewayModels = await fetchGatewayModels( + gatewayUrl ? { gatewayUrl } : undefined, + ); const options = this.gatewayModels .filter((model) => isAnthropicModel(model)) diff --git a/packages/agent/src/adapters/claude/claude-agent.ts b/packages/agent/src/adapters/claude/claude-agent.ts index 7c6d8220f4..9d1456b398 100644 --- a/packages/agent/src/adapters/claude/claude-agent.ts +++ b/packages/agent/src/adapters/claude/claude-agent.ts @@ -122,6 +122,7 @@ import { import { buildSessionOptions, buildSystemPrompt, + type GatewayEnv, type ProcessSpawnedInfo, } from "./session/options"; import { SettingsManager } from "./session/settings"; @@ -236,6 +237,8 @@ export interface ClaudeAcpAgentOptions { onMcpServersReady?: (serverNames: string[]) => void; onStructuredOutput?: (output: Record) => Promise; posthogApiConfig?: PostHogAPIConfig; + /** Explicit gateway config — avoids global process.env mutation across concurrent sessions. */ + gatewayEnv?: GatewayEnv; } export class ClaudeAcpAgent extends BaseAcpAgent { @@ -1740,6 +1743,7 @@ export class ClaudeAcpAgent extends BaseAcpAgent { onEnsureLocalToolsConnected: () => this.ensureLocalToolsConnected("guard-hook"), taskState, + gatewayEnv: this.options?.gatewayEnv, onTaskStateChange: async () => { await this.client.sessionUpdate({ sessionId, @@ -1844,6 +1848,7 @@ export class ClaudeAcpAgent extends BaseAcpAgent { const [rawModelOptions] = await Promise.all([ this.getModelConfigOptions( settingsManager.getSettings().model || meta?.model || undefined, + this.options?.gatewayEnv?.anthropicBaseUrl, ), ...(meta?.taskRunId ? [ diff --git a/packages/agent/src/adapters/claude/session/options.ts b/packages/agent/src/adapters/claude/session/options.ts index ece3813296..1e3a520a30 100644 --- a/packages/agent/src/adapters/claude/session/options.ts +++ b/packages/agent/src/adapters/claude/session/options.ts @@ -37,6 +37,22 @@ export interface ProcessSpawnedInfo { sessionId: string; } +/** + * Gateway config threaded explicitly through session creation so that + * concurrent Agent instances do not clobber each other's values via + * global `process.env` mutation. + */ +export type GatewayEnv = { + anthropicBaseUrl: string; + anthropicAuthToken: string; + openaiBaseUrl: string; + openaiApiKey: string; + /** Task-specific custom headers forwarded to the gateway (e.g. task_id, run_id). */ + anthropicCustomHeaders?: string; + /** PostHog project ID for per-team attribution headers. */ + posthogProjectId?: string; +}; + export interface BuildOptionsParams { cwd: string; mcpServers: Record; @@ -70,6 +86,8 @@ export interface BuildOptionsParams { /** Called after createTaskHook mutates taskState so callers can emit a plan * sessionUpdate to the client. */ onTaskStateChange?: () => Promise; + /** Explicit gateway config — prevents global process.env mutation. */ + gatewayEnv?: GatewayEnv; } export function buildSystemPrompt( @@ -116,13 +134,16 @@ function buildMcpServers( }; } -function buildEnvironment(): Record { +function buildEnvironment(gateway?: GatewayEnv): Record { // Custom HTTP headers reach the model only through the Claude CLI subprocess, // which reads them from this env var (newline-delimited `name: value` lines) // — the SDK has no direct header option. We finalize them here, the single // chokepoint every session (desktop and cloud) funnels through. const headerLines: string[] = []; - const existingCustomHeaders = process.env.ANTHROPIC_CUSTOM_HEADERS; + // Prefer explicit gateway config over process.env so concurrent sessions + // do not clobber each other's task-specific headers. + const existingCustomHeaders = + gateway?.anthropicCustomHeaders ?? process.env.ANTHROPIC_CUSTOM_HEADERS; if (existingCustomHeaders) { headerLines.push(existingCustomHeaders); } @@ -132,7 +153,7 @@ function buildEnvironment(): Record { // the event; both entrypoints export POSTHOG_PROJECT_ID before this runs // (workspace-server auth-adapter.ts, server/agent-server.ts). Mirrors django's // get_llm_client(team_id=...). - const projectId = process.env.POSTHOG_PROJECT_ID; + const projectId = gateway?.posthogProjectId ?? process.env.POSTHOG_PROJECT_ID; if (projectId) { headerLines.push(`x-posthog-property-team_id: ${projectId}`); } @@ -149,6 +170,18 @@ function buildEnvironment(): Record { return { ...process.env, + // Explicit gateway values win over whatever happens to be in process.env. + // This prevents concurrent Agent instances from clobbering each other's + // gateway config when process.env was mutated globally. + ...(gateway?.anthropicBaseUrl && { + ANTHROPIC_BASE_URL: gateway.anthropicBaseUrl, + }), + ...(gateway?.anthropicAuthToken && { + ANTHROPIC_AUTH_TOKEN: gateway.anthropicAuthToken, + ANTHROPIC_API_KEY: gateway.anthropicAuthToken, + }), + ...(gateway?.openaiBaseUrl && { OPENAI_BASE_URL: gateway.openaiBaseUrl }), + ...(gateway?.openaiApiKey && { OPENAI_API_KEY: gateway.openaiApiKey }), ELECTRON_RUN_AS_NODE: "1", CLAUDE_CODE_ENABLE_ASK_USER_QUESTION_TOOL: "true", // Offload all MCP tools by default @@ -410,7 +443,7 @@ export function buildSessionOptions(params: BuildOptionsParams): Options { params.mcpServers, loadUserClaudeJsonMcpServers(params.cwd, params.logger), ), - env: buildEnvironment(), + env: buildEnvironment(params.gatewayEnv), hooks: buildHooks( params.userProvidedOptions?.hooks, params.onModeChange, diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index 0717efb514..28d26d627c 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -2,6 +2,7 @@ import { createAcpConnection, type InProcessAcpConnection, } from "./adapters/acp-connection"; +import type { GatewayEnv } from "./adapters/claude/session/options"; import { DEFAULT_CODEX_MODEL, DEFAULT_GATEWAY_MODEL, @@ -50,7 +51,7 @@ export class Agent { } } - private async _configureLlmGateway(overrideUrl?: string): Promise<{ + private async _resolveGatewayConfig(overrideUrl?: string): Promise<{ gatewayUrl: string; apiKey: string; } | null> { @@ -61,15 +62,9 @@ export class Agent { try { const gatewayUrl = overrideUrl ?? this.posthogAPI.getLlmGatewayUrl(); const apiKey = await this.posthogAPI.getApiKey(); - - process.env.OPENAI_BASE_URL = `${gatewayUrl}/v1`; - process.env.OPENAI_API_KEY = apiKey; - process.env.ANTHROPIC_BASE_URL = gatewayUrl; - process.env.ANTHROPIC_AUTH_TOKEN = apiKey; - return { gatewayUrl, apiKey }; } catch (error) { - this.logger.error("Failed to configure LLM gateway", error); + this.logger.error("Failed to resolve LLM gateway config", error); throw error; } } @@ -79,7 +74,7 @@ export class Agent { taskRunId: string, options: TaskExecutionOptions = {}, ): Promise { - const gatewayConfig = await this._configureLlmGateway(options.gatewayUrl); + const gatewayConfig = await this._resolveGatewayConfig(options.gatewayUrl); this.taskRunId = taskRunId; let allowedModelIds: Set | undefined; @@ -115,6 +110,16 @@ export class Agent { sanitizedModel = DEFAULT_GATEWAY_MODEL; } + const claudeGatewayEnv: GatewayEnv | undefined = + options.adapter !== "codex" && gatewayConfig + ? { + anthropicBaseUrl: gatewayConfig.gatewayUrl, + anthropicAuthToken: gatewayConfig.apiKey, + openaiBaseUrl: `${gatewayConfig.gatewayUrl}/v1`, + openaiApiKey: gatewayConfig.apiKey, + } + : undefined; + this.acpConnection = createAcpConnection({ adapter: options.adapter, logWriter: this.sessionLogWriter, @@ -127,6 +132,7 @@ export class Agent { allowedModelIds, posthogApiConfig: this.posthogApiConfig, enricherEnabled: this.enricherEnabled, + claudeGatewayEnv, codexOptions: options.adapter === "codex" && gatewayConfig ? { diff --git a/packages/agent/src/server/agent-server.configure-environment.test.ts b/packages/agent/src/server/agent-server.configure-environment.test.ts index db67f98703..b62f93383d 100644 --- a/packages/agent/src/server/agent-server.configure-environment.test.ts +++ b/packages/agent/src/server/agent-server.configure-environment.test.ts @@ -1,4 +1,5 @@ import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import type { GatewayEnv } from "../adapters/claude/session/options"; import type { Task } from "../types"; import { AgentServer } from "./agent-server"; @@ -12,16 +13,10 @@ interface TestableServer { taskRunId?: string | null; taskUserId?: number | null; taskTitle?: string | null; - }): void; + }): GatewayEnv; } -const ENV_KEYS_UNDER_TEST = [ - "LLM_GATEWAY_URL", - "ANTHROPIC_BASE_URL", - "OPENAI_BASE_URL", - "ANTHROPIC_CUSTOM_HEADERS", - "POSTHOG_PROJECT_ID", -] as const; +const ENV_KEYS_UNDER_TEST = ["LLM_GATEWAY_URL", "POSTHOG_PROJECT_ID"] as const; describe("AgentServer.configureEnvironment", () => { const originalEnv: Partial> = {}; @@ -57,101 +52,100 @@ describe("AgentServer.configureEnvironment", () => { }) as unknown as TestableServer; it("tags as background_agents when the task is internal", () => { - buildServer("interactive").configureEnvironment({ isInternal: true }); + const env = buildServer("interactive").configureEnvironment({ + isInternal: true, + }); - expect(process.env.LLM_GATEWAY_URL).toBe( + expect(env.anthropicBaseUrl).toBe( "https://gateway.us.posthog.com/background_agents", ); - expect(process.env.ANTHROPIC_BASE_URL).toBe( - "https://gateway.us.posthog.com/background_agents", - ); - expect(process.env.OPENAI_BASE_URL).toBe( + expect(env.openaiBaseUrl).toBe( "https://gateway.us.posthog.com/background_agents/v1", ); }); it("tags as posthog_code when the task is not internal", () => { - buildServer("background").configureEnvironment({ isInternal: false }); + const env = buildServer("background").configureEnvironment({ + isInternal: false, + }); - expect(process.env.LLM_GATEWAY_URL).toBe( + expect(env.anthropicBaseUrl).toBe( "https://gateway.us.posthog.com/posthog_code", ); }); - // The Claude session builder reads POSTHOG_PROJECT_ID to emit the - // `x-posthog-property-team_id` attribution header (see - // adapters/claude/session/options.ts), so the cloud path must export it. - it("exports POSTHOG_PROJECT_ID for the team_id attribution header", () => { + // The Claude session builder reads posthogProjectId from GatewayEnv to emit + // the `x-posthog-property-team_id` attribution header (see + // adapters/claude/session/options.ts), so the cloud path must include it. + it("includes posthogProjectId for the team_id attribution header", () => { + const env = buildServer("background").configureEnvironment({ + isInternal: false, + }); + + expect(env.posthogProjectId).toBe("1"); + }); + + // POSTHOG_PROJECT_ID is a server-level constant, safe to keep in process.env. + it("exports POSTHOG_PROJECT_ID to process.env for tools that inherit it", () => { buildServer("background").configureEnvironment({ isInternal: false }); expect(process.env.POSTHOG_PROJECT_ID).toBe("1"); }); it("tags as posthog_code when isInternal is omitted (getTask failure fallback)", () => { - buildServer("background").configureEnvironment(); + const env = buildServer("background").configureEnvironment(); - expect(process.env.LLM_GATEWAY_URL).toBe( + expect(env.anthropicBaseUrl).toBe( "https://gateway.us.posthog.com/posthog_code", ); }); it("ignores mode when picking the gateway product", () => { - buildServer("background").configureEnvironment({ isInternal: false }); - const fromBackground = process.env.LLM_GATEWAY_URL; - - // Clear the env var the first call wrote — resolveLlmGatewayUrl now treats - // a set LLM_GATEWAY_URL as an override base and appends the product on top - // of it, which would double up the product slug across back-to-back calls - // in the same process. - delete process.env.LLM_GATEWAY_URL; - buildServer("interactive").configureEnvironment({ isInternal: false }); - const fromInteractive = process.env.LLM_GATEWAY_URL; - - expect(fromBackground).toBe(fromInteractive); - expect(fromBackground).toBe("https://gateway.us.posthog.com/posthog_code"); + const fromBackground = buildServer("background").configureEnvironment({ + isInternal: false, + }); + const fromInteractive = buildServer("interactive").configureEnvironment({ + isInternal: false, + }); + + expect(fromBackground.anthropicBaseUrl).toBe( + fromInteractive.anthropicBaseUrl, + ); + expect(fromBackground.anthropicBaseUrl).toBe( + "https://gateway.us.posthog.com/posthog_code", + ); }); it("tags as signals when an internal task has origin_product 'signal_report'", () => { - buildServer("background").configureEnvironment({ + const env = buildServer("background").configureEnvironment({ isInternal: true, originProduct: "signal_report", }); - expect(process.env.LLM_GATEWAY_URL).toBe( - "https://gateway.us.posthog.com/signals", - ); - expect(process.env.ANTHROPIC_BASE_URL).toBe( - "https://gateway.us.posthog.com/signals", - ); - expect(process.env.OPENAI_BASE_URL).toBe( - "https://gateway.us.posthog.com/signals/v1", - ); + expect(env.anthropicBaseUrl).toBe("https://gateway.us.posthog.com/signals"); + expect(env.openaiBaseUrl).toBe("https://gateway.us.posthog.com/signals/v1"); }); it("tags as signals when origin_product is 'signal_report' even if the task is not internal", () => { - buildServer("background").configureEnvironment({ + const env = buildServer("background").configureEnvironment({ isInternal: false, originProduct: "signal_report", }); - expect(process.env.LLM_GATEWAY_URL).toBe( - "https://gateway.us.posthog.com/signals", - ); + expect(env.anthropicBaseUrl).toBe("https://gateway.us.posthog.com/signals"); }); it("tags as signals for scout runs (origin_product 'signals_scout'), internal or not", () => { - buildServer("background").configureEnvironment({ + const env = buildServer("background").configureEnvironment({ isInternal: false, originProduct: "signals_scout", }); - expect(process.env.LLM_GATEWAY_URL).toBe( - "https://gateway.us.posthog.com/signals", - ); + expect(env.anthropicBaseUrl).toBe("https://gateway.us.posthog.com/signals"); }); - it("forwards task metadata as ANTHROPIC_CUSTOM_HEADERS", () => { - buildServer("background").configureEnvironment({ + it("forwards task metadata as anthropicCustomHeaders", () => { + const env = buildServer("background").configureEnvironment({ isInternal: true, originProduct: "signal_report", signalReportId: "report-123", @@ -162,7 +156,7 @@ describe("AgentServer.configureEnvironment", () => { taskTitle: "Fix the bug", }); - expect(process.env.ANTHROPIC_CUSTOM_HEADERS).toBe( + expect(env.anthropicCustomHeaders).toBe( [ "x-posthog-property-task_origin_product: signal_report", "x-posthog-property-task_internal: true", @@ -176,96 +170,90 @@ describe("AgentServer.configureEnvironment", () => { ); }); - it("omits ai_stage from ANTHROPIC_CUSTOM_HEADERS when not provided", () => { - buildServer("background").configureEnvironment({ + it("omits ai_stage from anthropicCustomHeaders when not provided", () => { + const env = buildServer("background").configureEnvironment({ isInternal: false, taskId: "task-abc", }); - expect(process.env.ANTHROPIC_CUSTOM_HEADERS).not.toContain("ai_stage"); + expect(env.anthropicCustomHeaders).not.toContain("ai_stage"); }); // A signals_scout title is multi-line; it must not inject extra header lines. it("collapses newlines in the task title", () => { - buildServer("background").configureEnvironment({ + const env = buildServer("background").configureEnvironment({ isInternal: false, taskId: "task-abc", taskTitle: "[sandbox_prompt:signals_scout:signals-scout-logs]\nLine two", }); - expect(process.env.ANTHROPIC_CUSTOM_HEADERS).toContain( + expect(env.anthropicCustomHeaders).toContain( "x-posthog-property-task_title: [sandbox_prompt:signals_scout:signals-scout-logs] Line two", ); }); - it("omits signal_report_id from ANTHROPIC_CUSTOM_HEADERS for non-report tasks", () => { - buildServer("background").configureEnvironment({ + it("omits signal_report_id from anthropicCustomHeaders for non-report tasks", () => { + const env = buildServer("background").configureEnvironment({ isInternal: false, taskId: "task-abc", }); - expect(process.env.ANTHROPIC_CUSTOM_HEADERS).not.toContain( - "signal_report_id", - ); + expect(env.anthropicCustomHeaders).not.toContain("signal_report_id"); }); - it("omits optional task metadata from ANTHROPIC_CUSTOM_HEADERS when not provided", () => { - buildServer("background").configureEnvironment({ isInternal: false }); + it("omits optional task metadata from anthropicCustomHeaders when not provided", () => { + const env = buildServer("background").configureEnvironment({ + isInternal: false, + }); - expect(process.env.ANTHROPIC_CUSTOM_HEADERS).toBe( + expect(env.anthropicCustomHeaders).toBe( "x-posthog-property-task_internal: false", ); }); it("tags as slack_app when the task was initiated from Slack", () => { - buildServer("interactive").configureEnvironment({ + const env = buildServer("interactive").configureEnvironment({ originProduct: "slack", }); - expect(process.env.LLM_GATEWAY_URL).toBe( - "https://gateway.us.posthog.com/slack_app", - ); - expect(process.env.ANTHROPIC_BASE_URL).toBe( + expect(env.anthropicBaseUrl).toBe( "https://gateway.us.posthog.com/slack_app", ); - expect(process.env.OPENAI_BASE_URL).toBe( + expect(env.openaiBaseUrl).toBe( "https://gateway.us.posthog.com/slack_app/v1", ); }); it("prefers slack_app over background_agents when both signals are present", () => { - buildServer("interactive").configureEnvironment({ + const env = buildServer("interactive").configureEnvironment({ isInternal: true, originProduct: "slack", }); - expect(process.env.LLM_GATEWAY_URL).toBe( + expect(env.anthropicBaseUrl).toBe( "https://gateway.us.posthog.com/slack_app", ); }); it("falls back to posthog_code for non-slack origin products", () => { - buildServer("background").configureEnvironment({ + const env = buildServer("background").configureEnvironment({ originProduct: "user_created", }); - expect(process.env.LLM_GATEWAY_URL).toBe( + expect(env.anthropicBaseUrl).toBe( "https://gateway.us.posthog.com/posthog_code", ); }); it("routes PostHog AI origin through the posthog_ai product", () => { - buildServer("interactive").configureEnvironment({ + const env = buildServer("interactive").configureEnvironment({ originProduct: "posthog_ai", }); - expect(process.env.LLM_GATEWAY_URL).toBe( - "https://gateway.us.posthog.com/posthog_ai", - ); - expect(process.env.ANTHROPIC_BASE_URL).toBe( + expect(env.anthropicBaseUrl).toBe( "https://gateway.us.posthog.com/posthog_ai", ); - expect(process.env.OPENAI_BASE_URL).toBe( + expect(env.openaiBaseUrl).toBe( "https://gateway.us.posthog.com/posthog_ai/v1", ); }); @@ -277,15 +265,14 @@ describe("AgentServer.configureEnvironment", () => { // llm_gateway product, which OAuth tokens cannot use. process.env.LLM_GATEWAY_URL = "http://ngrok.test/proxy"; - buildServer("background").configureEnvironment({ isInternal: true }); + const env = buildServer("background").configureEnvironment({ + isInternal: true, + }); - expect(process.env.LLM_GATEWAY_URL).toBe( - "http://ngrok.test/proxy/background_agents", - ); - expect(process.env.ANTHROPIC_BASE_URL).toBe( + expect(env.anthropicBaseUrl).toBe( "http://ngrok.test/proxy/background_agents", ); - expect(process.env.OPENAI_BASE_URL).toBe( + expect(env.openaiBaseUrl).toBe( "http://ngrok.test/proxy/background_agents/v1", ); }); diff --git a/packages/agent/src/server/agent-server.ts b/packages/agent/src/server/agent-server.ts index 5b01dc9968..cd05523174 100644 --- a/packages/agent/src/server/agent-server.ts +++ b/packages/agent/src/server/agent-server.ts @@ -22,6 +22,7 @@ import { createAcpConnection, type InProcessAcpConnection, } from "../adapters/acp-connection"; +import type { GatewayEnv } from "../adapters/claude/session/options"; import { type AgentErrorClassification, classifyAgentError, @@ -938,7 +939,7 @@ export class AgentServer { }), ]); - this.configureEnvironment({ + const gatewayEnv = this.configureEnvironment({ isInternal: preTask?.internal === true, originProduct: preTask?.origin_product, signalReportId: preTask?.signal_report, @@ -988,11 +989,12 @@ export class AgentServer { deviceType: deviceInfo.type, logWriter, logger: this.logger, + claudeGatewayEnv: runtimeAdapter !== "codex" ? gatewayEnv : undefined, codexOptions: runtimeAdapter === "codex" ? { cwd: this.config.repositoryPath ?? "/tmp/workspace", - apiBaseUrl: process.env.OPENAI_BASE_URL, + apiBaseUrl: gatewayEnv.openaiBaseUrl, apiKey: this.config.apiKey, model: this.config.model ?? DEFAULT_CODEX_MODEL, reasoningEffort: this.config.reasoningEffort, @@ -2041,7 +2043,7 @@ ${signedCommitInstructions} taskRunId?: string | null; taskUserId?: number | null; taskTitle?: string | null; - } = {}): void { + } = {}): GatewayEnv { const { apiKey, apiUrl, projectId } = this.config; const product = resolveGatewayProduct({ isInternal, originProduct }); const gatewayUrl = resolveLlmGatewayUrl( @@ -2069,24 +2071,28 @@ ${signedCommitInstructions} task_title: taskTitle, }); + // Server-level constants that don't vary per task — safe to keep in + // process.env so spawned tools (PostHog MCP, workspace-server, etc.) can + // reach the PostHog API without explicit threading. Object.assign(process.env, { - // PostHog POSTHOG_API_KEY: apiKey, POSTHOG_API_URL: apiUrl, POSTHOG_API_HOST: apiUrl, POSTHOG_AUTH_HEADER: `Bearer ${apiKey}`, POSTHOG_PROJECT_ID: String(projectId), - // Anthropic - ANTHROPIC_API_KEY: apiKey, - ANTHROPIC_AUTH_TOKEN: apiKey, - ANTHROPIC_BASE_URL: gatewayUrl, - ANTHROPIC_CUSTOM_HEADERS: customHeaders, - // OpenAI (for models like GPT-4, o1, etc.) - OPENAI_API_KEY: apiKey, - OPENAI_BASE_URL: openaiBaseUrl, - // Generic gateway - LLM_GATEWAY_URL: gatewayUrl, }); + + // Task-specific gateway config is returned rather than written to + // process.env so that concurrent sessions do not clobber each other's + // gateway URL, auth token, or custom headers. + return { + anthropicBaseUrl: gatewayUrl, + anthropicAuthToken: apiKey, + openaiBaseUrl, + openaiApiKey: apiKey, + anthropicCustomHeaders: customHeaders, + posthogProjectId: String(projectId), + }; } private buildSlackQuestionRelayResponse( diff --git a/packages/agent/src/server/event-stream-sender.test.ts b/packages/agent/src/server/event-stream-sender.test.ts index 60b97e9cae..e4e936362d 100644 --- a/packages/agent/src/server/event-stream-sender.test.ts +++ b/packages/agent/src/server/event-stream-sender.test.ts @@ -593,7 +593,7 @@ describe("TaskRunEventStreamSender", () => { } }, }), - retryDelayMs: 5, + retryDelayMs: 10_000, stopTimeoutMs: 1, }); diff --git a/packages/core/src/sessions/sessionService.ts b/packages/core/src/sessions/sessionService.ts index 1cfcd02747..28993a584a 100644 --- a/packages/core/src/sessions/sessionService.ts +++ b/packages/core/src/sessions/sessionService.ts @@ -514,7 +514,7 @@ export class SessionService { */ private previewConfigOptionsCache = new Map< string, - Promise + { promise: Promise; fetchedAt: number } >(); /** * Initial cloud prompt text (user message + any channel CONTEXT.md block), @@ -3136,9 +3136,10 @@ export class SessionService { initialModel?: string, ): Promise { const cacheKey = `${apiHost}::${adapter}`; - let pending = this.previewConfigOptionsCache.get(cacheKey); - if (!pending) { - pending = this.d.trpc.agent.getPreviewConfigOptions + let entry = this.previewConfigOptionsCache.get(cacheKey); + if (!entry || Date.now() - entry.fetchedAt > 300_000) { + if (entry) this.previewConfigOptionsCache.delete(cacheKey); + const promise = this.d.trpc.agent.getPreviewConfigOptions .query({ apiHost, adapter }) .catch((err: unknown) => { this.d.log.warn( @@ -3149,13 +3150,18 @@ export class SessionService { error: err, }, ); - this.previewConfigOptionsCache.delete(cacheKey); + // Only evict if this entry is still the cached one; a concurrent + // refresh may have replaced it and we must not drop the fresh entry. + if (this.previewConfigOptionsCache.get(cacheKey) === entry) { + this.previewConfigOptionsCache.delete(cacheKey); + } return [] as SessionConfigOption[]; }); - this.previewConfigOptionsCache.set(cacheKey, pending); + entry = { promise, fetchedAt: Date.now() }; + this.previewConfigOptionsCache.set(cacheKey, entry); } - const previewOptions = await pending; + const previewOptions = await entry.promise; const extras = previewOptions .filter( (opt) => opt.category === "model" || opt.category === "thought_level", diff --git a/packages/core/src/sessions/sessionStore.ts b/packages/core/src/sessions/sessionStore.ts new file mode 100644 index 0000000000..2ea1cbad09 --- /dev/null +++ b/packages/core/src/sessions/sessionStore.ts @@ -0,0 +1,282 @@ +import type { ContentBlock } from "@agentclientprotocol/sdk"; +import type { + AcpMessage, + AgentSession, + OptimisticItem, + PermissionRequest, + QueuedMessage, + TaskRunStatus, +} from "@posthog/shared"; +import { immer } from "zustand/middleware/immer"; +import { createStore } from "zustand/vanilla"; + +export interface SessionState { + /** Sessions indexed by taskRunId */ + sessions: Record; + /** Index mapping taskId -> taskRunId for O(1) lookups */ + taskIdIndex: Record; +} + +export const sessionStore = createStore()( + immer(() => ({ + sessions: {}, + taskIdIndex: {}, + })), +); + +export const sessionStoreSetters = { + setSession: (session: AgentSession) => { + sessionStore.setState((state) => { + // Clean up old session if taskId already has a different taskRunId + const existingTaskRunId = state.taskIdIndex[session.taskId]; + if (existingTaskRunId && existingTaskRunId !== session.taskRunId) { + delete state.sessions[existingTaskRunId]; + } + + state.sessions[session.taskRunId] = session; + state.taskIdIndex[session.taskId] = session.taskRunId; + }); + }, + + removeSession: (taskRunId: string) => { + sessionStore.setState((state) => { + const session = state.sessions[taskRunId]; + if (session) { + delete state.taskIdIndex[session.taskId]; + } + delete state.sessions[taskRunId]; + }); + }, + + updateSession: (taskRunId: string, updates: Partial) => { + sessionStore.setState((state) => { + if (state.sessions[taskRunId]) { + Object.assign(state.sessions[taskRunId], updates); + } + }); + }, + + appendEvents: ( + taskRunId: string, + events: AcpMessage[], + newLineCount?: number, + ) => { + sessionStore.setState((state) => { + const session = state.sessions[taskRunId]; + if (session) { + session.events.push(...events); + if (newLineCount !== undefined) { + session.processedLineCount = newLineCount; + } + } + }); + }, + + updateCloudStatus: ( + taskRunId: string, + fields: { + status?: TaskRunStatus; + stage?: string | null; + output?: Record | null; + errorMessage?: string | null; + branch?: string | null; + }, + ) => { + sessionStore.setState((state) => { + const session = state.sessions[taskRunId]; + if (!session) return; + if (fields.status !== undefined) session.cloudStatus = fields.status; + if (fields.stage !== undefined) session.cloudStage = fields.stage; + if (fields.output !== undefined) session.cloudOutput = fields.output; + if (fields.errorMessage !== undefined) + session.cloudErrorMessage = fields.errorMessage; + if (fields.branch !== undefined) session.cloudBranch = fields.branch; + }); + }, + + setPendingPermissions: ( + taskRunId: string, + permissions: Map, + ) => { + sessionStore.setState((state) => { + if (state.sessions[taskRunId]) { + state.sessions[taskRunId].pendingPermissions = permissions; + } + }); + }, + + enqueueMessage: ( + taskId: string, + content: string, + rawPrompt?: string | ContentBlock[], + ) => { + const id = `queue-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`; + sessionStore.setState((state) => { + const taskRunId = state.taskIdIndex[taskId]; + if (!taskRunId) return; + + const session = state.sessions[taskRunId]; + if (session) { + session.messageQueue.push({ + id, + content, + rawPrompt, + queuedAt: Date.now(), + }); + } + }); + }, + + removeQueuedMessage: (taskId: string, messageId: string) => { + sessionStore.setState((state) => { + const taskRunId = state.taskIdIndex[taskId]; + if (!taskRunId) return; + const session = state.sessions[taskRunId]; + if (session) { + session.messageQueue = session.messageQueue.filter( + (msg) => msg.id !== messageId, + ); + } + }); + }, + + clearMessageQueue: (taskId: string) => { + sessionStore.setState((state) => { + const taskRunId = state.taskIdIndex[taskId]; + if (!taskRunId) return; + + const session = state.sessions[taskRunId]; + if (session) { + session.messageQueue = []; + } + }); + }, + + dequeueMessagesAsText: (taskId: string): string | null => { + // Read the queue from the frozen committed state BEFORE entering the + // immer draft — same rationale as `dequeueMessages`: anything captured + // through a draft proxy can be revoked when setState exits. + const state = sessionStore.getState(); + const taskRunId = state.taskIdIndex[taskId]; + if (!taskRunId) return null; + const session = state.sessions[taskRunId]; + if (!session || session.messageQueue.length === 0) return null; + + const combined = session.messageQueue + .map((msg) => msg.content) + .join("\n\n"); + sessionStore.setState((draft) => { + const trid = draft.taskIdIndex[taskId]; + if (!trid) return; + const draftSession = draft.sessions[trid]; + if (draftSession) draftSession.messageQueue = []; + }); + return combined; + }, + + dequeueMessages: (taskId: string): QueuedMessage[] => { + // Read the queue from the frozen committed state BEFORE entering the + // immer draft, otherwise the items returned are proxies that get + // revoked when setState exits and any later access throws + // "Cannot perform 'get' on a proxy that has been revoked". + const state = sessionStore.getState(); + const taskRunId = state.taskIdIndex[taskId]; + if (!taskRunId) return []; + const session = state.sessions[taskRunId]; + if (!session || session.messageQueue.length === 0) return []; + + const queuedMessages = [...session.messageQueue]; + + sessionStore.setState((draft) => { + const trid = draft.taskIdIndex[taskId]; + if (!trid) return; + const draftSession = draft.sessions[trid]; + if (draftSession) { + draftSession.messageQueue = []; + } + }); + + return queuedMessages; + }, + + /** + * Splice messages back at the head of the queue. Used to roll back a + * dispatch attempt that drained the queue but failed before delivery. + */ + prependQueuedMessages: (taskId: string, messages: QueuedMessage[]) => { + if (messages.length === 0) return; + sessionStore.setState((state) => { + const taskRunId = state.taskIdIndex[taskId]; + if (!taskRunId) return; + const session = state.sessions[taskRunId]; + if (!session) return; + session.messageQueue = [...messages, ...session.messageQueue]; + }); + }, + + appendOptimisticItem: ( + taskRunId: string, + item: OptimisticItem extends infer T + ? T extends { id: string } + ? Omit + : never + : never, + ): void => { + const id = `optimistic-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`; + sessionStore.setState((state) => { + const session = state.sessions[taskRunId]; + if (session) { + session.optimisticItems.push({ ...item, id } as OptimisticItem); + } + }); + }, + + clearOptimisticItems: (taskRunId: string): void => { + sessionStore.setState((state) => { + const session = state.sessions[taskRunId]; + if (session) { + session.optimisticItems = []; + } + }); + }, + + clearTailOptimisticItems: (taskRunId: string): void => { + sessionStore.setState((state) => { + const session = state.sessions[taskRunId]; + if (session) { + session.optimisticItems = session.optimisticItems.filter( + (item) => item.type !== "user_message" || item.pinToTop !== false, + ); + } + }); + }, + + replaceOptimisticWithEvent: (taskRunId: string, event: AcpMessage): void => { + sessionStore.setState((state) => { + const session = state.sessions[taskRunId]; + if (session) { + session.events.push(event); + session.optimisticItems = []; + } + }); + }, + + /** O(1) lookup using taskIdIndex */ + getSessionByTaskId: (taskId: string): AgentSession | undefined => { + const state = sessionStore.getState(); + const taskRunId = state.taskIdIndex[taskId]; + if (!taskRunId) return undefined; + return state.sessions[taskRunId]; + }, + + getSessions: (): Record => { + return sessionStore.getState().sessions; + }, + + clearAll: () => { + sessionStore.setState((state) => { + state.sessions = {}; + state.taskIdIndex = {}; + }); + }, +}; diff --git a/packages/ui/src/features/sessions/sessionStore.ts b/packages/ui/src/features/sessions/sessionStore.ts index 7d91221d6c..b9432901f0 100644 --- a/packages/ui/src/features/sessions/sessionStore.ts +++ b/packages/ui/src/features/sessions/sessionStore.ts @@ -3,7 +3,11 @@ import type { SessionConfigOption, } from "@agentclientprotocol/sdk"; import { - type AcpMessage, + type SessionState, + sessionStore, + sessionStoreSetters, +} from "@posthog/core/sessions/sessionStore"; +import { type Adapter, type AgentSession, cycleModeOption, @@ -19,10 +23,9 @@ import { type SessionStatus, type TaskRunStatus, } from "@posthog/shared"; -import { create } from "zustand"; -import { immer } from "zustand/middleware/immer"; +import { useStore } from "zustand"; -// --- Types --- +// --- Type re-exports --- export type { Adapter, @@ -35,6 +38,8 @@ export type { SessionStatus, TaskRunStatus, }; +export type { ContentBlock }; +export type { SessionState }; export { cycleModeOption, flattenSelectOptions, @@ -44,21 +49,28 @@ export { mergeConfigOptions, }; -export interface SessionState { - /** Sessions indexed by taskRunId */ - sessions: Record; - /** Index mapping taskId -> taskRunId for O(1) lookups */ - taskIdIndex: Record; -} +// --- Setter re-export --- + +export { sessionStoreSetters }; -// --- Store --- +// --- React hook backed by the core vanilla store --- + +function useSessionStoreHook( + selector: (s: SessionState) => T, + equalityFn?: (a: T, b: T) => boolean, +): T { + return useStore(sessionStore, selector, equalityFn); +} -export const useSessionStore = create()( - immer(() => ({ - sessions: {}, - taskIdIndex: {}, - })), -); +export const useSessionStore: typeof useSessionStoreHook & { + getState: typeof sessionStore.getState; + setState: typeof sessionStore.setState; + subscribe: typeof sessionStore.subscribe; +} = Object.assign(useSessionStoreHook, { + getState: () => sessionStore.getState(), + setState: sessionStore.setState.bind(sessionStore), + subscribe: sessionStore.subscribe.bind(sessionStore), +}); // --- Re-exports --- @@ -77,262 +89,3 @@ export { useSessions, useThoughtLevelConfigOptionForTask, } from "./useSession"; - -// --- Setters --- - -export const sessionStoreSetters = { - setSession: (session: AgentSession) => { - useSessionStore.setState((state) => { - // Clean up old session if taskId already has a different taskRunId - const existingTaskRunId = state.taskIdIndex[session.taskId]; - if (existingTaskRunId && existingTaskRunId !== session.taskRunId) { - delete state.sessions[existingTaskRunId]; - } - - state.sessions[session.taskRunId] = session; - state.taskIdIndex[session.taskId] = session.taskRunId; - }); - }, - - removeSession: (taskRunId: string) => { - useSessionStore.setState((state) => { - const session = state.sessions[taskRunId]; - if (session) { - delete state.taskIdIndex[session.taskId]; - } - delete state.sessions[taskRunId]; - }); - }, - - updateSession: (taskRunId: string, updates: Partial) => { - useSessionStore.setState((state) => { - if (state.sessions[taskRunId]) { - Object.assign(state.sessions[taskRunId], updates); - } - }); - }, - - appendEvents: ( - taskRunId: string, - events: AcpMessage[], - newLineCount?: number, - ) => { - useSessionStore.setState((state) => { - const session = state.sessions[taskRunId]; - if (session) { - session.events.push(...events); - if (newLineCount !== undefined) { - session.processedLineCount = newLineCount; - } - } - }); - }, - - updateCloudStatus: ( - taskRunId: string, - fields: { - status?: TaskRunStatus; - stage?: string | null; - output?: Record | null; - errorMessage?: string | null; - branch?: string | null; - }, - ) => { - useSessionStore.setState((state) => { - const session = state.sessions[taskRunId]; - if (!session) return; - if (fields.status !== undefined) session.cloudStatus = fields.status; - if (fields.stage !== undefined) session.cloudStage = fields.stage; - if (fields.output !== undefined) session.cloudOutput = fields.output; - if (fields.errorMessage !== undefined) - session.cloudErrorMessage = fields.errorMessage; - if (fields.branch !== undefined) session.cloudBranch = fields.branch; - }); - }, - - setPendingPermissions: ( - taskRunId: string, - permissions: Map, - ) => { - useSessionStore.setState((state) => { - if (state.sessions[taskRunId]) { - state.sessions[taskRunId].pendingPermissions = permissions; - } - }); - }, - - enqueueMessage: ( - taskId: string, - content: string, - rawPrompt?: string | ContentBlock[], - ) => { - const id = `queue-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`; - useSessionStore.setState((state) => { - const taskRunId = state.taskIdIndex[taskId]; - if (!taskRunId) return; - - const session = state.sessions[taskRunId]; - if (session) { - session.messageQueue.push({ - id, - content, - rawPrompt, - queuedAt: Date.now(), - }); - } - }); - }, - - removeQueuedMessage: (taskId: string, messageId: string) => { - useSessionStore.setState((state) => { - const taskRunId = state.taskIdIndex[taskId]; - if (!taskRunId) return; - const session = state.sessions[taskRunId]; - if (session) { - session.messageQueue = session.messageQueue.filter( - (msg) => msg.id !== messageId, - ); - } - }); - }, - - clearMessageQueue: (taskId: string) => { - useSessionStore.setState((state) => { - const taskRunId = state.taskIdIndex[taskId]; - if (!taskRunId) return; - - const session = state.sessions[taskRunId]; - if (session) { - session.messageQueue = []; - } - }); - }, - - dequeueMessagesAsText: (taskId: string): string | null => { - // Read the queue from the frozen committed state BEFORE entering the - // immer draft — same rationale as `dequeueMessages`: anything captured - // through a draft proxy can be revoked when setState exits. - const state = useSessionStore.getState(); - const taskRunId = state.taskIdIndex[taskId]; - if (!taskRunId) return null; - const session = state.sessions[taskRunId]; - if (!session || session.messageQueue.length === 0) return null; - - const combined = session.messageQueue - .map((msg) => msg.content) - .join("\n\n"); - useSessionStore.setState((draft) => { - const trid = draft.taskIdIndex[taskId]; - if (!trid) return; - const draftSession = draft.sessions[trid]; - if (draftSession) draftSession.messageQueue = []; - }); - return combined; - }, - - dequeueMessages: (taskId: string): QueuedMessage[] => { - // Read the queue from the frozen committed state BEFORE entering the - // immer draft, otherwise the items returned are proxies that get - // revoked when setState exits and any later access throws - // "Cannot perform 'get' on a proxy that has been revoked". - const state = useSessionStore.getState(); - const taskRunId = state.taskIdIndex[taskId]; - if (!taskRunId) return []; - const session = state.sessions[taskRunId]; - if (!session || session.messageQueue.length === 0) return []; - - const queuedMessages = [...session.messageQueue]; - - useSessionStore.setState((draft) => { - const trid = draft.taskIdIndex[taskId]; - if (!trid) return; - const draftSession = draft.sessions[trid]; - if (draftSession) { - draftSession.messageQueue = []; - } - }); - - return queuedMessages; - }, - - /** - * Splice messages back at the head of the queue. Used to roll back a - * dispatch attempt that drained the queue but failed before delivery. - */ - prependQueuedMessages: (taskId: string, messages: QueuedMessage[]) => { - if (messages.length === 0) return; - useSessionStore.setState((state) => { - const taskRunId = state.taskIdIndex[taskId]; - if (!taskRunId) return; - const session = state.sessions[taskRunId]; - if (!session) return; - session.messageQueue = [...messages, ...session.messageQueue]; - }); - }, - - appendOptimisticItem: ( - taskRunId: string, - item: OptimisticItem extends infer T - ? T extends { id: string } - ? Omit - : never - : never, - ): void => { - const id = `optimistic-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`; - useSessionStore.setState((state) => { - const session = state.sessions[taskRunId]; - if (session) { - session.optimisticItems.push({ ...item, id } as OptimisticItem); - } - }); - }, - - clearOptimisticItems: (taskRunId: string): void => { - useSessionStore.setState((state) => { - const session = state.sessions[taskRunId]; - if (session) { - session.optimisticItems = []; - } - }); - }, - - clearTailOptimisticItems: (taskRunId: string): void => { - useSessionStore.setState((state) => { - const session = state.sessions[taskRunId]; - if (session) { - session.optimisticItems = session.optimisticItems.filter( - (item) => item.type !== "user_message" || item.pinToTop !== false, - ); - } - }); - }, - - replaceOptimisticWithEvent: (taskRunId: string, event: AcpMessage): void => { - useSessionStore.setState((state) => { - const session = state.sessions[taskRunId]; - if (session) { - session.events.push(event); - session.optimisticItems = []; - } - }); - }, - - /** O(1) lookup using taskIdIndex */ - getSessionByTaskId: (taskId: string): AgentSession | undefined => { - const state = useSessionStore.getState(); - const taskRunId = state.taskIdIndex[taskId]; - if (!taskRunId) return undefined; - return state.sessions[taskRunId]; - }, - - getSessions: (): Record => { - return useSessionStore.getState().sessions; - }, - - clearAll: () => { - useSessionStore.setState((state) => { - state.sessions = {}; - state.taskIdIndex = {}; - }); - }, -}; diff --git a/packages/ui/src/utils/clearStorage.ts b/packages/ui/src/utils/clearStorage.ts index b9f3aca301..9ef2b4b148 100644 --- a/packages/ui/src/utils/clearStorage.ts +++ b/packages/ui/src/utils/clearStorage.ts @@ -14,14 +14,23 @@ export function clearApplicationStorage(): void { if (!confirmed) return; - resolveService(HOST_TRPC_CLIENT) - .folders.clearAllData.mutate() - .then(() => { - localStorage.clear(); - window.location.reload(); - }) - .catch((error: unknown) => { - log.error("Failed to clear storage:", error); + const client = resolveService(HOST_TRPC_CLIENT); + + Promise.allSettled([ + client.folders.clearAllData.mutate(), + client.secureStore.clear.query(), + ]).then((results) => { + const rejected = results.filter( + (result): result is PromiseRejectedResult => result.status === "rejected", + ); + if (rejected.length > 0) { + for (const failure of rejected) { + log.error("Failed to clear application storage:", failure.reason); + } alert("Failed to clear storage. Please try again."); - }); + return; + } + localStorage.clear(); + window.location.reload(); + }); } diff --git a/packages/workspace-server/src/app.ts b/packages/workspace-server/src/app.ts index d0c08963bf..6af24af8a8 100644 --- a/packages/workspace-server/src/app.ts +++ b/packages/workspace-server/src/app.ts @@ -3,12 +3,13 @@ import { trpcServer } from "@hono/trpc-server"; import { Hono } from "hono"; import { createMiddleware } from "hono/factory"; import { HTTPException } from "hono/http-exception"; -import { appRouter } from "./trpc"; +import type { AppRouter } from "./trpc"; const SECRET_HEADER = "x-workspace-secret"; export interface CreateAppOptions { sharedSecret: string; + router: AppRouter; } export function createApp(options: CreateAppOptions): Hono { @@ -34,7 +35,7 @@ export function createApp(options: CreateAppOptions): Hono { }); app.use("/trpc/*", requireSecret); - app.use("/trpc/*", trpcServer({ router: appRouter })); + app.use("/trpc/*", trpcServer({ router: options.router })); return app; } diff --git a/packages/workspace-server/src/di/tokens.ts b/packages/workspace-server/src/di/tokens.ts index d591f13673..cc5cfaa92a 100644 --- a/packages/workspace-server/src/di/tokens.ts +++ b/packages/workspace-server/src/di/tokens.ts @@ -14,14 +14,3 @@ export const CONNECTIVITY_SERVICE = Symbol.for( export const ENVIRONMENT_SERVICE = Symbol.for( "posthog.workspace.environment-service", ); - -export const TOKENS = Object.freeze({ - FocusService: FOCUS_SERVICE, - FocusSyncService: FOCUS_SYNC_SERVICE, - GitService: GIT_SERVICE, - FsService: FS_SERVICE, - WatcherService: WATCHER_SERVICE, - LocalLogsService: LOCAL_LOGS_SERVICE, - ConnectivityService: CONNECTIVITY_SERVICE, - EnvironmentService: ENVIRONMENT_SERVICE, -}); diff --git a/packages/workspace-server/src/serve.ts b/packages/workspace-server/src/serve.ts index a1118bde4a..7e4b9c187d 100644 --- a/packages/workspace-server/src/serve.ts +++ b/packages/workspace-server/src/serve.ts @@ -3,6 +3,26 @@ import dns from "node:dns"; import net from "node:net"; import { serve } from "@hono/node-server"; import { createApp } from "./app"; +import { container } from "./di/container"; +import { + CONNECTIVITY_SERVICE, + ENVIRONMENT_SERVICE, + FOCUS_SERVICE, + FOCUS_SYNC_SERVICE, + FS_SERVICE, + GIT_SERVICE, + LOCAL_LOGS_SERVICE, + WATCHER_SERVICE, +} from "./di/tokens"; +import type { ConnectivityService } from "./services/connectivity/service"; +import type { EnvironmentService } from "./services/environment/service"; +import type { FocusService } from "./services/focus/service"; +import type { FocusSyncService } from "./services/focus/sync-service"; +import type { FsService } from "./services/fs/service"; +import type { GitService } from "./services/git/service"; +import type { LocalLogsService } from "./services/local-logs/service"; +import type { WatcherService } from "./services/watcher/service"; +import { createAppRouter } from "./trpc"; // Prefer IPv4 and disable "Happy Eyeballs" (mirrors apps/code main bootstrap). // This child makes all outbound HTTPS to PostHog/the gateway; its many-address @@ -33,7 +53,17 @@ if (!sharedSecret || !Number.isInteger(port) || port <= 0 || port > 65_535) { process.exit(2); } -const app = createApp({ sharedSecret }); +const router = createAppRouter({ + focusService: container.get(FOCUS_SERVICE), + focusSyncService: container.get(FOCUS_SYNC_SERVICE), + gitService: container.get(GIT_SERVICE), + fsService: container.get(FS_SERVICE), + watcherService: container.get(WATCHER_SERVICE), + localLogsService: container.get(LOCAL_LOGS_SERVICE), + connectivityService: container.get(CONNECTIVITY_SERVICE), + environmentService: container.get(ENVIRONMENT_SERVICE), +}); +const app = createApp({ sharedSecret, router }); let server: ReturnType | null = null; let shuttingDown = false; diff --git a/packages/workspace-server/src/services/agent/agent.ts b/packages/workspace-server/src/services/agent/agent.ts index ebf83a9d3a..d8ee377ef4 100644 --- a/packages/workspace-server/src/services/agent/agent.ts +++ b/packages/workspace-server/src/services/agent/agent.ts @@ -627,9 +627,6 @@ If a repository IS genuinely required, attach one in this priority order: this.validateSessionParams(params); const config = this.toSessionConfig(params); const session = await this.getOrCreateSession(config, false); - if (!session) { - throw new Error("Failed to create session"); - } return this.toSessionResponse(session); } @@ -648,6 +645,16 @@ If a repository IS genuinely required, attach one in this priority order: return session ? this.toSessionResponse(session) : null; } + private async getOrCreateSession( + config: SessionConfig, + isReconnect: false, + isRetry?: boolean, + ): Promise; + private async getOrCreateSession( + config: SessionConfig, + isReconnect: true, + isRetry?: boolean, + ): Promise; private async getOrCreateSession( config: SessionConfig, isReconnect: boolean, @@ -1024,7 +1031,10 @@ If a repository IS genuinely required, attach one in this priority order: `Auth error during ${isReconnect ? "reconnect" : "create"}, retrying`, { taskRunId }, ); - return this.getOrCreateSession(config, isReconnect, true); + if (isReconnect) { + return this.getOrCreateSession(config, true, true); + } + return this.getOrCreateSession(config, false, true); } // When the in-process ACP layer masks a thrown error as a generic // "Internal error", the real text survives in `data.details`. Surface it diff --git a/packages/workspace-server/src/services/git/task-pr-status.ts b/packages/workspace-server/src/services/git/task-pr-status.ts index c18749da87..77c39ab78f 100644 --- a/packages/workspace-server/src/services/git/task-pr-status.ts +++ b/packages/workspace-server/src/services/git/task-pr-status.ts @@ -2,7 +2,7 @@ import fs from "node:fs"; import { inject, injectable } from "inversify"; import { WORKSPACE_REPOSITORY } from "../../db/identifiers"; import type { IWorkspaceRepository } from "../../db/repositories/workspace-repository"; -import { TOKENS } from "../../di/tokens"; +import { GIT_SERVICE } from "../../di/tokens"; import { WORKSPACE_SERVICE } from "../workspace/identifiers"; import type { CachedPrUrlOutput, @@ -17,7 +17,7 @@ export class TaskPrStatusService { private readonly taskPrRevalidations = new Map>(); constructor( - @inject(TOKENS.GitService) + @inject(GIT_SERVICE) private readonly gitService: GitService, @inject(WORKSPACE_REPOSITORY) private readonly workspaceRepo: IWorkspaceRepository, diff --git a/packages/workspace-server/src/trpc.ts b/packages/workspace-server/src/trpc.ts index 87728e93da..629923dd30 100644 --- a/packages/workspace-server/src/trpc.ts +++ b/packages/workspace-server/src/trpc.ts @@ -1,8 +1,6 @@ import { initTRPC } from "@trpc/server"; import superjson from "superjson"; import { z } from "zod"; -import { container } from "./di/container"; -import { TOKENS } from "./di/tokens"; import { connectivityStatusOutput } from "./services/connectivity/schemas"; import type { ConnectivityService } from "./services/connectivity/service"; import { @@ -148,20 +146,6 @@ import type { WatcherService } from "./services/watcher/service"; const t = initTRPC.create({ transformer: superjson }); -const focusService = () => container.get(TOKENS.FocusService); -const focusSyncService = () => - container.get(TOKENS.FocusSyncService); -const gitService = () => container.get(TOKENS.GitService); -const fsService = () => container.get(TOKENS.FsService); -const watcherService = () => - container.get(TOKENS.WatcherService); -const localLogsService = () => - container.get(TOKENS.LocalLogsService); -const connectivityService = () => - container.get(TOKENS.ConnectivityService); -const environmentService = () => - container.get(TOKENS.EnvironmentService); - export { type FocusBranchRenamedEvent, type FocusForeignBranchCheckoutEvent, @@ -180,724 +164,776 @@ export { FileWatcherEventKind, } from "./services/watcher/schemas"; -export const appRouter = t.router({ - focus: t.router({ - getSession: t.procedure - .input(mainRepoPathInput) - .output(focusSessionSchema.nullable()) - .query(({ input }) => focusService().getSession(input.mainRepoPath)), - - saveSession: t.procedure - .input(focusSessionSchema) - .mutation(({ input }) => focusService().saveSession(input)), - - deleteSession: t.procedure - .input(mainRepoPathInput) - .mutation(({ input }) => - focusService().deleteSession(input.mainRepoPath), - ), - - isFocusActive: t.procedure - .input(mainRepoPathInput) - .output(z.boolean()) - .query(({ input }) => focusService().isFocusActive(input.mainRepoPath)), - - isDirty: t.procedure - .input(repoPathInput) - .output(z.boolean()) - .query(({ input }) => focusService().isDirty(input.repoPath)), - - getCommitSha: t.procedure - .input(repoPathInput) - .output(z.string()) - .query(({ input }) => focusService().getCommitSha(input.repoPath)), - - findWorktreeByBranch: t.procedure - .input(findWorktreeInput) - .output(z.string().nullable()) - .query(({ input }) => - focusService().findWorktreeByBranch(input.mainRepoPath, input.branch), - ), - - stash: t.procedure - .input(stashInput) - .output(stashResultSchema) - .mutation(({ input }) => - focusService().stash(input.repoPath, input.message), - ), +export interface WorkspaceServerServices { + focusService: FocusService; + focusSyncService: FocusSyncService; + gitService: GitService; + fsService: FsService; + watcherService: WatcherService; + localLogsService: LocalLogsService; + connectivityService: ConnectivityService; + environmentService: EnvironmentService; +} + +export function createAppRouter({ + focusService: focusServiceInst, + focusSyncService: focusSyncServiceInst, + gitService: gitServiceInst, + fsService: fsServiceInst, + watcherService: watcherServiceInst, + localLogsService: localLogsServiceInst, + connectivityService: connectivityServiceInst, + environmentService: environmentServiceInst, +}: WorkspaceServerServices) { + const focusService = () => focusServiceInst; + const focusSyncService = () => focusSyncServiceInst; + const gitService = () => gitServiceInst; + const fsService = () => fsServiceInst; + const watcherService = () => watcherServiceInst; + const localLogsService = () => localLogsServiceInst; + const connectivityService = () => connectivityServiceInst; + const environmentService = () => environmentServiceInst; + + return t.router({ + focus: t.router({ + getSession: t.procedure + .input(mainRepoPathInput) + .output(focusSessionSchema.nullable()) + .query(({ input }) => focusService().getSession(input.mainRepoPath)), + + saveSession: t.procedure + .input(focusSessionSchema) + .mutation(({ input }) => focusService().saveSession(input)), + + deleteSession: t.procedure + .input(mainRepoPathInput) + .mutation(({ input }) => + focusService().deleteSession(input.mainRepoPath), + ), - stashPop: t.procedure - .input(repoPathInput) - .output(focusResultSchema) - .mutation(({ input }) => focusService().stashPop(input.repoPath)), + isFocusActive: t.procedure + .input(mainRepoPathInput) + .output(z.boolean()) + .query(({ input }) => focusService().isFocusActive(input.mainRepoPath)), + + isDirty: t.procedure + .input(repoPathInput) + .output(z.boolean()) + .query(({ input }) => focusService().isDirty(input.repoPath)), + + getCommitSha: t.procedure + .input(repoPathInput) + .output(z.string()) + .query(({ input }) => focusService().getCommitSha(input.repoPath)), + + findWorktreeByBranch: t.procedure + .input(findWorktreeInput) + .output(z.string().nullable()) + .query(({ input }) => + focusService().findWorktreeByBranch(input.mainRepoPath, input.branch), + ), - stashApply: t.procedure - .input(z.object({ repoPath: z.string(), stashRef: z.string() })) - .output(focusResultSchema) - .mutation(({ input }) => - focusService().stashApply(input.repoPath, input.stashRef), - ), + stash: t.procedure + .input(stashInput) + .output(stashResultSchema) + .mutation(({ input }) => + focusService().stash(input.repoPath, input.message), + ), - checkout: t.procedure - .input(checkoutInput) - .output(focusResultSchema) - .mutation(({ input }) => - focusService().checkout(input.repoPath, input.branch), - ), + stashPop: t.procedure + .input(repoPathInput) + .output(focusResultSchema) + .mutation(({ input }) => focusService().stashPop(input.repoPath)), - detachWorktree: t.procedure - .input(worktreeInput) - .output(focusResultSchema) - .mutation(({ input }) => - focusService().detachWorktree(input.worktreePath), - ), + stashApply: t.procedure + .input(z.object({ repoPath: z.string(), stashRef: z.string() })) + .output(focusResultSchema) + .mutation(({ input }) => + focusService().stashApply(input.repoPath, input.stashRef), + ), - reattachWorktree: t.procedure - .input(reattachInput) - .output(focusResultSchema) - .mutation(({ input }) => - focusService().reattachWorktree(input.worktreePath, input.branch), - ), + checkout: t.procedure + .input(checkoutInput) + .output(focusResultSchema) + .mutation(({ input }) => + focusService().checkout(input.repoPath, input.branch), + ), - cleanWorkingTree: t.procedure - .input(repoPathInput) - .mutation(({ input }) => focusService().cleanWorkingTree(input.repoPath)), + detachWorktree: t.procedure + .input(worktreeInput) + .output(focusResultSchema) + .mutation(({ input }) => + focusService().detachWorktree(input.worktreePath), + ), - startSync: t.procedure - .input(syncInput) - .mutation(({ input }) => - focusSyncService().startSync(input.mainRepoPath, input.worktreePath), - ), + reattachWorktree: t.procedure + .input(reattachInput) + .output(focusResultSchema) + .mutation(({ input }) => + focusService().reattachWorktree(input.worktreePath, input.branch), + ), - stopSync: t.procedure.mutation(() => focusSyncService().stopSync()), + cleanWorkingTree: t.procedure + .input(repoPathInput) + .mutation(({ input }) => + focusService().cleanWorkingTree(input.repoPath), + ), - startWatchingMainRepo: t.procedure - .input(mainRepoPathInput) - .mutation(({ input }) => - focusService().startWatchingMainRepo(input.mainRepoPath), - ), + startSync: t.procedure + .input(syncInput) + .mutation(({ input }) => + focusSyncService().startSync(input.mainRepoPath, input.worktreePath), + ), - stopWatchingMainRepo: t.procedure.mutation(() => - focusService().stopWatchingMainRepo(), - ), + stopSync: t.procedure.mutation(() => focusSyncService().stopSync()), - onBranchRenamed: t.procedure.subscription(async function* (opts) { - for await (const event of focusService().branchRenamedEvents( - opts.signal, - )) { - yield event; - } - }), + startWatchingMainRepo: t.procedure + .input(mainRepoPathInput) + .mutation(({ input }) => + focusService().startWatchingMainRepo(input.mainRepoPath), + ), - onForeignBranchCheckout: t.procedure.subscription(async function* (opts) { - for await (const event of focusService().foreignBranchCheckoutEvents( - opts.signal, - )) { - yield event; - } - }), - }), - git: t.router({ - detectRepo: t.procedure - .input(directoryPathInput) - .output(detectRepoResultSchema) - .query(({ input }) => gitService().detectRepo(input.directoryPath)), - - validateRepo: t.procedure - .input(directoryPathInput) - .output(z.boolean()) - .query(({ input }) => gitService().validateRepo(input.directoryPath)), - - getRemoteUrl: t.procedure - .input(directoryPathInput) - .output(stringNullableOutput) - .query(({ input }) => gitService().getRemoteUrl(input.directoryPath)), - - getCurrentBranch: t.procedure - .input(directoryPathInput) - .output(stringNullableOutput) - .query(({ input, signal }) => - gitService().getCurrentBranch(input.directoryPath, signal), + stopWatchingMainRepo: t.procedure.mutation(() => + focusService().stopWatchingMainRepo(), ), - getDefaultBranch: t.procedure - .input(directoryPathInput) - .output(stringOutput) - .query(({ input }) => gitService().getDefaultBranch(input.directoryPath)), + onBranchRenamed: t.procedure.subscription(async function* (opts) { + for await (const event of focusService().branchRenamedEvents( + opts.signal, + )) { + yield event; + } + }), - getAllBranches: t.procedure - .input(directoryPathInput) - .output(stringArrayOutput) - .query(({ input, signal }) => - gitService().getAllBranches(input.directoryPath, signal), - ), + onForeignBranchCheckout: t.procedure.subscription(async function* (opts) { + for await (const event of focusService().foreignBranchCheckoutEvents( + opts.signal, + )) { + yield event; + } + }), + }), + git: t.router({ + detectRepo: t.procedure + .input(directoryPathInput) + .output(detectRepoResultSchema) + .query(({ input }) => gitService().detectRepo(input.directoryPath)), + + validateRepo: t.procedure + .input(directoryPathInput) + .output(z.boolean()) + .query(({ input }) => gitService().validateRepo(input.directoryPath)), + + getRemoteUrl: t.procedure + .input(directoryPathInput) + .output(stringNullableOutput) + .query(({ input }) => gitService().getRemoteUrl(input.directoryPath)), + + getCurrentBranch: t.procedure + .input(directoryPathInput) + .output(stringNullableOutput) + .query(({ input, signal }) => + gitService().getCurrentBranch(input.directoryPath, signal), + ), - getChangedFilesHead: t.procedure - .input(directoryPathInput) - .output(changedFilesOutput) - .query(({ input, signal }) => - gitService().getChangedFilesHead(input.directoryPath, signal), - ), + getDefaultBranch: t.procedure + .input(directoryPathInput) + .output(stringOutput) + .query(({ input }) => + gitService().getDefaultBranch(input.directoryPath), + ), - getFileAtHead: t.procedure - .input(filePathInput) - .output(stringNullableOutput) - .query(({ input, signal }) => - gitService().getFileAtHead(input.directoryPath, input.filePath, signal), - ), + getAllBranches: t.procedure + .input(directoryPathInput) + .output(stringArrayOutput) + .query(({ input, signal }) => + gitService().getAllBranches(input.directoryPath, signal), + ), - getDiffHead: t.procedure - .input(diffInput) - .output(stringOutput) - .query(({ input, signal }) => - gitService().getDiffHead( - input.directoryPath, - input.ignoreWhitespace, - signal, + getChangedFilesHead: t.procedure + .input(directoryPathInput) + .output(changedFilesOutput) + .query(({ input, signal }) => + gitService().getChangedFilesHead(input.directoryPath, signal), ), - ), - getDiffCached: t.procedure - .input(diffInput) - .output(stringOutput) - .query(({ input, signal }) => - gitService().getDiffCached( - input.directoryPath, - input.ignoreWhitespace, - signal, + getFileAtHead: t.procedure + .input(filePathInput) + .output(stringNullableOutput) + .query(({ input, signal }) => + gitService().getFileAtHead( + input.directoryPath, + input.filePath, + signal, + ), ), - ), - getDiffUnstaged: t.procedure - .input(diffInput) - .output(stringOutput) - .query(({ input, signal }) => - gitService().getDiffUnstaged( - input.directoryPath, - input.ignoreWhitespace, - signal, + getDiffHead: t.procedure + .input(diffInput) + .output(stringOutput) + .query(({ input, signal }) => + gitService().getDiffHead( + input.directoryPath, + input.ignoreWhitespace, + signal, + ), ), - ), - getLatestCommit: t.procedure - .input(directoryPathInput) - .output(gitCommitInfoNullableOutput) - .query(({ input, signal }) => - gitService().getLatestCommit(input.directoryPath, signal), - ), + getDiffCached: t.procedure + .input(diffInput) + .output(stringOutput) + .query(({ input, signal }) => + gitService().getDiffCached( + input.directoryPath, + input.ignoreWhitespace, + signal, + ), + ), - getGitRepoInfo: t.procedure - .input(directoryPathInput) - .output(gitRepoInfoNullableOutput) - .query(({ input }) => gitService().getGitRepoInfo(input.directoryPath)), + getDiffUnstaged: t.procedure + .input(diffInput) + .output(stringOutput) + .query(({ input, signal }) => + gitService().getDiffUnstaged( + input.directoryPath, + input.ignoreWhitespace, + signal, + ), + ), - getGitBusyState: t.procedure - .input(gitBusyStateInput) - .output(gitBusyStateSchema) - .query(({ input, signal }) => - gitService().getGitBusyState(input.directoryPath, signal), - ), + getLatestCommit: t.procedure + .input(directoryPathInput) + .output(gitCommitInfoNullableOutput) + .query(({ input, signal }) => + gitService().getLatestCommit(input.directoryPath, signal), + ), - getGitSyncStatus: t.procedure - .input(getGitSyncStatusInput) - .output(gitSyncStatusSchema) - .query(({ input }) => - gitService().getGitSyncStatus(input.directoryPath, input.forceRefresh), - ), + getGitRepoInfo: t.procedure + .input(directoryPathInput) + .output(gitRepoInfoNullableOutput) + .query(({ input }) => gitService().getGitRepoInfo(input.directoryPath)), - createBranch: t.procedure - .input(createBranchInput) - .mutation(({ input }) => - gitService().createBranch(input.directoryPath, input.branchName), - ), + getGitBusyState: t.procedure + .input(gitBusyStateInput) + .output(gitBusyStateSchema) + .query(({ input, signal }) => + gitService().getGitBusyState(input.directoryPath, signal), + ), - checkoutBranch: t.procedure - .input(checkoutBranchInput) - .output(checkoutBranchOutput) - .mutation(({ input }) => - gitService().checkoutBranch(input.directoryPath, input.branchName), - ), + getGitSyncStatus: t.procedure + .input(getGitSyncStatusInput) + .output(gitSyncStatusSchema) + .query(({ input }) => + gitService().getGitSyncStatus( + input.directoryPath, + input.forceRefresh, + ), + ), - stageFiles: t.procedure - .input(stageFilesInput) - .output(gitStateSnapshotSchema) - .mutation(({ input }) => - gitService().stageFiles(input.directoryPath, input.paths), - ), + createBranch: t.procedure + .input(createBranchInput) + .mutation(({ input }) => + gitService().createBranch(input.directoryPath, input.branchName), + ), - unstageFiles: t.procedure - .input(stageFilesInput) - .output(gitStateSnapshotSchema) - .mutation(({ input }) => - gitService().unstageFiles(input.directoryPath, input.paths), - ), + checkoutBranch: t.procedure + .input(checkoutBranchInput) + .output(checkoutBranchOutput) + .mutation(({ input }) => + gitService().checkoutBranch(input.directoryPath, input.branchName), + ), - discardFileChanges: t.procedure - .input(discardFileChangesInput) - .output(discardFileChangesOutput) - .mutation(({ input }) => - gitService().discardFileChanges( - input.directoryPath, - input.filePath, - input.fileStatus, + stageFiles: t.procedure + .input(stageFilesInput) + .output(gitStateSnapshotSchema) + .mutation(({ input }) => + gitService().stageFiles(input.directoryPath, input.paths), ), - ), - push: t.procedure - .input(pushInput) - .output(pushOutput) - .mutation(({ input, signal }) => - gitService().push( - input.directoryPath, - input.remote, - input.branch, - input.setUpstream, - signal, - input.env, + unstageFiles: t.procedure + .input(stageFilesInput) + .output(gitStateSnapshotSchema) + .mutation(({ input }) => + gitService().unstageFiles(input.directoryPath, input.paths), ), - ), - commit: t.procedure - .input(commitInput) - .output(commitOutput) - .mutation(({ input }) => - gitService().commit(input.directoryPath, input.message, { - paths: input.paths, - allowEmpty: input.allowEmpty, - stagedOnly: input.stagedOnly, - env: input.env, - }), - ), + discardFileChanges: t.procedure + .input(discardFileChangesInput) + .output(discardFileChangesOutput) + .mutation(({ input }) => + gitService().discardFileChanges( + input.directoryPath, + input.filePath, + input.fileStatus, + ), + ), - pull: t.procedure - .input(pullInput) - .output(pullOutput) - .mutation(({ input, signal }) => - gitService().pull( - input.directoryPath, - input.remote, - input.branch, - signal, + push: t.procedure + .input(pushInput) + .output(pushOutput) + .mutation(({ input, signal }) => + gitService().push( + input.directoryPath, + input.remote, + input.branch, + input.setUpstream, + signal, + input.env, + ), ), - ), - publish: t.procedure - .input(publishInput) - .output(publishOutput) - .mutation(({ input, signal }) => - gitService().publish( - input.directoryPath, - input.remote, - signal, - input.env, + commit: t.procedure + .input(commitInput) + .output(commitOutput) + .mutation(({ input }) => + gitService().commit(input.directoryPath, input.message, { + paths: input.paths, + allowEmpty: input.allowEmpty, + stagedOnly: input.stagedOnly, + env: input.env, + }), ), - ), - sync: t.procedure - .input(gitSyncInput) - .output(gitSyncOutput) - .mutation(({ input, signal }) => - gitService().sync(input.directoryPath, input.remote, signal), - ), + pull: t.procedure + .input(pullInput) + .output(pullOutput) + .mutation(({ input, signal }) => + gitService().pull( + input.directoryPath, + input.remote, + input.branch, + signal, + ), + ), - getGhStatus: t.procedure - .output(ghStatusOutput) - .query(() => gitService().getGhStatus()), + publish: t.procedure + .input(publishInput) + .output(publishOutput) + .mutation(({ input, signal }) => + gitService().publish( + input.directoryPath, + input.remote, + signal, + input.env, + ), + ), - getGhAuthToken: t.procedure - .output(ghAuthTokenOutput) - .query(() => gitService().getGhAuthToken()), + sync: t.procedure + .input(gitSyncInput) + .output(gitSyncOutput) + .mutation(({ input, signal }) => + gitService().sync(input.directoryPath, input.remote, signal), + ), - getPrStatus: t.procedure - .input(directoryPathInput) - .output(prStatusOutput) - .query(({ input }) => gitService().getPrStatus(input.directoryPath)), + getGhStatus: t.procedure + .output(ghStatusOutput) + .query(() => gitService().getGhStatus()), - getPrUrlForBranch: t.procedure - .input(getPrUrlForBranchInput) - .output(getPrUrlForBranchOutput) - .query(({ input }) => - gitService().getPrUrlForBranch(input.directoryPath, input.branchName), - ), + getGhAuthToken: t.procedure + .output(ghAuthTokenOutput) + .query(() => gitService().getGhAuthToken()), - openPr: t.procedure - .input(openPrInput) - .output(openPrOutput) - .mutation(({ input }) => gitService().openPr(input.directoryPath)), - - getPrDetailsByUrl: t.procedure - .input(getPrDetailsByUrlInput) - .output(getPrDetailsByUrlOutput.nullable()) - .query(({ input }) => gitService().getPrDetailsByUrl(input.prUrl)), - - getPrChangedFiles: t.procedure - .input(getPrChangedFilesInput) - .output(changedFilesOutput) - .query(({ input }) => gitService().getPrChangedFiles(input.prUrl)), - - getPrDiffStatsBatch: t.procedure - .input(getPrDiffStatsBatchInput) - .output(getPrDiffStatsBatchOutput) - .query(({ input }) => gitService().getPrDiffStatsBatch(input.prUrls)), - - getBranchChangedFiles: t.procedure - .input(getBranchChangedFilesInput) - .output(changedFilesOutput) - .query(({ input }) => - gitService().getBranchChangedFiles(input.repo, input.branch), - ), + getPrStatus: t.procedure + .input(directoryPathInput) + .output(prStatusOutput) + .query(({ input }) => gitService().getPrStatus(input.directoryPath)), - getLocalBranchChangedFiles: t.procedure - .input(getLocalBranchChangedFilesInput) - .output(changedFilesOutput) - .query(({ input }) => - gitService().getLocalBranchChangedFiles( - input.directoryPath, - input.branch, + getPrUrlForBranch: t.procedure + .input(getPrUrlForBranchInput) + .output(getPrUrlForBranchOutput) + .query(({ input }) => + gitService().getPrUrlForBranch(input.directoryPath, input.branchName), ), - ), - updatePrByUrl: t.procedure - .input(updatePrByUrlInput) - .output(updatePrByUrlOutput) - .mutation(({ input }) => - gitService().updatePrByUrl(input.prUrl, input.action), - ), - - getPrReviewComments: t.procedure - .input(getPrReviewCommentsInput) - .output(getPrReviewCommentsOutput) - .query(({ input }) => gitService().getPrReviewComments(input.prUrl)), + openPr: t.procedure + .input(openPrInput) + .output(openPrOutput) + .mutation(({ input }) => gitService().openPr(input.directoryPath)), + + getPrDetailsByUrl: t.procedure + .input(getPrDetailsByUrlInput) + .output(getPrDetailsByUrlOutput.nullable()) + .query(({ input }) => gitService().getPrDetailsByUrl(input.prUrl)), + + getPrChangedFiles: t.procedure + .input(getPrChangedFilesInput) + .output(changedFilesOutput) + .query(({ input }) => gitService().getPrChangedFiles(input.prUrl)), + + getPrDiffStatsBatch: t.procedure + .input(getPrDiffStatsBatchInput) + .output(getPrDiffStatsBatchOutput) + .query(({ input }) => gitService().getPrDiffStatsBatch(input.prUrls)), + + getBranchChangedFiles: t.procedure + .input(getBranchChangedFilesInput) + .output(changedFilesOutput) + .query(({ input }) => + gitService().getBranchChangedFiles(input.repo, input.branch), + ), - resolveReviewThread: t.procedure - .input(resolveReviewThreadInput) - .output(resolveReviewThreadOutput) - .mutation(({ input }) => - gitService().resolveReviewThread(input.threadNodeId, input.resolved), - ), + getLocalBranchChangedFiles: t.procedure + .input(getLocalBranchChangedFilesInput) + .output(changedFilesOutput) + .query(({ input }) => + gitService().getLocalBranchChangedFiles( + input.directoryPath, + input.branch, + ), + ), - replyToPrComment: t.procedure - .input(replyToPrCommentInput) - .output(replyToPrCommentOutput) - .mutation(({ input }) => - gitService().replyToPrComment(input.prUrl, input.commentId, input.body), - ), + updatePrByUrl: t.procedure + .input(updatePrByUrlInput) + .output(updatePrByUrlOutput) + .mutation(({ input }) => + gitService().updatePrByUrl(input.prUrl, input.action), + ), - getPrTemplate: t.procedure - .input(getPrTemplateInput) - .output(getPrTemplateOutput) - .query(({ input }) => gitService().getPrTemplate(input.directoryPath)), + getPrReviewComments: t.procedure + .input(getPrReviewCommentsInput) + .output(getPrReviewCommentsOutput) + .query(({ input }) => gitService().getPrReviewComments(input.prUrl)), - getCommitConventions: t.procedure - .input(getCommitConventionsInput) - .output(getCommitConventionsOutput) - .query(({ input }) => - gitService().getCommitConventions( - input.directoryPath, - input.sampleSize, + resolveReviewThread: t.procedure + .input(resolveReviewThreadInput) + .output(resolveReviewThreadOutput) + .mutation(({ input }) => + gitService().resolveReviewThread(input.threadNodeId, input.resolved), ), - ), - searchGithubRefs: t.procedure - .input(searchGithubRefsInput) - .output(searchGithubRefsOutput) - .query(({ input }) => - gitService().searchGithubRefs( - input.directoryPath, - input.query, - input.limit, - input.kinds, + replyToPrComment: t.procedure + .input(replyToPrCommentInput) + .output(replyToPrCommentOutput) + .mutation(({ input }) => + gitService().replyToPrComment( + input.prUrl, + input.commentId, + input.body, + ), ), - ), - getGithubIssue: t.procedure - .input(getGithubIssueInput) - .output(getGithubIssueOutput) - .query(({ input }) => - gitService().getGithubIssue(input.owner, input.repo, input.number), - ), - - getGithubPullRequest: t.procedure - .input(getGithubPullRequestInput) - .output(getGithubPullRequestOutput) - .query(({ input }) => - gitService().getGithubPullRequest( - input.owner, - input.repo, - input.number, + getPrTemplate: t.procedure + .input(getPrTemplateInput) + .output(getPrTemplateOutput) + .query(({ input }) => gitService().getPrTemplate(input.directoryPath)), + + getCommitConventions: t.procedure + .input(getCommitConventionsInput) + .output(getCommitConventionsOutput) + .query(({ input }) => + gitService().getCommitConventions( + input.directoryPath, + input.sampleSize, + ), ), - ), - readHandoffLocalGitState: t.procedure - .input(readHandoffLocalGitStateInput) - .output(readHandoffLocalGitStateOutput) - .query(({ input }) => - gitService().readHandoffLocalGitState(input.directoryPath), - ), + searchGithubRefs: t.procedure + .input(searchGithubRefsInput) + .output(searchGithubRefsOutput) + .query(({ input }) => + gitService().searchGithubRefs( + input.directoryPath, + input.query, + input.limit, + input.kinds, + ), + ), - cleanupAfterCloudHandoff: t.procedure - .input(cleanupAfterCloudHandoffInput) - .output(cleanupAfterCloudHandoffOutput) - .mutation(({ input }) => - gitService().cleanupAfterCloudHandoff( - input.directoryPath, - input.branchName, + getGithubIssue: t.procedure + .input(getGithubIssueInput) + .output(getGithubIssueOutput) + .query(({ input }) => + gitService().getGithubIssue(input.owner, input.repo, input.number), ), - ), - getDiffStats: t.procedure - .input(diffStatsInput) - .output(diffStatsSchema) - .query(({ input }) => gitService().getDiffStats(input.directoryPath)), + getGithubPullRequest: t.procedure + .input(getGithubPullRequestInput) + .output(getGithubPullRequestOutput) + .query(({ input }) => + gitService().getGithubPullRequest( + input.owner, + input.repo, + input.number, + ), + ), - getGitStatus: t.procedure - .output(gitStatusOutput) - .query(() => gitService().getGitStatus()), + readHandoffLocalGitState: t.procedure + .input(readHandoffLocalGitStateInput) + .output(readHandoffLocalGitStateOutput) + .query(({ input }) => + gitService().readHandoffLocalGitState(input.directoryPath), + ), - getHeadSha: t.procedure - .input(directoryPathInput) - .output(getHeadShaOutput) - .query(({ input }) => gitService().getHeadSha(input.directoryPath)), + cleanupAfterCloudHandoff: t.procedure + .input(cleanupAfterCloudHandoffInput) + .output(cleanupAfterCloudHandoffOutput) + .mutation(({ input }) => + gitService().cleanupAfterCloudHandoff( + input.directoryPath, + input.branchName, + ), + ), - getDiffAgainstRemote: t.procedure - .input(getDiffAgainstRemoteInput) - .output(stringOutput) - .query(({ input }) => - gitService().getDiffAgainstRemote( - input.directoryPath, - input.baseBranch, + getDiffStats: t.procedure + .input(diffStatsInput) + .output(diffStatsSchema) + .query(({ input }) => gitService().getDiffStats(input.directoryPath)), + + getGitStatus: t.procedure + .output(gitStatusOutput) + .query(() => gitService().getGitStatus()), + + getHeadSha: t.procedure + .input(directoryPathInput) + .output(getHeadShaOutput) + .query(({ input }) => gitService().getHeadSha(input.directoryPath)), + + getDiffAgainstRemote: t.procedure + .input(getDiffAgainstRemoteInput) + .output(stringOutput) + .query(({ input }) => + gitService().getDiffAgainstRemote( + input.directoryPath, + input.baseBranch, + ), ), - ), - getCommitsBetweenBranches: t.procedure - .input(getCommitsBetweenBranchesInput) - .output(getCommitsBetweenBranchesOutput) - .query(({ input }) => - gitService().getCommitsBetweenBranches( - input.directoryPath, - input.baseBranch, - input.head, - input.limit, + getCommitsBetweenBranches: t.procedure + .input(getCommitsBetweenBranchesInput) + .output(getCommitsBetweenBranchesOutput) + .query(({ input }) => + gitService().getCommitsBetweenBranches( + input.directoryPath, + input.baseBranch, + input.head, + input.limit, + ), ), - ), - resetSoft: t.procedure - .input(resetSoftInput) - .mutation(({ input }) => - gitService().resetSoft(input.directoryPath, input.sha), - ), + resetSoft: t.procedure + .input(resetSoftInput) + .mutation(({ input }) => + gitService().resetSoft(input.directoryPath, input.sha), + ), - createPrViaGh: t.procedure - .input(createPrViaGhInput) - .output(createPrViaGhOutput) - .mutation(({ input }) => - gitService().createPrViaGh( - input.directoryPath, - input.title, - input.body, - input.draft, - input.env, + createPrViaGh: t.procedure + .input(createPrViaGhInput) + .output(createPrViaGhOutput) + .mutation(({ input }) => + gitService().createPrViaGh( + input.directoryPath, + input.title, + input.body, + input.draft, + input.env, + ), ), - ), - cloneRepository: t.procedure - .input(cloneRepositoryInput) - .output(cloneRepositoryOutput) - .mutation(({ input }) => - gitService().cloneRepository( - input.repoUrl, - input.targetPath, - input.cloneId, + cloneRepository: t.procedure + .input(cloneRepositoryInput) + .output(cloneRepositoryOutput) + .mutation(({ input }) => + gitService().cloneRepository( + input.repoUrl, + input.targetPath, + input.cloneId, + ), ), - ), - onCloneProgress: t.procedure.subscription(async function* (opts) { - for await (const data of gitService().toIterable("cloneProgress", { - signal: opts.signal, - })) { - yield data; - } + onCloneProgress: t.procedure.subscription(async function* (opts) { + for await (const data of gitService().toIterable("cloneProgress", { + signal: opts.signal, + })) { + yield data; + } + }), }), - }), - diffStats: t.router({ - getDiffStats: t.procedure - .input(diffStatsInput) - .output(diffStatsSchema) - .query(({ input }) => gitService().getDiffStats(input.directoryPath)), - }), - fs: t.router({ - listDirectory: t.procedure - .input(listDirectoryInput) - .output(listDirectoryOutput) - .query(({ input }) => fsService().listDirectory(input.dirPath)), - - listRepoFiles: t.procedure - .input(listRepoFilesInput) - .output(listRepoFilesOutput) - .query(({ input }) => - fsService().listRepoFiles(input.repoPath, input.query, input.limit), - ), - - readRepoFile: t.procedure - .input(readRepoFileInput) - .output(readRepoFileOutput) - .query(({ input }) => - fsService().readRepoFile(input.repoPath, input.filePath), - ), + diffStats: t.router({ + getDiffStats: t.procedure + .input(diffStatsInput) + .output(diffStatsSchema) + .query(({ input }) => gitService().getDiffStats(input.directoryPath)), + }), + fs: t.router({ + listDirectory: t.procedure + .input(listDirectoryInput) + .output(listDirectoryOutput) + .query(({ input }) => fsService().listDirectory(input.dirPath)), + + listRepoFiles: t.procedure + .input(listRepoFilesInput) + .output(listRepoFilesOutput) + .query(({ input }) => + fsService().listRepoFiles(input.repoPath, input.query, input.limit), + ), - readRepoFiles: t.procedure - .input(readRepoFilesInput) - .output(readRepoFilesOutput) - .query(({ input }) => - fsService().readRepoFiles(input.repoPath, input.filePaths), - ), + readRepoFile: t.procedure + .input(readRepoFileInput) + .output(readRepoFileOutput) + .query(({ input }) => + fsService().readRepoFile(input.repoPath, input.filePath), + ), - readRepoFileBounded: t.procedure - .input(readRepoFileBoundedInput) - .output(boundedReadResult) - .query(({ input }) => - fsService().readRepoFileBounded( - input.repoPath, - input.filePath, - input.maxLines, + readRepoFiles: t.procedure + .input(readRepoFilesInput) + .output(readRepoFilesOutput) + .query(({ input }) => + fsService().readRepoFiles(input.repoPath, input.filePaths), ), - ), - readRepoFilesBounded: t.procedure - .input(readRepoFilesBoundedInput) - .output(readRepoFilesBoundedOutput) - .query(({ input }) => - fsService().readRepoFilesBounded( - input.repoPath, - input.filePaths, - input.maxLines, + readRepoFileBounded: t.procedure + .input(readRepoFileBoundedInput) + .output(boundedReadResult) + .query(({ input }) => + fsService().readRepoFileBounded( + input.repoPath, + input.filePath, + input.maxLines, + ), ), - ), - readAbsoluteFile: t.procedure - .input(readAbsoluteFileInput) - .output(readRepoFileOutput) - .query(({ input }) => fsService().readAbsoluteFile(input.filePath)), + readRepoFilesBounded: t.procedure + .input(readRepoFilesBoundedInput) + .output(readRepoFilesBoundedOutput) + .query(({ input }) => + fsService().readRepoFilesBounded( + input.repoPath, + input.filePaths, + input.maxLines, + ), + ), - readFileAsBase64: t.procedure - .input(readAbsoluteFileInput) - .output(readRepoFileOutput) - .query(({ input }) => fsService().readFileAsBase64(input.filePath)), + readAbsoluteFile: t.procedure + .input(readAbsoluteFileInput) + .output(readRepoFileOutput) + .query(({ input }) => fsService().readAbsoluteFile(input.filePath)), + + readFileAsBase64: t.procedure + .input(readAbsoluteFileInput) + .output(readRepoFileOutput) + .query(({ input }) => fsService().readFileAsBase64(input.filePath)), + + writeRepoFile: t.procedure + .input(writeRepoFileInput) + .mutation(({ input }) => + fsService().writeRepoFile( + input.repoPath, + input.filePath, + input.content, + ), + ), + }), + watcher: t.router({ + resolveGitDirs: t.procedure + .input(resolveGitDirsInput) + .output(resolveGitDirsOutput) + .query(({ input }) => watcherService().resolveGitDirs(input.repoPath)), + + watch: t.procedure + .input(watchInput) + .subscription(({ input, signal }) => + watcherService().watch( + input.dirPath, + { ignore: input.ignore }, + signal, + ), + ), + }), + fileWatcher: t.router({ + watch: t.procedure + .input(watchRepoInput) + .subscription(({ input, signal }) => + watcherService().watchRepo(input.repoPath, signal), + ), + }), + localLogs: t.router({ + read: t.procedure + .input(readLocalLogsInput) + .output(readLocalLogsOutput) + .query(({ input }) => + localLogsService().readLocalLogs(input.taskRunId), + ), - writeRepoFile: t.procedure - .input(writeRepoFileInput) - .mutation(({ input }) => - fsService().writeRepoFile( - input.repoPath, - input.filePath, - input.content, + write: t.procedure + .input(writeLocalLogsInput) + .mutation(({ input }) => + localLogsService().writeLocalLogs(input.taskRunId, input.content), ), - ), - }), - watcher: t.router({ - resolveGitDirs: t.procedure - .input(resolveGitDirsInput) - .output(resolveGitDirsOutput) - .query(({ input }) => watcherService().resolveGitDirs(input.repoPath)), - - watch: t.procedure - .input(watchInput) - .subscription(({ input, signal }) => - watcherService().watch(input.dirPath, { ignore: input.ignore }, signal), - ), - }), - fileWatcher: t.router({ - watch: t.procedure - .input(watchRepoInput) - .subscription(({ input, signal }) => - watcherService().watchRepo(input.repoPath, signal), - ), - }), - localLogs: t.router({ - read: t.procedure - .input(readLocalLogsInput) - .output(readLocalLogsOutput) - .query(({ input }) => localLogsService().readLocalLogs(input.taskRunId)), - - write: t.procedure - .input(writeLocalLogsInput) - .mutation(({ input }) => - localLogsService().writeLocalLogs(input.taskRunId, input.content), - ), - seed: t.procedure - .input(seedLocalLogsInput) - .mutation(({ input }) => - localLogsService().seedLocalLogs(input.taskRunId, input.content), - ), + seed: t.procedure + .input(seedLocalLogsInput) + .mutation(({ input }) => + localLogsService().seedLocalLogs(input.taskRunId, input.content), + ), - count: t.procedure - .input(countLocalLogEntriesInput) - .output(countLocalLogEntriesOutput) - .query(({ input }) => - localLogsService().countLocalLogEntries(input.taskRunId), - ), + count: t.procedure + .input(countLocalLogEntriesInput) + .output(countLocalLogEntriesOutput) + .query(({ input }) => + localLogsService().countLocalLogEntries(input.taskRunId), + ), - delete: t.procedure - .input(deleteLocalLogCacheInput) - .mutation(({ input }) => - localLogsService().deleteLocalLogCache(input.taskRunId), - ), - }), - connectivity: t.router({ - getStatus: t.procedure - .output(connectivityStatusOutput) - .query(() => connectivityService().getStatus()), - - checkNow: t.procedure - .output(connectivityStatusOutput) - .mutation(() => connectivityService().checkNow()), - - onStatusChange: t.procedure.subscription(async function* (opts) { - for await (const status of connectivityService().statusChangeEvents( - opts.signal, - )) { - yield status; - } + delete: t.procedure + .input(deleteLocalLogCacheInput) + .mutation(({ input }) => + localLogsService().deleteLocalLogCache(input.taskRunId), + ), }), - }), - environment: t.router({ - list: t.procedure - .input(listEnvironmentsInput) - .output(environmentSchema.array()) - .query(({ input }) => - environmentService().listEnvironments(input.repoPath), - ), + connectivity: t.router({ + getStatus: t.procedure + .output(connectivityStatusOutput) + .query(() => connectivityService().getStatus()), + + checkNow: t.procedure + .output(connectivityStatusOutput) + .mutation(() => connectivityService().checkNow()), + + onStatusChange: t.procedure.subscription(async function* (opts) { + for await (const status of connectivityService().statusChangeEvents( + opts.signal, + )) { + yield status; + } + }), + }), + environment: t.router({ + list: t.procedure + .input(listEnvironmentsInput) + .output(environmentSchema.array()) + .query(({ input }) => + environmentService().listEnvironments(input.repoPath), + ), - get: t.procedure - .input(getEnvironmentInput) - .output(environmentSchema.nullable()) - .query(({ input }) => - environmentService().getEnvironment(input.repoPath, input.id), - ), + get: t.procedure + .input(getEnvironmentInput) + .output(environmentSchema.nullable()) + .query(({ input }) => + environmentService().getEnvironment(input.repoPath, input.id), + ), - create: t.procedure - .input(createEnvironmentInput) - .output(environmentSchema) - .mutation(({ input }) => { - const { repoPath, ...rest } = input; - return environmentService().createEnvironment(rest, repoPath); - }), + create: t.procedure + .input(createEnvironmentInput) + .output(environmentSchema) + .mutation(({ input }) => { + const { repoPath, ...rest } = input; + return environmentService().createEnvironment(rest, repoPath); + }), - update: t.procedure - .input(updateEnvironmentInput) - .output(environmentSchema) - .mutation(({ input }) => { - const { repoPath, ...rest } = input; - return environmentService().updateEnvironment(rest, repoPath); - }), + update: t.procedure + .input(updateEnvironmentInput) + .output(environmentSchema) + .mutation(({ input }) => { + const { repoPath, ...rest } = input; + return environmentService().updateEnvironment(rest, repoPath); + }), - delete: t.procedure - .input(deleteEnvironmentInput) - .mutation(({ input }) => - environmentService().deleteEnvironment(input.repoPath, input.id), - ), - }), -}); + delete: t.procedure + .input(deleteEnvironmentInput) + .mutation(({ input }) => + environmentService().deleteEnvironment(input.repoPath, input.id), + ), + }), + }); +} -export type AppRouter = typeof appRouter; +export type AppRouter = ReturnType;