From 5d28b7a1f1f079526125df758aedfddd99afc680 Mon Sep 17 00:00:00 2001 From: Fabian Guera Date: Wed, 6 May 2026 20:54:28 +0100 Subject: [PATCH] Add WanKVCache for autoregressive Wan video generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WanKVCache is a per-block self-attention KV cache that lets a Wan transformer generate video chunk by chunk while reusing the K/V tensors computed for prior chunks instead of re-running the full attention over the whole prefix on every step. API: - ``WanKVCache(num_blocks, window_size=-1)`` — one cache per transformer instance. ``window_size=-1`` keeps the full prefix; a finite window evicts the oldest tokens once the cap is reached. - ``cache.enable_append_mode()`` / ``cache.enable_overwrite_mode()`` — pick the write semantics for the next forward pass. Append grows the cache (or rolls when full); overwrite replaces the newest chunk in place — used for additional denoising steps that re-do the most recent chunk. - ``cache.update(block_idx, key, value)`` — called from ``WanAttnProcessor`` during self-attention to merge the current chunk into the per-block cache and return the K/V to attend over. - ``cache.reset()`` — clear all blocks between videos. Wan plumbing: - ``WanTransformer3DModel.forward`` accepts ``frame_offset: int = 0`` and forwards ``kv_cache`` (extracted from ``attention_kwargs``) plus ``block_idx`` to each transformer block. - ``WanRotaryPosEmbed.forward`` takes ``frame_offset`` so RoPE can address positions in the original (uncached) sequence even when the latent input is just one chunk. - ``WanAttnProcessor.__call__`` receives ``kv_cache`` / ``block_idx``; on self-attention it calls ``cache.update(...)`` and uses the returned K/V for SDPA. Cross-attention is unaffected. Caller usage:: cache = WanKVCache(num_blocks=len(transformer.blocks)) for chunk_idx, latent_chunk in enumerate(chunks): cache.enable_append_mode() for step_idx, t in enumerate(denoising_steps): if step_idx > 0: cache.enable_overwrite_mode() transformer( hidden_states=latent_chunk, timestep=t, encoder_hidden_states=prompt_embeds, frame_offset=chunk_idx * patch_frames_per_chunk, attention_kwargs={"kv_cache": cache}, ) Tests cover unbounded append, windowed append (with eviction across one and multiple chunks), in-place overwrite of the newest chunk, the read-from-prior-context contract, reset, and frame_offset's effect on RoPE. --- .../en/api/models/wan_transformer_3d.md | 40 ++++ src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 8 +- src/diffusers/models/transformers/__init__.py | 2 +- .../models/transformers/transformer_wan.py | 167 +++++++++++++-- .../test_models_transformer_wan.py | 190 +++++++++++++++++- 6 files changed, 396 insertions(+), 15 deletions(-) diff --git a/docs/source/en/api/models/wan_transformer_3d.md b/docs/source/en/api/models/wan_transformer_3d.md index c218166584c6..a6516b88e33c 100644 --- a/docs/source/en/api/models/wan_transformer_3d.md +++ b/docs/source/en/api/models/wan_transformer_3d.md @@ -25,6 +25,46 @@ transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diff [[autodoc]] WanTransformer3DModel +## Rolling KV cache + +For autoregressive video generation that produces one chunk at a time, [`WanTransformer3DModel.forward`] accepts a `WanKVCache` instance via `attention_kwargs={"kv_cache": cache}`. The cache holds post-norm, post-RoPE self-attention K/V tensors from prior chunks so subsequent chunks attend over the full prefix without recomputing it. The chunk's RoPE positions are picked via the `frame_offset` argument on `forward`. + +The cache exposes two write modes that the caller toggles between denoising steps: + +- `enable_append_mode()` — the next forward pass appends the chunk's K/V to the cache; once the cache reaches `window_size`, the oldest tokens are evicted from the front. Use this for the first denoising step of every new chunk. +- `enable_overwrite_mode()` — the next forward pass replaces the newest `chunk_size` tokens in place. Use this for subsequent denoising steps within the same chunk so re-running the chunk doesn't grow the cache. + +```python +from diffusers import WanKVCache, WanTransformer3DModel + +transformer = WanTransformer3DModel.from_pretrained(...) +cache = WanKVCache(num_blocks=len(transformer.blocks)) + +for chunk_idx, latent_chunk in enumerate(chunks): + for step_idx, t in enumerate(denoising_steps): + if step_idx == 0: + cache.enable_append_mode() + else: + cache.enable_overwrite_mode() + transformer( + hidden_states=latent_chunk, + timestep=t, + encoder_hidden_states=prompt_embeds, + frame_offset=chunk_idx * patch_frames_per_chunk, + attention_kwargs={"kv_cache": cache}, + ) + +cache.reset() # between videos +``` + +## WanKVCache + +[[autodoc]] WanKVCache + +## WanKVBlockCache + +[[autodoc]] WanKVBlockCache + ## Transformer2DModelOutput [[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0c6083cafd0a..0d23d5d877ca 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -301,6 +301,8 @@ "UVit2DModel", "VQModel", "WanAnimateTransformer3DModel", + "WanKVBlockCache", + "WanKVCache", "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageControlNetModel", @@ -1117,6 +1119,8 @@ UVit2DModel, VQModel, WanAnimateTransformer3DModel, + WanKVBlockCache, + WanKVCache, WanTransformer3DModel, WanVACETransformer3DModel, ZImageControlNetModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dc772fcc6d0c..e7b6e996d452 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -129,7 +129,11 @@ _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] - _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] + _import_structure["transformers.transformer_wan"] = [ + "WanKVBlockCache", + "WanKVCache", + "WanTransformer3DModel", + ] _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"] @@ -261,6 +265,8 @@ Transformer2DModel, TransformerTemporalModel, WanAnimateTransformer3DModel, + WanKVBlockCache, + WanKVCache, WanTransformer3DModel, WanVACETransformer3DModel, ZImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index bbd7ecfa911b..fa983b166234 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -52,7 +52,7 @@ from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_temporal import TransformerTemporalModel - from .transformer_wan import WanTransformer3DModel + from .transformer_wan import WanKVBlockCache, WanKVCache, WanTransformer3DModel from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 5926bbb8e713..d6d058eb682c 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math +from dataclasses import dataclass from typing import Any import torch @@ -36,7 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): +def _get_qkv_projections(attn: WanAttention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): # encoder_hidden_states is only passed for cross-attention if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -56,7 +59,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco return query, key, value -def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor): +def _get_added_kv_projections(attn: WanAttention, encoder_hidden_states_img: torch.Tensor): if attn.fused_projections: key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) else: @@ -65,6 +68,115 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t return key_img, value_img +@dataclass +class WanKVBlockCache: + """Per-block rolling KV cache state for autoregressive WAN inference. + + ``cached_key`` and ``cached_value`` hold the post-norm, post-RoPE K/V from prior chunks + with shape ``(batch_size, cached_seq_len, num_heads, head_dim)``. + """ + + cached_key: torch.Tensor | None = None + cached_value: torch.Tensor | None = None + + def reset(self) -> None: + self.__init__() + + +class WanKVCache: + """Rolling KV cache for autoregressive WAN video generation. + + Holds a per-block ``WanKVBlockCache`` for every transformer block, plus shared + write-control state. Pass an instance via ``attention_kwargs`` on each transformer forward + call. ``WanAttnProcessor`` calls :py:meth:`update` to merge the current chunk's K/V into + the cache and get back the (possibly trimmed) attention K/V. + + TODO: cross-attention K/V projections are currently recomputed on every forward pass even + though the text embeddings are constant across chunks. A future change can add cross-attn + caching alongside the existing self-attn cache. + + Args: + num_blocks (`int`): Number of transformer blocks (``len(transformer.blocks)``). + window_size (`int`, defaults to ``-1``): Maximum cached tokens per block. ``-1`` keeps + the full prefix. + + Example: + + ```python + >>> cache = WanKVCache(num_blocks=len(transformer.blocks)) + >>> transformer(..., attention_kwargs={"kv_cache": cache}) + ``` + """ + + def __init__(self, num_blocks: int, window_size: int = -1): + self.block_caches: list[WanKVBlockCache] = [WanKVBlockCache() for _ in range(num_blocks)] + self.window_size: int = window_size + self.overwrite_newest: bool = False + + def enable_append_mode(self) -> None: + """Next forward pass appends the new chunk's K/V to the cache (cache grows, or oldest gets evicted).""" + self.overwrite_newest = False + + def enable_overwrite_mode(self) -> None: + """Next forward pass replaces the newest ``chunk_size`` tokens in place (cache size unchanged).""" + self.overwrite_newest = True + + def reset(self) -> None: + """Clear all cached K/V tensors and reset write-control state.""" + for bc in self.block_caches: + bc.reset() + self.overwrite_newest = False + + def update( + self, + block_idx: int, + new_key: torch.Tensor, + new_value: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Merge the current chunk's K/V into block ``block_idx``'s cache and return the + K/V that the self-attention should attend over. + + Two paths: + - **Overwrite-newest** (``overwrite_newest=True`` and the cache already holds at + least ``new_key.shape[1]`` tokens): write the new K/V *in place* into the trailing + positions of the existing tensor. No allocation, no concat. + - **Append** (default): concatenate the existing prefix with the new K/V, then trim + the oldest tokens from the front if the result exceeds ``window_size``. + """ + block_cache = self.block_caches[block_idx] + prefix_k = block_cache.cached_key + prefix_v = block_cache.cached_value + n = new_key.shape[1] + + if self.window_size > 0 and n > self.window_size: + raise RuntimeError(f"new chunk has {n} tokens, which exceeds window_size={self.window_size}.") + + if self.overwrite_newest: + if prefix_k is None or prefix_k.shape[1] < n: + raise RuntimeError( + "overwrite_newest requires the cache to already hold at least one chunk's worth of tokens " + f"(>= {n}); cached length is {0 if prefix_k is None else prefix_k.shape[1]}. " + "Use enable_append_mode() for the first write of a new chunk." + ) + # In-place update of the cached tensors; block_cache already references them. + prefix_k[:, -n:] = new_key + prefix_v[:, -n:] = new_value + return prefix_k, prefix_v + + if prefix_k is None: + block_cache.cached_key = new_key + block_cache.cached_value = new_value + return new_key, new_value + + keep_prefix = self.window_size - n if self.window_size > 0 else prefix_k.shape[1] + if keep_prefix > 0: + new_key = torch.cat([prefix_k[:, -keep_prefix:], new_key], dim=1) + new_value = torch.cat([prefix_v[:, -keep_prefix:], new_value], dim=1) + block_cache.cached_key = new_key + block_cache.cached_value = new_value + return new_key, new_value + + class WanAttnProcessor: _attention_backend = None _parallel_config = None @@ -77,11 +189,13 @@ def __init__(self): def __call__( self, - attn: "WanAttention", + attn: WanAttention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + kv_cache: WanKVCache | None = None, + block_idx: int | None = None, ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: @@ -117,6 +231,11 @@ def apply_rotary_emb( query = apply_rotary_emb(query, *rotary_emb) key = apply_rotary_emb(key, *rotary_emb) + # Self-attention rolling KV cache: merge the current chunk's K/V into the per-block + # cache and use the (possibly trimmed) result for attention. + if kv_cache is not None and encoder_hidden_states is None: + key, value = kv_cache.update(block_idx, key, value) + # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: @@ -392,7 +511,7 @@ def __init__( self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, frame_offset: int = 0) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w @@ -402,11 +521,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1) - freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_f = freqs_cos[0][frame_offset : frame_offset + ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) - freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_f = freqs_sin[0][frame_offset : frame_offset + ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) @@ -465,6 +584,8 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, rotary_emb: torch.Tensor, + kv_cache: WanKVCache | None = None, + block_idx: int | None = None, ) -> torch.Tensor: if temb.ndim == 4: # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) @@ -486,7 +607,14 @@ def forward( # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + attn_output = self.attn1( + norm_hidden_states, + None, + None, + rotary_emb, + kv_cache=kv_cache, + block_idx=block_idx, + ) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention @@ -634,6 +762,7 @@ def forward( encoder_hidden_states_image: torch.Tensor | None = None, return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, + frame_offset: int = 0, ) -> torch.Tensor | dict[str, torch.Tensor]: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size @@ -641,7 +770,8 @@ def forward( post_patch_height = height // p_h post_patch_width = width // p_w - rotary_emb = self.rope(hidden_states) + rotary_emb = self.rope(hidden_states, frame_offset=frame_offset) + kv_cache: WanKVCache | None = (attention_kwargs or {}).pop("kv_cache", None) hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) @@ -668,13 +798,26 @@ def forward( # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - for block in self.blocks: + for block_idx, block in enumerate(self.blocks): hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + kv_cache, + block_idx, ) else: - for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + for block_idx, block in enumerate(self.blocks): + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + kv_cache=kv_cache, + block_idx=block_idx, + ) # 5. Output norm, projection & unpatchify if temb.ndim == 3: diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index 60bba9dfbe18..e917148e779d 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest import torch -from diffusers import WanTransformer3DModel +from diffusers import WanKVCache, WanTransformer3DModel from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device @@ -235,3 +236,190 @@ def get_dummy_inputs(self): ), "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), } + + +class TestWanKVCache: + NUM_BLOCKS = 2 + TOKENS_PER_CHUNK = 4 + _CONFIG = { + "patch_size": [1, 2, 2], + "num_attention_heads": 2, + "attention_head_dim": 16, + "in_channels": 16, + "out_channels": 16, + "text_dim": 32, + "freq_dim": 32, + "ffn_dim": 64, + "num_layers": NUM_BLOCKS, + "cross_attn_norm": False, + "qk_norm": "rms_norm_across_heads", + "eps": 1e-6, + "rope_max_seq_len": 32, + } + + def setup_method(self): + self.transformer = WanTransformer3DModel.from_config(self._CONFIG).eval() + + def _make_chunk(self, seed): + # centered around zero so RoPE has a measurable effect (all-positive inputs → uniform attn) + n_lat, n_enc = 16 * 4 * 4, 10 * 32 + lat = (torch.arange(n_lat, dtype=torch.float32) - n_lat // 2 + seed * 7).reshape(1, 16, 1, 4, 4) / 50 + enc = (torch.arange(n_enc, dtype=torch.float32) - n_enc // 2 + seed * 7).reshape(1, 10, 32) / 50 + return lat, torch.zeros(1, dtype=torch.long), enc + + def _denoise_chunk(self, latents, timestep, encoder_hidden_states, *, cache, frame_offset=0): + with torch.no_grad(): + return self.transformer( + latents, + timestep, + encoder_hidden_states, + frame_offset=frame_offset, + return_dict=False, + attention_kwargs={"kv_cache": cache}, + )[0] + + def _cached_len(self, cache, block=0): + k = cache.block_caches[block].cached_key + return 0 if k is None else k.shape[1] + + def _cached_keys(self, cache, block=0): + return cache.block_caches[block].cached_key + + def _assert_equal(self, a, b): + assert torch.equal(a, b) + + def _assert_not_equal(self, a, b): + assert not torch.equal(a, b) + + def test_append_unbounded(self): + """mode=append, window_size=-1: cache grows by TOKENS_PER_CHUNK each call; existing prefix is never disturbed.""" + T = self.TOKENS_PER_CHUNK + cache = WanKVCache(num_blocks=self.NUM_BLOCKS, window_size=-1) + cache.enable_append_mode() + + self._denoise_chunk(*self._make_chunk(1), cache=cache) + assert self._cached_len(cache) == T + snap1 = self._cached_keys(cache).clone() + + self._denoise_chunk(*self._make_chunk(2), cache=cache) + assert self._cached_len(cache) == T * 2 + snap2 = self._cached_keys(cache).clone() + self._assert_equal(snap1[:, :T], snap2[:, :T]) + + self._denoise_chunk(*self._make_chunk(3), cache=cache) + assert self._cached_len(cache) == T * 3 + snap3 = self._cached_keys(cache).clone() + self._assert_equal(snap2[:, : 2 * T], snap3[:, : 2 * T]) + + self._denoise_chunk(*self._make_chunk(4), cache=cache) + assert self._cached_len(cache) == T * 4 + snap4 = self._cached_keys(cache).clone() + self._assert_equal(snap3[:, : 3 * T], snap4[:, : 3 * T]) + + def test_append_windowed_single_chunk(self): + """mode=append, window_size=T: each new chunk fully evicts the previous; cache stays at T tokens.""" + T = self.TOKENS_PER_CHUNK + cache = WanKVCache(num_blocks=self.NUM_BLOCKS, window_size=T) + cache.enable_append_mode() + + self._denoise_chunk(*self._make_chunk(1), cache=cache) + assert self._cached_len(cache) == T + snap1 = self._cached_keys(cache).clone() + + self._denoise_chunk(*self._make_chunk(2), cache=cache) + assert self._cached_len(cache) == T # eviction kept size at T + snap2 = self._cached_keys(cache).clone() + self._assert_not_equal(snap1[:, :T], snap2[:, :T]) # chunk 1 fully evicted + + self._denoise_chunk(*self._make_chunk(3), cache=cache) + assert self._cached_len(cache) == T # eviction kept size at T + snap3 = self._cached_keys(cache).clone() + self._assert_not_equal(snap2[:, :T], snap3[:, :T]) # chunk 2 fully evicted + + def test_append_windowed_three_chunks(self): + """mode=append, window_size=3*T: cache fills to 3 chunks then rolls — surviving chunks shift left by T per step.""" + T = self.TOKENS_PER_CHUNK + cache = WanKVCache(num_blocks=self.NUM_BLOCKS, window_size=3 * T) + cache.enable_append_mode() + + # Fill the window with chunks 1, 2, 3 (cache grows; no eviction yet) + self._denoise_chunk(*self._make_chunk(1), cache=cache) + assert self._cached_len(cache) == T + self._denoise_chunk(*self._make_chunk(2), cache=cache) + assert self._cached_len(cache) == 2 * T + self._denoise_chunk(*self._make_chunk(3), cache=cache) + assert self._cached_len(cache) == 3 * T + snap_full = self._cached_keys(cache).clone() # [chunk1, chunk2, chunk3] + + # Chunk 4: window full → chunk 1 evicted; chunks 2-3 shift left by T (size stays 3T) + self._denoise_chunk(*self._make_chunk(4), cache=cache) + assert self._cached_len(cache) == 3 * T + snap_after_4 = self._cached_keys(cache).clone() # [chunk2, chunk3, chunk4] + # chunks 2-3 (snap_full[T:3T]) now sit at positions [0:2T] in the new cache + self._assert_equal(snap_after_4[:, : 2 * T], snap_full[:, T : 3 * T]) + self._assert_not_equal( + snap_full[:, 2 * T : 3 * T], snap_after_4[:, 2 * T : 3 * T] + ) # last slot is chunk 4 (≠ chunk 3) + + # Chunk 5: chunk 2 evicted; chunks 3-4 shift left + self._denoise_chunk(*self._make_chunk(5), cache=cache) + assert self._cached_len(cache) == 3 * T + snap_after_5 = self._cached_keys(cache).clone() # [chunk3, chunk4, chunk5] + self._assert_equal(snap_after_5[:, : 2 * T], snap_after_4[:, T : 3 * T]) + self._assert_not_equal(snap_after_4[:, 2 * T : 3 * T], snap_after_5[:, 2 * T : 3 * T]) # last slot is chunk 5 + + def test_overwrite_end_replaces_last_chunk(self): + """mode=overwrite_end, window_size=-1: simulates Self-Forcing multi-step denoising — append a chunk, then re-write it in place.""" + T = self.TOKENS_PER_CHUNK + cache = WanKVCache(num_blocks=self.NUM_BLOCKS, window_size=-1) + + # Chunk 0: append (cache empty) + cache.enable_append_mode() + self._denoise_chunk(*self._make_chunk(1), cache=cache, frame_offset=0) + assert self._cached_len(cache) == T + snap_chunk0_v1 = self._cached_keys(cache).clone() + + # Re-run chunk 0 with different content via overwrite_end (subsequent denoising step) + cache.enable_overwrite_mode() + self._denoise_chunk(*self._make_chunk(99), cache=cache, frame_offset=0) + assert self._cached_len(cache) == T + snap_chunk0_v2 = self._cached_keys(cache).clone() + self._assert_not_equal(snap_chunk0_v1[:, :T], snap_chunk0_v2[:, :T]) # last (only) chunk replaced + + # Chunk 1: append (extends cache) + cache.enable_append_mode() + self._denoise_chunk(*self._make_chunk(2), cache=cache, frame_offset=1) + assert self._cached_len(cache) == 2 * T + snap_after_append = self._cached_keys(cache).clone() + self._assert_equal(snap_chunk0_v2[:, :T], snap_after_append[:, :T]) # chunk 0 untouched + + # Re-run chunk 1 with different content via overwrite_end + cache.enable_overwrite_mode() + self._denoise_chunk(*self._make_chunk(98), cache=cache, frame_offset=1) + assert self._cached_len(cache) == 2 * T + self._assert_equal(snap_after_append[:, :T], self._cached_keys(cache)[:, :T]) # chunk 0 untouched + self._assert_not_equal( + snap_after_append[:, T : 2 * T], self._cached_keys(cache)[:, T : 2 * T] + ) # chunk 1 replaced + + def test_uses_prior_context(self): + """The cache is read during forward pass: same input gives different output with vs. without context.""" + cache = WanKVCache(num_blocks=self.NUM_BLOCKS) + self._denoise_chunk(*self._make_chunk(1), cache=cache) + out_with = self._denoise_chunk(*self._make_chunk(2), cache=cache) + out_without = self._denoise_chunk(*self._make_chunk(2), cache=WanKVCache(num_blocks=self.NUM_BLOCKS)) + self._assert_not_equal(out_with, out_without) + + def test_reset(self): + """reset() sets keys, values, and offsets back to initial state across every block.""" + cache = WanKVCache(num_blocks=self.NUM_BLOCKS) + self._denoise_chunk(*self._make_chunk(1), cache=cache) + cache.reset() + assert all(bc.cached_key is None and bc.cached_value is None for bc in cache.block_caches) + + def test_frame_offset_affects_rope(self): + """frame_offset shifts RoPE positions; same chunk at different offsets produces different output.""" + chunk = self._make_chunk(42) + out_0 = self._denoise_chunk(*chunk, cache=WanKVCache(num_blocks=self.NUM_BLOCKS), frame_offset=0) + out_1 = self._denoise_chunk(*chunk, cache=WanKVCache(num_blocks=self.NUM_BLOCKS), frame_offset=1) + self._assert_not_equal(out_0, out_1)