Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing-style").#13681
Open
gueraf wants to merge 1 commit intohuggingface:mainfrom
Open
Add a (rolling) KV cache for Wan models to enable autoregressive rollouts ("SelfForcing-style").#13681gueraf wants to merge 1 commit intohuggingface:mainfrom
gueraf wants to merge 1 commit intohuggingface:mainfrom
Conversation
fb97f37 to
b2f85fa
Compare
123314b to
7c6255e
Compare
WanKVCache is a per-block self-attention KV cache that lets a Wan
transformer generate video chunk by chunk while reusing the K/V tensors
computed for prior chunks instead of re-running the full attention over the
whole prefix on every step.
API:
- ``WanKVCache(num_blocks, window_size=-1)`` — one cache per transformer
instance. ``window_size=-1`` keeps the full prefix; a finite window
evicts the oldest tokens once the cap is reached.
- ``cache.enable_append_mode()`` / ``cache.enable_overwrite_mode()`` — pick
the write semantics for the next forward pass. Append grows the cache
(or rolls when full); overwrite replaces the newest chunk in place — used
for additional denoising steps that re-do the most recent chunk.
- ``cache.update(block_idx, key, value)`` — called from ``WanAttnProcessor``
during self-attention to merge the current chunk into the per-block
cache and return the K/V to attend over.
- ``cache.reset()`` — clear all blocks between videos.
Wan plumbing:
- ``WanTransformer3DModel.forward`` accepts ``frame_offset: int = 0`` and
forwards ``kv_cache`` (extracted from ``attention_kwargs``) plus
``block_idx`` to each transformer block.
- ``WanRotaryPosEmbed.forward`` takes ``frame_offset`` so RoPE can address
positions in the original (uncached) sequence even when the latent input
is just one chunk.
- ``WanAttnProcessor.__call__`` receives ``kv_cache`` / ``block_idx``;
on self-attention it calls ``cache.update(...)`` and uses the returned
K/V for SDPA. Cross-attention is unaffected.
Caller usage::
cache = WanKVCache(num_blocks=len(transformer.blocks))
for chunk_idx, latent_chunk in enumerate(chunks):
cache.enable_append_mode()
for step_idx, t in enumerate(denoising_steps):
if step_idx > 0:
cache.enable_overwrite_mode()
transformer(
hidden_states=latent_chunk,
timestep=t,
encoder_hidden_states=prompt_embeds,
frame_offset=chunk_idx * patch_frames_per_chunk,
attention_kwargs={"kv_cache": cache},
)
Tests cover unbounded append, windowed append (with eviction across one and
multiple chunks), in-place overwrite of the newest chunk, the
read-from-prior-context contract, reset, and frame_offset's effect on RoPE.
7c6255e to
5d28b7a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Motivation
This is a tightly scoped follow-up to #12773 and a first step toward #12600. The previous draft explored similar functionality but also included Krea-specific experiments and broader integration work.
As for practical use, we (https://odyssey.ml/) would like to rely on the Hugging Face Diffusers ecosystem to ship Self-Forcing-like models without having to ship many custom modules, ideally none.
Progresses #12600
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul thanks for offering help with this :)