diff --git a/benchmark/reimFFT.ts b/benchmark/reimFFT.ts index 1ac7d00f..f3799718 100644 --- a/benchmark/reimFFT.ts +++ b/benchmark/reimFFT.ts @@ -1,99 +1,95 @@ /* eslint-disable no-console */ +import FFT from 'fft.js'; +import { XSadd } from 'ml-xsadd'; + import { reimFFT } from '../src/reim/reimFFT.ts'; import { reimArrayFFT } from '../src/reimArray/reimArrayFFT.ts'; +import type { DataReIm } from '../src/types/index.ts'; -const size = 2 ** 16; -const count = 10; // number of spectra in the array benchmark +const size = 2 ** 16; // 64k-point transform: FFT setup dominates the cost +const count = 10; // number of spectra processed per round +const targetMs = 5000; -// Build input data +// Deterministic, reproducible input so every section runs on identical data. +const { random } = new XSadd(42); const spectra = Array.from({ length: count }, () => { const re = new Float64Array(size); const im = new Float64Array(size); for (let i = 0; i < size; i++) { - re[i] = Math.random(); - im[i] = Math.random(); + re[i] = random(); + im[i] = random(); } return { re, im }; }); -// Warmup -for (const s of spectra) reimFFT(s); -for (const s of spectra) reimFFT(s, { inPlace: true }); -reimArrayFFT(spectra); -reimArrayFFT(spectra, { inPlace: true }); - -const targetMs = 5000; +/** + * `reimFFT` as it was *before* the cache fix: a fresh `FFT` instance is built on + * every call. Kept here as the baseline to confirm the cached version is faster. + * @param data - complex spectrum. + * @returns FFT of the complex spectrum. + */ +function reimFFTNoCache(data: DataReIm): DataReIm { + const { re, im } = data; + const length = re.length; + const csize = length << 1; -// --- reimFFT (loop over each spectrum individually) --- -{ - let iterations = 0; - const start = performance.now(); - console.time('reimFFT (loop)'); - while (performance.now() - start < targetMs) { - for (const s of spectra) reimFFT(s); - iterations++; + const complexArray = new Float64Array(csize); + for (let i = 0; i < csize; i += 2) { + complexArray[i] = re[i >>> 1]; + complexArray[i + 1] = im[i >>> 1]; } - const elapsed = performance.now() - start; - console.timeEnd('reimFFT (loop)'); - console.log( - ` ${iterations * count} total FFTs, ${count} spectra × ${iterations} rounds`, - ); - console.log(` ${(elapsed / (iterations * count)).toFixed(3)} ms per FFT`); -} -console.log(''); + const fft = new FFT(length); + const output = new Float64Array(csize); + fft.transform(output, complexArray); -// --- reimFFT inPlace (loop over each spectrum individually) --- -{ - let iterations = 0; - const start = performance.now(); - console.time('reimFFT inPlace (loop)'); - while (performance.now() - start < targetMs) { - for (const s of spectra) reimFFT(s, { inPlace: true }); - iterations++; + const newRe = new Float64Array(length); + const newIm = new Float64Array(length); + for (let i = 0; i < csize; i += 2) { + newRe[i >>> 1] = output[i]; + newIm[i >>> 1] = output[i + 1]; } - const elapsed = performance.now() - start; - console.timeEnd('reimFFT inPlace (loop)'); - console.log( - ` ${iterations * count} total FFTs, ${count} spectra × ${iterations} rounds`, - ); - console.log(` ${(elapsed / (iterations * count)).toFixed(3)} ms per FFT`); + return { re: newRe, im: newIm }; } -console.log(''); - -// --- reimArrayFFT (single call for the whole array) --- -{ - let iterations = 0; +/** + * Run `task` repeatedly for `targetMs` and report the time per FFT. Each round + * performs `count` transforms. + * @param label - section name. + * @param task - one round of work (transforms all `count` spectra). + */ +function bench(label: string, task: () => void): void { + task(); // warmup + let rounds = 0; const start = performance.now(); - console.time('reimArrayFFT'); while (performance.now() - start < targetMs) { - reimArrayFFT(spectra); - iterations++; + task(); + rounds++; } const elapsed = performance.now() - start; - console.timeEnd('reimArrayFFT'); + const totalFFTs = rounds * count; + console.log(label); console.log( - ` ${iterations * count} total FFTs, ${count} spectra × ${iterations} rounds`, + ` ${(elapsed / totalFFTs).toFixed(3)} ms per FFT (${totalFFTs} FFTs over ${rounds} rounds)`, ); - console.log(` ${(elapsed / (iterations * count)).toFixed(3)} ms per FFT`); + console.log(''); } +console.log(`FFT size: ${size} (2^16), ${count} spectra per round`); console.log(''); -// --- reimArrayFFT inPlace (single call for the whole array) --- -{ - let iterations = 0; - const start = performance.now(); - console.time('reimArrayFFT inPlace'); - while (performance.now() - start < targetMs) { - reimArrayFFT(spectra, { inPlace: true }); - iterations++; - } - const elapsed = performance.now() - start; - console.timeEnd('reimArrayFFT inPlace'); - console.log( - ` ${iterations * count} total FFTs, ${count} spectra × ${iterations} rounds`, - ); - console.log(` ${(elapsed / (iterations * count)).toFixed(3)} ms per FFT`); -} +// Before the fix: new FFT instance per call. +bench('reimFFT — before fix (new FFT per call)', () => { + for (const spectrum of spectra) reimFFTNoCache(spectrum); +}); + +// After the fix: FFT instance cached per size and reused across calls. +bench('reimFFT — after fix (cached FFT instance)', () => { + for (const spectrum of spectra) reimFFT(spectrum); +}); + +// reimArrayFFT reuses a single FFT instance (and working buffers) for the whole +// array in one call. +bench('reimArrayFFT (single shared FFT instance)', () => { + reimArrayFFT(spectra); +}); diff --git a/benchmark/xHilbertTransform.ts b/benchmark/xHilbertTransform.ts new file mode 100644 index 00000000..048e6ab2 --- /dev/null +++ b/benchmark/xHilbertTransform.ts @@ -0,0 +1,94 @@ +/* eslint-disable no-console */ +import FFT from 'fft.js'; +import { XSadd } from 'ml-xsadd'; + +import { xHilbertTransform } from '../src/x/xHilbertTransform.ts'; + +const size = 2 ** 16; // power of two => the FFT path (hilbertTransformWithFFT) +const count = 10; // signals processed per round +const targetMs = 5000; + +// Deterministic, reproducible input. +const { random } = new XSadd(42); +const signals = Array.from({ length: count }, () => { + const array = new Float64Array(size); + for (let i = 0; i < size; i++) array[i] = random() * 2 - 1; + return array; +}); + +/** + * Hilbert transform via FFT as it was *before* the shared cache: a fresh `FFT` + * instance is built on every call. Mirrors `hilbertTransformWithFFT`, used as + * the baseline to confirm the cached version is faster. + * @param array - real input signal whose length is a power of two. + * @returns the Hilbert transform (90° phase-shifted signal). + */ +function hilbertNoCache(array: Float64Array): Float64Array { + const length = array.length; + const fft = new FFT(length); + + const spectrum = new Float64Array(length * 2); + fft.realTransform(spectrum, array); + fft.completeSpectrum(spectrum); + + const half = length >> 1; + const nyquist = half << 1; + spectrum[nyquist] = 0; + spectrum[nyquist + 1] = 0; + for (let j = (half + 1) << 1; j < spectrum.length; j += 2) { + spectrum[j] = -spectrum[j]; + spectrum[j + 1] = -spectrum[j + 1]; + } + + const hilbertSignal = new Float64Array(length * 2); + fft.inverseTransform(hilbertSignal, spectrum); + + const result = new Float64Array(length); + for (let i = 0; i < length; i++) result[i] = hilbertSignal[i * 2 + 1]; + return result; +} + +/** + * Run `task` for `targetMs` and report the time per transform. Each round + * processes `count` signals. + * @param label - section name. + * @param task - one round of work (transforms all `count` signals). + */ +function bench(label: string, task: () => void): void { + task(); // warmup + let rounds = 0; + const start = performance.now(); + while (performance.now() - start < targetMs) { + task(); + rounds++; + } + const elapsed = performance.now() - start; + const total = rounds * count; + console.log(label); + console.log( + ` ${(elapsed / total).toFixed(3)} ms per transform (${total} transforms over ${rounds} rounds)`, + ); + console.log(''); +} + +// Sanity check: the cached path and the baseline must compute the same thing. +const reference = hilbertNoCache(signals[0]); +const cached = xHilbertTransform(signals[0]); +let maxDiff = 0; +for (let i = 0; i < reference.length; i++) { + const diff = Math.abs(reference[i] - cached[i]); + if (diff > maxDiff) maxDiff = diff; +} + +console.log( + `xHilbertTransform: size ${size} (2^16), ${count} signals per round`, +); +console.log(`equivalence check: max abs diff ${maxDiff.toExponential(2)}\n`); + +bench('before shared cache (new FFT per call)', () => { + for (const signal of signals) hilbertNoCache(signal); +}); + +bench('after shared cache (reused FFT instance)', () => { + for (const signal of signals) xHilbertTransform(signal); +}); diff --git a/src/__tests__/__snapshots__/index.test.ts.snap b/src/__tests__/__snapshots__/index.test.ts.snap index f87c96fe..ea521371 100644 --- a/src/__tests__/__snapshots__/index.test.ts.snap +++ b/src/__tests__/__snapshots__/index.test.ts.snap @@ -181,6 +181,8 @@ exports[`existence of exported functions 1`] = ` "matrixZRescale", "matrixZRescalePerColumn", "matrixTranspose", + "clearFFTCache", + "setFFTCacheMaxSize", "createNumberArray", "createDoubleArray", "createFromToArray", diff --git a/src/matrix/matrixHilbertTransform.ts b/src/matrix/matrixHilbertTransform.ts index ac1b1460..6f8f243c 100644 --- a/src/matrix/matrixHilbertTransform.ts +++ b/src/matrix/matrixHilbertTransform.ts @@ -1,5 +1,4 @@ -import FFT from 'fft.js'; - +import { getFFT } from '../utils/fftCache.ts'; import { isPowerOfTwo } from '../utils/index.ts'; import { matrixCreateEmpty } from './matrixCreateEmpty.ts'; @@ -41,7 +40,7 @@ export function matrixHilbertTransform( } // Single FFT instance reused across all rows - const fft = new FFT(size); + const fft = getFFT(size); // Multiplier computed once — identical for every row of the same length const half = size >> 1; diff --git a/src/reim/reimFFT.ts b/src/reim/reimFFT.ts index ee7a057a..be40b07f 100644 --- a/src/reim/reimFFT.ts +++ b/src/reim/reimFFT.ts @@ -1,6 +1,5 @@ -import FFT from 'fft.js'; - import type { DataReIm } from '../types/index.ts'; +import { getFFT } from '../utils/fftCache.ts'; import { zeroShift } from './zeroShift.ts'; @@ -36,7 +35,7 @@ export function reimFFT( complexArray[i + 1] = im[i >>> 1]; } - const fft = new FFT(size); + const fft = getFFT(size); let output = new Float64Array(csize); if (inverse) { if (applyZeroShift) complexArray = zeroShift(complexArray, true); diff --git a/src/reimArray/reimArrayFFT.ts b/src/reimArray/reimArrayFFT.ts index 3645376d..91e897cb 100644 --- a/src/reimArray/reimArrayFFT.ts +++ b/src/reimArray/reimArrayFFT.ts @@ -1,7 +1,6 @@ -import FFT from 'fft.js'; - import { zeroShift } from '../reim/zeroShift.ts'; import type { DataReIm } from '../types/index.ts'; +import { getFFT } from '../utils/fftCache.ts'; export interface ReimArrayFFTOptions { inverse?: boolean; @@ -41,7 +40,7 @@ export function reimArrayFFT( } // Single FFT instance and working buffers reused across all spectra - const fft = new FFT(size); + const fft = getFFT(size); const complexArray = new Float64Array(csize); const output = new Float64Array(csize); diff --git a/src/reimMatrix/reimMatrixFFT.ts b/src/reimMatrix/reimMatrixFFT.ts index 7ae3d5d1..fc1c944a 100644 --- a/src/reimMatrix/reimMatrixFFT.ts +++ b/src/reimMatrix/reimMatrixFFT.ts @@ -1,7 +1,6 @@ -import FFT from 'fft.js'; - import { zeroShift } from '../reim/zeroShift.ts'; import type { DataReImMatrix } from '../types/index.ts'; +import { getFFT } from '../utils/fftCache.ts'; export interface ReimMatrixFFTOptions { inverse?: boolean; @@ -44,7 +43,7 @@ export function reimMatrixFFT( } // Single FFT instance and working buffers reused across all rows - const fft = new FFT(size); + const fft = getFFT(size); const complexArray = new Float64Array(csize); const output = new Float64Array(csize); diff --git a/src/reimMatrix/reimMatrixFFTByColumns.ts b/src/reimMatrix/reimMatrixFFTByColumns.ts index 4c29dda1..d9a70262 100644 --- a/src/reimMatrix/reimMatrixFFTByColumns.ts +++ b/src/reimMatrix/reimMatrixFFTByColumns.ts @@ -1,7 +1,6 @@ -import FFT from 'fft.js'; - import { zeroShift } from '../reim/zeroShift.ts'; import type { DataReImMatrix } from '../types/index.ts'; +import { getFFT } from '../utils/fftCache.ts'; export interface ReimMatrixFFTByColumnsOptions { inverse?: boolean; @@ -49,7 +48,7 @@ export function reimMatrixFFTByColumns( } // Single FFT instance and working buffers reused across all columns - const fft = new FFT(numRows); + const fft = getFFT(numRows); const complexArray = new Float64Array(csize); const output = new Float64Array(csize); diff --git a/src/utils/__tests__/fftCache.test.ts b/src/utils/__tests__/fftCache.test.ts new file mode 100644 index 00000000..c186cb41 --- /dev/null +++ b/src/utils/__tests__/fftCache.test.ts @@ -0,0 +1,50 @@ +import { expect, test } from 'vitest'; + +import { clearFFTCache, getFFT, setFFTCacheMaxSize } from '../fftCache.ts'; + +test('getFFT returns the same cached instance for the same size', () => { + const first = getFFT(1024); + const second = getFFT(1024); + + expect(second).toBe(first); +}); + +test('getFFT returns distinct instances for distinct sizes', () => { + expect(getFFT(512)).not.toBe(getFFT(2048)); +}); + +test('clearFFTCache forces a fresh instance to be built', () => { + const before = getFFT(256); + clearFFTCache(); + const after = getFFT(256); + + expect(after).not.toBe(before); +}); + +test('cache is bounded: a fresh instance is built after the cap is exceeded', () => { + clearFFTCache(); + const original = getFFT(2); + // Insert 10 further distinct sizes; the cap is 10, so reaching it clears the + // whole cache and evicts the original size-2 instance. + for (let power = 2; power <= 11; power++) getFFT(2 ** power); + + expect(getFFT(2)).not.toBe(original); +}); + +test('setFFTCacheMaxSize lowers the bound and clears when over it', () => { + setFFTCacheMaxSize(10); + const kept = getFFT(64); + getFFT(128); + getFFT(256); + // Shrinking below the current count (3) clears the cache immediately. + setFFTCacheMaxSize(2); + + expect(getFFT(64)).not.toBe(kept); + + setFFTCacheMaxSize(10); // restore for any later tests in this file +}); + +test('setFFTCacheMaxSize rejects non-positive or non-integer sizes', () => { + expect(() => setFFTCacheMaxSize(0)).toThrow(RangeError); + expect(() => setFFTCacheMaxSize(1.5)).toThrow(RangeError); +}); diff --git a/src/utils/fftCache.ts b/src/utils/fftCache.ts new file mode 100644 index 00000000..acc4ff40 --- /dev/null +++ b/src/utils/fftCache.ts @@ -0,0 +1,87 @@ +import FFT from 'fft.js'; + +// An FFT instance precomputes size-dependent twiddle factors and a bit-reversal +// table; for a 64k transform that setup dominates the cost. A single shared +// cache keeps one instance per size so every FFT-based function in the library +// (spectra, matrices, Hilbert transforms…) reuses it instead of rebuilding the +// tables on each call. +class FFTCache { + #maxSize: number; + readonly #instances = new Map(); + + constructor(maxSize = 10) { + this.#maxSize = maxSize; + } + + get maxSize(): number { + return this.#maxSize; + } + + set maxSize(value: number) { + if (!Number.isInteger(value) || value < 1) { + throw new RangeError( + `FFT cache size must be a positive integer, got ${value}.`, + ); + } + this.#maxSize = value; + // Honour the new bound right away: if already over it, drop everything, + // matching the clear-when-full strategy used in get(). + if (this.#instances.size > value) this.#instances.clear(); + } + + get(size: number): FFT { + let fft = this.#instances.get(size); + if (fft === undefined) { + // Bound the cache so transforming many different sizes cannot grow it + // without limit. Rather than track insertion order to evict a single + // entry, drop everything once full: distinct sizes are rare, so the + // common workload never hits this and stays at full speed. + if (this.#instances.size >= this.#maxSize) this.#instances.clear(); + fft = new FFT(size); + this.#instances.set(size, fft); + } + return fft; + } + + clear(): void { + this.#instances.clear(); + } +} + +// Process-wide singleton: a single shared cache, like a static global, but kept +// as an instance so its state stays fully encapsulated. +const fftCache = new FFTCache(); + +/** + * Returns a cached `fft.js` instance for the given transform size, building it + * on first use. Shared by every FFT-based function in the library so the + * size-dependent lookup tables are computed once and reused across calls. + * @param size - number of points of the transform. + * @returns a reusable FFT instance for that size. + */ +export function getFFT(size: number): FFT { + return fftCache.get(size); +} + +/** + * Releases every cached FFT instance. The library keeps one FFT instance per + * transform size to avoid rebuilding its lookup tables on each call; each + * instance holds tables proportional to its size (≈1 MB for a 64k transform). + * Call this to free that memory once you are done transforming — for example + * after processing a batch of large spectra. Subsequent transforms simply + * rebuild whatever sizes they need. + */ +export function clearFFTCache(): void { + fftCache.clear(); +} + +/** + * Sets the maximum number of distinct transform sizes kept in the shared FFT + * cache (default 10). When the cache is full and a new size is requested, the + * whole cache is dropped and rebuilt on demand. Lowering the limit below the + * current number of cached sizes clears the cache immediately. + * @param maxSize - maximum number of cached sizes; must be a positive integer. + */ +export function setFFTCacheMaxSize(maxSize: number): void { + fftCache.maxSize = maxSize; +} diff --git a/src/utils/index.ts b/src/utils/index.ts index 73ab454f..58b8388c 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -2,6 +2,7 @@ export * from './createArray.ts'; export * from './createFromToArray.ts'; export * from './createRandomArray.ts'; export * from './createStepArray.ts'; +export { clearFFTCache, setFFTCacheMaxSize } from './fftCache.ts'; export * from './getCombinations.ts'; export * from './getCombinationsIterator.ts'; export * from './getRescaler.ts'; diff --git a/src/x/xHilbertTransform.ts b/src/x/xHilbertTransform.ts index c9a64d00..919ec7b6 100644 --- a/src/x/xHilbertTransform.ts +++ b/src/x/xHilbertTransform.ts @@ -1,6 +1,6 @@ import type { NumberArray } from 'cheminfo-types'; -import FFT from 'fft.js'; +import { getFFT } from '../utils/fftCache.ts'; import { isPowerOfTwo, nextPowerOfTwo } from '../utils/index.ts'; import { xCheck } from './xCheck.ts'; @@ -44,7 +44,7 @@ export function xHilbertTransform( */ function hilbertTransformWithFFT(array: NumberArray) { const length = array.length; - const fft = new FFT(length); + const fft = getFFT(length); // Single reusable buffer for FFT spectrum const spectrum = new Float64Array(length * 2);