diff --git a/packages/shared/src/dpop.test.ts b/packages/shared/src/dpop.test.ts index 58bd161e2a3..fa951bc74f1 100644 --- a/packages/shared/src/dpop.test.ts +++ b/packages/shared/src/dpop.test.ts @@ -1,6 +1,6 @@ import * as NodeCrypto from "node:crypto"; -import { describe, expect, it } from "@effect/vitest"; +import { assert, describe, it } from "@effect/vitest"; import { computeDpopAccessTokenHash, @@ -56,59 +56,59 @@ describe("verifyDpopProof", () => { it("verifies an ES256 DPoP proof and returns the RFC 7638 thumbprint", () => { const thumbprint = computeDpopJwkThumbprint(publicJwk); - expect( - verifyDpopProof({ - proof, - method: "POST", - url: "https://example.com/oauth/token", - nowEpochSeconds: 101, - expectedThumbprint: thumbprint, - }), - ).toMatchObject({ - ok: true, - thumbprint, - jti: "proof-1", + const result = verifyDpopProof({ + proof, + method: "POST", + url: "https://example.com/oauth/token", + nowEpochSeconds: 101, + expectedThumbprint: thumbprint, }); + + if (!result.ok) { + assert.fail(result.reason); + } + assert.equal(result.thumbprint, thumbprint); + assert.equal(result.jti, "proof-1"); }); it("rejects method, URL, thumbprint, and time-window mismatches", () => { const thumbprint = computeDpopJwkThumbprint(publicJwk); - expect( + assert.isFalse( verifyDpopProof({ proof, method: "GET", url: "https://example.com/oauth/token", nowEpochSeconds: 101, expectedThumbprint: thumbprint, - }), - ).toMatchObject({ ok: false }); - expect( + }).ok, + ); + assert.isFalse( verifyDpopProof({ proof, method: "POST", url: "https://example.com/other", nowEpochSeconds: 101, expectedThumbprint: thumbprint, - }), - ).toMatchObject({ ok: false }); - expect( + }).ok, + ); + assert.isFalse( verifyDpopProof({ proof, method: "POST", url: "https://example.com/oauth/token", nowEpochSeconds: 101, expectedThumbprint: "other-thumbprint", - }), - ).toMatchObject({ ok: false }); - expect( + }).ok, + ); + assert.isFalse( verifyDpopProof({ proof, method: "POST", url: "https://example.com/oauth/token", nowEpochSeconds: 1_000, expectedThumbprint: thumbprint, - }), - ).toMatchObject({ ok: false }); + }).ok, + ); }); it("requires the RFC 9449 access token hash when an access token is expected", () => { @@ -122,7 +122,7 @@ describe("verifyDpopProof", () => { accessToken: "clerk-access-token", }); - expect( + assert.isTrue( verifyDpopProof({ proof: accessTokenProof, method: "POST", @@ -130,32 +130,39 @@ describe("verifyDpopProof", () => { nowEpochSeconds: 101, expectedThumbprint: thumbprint, expectedAccessToken: "clerk-access-token", - }), - ).toMatchObject({ ok: true }); - expect( - verifyDpopProof({ - proof, - method: "POST", - url: "https://example.com/oauth/token", - nowEpochSeconds: 101, - expectedThumbprint: thumbprint, - expectedAccessToken: "clerk-access-token", - }), - ).toMatchObject({ ok: false, reason: "DPoP access token hash mismatch." }); - expect( - verifyDpopProof({ - proof: accessTokenProof, - method: "POST", - url: "https://example.com/v1/environments/env/connect", - nowEpochSeconds: 101, - expectedThumbprint: thumbprint, - expectedAccessToken: "other-access-token", - }), - ).toMatchObject({ ok: false, reason: "DPoP access token hash mismatch." }); + }).ok, + ); + + const missingAccessTokenHash = verifyDpopProof({ + proof, + method: "POST", + url: "https://example.com/oauth/token", + nowEpochSeconds: 101, + expectedThumbprint: thumbprint, + expectedAccessToken: "clerk-access-token", + }); + if (missingAccessTokenHash.ok) { + assert.fail("Expected missing access token hash to be rejected."); + } + assert.equal(missingAccessTokenHash.reason, "DPoP access token hash mismatch."); + + const mismatchedAccessTokenHash = verifyDpopProof({ + proof: accessTokenProof, + method: "POST", + url: "https://example.com/v1/environments/env/connect", + nowEpochSeconds: 101, + expectedThumbprint: thumbprint, + expectedAccessToken: "other-access-token", + }); + if (mismatchedAccessTokenHash.ok) { + assert.fail("Expected mismatched access token hash to be rejected."); + } + assert.equal(mismatchedAccessTokenHash.reason, "DPoP access token hash mismatch."); }); it("normalizes htu by excluding query and fragment components per RFC 9449", () => { - expect(normalizeDpopHtu("https://example.com/v1/environments/env/connect?foo=bar#frag")).toBe( + assert.equal( + normalizeDpopHtu("https://example.com/v1/environments/env/connect?foo=bar#frag"), "https://example.com/v1/environments/env/connect", ); @@ -168,15 +175,15 @@ describe("verifyDpopProof", () => { publicJwk, }); - expect( + assert.isTrue( verifyDpopProof({ proof: queryProof, method: "POST", url: "https://example.com/v1/environments/env/connect?foo=bar#frag", nowEpochSeconds: 101, expectedThumbprint: thumbprint, - }), - ).toMatchObject({ ok: true }); + }).ok, + ); }); it("rejects DPoP public JWK headers that expose private key material", () => { @@ -192,14 +199,16 @@ describe("verifyDpopProof", () => { publicJwk: privateJwk, }); - expect( - verifyDpopProof({ - proof: proofWithPrivateJwk, - method: "POST", - url: "https://example.com/oauth/token", - nowEpochSeconds: 101, - expectedThumbprint: thumbprint, - }), - ).toMatchObject({ ok: false, reason: "Invalid DPoP JWT header." }); + const result = verifyDpopProof({ + proof: proofWithPrivateJwk, + method: "POST", + url: "https://example.com/oauth/token", + nowEpochSeconds: 101, + expectedThumbprint: thumbprint, + }); + if (result.ok) { + assert.fail("Expected private JWK material to be rejected."); + } + assert.equal(result.reason, "Invalid DPoP JWT header."); }); }); diff --git a/packages/shared/src/dpop.ts b/packages/shared/src/dpop.ts index 34210679007..3b45226b8f5 100644 --- a/packages/shared/src/dpop.ts +++ b/packages/shared/src/dpop.ts @@ -1,6 +1,7 @@ import { p256 } from "@noble/curves/nist"; import { sha256 } from "@noble/hashes/sha2"; import * as Encoding from "effect/Encoding"; +import * as Option from "effect/Option"; import * as Result from "effect/Result"; import * as Schema from "effect/Schema"; @@ -17,21 +18,27 @@ export const DpopPublicJwk = Schema.Struct({ y: Schema.String.check(Schema.isNonEmpty()), }); export type DpopPublicJwk = typeof DpopPublicJwk.Type; -const isDpopPublicJwk = Schema.is(DpopPublicJwk); -interface DpopJwtHeader { - readonly typ: string; - readonly alg: string; - readonly jwk: DpopPublicJwk; -} +const DpopJwtHeader = Schema.Struct({ + typ: Schema.Literal(DPOP_TYP), + alg: Schema.Literal(DPOP_ALG), + jwk: DpopPublicJwk, +}); +type DpopJwtHeader = typeof DpopJwtHeader.Type; + +const DpopJwtPayload = Schema.Struct({ + htm: Schema.String.check(Schema.isNonEmpty()), + htu: Schema.String.check(Schema.isNonEmpty()), + jti: Schema.String.check(Schema.isNonEmpty()), + iat: Schema.Int, + ath: Schema.optional(Schema.String), +}); +type DpopJwtPayload = typeof DpopJwtPayload.Type; -interface DpopJwtPayload { - readonly htm: string; - readonly htu: string; - readonly jti: string; - readonly iat: number; - readonly ath?: string; -} +const decodeDpopJwtHeaderJson = Schema.decodeUnknownOption(Schema.fromJsonString(DpopJwtHeader), { + onExcessProperty: "preserve", +}); +const decodeDpopJwtPayloadJson = Schema.decodeUnknownOption(Schema.fromJsonString(DpopJwtPayload)); export type DpopVerificationResult = | { @@ -49,40 +56,24 @@ function base64UrlToBytes(value: string): Uint8Array { return Result.getOrThrow(Encoding.decodeBase64Url(value)); } -function decodeBase64UrlJson(value: string): unknown { - return JSON.parse(Result.getOrThrow(Encoding.decodeBase64UrlString(value))) as unknown; +function decodeBase64UrlJsonOption( + value: string, + decode: (input: unknown) => Option.Option, +): Option.Option { + const decoded = Encoding.decodeBase64UrlString(value); + return Result.isFailure(decoded) ? Option.none() : decode(decoded.success); } -function isDpopJwtHeader(value: unknown): value is DpopJwtHeader { - if (typeof value !== "object" || value === null) { - return false; +function decodeDpopJwtHeader(value: string): Option.Option { + const header = decodeBase64UrlJsonOption(value, decodeDpopJwtHeaderJson); + if (Option.isNone(header)) { + return Option.none(); } - const record = value as Record; - return ( - record.typ === DPOP_TYP && - record.alg === DPOP_ALG && - typeof record.jwk === "object" && - record.jwk !== null && - !("d" in record.jwk) && - isDpopPublicJwk(record.jwk) - ); + return "d" in header.value.jwk ? Option.none() : header; } -function isDpopJwtPayload(value: unknown): value is DpopJwtPayload { - if (typeof value !== "object" || value === null) { - return false; - } - const record = value as Record; - return ( - typeof record.htm === "string" && - record.htm.length > 0 && - typeof record.htu === "string" && - record.htu.length > 0 && - typeof record.jti === "string" && - record.jti.length > 0 && - typeof record.iat === "number" && - Number.isInteger(record.iat) - ); +function decodeDpopJwtPayload(value: string): Option.Option { + return decodeBase64UrlJsonOption(value, decodeDpopJwtPayloadJson); } function dpopThumbprintInput(jwk: DpopPublicJwk): string { @@ -145,14 +136,16 @@ export function verifyDpopProof(input: { } try { - const header = decodeBase64UrlJson(parts[0]); - const payload = decodeBase64UrlJson(parts[1]); - if (!isDpopJwtHeader(header)) { + const headerOption = decodeDpopJwtHeader(parts[0]); + if (Option.isNone(headerOption)) { return { ok: false, reason: "Invalid DPoP JWT header." }; } - if (!isDpopJwtPayload(payload)) { + const payloadOption = decodeDpopJwtPayload(parts[1]); + if (Option.isNone(payloadOption)) { return { ok: false, reason: "Invalid DPoP JWT payload." }; } + const header = headerOption.value; + const payload = payloadOption.value; const thumbprint = computeDpopJwkThumbprint(header.jwk); if (input.expectedThumbprint && thumbprint !== input.expectedThumbprint) { diff --git a/packages/ssh/src/errors.ts b/packages/ssh/src/errors.ts index f1ba40b560c..1c2014cb596 100644 --- a/packages/ssh/src/errors.ts +++ b/packages/ssh/src/errors.ts @@ -1,4 +1,5 @@ import * as Data from "effect/Data"; +import * as Schema from "effect/Schema"; export class SshHostDiscoveryError extends Data.TaggedError("SshHostDiscoveryError")<{ readonly message: string; @@ -36,10 +37,56 @@ export class SshHttpBridgeError extends Data.TaggedError("SshHttpBridgeError")<{ readonly cause?: unknown; }> {} -export class SshReadinessError extends Data.TaggedError("SshReadinessError")<{ - readonly message: string; - readonly cause?: unknown; -}> {} +export class SshReadinessProbeFailedError extends Schema.TaggedErrorClass()( + "SshReadinessProbeFailedError", + { + requestUrl: Schema.String, + attempt: Schema.Number, + cause: Schema.Defect(), + }, +) { + override get message(): string { + return `Backend readiness probe failed at ${this.requestUrl}.`; + } +} + +export class SshReadinessProbeTimedOutError extends Schema.TaggedErrorClass()( + "SshReadinessProbeTimedOutError", + { + requestUrl: Schema.String, + attempt: Schema.Number, + probeTimeoutMs: Schema.Number, + }, +) { + override get message(): string { + return `Backend readiness probe exceeded ${this.probeTimeoutMs}ms at ${this.requestUrl}.`; + } +} + +export class SshReadinessTimedOutError extends Schema.TaggedErrorClass()( + "SshReadinessTimedOutError", + { + baseUrl: Schema.String, + requestUrl: Schema.String, + timeoutMs: Schema.Number, + intervalMs: Schema.Number, + probeTimeoutMs: Schema.Number, + attempts: Schema.Number, + lastFailure: Schema.optional(Schema.Defect()), + }, +) { + override get message(): string { + return `Timed out waiting ${this.timeoutMs}ms for backend readiness at ${this.baseUrl}.`; + } +} + +export const SshReadinessError = Schema.Union([ + SshReadinessProbeFailedError, + SshReadinessProbeTimedOutError, + SshReadinessTimedOutError, +]); +export type SshReadinessError = typeof SshReadinessError.Type; +export const isSshReadinessError = Schema.is(SshReadinessError); export class SshPasswordPromptError extends Data.TaggedError("SshPasswordPromptError")<{ readonly message: string; diff --git a/packages/ssh/src/tunnel.test.ts b/packages/ssh/src/tunnel.test.ts index 2e5c1a69904..fb998b608a0 100644 --- a/packages/ssh/src/tunnel.test.ts +++ b/packages/ssh/src/tunnel.test.ts @@ -25,6 +25,7 @@ import { SshEnvironmentManager, waitForHttpReady, } from "./tunnel.ts"; +import { SshReadinessTimedOutError } from "./errors.ts"; const TEST_NODE_ENGINE_RANGE = "^22.16 || ^23.11 || >=24.10"; @@ -255,7 +256,12 @@ describe("ssh tunnel scripts", () => { assert.isTrue(Result.isFailure(result)); if (Result.isFailure(result)) { + assert.instanceOf(result.failure, SshReadinessTimedOutError); + assert.equal(result.failure._tag, "SshReadinessTimedOutError"); + assert.equal(result.failure.timeoutMs, 1_000); + assert.equal(result.failure.probeTimeoutMs, 250); assert.include(result.failure.message, "Timed out waiting 1000ms"); + assert.isDefined(result.failure.lastFailure); } }).pipe( Effect.provide( diff --git a/packages/ssh/src/tunnel.ts b/packages/ssh/src/tunnel.ts index 029b7644897..a9cebd6e1de 100644 --- a/packages/ssh/src/tunnel.ts +++ b/packages/ssh/src/tunnel.ts @@ -47,6 +47,10 @@ import { SshPairingError, SshPasswordPromptError, SshReadinessError, + SshReadinessProbeFailedError, + SshReadinessProbeTimedOutError, + SshReadinessTimedOutError, + isSshReadinessError, } from "./errors.ts"; export const DEFAULT_REMOTE_PORT = 3773; @@ -233,12 +237,39 @@ function applyScriptPlaceholders( } export function describeReadinessCause(cause: unknown): unknown { - if (cause instanceof SshReadinessError) { - return { - _tag: cause._tag, - message: cause.message, - ...(cause.cause === undefined ? {} : { cause: describeReadinessCause(cause.cause) }), - }; + if (isSshReadinessError(cause)) { + switch (cause._tag) { + case "SshReadinessProbeFailedError": + return { + _tag: cause._tag, + message: cause.message, + requestUrl: cause.requestUrl, + attempt: cause.attempt, + cause: describeReadinessCause(cause.cause), + }; + case "SshReadinessProbeTimedOutError": + return { + _tag: cause._tag, + message: cause.message, + requestUrl: cause.requestUrl, + attempt: cause.attempt, + probeTimeoutMs: cause.probeTimeoutMs, + }; + case "SshReadinessTimedOutError": + return { + _tag: cause._tag, + message: cause.message, + baseUrl: cause.baseUrl, + requestUrl: cause.requestUrl, + timeoutMs: cause.timeoutMs, + intervalMs: cause.intervalMs, + probeTimeoutMs: cause.probeTimeoutMs, + attempts: cause.attempts, + ...(cause.lastFailure === undefined + ? {} + : { lastFailure: describeReadinessCause(cause.lastFailure) }), + }; + } } if (cause instanceof Error) { return { @@ -879,7 +910,7 @@ export const waitForHttpReady = Effect.fn("ssh/tunnel.waitForHttpReady")(functio ); const requestUrl = new URL(input.path ?? "/", input.baseUrl).toString(); const client = yield* HttpClient.HttpClient; - const lastProbeFailure = yield* Ref.make(null); + const lastProbeFailure = yield* Ref.make>(Option.none()); let attempt = 0; yield* Effect.logDebug("ssh.tunnel.httpReady.start", { @@ -899,8 +930,9 @@ export const waitForHttpReady = Effect.fn("ssh/tunnel.waitForHttpReady")(functio Effect.timeoutOption(Duration.millis(probeTimeoutMs)), Effect.mapError( (cause) => - new SshReadinessError({ - message: `Backend readiness probe failed at ${requestUrl}.`, + new SshReadinessProbeFailedError({ + requestUrl, + attempt, cause, }), ), @@ -909,30 +941,25 @@ export const waitForHttpReady = Effect.fn("ssh/tunnel.waitForHttpReady")(functio onSome: Effect.succeed, onNone: () => Effect.fail( - new SshReadinessError({ - message: `Backend readiness probe exceeded ${probeTimeoutMs}ms at ${requestUrl}.`, - cause: { - kind: "probe-timeout", - attempt, - probeTimeoutMs, - }, + new SshReadinessProbeTimedOutError({ + requestUrl, + attempt, + probeTimeoutMs, }), ), }); }).pipe( Effect.mapError((cause) => - cause instanceof SshReadinessError + isSshReadinessError(cause) ? cause - : new SshReadinessError({ - message: `Backend readiness probe failed at ${requestUrl}.`, + : new SshReadinessProbeFailedError({ + requestUrl, + attempt, cause, }), ), Effect.tapError((cause) => - Ref.set(lastProbeFailure, { - attempt, - cause: describeReadinessCause(cause), - }), + Ref.set(lastProbeFailure, Option.some(describeReadinessCause(cause))), ), ), ), @@ -942,10 +969,11 @@ export const waitForHttpReady = Effect.fn("ssh/tunnel.waitForHttpReady")(functio const result = yield* readinessClient.execute(HttpClientRequest.get(requestUrl)).pipe( Effect.mapError((cause) => - cause instanceof SshReadinessError + isSshReadinessError(cause) ? cause - : new SshReadinessError({ - message: `Backend readiness probe failed at ${requestUrl}.`, + : new SshReadinessProbeFailedError({ + requestUrl, + attempt, cause, }), ), @@ -961,7 +989,7 @@ export const waitForHttpReady = Effect.fn("ssh/tunnel.waitForHttpReady")(functio }), onNone: () => Effect.gen(function* () { - const lastFailure = yield* Ref.get(lastProbeFailure); + const lastFailure = Option.getOrUndefined(yield* Ref.get(lastProbeFailure)); yield* Effect.logWarning("ssh.tunnel.httpReady.timedOut", { baseUrl: input.baseUrl, requestUrl, @@ -971,9 +999,14 @@ export const waitForHttpReady = Effect.fn("ssh/tunnel.waitForHttpReady")(functio attempts: attempt, lastFailure, }); - return yield* new SshReadinessError({ - message: `Timed out waiting ${timeoutMs}ms for backend readiness at ${input.baseUrl}.`, - cause: lastFailure, + return yield* new SshReadinessTimedOutError({ + baseUrl: input.baseUrl, + requestUrl, + timeoutMs, + intervalMs, + probeTimeoutMs, + attempts: attempt, + ...(lastFailure === undefined ? {} : { lastFailure }), }); }), });