Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions apps/code/src/main/db/migrations/0004_auth_preferences.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE TABLE `auth_preferences` (
`account_key` text NOT NULL,
`cloud_region` text NOT NULL,
`last_selected_project_id` integer,
`created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
`updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL
);
Comment on lines +1 to +7
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The auth_preferences table is missing a PRIMARY KEY or UNIQUE constraint on (account_key, cloud_region). This will allow duplicate rows to be inserted for the same account and region combination, causing unpredictable behavior in the get and save methods.

Impact: Multiple preference records can exist for the same account/region, and LIMIT 1 queries will return arbitrary rows. The update logic in save() may also fail to work correctly.

Fix: Add a composite primary key:

CREATE TABLE `auth_preferences` (
  `account_key` text NOT NULL,
  `cloud_region` text NOT NULL,
  `last_selected_project_id` integer,
  `created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
  `updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
  PRIMARY KEY (`account_key`, `cloud_region`)
);
Suggested change
CREATE TABLE `auth_preferences` (
`account_key` text NOT NULL,
`cloud_region` text NOT NULL,
`last_selected_project_id` integer,
`created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
`updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL
);
CREATE TABLE `auth_preferences` (
`account_key` text NOT NULL,
`cloud_region` text NOT NULL,
`last_selected_project_id` integer,
`created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
`updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
PRIMARY KEY (`account_key`, `cloud_region`)
);

Spotted by Graphite

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

--> statement-breakpoint
CREATE INDEX `auth_preferences_account_region_idx` ON `auth_preferences` (`account_key`,`cloud_region`);
7 changes: 7 additions & 0 deletions apps/code/src/main/db/migrations/meta/_journal.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
"when": 1774890000000,
"tag": "0003_fair_whiplash",
"breakpoints": true
},
{
"idx": 4,
"version": "7",
"when": 1774891000000,
"tag": "0004_auth_preferences",
"breakpoints": true
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import type {
AuthPreference,
IAuthPreferenceRepository,
PersistAuthPreferenceInput,
} from "./auth-preference-repository";

export interface MockAuthPreferenceRepository
extends IAuthPreferenceRepository {
_preferences: AuthPreference[];
}

export function createMockAuthPreferenceRepository(): MockAuthPreferenceRepository {
let preferences: AuthPreference[] = [];

const clone = (value: AuthPreference): AuthPreference => ({ ...value });

return {
get _preferences() {
return preferences.map(clone);
},
set _preferences(value) {
preferences = value.map(clone);
},
get: (accountKey, cloudRegion) => {
const preference = preferences.find(
(entry) =>
entry.accountKey === accountKey && entry.cloudRegion === cloudRegion,
);
return preference ? clone(preference) : null;
},
save: (input: PersistAuthPreferenceInput) => {
const timestamp = new Date().toISOString();
const existingIndex = preferences.findIndex(
(entry) =>
entry.accountKey === input.accountKey &&
entry.cloudRegion === input.cloudRegion,
);

const row: AuthPreference = {
accountKey: input.accountKey,
cloudRegion: input.cloudRegion,
lastSelectedProjectId: input.lastSelectedProjectId,
createdAt:
existingIndex >= 0 ? preferences[existingIndex].createdAt : timestamp,
updatedAt: timestamp,
};

if (existingIndex >= 0) {
preferences[existingIndex] = row;
} else {
preferences.push(row);
}

return clone(row);
},
};
}
89 changes: 89 additions & 0 deletions apps/code/src/main/db/repositories/auth-preference-repository.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { and, eq } from "drizzle-orm";
import { inject, injectable } from "inversify";
import { MAIN_TOKENS } from "../../di/tokens";
import { authPreferences } from "../schema";
import type { DatabaseService } from "../service";

export type AuthPreference = typeof authPreferences.$inferSelect;
export type NewAuthPreference = typeof authPreferences.$inferInsert;

export interface PersistAuthPreferenceInput {
accountKey: string;
cloudRegion: "us" | "eu" | "dev";
lastSelectedProjectId: number | null;
}

export interface IAuthPreferenceRepository {
get(
accountKey: string,
cloudRegion: "us" | "eu" | "dev",
): AuthPreference | null;
save(input: PersistAuthPreferenceInput): AuthPreference;
}

const now = () => new Date().toISOString();

@injectable()
export class AuthPreferenceRepository implements IAuthPreferenceRepository {
constructor(
@inject(MAIN_TOKENS.DatabaseService)
private readonly databaseService: DatabaseService,
) {}

private get db() {
return this.databaseService.db;
}

get(
accountKey: string,
cloudRegion: "us" | "eu" | "dev",
): AuthPreference | null {
return (
this.db
.select()
.from(authPreferences)
.where(
and(
eq(authPreferences.accountKey, accountKey),
eq(authPreferences.cloudRegion, cloudRegion),
),
)
.limit(1)
.get() ?? null
);
}

save(input: PersistAuthPreferenceInput): AuthPreference {
const timestamp = now();
const existing = this.get(input.accountKey, input.cloudRegion);

const row: NewAuthPreference = {
accountKey: input.accountKey,
cloudRegion: input.cloudRegion,
lastSelectedProjectId: input.lastSelectedProjectId,
createdAt: existing?.createdAt ?? timestamp,
updatedAt: timestamp,
};

if (existing) {
this.db
.update(authPreferences)
.set(row)
.where(
and(
eq(authPreferences.accountKey, input.accountKey),
eq(authPreferences.cloudRegion, input.cloudRegion),
),
)
.run();
} else {
this.db.insert(authPreferences).values(row).run();
}

const saved = this.get(input.accountKey, input.cloudRegion);
if (!saved) {
throw new Error("Failed to persist auth preference");
}
return saved;
}
}
17 changes: 17 additions & 0 deletions apps/code/src/main/db/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,20 @@ export const authSessions = sqliteTable("auth_sessions", {
createdAt: createdAt(),
updatedAt: updatedAt(),
});

export const authPreferences = sqliteTable(
"auth_preferences",
{
accountKey: text().notNull(),
cloudRegion: text({ enum: ["us", "eu", "dev"] }).notNull(),
lastSelectedProjectId: integer(),
createdAt: createdAt(),
updatedAt: updatedAt(),
},
(t) => [
index("auth_preferences_account_region_idx").on(
t.accountKey,
t.cloudRegion,
),
],
);
4 changes: 4 additions & 0 deletions apps/code/src/main/di/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import "reflect-metadata";

import { Container } from "inversify";
import { ArchiveRepository } from "../db/repositories/archive-repository";
import { AuthPreferenceRepository } from "../db/repositories/auth-preference-repository";
import { AuthSessionRepository } from "../db/repositories/auth-session-repository";
import { RepositoryRepository } from "../db/repositories/repository-repository";
import { SuspensionRepositoryImpl } from "../db/repositories/suspension-repository";
Expand Down Expand Up @@ -52,6 +53,9 @@ export const container = new Container({
});

container.bind(MAIN_TOKENS.DatabaseService).to(DatabaseService);
container
.bind(MAIN_TOKENS.AuthPreferenceRepository)
.to(AuthPreferenceRepository);
container.bind(MAIN_TOKENS.AuthSessionRepository).to(AuthSessionRepository);
container.bind(MAIN_TOKENS.RepositoryRepository).to(RepositoryRepository);
container.bind(MAIN_TOKENS.WorkspaceRepository).to(WorkspaceRepository);
Expand Down
1 change: 1 addition & 0 deletions apps/code/src/main/di/tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export const MAIN_TOKENS = Object.freeze({
SettingsStore: Symbol.for("Main.SettingsStore"),

// Database
AuthPreferenceRepository: Symbol.for("Main.AuthPreferenceRepository"),
DatabaseService: Symbol.for("Main.DatabaseService"),
AuthSessionRepository: Symbol.for("Main.AuthSessionRepository"),
RepositoryRepository: Symbol.for("Main.RepositoryRepository"),
Expand Down
114 changes: 95 additions & 19 deletions apps/code/src/main/services/auth/service.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { OAUTH_SCOPE_VERSION } from "@shared/constants/oauth";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { createMockAuthPreferenceRepository } from "../../db/repositories/auth-preference-repository.mock";
import { createMockAuthSessionRepository } from "../../db/repositories/auth-session-repository.mock";
import { decrypt, encrypt } from "../../utils/encryption";
import type { ConnectivityService } from "../connectivity/service";
Expand All @@ -18,6 +19,7 @@ vi.mock("../../utils/logger.js", () => ({
}));

describe("AuthService", () => {
const preferenceRepository = createMockAuthPreferenceRepository();
const repository = createMockAuthSessionRepository();
const oauthService = {
refreshToken: vi.fn(),
Expand All @@ -31,16 +33,43 @@ describe("AuthService", () => {
let service: AuthService;

beforeEach(() => {
preferenceRepository._preferences = [];
repository.clearCurrent();
vi.clearAllMocks();
service = new AuthService(repository, oauthService, connectivityService);
service = new AuthService(
preferenceRepository,
repository,
oauthService,
connectivityService,
);
});

afterEach(async () => {
vi.unstubAllGlobals();
await service.logout();
});

const stubAuthFetch = (accountKey = "user-1") => {
vi.stubGlobal(
"fetch",
vi.fn(async (input: string | Request) => {
const url = typeof input === "string" ? input : input.url;

if (url.includes("/api/users/@me/")) {
return {
ok: true,
json: vi.fn().mockResolvedValue({ uuid: accountKey }),
} as unknown as Response;
}

return {
ok: true,
json: vi.fn().mockResolvedValue({ has_access: true }),
} as unknown as Response;
}) as typeof fetch,
);
};

it("bootstraps to anonymous when there is no stored session", async () => {
await service.initialize();

Expand Down Expand Up @@ -99,12 +128,7 @@ describe("AuthService", () => {
},
});

vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue({ has_access: true }),
}) as unknown as typeof fetch,
);
stubAuthFetch();

await service.initialize();

Expand Down Expand Up @@ -151,12 +175,7 @@ describe("AuthService", () => {
scoped_organizations: ["org-1"],
},
});
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue({ has_access: true }),
}) as unknown as typeof fetch,
);
stubAuthFetch();

await service.login("us");

Expand Down Expand Up @@ -211,12 +230,7 @@ describe("AuthService", () => {
},
});

vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue({ has_access: true }),
}) as unknown as typeof fetch,
);
stubAuthFetch();

await service.login("us");
await service.selectProject(84);
Expand All @@ -237,4 +251,66 @@ describe("AuthService", () => {
availableProjectIds: [42, 84],
});
});

it("restores the selected project after app restart while logged out", async () => {
vi.mocked(oauthService.startFlow)
.mockResolvedValueOnce({
success: true,
data: {
access_token: "initial-access-token",
refresh_token: "initial-refresh-token",
expires_in: 3600,
token_type: "Bearer",
scope: "",
scoped_teams: [42, 84],
scoped_organizations: ["org-1"],
},
})
.mockResolvedValueOnce({
success: true,
data: {
access_token: "second-access-token",
refresh_token: "second-refresh-token",
expires_in: 3600,
token_type: "Bearer",
scope: "",
scoped_teams: [42, 84],
scoped_organizations: ["org-1"],
},
});
vi.mocked(oauthService.refreshToken).mockResolvedValue({
success: true,
data: {
access_token: "refreshed-access-token",
refresh_token: "refreshed-refresh-token",
expires_in: 3600,
token_type: "Bearer",
scope: "",
scoped_teams: [42, 84],
scoped_organizations: ["org-1"],
},
});

stubAuthFetch();

await service.login("us");
await service.selectProject(84);
await service.logout();

service = new AuthService(
preferenceRepository,
repository,
oauthService,
connectivityService,
);

await service.login("us");

expect(service.getState()).toMatchObject({
status: "authenticated",
cloudRegion: "us",
projectId: 84,
availableProjectIds: [42, 84],
});
});
});
Loading
Loading