From 8077d8d3bd4e55a3a28c53ac7098b649a79a4354 Mon Sep 17 00:00:00 2001 From: JonathanLab Date: Mon, 30 Mar 2026 16:18:12 +0200 Subject: [PATCH] feat: persist preferred project by account key --- .../db/migrations/0004_auth_preferences.sql | 9 ++ .../src/main/db/migrations/meta/_journal.json | 7 ++ .../auth-preference-repository.mock.ts | 57 +++++++++ .../auth-preference-repository.ts | 89 ++++++++++++++ apps/code/src/main/db/schema.ts | 17 +++ apps/code/src/main/di/container.ts | 4 + apps/code/src/main/di/tokens.ts | 1 + .../src/main/services/auth/service.test.ts | 114 +++++++++++++++--- apps/code/src/main/services/auth/service.ts | 83 +++++++++++-- 9 files changed, 355 insertions(+), 26 deletions(-) create mode 100644 apps/code/src/main/db/migrations/0004_auth_preferences.sql create mode 100644 apps/code/src/main/db/repositories/auth-preference-repository.mock.ts create mode 100644 apps/code/src/main/db/repositories/auth-preference-repository.ts diff --git a/apps/code/src/main/db/migrations/0004_auth_preferences.sql b/apps/code/src/main/db/migrations/0004_auth_preferences.sql new file mode 100644 index 000000000..d1b5c0d2d --- /dev/null +++ b/apps/code/src/main/db/migrations/0004_auth_preferences.sql @@ -0,0 +1,9 @@ +CREATE TABLE `auth_preferences` ( + `account_key` text NOT NULL, + `cloud_region` text NOT NULL, + `last_selected_project_id` integer, + `created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL, + `updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL +); +--> statement-breakpoint +CREATE INDEX `auth_preferences_account_region_idx` ON `auth_preferences` (`account_key`,`cloud_region`); diff --git a/apps/code/src/main/db/migrations/meta/_journal.json b/apps/code/src/main/db/migrations/meta/_journal.json index ab1209f9a..791d110c9 100644 --- a/apps/code/src/main/db/migrations/meta/_journal.json +++ b/apps/code/src/main/db/migrations/meta/_journal.json @@ -29,6 +29,13 @@ "when": 1774890000000, "tag": "0003_fair_whiplash", "breakpoints": true + }, + { + "idx": 4, + "version": "7", + "when": 1774891000000, + "tag": "0004_auth_preferences", + "breakpoints": true } ] } diff --git a/apps/code/src/main/db/repositories/auth-preference-repository.mock.ts b/apps/code/src/main/db/repositories/auth-preference-repository.mock.ts new file mode 100644 index 000000000..ae99875b6 --- /dev/null +++ b/apps/code/src/main/db/repositories/auth-preference-repository.mock.ts @@ -0,0 +1,57 @@ +import type { + AuthPreference, + IAuthPreferenceRepository, + PersistAuthPreferenceInput, +} from "./auth-preference-repository"; + +export interface MockAuthPreferenceRepository + extends IAuthPreferenceRepository { + _preferences: AuthPreference[]; +} + +export function createMockAuthPreferenceRepository(): MockAuthPreferenceRepository { + let preferences: AuthPreference[] = []; + + const clone = (value: AuthPreference): AuthPreference => ({ ...value }); + + return { + get _preferences() { + return preferences.map(clone); + }, + set _preferences(value) { + preferences = value.map(clone); + }, + get: (accountKey, cloudRegion) => { + const preference = preferences.find( + (entry) => + entry.accountKey === accountKey && entry.cloudRegion === cloudRegion, + ); + return preference ? clone(preference) : null; + }, + save: (input: PersistAuthPreferenceInput) => { + const timestamp = new Date().toISOString(); + const existingIndex = preferences.findIndex( + (entry) => + entry.accountKey === input.accountKey && + entry.cloudRegion === input.cloudRegion, + ); + + const row: AuthPreference = { + accountKey: input.accountKey, + cloudRegion: input.cloudRegion, + lastSelectedProjectId: input.lastSelectedProjectId, + createdAt: + existingIndex >= 0 ? preferences[existingIndex].createdAt : timestamp, + updatedAt: timestamp, + }; + + if (existingIndex >= 0) { + preferences[existingIndex] = row; + } else { + preferences.push(row); + } + + return clone(row); + }, + }; +} diff --git a/apps/code/src/main/db/repositories/auth-preference-repository.ts b/apps/code/src/main/db/repositories/auth-preference-repository.ts new file mode 100644 index 000000000..6962e03e9 --- /dev/null +++ b/apps/code/src/main/db/repositories/auth-preference-repository.ts @@ -0,0 +1,89 @@ +import { and, eq } from "drizzle-orm"; +import { inject, injectable } from "inversify"; +import { MAIN_TOKENS } from "../../di/tokens"; +import { authPreferences } from "../schema"; +import type { DatabaseService } from "../service"; + +export type AuthPreference = typeof authPreferences.$inferSelect; +export type NewAuthPreference = typeof authPreferences.$inferInsert; + +export interface PersistAuthPreferenceInput { + accountKey: string; + cloudRegion: "us" | "eu" | "dev"; + lastSelectedProjectId: number | null; +} + +export interface IAuthPreferenceRepository { + get( + accountKey: string, + cloudRegion: "us" | "eu" | "dev", + ): AuthPreference | null; + save(input: PersistAuthPreferenceInput): AuthPreference; +} + +const now = () => new Date().toISOString(); + +@injectable() +export class AuthPreferenceRepository implements IAuthPreferenceRepository { + constructor( + @inject(MAIN_TOKENS.DatabaseService) + private readonly databaseService: DatabaseService, + ) {} + + private get db() { + return this.databaseService.db; + } + + get( + accountKey: string, + cloudRegion: "us" | "eu" | "dev", + ): AuthPreference | null { + return ( + this.db + .select() + .from(authPreferences) + .where( + and( + eq(authPreferences.accountKey, accountKey), + eq(authPreferences.cloudRegion, cloudRegion), + ), + ) + .limit(1) + .get() ?? null + ); + } + + save(input: PersistAuthPreferenceInput): AuthPreference { + const timestamp = now(); + const existing = this.get(input.accountKey, input.cloudRegion); + + const row: NewAuthPreference = { + accountKey: input.accountKey, + cloudRegion: input.cloudRegion, + lastSelectedProjectId: input.lastSelectedProjectId, + createdAt: existing?.createdAt ?? timestamp, + updatedAt: timestamp, + }; + + if (existing) { + this.db + .update(authPreferences) + .set(row) + .where( + and( + eq(authPreferences.accountKey, input.accountKey), + eq(authPreferences.cloudRegion, input.cloudRegion), + ), + ) + .run(); + } else { + this.db.insert(authPreferences).values(row).run(); + } + + const saved = this.get(input.accountKey, input.cloudRegion); + if (!saved) { + throw new Error("Failed to persist auth preference"); + } + return saved; + } +} diff --git a/apps/code/src/main/db/schema.ts b/apps/code/src/main/db/schema.ts index 00849018a..86ec6c432 100644 --- a/apps/code/src/main/db/schema.ts +++ b/apps/code/src/main/db/schema.ts @@ -86,3 +86,20 @@ export const authSessions = sqliteTable("auth_sessions", { createdAt: createdAt(), updatedAt: updatedAt(), }); + +export const authPreferences = sqliteTable( + "auth_preferences", + { + accountKey: text().notNull(), + cloudRegion: text({ enum: ["us", "eu", "dev"] }).notNull(), + lastSelectedProjectId: integer(), + createdAt: createdAt(), + updatedAt: updatedAt(), + }, + (t) => [ + index("auth_preferences_account_region_idx").on( + t.accountKey, + t.cloudRegion, + ), + ], +); diff --git a/apps/code/src/main/di/container.ts b/apps/code/src/main/di/container.ts index 0b08821a7..f45163c53 100644 --- a/apps/code/src/main/di/container.ts +++ b/apps/code/src/main/di/container.ts @@ -2,6 +2,7 @@ import "reflect-metadata"; import { Container } from "inversify"; import { ArchiveRepository } from "../db/repositories/archive-repository"; +import { AuthPreferenceRepository } from "../db/repositories/auth-preference-repository"; import { AuthSessionRepository } from "../db/repositories/auth-session-repository"; import { RepositoryRepository } from "../db/repositories/repository-repository"; import { SuspensionRepositoryImpl } from "../db/repositories/suspension-repository"; @@ -52,6 +53,9 @@ export const container = new Container({ }); container.bind(MAIN_TOKENS.DatabaseService).to(DatabaseService); +container + .bind(MAIN_TOKENS.AuthPreferenceRepository) + .to(AuthPreferenceRepository); container.bind(MAIN_TOKENS.AuthSessionRepository).to(AuthSessionRepository); container.bind(MAIN_TOKENS.RepositoryRepository).to(RepositoryRepository); container.bind(MAIN_TOKENS.WorkspaceRepository).to(WorkspaceRepository); diff --git a/apps/code/src/main/di/tokens.ts b/apps/code/src/main/di/tokens.ts index 8080584b4..27bdbcafc 100644 --- a/apps/code/src/main/di/tokens.ts +++ b/apps/code/src/main/di/tokens.ts @@ -9,6 +9,7 @@ export const MAIN_TOKENS = Object.freeze({ SettingsStore: Symbol.for("Main.SettingsStore"), // Database + AuthPreferenceRepository: Symbol.for("Main.AuthPreferenceRepository"), DatabaseService: Symbol.for("Main.DatabaseService"), AuthSessionRepository: Symbol.for("Main.AuthSessionRepository"), RepositoryRepository: Symbol.for("Main.RepositoryRepository"), diff --git a/apps/code/src/main/services/auth/service.test.ts b/apps/code/src/main/services/auth/service.test.ts index dc822fe71..8caf9ac69 100644 --- a/apps/code/src/main/services/auth/service.test.ts +++ b/apps/code/src/main/services/auth/service.test.ts @@ -1,5 +1,6 @@ import { OAUTH_SCOPE_VERSION } from "@shared/constants/oauth"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { createMockAuthPreferenceRepository } from "../../db/repositories/auth-preference-repository.mock"; import { createMockAuthSessionRepository } from "../../db/repositories/auth-session-repository.mock"; import { decrypt, encrypt } from "../../utils/encryption"; import type { ConnectivityService } from "../connectivity/service"; @@ -18,6 +19,7 @@ vi.mock("../../utils/logger.js", () => ({ })); describe("AuthService", () => { + const preferenceRepository = createMockAuthPreferenceRepository(); const repository = createMockAuthSessionRepository(); const oauthService = { refreshToken: vi.fn(), @@ -31,9 +33,15 @@ describe("AuthService", () => { let service: AuthService; beforeEach(() => { + preferenceRepository._preferences = []; repository.clearCurrent(); vi.clearAllMocks(); - service = new AuthService(repository, oauthService, connectivityService); + service = new AuthService( + preferenceRepository, + repository, + oauthService, + connectivityService, + ); }); afterEach(async () => { @@ -41,6 +49,27 @@ describe("AuthService", () => { await service.logout(); }); + const stubAuthFetch = (accountKey = "user-1") => { + vi.stubGlobal( + "fetch", + vi.fn(async (input: string | Request) => { + const url = typeof input === "string" ? input : input.url; + + if (url.includes("/api/users/@me/")) { + return { + ok: true, + json: vi.fn().mockResolvedValue({ uuid: accountKey }), + } as unknown as Response; + } + + return { + ok: true, + json: vi.fn().mockResolvedValue({ has_access: true }), + } as unknown as Response; + }) as typeof fetch, + ); + }; + it("bootstraps to anonymous when there is no stored session", async () => { await service.initialize(); @@ -99,12 +128,7 @@ describe("AuthService", () => { }, }); - vi.stubGlobal( - "fetch", - vi.fn().mockResolvedValue({ - json: vi.fn().mockResolvedValue({ has_access: true }), - }) as unknown as typeof fetch, - ); + stubAuthFetch(); await service.initialize(); @@ -151,12 +175,7 @@ describe("AuthService", () => { scoped_organizations: ["org-1"], }, }); - vi.stubGlobal( - "fetch", - vi.fn().mockResolvedValue({ - json: vi.fn().mockResolvedValue({ has_access: true }), - }) as unknown as typeof fetch, - ); + stubAuthFetch(); await service.login("us"); @@ -211,12 +230,7 @@ describe("AuthService", () => { }, }); - vi.stubGlobal( - "fetch", - vi.fn().mockResolvedValue({ - json: vi.fn().mockResolvedValue({ has_access: true }), - }) as unknown as typeof fetch, - ); + stubAuthFetch(); await service.login("us"); await service.selectProject(84); @@ -237,4 +251,66 @@ describe("AuthService", () => { availableProjectIds: [42, 84], }); }); + + it("restores the selected project after app restart while logged out", async () => { + vi.mocked(oauthService.startFlow) + .mockResolvedValueOnce({ + success: true, + data: { + access_token: "initial-access-token", + refresh_token: "initial-refresh-token", + expires_in: 3600, + token_type: "Bearer", + scope: "", + scoped_teams: [42, 84], + scoped_organizations: ["org-1"], + }, + }) + .mockResolvedValueOnce({ + success: true, + data: { + access_token: "second-access-token", + refresh_token: "second-refresh-token", + expires_in: 3600, + token_type: "Bearer", + scope: "", + scoped_teams: [42, 84], + scoped_organizations: ["org-1"], + }, + }); + vi.mocked(oauthService.refreshToken).mockResolvedValue({ + success: true, + data: { + access_token: "refreshed-access-token", + refresh_token: "refreshed-refresh-token", + expires_in: 3600, + token_type: "Bearer", + scope: "", + scoped_teams: [42, 84], + scoped_organizations: ["org-1"], + }, + }); + + stubAuthFetch(); + + await service.login("us"); + await service.selectProject(84); + await service.logout(); + + service = new AuthService( + preferenceRepository, + repository, + oauthService, + connectivityService, + ); + + await service.login("us"); + + expect(service.getState()).toMatchObject({ + status: "authenticated", + cloudRegion: "us", + projectId: 84, + availableProjectIds: [42, 84], + }); + }); }); diff --git a/apps/code/src/main/services/auth/service.ts b/apps/code/src/main/services/auth/service.ts index 4b9808e43..41cadfdf9 100644 --- a/apps/code/src/main/services/auth/service.ts +++ b/apps/code/src/main/services/auth/service.ts @@ -4,6 +4,7 @@ import { } from "@shared/constants/oauth"; import type { CloudRegion } from "@shared/types/oauth"; import { inject, injectable } from "inversify"; +import type { IAuthPreferenceRepository } from "../../db/repositories/auth-preference-repository"; import type { IAuthSessionRepository, PersistAuthSessionInput, @@ -30,6 +31,7 @@ type FetchLike = ( ) => Promise; interface InMemorySession { + accountKey: string | null; accessToken: string; accessTokenExpiresAt: number; refreshToken: string; @@ -67,6 +69,8 @@ export class AuthService extends TypedEventEmitter { private refreshPromise: Promise | null = null; constructor( + @inject(MAIN_TOKENS.AuthPreferenceRepository) + private readonly authPreferenceRepository: IAuthPreferenceRepository, @inject(MAIN_TOKENS.AuthSessionRepository) private readonly authSessionRepository: IAuthSessionRepository, @inject(MAIN_TOKENS.OAuthService) @@ -208,6 +212,7 @@ export class AuthService extends TypedEventEmitter { projectId, }; + this.persistProjectPreference(this.session); this.persistSession({ refreshToken: this.session.refreshToken, cloudRegion: this.session.cloudRegion, @@ -358,22 +363,32 @@ export class AuthService extends TypedEventEmitter { throw new Error(result.error || "Token refresh failed"); } - return this.createSessionFromTokenResponse(result.data, input); + return await this.createSessionFromTokenResponse(result.data, input); } - private createSessionFromTokenResponse( + private async createSessionFromTokenResponse( tokenResponse: AuthTokenResponse, options: TokenResponseOptions, - ): InMemorySession { + ): Promise { const availableProjectIds = tokenResponse.scoped_teams ?? []; const availableOrgIds = tokenResponse.scoped_organizations ?? []; + const accountKey = await this.fetchAccountKey( + tokenResponse.access_token, + options.cloudRegion, + ); + const preferredProjectId = + options.selectedProjectId ?? + (accountKey + ? (this.authPreferenceRepository.get(accountKey, options.cloudRegion) + ?.lastSelectedProjectId ?? null) + : null); const projectId = - options.selectedProjectId && - availableProjectIds.includes(options.selectedProjectId) - ? options.selectedProjectId + preferredProjectId && availableProjectIds.includes(preferredProjectId) + ? preferredProjectId : (availableProjectIds[0] ?? null); const session: InMemorySession = { + accountKey, accessToken: tokenResponse.access_token, accessTokenExpiresAt: Date.now() + tokenResponse.expires_in * 1000, refreshToken: tokenResponse.refresh_token, @@ -400,7 +415,7 @@ export class AuthService extends TypedEventEmitter { throw new Error(result.error || fallbackError); } - const session = this.createSessionFromTokenResponse(result.data, { + const session = await this.createSessionFromTokenResponse(result.data, { cloudRegion: region, selectedProjectId: this.state.projectId, }); @@ -417,6 +432,7 @@ export class AuthService extends TypedEventEmitter { private async syncAuthenticatedSession( session: InMemorySession, ): Promise { + this.persistProjectPreference(session); this.persistSession({ refreshToken: session.refreshToken, cloudRegion: session.cloudRegion, @@ -451,6 +467,18 @@ export class AuthService extends TypedEventEmitter { this.authSessionRepository.saveCurrent(row); } + private persistProjectPreference(session: InMemorySession): void { + if (!session.accountKey) { + return; + } + + this.authPreferenceRepository.save({ + accountKey: session.accountKey, + cloudRegion: session.cloudRegion, + lastSelectedProjectId: session.projectId, + }); + } + private isSessionExpiring(session: InMemorySession): boolean { return session.accessTokenExpiresAt - Date.now() <= TOKEN_EXPIRY_SKEW_MS; } @@ -470,6 +498,47 @@ export class AuthService extends TypedEventEmitter { }; } + private async fetchAccountKey( + accessToken: string, + cloudRegion: "us" | "eu" | "dev", + ): Promise { + try { + const response = await fetch( + `${getCloudUrlFromRegion(cloudRegion)}/api/users/@me/`, + { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }, + ); + + if (!response.ok) { + return null; + } + + const data = (await response.json().catch(() => ({}))) as { + uuid?: unknown; + distinct_id?: unknown; + email?: unknown; + }; + + if (typeof data.uuid === "string" && data.uuid.length > 0) { + return data.uuid; + } + if (typeof data.distinct_id === "string" && data.distinct_id.length > 0) { + return data.distinct_id; + } + if (typeof data.email === "string" && data.email.length > 0) { + return data.email; + } + + return null; + } catch (error) { + log.warn("Failed to resolve auth account key", { error }); + return null; + } + } + private requireSession(): InMemorySession { if (!this.session) { throw new Error("Not authenticated");