Skip to content

streaming support#16

Open
srao25 wants to merge 5 commits into
mainfrom
sr/streaming
Open

streaming support#16
srao25 wants to merge 5 commits into
mainfrom
sr/streaming

Conversation

@srao25
Copy link
Copy Markdown
Collaborator

@srao25 srao25 commented Mar 19, 2026

Streaming audio generation

Adds end-to-end streaming to generate() via a generator/iterator API. Audio chunks are emitted during LLM generation, not after. No retraining required, fully backward compatible.

Usage

stream = model.generate(prompt=prompt, text="Hello world", stream=True)
for chunk, sample_rate in stream:
    play(chunk)  # chunk: torch.Tensor, 24kHz mono
output = stream.result  # GenerationOutput (audio=None — caller collects chunks)

When stream is not set (default), the existing non-streaming path runs unchanged.

What changed

generate() gains two params:

Param Default Description
stream False When True, returns an AudioStream iterator yielding (chunk, sample_rate) tuples.
streaming_cnn_window_size 100 CNN sliding-window size in frames. Use 50 for <32 GB RAM.

Under the hood, _generate() is now a Python generator. It yields one StepOutput per predicted token (carrying the new acoustic feature + time_before) and a final SyncTokGenerationOutput on completion. The non-streaming path simply drains the generator eagerly, so its semantics are unchanged.

How it works

The TADA decoder has two stages, both made streaming:

  1. Transformer (KV-cache, bit-exact). forward_with_cache() + _apply_rope_with_offset() on LocalSelfAttention / LocalAttentionEncoderLayer (encoder.py). Each new token computes only Q and attends to cached post-RoPE K,V. The v2 block attention mask restricts attention to current + previous block. Max diff vs. full-sequence pass: ~1.91e-06 (float noise only).

  2. CNN / DACDecoder (sliding window). New StreamingDecoder class in decoder.py. The CNN uses symmetric padding (non-causal), so it needs left and right context. The sliding window provides 20 frames of left context and 15 frames of right lookahead, empirically measured against the pretrained weights to be inaudible (≤ 0.0003 max diff). _all_hidden is capped at the window size for bounded memory.

E2E wiring lives in AudioStream.__iter__ (tada.py): on each predicted token it denormalizes acoustic features, feeds them to StreamingDecoder.decode_block(), and yields any audio that the sliding window has been able to emit. On the first token, leading silence is skipped at the frame level. On the final SyncTokGenerationOutput, the streaming decoder is flushed and stream.result is populated.

Other fixes bundled in

  • Tokenization boundary bug (pre-existing): generate() used to encode the prompt and gen text separately, causing BPE merges to drop characters at the seam (e.g. ".Hello"".H" + "ello"). Now uses joint encoding with a space separator. Fixes both streaming and non-streaming paths.
  • Leading-silence transition blip (pre-existing, ~1 in 4 runs): when the LLM predicts a very short time_before[0], the leading-silence trim removed almost nothing and exposed a CNN transition artifact. Now clamps minimum trim to 5 frames (100 ms) in both paths.

Performance

True time-to-first-audio (TTFA) measured from generate() call to first chunk:

Model Device Text True TTFA Total Non-streaming total
TADA-1B GPU (H100, bf16) Short ~440 ms 0.47 s 0.77 s
TADA-1B GPU (H100, bf16) Medium ~350 ms 1.06 s 0.86 s
TADA-1B GPU (H100, bf16) Long ~330 ms 3.02 s 2.47 s
TADA-1B CPU (fp32) Short ~2.3 s 3.2 s 2.9 s
TADA-1B CPU (fp32) Long ~2.1 s 25 s 18 s

Peak VRAM is lower in streaming mode for medium/long text (up to 3.6 GB less than non-streaming), because the streaming decoder operates on a bounded sliding window instead of materializing all hidden states at once.

Tests

  • Unit tests (no GPU, no weights): StepOutput, AudioStream (fake generator + tiny Decoder), StreamingDecoder (basic streaming, skip_leading_frames, reset, buffering, flush), segment attention mask.
  • Integration tests (@pytest.mark.integration, real TADA-1B on GPU):
    • test_non_streaming_unchanged — non-streaming path still produces audio.
    • test_streaming_produces_chunks — streaming yields chunks, stream.result is populated.
    • test_streaming_vs_nonstreaming_similar_length — streaming and non-streaming produce audio within 0.5×–2× length of each other on the same text.
    • test_streaming_early_break — breaking out of the iterator mid-stream is safe.
    • TestGenerateAudios — generates 5 streaming + 5 non-streaming WAVs for manual A/B listening.
  • Run: pytest tests/test_streaming.py -m integration -s (or sbatch tests/run_integration.sh).

All integration tests passed on TADA-1B (H100) on the latest run.

Files changed

  • tada/modules/decoder.py — new StreamingDecoder class (~365 lines). Existing Decoder untouched.
  • tada/modules/encoder.pyforward_with_cache() + _apply_rope_with_offset() added to attention layers. Existing forward() paths untouched.
  • tada/modules/tada.py_generate() becomes a generator yielding StepOutput; new AudioStream class; generate() gains stream + streaming_cnn_window_size params; tokenization fix; 5-frame trim clamp.
  • tada/modules/__init__.py — exports AudioStream, StreamingDecoder.
  • README.md — new "Streaming Audio Generation" section with examples, perf table, parameter reference.
  • tests/test_streaming.py, tests/run_integration.sh — new test suite.

Future work (not in this PR)

  • Causal CNN to eliminate the 15-frame right-context lookahead (would require retraining).
  • torch.compile on the streaming CNN/transformer path.

srao25 and others added 2 commits May 11, 2026 17:02
Clears PytestUnknownMarkWarning when running streaming integration tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant