From 6dee6dc972fe8ad50957ed8c89fff58e6a404927 Mon Sep 17 00:00:00 2001 From: Fabian Guera Date: Mon, 27 Apr 2026 13:51:51 -0700 Subject: [PATCH 1/4] Add rolling KV cache hook --- src/diffusers/__init__.py | 6 + src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/rolling_kv_cache.py | 345 ++++++++++++++++++++++ src/diffusers/models/cache_utils.py | 8 + tests/hooks/test_rolling_kv_cache.py | 377 ++++++++++++++++++++++++ 5 files changed, 737 insertions(+) create mode 100644 src/diffusers/hooks/rolling_kv_cache.py create mode 100644 tests/hooks/test_rolling_kv_cache.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2cbfd6e29305..d82b57b461bd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -167,6 +167,7 @@ "LayerSkipConfig", "MagCacheConfig", "PyramidAttentionBroadcastConfig", + "RollingKVCacheConfig", "SmoothedEnergyGuidanceConfig", "TaylorSeerCacheConfig", "TextKVCacheConfig", @@ -175,6 +176,8 @@ "apply_layer_skip", "apply_mag_cache", "apply_pyramid_attention_broadcast", + "apply_rolling_kv_cache", + "get_rolling_kv_cache_state", "apply_taylorseer_cache", "apply_text_kv_cache", ] @@ -979,6 +982,7 @@ LayerSkipConfig, MagCacheConfig, PyramidAttentionBroadcastConfig, + RollingKVCacheConfig, SmoothedEnergyGuidanceConfig, TaylorSeerCacheConfig, TextKVCacheConfig, @@ -987,6 +991,8 @@ apply_layer_skip, apply_mag_cache, apply_pyramid_attention_broadcast, + apply_rolling_kv_cache, + get_rolling_kv_cache_state, apply_taylorseer_cache, apply_text_kv_cache, ) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 2a9aa81608e7..0d4305a24874 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -25,6 +25,7 @@ from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .mag_cache import MagCacheConfig, apply_mag_cache from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .rolling_kv_cache import RollingKVCacheConfig, apply_rolling_kv_cache, get_rolling_kv_cache_state from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache diff --git a/src/diffusers/hooks/rolling_kv_cache.py b/src/diffusers/hooks/rolling_kv_cache.py new file mode 100644 index 000000000000..13d31b45aa19 --- /dev/null +++ b/src/diffusers/hooks/rolling_kv_cache.py @@ -0,0 +1,345 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from ..models.attention_dispatch import dispatch_attention_fn +from ..utils import logging +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = logging.get_logger(__name__) + + +_ROLLING_KV_CACHE_HOOK = "rolling_kv_cache" +_ROLLING_KV_WRITE_MODES = {"append", "overwrite"} +_TESTED_ATTENTION_CLASSES = frozenset({"WanAttention"}) + + +@dataclass +class RollingKVCacheConfig: + r"""Configuration for rolling self-attention KV caching during autoregressive inference. + + Args: + window_size (`int`, defaults to `-1`): + Maximum number of cached self-attention tokens to keep. Set to `-1` to keep the full prefix. + """ + + window_size: int = -1 + + +class RollingKVAttentionProcessor: + r"""Default attention preprocessor used by the rolling KV cache hook. + + The defaults target Wan-style self-attention modules. To support a model with a different + attention layout — most often a different rotary embedding form — subclass and override the + relevant method, then pass the instance via `apply_rolling_kv_cache(..., attention_processor=...)`. + """ + + def prepare_qkv( + self, + attn: torch.nn.Module, + hidden_states: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if getattr(attn, "fused_projections", False): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + query = self.apply_rotary_emb(query, *rotary_emb) + key = self.apply_rotary_emb(key, *rotary_emb) + + return query, key, value + + def apply_rotary_emb( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ) -> torch.Tensor: + hidden_states_complex = torch.view_as_complex( + hidden_states.to(torch.float64).reshape(*hidden_states.shape[:-1], -1, 2) + ) + freqs_complex = torch.complex( + freqs_cos[..., 0::2].to(torch.float64), + freqs_sin[..., 0::2].to(torch.float64), + ) + out = torch.view_as_real(hidden_states_complex * freqs_complex).flatten(-2) + return out.type_as(hidden_states) + + def post_attention( + self, + attn: torch.nn.Module, + attn_output: torch.Tensor, + query: torch.Tensor, + ) -> torch.Tensor: + out = attn_output.flatten(2, 3).type_as(query) + out = attn.to_out[0](out) + out = attn.to_out[1](out) + return out + + def get_attention_backend(self, attn: torch.nn.Module): + processor = getattr(attn, "processor", None) + return getattr(processor, "_attention_backend", None) + + +class RollingKVCacheState(BaseState): + r"""Shared state controlling how the rolling KV cache is updated.""" + + def __init__(self): + self.should_update_cache = True + self.write_mode = "append" + self.absolute_token_offset: int | None = None + + def configure_cache_write(self, write_mode: str = "append", absolute_token_offset: int | None = None) -> None: + if write_mode not in _ROLLING_KV_WRITE_MODES: + raise ValueError( + f"`write_mode` must be one of {sorted(_ROLLING_KV_WRITE_MODES)}, but received {write_mode!r}." + ) + if write_mode == "append" and absolute_token_offset is not None: + raise ValueError("`absolute_token_offset` is only supported with `write_mode='overwrite'`.") + if write_mode == "overwrite" and absolute_token_offset is None: + raise ValueError("`absolute_token_offset` must be provided when `write_mode='overwrite'`.") + if absolute_token_offset is not None and absolute_token_offset < 0: + raise ValueError("`absolute_token_offset` must be >= 0.") + + self.write_mode = write_mode + self.absolute_token_offset = absolute_token_offset + + def clear_cache_write(self) -> None: + self.write_mode = "append" + self.absolute_token_offset = None + + def reset(self): + self.should_update_cache = True + self.clear_cache_write() + + +class RollingKVCacheBlockState(BaseState): + r"""Per-attention-block self-attention cache state.""" + + def __init__(self): + self.cached_key: torch.Tensor | None = None + self.cached_value: torch.Tensor | None = None + self.cache_start_token_offset = 0 + + def reset(self): + self.cached_key = None + self.cached_value = None + self.cache_start_token_offset = 0 + + +def _ensure_state(state_manager: StateManager): + if state_manager._current_context is None: + state_manager.set_context("inference") + return state_manager.get_state() + + +def _slice_cache_for_overwrite( + block_state: RollingKVCacheBlockState, + absolute_token_offset: int, +) -> tuple[torch.Tensor | None, torch.Tensor | None, int]: + cached_key = block_state.cached_key + cached_value = block_state.cached_value + cache_start = block_state.cache_start_token_offset + + if cached_key is None: + return None, None, absolute_token_offset + + cache_end = cache_start + cached_key.shape[1] + if absolute_token_offset > cache_end: + raise ValueError( + "`absolute_token_offset` points beyond the retained cache prefix. Reset the cache or prefill the " + "missing chunks before appending new ones." + ) + if absolute_token_offset < cache_start: + return None, None, absolute_token_offset + + prefix_length = absolute_token_offset - cache_start + return cached_key[:, :prefix_length], cached_value[:, :prefix_length], cache_start + + +def _trim_cache_to_window( + key: torch.Tensor, + value: torch.Tensor, + cache_start_token_offset: int, + window_size: int, +) -> tuple[torch.Tensor, torch.Tensor, int]: + if window_size > 0 and key.shape[1] > window_size: + trim = key.shape[1] - window_size + key = key[:, trim:] + value = value[:, trim:] + cache_start_token_offset += trim + + return key.detach(), value.detach(), cache_start_token_offset + + +def _is_self_attention_module(module: torch.nn.Module) -> bool: + if getattr(module, "is_cross_attention", False): + return False + + required_attrs = ("to_q", "to_k", "to_v", "to_out", "heads", "norm_q", "norm_k") + return all(hasattr(module, attr) for attr in required_attrs) + + +class RollingKVCacheHook(ModelHook): + _is_stateful = True + + def __init__( + self, + config: RollingKVCacheConfig, + state_manager: StateManager, + block_state_manager: StateManager, + attention_processor: RollingKVAttentionProcessor, + ): + super().__init__() + self.config = config + self.state_manager = state_manager + self.block_state_manager = block_state_manager + self.attention_processor = attention_processor + + def new_forward( + self, + module: torch.nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + if encoder_hidden_states is not None: + raise ValueError("Rolling KV cache only supports self-attention modules.") + + shared_state: RollingKVCacheState = _ensure_state(self.state_manager) + block_state: RollingKVCacheBlockState = _ensure_state(self.block_state_manager) + proc = self.attention_processor + + query, key, value = proc.prepare_qkv(module, hidden_states, rotary_emb) + + if shared_state.write_mode == "overwrite": + cached_key, cached_value, prefix_start = _slice_cache_for_overwrite( + block_state, shared_state.absolute_token_offset + ) + else: + cached_key = block_state.cached_key + cached_value = block_state.cached_value + prefix_start = block_state.cache_start_token_offset + + if cached_key is not None: + if cached_key.shape[0] != key.shape[0]: + raise ValueError( + f"Rolling KV cache batch size mismatch (cached={cached_key.shape[0]}, current={key.shape[0]}). " + "Use `cache_context` to isolate cond/uncond passes or reset the cache before changing batch size." + ) + full_key = torch.cat([cached_key, key], dim=1) + full_value = torch.cat([cached_value, value], dim=1) + else: + full_key = key + full_value = value + + if shared_state.should_update_cache: + ( + block_state.cached_key, + block_state.cached_value, + block_state.cache_start_token_offset, + ) = _trim_cache_to_window( + full_key, + full_value, + prefix_start, + self.config.window_size, + ) + + attn_output = dispatch_attention_fn( + query, + full_key, + full_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=proc.get_attention_backend(module), + ) + return proc.post_attention(module, attn_output, query) + + def reset_state(self, module: torch.nn.Module): + self.state_manager.reset() + self.block_state_manager.reset() + return module + + +def apply_rolling_kv_cache( + module: torch.nn.Module, + config: RollingKVCacheConfig | None = None, + attention_processor: RollingKVAttentionProcessor | None = None, +) -> None: + r"""Apply rolling KV cache hooks to compatible self-attention modules. + + The default `attention_processor` targets Wan-style attention modules. Pass a custom + `RollingKVAttentionProcessor` subclass for models with a different rotary embedding form + or projection layout. + """ + if config is None: + config = RollingKVCacheConfig() + if attention_processor is None: + attention_processor = RollingKVAttentionProcessor() + + state_manager = StateManager(RollingKVCacheState) + HookRegistry.check_if_exists_or_initialize(module) + + warned_classes: set[str] = set() + for submodule in module.modules(): + if not _is_self_attention_module(submodule): + continue + + cls_name = type(submodule).__name__ + if cls_name not in _TESTED_ATTENTION_CLASSES and cls_name not in warned_classes: + warned_classes.add(cls_name) + logger.warning( + "apply_rolling_kv_cache: attaching to '%s' which is untested. The default " + "RollingKVAttentionProcessor targets Wan-style attention; if outputs look wrong, " + "subclass it (in particular `apply_rotary_emb`) and pass via `attention_processor=`.", + cls_name, + ) + + block_state_manager = StateManager(RollingKVCacheBlockState) + hook = RollingKVCacheHook(config, state_manager, block_state_manager, attention_processor) + registry = HookRegistry.check_if_exists_or_initialize(submodule) + registry.register_hook(hook, _ROLLING_KV_CACHE_HOOK) + + +def get_rolling_kv_cache_state(module: torch.nn.Module) -> RollingKVCacheState | None: + r"""Return the shared rolling KV cache state for a hooked module.""" + for submodule in module.modules(): + if not hasattr(submodule, "_diffusers_hook"): + continue + + hook = submodule._diffusers_hook.get_hook(_ROLLING_KV_CACHE_HOOK) + if hook is not None: + return _ensure_state(hook.state_manager) + + return None diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 161fcf426f21..6f62a92c72d6 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -71,12 +71,14 @@ def enable_cache(self, config) -> None: FirstBlockCacheConfig, MagCacheConfig, PyramidAttentionBroadcastConfig, + RollingKVCacheConfig, TaylorSeerCacheConfig, TextKVCacheConfig, apply_faster_cache, apply_first_block_cache, apply_mag_cache, apply_pyramid_attention_broadcast, + apply_rolling_kv_cache, apply_taylorseer_cache, apply_text_kv_cache, ) @@ -96,6 +98,8 @@ def enable_cache(self, config) -> None: apply_text_kv_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, RollingKVCacheConfig): + apply_rolling_kv_cache(self, config) elif isinstance(config, TaylorSeerCacheConfig): apply_taylorseer_cache(self, config) else: @@ -110,6 +114,7 @@ def disable_cache(self) -> None: HookRegistry, MagCacheConfig, PyramidAttentionBroadcastConfig, + RollingKVCacheConfig, TaylorSeerCacheConfig, TextKVCacheConfig, ) @@ -117,6 +122,7 @@ def disable_cache(self) -> None: from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK + from ..hooks.rolling_kv_cache import _ROLLING_KV_CACHE_HOOK from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK from ..hooks.text_kv_cache import _TEXT_KV_CACHE_BLOCK_HOOK, _TEXT_KV_CACHE_TRANSFORMER_HOOK @@ -136,6 +142,8 @@ def disable_cache(self) -> None: registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, RollingKVCacheConfig): + registry.remove_hook(_ROLLING_KV_CACHE_HOOK, recurse=True) elif isinstance(self._cache_config, TextKVCacheConfig): registry.remove_hook(_TEXT_KV_CACHE_TRANSFORMER_HOOK, recurse=True) registry.remove_hook(_TEXT_KV_CACHE_BLOCK_HOOK, recurse=True) diff --git a/tests/hooks/test_rolling_kv_cache.py b/tests/hooks/test_rolling_kv_cache.py new file mode 100644 index 000000000000..41aaf46d1e5c --- /dev/null +++ b/tests/hooks/test_rolling_kv_cache.py @@ -0,0 +1,377 @@ +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest + +import torch + +from diffusers.hooks import RollingKVCacheConfig, apply_rolling_kv_cache, get_rolling_kv_cache_state +from diffusers.hooks.rolling_kv_cache import ( + _ROLLING_KV_CACHE_HOOK, + RollingKVAttentionProcessor, + RollingKVCacheBlockState, + RollingKVCacheState, +) +from diffusers.models.cache_utils import CacheMixin + + +_DEVICE = torch.device("cpu") +_HEAD_DIM = 4 + + +# ---------- Fake self-attention so cache contents read out as input tokens ---------- + + +class _IdentitySelfAttention(torch.nn.Module): + """Self-attention stub with identity Q/K/V/norm/out projections. + + Because every projection is identity, whatever scalar token value goes in comes back out + inside `cached_key` / `cached_value`. That makes the cache directly inspectable in tests. + """ + + def __init__(self): + super().__init__() + self.heads = 1 + self.is_cross_attention = False + self.to_q = torch.nn.Identity() + self.to_k = torch.nn.Identity() + self.to_v = torch.nn.Identity() + self.norm_q = torch.nn.Identity() + self.norm_k = torch.nn.Identity() + self.to_out = torch.nn.ModuleList([torch.nn.Identity(), torch.nn.Identity()]) + + +class _FakeTransformer(torch.nn.Module, CacheMixin): + """Tiny CacheMixin wrapper so `cache_context(...)` works in the tests.""" + + def __init__(self): + super().__init__() + self.attn = _IdentitySelfAttention() + + def forward(self, hidden_states): + return self.attn(hidden_states) + + +def _make_transformer(window_size: int = -1) -> _FakeTransformer: + transformer = _FakeTransformer().to(_DEVICE).eval() + # Silence the "untested attention class" warning the hook emits for non-Wan modules — + # exercising that warning is covered by `test_warns_when_attaching_to_untested_class`. + logging.getLogger("diffusers.hooks.rolling_kv_cache").setLevel(logging.ERROR) + apply_rolling_kv_cache(transformer, RollingKVCacheConfig(window_size=window_size)) + logging.getLogger("diffusers.hooks.rolling_kv_cache").setLevel(logging.WARNING) + return transformer + + +def _ramp(values) -> torch.Tensor: + """Build (1, len(values), HEAD_DIM) where token i has every entry equal to values[i].""" + return ( + torch.tensor(values, dtype=torch.float32, device=_DEVICE) + .reshape(1, -1, 1) + .expand(1, -1, _HEAD_DIM) + .contiguous() + ) + + +def _cached_token_values(cached: torch.Tensor) -> list[float]: + """Read scalar token values from a (B, S, H, D) cache tensor along S.""" + return cached[0, :, 0, 0].tolist() + + +def _block_state(transformer: _FakeTransformer) -> RollingKVCacheBlockState: + hook = transformer.attn._diffusers_hook.get_hook(_ROLLING_KV_CACHE_HOOK) + if hook.block_state_manager._current_context is None: + hook.block_state_manager.set_context("inference") + return hook.block_state_manager.get_state() + + +# ---------- State-class unit tests ---------- + + +class TestStateClasses(unittest.TestCase): + def test_rolling_kv_cache_state_defaults(self): + state = RollingKVCacheState() + self.assertTrue(state.should_update_cache) + self.assertEqual(state.write_mode, "append") + self.assertIsNone(state.absolute_token_offset) + + def test_rolling_kv_cache_state_reset(self): + state = RollingKVCacheState() + state.should_update_cache = False + state.configure_cache_write(write_mode="overwrite", absolute_token_offset=8) + state.reset() + self.assertTrue(state.should_update_cache) + self.assertEqual(state.write_mode, "append") + self.assertIsNone(state.absolute_token_offset) + + def test_configure_cache_write_rejects_offset_in_append_mode(self): + with self.assertRaises(ValueError): + RollingKVCacheState().configure_cache_write(write_mode="append", absolute_token_offset=4) + + def test_configure_cache_write_requires_offset_in_overwrite_mode(self): + with self.assertRaises(ValueError): + RollingKVCacheState().configure_cache_write(write_mode="overwrite") + + def test_block_state_reset(self): + state = RollingKVCacheBlockState() + state.cached_key = torch.randn(1, 4, 2, _HEAD_DIM) + state.cached_value = torch.randn(1, 4, 2, _HEAD_DIM) + state.cache_start_token_offset = 16 + state.reset() + self.assertIsNone(state.cached_key) + self.assertIsNone(state.cached_value) + self.assertEqual(state.cache_start_token_offset, 0) + + +# ---------- Rotary helper unit tests ---------- + + +class TestRotaryEmb(unittest.TestCase): + def setUp(self): + self.processor = RollingKVAttentionProcessor() + + def test_output_shape_and_dtype_preserved(self): + x = torch.randn(1, 4, 2, 16, dtype=torch.bfloat16) + freqs_cos = torch.ones(1, 4, 1, 16) + freqs_sin = torch.zeros(1, 4, 1, 16) + + out = self.processor.apply_rotary_emb(x, freqs_cos, freqs_sin) + + self.assertEqual(out.shape, x.shape) + self.assertEqual(out.dtype, x.dtype) + torch.testing.assert_close(out, x) + + def test_matches_complex_reference(self): + x = torch.randn(1, 4, 2, 16, dtype=torch.bfloat16) + freqs_cos = torch.randn(1, 4, 1, 16, dtype=torch.float32) + freqs_sin = torch.randn(1, 4, 1, 16, dtype=torch.float32) + + expected = torch.view_as_real( + torch.view_as_complex(x.to(torch.float64).reshape(*x.shape[:-1], -1, 2)) + * torch.complex(freqs_cos[..., 0::2].to(torch.float64), freqs_sin[..., 0::2].to(torch.float64)) + ).flatten(-2) + expected = expected.to(x.dtype) + + out = self.processor.apply_rotary_emb(x, freqs_cos, freqs_sin) + + torch.testing.assert_close(out, expected) + + +# ---------- Cache mechanics, expressed as token values ---------- + + +class TestRollingKVCacheMechanics(unittest.TestCase): + def test_first_pass_caches_input_tokens_verbatim(self): + transformer = _make_transformer() + + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) + + state = _block_state(transformer) + self.assertEqual(_cached_token_values(state.cached_key), [0.0, 1.0, 2.0]) + self.assertEqual(_cached_token_values(state.cached_value), [0.0, 1.0, 2.0]) + self.assertEqual(state.cache_start_token_offset, 0) + + def test_second_pass_appends_to_existing_cache(self): + transformer = _make_transformer() + + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) + transformer(_ramp([3, 4, 5])) + + state = _block_state(transformer) + self.assertEqual(_cached_token_values(state.cached_key), [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) + self.assertEqual(state.cache_start_token_offset, 0) + + def test_window_trims_oldest_tokens_and_advances_start_offset(self): + transformer = _make_transformer(window_size=4) + + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) + transformer(_ramp([3, 4, 5])) + + # Cache reached length 6, window keeps the most recent 4. Start offset moves to 2 so that + # `cache_start_token_offset + cached_len` is still the absolute end of the stream (=6). + state = _block_state(transformer) + self.assertEqual(_cached_token_values(state.cached_key), [2.0, 3.0, 4.0, 5.0]) + self.assertEqual(state.cache_start_token_offset, 2) + + def test_should_update_cache_false_freezes_cache(self): + transformer = _make_transformer() + cache_state = get_rolling_kv_cache_state(transformer) + + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) + + cache_state.should_update_cache = False + with torch.no_grad(): + transformer(_ramp([3, 4, 5])) + + # Forward still ran — but the new chunk did NOT enter the cache. + state = _block_state(transformer) + self.assertEqual(_cached_token_values(state.cached_key), [0.0, 1.0, 2.0]) + + def test_overwrite_replaces_suffix_at_absolute_offset(self): + transformer = _make_transformer() + cache_state = get_rolling_kv_cache_state(transformer) + + # Three chunks → cache holds tokens 0..8 at absolute positions 0..8. + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) + transformer(_ramp([3, 4, 5])) + transformer(_ramp([6, 7, 8])) + + # Rewind to absolute offset 3 and replace the [3,4,5] + [6,7,8] suffix with new values. + cache_state.configure_cache_write(write_mode="overwrite", absolute_token_offset=3) + try: + with torch.no_grad(): + transformer(_ramp([90, 91, 92])) + finally: + cache_state.clear_cache_write() + + state = _block_state(transformer) + # Tokens at absolute 6..8 are gone; new tokens land at absolute 3..5. + self.assertEqual(_cached_token_values(state.cached_key), [0.0, 1.0, 2.0, 90.0, 91.0, 92.0]) + self.assertEqual(state.cache_start_token_offset, 0) + + def test_overwrite_at_cache_start_drops_entire_prefix(self): + transformer = _make_transformer() + cache_state = get_rolling_kv_cache_state(transformer) + + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) + transformer(_ramp([3, 4, 5])) + + cache_state.configure_cache_write(write_mode="overwrite", absolute_token_offset=0) + try: + with torch.no_grad(): + transformer(_ramp([7, 8])) + finally: + cache_state.clear_cache_write() + + state = _block_state(transformer) + self.assertEqual(_cached_token_values(state.cached_key), [7.0, 8.0]) + self.assertEqual(state.cache_start_token_offset, 0) + + def test_overwrite_past_cache_end_raises(self): + transformer = _make_transformer() + cache_state = get_rolling_kv_cache_state(transformer) + + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) + + # Cache covers absolute positions 0..2 (cache_end=3). Asking to overwrite at 5 leaves a hole. + cache_state.configure_cache_write(write_mode="overwrite", absolute_token_offset=5) + with self.assertRaisesRegex(ValueError, "beyond the retained cache prefix"): + with torch.no_grad(): + transformer(_ramp([9, 10])) + + def test_cache_context_isolates_cond_and_uncond(self): + transformer = _make_transformer() + + with torch.no_grad(), transformer.cache_context("cond"): + transformer(_ramp([0, 1, 2])) + self.assertEqual(_cached_token_values(_block_state(transformer).cached_key), [0.0, 1.0, 2.0]) + + with torch.no_grad(), transformer.cache_context("uncond"): + transformer(_ramp([10, 20, 30])) + self.assertEqual( + _cached_token_values(_block_state(transformer).cached_key), [10.0, 20.0, 30.0] + ) + + # Re-entering "cond" sees the cond cache untouched, then appends. + with torch.no_grad(), transformer.cache_context("cond"): + transformer(_ramp([3, 4, 5])) + self.assertEqual( + _cached_token_values(_block_state(transformer).cached_key), + [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], + ) + + def test_reset_stateful_hooks_clears_cache(self): + transformer = _make_transformer() + + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) + self.assertIsNotNone(_block_state(transformer).cached_key) + + transformer._diffusers_hook.reset_stateful_hooks() + + state = _block_state(transformer) + self.assertIsNone(state.cached_key) + self.assertIsNone(state.cached_value) + self.assertEqual(state.cache_start_token_offset, 0) + + def test_batch_size_mismatch_raises(self): + transformer = _make_transformer() + + with torch.no_grad(): + transformer(_ramp([0, 1, 2])) # cached batch = 1 + + # New chunk with batch=2 by stacking the same ramp; cached cache was batch=1. + chunk = _ramp([3, 4, 5]).expand(2, -1, -1).contiguous() + with self.assertRaisesRegex(ValueError, "batch size mismatch"): + with torch.no_grad(): + transformer(chunk) + + +# ---------- Integration: real Wan attention selection + warning ---------- + + +class TestApplyRollingKVCacheOnWan(unittest.TestCase): + """One sanity check that the duck-typed self-attn detection actually picks WanAttention.""" + + def test_hooks_attach_to_self_attention_only(self): + from diffusers import WanTransformer3DModel + from diffusers.models.transformers.transformer_wan import WanTransformerBlock + + config = { + "patch_size": [1, 2, 2], + "num_attention_heads": 2, + "attention_head_dim": 16, + "in_channels": 16, + "out_channels": 16, + "text_dim": 32, + "freq_dim": 32, + "ffn_dim": 64, + "num_layers": 2, + "cross_attn_norm": False, + "qk_norm": "rms_norm_across_heads", + "eps": 1e-6, + "image_dim": None, + "added_kv_proj_dim": None, + "rope_max_seq_len": 32, + } + torch.manual_seed(0) + transformer = WanTransformer3DModel.from_config(config).to(_DEVICE).eval() + apply_rolling_kv_cache(transformer, RollingKVCacheConfig(window_size=-1)) + + blocks = [m for m in transformer.modules() if isinstance(m, WanTransformerBlock)] + self.assertEqual(len(blocks), config["num_layers"]) + for block in blocks: + self.assertIsNotNone(block.attn1._diffusers_hook.get_hook(_ROLLING_KV_CACHE_HOOK)) + if hasattr(block.attn2, "_diffusers_hook"): + self.assertIsNone(block.attn2._diffusers_hook.get_hook(_ROLLING_KV_CACHE_HOOK)) + + def test_warns_when_attaching_to_untested_class(self): + # _IdentitySelfAttention is not in the tested set, so apply_rolling_kv_cache must warn. + with self.assertLogs("diffusers.hooks.rolling_kv_cache", level="WARNING") as captured: + apply_rolling_kv_cache(_IdentitySelfAttention(), RollingKVCacheConfig(window_size=-1)) + self.assertTrue( + any("_IdentitySelfAttention" in msg and "untested" in msg for msg in captured.output), + captured.output, + ) + + +if __name__ == "__main__": + unittest.main() From b128cd9ee6b2f7cb07ea204ab342eb1e1fc85a2a Mon Sep 17 00:00:00 2001 From: Fabian Guera Date: Mon, 27 Apr 2026 22:39:15 +0100 Subject: [PATCH 2/4] Note pinned sink frame follow-up --- src/diffusers/hooks/rolling_kv_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/hooks/rolling_kv_cache.py b/src/diffusers/hooks/rolling_kv_cache.py index 13d31b45aa19..e98fa7295356 100644 --- a/src/diffusers/hooks/rolling_kv_cache.py +++ b/src/diffusers/hooks/rolling_kv_cache.py @@ -192,6 +192,7 @@ def _trim_cache_to_window( window_size: int, ) -> tuple[torch.Tensor, torch.Tensor, int]: if window_size > 0 and key.shape[1] > window_size: + # TODO: support pinned sink frames when rolling the cache window. trim = key.shape[1] - window_size key = key[:, trim:] value = value[:, trim:] From 5302dd185c75e7ab91d507823bd4805825131b1f Mon Sep 17 00:00:00 2001 From: Fabian Guera Date: Mon, 27 Apr 2026 22:52:35 +0100 Subject: [PATCH 3/4] Align rolling KV RoPE with Wan --- src/diffusers/hooks/rolling_kv_cache.py | 10 ++++++---- tests/hooks/test_rolling_kv_cache.py | 18 +++++++++++++++++- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/rolling_kv_cache.py b/src/diffusers/hooks/rolling_kv_cache.py index e98fa7295356..afc63db707f7 100644 --- a/src/diffusers/hooks/rolling_kv_cache.py +++ b/src/diffusers/hooks/rolling_kv_cache.py @@ -46,9 +46,11 @@ class RollingKVCacheConfig: class RollingKVAttentionProcessor: r"""Default attention preprocessor used by the rolling KV cache hook. - The defaults target Wan-style self-attention modules. To support a model with a different - attention layout — most often a different rotary embedding form — subclass and override the - relevant method, then pass the instance via `apply_rolling_kv_cache(..., attention_processor=...)`. + The defaults target Wan-style self-attention modules. The default rotary embedding path mirrors + WanAttnProcessor while staying local to avoid a hook dependency on Wan private helpers; override + it for other layouts. To support a model with a different attention layout — most often a + different rotary embedding form — subclass and override the relevant method, then pass the + instance via `apply_rolling_kv_cache(..., attention_processor=...)`. """ def prepare_qkv( @@ -88,7 +90,7 @@ def apply_rotary_emb( ) freqs_complex = torch.complex( freqs_cos[..., 0::2].to(torch.float64), - freqs_sin[..., 0::2].to(torch.float64), + freqs_sin[..., 1::2].to(torch.float64), ) out = torch.view_as_real(hidden_states_complex * freqs_complex).flatten(-2) return out.type_as(hidden_states) diff --git a/tests/hooks/test_rolling_kv_cache.py b/tests/hooks/test_rolling_kv_cache.py index 41aaf46d1e5c..14001ac2b0a8 100644 --- a/tests/hooks/test_rolling_kv_cache.py +++ b/tests/hooks/test_rolling_kv_cache.py @@ -159,7 +159,7 @@ def test_matches_complex_reference(self): expected = torch.view_as_real( torch.view_as_complex(x.to(torch.float64).reshape(*x.shape[:-1], -1, 2)) - * torch.complex(freqs_cos[..., 0::2].to(torch.float64), freqs_sin[..., 0::2].to(torch.float64)) + * torch.complex(freqs_cos[..., 0::2].to(torch.float64), freqs_sin[..., 1::2].to(torch.float64)) ).flatten(-2) expected = expected.to(x.dtype) @@ -324,6 +324,22 @@ def test_batch_size_mismatch_raises(self): with torch.no_grad(): transformer(chunk) + def test_cache_mixin_enable_disable_cache(self): + transformer = _FakeTransformer().to(_DEVICE).eval() + + logging.getLogger("diffusers.hooks.rolling_kv_cache").setLevel(logging.ERROR) + transformer.enable_cache(RollingKVCacheConfig(window_size=4)) + logging.getLogger("diffusers.hooks.rolling_kv_cache").setLevel(logging.WARNING) + + self.assertTrue(transformer.is_cache_enabled) + self.assertIsNotNone(transformer.attn._diffusers_hook.get_hook(_ROLLING_KV_CACHE_HOOK)) + self.assertIsNotNone(get_rolling_kv_cache_state(transformer)) + + transformer.disable_cache() + + self.assertFalse(transformer.is_cache_enabled) + self.assertIsNone(transformer.attn._diffusers_hook.get_hook(_ROLLING_KV_CACHE_HOOK)) + # ---------- Integration: real Wan attention selection + warning ---------- From 5e76115ba625d060a69bd47e0561f7232b631e7d Mon Sep 17 00:00:00 2001 From: Fabian Guera Date: Mon, 27 Apr 2026 22:54:18 +0100 Subject: [PATCH 4/4] Clarify rolling KV rotary helper --- src/diffusers/hooks/rolling_kv_cache.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/rolling_kv_cache.py b/src/diffusers/hooks/rolling_kv_cache.py index afc63db707f7..7042b11c5247 100644 --- a/src/diffusers/hooks/rolling_kv_cache.py +++ b/src/diffusers/hooks/rolling_kv_cache.py @@ -46,11 +46,9 @@ class RollingKVCacheConfig: class RollingKVAttentionProcessor: r"""Default attention preprocessor used by the rolling KV cache hook. - The defaults target Wan-style self-attention modules. The default rotary embedding path mirrors - WanAttnProcessor while staying local to avoid a hook dependency on Wan private helpers; override - it for other layouts. To support a model with a different attention layout — most often a - different rotary embedding form — subclass and override the relevant method, then pass the - instance via `apply_rolling_kv_cache(..., attention_processor=...)`. + The defaults target Wan-style self-attention modules. To support a model with a different + attention layout — most often a different rotary embedding form — subclass and override the + relevant method, then pass the instance via `apply_rolling_kv_cache(..., attention_processor=...)`. """ def prepare_qkv( @@ -85,6 +83,7 @@ def apply_rotary_emb( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ) -> torch.Tensor: + r"""Apply Wan-style rotary embeddings without depending on Wan private helpers.""" hidden_states_complex = torch.view_as_complex( hidden_states.to(torch.float64).reshape(*hidden_states.shape[:-1], -1, 2) )