diff --git a/src/stream.ts b/src/stream.ts index 54a25b5..97658a1 100644 --- a/src/stream.ts +++ b/src/stream.ts @@ -2,13 +2,26 @@ import { Semaphore } from "./sync/semaphore"; import { unboundedChannel } from "./sync/mpsc"; import { sleep as timeSleep, TimeoutError } from "./time"; -export async function* map( +export function map( source: AsyncIterable, fn: (item: T) => U, ): AsyncIterable { - for await (const item of source) { - yield fn(item); - } + return { + [Symbol.asyncIterator]() { + const iter = source[Symbol.asyncIterator](); + return { + async next() { + const { done, value } = await iter.next(); + if (done) return { done: true, value: undefined } as IteratorReturnResult; + return { done: false, value: fn(value) }; + }, + async return(val?: any) { + await iter.return?.(val); + return { done: true as const, value: undefined }; + }, + }; + }, + }; } /** Named `andThen` to avoid JS thenable conflicts with `then`. */ @@ -113,9 +126,15 @@ export function bufferUnordered( let inFlight = 0; const drainSource = async () => { + const iter = source[Symbol.asyncIterator](); try { - for await (const promise of source) { + while (true) { const permit = await sem.acquire(); + const { done, value: promise } = await iter.next(); + if (done) { + permit.release(); + break; + } inFlight++; Promise.resolve(promise).then( (value: T) => { @@ -163,9 +182,15 @@ export function buffered( let nextIndex = 0; const drainSource = async () => { + const iter = source[Symbol.asyncIterator](); try { - for await (const promise of source) { + while (true) { const permit = await sem.acquire(); + const { done, value: promise } = await iter.next(); + if (done) { + permit.release(); + break; + } const idx = nextIndex++; inFlight++; Promise.resolve(promise).then( diff --git a/tests/stream.test.ts b/tests/stream.test.ts index 9cad98d..736a1bf 100644 --- a/tests/stream.test.ts +++ b/tests/stream.test.ts @@ -330,6 +330,68 @@ describe("map - edge cases", () => { const result = await collect(map(fromArray([42]), (x) => x.toString())); expect(result).toEqual(["42"]); }); + + it("does not auto-await promises returned by fn", async () => { + const mapped = map(fromArray([1, 2, 3]), (x) => Promise.resolve(x * 10)); + const results: unknown[] = []; + for await (const item of mapped) { + expect(item).toBeInstanceOf(Promise); + results.push(await item); + } + expect(results).toEqual([10, 20, 30]); + }); + + it("composes with bufferUnordered for concurrent async mapping", async () => { + const delays = [30, 10, 20]; + const source = fromArray([0, 1, 2]); + + const mapped = map(source, (i) => + new Promise((resolve) => setTimeout(() => resolve(i), delays[i])), + ); + const results = await collect(bufferUnordered(mapped, 3)); + + expect(results.sort()).toEqual([0, 1, 2]); + }); + + it("composes with bufferUnordered respecting concurrency", async () => { + let maxConcurrent = 0; + let current = 0; + + const source = fromArray([0, 1, 2, 3, 4, 5]); + const mapped = map(source, (i) => + new Promise((resolve) => { + current++; + if (current > maxConcurrent) maxConcurrent = current; + setTimeout(() => { + current--; + resolve(i); + }, 20); + }), + ); + + const results = await collect(bufferUnordered(mapped, 2)); + expect(results.sort()).toEqual([0, 1, 2, 3, 4, 5]); + expect(maxConcurrent).toBeLessThanOrEqual(2); + }); + + it("delegates return to source iterator on early break", async () => { + let cleanedUp = false; + async function* source(): AsyncIterable { + try { + yield 1; + yield 2; + yield 3; + } finally { + cleanedUp = true; + } + } + + const mapped = map(source(), (x) => x * 2); + for await (const item of mapped) { + if (item === 2) break; + } + expect(cleanedUp).toBe(true); + }); }); describe("filter - edge cases", () => {