diff --git a/apps/code/src/main/db/migrations/0003_fair_whiplash.sql b/apps/code/src/main/db/migrations/0003_fair_whiplash.sql new file mode 100644 index 000000000..9fa93d5e5 --- /dev/null +++ b/apps/code/src/main/db/migrations/0003_fair_whiplash.sql @@ -0,0 +1,9 @@ +CREATE TABLE `auth_sessions` ( + `id` integer PRIMARY KEY NOT NULL CHECK (`id` = 1), + `refresh_token_encrypted` text NOT NULL, + `cloud_region` text NOT NULL, + `selected_project_id` integer, + `scope_version` integer NOT NULL, + `created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL, + `updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL +); diff --git a/apps/code/src/main/db/migrations/meta/_journal.json b/apps/code/src/main/db/migrations/meta/_journal.json index 3583bdfb3..ab1209f9a 100644 --- a/apps/code/src/main/db/migrations/meta/_journal.json +++ b/apps/code/src/main/db/migrations/meta/_journal.json @@ -22,6 +22,13 @@ "when": 1773335630838, "tag": "0002_massive_bishop", "breakpoints": true + }, + { + "idx": 3, + "version": "6", + "when": 1774890000000, + "tag": "0003_fair_whiplash", + "breakpoints": true } ] } diff --git a/apps/code/src/main/db/repositories/auth-session-repository.mock.ts b/apps/code/src/main/db/repositories/auth-session-repository.mock.ts new file mode 100644 index 000000000..8bf82de7c --- /dev/null +++ b/apps/code/src/main/db/repositories/auth-session-repository.mock.ts @@ -0,0 +1,42 @@ +import type { + AuthSession, + IAuthSessionRepository, + PersistAuthSessionInput, +} from "./auth-session-repository"; + +export interface MockAuthSessionRepository extends IAuthSessionRepository { + _session: AuthSession | null; +} + +export function createMockAuthSessionRepository(): MockAuthSessionRepository { + let session: AuthSession | null = null; + + const clone = (value: AuthSession | null): AuthSession | null => + value ? { ...value } : null; + + return { + get _session() { + return clone(session); + }, + set _session(value) { + session = clone(value); + }, + getCurrent: () => clone(session), + saveCurrent: (input: PersistAuthSessionInput) => { + const timestamp = new Date().toISOString(); + session = { + id: 1, + refreshTokenEncrypted: input.refreshTokenEncrypted, + cloudRegion: input.cloudRegion, + selectedProjectId: input.selectedProjectId, + scopeVersion: input.scopeVersion, + createdAt: session?.createdAt ?? timestamp, + updatedAt: timestamp, + }; + return { ...session }; + }, + clearCurrent: () => { + session = null; + }, + }; +} diff --git a/apps/code/src/main/db/repositories/auth-session-repository.ts b/apps/code/src/main/db/repositories/auth-session-repository.ts new file mode 100644 index 000000000..77abb45bd --- /dev/null +++ b/apps/code/src/main/db/repositories/auth-session-repository.ts @@ -0,0 +1,75 @@ +import type { CloudRegion } from "@shared/types/oauth"; +import { eq } from "drizzle-orm"; +import { inject, injectable } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { authSessions } from "../schema"; +import type { DatabaseService } from "../service"; + +export type AuthSession = typeof authSessions.$inferSelect; +export type NewAuthSession = typeof authSessions.$inferInsert; + +export interface PersistAuthSessionInput { + refreshTokenEncrypted: string; + cloudRegion: CloudRegion; + selectedProjectId: number | null; + scopeVersion: number; +} + +export interface IAuthSessionRepository { + getCurrent(): AuthSession | null; + saveCurrent(input: PersistAuthSessionInput): AuthSession; + clearCurrent(): void; +} + +const CURRENT_AUTH_SESSION_ID = 1; +const byId = eq(authSessions.id, CURRENT_AUTH_SESSION_ID); +const now = () => new Date().toISOString(); + +@injectable() +export class AuthSessionRepository implements IAuthSessionRepository { + constructor( + @inject(MAIN_TOKENS.DatabaseService) + private readonly databaseService: DatabaseService, + ) {} + + private get db() { + return this.databaseService.db; + } + + getCurrent(): AuthSession | null { + return ( + this.db.select().from(authSessions).where(byId).limit(1).get() ?? null + ); + } + + saveCurrent(input: PersistAuthSessionInput): AuthSession { + const timestamp = now(); + const existing = this.getCurrent(); + + const row: NewAuthSession = { + id: CURRENT_AUTH_SESSION_ID, + refreshTokenEncrypted: input.refreshTokenEncrypted, + cloudRegion: input.cloudRegion, + selectedProjectId: input.selectedProjectId, + scopeVersion: input.scopeVersion, + createdAt: existing?.createdAt ?? timestamp, + updatedAt: timestamp, + }; + + if (existing) { + this.db.update(authSessions).set(row).where(byId).run(); + } else { + this.db.insert(authSessions).values(row).run(); + } + + const saved = this.getCurrent(); + if (!saved) { + throw new Error("Failed to persist current auth session"); + } + return saved; + } + + clearCurrent(): void { + this.db.delete(authSessions).where(byId).run(); + } +} diff --git a/apps/code/src/main/db/schema.ts b/apps/code/src/main/db/schema.ts index 677932b39..00849018a 100644 --- a/apps/code/src/main/db/schema.ts +++ b/apps/code/src/main/db/schema.ts @@ -1,5 +1,5 @@ import { sql } from "drizzle-orm"; -import { index, sqliteTable, text } from "drizzle-orm/sqlite-core"; +import { index, integer, sqliteTable, text } from "drizzle-orm/sqlite-core"; const id = () => text() @@ -76,3 +76,13 @@ export const suspensions = sqliteTable("suspensions", { createdAt: createdAt(), updatedAt: updatedAt(), }); + +export const authSessions = sqliteTable("auth_sessions", { + id: integer().primaryKey(), + refreshTokenEncrypted: text().notNull(), + cloudRegion: text({ enum: ["us", "eu", "dev"] }).notNull(), + selectedProjectId: integer(), + scopeVersion: integer().notNull(), + createdAt: createdAt(), + updatedAt: updatedAt(), +}); diff --git a/apps/code/src/main/di/container.ts b/apps/code/src/main/di/container.ts index 4f94a1ca0..87eb47635 100644 --- a/apps/code/src/main/di/container.ts +++ b/apps/code/src/main/di/container.ts @@ -2,6 +2,7 @@ import "reflect-metadata"; import { Container } from "inversify"; import { ArchiveRepository } from "../db/repositories/archive-repository"; +import { AuthSessionRepository } from "../db/repositories/auth-session-repository"; import { RepositoryRepository } from "../db/repositories/repository-repository"; import { SuspensionRepositoryImpl } from "../db/repositories/suspension-repository"; import { WorkspaceRepository } from "../db/repositories/workspace-repository"; @@ -10,6 +11,7 @@ import { DatabaseService } from "../db/service"; import { AgentService } from "../services/agent/service"; import { AppLifecycleService } from "../services/app-lifecycle/service"; import { ArchiveService } from "../services/archive/service"; +import { AuthService } from "../services/auth/service"; import { AuthProxyService } from "../services/auth-proxy/service"; import { CloudTaskService } from "../services/cloud-task/service"; import { ConnectivityService } from "../services/connectivity/service"; @@ -49,12 +51,14 @@ export const container = new Container({ }); container.bind(MAIN_TOKENS.DatabaseService).to(DatabaseService); +container.bind(MAIN_TOKENS.AuthSessionRepository).to(AuthSessionRepository); container.bind(MAIN_TOKENS.RepositoryRepository).to(RepositoryRepository); container.bind(MAIN_TOKENS.WorkspaceRepository).to(WorkspaceRepository); container.bind(MAIN_TOKENS.WorktreeRepository).to(WorktreeRepository); container.bind(MAIN_TOKENS.ArchiveRepository).to(ArchiveRepository); container.bind(MAIN_TOKENS.SuspensionRepository).to(SuspensionRepositoryImpl); container.bind(MAIN_TOKENS.AgentService).to(AgentService); +container.bind(MAIN_TOKENS.AuthService).to(AuthService); container.bind(MAIN_TOKENS.AuthProxyService).to(AuthProxyService); container.bind(MAIN_TOKENS.ArchiveService).to(ArchiveService); container.bind(MAIN_TOKENS.SuspensionService).to(SuspensionService); diff --git a/apps/code/src/main/di/tokens.ts b/apps/code/src/main/di/tokens.ts index a11400b67..ad8ee191b 100644 --- a/apps/code/src/main/di/tokens.ts +++ b/apps/code/src/main/di/tokens.ts @@ -10,6 +10,7 @@ export const MAIN_TOKENS = Object.freeze({ // Database DatabaseService: Symbol.for("Main.DatabaseService"), + AuthSessionRepository: Symbol.for("Main.AuthSessionRepository"), RepositoryRepository: Symbol.for("Main.RepositoryRepository"), WorkspaceRepository: Symbol.for("Main.WorkspaceRepository"), WorktreeRepository: Symbol.for("Main.WorktreeRepository"), @@ -18,6 +19,7 @@ export const MAIN_TOKENS = Object.freeze({ // Services AgentService: Symbol.for("Main.AgentService"), + AuthService: Symbol.for("Main.AuthService"), AuthProxyService: Symbol.for("Main.AuthProxyService"), ArchiveService: Symbol.for("Main.ArchiveService"), SuspensionService: Symbol.for("Main.SuspensionService"), diff --git a/apps/code/src/main/index.ts b/apps/code/src/main/index.ts index 09c683ef1..d881aafa7 100644 --- a/apps/code/src/main/index.ts +++ b/apps/code/src/main/index.ts @@ -11,6 +11,7 @@ import { container } from "./di/container"; import { MAIN_TOKENS } from "./di/tokens"; import { registerMcpSandboxProtocol } from "./protocols/mcp-sandbox"; import type { AppLifecycleService } from "./services/app-lifecycle/service"; +import type { AuthService } from "./services/auth/service"; import type { ExternalAppsService } from "./services/external-apps/service"; import type { NotificationService } from "./services/notification/service"; import type { OAuthService } from "./services/oauth/service"; @@ -35,15 +36,18 @@ if (!gotTheLock) { process.exit(0); } -function initializeServices(): void { +async function initializeServices(): Promise { container.get(MAIN_TOKENS.DatabaseService); container.get(MAIN_TOKENS.OAuthService); + const authService = container.get(MAIN_TOKENS.AuthService); container.get(MAIN_TOKENS.NotificationService); container.get(MAIN_TOKENS.UpdatesService); container.get(MAIN_TOKENS.TaskLinkService); container.get(MAIN_TOKENS.ExternalAppsService); container.get(MAIN_TOKENS.PosthogPluginService); + await authService.initialize(); + // Initialize workspace branch watcher for live branch rename detection const workspaceService = container.get( MAIN_TOKENS.WorkspaceService, @@ -69,7 +73,7 @@ registerDeepLinkHandlers(); // Initialize PostHog analytics initializePostHog(); -app.whenReady().then(() => { +app.whenReady().then(async () => { const commit = __BUILD_COMMIT__ ?? "dev"; const buildDate = __BUILD_DATE__ ?? "dev"; log.info( @@ -87,8 +91,9 @@ app.whenReady().then(() => { ensureClaudeConfigDir(); registerMcpSandboxProtocol(); createWindow(); - initializeServices(); + await initializeServices(); initializeDeepLinks(); + await initializeServices(); powerMonitor.on("suspend", () => { log.info("System entering sleep"); }); diff --git a/apps/code/src/main/services/auth/schemas.ts b/apps/code/src/main/services/auth/schemas.ts new file mode 100644 index 000000000..f165e6a22 --- /dev/null +++ b/apps/code/src/main/services/auth/schemas.ts @@ -0,0 +1,51 @@ +import { z } from "zod"; +import { cloudRegion, type oAuthTokenResponse } from "../oauth/schemas"; + +export const authStatusSchema = z.enum(["anonymous", "authenticated"]); +export type AuthStatus = z.infer; + +export const authStateSchema = z.object({ + status: authStatusSchema, + bootstrapComplete: z.boolean(), + cloudRegion: cloudRegion.nullable(), + projectId: z.number().nullable(), + availableProjectIds: z.array(z.number()), + availableOrgIds: z.array(z.string()), + hasCodeAccess: z.boolean().nullable(), + needsScopeReauth: z.boolean(), +}); +export type AuthState = z.infer; + +export const loginInput = z.object({ + region: cloudRegion, +}); +export type LoginInput = z.infer; + +export const loginOutput = z.object({ + state: authStateSchema, +}); +export type LoginOutput = z.infer; + +export const redeemInviteCodeInput = z.object({ + code: z.string().min(1), +}); + +export const selectProjectInput = z.object({ + projectId: z.number(), +}); + +export const validAccessTokenOutput = z.object({ + accessToken: z.string(), + apiHost: z.string(), +}); +export type ValidAccessTokenOutput = z.infer; + +export const AuthServiceEvent = { + StateChanged: "state-changed", +} as const; + +export interface AuthServiceEvents { + [AuthServiceEvent.StateChanged]: AuthState; +} + +export type AuthTokenResponse = z.infer; diff --git a/apps/code/src/main/services/auth/service.test.ts b/apps/code/src/main/services/auth/service.test.ts new file mode 100644 index 000000000..c90c932c0 --- /dev/null +++ b/apps/code/src/main/services/auth/service.test.ts @@ -0,0 +1,174 @@ +import { OAUTH_SCOPE_VERSION } from "@shared/constants/oauth"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { createMockAuthSessionRepository } from "../../db/repositories/auth-session-repository.mock"; +import { decrypt, encrypt } from "../../utils/encryption"; +import type { ConnectivityService } from "../connectivity/service"; +import type { OAuthService } from "../oauth/service"; +import { AuthService } from "./service"; + +vi.mock("../../utils/logger.js", () => ({ + logger: { + scope: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), + }, +})); + +describe("AuthService", () => { + const repository = createMockAuthSessionRepository(); + const oauthService = { + refreshToken: vi.fn(), + startFlow: vi.fn(), + startSignupFlow: vi.fn(), + } as unknown as OAuthService; + const connectivityService = { + getStatus: vi.fn(() => ({ isOnline: true })), + } as unknown as ConnectivityService; + + let service: AuthService; + + beforeEach(() => { + repository.clearCurrent(); + vi.clearAllMocks(); + service = new AuthService(repository, oauthService, connectivityService); + }); + + afterEach(async () => { + vi.unstubAllGlobals(); + await service.logout(); + }); + + it("bootstraps to anonymous when there is no stored session", async () => { + await service.initialize(); + + expect(service.getState()).toEqual({ + status: "anonymous", + bootstrapComplete: true, + cloudRegion: null, + projectId: null, + availableProjectIds: [], + availableOrgIds: [], + hasCodeAccess: null, + needsScopeReauth: false, + }); + }); + + it("requires scope reauthentication when the stored scope version is stale", async () => { + repository.saveCurrent({ + refreshTokenEncrypted: encrypt("refresh-token"), + cloudRegion: "us", + selectedProjectId: 123, + scopeVersion: OAUTH_SCOPE_VERSION - 1, + }); + + await service.initialize(); + + expect(service.getState()).toEqual({ + status: "anonymous", + bootstrapComplete: true, + cloudRegion: "us", + projectId: 123, + availableProjectIds: [], + availableOrgIds: [], + hasCodeAccess: null, + needsScopeReauth: true, + }); + }); + + it("restores an authenticated session by refreshing the stored refresh token", async () => { + repository.saveCurrent({ + refreshTokenEncrypted: encrypt("stored-refresh-token"), + cloudRegion: "us", + selectedProjectId: 42, + scopeVersion: OAUTH_SCOPE_VERSION, + }); + + vi.mocked(oauthService.refreshToken).mockResolvedValue({ + success: true, + data: { + access_token: "new-access-token", + refresh_token: "rotated-refresh-token", + expires_in: 3600, + token_type: "Bearer", + scope: "", + scoped_teams: [42, 84], + scoped_organizations: ["org-1"], + }, + }); + + vi.stubGlobal( + "fetch", + vi.fn().mockResolvedValue({ + json: vi.fn().mockResolvedValue({ has_access: true }), + }) as unknown as typeof fetch, + ); + + await service.initialize(); + + expect(service.getState()).toMatchObject({ + status: "authenticated", + bootstrapComplete: true, + cloudRegion: "us", + projectId: 42, + availableProjectIds: [42, 84], + availableOrgIds: ["org-1"], + hasCodeAccess: true, + needsScopeReauth: false, + }); + + const persisted = repository.getCurrent(); + expect(persisted).not.toBeNull(); + expect(decrypt(persisted?.refreshTokenEncrypted ?? "")).toBe( + "rotated-refresh-token", + ); + }); + + it("forces a token refresh when explicitly requested", async () => { + vi.mocked(oauthService.startFlow).mockResolvedValue({ + success: true, + data: { + access_token: "initial-access-token", + refresh_token: "initial-refresh-token", + expires_in: 3600, + token_type: "Bearer", + scope: "", + scoped_teams: [42], + scoped_organizations: ["org-1"], + }, + }); + vi.mocked(oauthService.refreshToken).mockResolvedValue({ + success: true, + data: { + access_token: "refreshed-access-token", + refresh_token: "rotated-refresh-token", + expires_in: 3600, + token_type: "Bearer", + scope: "", + scoped_teams: [42], + scoped_organizations: ["org-1"], + }, + }); + vi.stubGlobal( + "fetch", + vi.fn().mockResolvedValue({ + json: vi.fn().mockResolvedValue({ has_access: true }), + }) as unknown as typeof fetch, + ); + + await service.login("us"); + + const token = await service.refreshAccessToken(); + + expect(token.accessToken).toBe("refreshed-access-token"); + expect(oauthService.refreshToken).toHaveBeenCalledWith( + "initial-refresh-token", + "us", + ); + expect(decrypt(repository.getCurrent()?.refreshTokenEncrypted ?? "")).toBe( + "rotated-refresh-token", + ); + }); +}); diff --git a/apps/code/src/main/services/auth/service.ts b/apps/code/src/main/services/auth/service.ts new file mode 100644 index 000000000..51faa8318 --- /dev/null +++ b/apps/code/src/main/services/auth/service.ts @@ -0,0 +1,468 @@ +import { + getCloudUrlFromRegion, + OAUTH_SCOPE_VERSION, +} from "@shared/constants/oauth"; +import type { CloudRegion } from "@shared/types/oauth"; +import { inject, injectable } from "inversify"; +import type { + IAuthSessionRepository, + PersistAuthSessionInput, +} from "../../db/repositories/auth-session-repository"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { decrypt, encrypt } from "../../utils/encryption"; +import { logger } from "../../utils/logger"; +import { TypedEventEmitter } from "../../utils/typed-event-emitter"; +import type { ConnectivityService } from "../connectivity/service"; +import type { OAuthService } from "../oauth/service"; +import { + AuthServiceEvent, + type AuthServiceEvents, + type AuthState, + type AuthTokenResponse, + type ValidAccessTokenOutput, +} from "./schemas"; + +const log = logger.scope("auth-service"); +const TOKEN_EXPIRY_SKEW_MS = 60_000; + +interface InMemorySession { + accessToken: string; + accessTokenExpiresAt: number; + refreshToken: string; + cloudRegion: CloudRegion; + projectId: number | null; + availableProjectIds: number[]; + availableOrgIds: string[]; +} + +interface StoredSessionInput { + refreshToken: string; + cloudRegion: CloudRegion; + selectedProjectId: number | null; +} + +interface TokenResponseOptions { + cloudRegion: CloudRegion; + selectedProjectId: number | null; +} + +@injectable() +export class AuthService extends TypedEventEmitter { + private state: AuthState = { + status: "anonymous", + bootstrapComplete: false, + cloudRegion: null, + projectId: null, + availableProjectIds: [], + availableOrgIds: [], + hasCodeAccess: null, + needsScopeReauth: false, + }; + private session: InMemorySession | null = null; + private initializePromise: Promise | null = null; + private refreshPromise: Promise | null = null; + + constructor( + @inject(MAIN_TOKENS.AuthSessionRepository) + private readonly authSessionRepository: IAuthSessionRepository, + @inject(MAIN_TOKENS.OAuthService) + private readonly oauthService: OAuthService, + @inject(MAIN_TOKENS.ConnectivityService) + private readonly connectivityService: ConnectivityService, + ) { + super(); + } + + async initialize(): Promise { + if (this.initializePromise) { + return this.initializePromise; + } + + this.initializePromise = this.doInitialize(); + return this.initializePromise; + } + + getState(): AuthState { + return { ...this.state }; + } + + async login(region: CloudRegion): Promise { + await this.authenticateWithFlow( + () => this.oauthService.startFlow(region), + region, + "OAuth flow failed", + ); + return this.getState(); + } + + async signup(region: CloudRegion): Promise { + await this.authenticateWithFlow( + () => this.oauthService.startSignupFlow(region), + region, + "Signup failed", + ); + return this.getState(); + } + + async getValidAccessToken(): Promise { + await this.initialize(); + + const session = await this.ensureValidSession(); + return { + accessToken: session.accessToken, + apiHost: getCloudUrlFromRegion(session.cloudRegion), + }; + } + + async refreshAccessToken(): Promise { + await this.initialize(); + + const session = await this.ensureValidSession(true); + return { + accessToken: session.accessToken, + apiHost: getCloudUrlFromRegion(session.cloudRegion), + }; + } + + async redeemInviteCode(code: string): Promise { + const { accessToken, apiHost } = await this.getValidAccessToken(); + const response = await fetch(`${apiHost}/api/code/invites/redeem/`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + }, + body: JSON.stringify({ code }), + }); + + const data = (await response.json().catch(() => ({}))) as { + success?: boolean; + error?: string; + }; + + if (!response.ok || !data.success) { + throw new Error(data.error || "Failed to redeem invite code"); + } + + this.updateState({ hasCodeAccess: true }); + return this.getState(); + } + + async selectProject(projectId: number): Promise { + await this.initialize(); + + const session = this.requireSession(); + + if (!session.availableProjectIds.includes(projectId)) { + throw new Error("Invalid project selection"); + } + + this.session = { + ...session, + projectId, + }; + + this.persistSession({ + refreshToken: this.session.refreshToken, + cloudRegion: this.session.cloudRegion, + selectedProjectId: projectId, + }); + + this.updateState({ projectId }); + return this.getState(); + } + + async logout(): Promise { + this.authSessionRepository.clearCurrent(); + this.session = null; + this.setAnonymousState(); + return this.getState(); + } + + private async doInitialize(): Promise { + const stored = this.authSessionRepository.getCurrent(); + + if (!stored) { + this.setAnonymousState({ bootstrapComplete: true }); + return; + } + + if (stored.scopeVersion < OAUTH_SCOPE_VERSION) { + this.session = null; + this.setAnonymousState({ + bootstrapComplete: true, + cloudRegion: stored.cloudRegion, + projectId: stored.selectedProjectId, + needsScopeReauth: true, + }); + return; + } + + const storedSession = this.getStoredSessionInput( + stored.refreshTokenEncrypted, + { + cloudRegion: stored.cloudRegion, + selectedProjectId: stored.selectedProjectId, + }, + ); + if (!storedSession) { + log.warn("Stored auth session could not be decrypted"); + this.authSessionRepository.clearCurrent(); + this.setAnonymousState({ bootstrapComplete: true }); + return; + } + + try { + await this.refreshAndSyncSession(storedSession); + } catch (error) { + log.warn("Failed to restore stored auth session", { error }); + this.session = null; + this.setAnonymousState({ + bootstrapComplete: true, + cloudRegion: storedSession.cloudRegion, + projectId: storedSession.selectedProjectId, + }); + } + } + + private async ensureValidSession( + forceRefresh = false, + ): Promise { + if ( + this.session && + !forceRefresh && + !this.isSessionExpiring(this.session) + ) { + return this.session; + } + + if (this.refreshPromise) { + return this.refreshPromise; + } + + const sessionInput = this.getSessionInputForRefresh(); + + this.refreshPromise = this.refreshSession(sessionInput).finally(() => { + this.refreshPromise = null; + }); + + const session = await this.refreshPromise; + await this.syncAuthenticatedSession(session); + return session; + } + + private getSessionInputForRefresh(): StoredSessionInput { + if (this.session) { + return { + refreshToken: this.session.refreshToken, + cloudRegion: this.session.cloudRegion, + selectedProjectId: this.session.projectId, + }; + } + + const stored = this.authSessionRepository.getCurrent(); + if (!stored) { + throw new Error("Not authenticated"); + } + + const storedSession = this.getStoredSessionInput( + stored.refreshTokenEncrypted, + { + cloudRegion: stored.cloudRegion, + selectedProjectId: stored.selectedProjectId, + }, + ); + if (!storedSession) { + throw new Error("Stored session is invalid"); + } + + return storedSession; + } + + private async refreshSession( + input: StoredSessionInput, + ): Promise { + if (!this.connectivityService.getStatus().isOnline) { + throw new Error("Offline"); + } + + const result = await this.oauthService.refreshToken( + input.refreshToken, + input.cloudRegion, + ); + + if (!result.success || !result.data) { + throw new Error(result.error || "Token refresh failed"); + } + + return this.createSessionFromTokenResponse(result.data, input); + } + + private createSessionFromTokenResponse( + tokenResponse: AuthTokenResponse, + options: TokenResponseOptions, + ): InMemorySession { + const availableProjectIds = tokenResponse.scoped_teams ?? []; + const availableOrgIds = tokenResponse.scoped_organizations ?? []; + const projectId = + options.selectedProjectId && + availableProjectIds.includes(options.selectedProjectId) + ? options.selectedProjectId + : (availableProjectIds[0] ?? null); + + const session: InMemorySession = { + accessToken: tokenResponse.access_token, + accessTokenExpiresAt: Date.now() + tokenResponse.expires_in * 1000, + refreshToken: tokenResponse.refresh_token, + cloudRegion: options.cloudRegion, + projectId, + availableProjectIds, + availableOrgIds, + }; + + return session; + } + + private async authenticateWithFlow( + runFlow: () => Promise<{ + success: boolean; + data?: AuthTokenResponse; + error?: string; + }>, + region: CloudRegion, + fallbackError: string, + ): Promise { + const result = await runFlow(); + if (!result.success || !result.data) { + throw new Error(result.error || fallbackError); + } + + const session = this.createSessionFromTokenResponse(result.data, { + cloudRegion: region, + selectedProjectId: this.state.projectId, + }); + await this.syncAuthenticatedSession(session); + } + + private async refreshAndSyncSession( + input: StoredSessionInput, + ): Promise { + const session = await this.refreshSession(input); + await this.syncAuthenticatedSession(session); + } + + private async syncAuthenticatedSession( + session: InMemorySession, + ): Promise { + this.persistSession({ + refreshToken: session.refreshToken, + cloudRegion: session.cloudRegion, + selectedProjectId: session.projectId, + }); + + this.session = session; + this.updateState({ + status: "authenticated", + bootstrapComplete: true, + cloudRegion: session.cloudRegion, + projectId: session.projectId, + availableProjectIds: session.availableProjectIds, + availableOrgIds: session.availableOrgIds, + needsScopeReauth: false, + }); + await this.updateCodeAccessFromSession(); + } + + private persistSession(input: { + refreshToken: string; + cloudRegion: CloudRegion; + selectedProjectId: number | null; + }): void { + const row: PersistAuthSessionInput = { + refreshTokenEncrypted: encrypt(input.refreshToken), + cloudRegion: input.cloudRegion, + selectedProjectId: input.selectedProjectId, + scopeVersion: OAUTH_SCOPE_VERSION, + }; + + this.authSessionRepository.saveCurrent(row); + } + + private isSessionExpiring(session: InMemorySession): boolean { + return session.accessTokenExpiresAt - Date.now() <= TOKEN_EXPIRY_SKEW_MS; + } + + private getStoredSessionInput( + refreshTokenEncrypted: string, + options: Omit, + ): StoredSessionInput | null { + const refreshToken = decrypt(refreshTokenEncrypted); + if (!refreshToken) { + return null; + } + + return { + refreshToken, + ...options, + }; + } + + private requireSession(): InMemorySession { + if (!this.session) { + throw new Error("Not authenticated"); + } + return this.session; + } + + private setAnonymousState( + partial: Pick< + Partial, + "bootstrapComplete" | "cloudRegion" | "projectId" | "needsScopeReauth" + > = {}, + ): void { + this.updateState({ + status: "anonymous", + bootstrapComplete: partial.bootstrapComplete ?? true, + cloudRegion: partial.cloudRegion ?? null, + projectId: partial.projectId ?? null, + availableProjectIds: [], + availableOrgIds: [], + hasCodeAccess: null, + needsScopeReauth: partial.needsScopeReauth ?? false, + }); + } + + private async updateCodeAccessFromSession(): Promise { + if (!this.session) { + this.updateState({ hasCodeAccess: null }); + return; + } + + try { + const response = await fetch( + `${getCloudUrlFromRegion(this.session.cloudRegion)}/api/code/invites/check-access/`, + { + headers: { + Authorization: `Bearer ${this.session.accessToken}`, + }, + }, + ); + const data = (await response.json().catch(() => ({}))) as { + has_access?: boolean; + }; + + this.updateState({ hasCodeAccess: data.has_access === true }); + } catch (error) { + log.warn("Failed to update code access state", { error }); + this.updateState({ hasCodeAccess: false }); + } + } + + private updateState(partial: Partial): void { + this.state = { + ...this.state, + ...partial, + }; + this.emit(AuthServiceEvent.StateChanged, this.getState()); + } +} diff --git a/apps/code/src/main/trpc/router.ts b/apps/code/src/main/trpc/router.ts index 73a960569..75ae41d5f 100644 --- a/apps/code/src/main/trpc/router.ts +++ b/apps/code/src/main/trpc/router.ts @@ -1,6 +1,7 @@ import { agentRouter } from "./routers/agent"; import { analyticsRouter } from "./routers/analytics"; import { archiveRouter } from "./routers/archive"; +import { authRouter } from "./routers/auth"; import { cloudTaskRouter } from "./routers/cloud-task"; import { connectivityRouter } from "./routers/connectivity"; import { contextMenuRouter } from "./routers/context-menu"; @@ -38,6 +39,7 @@ export const trpcRouter = router({ agent: agentRouter, analytics: analyticsRouter, archive: archiveRouter, + auth: authRouter, cloudTask: cloudTaskRouter, connectivity: connectivityRouter, contextMenu: contextMenuRouter, diff --git a/apps/code/src/main/trpc/routers/auth.ts b/apps/code/src/main/trpc/routers/auth.ts new file mode 100644 index 000000000..161d07114 --- /dev/null +++ b/apps/code/src/main/trpc/routers/auth.ts @@ -0,0 +1,67 @@ +import { container } from "../../di/container"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { + AuthServiceEvent, + authStateSchema, + loginInput, + loginOutput, + redeemInviteCodeInput, + selectProjectInput, + validAccessTokenOutput, +} from "../../services/auth/schemas"; +import type { AuthService } from "../../services/auth/service"; +import { publicProcedure, router } from "../trpc"; + +const getService = () => container.get(MAIN_TOKENS.AuthService); + +export const authRouter = router({ + getState: publicProcedure.output(authStateSchema).query(() => { + return getService().getState(); + }), + + onStateChanged: publicProcedure.subscription(async function* (opts) { + const service = getService(); + const iterable = service.toIterable(AuthServiceEvent.StateChanged, { + signal: opts.signal, + }); + for await (const state of iterable) { + yield state; + } + }), + + login: publicProcedure + .input(loginInput) + .output(loginOutput) + .mutation(async ({ input }) => ({ + state: await getService().login(input.region), + })), + + signup: publicProcedure + .input(loginInput) + .output(loginOutput) + .mutation(async ({ input }) => ({ + state: await getService().signup(input.region), + })), + + getValidAccessToken: publicProcedure + .output(validAccessTokenOutput) + .query(async () => getService().getValidAccessToken()), + + refreshAccessToken: publicProcedure + .output(validAccessTokenOutput) + .mutation(async () => getService().refreshAccessToken()), + + selectProject: publicProcedure + .input(selectProjectInput) + .output(authStateSchema) + .mutation(async ({ input }) => getService().selectProject(input.projectId)), + + redeemInviteCode: publicProcedure + .input(redeemInviteCodeInput) + .output(authStateSchema) + .mutation(async ({ input }) => getService().redeemInviteCode(input.code)), + + logout: publicProcedure.output(authStateSchema).mutation(async () => { + return getService().logout(); + }), +}); diff --git a/apps/code/src/renderer/App.tsx b/apps/code/src/renderer/App.tsx index 7cea1b131..3ea3e1d85 100644 --- a/apps/code/src/renderer/App.tsx +++ b/apps/code/src/renderer/App.tsx @@ -114,18 +114,13 @@ function App() { }), ); - // Wait for authStore to hydrate, then restore session from stored tokens + // Initialize auth state from main process useEffect(() => { const initialize = async () => { - if (!useAuthStore.persist.hasHydrated()) { - await new Promise((resolve) => { - useAuthStore.persist.onFinishHydration(() => resolve()); - }); - } await useAuthStore.getState().initializeOAuth(); setIsLoading(false); }; - initialize(); + void initialize(); }, []); // Handle transition into main app — only show the dark overlay if dark mode is active diff --git a/apps/code/src/renderer/api/fetcher.test.ts b/apps/code/src/renderer/api/fetcher.test.ts index 0a7b4e740..e3890a7ae 100644 --- a/apps/code/src/renderer/api/fetcher.test.ts +++ b/apps/code/src/renderer/api/fetcher.test.ts @@ -24,79 +24,77 @@ describe("buildApiFetcher", () => { vi.stubGlobal("fetch", mockFetch); }); - it("makes request with bearer token", async () => { + it("makes request with a token fetched from the provider", async () => { + const getAccessToken = vi.fn().mockResolvedValue("my-token"); + const refreshAccessToken = vi.fn().mockResolvedValue("new-token"); mockFetch.mockResolvedValueOnce(ok()); - const fetcher = buildApiFetcher({ apiToken: "my-token" }); + const fetcher = buildApiFetcher({ getAccessToken, refreshAccessToken }); await fetcher.fetch(mockInput); + expect(getAccessToken).toHaveBeenCalledTimes(1); + expect(refreshAccessToken).not.toHaveBeenCalled(); expect(mockFetch.mock.calls[0][1].headers.get("Authorization")).toBe( "Bearer my-token", ); }); - it("retries with new token on 401", async () => { - const onTokenRefresh = vi.fn().mockResolvedValue("new-token"); + it("retries once with a freshly fetched token on 401", async () => { + const getAccessToken = vi.fn().mockResolvedValue("old-token"); + const refreshAccessToken = vi.fn().mockResolvedValue("new-token"); mockFetch.mockResolvedValueOnce(err(401)).mockResolvedValueOnce(ok()); - const fetcher = buildApiFetcher({ apiToken: "old-token", onTokenRefresh }); + const fetcher = buildApiFetcher({ getAccessToken, refreshAccessToken }); const response = await fetcher.fetch(mockInput); expect(response.ok).toBe(true); - expect(onTokenRefresh).toHaveBeenCalledTimes(1); + expect(getAccessToken).toHaveBeenCalledTimes(1); + expect(refreshAccessToken).toHaveBeenCalledTimes(1); expect(mockFetch.mock.calls[1][1].headers.get("Authorization")).toBe( "Bearer new-token", ); }); - it("uses refreshed token for subsequent requests", async () => { - const onTokenRefresh = vi.fn().mockResolvedValue("refreshed-token"); - mockFetch - .mockResolvedValueOnce(err(401)) - .mockResolvedValueOnce(ok()) - .mockResolvedValueOnce(ok()); - - const fetcher = buildApiFetcher({ - apiToken: "initial-token", - onTokenRefresh, - }); - await fetcher.fetch(mockInput); - await fetcher.fetch(mockInput); - - expect(mockFetch.mock.calls[2][1].headers.get("Authorization")).toBe( - "Bearer refreshed-token", - ); - }); - - it("does not refresh on non-401 errors", async () => { - const onTokenRefresh = vi.fn(); + it("does not retry on non-401 errors", async () => { + const getAccessToken = vi.fn().mockResolvedValue("token"); + const refreshAccessToken = vi.fn().mockResolvedValue("new-token"); mockFetch.mockResolvedValueOnce(err(403)); - const fetcher = buildApiFetcher({ apiToken: "token", onTokenRefresh }); + const fetcher = buildApiFetcher({ getAccessToken, refreshAccessToken }); await expect(fetcher.fetch(mockInput)).rejects.toThrow("[403]"); - expect(onTokenRefresh).not.toHaveBeenCalled(); + expect(getAccessToken).toHaveBeenCalledTimes(1); + expect(refreshAccessToken).not.toHaveBeenCalled(); }); - it("throws on 401 without refresh callback", async () => { - mockFetch.mockResolvedValueOnce(err(401)); - const fetcher = buildApiFetcher({ apiToken: "token" }); + it("throws when the retry still returns 401", async () => { + const getAccessToken = vi.fn().mockResolvedValue("token-1"); + const refreshAccessToken = vi.fn().mockResolvedValue("token-2"); + mockFetch.mockResolvedValueOnce(err(401)).mockResolvedValueOnce(err(401)); + + const fetcher = buildApiFetcher({ getAccessToken, refreshAccessToken }); await expect(fetcher.fetch(mockInput)).rejects.toThrow("[401]"); }); - it("throws when refresh fails", async () => { - const onTokenRefresh = vi.fn().mockRejectedValue(new Error("failed")); + it("throws when refetching a token fails during retry", async () => { + const getAccessToken = vi.fn().mockResolvedValue("token"); + const refreshAccessToken = vi + .fn() + .mockRejectedValueOnce(new Error("failed")); mockFetch.mockResolvedValueOnce(err(401)); - const fetcher = buildApiFetcher({ apiToken: "token", onTokenRefresh }); + const fetcher = buildApiFetcher({ getAccessToken, refreshAccessToken }); await expect(fetcher.fetch(mockInput)).rejects.toThrow("[401]"); }); it("handles network errors", async () => { mockFetch.mockRejectedValueOnce(new Error("Network failure")); - const fetcher = buildApiFetcher({ apiToken: "token" }); + const fetcher = buildApiFetcher({ + getAccessToken: vi.fn().mockResolvedValue("token"), + refreshAccessToken: vi.fn().mockResolvedValue("new-token"), + }); await expect(fetcher.fetch(mockInput)).rejects.toThrow( "Network request failed", diff --git a/apps/code/src/renderer/api/fetcher.ts b/apps/code/src/renderer/api/fetcher.ts index 14a69051b..cccefd7eb 100644 --- a/apps/code/src/renderer/api/fetcher.ts +++ b/apps/code/src/renderer/api/fetcher.ts @@ -3,11 +3,9 @@ import type { createApiClient } from "./generated"; const USER_AGENT = `posthog/desktop.hog.dev; version: ${typeof __APP_VERSION__ !== "undefined" ? __APP_VERSION__ : "unknown"}`; export const buildApiFetcher: (config: { - apiToken: string; - onTokenRefresh?: () => Promise; + getAccessToken: () => Promise; + refreshAccessToken: () => Promise; }) => Parameters[0] = (config) => { - let currentToken = config.apiToken; - const makeRequest = async ( input: Parameters[0]["fetch"]>[0], token: string, @@ -56,16 +54,16 @@ export const buildApiFetcher: (config: { return { fetch: async (input) => { - let response = await makeRequest(input, currentToken); + let response = await makeRequest(input, await config.getAccessToken()); - // Handle 401 with automatic token refresh - if (!response.ok && response.status === 401 && config.onTokenRefresh) { + // Retry once on 401 after asking main for a fresh valid token again. + if (!response.ok && response.status === 401) { try { - const newToken = await config.onTokenRefresh(); - currentToken = newToken; - response = await makeRequest(input, currentToken); + response = await makeRequest( + input, + await config.refreshAccessToken(), + ); } catch { - // Token refresh failed - throw the original 401 error const errorResponse = await response.json(); throw new Error( `Failed request: [${response.status}] ${JSON.stringify(errorResponse)}`, diff --git a/apps/code/src/renderer/api/posthogClient.ts b/apps/code/src/renderer/api/posthogClient.ts index c870611b6..ca863b679 100644 --- a/apps/code/src/renderer/api/posthogClient.ts +++ b/apps/code/src/renderer/api/posthogClient.ts @@ -138,16 +138,16 @@ export class PostHogAPIClient { private _teamId: number | null = null; constructor( - accessToken: string, apiHost: string, - onTokenRefresh?: () => Promise, + getAccessToken: () => Promise, + refreshAccessToken: () => Promise, teamId?: number, ) { const baseUrl = apiHost.endsWith("/") ? apiHost.slice(0, -1) : apiHost; this.api = createApiClient( buildApiFetcher({ - apiToken: accessToken, - onTokenRefresh, + getAccessToken, + refreshAccessToken, }), baseUrl, ); diff --git a/apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx b/apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx index 3e8706807..8dfe8a4b1 100644 --- a/apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx +++ b/apps/code/src/renderer/components/ScopeReauthPrompt.test.tsx @@ -4,15 +4,50 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; vi.mock("@renderer/trpc/client", () => ({ trpcClient: { - secureStore: { - getItem: { query: vi.fn() }, - setItem: { query: vi.fn() }, - removeItem: { query: vi.fn() }, - }, - oauth: { - refreshToken: { mutate: vi.fn() }, - startFlow: { mutate: vi.fn() }, - startSignupFlow: { mutate: vi.fn() }, + auth: { + getState: { query: vi.fn() }, + onStateChanged: { subscribe: vi.fn(() => ({ unsubscribe: vi.fn() })) }, + getValidAccessToken: { + query: vi.fn().mockResolvedValue({ + accessToken: "token", + apiHost: "https://us.posthog.com", + }), + }, + refreshAccessToken: { + mutate: vi.fn().mockResolvedValue({ + accessToken: "token", + apiHost: "https://us.posthog.com", + }), + }, + login: { + mutate: vi.fn().mockResolvedValue({ + state: { + status: "authenticated", + bootstrapComplete: true, + cloudRegion: "us", + projectId: 1, + availableProjectIds: [1], + availableOrgIds: [], + hasCodeAccess: true, + needsScopeReauth: false, + }, + }), + }, + signup: { mutate: vi.fn() }, + selectProject: { mutate: vi.fn() }, + redeemInviteCode: { mutate: vi.fn() }, + logout: { + mutate: vi.fn().mockResolvedValue({ + status: "anonymous", + bootstrapComplete: true, + cloudRegion: null, + projectId: null, + availableProjectIds: [], + availableOrgIds: [], + hasCodeAccess: null, + needsScopeReauth: false, + }), + }, }, agent: { updateToken: { mutate: vi.fn().mockResolvedValue(undefined) }, @@ -58,7 +93,10 @@ vi.mock("@stores/navigationStore", () => ({ }, })); -import { useAuthStore } from "@features/auth/stores/authStore"; +import { + resetAuthStoreModuleStateForTest, + useAuthStore, +} from "@features/auth/stores/authStore"; import { Theme } from "@radix-ui/themes"; import type { ReactElement } from "react"; import { ScopeReauthPrompt } from "./ScopeReauthPrompt"; @@ -70,6 +108,7 @@ function renderWithTheme(ui: ReactElement) { describe("ScopeReauthPrompt", () => { beforeEach(() => { localStorage.clear(); + resetAuthStoreModuleStateForTest(); useAuthStore.setState({ needsScopeReauth: false, cloudRegion: null, diff --git a/apps/code/src/renderer/features/auth/components/AuthScreen.tsx b/apps/code/src/renderer/features/auth/components/AuthScreen.tsx index 93d58e234..51e06bd49 100644 --- a/apps/code/src/renderer/features/auth/components/AuthScreen.tsx +++ b/apps/code/src/renderer/features/auth/components/AuthScreen.tsx @@ -42,7 +42,7 @@ export const getErrorMessage = (error: unknown) => { type AuthMode = "login" | "signup"; export function AuthScreen() { - const staleRegion = useAuthStore((s) => s.staleTokens?.cloudRegion); + const staleRegion = useAuthStore((s) => s.staleCloudRegion); const [region, setRegion] = useState(staleRegion ?? "us"); const [authMode, setAuthMode] = useState("login"); const { loginWithOAuth, signupWithOAuth } = useAuthStore(); diff --git a/apps/code/src/renderer/features/auth/stores/authStore.test.ts b/apps/code/src/renderer/features/auth/stores/authStore.test.ts index b03f2c19d..0c1fd3fb6 100644 --- a/apps/code/src/renderer/features/auth/stores/authStore.test.ts +++ b/apps/code/src/renderer/features/auth/stores/authStore.test.ts @@ -1,27 +1,28 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; +const mockGetState = vi.hoisted(() => ({ query: vi.fn() })); +const mockOnStateChangedSubscribe = vi.hoisted(() => vi.fn()); +const mockGetValidAccessToken = vi.hoisted(() => ({ query: vi.fn() })); +const mockRefreshAccessToken = vi.hoisted(() => ({ mutate: vi.fn() })); +const mockLogin = vi.hoisted(() => ({ mutate: vi.fn() })); +const mockSignup = vi.hoisted(() => ({ mutate: vi.fn() })); +const mockSelectProject = vi.hoisted(() => ({ mutate: vi.fn() })); +const mockRedeemInviteCode = vi.hoisted(() => ({ mutate: vi.fn() })); +const mockLogout = vi.hoisted(() => ({ mutate: vi.fn() })); const mockGetCurrentUser = vi.fn(); -const { getItem, setItem } = vi.hoisted(() => ({ - getItem: vi.fn(), - setItem: vi.fn(), -})); - -const mockRefreshToken = vi.hoisted(() => ({ mutate: vi.fn() })); -const mockStartFlow = vi.hoisted(() => ({ mutate: vi.fn() })); -const mockStartSignupFlow = vi.hoisted(() => ({ mutate: vi.fn() })); - vi.mock("@renderer/trpc/client", () => ({ trpcClient: { - secureStore: { - getItem: { query: getItem }, - setItem: { query: setItem }, - removeItem: { query: vi.fn() }, - }, - oauth: { - refreshToken: mockRefreshToken, - startFlow: mockStartFlow, - startSignupFlow: mockStartSignupFlow, + auth: { + getState: mockGetState, + onStateChanged: { subscribe: mockOnStateChangedSubscribe }, + getValidAccessToken: mockGetValidAccessToken, + refreshAccessToken: mockRefreshAccessToken, + login: mockLogin, + signup: mockSignup, + selectProject: mockSelectProject, + redeemInviteCode: mockRedeemInviteCode, + logout: mockLogout, }, agent: { updateToken: { mutate: vi.fn().mockResolvedValue(undefined) }, @@ -36,13 +37,19 @@ vi.mock("@renderer/trpc/client", () => ({ }, })); +vi.mock("@renderer/api/posthogClient", () => ({ + PostHogAPIClient: vi.fn().mockImplementation(function ( + this: Record, + ) { + this.getCurrentUser = mockGetCurrentUser; + this.setTeamId = vi.fn(); + }), +})); + vi.mock("@utils/analytics", () => ({ identifyUser: vi.fn(), resetUser: vi.fn(), track: vi.fn(), - isFeatureFlagEnabled: vi.fn().mockReturnValue(false), - onFeatureFlagsLoaded: vi.fn(), - reloadFeatureFlags: vi.fn(), })); vi.mock("@utils/logger", () => ({ @@ -64,57 +71,58 @@ vi.mock("@utils/queryClient", () => ({ }, })); -vi.mock("@renderer/api/posthogClient", () => ({ - PostHogAPIClient: vi.fn().mockImplementation(function ( - this: Record, - ) { - this.getCurrentUser = mockGetCurrentUser; - this.setTeamId = vi.fn(); - }), -})); - vi.mock("@stores/navigationStore", () => ({ useNavigationStore: { getState: () => ({ navigateToTaskInput: vi.fn() }), }, })); -import { OAUTH_SCOPE_VERSION } from "@shared/constants/oauth"; -import { useAuthStore } from "./authStore"; - -function makeStoredTokens(overrides: Record = {}) { - return { - accessToken: "test-access-token", - refreshToken: "test-refresh-token", - expiresAt: Date.now() + 3600 * 1000, - cloudRegion: "us" as const, - scopedTeams: [1], - scopeVersion: OAUTH_SCOPE_VERSION, - ...overrides, - }; -} - -const mockUser = { - distinct_id: "user-123", - email: "test@example.com", - uuid: "uuid-123", - team: { id: 1 }, +import { resetAuthStoreModuleStateForTest, useAuthStore } from "./authStore"; + +const authenticatedState = { + status: "authenticated" as const, + bootstrapComplete: true, + cloudRegion: "us" as const, + projectId: 1, + availableProjectIds: [1, 2], + availableOrgIds: ["org-1"], + hasCodeAccess: true, + needsScopeReauth: false, }; -describe("authStore - scope version", () => { +describe("authStore", () => { beforeEach(() => { vi.clearAllMocks(); - getItem.mockResolvedValue(null); - setItem.mockResolvedValue(undefined); - mockGetCurrentUser.mockResolvedValue(mockUser); + resetAuthStoreModuleStateForTest(); + mockGetCurrentUser.mockResolvedValue({ + distinct_id: "user-123", + email: "test@example.com", + uuid: "uuid-123", + }); + mockGetValidAccessToken.query.mockResolvedValue({ + accessToken: "test-access-token", + apiHost: "https://us.posthog.com", + }); + mockRefreshAccessToken.mutate.mockResolvedValue({ + accessToken: "fresh-access-token", + apiHost: "https://us.posthog.com", + }); + mockGetState.query.mockResolvedValue({ + status: "anonymous", + bootstrapComplete: true, + cloudRegion: null, + projectId: null, + availableProjectIds: [], + availableOrgIds: [], + hasCodeAccess: null, + needsScopeReauth: false, + }); + mockOnStateChangedSubscribe.mockReturnValue({ unsubscribe: vi.fn() }); useAuthStore.setState({ oauthAccessToken: null, - oauthRefreshToken: null, - tokenExpiry: null, cloudRegion: null, - storedTokens: null, - staleTokens: null, + staleCloudRegion: null, isAuthenticated: false, client: null, projectId: null, @@ -122,135 +130,42 @@ describe("authStore - scope version", () => { availableOrgIds: [], needsProjectSelection: false, needsScopeReauth: false, + hasCodeAccess: null, + hasCompletedOnboarding: false, + selectedPlan: null, + selectedOrgId: null, }); }); - describe("initializeOAuth", () => { - async function initializeWithTokens( - tokenOverrides: Record, - ) { - const tokens = makeStoredTokens(tokenOverrides); - useAuthStore.setState({ storedTokens: tokens }); - // Ensure hasHydrated returns true - await useAuthStore.persist.rehydrate(); - return useAuthStore.getState().initializeOAuth(); - } - - it("sets needsScopeReauth when scopeVersion is missing (treated as 0)", async () => { - const result = await initializeWithTokens({ scopeVersion: undefined }); - - expect(result).toBe(true); - expect(useAuthStore.getState().needsScopeReauth).toBe(true); - expect(useAuthStore.getState().isAuthenticated).toBe(true); - expect(useAuthStore.getState().cloudRegion).toBe("us"); - expect(useAuthStore.getState().storedTokens).not.toBeNull(); - // Should NOT create a client or call getCurrentUser — early return avoids - // racing with loginWithOAuth when the user clicks Sign In. - expect(mockGetCurrentUser).not.toHaveBeenCalled(); - expect(useAuthStore.getState().client).toBeNull(); - }); - - it("sets needsScopeReauth when scopeVersion is less than OAUTH_SCOPE_VERSION", async () => { - const result = await initializeWithTokens({ - scopeVersion: OAUTH_SCOPE_VERSION - 1, - }); + it("initializes from main auth state", async () => { + mockGetState.query.mockResolvedValue(authenticatedState); - expect(result).toBe(true); - expect(useAuthStore.getState().needsScopeReauth).toBe(true); - expect(useAuthStore.getState().isAuthenticated).toBe(true); - expect(useAuthStore.getState().cloudRegion).toBe("us"); - expect(useAuthStore.getState().storedTokens).not.toBeNull(); - expect(mockGetCurrentUser).not.toHaveBeenCalled(); - expect(useAuthStore.getState().client).toBeNull(); - }); - - it("does not set needsScopeReauth when scopeVersion matches", async () => { - const result = await initializeWithTokens({ - scopeVersion: OAUTH_SCOPE_VERSION, - }); + const result = await useAuthStore.getState().initializeOAuth(); - expect(result).toBe(true); - expect(useAuthStore.getState().needsScopeReauth).toBe(false); - expect(useAuthStore.getState().isAuthenticated).toBe(true); - expect(useAuthStore.getState().storedTokens).not.toBeNull(); - }); + expect(result).toBe(true); + expect(useAuthStore.getState().isAuthenticated).toBe(true); + expect(useAuthStore.getState().projectId).toBe(1); + expect(useAuthStore.getState().oauthAccessToken).toBe("test-access-token"); }); - describe("loginWithOAuth", () => { - it("clears needsScopeReauth after successful login", async () => { - useAuthStore.setState({ needsScopeReauth: true }); + it("logs in through the main auth service", async () => { + mockLogin.mutate.mockResolvedValue({ state: authenticatedState }); + mockGetState.query.mockResolvedValue(authenticatedState); - mockStartFlow.mutate.mockResolvedValue({ - success: true, - data: { - access_token: "new-access-token", - refresh_token: "new-refresh-token", - expires_in: 3600, - scoped_teams: [1], - scoped_organizations: ["org-1"], - }, - }); + await useAuthStore.getState().loginWithOAuth("us"); - await useAuthStore.getState().loginWithOAuth("us"); - - expect(useAuthStore.getState().needsScopeReauth).toBe(false); - expect(useAuthStore.getState().isAuthenticated).toBe(true); - }); + expect(mockLogin.mutate).toHaveBeenCalledWith({ region: "us" }); + expect(useAuthStore.getState().isAuthenticated).toBe(true); + expect(useAuthStore.getState().needsScopeReauth).toBe(false); }); - describe("refreshAccessToken", () => { - it("preserves existing scopeVersion on refreshed tokens", async () => { - const staleVersion = OAUTH_SCOPE_VERSION - 1; - useAuthStore.setState({ - oauthAccessToken: "old-token", - oauthRefreshToken: "old-refresh-token", - cloudRegion: "us", - storedTokens: makeStoredTokens({ scopeVersion: staleVersion }), - isAuthenticated: true, - }); - - mockRefreshToken.mutate.mockResolvedValue({ - success: true, - data: { - access_token: "new-access-token", - refresh_token: "new-refresh-token", - expires_in: 3600, - scoped_teams: [1], - }, - }); + it("deduplicates expensive renderer auth sync for repeated auth-state events", async () => { + mockGetState.query.mockResolvedValue(authenticatedState); - await useAuthStore.getState().refreshAccessToken(); + await useAuthStore.getState().initializeOAuth(); + await useAuthStore.getState().checkCodeAccess(); - const tokens = useAuthStore.getState().storedTokens; - expect(tokens).not.toBeNull(); - expect(tokens?.scopeVersion).toBe(staleVersion); - expect(tokens?.accessToken).toBe("new-access-token"); - }); - - it("defaults scopeVersion to 0 when storedTokens is null", async () => { - useAuthStore.setState({ - oauthAccessToken: "old-token", - oauthRefreshToken: "old-refresh-token", - cloudRegion: "us", - storedTokens: null, - isAuthenticated: true, - }); - - mockRefreshToken.mutate.mockResolvedValue({ - success: true, - data: { - access_token: "new-access-token", - refresh_token: "new-refresh-token", - expires_in: 3600, - scoped_teams: [1], - }, - }); - - await useAuthStore.getState().refreshAccessToken(); - - const tokens = useAuthStore.getState().storedTokens; - expect(tokens).not.toBeNull(); - expect(tokens?.scopeVersion).toBe(0); - }); + expect(mockGetCurrentUser).toHaveBeenCalledTimes(1); + expect(mockGetValidAccessToken.query).toHaveBeenCalledTimes(1); }); }); diff --git a/apps/code/src/renderer/features/auth/stores/authStore.ts b/apps/code/src/renderer/features/auth/stores/authStore.ts index 6c7de067f..167f02262 100644 --- a/apps/code/src/renderer/features/auth/stores/authStore.ts +++ b/apps/code/src/renderer/features/auth/stores/authStore.ts @@ -1,927 +1,330 @@ import { PostHogAPIClient } from "@renderer/api/posthogClient"; import { trpcClient } from "@renderer/trpc/client"; -import { - getCloudUrlFromRegion, - OAUTH_SCOPE_VERSION, - OAUTH_SCOPES, - TOKEN_REFRESH_BUFFER_MS, -} from "@shared/constants/oauth"; +import { getCloudUrlFromRegion } from "@shared/constants/oauth"; import { ANALYTICS_EVENTS } from "@shared/types/analytics"; import type { CloudRegion } from "@shared/types/oauth"; -import { sleepWithBackoff } from "@shared/utils/backoff"; import { useNavigationStore } from "@stores/navigationStore"; -import { - identifyUser, - isFeatureFlagEnabled, - reloadFeatureFlags, - resetUser, - track, -} from "@utils/analytics"; -import { electronStorage } from "@utils/electronStorage"; +import { identifyUser, resetUser, track } from "@utils/analytics"; import { logger } from "@utils/logger"; import { queryClient } from "@utils/queryClient"; import { create } from "zustand"; -import { persist, subscribeWithSelector } from "zustand/middleware"; const log = logger.scope("auth-store"); -let refreshPromise: Promise | null = null; let initializePromise: Promise | null = null; - +let authStateSubscription: { unsubscribe: () => void } | null = null; let sessionResetCallback: (() => void) | null = null; +let inFlightAuthSync: Promise | null = null; +let inFlightAuthSyncKey: string | null = null; +let lastCompletedAuthSyncKey: string | null = null; export function setSessionResetCallback(callback: () => void) { sessionResetCallback = callback; } -const REFRESH_MAX_RETRIES = 3; -const REFRESH_INITIAL_DELAY_MS = 1000; - -function updateServiceTokens(token: string): void { - trpcClient.agent.updateToken - .mutate({ token }) - .catch((err) => log.warn("Failed to update agent token", err)); - trpcClient.cloudTask.updateToken - .mutate({ token }) - .catch((err) => log.warn("Failed to update cloud task token", err)); -} - -interface StoredTokens { - accessToken: string; - refreshToken: string; - expiresAt: number; - cloudRegion: CloudRegion; - scopedTeams?: number[]; - scopeVersion?: number; +export function resetAuthStoreModuleStateForTest(): void { + initializePromise = null; + authStateSubscription = null; + sessionResetCallback = null; + inFlightAuthSync = null; + inFlightAuthSyncKey = null; + lastCompletedAuthSyncKey = null; } -interface AuthState { - // OAuth state +interface AuthStoreState { oauthAccessToken: string | null; - oauthRefreshToken: string | null; - tokenExpiry: number | null; // Unix timestamp in milliseconds cloudRegion: CloudRegion | null; - storedTokens: StoredTokens | null; - staleTokens: StoredTokens | null; + staleCloudRegion: CloudRegion | null; - // PostHog client isAuthenticated: boolean; client: PostHogAPIClient | null; - projectId: number | null; // Current team/project ID - - // Multi-project state - availableProjectIds: number[]; // All projects from scoped_teams - availableOrgIds: string[]; // All orgs from scoped_organizations - needsProjectSelection: boolean; // True when multiple projects and no selection stored - - needsScopeReauth: boolean; // True when stored token scope version is stale + projectId: number | null; + availableProjectIds: number[]; + availableOrgIds: string[]; + needsProjectSelection: boolean; + needsScopeReauth: boolean; + hasCodeAccess: boolean | null; - // Access gate state - hasCodeAccess: boolean | null; // null = not yet checked - - // Onboarding state hasCompletedOnboarding: boolean; selectedPlan: "free" | "pro" | null; selectedOrgId: string | null; - // Access gate methods - checkCodeAccess: () => void; + checkCodeAccess: () => Promise; redeemInviteCode: (code: string) => Promise; - - // OAuth methods loginWithOAuth: (region: CloudRegion) => Promise; + signupWithOAuth: (region: CloudRegion) => Promise; refreshAccessToken: () => Promise; - scheduleTokenRefresh: () => void; initializeOAuth: () => Promise; - - // Signup method - signupWithOAuth: (region: CloudRegion) => Promise; - - // Project selection - selectProject: (projectId: number) => void; - - // Onboarding methods + selectProject: (projectId: number) => Promise; completeOnboarding: () => void; selectPlan: (plan: "free" | "pro") => void; selectOrg: (orgId: string) => void; - - // Other methods - logout: () => void; + logout: () => Promise; } -let refreshTimeoutId: number | null = null; - -export const useAuthStore = create()( - subscribeWithSelector( - persist( - (set, get) => ({ - // OAuth state - oauthAccessToken: null, - oauthRefreshToken: null, - tokenExpiry: null, - cloudRegion: null, - storedTokens: null, - staleTokens: null, - - // PostHog client - isAuthenticated: false, - client: null, - projectId: null, - - // Multi-project state - availableProjectIds: [], - availableOrgIds: [], - needsProjectSelection: false, - // Scope re-auth state - needsScopeReauth: false, - - // Access gate state - hasCodeAccess: null, - - // Onboarding state - hasCompletedOnboarding: false, - selectedPlan: null, - selectedOrgId: null, - - checkCodeAccess: () => { - const state = get(); - if (!state.cloudRegion || !state.oauthAccessToken) { - set({ hasCodeAccess: false }); - return; - } - - set({ hasCodeAccess: null }); - - const baseUrl = getCloudUrlFromRegion(state.cloudRegion); - fetch(`${baseUrl}/api/code/invites/check-access/`, { - headers: { - Authorization: `Bearer ${state.oauthAccessToken}`, - }, - }) - .then((res) => res.json()) - .then((data) => { - set({ hasCodeAccess: data.has_access === true }); - }) - .catch((err) => { - log.error("Failed to check code access", err); - // On network error, fall back to feature flag check - set({ hasCodeAccess: isFeatureFlagEnabled("tasks") }); - }); - }, - - redeemInviteCode: async (code: string) => { - const state = get(); - if (!state.cloudRegion || !state.oauthAccessToken) { - throw new Error("Not authenticated"); - } - - const baseUrl = getCloudUrlFromRegion(state.cloudRegion); - const response = await fetch(`${baseUrl}/api/code/invites/redeem/`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${state.oauthAccessToken}`, - }, - body: JSON.stringify({ code }), - }); - - const data = await response.json(); - - if (!response.ok || !data.success) { - throw new Error(data.error || "Failed to redeem invite code"); - } - - // Optimistically grant access — the flag will catch up on next launch - set({ hasCodeAccess: true }); - reloadFeatureFlags(); - }, - - loginWithOAuth: async (region: CloudRegion) => { - const result = await trpcClient.oauth.startFlow.mutate({ region }); - - if (!result.success || !result.data) { - throw new Error(result.error || "OAuth flow failed"); - } - - const tokenResponse = result.data; - const expiresAt = Date.now() + tokenResponse.expires_in * 1000; - - const scopedTeams = tokenResponse.scoped_teams ?? []; - const scopedOrgs = tokenResponse.scoped_organizations ?? []; - - if (scopedTeams.length === 0) { - throw new Error("No team found in OAuth scopes"); - } - - const storedTokens: StoredTokens = { - accessToken: tokenResponse.access_token, - refreshToken: tokenResponse.refresh_token, - expiresAt, - cloudRegion: region, - scopedTeams, - scopeVersion: OAUTH_SCOPE_VERSION, - }; - - const apiHost = getCloudUrlFromRegion(region); - - const client = new PostHogAPIClient( - tokenResponse.access_token, - apiHost, - async () => { - await get().refreshAccessToken(); - const token = get().oauthAccessToken; - if (!token) { - throw new Error("No access token after refresh"); - } - return token; - }, - scopedTeams[0], - ); - - try { - const user = await client.getCurrentUser(); - - // Determine project: prefer user's current PostHog project, then previously stored, then first available - const userCurrentTeam = user?.team?.id; - const storedProjectId = get().projectId; - const selectedProjectId = - userCurrentTeam != null && scopedTeams.includes(userCurrentTeam) - ? userCurrentTeam - : storedProjectId !== null && - scopedTeams.includes(storedProjectId) - ? storedProjectId - : scopedTeams[0]; - - // Update client's teamId to match selected project - client.setTeamId(selectedProjectId); - - set({ - oauthAccessToken: tokenResponse.access_token, - oauthRefreshToken: tokenResponse.refresh_token, - tokenExpiry: expiresAt, - cloudRegion: region, - storedTokens, - isAuthenticated: true, - client, - projectId: selectedProjectId, - availableProjectIds: scopedTeams, - availableOrgIds: scopedOrgs, - needsProjectSelection: false, - needsScopeReauth: false, - }); - - updateServiceTokens(tokenResponse.access_token); - - // Clear any cached data from previous sessions AFTER setting new auth - queryClient.clear(); - queryClient.setQueryData(["currentUser"], user); - - get().scheduleTokenRefresh(); - - // Track user login - use distinct_id to match web sessions (same as PostHog web app) - const distinctId = user.distinct_id || user.email; - identifyUser(distinctId, { - email: user.email, - uuid: user.uuid, - project_id: selectedProjectId.toString(), - region, - }); - track(ANALYTICS_EVENTS.USER_LOGGED_IN, { - project_id: selectedProjectId.toString(), - region, - }); - - trpcClient.analytics.setUserId.mutate({ - userId: distinctId, - properties: { - email: user.email, - uuid: user.uuid, - project_id: selectedProjectId.toString(), - region, - }, - }); - - get().checkCodeAccess(); - } catch (error) { - log.error("Failed to authenticate with PostHog", error); - throw new Error("Failed to authenticate with PostHog"); - } - }, - - refreshAccessToken: async () => { - // If a refresh is already in progress, wait for it - if (refreshPromise) { - log.debug("Token refresh already in progress, waiting..."); - return refreshPromise; - } - - const doRefresh = async () => { - const state = get(); - - if (!state.oauthRefreshToken || !state.cloudRegion) { - throw new Error("No refresh token available"); - } - - // Retry with exponential backoff - let lastError: Error | null = null; - for (let attempt = 0; attempt < REFRESH_MAX_RETRIES; attempt++) { - try { - if (attempt > 0) { - log.debug( - `Retrying token refresh (attempt ${ - attempt + 1 - }/${REFRESH_MAX_RETRIES})`, - ); - await sleepWithBackoff(attempt - 1, { - initialDelayMs: REFRESH_INITIAL_DELAY_MS, - }); - } - - const result = await trpcClient.oauth.refreshToken.mutate({ - refreshToken: state.oauthRefreshToken, - region: state.cloudRegion, - }); - - if (!result.success || !result.data) { - // Network/server errors should retry, auth errors should logout immediately - if ( - result.errorCode === "network_error" || - result.errorCode === "server_error" - ) { - log.warn( - `Token refresh ${result.errorCode} (attempt ${ - attempt + 1 - }/${REFRESH_MAX_RETRIES}): ${result.error}`, - ); - lastError = new Error( - result.error || "Token refresh failed", - ); - continue; // Retry - } - - // Auth error or unknown - logout - log.error( - `Token refresh failed with ${result.errorCode}: ${result.error}`, - ); - get().logout(); - throw new Error(result.error || "Token refresh failed"); - } - - const tokenResponse = result.data; - const expiresAt = Date.now() + tokenResponse.expires_in * 1000; - - const storedTokens: StoredTokens = { - accessToken: tokenResponse.access_token, - refreshToken: tokenResponse.refresh_token, - expiresAt, - cloudRegion: state.cloudRegion, - scopedTeams: tokenResponse.scoped_teams, - scopeVersion: state.storedTokens?.scopeVersion ?? 0, - }; - - const apiHost = getCloudUrlFromRegion(state.cloudRegion); - const scopedTeams = tokenResponse.scoped_teams ?? []; - const storedProjectId = state.projectId; - const projectId = - storedProjectId && scopedTeams.includes(storedProjectId) - ? storedProjectId - : (scopedTeams[0] ?? storedProjectId ?? undefined); - - const client = new PostHogAPIClient( - tokenResponse.access_token, - apiHost, - async () => { - await get().refreshAccessToken(); - const token = get().oauthAccessToken; - if (!token) { - throw new Error("No access token after refresh"); - } - return token; - }, - projectId, - ); - - set({ - oauthAccessToken: tokenResponse.access_token, - oauthRefreshToken: tokenResponse.refresh_token, - tokenExpiry: expiresAt, - storedTokens, - client, - ...(projectId && { projectId }), - availableProjectIds: - scopedTeams.length > 0 - ? scopedTeams - : state.availableProjectIds, - }); - - updateServiceTokens(tokenResponse.access_token); - - get().scheduleTokenRefresh(); - return; // Success - } catch (error) { - lastError = - error instanceof Error ? error : new Error(String(error)); - - // Check if this is a permanent failure (logout already called) - if (!get().oauthRefreshToken) { - throw lastError; - } - - // tRPC exceptions are typically IPC failures - retry them - log.warn( - `Token refresh exception (attempt ${attempt + 1}): ${ - lastError.message - }`, - ); - } - } - - // All retries exhausted - log.error( - `Token refresh failed after all retries: ${ - lastError?.message || "Unknown error" - }`, - ); - get().logout(); - throw lastError || new Error("Token refresh failed"); - }; - - refreshPromise = doRefresh().finally(() => { - refreshPromise = null; - }); - - return refreshPromise; - }, - - scheduleTokenRefresh: () => { - const state = get(); - - if (refreshTimeoutId) { - window.clearTimeout(refreshTimeoutId); - refreshTimeoutId = null; - } - - if (!state.tokenExpiry) { - return; - } - - const timeUntilRefresh = - state.tokenExpiry - Date.now() - TOKEN_REFRESH_BUFFER_MS; - - if (timeUntilRefresh > 0) { - refreshTimeoutId = window.setTimeout(() => { - get() - .refreshAccessToken() - .catch((error) => { - log.error("Proactive token refresh failed:", error); - }); - }, timeUntilRefresh); - } else { - get() - .refreshAccessToken() - .catch((error) => { - log.error("Immediate token refresh failed:", error); - }); - } - }, - - initializeOAuth: async () => { - // If initialization is already in progress, wait for it - if (initializePromise) { - log.debug("OAuth initialization already in progress, waiting..."); - return initializePromise; - } - - const doInitialize = async (): Promise => { - // Wait for zustand hydration from async storage - if (!useAuthStore.persist.hasHydrated()) { - await new Promise((resolve) => { - useAuthStore.persist.onFinishHydration(() => resolve()); - }); - } - - const state = get(); - - if (state.storedTokens) { - const tokens = state.storedTokens; - const tokenScopeVersion = tokens.scopeVersion ?? 0; - if (tokenScopeVersion < OAUTH_SCOPE_VERSION) { - log.info("OAuth scopes updated, re-authentication required", { - tokenVersion: tokenScopeVersion, - requiredVersion: OAUTH_SCOPE_VERSION, - requiredScopes: OAUTH_SCOPES, - }); - set({ - needsScopeReauth: true, - oauthAccessToken: tokens.accessToken, - oauthRefreshToken: tokens.refreshToken, - tokenExpiry: tokens.expiresAt, - cloudRegion: tokens.cloudRegion, - isAuthenticated: true, - }); - return true; - } - const now = Date.now(); - const isExpired = tokens.expiresAt <= now; - - set({ - oauthAccessToken: tokens.accessToken, - oauthRefreshToken: tokens.refreshToken, - tokenExpiry: tokens.expiresAt, - cloudRegion: tokens.cloudRegion, - }); - - if (isExpired) { - try { - await get().refreshAccessToken(); - } catch (error) { - log.error("Failed to refresh expired token:", error); - set({ - storedTokens: null, - isAuthenticated: false, - needsScopeReauth: false, - }); - return false; - } - } - - // Re-fetch tokens after potential refresh to get updated values - const currentTokens = get().storedTokens; - if (!currentTokens) { - return false; - } - - const apiHost = getCloudUrlFromRegion(currentTokens.cloudRegion); - const scopedTeams = currentTokens.scopedTeams ?? []; - - if (scopedTeams.length === 0) { - log.error("No projects found in stored tokens"); - get().logout(); - return false; - } - - const storedProjectId = get().projectId; - const availableProjects = - get().availableProjectIds.length > 0 - ? get().availableProjectIds - : scopedTeams; - const hasValidStoredProject = - storedProjectId !== null && - availableProjects.includes(storedProjectId); - - const client = new PostHogAPIClient( - currentTokens.accessToken, - apiHost, - async () => { - await get().refreshAccessToken(); - const token = get().oauthAccessToken; - if (!token) { - throw new Error("No access token after refresh"); - } - return token; - }, - hasValidStoredProject ? storedProjectId : scopedTeams[0], - ); - - try { - const user = await client.getCurrentUser(); - - // Prefer stored project, then user's current PostHog project, then first available - const userCurrentTeam = user?.team?.id; - const selectedProjectId = hasValidStoredProject - ? storedProjectId - : userCurrentTeam != null && - scopedTeams.includes(userCurrentTeam) - ? userCurrentTeam - : scopedTeams[0]; - - // Update client's teamId to match selected project - client.setTeamId(selectedProjectId); - - set({ - isAuthenticated: true, - client, - projectId: selectedProjectId, - availableProjectIds: scopedTeams, - needsProjectSelection: false, - }); - - queryClient.setQueryData(["currentUser"], user); - - updateServiceTokens(currentTokens.accessToken); - - get().scheduleTokenRefresh(); - - // Use distinct_id to match web sessions (same as PostHog web app) - const distinctId = user.distinct_id || user.email; - identifyUser(distinctId, { - email: user.email, - uuid: user.uuid, - project_id: selectedProjectId.toString(), - region: tokens.cloudRegion, - }); - - trpcClient.analytics.setUserId.mutate({ - userId: distinctId, - properties: { - email: user.email, - uuid: user.uuid, - project_id: selectedProjectId.toString(), - region: tokens.cloudRegion, - }, - }); - - get().checkCodeAccess(); - - return true; - } catch (error) { - log.error("Failed to validate OAuth session:", error); - - // Network errors from fetch are TypeError, wrapped by fetcher.ts as cause - const isNetworkError = - error instanceof Error && error.cause instanceof TypeError; - - if (isNetworkError) { - log.warn( - "Network error during session validation - keeping session active", - ); - const fallbackProjectId = hasValidStoredProject - ? storedProjectId - : scopedTeams[0]; - set({ - isAuthenticated: true, - client, - projectId: fallbackProjectId, - availableProjectIds: scopedTeams, - needsProjectSelection: false, - }); - get().scheduleTokenRefresh(); - return true; - } - - // For auth errors (401/403) or unknown errors, clear the session - set({ - storedTokens: null, - isAuthenticated: false, - needsScopeReauth: false, - }); - return false; - } - } - - return state.isAuthenticated; - }; - - initializePromise = doInitialize().finally(() => { - initializePromise = null; - }); - - return initializePromise; - }, +async function getValidAccessToken(): Promise { + const { accessToken } = await trpcClient.auth.getValidAccessToken.query(); + useAuthStore.setState({ oauthAccessToken: accessToken }); + return accessToken; +} - signupWithOAuth: async (region: CloudRegion) => { - const result = await trpcClient.oauth.startSignupFlow.mutate({ - region, - }); - - if (!result.success || !result.data) { - throw new Error(result.error || "Signup failed"); - } - - const tokenResponse = result.data; - const expiresAt = Date.now() + tokenResponse.expires_in * 1000; - - const scopedTeams = tokenResponse.scoped_teams ?? []; - const scopedOrgs = tokenResponse.scoped_organizations ?? []; - - if (scopedTeams.length === 0) { - throw new Error("No team found in OAuth scopes"); - } - - const storedTokens: StoredTokens = { - accessToken: tokenResponse.access_token, - refreshToken: tokenResponse.refresh_token, - expiresAt, - cloudRegion: region, - scopedTeams, - scopeVersion: OAUTH_SCOPE_VERSION, - }; - - const apiHost = getCloudUrlFromRegion(region); - const selectedProjectId = scopedTeams[0]; - - const client = new PostHogAPIClient( - tokenResponse.access_token, - apiHost, - async () => { - await get().refreshAccessToken(); - const token = get().oauthAccessToken; - if (!token) { - throw new Error("No access token after refresh"); - } - return token; - }, - selectedProjectId, - ); - - try { - const user = await client.getCurrentUser(); - - set({ - oauthAccessToken: tokenResponse.access_token, - oauthRefreshToken: tokenResponse.refresh_token, - tokenExpiry: expiresAt, - cloudRegion: region, - storedTokens, - isAuthenticated: true, - client, - projectId: selectedProjectId, - availableProjectIds: scopedTeams, - availableOrgIds: scopedOrgs, - needsProjectSelection: false, - needsScopeReauth: false, - }); - - updateServiceTokens(tokenResponse.access_token); - - queryClient.clear(); - queryClient.setQueryData(["currentUser"], user); - - get().scheduleTokenRefresh(); - - const distinctId = user.distinct_id || user.email; - identifyUser(distinctId, { - email: user.email, - uuid: user.uuid, - project_id: selectedProjectId.toString(), - region, - }); - track(ANALYTICS_EVENTS.USER_LOGGED_IN, { - project_id: selectedProjectId.toString(), - region, - }); - - trpcClient.analytics.setUserId.mutate({ - userId: distinctId, - properties: { - email: user.email, - uuid: user.uuid, - project_id: selectedProjectId.toString(), - region, - }, - }); - - get().checkCodeAccess(); - } catch (error) { - log.error("Failed to authenticate with PostHog", error); - throw new Error("Failed to authenticate with PostHog"); - } - }, +async function refreshAccessToken(): Promise { + const { accessToken } = await trpcClient.auth.refreshAccessToken.mutate(); + useAuthStore.setState({ oauthAccessToken: accessToken }); + return accessToken; +} - selectProject: (projectId: number) => { - const state = get(); - - // Validate that the project is in the available list - if (!state.availableProjectIds.includes(projectId)) { - log.error("Attempted to select invalid project", { projectId }); - throw new Error("Invalid project selection"); - } - - const cloudRegion = state.cloudRegion; - if (!cloudRegion) { - throw new Error("No cloud region available"); - } - - const accessToken = state.oauthAccessToken; - if (!accessToken) { - throw new Error("No access token available"); - } - - // Clean up all existing sessions before switching projects - sessionResetCallback?.(); - - const apiHost = getCloudUrlFromRegion(cloudRegion); - - // Create a new client with the selected project - const client = new PostHogAPIClient( - accessToken, - apiHost, - async () => { - await get().refreshAccessToken(); - const token = get().oauthAccessToken; - if (!token) { - throw new Error("No access token after refresh"); - } - return token; - }, - projectId, - ); - - // Update stored tokens with the selected project - const updatedTokens = state.storedTokens - ? { ...state.storedTokens, scopedTeams: state.availableProjectIds } - : null; - - set({ - projectId, - client, - needsProjectSelection: false, - storedTokens: updatedTokens, - }); - - // Clear project-scoped queries, but keep project list/user for the switcher - queryClient.removeQueries({ - predicate: (query) => { - const key = Array.isArray(query.queryKey) - ? query.queryKey[0] - : query.queryKey; - return key !== "currentUser"; - }, - }); - - // Navigate to task input after project selection - useNavigationStore.getState().navigateToTaskInput(); - - // Update analytics with the selected project - updateServiceTokens(accessToken); - - track(ANALYTICS_EVENTS.USER_LOGGED_IN, { - project_id: projectId.toString(), - region: cloudRegion, - }); - - log.info("Project selected", { projectId }); - }, +function updateServiceTokens(token: string): void { + trpcClient.agent.updateToken + .mutate({ token }) + .catch((err) => log.warn("Failed to update agent token", err)); + trpcClient.cloudTask.updateToken + .mutate({ token }) + .catch((err) => log.warn("Failed to update cloud task token", err)); +} - completeOnboarding: () => { - set({ hasCompletedOnboarding: true }); - }, +function createClient( + cloudRegion: CloudRegion, + projectId: number | null, +): PostHogAPIClient { + const client = new PostHogAPIClient( + getCloudUrlFromRegion(cloudRegion), + getValidAccessToken, + refreshAccessToken, + projectId ?? undefined, + ); + if (projectId) { + client.setTeamId(projectId); + } + return client; +} - selectPlan: (plan: "free" | "pro") => { - set({ selectedPlan: plan }); +async function syncAuthState(): Promise { + const authState = await trpcClient.auth.getState.query(); + const isAuthenticated = authState.status === "authenticated"; + + useAuthStore.setState((state) => { + const regionChanged = authState.cloudRegion !== state.cloudRegion; + const projectChanged = authState.projectId !== state.projectId; + const client = + isAuthenticated && authState.cloudRegion + ? regionChanged || projectChanged || !state.client + ? createClient(authState.cloudRegion, authState.projectId) + : state.client + : null; + + return { + ...state, + isAuthenticated, + cloudRegion: authState.cloudRegion, + staleCloudRegion: isAuthenticated + ? null + : (authState.cloudRegion ?? state.staleCloudRegion), + client, + projectId: authState.projectId, + availableProjectIds: authState.availableProjectIds, + availableOrgIds: authState.availableOrgIds, + needsProjectSelection: + isAuthenticated && + authState.availableProjectIds.length > 1 && + authState.projectId === null, + needsScopeReauth: authState.needsScopeReauth, + hasCodeAccess: authState.hasCodeAccess, + }; + }); + + const client = useAuthStore.getState().client; + + if (!isAuthenticated || !authState.cloudRegion || !client) { + inFlightAuthSync = null; + inFlightAuthSyncKey = null; + lastCompletedAuthSyncKey = null; + return; + } + + const authSyncKey = JSON.stringify({ + status: authState.status, + cloudRegion: authState.cloudRegion, + projectId: authState.projectId, + }); + + if (authSyncKey === lastCompletedAuthSyncKey) { + return; + } + + if (inFlightAuthSync && inFlightAuthSyncKey === authSyncKey) { + await inFlightAuthSync; + return; + } + + inFlightAuthSyncKey = authSyncKey; + inFlightAuthSync = (async () => { + try { + const user = await client.getCurrentUser(); + queryClient.setQueryData(["currentUser"], user); + + const token = await getValidAccessToken(); + updateServiceTokens(token); + + const distinctId = user.distinct_id || user.email; + identifyUser(distinctId, { + email: user.email, + uuid: user.uuid, + project_id: authState.projectId?.toString() ?? "", + region: authState.cloudRegion ?? "", + }); + + trpcClient.analytics.setUserId.mutate({ + userId: distinctId, + properties: { + email: user.email, + uuid: user.uuid, + project_id: authState.projectId?.toString() ?? "", + region: authState.cloudRegion ?? "", }, + }); + + lastCompletedAuthSyncKey = authSyncKey; + } catch (error) { + log.warn("Failed to synchronize authenticated renderer state", { error }); + } finally { + if (inFlightAuthSyncKey === authSyncKey) { + inFlightAuthSync = null; + inFlightAuthSyncKey = null; + } + } + })(); + + await inFlightAuthSync; +} - selectOrg: (orgId: string) => { - set({ selectedOrgId: orgId }); - }, +function ensureAuthSubscription(): void { + if (authStateSubscription) { + return; + } + + authStateSubscription = trpcClient.auth.onStateChanged.subscribe(undefined, { + onData: () => { + void syncAuthState(); + }, + onError: (error) => { + log.error("Auth state subscription error", { error }); + }, + }); +} - logout: () => { - track(ANALYTICS_EVENTS.USER_LOGGED_OUT); - resetUser(); - - // Clean up session service subscriptions before clearing auth state - sessionResetCallback?.(); - - trpcClient.analytics.resetUser.mutate(); - - if (refreshTimeoutId) { - window.clearTimeout(refreshTimeoutId); - refreshTimeoutId = null; - } - - queryClient.clear(); - - useNavigationStore.getState().navigateToTaskInput(); - - const currentTokens = get().storedTokens; - - set({ - oauthAccessToken: null, - oauthRefreshToken: null, - tokenExpiry: null, - cloudRegion: null, - storedTokens: null, - staleTokens: currentTokens, - isAuthenticated: false, - client: null, - projectId: null, - availableProjectIds: [], - availableOrgIds: [], - needsProjectSelection: false, - needsScopeReauth: false, - hasCodeAccess: null, - selectedPlan: null, - selectedOrgId: null, - }); - }, - }), - { - // TODO: Migrate to posthog-code - name: "array-auth", - storage: electronStorage, - partialize: (state) => ({ - cloudRegion: state.cloudRegion, - storedTokens: state.storedTokens, - staleTokens: state.staleTokens, - projectId: state.projectId, - availableProjectIds: state.availableProjectIds, - availableOrgIds: state.availableOrgIds, - hasCodeAccess: state.hasCodeAccess, - hasCompletedOnboarding: state.hasCompletedOnboarding, - selectedPlan: state.selectedPlan, - selectedOrgId: state.selectedOrgId, - }), - }, - ), - ), -); +export const useAuthStore = create((set, get) => ({ + oauthAccessToken: null, + cloudRegion: null, + staleCloudRegion: null, + + isAuthenticated: false, + client: null, + projectId: null, + availableProjectIds: [], + availableOrgIds: [], + needsProjectSelection: false, + needsScopeReauth: false, + hasCodeAccess: null, + + hasCompletedOnboarding: false, + selectedPlan: null, + selectedOrgId: null, + + checkCodeAccess: async () => { + await syncAuthState(); + }, + + redeemInviteCode: async (code: string) => { + await trpcClient.auth.redeemInviteCode.mutate({ code }); + await syncAuthState(); + }, + + loginWithOAuth: async (region: CloudRegion) => { + const result = await trpcClient.auth.login.mutate({ region }); + await syncAuthState(); + track(ANALYTICS_EVENTS.USER_LOGGED_IN, { + project_id: result.state.projectId?.toString() ?? "", + region, + }); + }, + + signupWithOAuth: async (region: CloudRegion) => { + const result = await trpcClient.auth.signup.mutate({ region }); + await syncAuthState(); + track(ANALYTICS_EVENTS.USER_LOGGED_IN, { + project_id: result.state.projectId?.toString() ?? "", + region, + }); + }, + + refreshAccessToken: async () => { + const token = await refreshAccessToken(); + updateServiceTokens(token); + }, + + initializeOAuth: async () => { + if (initializePromise) { + return initializePromise; + } + + initializePromise = (async () => { + ensureAuthSubscription(); + await syncAuthState(); + return get().isAuthenticated || get().needsScopeReauth; + })().finally(() => { + initializePromise = null; + }); + + return initializePromise; + }, + + selectProject: async (projectId: number) => { + sessionResetCallback?.(); + await trpcClient.auth.selectProject.mutate({ projectId }); + await syncAuthState(); + useNavigationStore.getState().navigateToTaskInput(); + }, + + completeOnboarding: () => { + set({ hasCompletedOnboarding: true }); + }, + + selectPlan: (plan: "free" | "pro") => { + set({ selectedPlan: plan }); + }, + + selectOrg: (orgId: string) => { + set({ selectedOrgId: orgId }); + }, + + logout: async () => { + track(ANALYTICS_EVENTS.USER_LOGGED_OUT); + resetUser(); + sessionResetCallback?.(); + queryClient.clear(); + await trpcClient.auth.logout.mutate(); + trpcClient.analytics.resetUser.mutate(); + useNavigationStore.getState().navigateToTaskInput(); + + set((state) => ({ + ...state, + oauthAccessToken: null, + cloudRegion: null, + staleCloudRegion: state.cloudRegion ?? null, + isAuthenticated: false, + client: null, + projectId: null, + availableProjectIds: [], + availableOrgIds: [], + needsProjectSelection: false, + needsScopeReauth: false, + hasCodeAccess: null, + selectedPlan: null, + selectedOrgId: null, + })); + inFlightAuthSync = null; + inFlightAuthSyncKey = null; + lastCompletedAuthSyncKey = null; + }, +})); diff --git a/scripts/clean-posthog-code-macos.sh b/scripts/clean-posthog-code-macos.sh index 3a8946a66..814cd6167 100755 --- a/scripts/clean-posthog-code-macos.sh +++ b/scripts/clean-posthog-code-macos.sh @@ -1,10 +1,10 @@ #!/bin/bash -# Clean Twig app data from macOS +# Clean PostHog Code app data from macOS # # Usage: -# ./scripts/clean-twig-macos.sh # Clean data only -# ./scripts/clean-twig-macos.sh --app # Clean data and delete app +# ./scripts/clean-posthog-code-macos.sh # Clean data only +# ./scripts/clean-posthog-code-macos.sh --app # Clean data and delete app set -e @@ -20,25 +20,39 @@ for arg in "$@"; do echo "Usage: $0 [--app]" echo "" echo "Options:" - echo " --app Also delete Twig.app from /Applications" + echo " --app Also delete PostHog Code.app from /Applications" echo "" echo "This script removes:" - echo " - ~/Library/Application Support/@posthog/Array" - echo " - ~/Library/Application Support/@posthog/Twig" - echo " - ~/Library/Application Support/@posthog/twig-dev" + echo " - ~/Library/Application Support/@posthog/posthog-code" + echo " - ~/Library/Application Support/@posthog/posthog-code-dev" + echo " - ~/Library/Application Support/@posthog/Array (legacy)" + echo " - ~/Library/Application Support/@posthog/Twig (legacy)" + echo " - ~/Library/Application Support/@posthog/twig-dev (legacy)" echo " - ~/Library/Preferences/com.posthog.array.plist" echo " - ~/Library/Caches/com.posthog.array" - echo " - ~/Library/Logs/Twig" + echo " - ~/.posthog-code (logs and cache)" + echo " - ~/Library/Logs/PostHog Code" echo " - ~/Library/Saved Application State/com.posthog.array.savedState" exit 0 ;; esac done -echo "Cleaning Twig data from macOS..." +echo "Cleaning PostHog Code data from macOS..." echo "" -# Application Support - actual electron data locations +# Application Support - current electron data locations +if [ -d "$HOME/Library/Application Support/@posthog/posthog-code" ]; then + echo "Removing ~/Library/Application Support/@posthog/posthog-code" + rm -rf "$HOME/Library/Application Support/@posthog/posthog-code" +fi + +if [ -d "$HOME/Library/Application Support/@posthog/posthog-code-dev" ]; then + echo "Removing ~/Library/Application Support/@posthog/posthog-code-dev" + rm -rf "$HOME/Library/Application Support/@posthog/posthog-code-dev" +fi + +# Application Support - legacy locations if [ -d "$HOME/Library/Application Support/@posthog/Array" ]; then echo "Removing ~/Library/Application Support/@posthog/Array" rm -rf "$HOME/Library/Application Support/@posthog/Array" @@ -102,7 +116,18 @@ if [ -d "$HOME/Library/Caches/Twig" ]; then rm -rf "$HOME/Library/Caches/Twig" fi +# Home directory data (logs and cache) +if [ -d "$HOME/.posthog-code" ]; then + echo "Removing ~/.posthog-code" + rm -rf "$HOME/.posthog-code" +fi + # Logs +if [ -d "$HOME/Library/Logs/PostHog Code" ]; then + echo "Removing ~/Library/Logs/PostHog Code" + rm -rf "$HOME/Library/Logs/PostHog Code" +fi + if [ -d "$HOME/Library/Logs/twig" ]; then echo "Removing ~/Library/Logs/twig" rm -rf "$HOME/Library/Logs/twig" @@ -126,6 +151,10 @@ fi # App (optional) if [ "$DELETE_APP" = true ]; then + if [ -d "/Applications/PostHog Code.app" ]; then + echo "Removing /Applications/PostHog Code.app" + rm -rf "/Applications/PostHog Code.app" + fi if [ -d "/Applications/Twig.app" ]; then echo "Removing /Applications/Twig.app" rm -rf "/Applications/Twig.app" @@ -141,5 +170,5 @@ echo "Done!" if [ "$DELETE_APP" = false ]; then echo "" - echo "Note: Twig.app was not deleted. Use --app flag to also remove the app." + echo "Note: PostHog Code.app was not deleted. Use --app flag to also remove the app." fi