Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/source/en/api/models/wan_transformer_3d.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,46 @@ transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diff

[[autodoc]] WanTransformer3DModel

## Rolling KV cache

For autoregressive video generation that produces one chunk at a time, [`WanTransformer3DModel.forward`] accepts a `WanKVCache` instance via `attention_kwargs={"kv_cache": cache}`. The cache holds post-norm, post-RoPE self-attention K/V tensors from prior chunks so subsequent chunks attend over the full prefix without recomputing it. The chunk's RoPE positions are picked via the `frame_offset` argument on `forward`.

The cache exposes two write modes that the caller toggles between denoising steps:

- `enable_append_mode()` — the next forward pass appends the chunk's K/V to the cache; once the cache reaches `window_size`, the oldest tokens are evicted from the front. Use this for the first denoising step of every new chunk.
- `enable_overwrite_mode()` — the next forward pass replaces the newest `chunk_size` tokens in place. Use this for subsequent denoising steps within the same chunk so re-running the chunk doesn't grow the cache.

```python
from diffusers import WanKVCache, WanTransformer3DModel

transformer = WanTransformer3DModel.from_pretrained(...)
cache = WanKVCache(num_blocks=len(transformer.blocks))

for chunk_idx, latent_chunk in enumerate(chunks):
for step_idx, t in enumerate(denoising_steps):
if step_idx == 0:
cache.enable_append_mode()
else:
cache.enable_overwrite_mode()
transformer(
hidden_states=latent_chunk,
timestep=t,
encoder_hidden_states=prompt_embeds,
frame_offset=chunk_idx * patch_frames_per_chunk,
attention_kwargs={"kv_cache": cache},
)

cache.reset() # between videos
```

## WanKVCache

[[autodoc]] WanKVCache

## WanKVBlockCache

[[autodoc]] WanKVBlockCache

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@
"UVit2DModel",
"VQModel",
"WanAnimateTransformer3DModel",
"WanKVBlockCache",
"WanKVCache",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"ZImageControlNetModel",
Expand Down Expand Up @@ -1117,6 +1119,8 @@
UVit2DModel,
VQModel,
WanAnimateTransformer3DModel,
WanKVBlockCache,
WanKVCache,
WanTransformer3DModel,
WanVACETransformer3DModel,
ZImageControlNetModel,
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan"] = [
"WanKVBlockCache",
"WanKVCache",
"WanTransformer3DModel",
]
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
Expand Down Expand Up @@ -261,6 +265,8 @@
Transformer2DModel,
TransformerTemporalModel,
WanAnimateTransformer3DModel,
WanKVBlockCache,
WanKVCache,
WanTransformer3DModel,
WanVACETransformer3DModel,
ZImageTransformer2DModel,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan import WanKVBlockCache, WanKVCache, WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
from .transformer_z_image import ZImageTransformer2DModel
167 changes: 155 additions & 12 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any

import torch
Expand All @@ -36,7 +39,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
def _get_qkv_projections(attn: WanAttention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
# encoder_hidden_states is only passed for cross-attention
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
Expand All @@ -56,7 +59,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
return query, key, value


def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
def _get_added_kv_projections(attn: WanAttention, encoder_hidden_states_img: torch.Tensor):
if attn.fused_projections:
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
else:
Expand All @@ -65,6 +68,115 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
return key_img, value_img


@dataclass
class WanKVBlockCache:
"""Per-block rolling KV cache state for autoregressive WAN inference.

``cached_key`` and ``cached_value`` hold the post-norm, post-RoPE K/V from prior chunks
with shape ``(batch_size, cached_seq_len, num_heads, head_dim)``.
"""

cached_key: torch.Tensor | None = None
cached_value: torch.Tensor | None = None

def reset(self) -> None:
self.__init__()


class WanKVCache:
"""Rolling KV cache for autoregressive WAN video generation.

Holds a per-block ``WanKVBlockCache`` for every transformer block, plus shared
write-control state. Pass an instance via ``attention_kwargs`` on each transformer forward
call. ``WanAttnProcessor`` calls :py:meth:`update` to merge the current chunk's K/V into
the cache and get back the (possibly trimmed) attention K/V.

TODO: cross-attention K/V projections are currently recomputed on every forward pass even
though the text embeddings are constant across chunks. A future change can add cross-attn
caching alongside the existing self-attn cache.

Args:
num_blocks (`int`): Number of transformer blocks (``len(transformer.blocks)``).
window_size (`int`, defaults to ``-1``): Maximum cached tokens per block. ``-1`` keeps
the full prefix.

Example:

```python
>>> cache = WanKVCache(num_blocks=len(transformer.blocks))
>>> transformer(..., attention_kwargs={"kv_cache": cache})
```
"""

def __init__(self, num_blocks: int, window_size: int = -1):
self.block_caches: list[WanKVBlockCache] = [WanKVBlockCache() for _ in range(num_blocks)]
self.window_size: int = window_size
self.overwrite_newest: bool = False

def enable_append_mode(self) -> None:
"""Next forward pass appends the new chunk's K/V to the cache (cache grows, or oldest gets evicted)."""
self.overwrite_newest = False

def enable_overwrite_mode(self) -> None:
"""Next forward pass replaces the newest ``chunk_size`` tokens in place (cache size unchanged)."""
self.overwrite_newest = True

def reset(self) -> None:
"""Clear all cached K/V tensors and reset write-control state."""
for bc in self.block_caches:
bc.reset()
self.overwrite_newest = False

def update(
self,
block_idx: int,
new_key: torch.Tensor,
new_value: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Merge the current chunk's K/V into block ``block_idx``'s cache and return the
K/V that the self-attention should attend over.

Two paths:
- **Overwrite-newest** (``overwrite_newest=True`` and the cache already holds at
least ``new_key.shape[1]`` tokens): write the new K/V *in place* into the trailing
positions of the existing tensor. No allocation, no concat.
- **Append** (default): concatenate the existing prefix with the new K/V, then trim
the oldest tokens from the front if the result exceeds ``window_size``.
"""
block_cache = self.block_caches[block_idx]
prefix_k = block_cache.cached_key
prefix_v = block_cache.cached_value
n = new_key.shape[1]

if self.window_size > 0 and n > self.window_size:
raise RuntimeError(f"new chunk has {n} tokens, which exceeds window_size={self.window_size}.")

if self.overwrite_newest:
if prefix_k is None or prefix_k.shape[1] < n:
raise RuntimeError(
"overwrite_newest requires the cache to already hold at least one chunk's worth of tokens "
f"(>= {n}); cached length is {0 if prefix_k is None else prefix_k.shape[1]}. "
"Use enable_append_mode() for the first write of a new chunk."
)
# In-place update of the cached tensors; block_cache already references them.
prefix_k[:, -n:] = new_key
prefix_v[:, -n:] = new_value
return prefix_k, prefix_v

if prefix_k is None:
block_cache.cached_key = new_key
block_cache.cached_value = new_value
return new_key, new_value

keep_prefix = self.window_size - n if self.window_size > 0 else prefix_k.shape[1]
if keep_prefix > 0:
new_key = torch.cat([prefix_k[:, -keep_prefix:], new_key], dim=1)
new_value = torch.cat([prefix_v[:, -keep_prefix:], new_value], dim=1)
block_cache.cached_key = new_key
block_cache.cached_value = new_value
return new_key, new_value


class WanAttnProcessor:
_attention_backend = None
_parallel_config = None
Expand All @@ -77,11 +189,13 @@ def __init__(self):

def __call__(
self,
attn: "WanAttention",
attn: WanAttention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
kv_cache: WanKVCache | None = None,
block_idx: int | None = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
Expand Down Expand Up @@ -117,6 +231,11 @@ def apply_rotary_emb(
query = apply_rotary_emb(query, *rotary_emb)
key = apply_rotary_emb(key, *rotary_emb)

# Self-attention rolling KV cache: merge the current chunk's K/V into the per-block
# cache and use the (possibly trimmed) result for attention.
if kv_cache is not None and encoder_hidden_states is None:
key, value = kv_cache.update(block_idx, key, value)

# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
Expand Down Expand Up @@ -392,7 +511,7 @@ def __init__(
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, frame_offset: int = 0) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
Expand All @@ -402,11 +521,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_f = freqs_cos[0][frame_offset : frame_offset + ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_f = freqs_sin[0][frame_offset : frame_offset + ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

Expand Down Expand Up @@ -465,6 +584,8 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
kv_cache: WanKVCache | None = None,
block_idx: int | None = None,
) -> torch.Tensor:
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
Expand All @@ -486,7 +607,14 @@ def forward(

# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
attn_output = self.attn1(
norm_hidden_states,
None,
None,
rotary_emb,
kv_cache=kv_cache,
block_idx=block_idx,
)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)

# 2. Cross-attention
Expand Down Expand Up @@ -634,14 +762,16 @@ def forward(
encoder_hidden_states_image: torch.Tensor | None = None,
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
frame_offset: int = 0,
) -> torch.Tensor | dict[str, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w

rotary_emb = self.rope(hidden_states)
rotary_emb = self.rope(hidden_states, frame_offset=frame_offset)
kv_cache: WanKVCache | None = (attention_kwargs or {}).pop("kv_cache", None)

hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
Expand All @@ -668,13 +798,26 @@ def forward(

# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
for block_idx, block in enumerate(self.blocks):
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
block,
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
kv_cache,
block_idx,
)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
for block_idx, block in enumerate(self.blocks):
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
kv_cache=kv_cache,
block_idx=block_idx,
)

# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
Expand Down
Loading
Loading