Skip to content

Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing-style").#13681

Open
gueraf wants to merge 1 commit intohuggingface:mainfrom
gueraf:wan-rolling-kv-cache
Open

Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing-style").#13681
gueraf wants to merge 1 commit intohuggingface:mainfrom
gueraf:wan-rolling-kv-cache

Conversation

@gueraf
Copy link
Copy Markdown

@gueraf gueraf commented May 5, 2026

What does this PR do?

  • Implements a simple (rolling) KV cache for Wan models to enable autoregressive generation.
  • Tries to mirror the KV cache pattern in transformer_flux2.py as well as transformer's DynamicCache as much as possible.
  • Videos and byte-level equivalence against upstream Self Forcing tested in https://github.com/gueraf/self-forcing-diffusers/ (see videos attached to release, and inference script here).
  • This initial PR does not yet implement sink-frame pinning yet, lacks some model-level adjustments (Self Forcing has cross-attention QK norms and per-frame timestep modulation), and does not implement cross attention caching (easy to add, but in reality this is negligible GPU time and often a small regression).
  • Add tests for cache append/overwrite, and window eviction behavior.

Motivation

This is a tightly scoped follow-up to #12773 and a first step toward #12600. The previous draft explored similar functionality but also included Krea-specific experiments and broader integration work.

As for practical use, we (https://odyssey.ml/) would like to rely on the Hugging Face Diffusers ecosystem to ship Self-Forcing-like models without having to ship many custom modules, ideally none.

Progresses #12600

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul thanks for offering help with this :)

@github-actions github-actions Bot added models tests size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 5, 2026
@gueraf gueraf changed the title Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing"). Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing-style"). May 6, 2026
@gueraf gueraf force-pushed the wan-rolling-kv-cache branch 2 times, most recently from fb97f37 to b2f85fa Compare May 6, 2026 19:58
@github-actions github-actions Bot added the documentation Improvements or additions to documentation label May 6, 2026
@gueraf gueraf force-pushed the wan-rolling-kv-cache branch 5 times, most recently from 123314b to 7c6255e Compare May 6, 2026 20:25
WanKVCache is a per-block self-attention KV cache that lets a Wan
transformer generate video chunk by chunk while reusing the K/V tensors
computed for prior chunks instead of re-running the full attention over the
whole prefix on every step.

API:
- ``WanKVCache(num_blocks, window_size=-1)`` — one cache per transformer
  instance. ``window_size=-1`` keeps the full prefix; a finite window
  evicts the oldest tokens once the cap is reached.
- ``cache.enable_append_mode()`` / ``cache.enable_overwrite_mode()`` — pick
  the write semantics for the next forward pass. Append grows the cache
  (or rolls when full); overwrite replaces the newest chunk in place — used
  for additional denoising steps that re-do the most recent chunk.
- ``cache.update(block_idx, key, value)`` — called from ``WanAttnProcessor``
  during self-attention to merge the current chunk into the per-block
  cache and return the K/V to attend over.
- ``cache.reset()`` — clear all blocks between videos.

Wan plumbing:
- ``WanTransformer3DModel.forward`` accepts ``frame_offset: int = 0`` and
  forwards ``kv_cache`` (extracted from ``attention_kwargs``) plus
  ``block_idx`` to each transformer block.
- ``WanRotaryPosEmbed.forward`` takes ``frame_offset`` so RoPE can address
  positions in the original (uncached) sequence even when the latent input
  is just one chunk.
- ``WanAttnProcessor.__call__`` receives ``kv_cache`` / ``block_idx``;
  on self-attention it calls ``cache.update(...)`` and uses the returned
  K/V for SDPA. Cross-attention is unaffected.

Caller usage::

    cache = WanKVCache(num_blocks=len(transformer.blocks))
    for chunk_idx, latent_chunk in enumerate(chunks):
        cache.enable_append_mode()
        for step_idx, t in enumerate(denoising_steps):
            if step_idx > 0:
                cache.enable_overwrite_mode()
            transformer(
                hidden_states=latent_chunk,
                timestep=t,
                encoder_hidden_states=prompt_embeds,
                frame_offset=chunk_idx * patch_frames_per_chunk,
                attention_kwargs={"kv_cache": cache},
            )

Tests cover unbounded append, windowed append (with eviction across one and
multiple chunks), in-place overwrite of the newest chunk, the
read-from-prior-context contract, reset, and frame_offset's effect on RoPE.
@gueraf gueraf force-pushed the wan-rolling-kv-cache branch from 7c6255e to 5d28b7a Compare May 6, 2026 20:35
@gueraf gueraf marked this pull request as ready for review May 6, 2026 20:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant