diff --git a/src/cli/commands/import/phase2-import.ts b/src/cli/commands/import/phase2-import.ts index d898785f3..c980110bd 100644 --- a/src/cli/commands/import/phase2-import.ts +++ b/src/cli/commands/import/phase2-import.ts @@ -1,3 +1,4 @@ +import { PollExhaustedError, PollTimeoutError, isThrottlingError, poll } from '../../../lib/utils/polling'; import { getCredentialProvider } from '../../aws/account'; import type { CfnTemplate } from './template-utils'; import { buildImportTemplate } from './template-utils'; @@ -141,64 +142,61 @@ async function waitForChangeSetReady( stackName: string, changeSetName: string ): Promise { - const maxAttempts = 60; - const delay = 5000; // 5 seconds - - for (let attempt = 0; attempt < maxAttempts; attempt++) { - const response = await cfn.send( - new DescribeChangeSetCommand({ - StackName: stackName, - ChangeSetName: changeSetName, - }) - ); - - const status = response.Status; - - if (status === 'CREATE_COMPLETE') { - return; - } - - if (status === 'FAILED') { - throw new Error(`Change set creation failed: ${response.StatusReason ?? 'Unknown reason'}`); + try { + await poll({ + fn: async () => { + const response = await cfn.send( + new DescribeChangeSetCommand({ + StackName: stackName, + ChangeSetName: changeSetName, + }) + ); + const status = response.Status; + if (status === 'CREATE_COMPLETE') return { done: true, value: undefined }; + if (status === 'FAILED') { + throw new Error(`Change set creation failed: ${response.StatusReason ?? 'Unknown reason'}`); + } + return { done: false }; + }, + maxAttempts: 60, + delayMs: 5000, + onError: (err: unknown) => (isThrottlingError(err) ? 'retry' : 'abort'), + }); + } catch (err) { + if (err instanceof PollExhaustedError || err instanceof PollTimeoutError) { + throw new Error('Timed out waiting for change set creation', { cause: err }); } - - // CREATE_PENDING, CREATE_IN_PROGRESS — keep waiting - await new Promise(resolve => setTimeout(resolve, delay)); + throw err; } - - throw new Error('Timed out waiting for change set creation'); } /** * Wait for stack to reach IMPORT_COMPLETE status. */ async function waitForStackImportComplete(cfn: CloudFormationClient, stackName: string): Promise { - const maxAttempts = 120; - const delay = 5000; // 5 seconds - - for (let attempt = 0; attempt < maxAttempts; attempt++) { - const response = await cfn.send(new DescribeStacksCommand({ StackName: stackName })); - const stack = response.Stacks?.[0]; - - if (!stack) { - throw new Error(`Stack ${stackName} not found during import wait`); - } - - const status = stack.StackStatus ?? ''; - - if (status === 'IMPORT_COMPLETE') { - return; - } - - if (status.includes('FAILED') || status.includes('ROLLBACK')) { - throw new Error(`Import failed with status: ${status}. Reason: ${stack.StackStatusReason ?? 'Unknown'}`); + try { + await poll({ + fn: async () => { + const response = await cfn.send(new DescribeStacksCommand({ StackName: stackName })); + const stack = response.Stacks?.[0]; + if (!stack) throw new Error(`Stack ${stackName} not found during import wait`); + const status = stack.StackStatus ?? ''; + if (status === 'IMPORT_COMPLETE') return { done: true, value: undefined }; + if (status.includes('FAILED') || status.includes('ROLLBACK')) { + throw new Error(`Import failed with status: ${status}. Reason: ${stack.StackStatusReason ?? 'Unknown'}`); + } + return { done: false }; + }, + maxAttempts: 120, + delayMs: 5000, + onError: (err: unknown) => (isThrottlingError(err) ? 'retry' : 'abort'), + }); + } catch (err) { + if (err instanceof PollExhaustedError || err instanceof PollTimeoutError) { + throw new Error('Timed out waiting for import to complete', { cause: err }); } - - // IMPORT_IN_PROGRESS — keep waiting - await new Promise(resolve => setTimeout(resolve, delay)); + throw err; } - - throw new Error('Timed out waiting for import to complete'); } /** diff --git a/src/lib/utils/__tests__/polling.test.ts b/src/lib/utils/__tests__/polling.test.ts new file mode 100644 index 000000000..c3440b669 --- /dev/null +++ b/src/lib/utils/__tests__/polling.test.ts @@ -0,0 +1,217 @@ +import { PollExhaustedError, PollTimeoutError, isThrottlingError, poll } from '../polling.js'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +/* eslint-disable @typescript-eslint/require-await */ + +describe('poll', () => { + it('returns immediately on first success', async () => { + const result = await poll({ fn: async () => ({ done: true, value: 42 }), maxAttempts: 5 }); + expect(result).toBe(42); + }); + + it('polls until success', async () => { + let count = 0; + const result = await poll({ + fn: async () => { + count++; + return count === 3 ? { done: true, value: 'ok' } : { done: false }; + }, + maxAttempts: 5, + delayMs: 1, + }); + expect(result).toBe('ok'); + expect(count).toBe(3); + }); + + it('throws PollExhaustedError when maxAttempts exceeded', async () => { + await expect(poll({ fn: async () => ({ done: false }), maxAttempts: 3, delayMs: 1 })).rejects.toThrow( + PollExhaustedError + ); + }); + + it('throws PollTimeoutError when timeout exceeded', async () => { + await expect(poll({ fn: async () => ({ done: false }), timeoutMs: 50, delayMs: 20 })).rejects.toThrow( + PollTimeoutError + ); + }); + + describe('backoff', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('applies exponential backoff', async () => { + let count = 0; + const promise = poll({ + fn: async () => { + count++; + return count === 4 ? { done: true, value: 'done' } : { done: false }; + }, + maxAttempts: 5, + delayMs: 100, + backoffFactor: 2, + }); + await vi.advanceTimersByTimeAsync(100); // 1st delay: 100 + await vi.advanceTimersByTimeAsync(200); // 2nd delay: 200 + await vi.advanceTimersByTimeAsync(400); // 3rd delay: 400 + const result = await promise; + expect(result).toBe('done'); + }); + + it('caps delay at maxDelayMs', async () => { + let count = 0; + const promise = poll({ + fn: async () => { + count++; + return count === 4 ? { done: true, value: 'done' } : { done: false }; + }, + maxAttempts: 5, + delayMs: 100, + backoffFactor: 10, + maxDelayMs: 500, + }); + await vi.advanceTimersByTimeAsync(100); // 1st: 100 + await vi.advanceTimersByTimeAsync(500); // 2nd: capped at 500 + await vi.advanceTimersByTimeAsync(500); // 3rd: capped at 500 + const result = await promise; + expect(result).toBe('done'); + }); + }); + + it('retries on error by default', async () => { + let count = 0; + const result = await poll({ + fn: async () => { + count++; + if (count < 3) throw new Error('transient'); + return { done: true, value: 'ok' }; + }, + maxAttempts: 5, + delayMs: 1, + }); + expect(result).toBe('ok'); + expect(count).toBe(3); + }); + + it('aborts on error when onError returns abort', async () => { + const err = new Error('fatal'); + await expect( + poll({ + fn: async () => { + throw err; + }, + maxAttempts: 5, + delayMs: 1, + onError: () => 'abort', + }) + ).rejects.toThrow('fatal'); + }); + + it('throws PollExhaustedError after maxConsecutiveErrors', async () => { + await expect( + poll({ + fn: async () => { + throw new Error('fail'); + }, + maxAttempts: 10, + delayMs: 1, + maxConsecutiveErrors: 3, + }) + ).rejects.toThrow(PollExhaustedError); + }); + + it('PollExhaustedError includes cause with the last error', async () => { + const err = await poll({ + fn: async () => { + throw new Error('Rate exceeded'); + }, + maxAttempts: 3, + delayMs: 1, + }).catch((e: unknown) => e); + expect(err).toBeInstanceOf(PollExhaustedError); + expect((err as PollExhaustedError).cause).toBeInstanceOf(Error); + expect(((err as PollExhaustedError).cause as Error).message).toBe('Rate exceeded'); + }); + + it('PollTimeoutError includes cause with the last error', async () => { + const err = await poll({ + fn: async () => { + throw new Error('service unavailable'); + }, + timeoutMs: 50, + delayMs: 10, + }).catch((e: unknown) => e); + expect(err).toBeInstanceOf(PollTimeoutError); + expect((err as PollTimeoutError).cause).toBeInstanceOf(Error); + expect(((err as PollTimeoutError).cause as Error).message).toBe('service unavailable'); + }); + + it('cause is undefined when no errors occurred during polling', async () => { + const err = await poll({ + fn: async () => ({ done: false }), + maxAttempts: 2, + delayMs: 1, + }).catch((e: unknown) => e); + expect(err).toBeInstanceOf(PollExhaustedError); + expect((err as PollExhaustedError).cause).toBeUndefined(); + }); + + it('resets consecutive error count on success', async () => { + let count = 0; + const result = await poll({ + fn: async () => { + count++; + if (count === 1) throw new Error('err1'); + if (count === 2) throw new Error('err2'); + if (count === 3) return { done: false }; // success resets counter + if (count === 4) throw new Error('err3'); + if (count === 5) throw new Error('err4'); + return { done: true, value: 'ok' }; + }, + maxAttempts: 10, + delayMs: 1, + maxConsecutiveErrors: 3, + }); + expect(result).toBe('ok'); + }); + + it('throws if neither maxAttempts nor timeoutMs provided', async () => { + await expect(poll({ fn: async () => ({ done: true, value: 1 }) })).rejects.toThrow( + 'poll() requires at least one of maxAttempts or timeoutMs' + ); + }); + + it('supports both maxAttempts and timeoutMs together', async () => { + // maxAttempts hit first + await expect( + poll({ fn: async () => ({ done: false }), maxAttempts: 2, timeoutMs: 10000, delayMs: 1 }) + ).rejects.toThrow(PollExhaustedError); + }); +}); + +describe('isThrottlingError', () => { + it('detects ThrottlingException by name', () => { + expect(isThrottlingError({ name: 'ThrottlingException', message: '' })).toBe(true); + }); + + it('detects Rate exceeded in message', () => { + expect(isThrottlingError(new Error('Rate exceeded'))).toBe(true); + }); + + it('detects TooManyRequestsException', () => { + expect(isThrottlingError({ name: 'TooManyRequestsException', message: '' })).toBe(true); + }); + + it('returns false for non-throttle errors', () => { + expect(isThrottlingError(new Error('Stack not found'))).toBe(false); + }); + + it('returns false for null/undefined', () => { + expect(isThrottlingError(null)).toBe(false); + expect(isThrottlingError(undefined)).toBe(false); + }); +}); diff --git a/src/lib/utils/index.ts b/src/lib/utils/index.ts index d595d9bf7..8902fca44 100644 --- a/src/lib/utils/index.ts +++ b/src/lib/utils/index.ts @@ -11,4 +11,5 @@ export { } from './subprocess'; export { parseTimeString } from './time-parser'; export { parseJsonRpcResponse } from './json-rpc'; +export { poll, isThrottlingError, PollTimeoutError, PollExhaustedError } from './polling'; export { validateAgentSchema, validateProjectSchema } from './zod'; diff --git a/src/lib/utils/polling.ts b/src/lib/utils/polling.ts new file mode 100644 index 000000000..3486e634f --- /dev/null +++ b/src/lib/utils/polling.ts @@ -0,0 +1,109 @@ +/** + * Shared polling/retry utility for async operations. + */ + +export type PollResult = { done: true; value: T } | { done: false }; + +export interface PollOptions { + /** Async function called each iteration. Return {done: true, value} when complete, {done: false} to keep polling. */ + fn: () => Promise>; + /** Max number of attempts before throwing PollExhaustedError. */ + maxAttempts?: number; + /** Max total time in ms before throwing PollTimeoutError. */ + timeoutMs?: number; + /** Delay between iterations in ms. Default 5000. */ + delayMs?: number; + /** Multiply delay by this factor each iteration. Default 1 (fixed). */ + backoffFactor?: number; + /** Cap on delay in ms. */ + maxDelayMs?: number; + /** Abort after this many consecutive errors. */ + maxConsecutiveErrors?: number; + /** Called when fn throws. Return 'retry' to continue or 'abort' to rethrow. Default: 'retry'. */ + onError?: (err: unknown) => 'retry' | 'abort'; +} + +export class PollTimeoutError extends Error { + constructor(timeoutMs: number, options?: { cause?: unknown }) { + super(`Polling timed out after ${timeoutMs}ms`, options); + this.name = 'PollTimeoutError'; + } +} + +export class PollExhaustedError extends Error { + constructor(maxAttempts: number, options?: { cause?: unknown }) { + super(`Polling exhausted after ${maxAttempts} attempts`, options); + this.name = 'PollExhaustedError'; + } +} + +export async function poll(options: PollOptions): Promise { + const { + fn, + maxAttempts, + timeoutMs, + delayMs = 5000, + backoffFactor = 1, + maxDelayMs, + maxConsecutiveErrors, + onError, + } = options; + + if (maxAttempts === undefined && timeoutMs === undefined) { + throw new Error('poll() requires at least one of maxAttempts or timeoutMs'); + } + + const start = Date.now(); + let attempts = 0; + let consecutiveErrors = 0; + let currentDelay = delayMs; + let lastError: unknown = undefined; + + while (true) { + if (maxAttempts !== undefined && attempts >= maxAttempts) { + throw new PollExhaustedError(maxAttempts, { cause: lastError }); + } + if (timeoutMs !== undefined && Date.now() - start >= timeoutMs) { + throw new PollTimeoutError(timeoutMs, { cause: lastError }); + } + + attempts++; + + try { + const result = await fn(); + consecutiveErrors = 0; + if (result.done) return result.value; + } catch (err: unknown) { + const action = onError ? onError(err) : 'retry'; + if (action === 'abort') throw err; + lastError = err; + consecutiveErrors++; + if (maxConsecutiveErrors && consecutiveErrors >= maxConsecutiveErrors) { + throw new PollExhaustedError(attempts, { cause: lastError }); + } + } + + // Don't sleep if we're about to exceed timeout + if (timeoutMs !== undefined && Date.now() - start + currentDelay >= timeoutMs) { + throw new PollTimeoutError(timeoutMs, { cause: lastError }); + } + + await new Promise(resolve => setTimeout(resolve, currentDelay)); + currentDelay = maxDelayMs ? Math.min(currentDelay * backoffFactor, maxDelayMs) : currentDelay * backoffFactor; + } +} + +/** Check if an error is an AWS throttling/rate-limit error. */ +export function isThrottlingError(err: unknown): boolean { + if (err == null || typeof err !== 'object') return false; + const name = (err as { name?: string }).name ?? ''; + const message = (err as { message?: string }).message ?? ''; + return ( + name === 'ThrottlingException' || + name === 'Throttling' || + name === 'TooManyRequestsException' || + name === 'RequestLimitExceeded' || + message.includes('Rate exceeded') || + message.includes('Throttling') + ); +}