diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index c36f5a70358..9ea1bcf52f1 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -286,6 +286,11 @@ def _from_config( self.freqs_sin = freqs[1].to(dtype) split_mha = config.attention_type in ("static", "static_shas") + # YOCO: skip cache creation for KV-shared layers (2nd half) + num_kv_shared = getattr(config, "num_kv_shared_layers", 0) + first_kv_shared = ( + config.n_layers - num_kv_shared if num_kv_shared > 0 else config.n_layers + ) if split_mha: self.k_caches = { StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( @@ -296,7 +301,7 @@ def _from_config( ) for layer_id in range(config.n_layers) for head_id in range(none_throws(config.n_kv_heads)) - if cache_lens[layer_id] > 0 + if cache_lens[layer_id] > 0 and layer_id < first_kv_shared } self.v_caches = { StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( @@ -307,7 +312,7 @@ def _from_config( ) for layer_id in range(config.n_layers) for head_id in range(none_throws(config.n_kv_heads)) - if cache_lens[layer_id] > 0 + if cache_lens[layer_id] > 0 and layer_id < first_kv_shared } else: self.k_caches = { @@ -319,7 +324,7 @@ def _from_config( dtype=dtype, ) for layer_id in range(config.n_layers) - if cache_lens[layer_id] > 0 + if cache_lens[layer_id] > 0 and layer_id < first_kv_shared } self.v_caches = { StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros( @@ -330,7 +335,7 @@ def _from_config( dtype=dtype, ) for layer_id in range(config.n_layers) - if cache_lens[layer_id] > 0 + if cache_lens[layer_id] > 0 and layer_id < first_kv_shared } self.generate_full_logits = config.generate_full_logits @@ -360,6 +365,8 @@ def _from_model( self.k_caches = {} self.v_caches = {} for attn in static_attentions: + if attn.is_kv_shared_layer: + continue if attn.split_mha: for head_id in range(attn.n_heads): cache_key = StaticKVCache.calculate_cache_key( @@ -761,6 +768,7 @@ def __init__( layer_id: int, rope: Rope, split_mha: bool = True, + is_kv_shared_layer: bool = False, **kwargs: Any, ): super().__init__() @@ -778,6 +786,8 @@ def __init__( self.use_qk_norm = config.use_qk_norm self.qk_norm_before_rope = config.qk_norm_before_rope self.split_mha = split_mha + self.is_kv_shared_layer = is_kv_shared_layer + self.num_kv_shared_layers = config.num_kv_shared_layers self.use_conv2d = False self.enable_qnn_masked_softmax = kwargs.get("enable_qnn_masked_softmax", False) @@ -792,25 +802,31 @@ def __init__( for _ in range(self.n_heads) ] ) - self.wks = nn.ModuleList( - [ - nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) - for _ in range(self.n_kv_heads) - ] - ) - self.wvs = nn.ModuleList( - [ - nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) - for _ in range(self.n_kv_heads) - ] - ) + if is_kv_shared_layer: + self.wks = nn.ModuleList() + self.wvs = nn.ModuleList() + self.k_caches = nn.ModuleList() + self.v_caches = nn.ModuleList() + else: + self.wks = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) + for _ in range(self.n_kv_heads) + ] + ) + self.wvs = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) + for _ in range(self.n_kv_heads) + ] + ) - self.k_caches = nn.ModuleList( - [StaticKCache(layer_id, i) for i in range(self.n_kv_heads)] - ) - self.v_caches = nn.ModuleList( - [StaticVCache(layer_id, i) for i in range(self.n_kv_heads)] - ) + self.k_caches = nn.ModuleList( + [StaticKCache(layer_id, i) for i in range(self.n_kv_heads)] + ) + self.v_caches = nn.ModuleList( + [StaticVCache(layer_id, i) for i in range(self.n_kv_heads)] + ) else: has_lora = config.target_modules is not None _PROJ_TARGET = { @@ -819,6 +835,9 @@ def __init__( "wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads), } for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items(): + if is_kv_shared_layer and attr in ("wks", "wvs"): + setattr(self, attr, nn.ModuleList()) + continue if has_lora and target in config.target_modules: proj = LoRALinear( in_dim=in_dim, @@ -831,9 +850,19 @@ def __init__( proj = nn.Linear(in_dim, out_dim, bias=self.attention_qkv_bias) setattr(self, attr, nn.ModuleList([proj])) - self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)]) - self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)]) + if is_kv_shared_layer: + self.k_caches = nn.ModuleList() + self.v_caches = nn.ModuleList() + else: + self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)]) + self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)]) + + self._init_wo(config) + self.rope = _Rope(rope.params) + self.layer_id = layer_id + self._init_qk_norms(config, is_kv_shared_layer) + def _init_wo(self, config: ModelArgs) -> None: wo_use_lora = ( not self.split_mha and config.target_modules is not None @@ -852,12 +881,14 @@ def __init__( ) else: self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) - self.rope = _Rope(rope.params) - self.layer_id = layer_id + def _init_qk_norms(self, config: ModelArgs, is_kv_shared_layer: bool) -> None: if self.use_qk_norm: self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) - self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + if is_kv_shared_layer: + self.k_norm = nn.Identity() + else: + self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) else: self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity() @@ -870,10 +901,14 @@ def from_attention_mha( rms_norm_class=torch.nn.RMSNorm, **kwargs: Any, ) -> "StaticAttention": - has_lora = any( - isinstance(proj, LoRALinear) - for proj in [other.wq, other.wk, other.wv, other.wo] - ) + is_kv_shared = getattr(other, "is_kv_shared_layer", False) + + lora_projs = [other.wq, other.wo] + if other.wk is not None: + lora_projs.append(other.wk) + if other.wv is not None: + lora_projs.append(other.wv) + has_lora = any(isinstance(proj, LoRALinear) for proj in lora_projs) if has_lora and split_mha: raise ValueError( @@ -893,6 +928,7 @@ def from_attention_mha( use_qk_norm=other.use_qk_norm, qk_norm_before_rope=other.qk_norm_before_rope, norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5, + num_kv_shared_layers=getattr(other, "num_kv_shared_layers", 0), ) instance = cls( @@ -900,42 +936,36 @@ def from_attention_mha( layer_id=other.layer_id, rope=other.rope, split_mha=split_mha, + is_kv_shared_layer=is_kv_shared, **kwargs, ) # Replace nn.Linear with LoRALinear where the source uses LoRA. if has_lora: - for attr, proj, in_dim, out_dim, bias in [ - ( - "wqs", - other.wq, - other.dim, - other.n_heads * other.head_dim, - other.attention_qkv_bias, - ), - ( - "wks", - other.wk, - other.dim, - other.n_kv_heads * other.head_dim, - other.attention_qkv_bias, - ), - ( - "wvs", - other.wv, - other.dim, - other.n_kv_heads * other.head_dim, - other.attention_qkv_bias, - ), - ]: - if isinstance(proj, LoRALinear): - getattr(instance, attr)[0] = LoRALinear( - in_dim=in_dim, - out_dim=out_dim, - rank=proj.rank, - alpha=proj.alpha, - use_bias=bias, - ) + # Always handle wq LoRA + if isinstance(other.wq, LoRALinear): + instance.wqs[0] = LoRALinear( + in_dim=other.dim, + out_dim=other.n_heads * other.head_dim, + rank=other.wq.rank, + alpha=other.wq.alpha, + use_bias=other.attention_qkv_bias, + ) + # Only handle wk/wv LoRA for non-shared layers + if not is_kv_shared: + for attr, proj, out_dim in [ + ("wks", other.wk, other.n_kv_heads * other.head_dim), + ("wvs", other.wv, other.n_kv_heads * other.head_dim), + ]: + if isinstance(proj, LoRALinear): + getattr(instance, attr)[0] = LoRALinear( + in_dim=other.dim, + out_dim=out_dim, + rank=proj.rank, + alpha=proj.alpha, + use_bias=other.attention_qkv_bias, + ) + # Always handle wo LoRA if isinstance(other.wo, LoRALinear): instance.wo = LoRALinear( in_dim=other.n_heads * other.head_dim, @@ -966,8 +996,15 @@ def forward( x = x.reshape(bsz, -1, 1, dim).transpose(1, 3) new_qs = [wq(x) for wq in self.wqs] - new_ks = [wk(x) for wk in self.wks] - new_vs = [wv(x) for wv in self.wvs] + + shared_kv = kwargs.get("shared_kv") + if self.is_kv_shared_layer: + assert shared_kv is not None + new_ks = [] + new_vs = [] + else: + new_ks = [wk(x) for wk in self.wks] + new_vs = [wv(x) for wv in self.wvs] if self.use_conv2d: @@ -975,11 +1012,13 @@ def from_conv2ds(ts): return [t.reshape(bsz, self.head_dim, -1).transpose(1, 2) for t in ts] new_qs = from_conv2ds(new_qs) - new_ks = from_conv2ds(new_ks) - new_vs = from_conv2ds(new_vs) + if new_ks: + new_ks = from_conv2ds(new_ks) + if new_vs: + new_vs = from_conv2ds(new_vs) if self.split_mha: - y, out_cache_state = self._forward_sha( + y, out_cache_state, kv_to_share = self._forward_sha( new_qs, new_ks, new_vs, @@ -989,10 +1028,10 @@ def from_conv2ds(ts): **kwargs, ) else: - y, out_cache_state = self._forward_mha( + y, out_cache_state, kv_to_share = self._forward_mha( new_qs[0], - new_ks[0], - new_vs[0], + new_ks[0] if new_ks else None, + new_vs[0] if new_vs else None, freqs_cos, freqs_sin, bsz, @@ -1011,36 +1050,48 @@ def from_conv2ds(ts): else: y = self.wo(y) - return y, {"out_cache_state": out_cache_state} + update = {"out_cache_state": out_cache_state} + if kv_to_share is not None: + update["kv_to_share"] = kv_to_share + return y, update - def _forward_sha( - self, - new_qs, - new_ks, - new_vs, - freqs_cos, - freqs_sin, - seq_len, - **kwargs: ForwardOptions, - ): - if (freqs_cos_override := kwargs.get("freqs_cos_override")) is not None: - freqs_cos = freqs_cos_override # pyre-ignore - if (freqs_sin_override := kwargs.get("freqs_sin_override")) is not None: - freqs_sin = freqs_sin_override # pyre-ignore - in_cache_state = kwargs.get("in_cache_state") - out_cache_state = kwargs.get("out_cache_state") - - if self.use_qk_norm and self.qk_norm_before_rope: - new_qs = [self.q_norm(q) for q in new_qs] - new_ks = [self.k_norm(k) for k in new_ks] - - new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs] - new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks] - - if self.use_qk_norm and not self.qk_norm_before_rope: - new_qs = [self.q_norm(q) for q in new_qs] - new_ks = [self.k_norm(k) for k in new_ks] + def _apply_qk_norm(self, qs, ks=None, before_rope=False): + """Apply QK normalization before or after RoPE. + Args: + qs (list): List of queries. + ks (list, optional): List of keys. Defaults to None. + before_rope (bool, optional): Whether to apply normalization before RoPE. Defaults to False. + """ + if self.use_qk_norm and before_rope == self.qk_norm_before_rope: + qs = [self.q_norm(q) for q in qs] + if ks is not None: + ks = [self.k_norm(k) for k in ks] + return qs, ks + + def _apply_rope(self, qs, ks, freqs_cos, freqs_sin): + """Apply RoPE to queries and keys. + + Args: + qs (list): List of queries. + ks (list, optional): List of keys. Defaults to None. + freqs_cos (list): List of cosine frequencies. + freqs_sin (list): List of sine frequencies. + """ + qs = [self.rope(q, freqs_cos, freqs_sin) for q in qs] + if ks is not None: + ks = [self.rope(k, freqs_cos, freqs_sin) for k in ks] + return qs, ks + + def _update_kv_cache(self, new_ks, new_vs, in_cache_state, out_cache_state): + """Update KV cache. + + Args: + new_ks (list): List of new keys. + new_vs (list): List of new values. + in_cache_state (object): Initial cache state. + out_cache_state (object): Output cache state. + """ all_ks = [] all_vs = [] for i in range(self.n_kv_heads if self.split_mha else 1): @@ -1052,27 +1103,175 @@ def _forward_sha( new_vs[i], in_cache_state, out_cache_state ) all_vs.append(vs) + return all_ks, all_vs, out_cache_state + + def _compute_attention(self, q, k, v, mask, enable_qnn_masked_softmax=False): + """Compute attention. + + Args: + q (torch.Tensor): Query. + k (torch.Tensor): Key. + v (torch.Tensor): Value. + mask (torch.Tensor): Mask. + enable_qnn_masked_softmax (bool, optional): Whether to enable QNN masked softmax. Defaults to False. + """ + attn = q @ k.transpose(-2, -1) + attn = attn * self.inv_scale + if enable_qnn_masked_softmax: + attn_min = torch.amin(attn, dim=-1, keepdim=True) + minus_value = -20 + attn = torch.where(mask == 0, attn, attn_min + minus_value) # prye-ignore + else: + attn = attn + mask + attn = F.softmax(attn, dim=-1) + return attn @ v + + def _process_shared_kv( + self, new_qs, shared_kv, freqs_cos, freqs_sin, seq_len, masks + ): + """Process shared KV. + + Args: + new_qs (list): List of new queries. + shared_kv (tuple): Shared KV. + freqs_cos (list): List of cosine frequencies. + freqs_sin (list): List of sine frequencies. + seq_len (int): Sequence length. + masks (list): List of masks. + """ + # Apply normalization before RoPE if configured + new_qs, _ = self._apply_qk_norm(new_qs, None, before_rope=True) + + # Apply RoPE to queries + new_qs, _ = self._apply_rope(new_qs, None, freqs_cos, freqs_sin) + + # Apply normalization after RoPE if configured + new_qs, _ = self._apply_qk_norm(new_qs, None, before_rope=False) + + k_shared, v_shared = shared_kv + cache_len = k_shared.size(-2) - seq_len + mask = masks[cache_len] + + heads = [] + for i in range(self.n_heads): + kv_idx = i // self.n_heads_per_kv_group + k_head = k_shared[:, kv_idx : kv_idx + 1, :, :].squeeze(1) + v_head = v_shared[:, kv_idx : kv_idx + 1, :, :].squeeze(1) + heads.append( + self._compute_attention( + new_qs[i], k_head, v_head, mask, self.enable_qnn_masked_softmax + ) + ) + + return torch.cat(heads, dim=-1) + + def _process_normal_kv( + self, + new_qs, + new_ks, + new_vs, + freqs_cos, + freqs_sin, + in_cache_state, + out_cache_state, + masks, + seq_len, + ): + """Process normal KV. + + Args: + new_qs (list): List of new queries. + new_ks (list): List of new keys. + new_vs (list): List of new values. + freqs_cos (list): List of cosine frequencies. + freqs_sin (list): List of sine frequencies. + in_cache_state (object): Initial cache state. + out_cache_state (object): Output cache state. + masks (list): List of masks. + seq_len (int): Sequence length. + """ + # Apply normalization before RoPE if configured + new_qs, new_ks = self._apply_qk_norm(new_qs, new_ks, before_rope=True) + + # Apply RoPE + new_qs, new_ks = self._apply_rope(new_qs, new_ks, freqs_cos, freqs_sin) + + # Apply normalization after RoPE if configured + new_qs, new_ks = self._apply_qk_norm(new_qs, new_ks, before_rope=False) + + # Update KV cache + all_ks, all_vs, out_cache_state = self._update_kv_cache( + new_ks, new_vs, in_cache_state, out_cache_state + ) cache_len = all_ks[0].size(-2) - seq_len - mask = kwargs["masks"][cache_len] + mask = masks[cache_len] heads = [] for i in range(self.n_heads): kv_idx = i // self.n_heads_per_kv_group - attn = new_qs[i] @ all_ks[kv_idx].transpose(-2, -1) - attn = attn * self.inv_scale - if self.enable_qnn_masked_softmax: - attn_min = torch.amin(attn, dim=-1, keepdim=True) - minus_value = -20 - attn = torch.where( - mask == 0, attn, attn_min + minus_value - ) # prye-ignore - else: - attn = attn + mask - attn = F.softmax(attn, dim=-1) - heads.append(attn @ all_vs[kv_idx]) + heads.append( + self._compute_attention( + new_qs[i], + all_ks[kv_idx], + all_vs[kv_idx], + mask, + self.enable_qnn_masked_softmax, + ) + ) + + kv_to_share = None + if self.num_kv_shared_layers > 0: + kv_to_share = (torch.cat(all_ks, dim=1), torch.cat(all_vs, dim=1)) + + return torch.cat(heads, dim=-1), out_cache_state, kv_to_share + + def _forward_sha( + self, + new_qs, + new_ks, + new_vs, + freqs_cos, + freqs_sin, + seq_len, + **kwargs: ForwardOptions, + ): + """Forward pass for SHA. + + Args: + new_qs (list): List of new queries. + new_ks (list): List of new keys. + new_vs (list): List of new values. + freqs_cos (list): List of cosine frequencies. + freqs_sin (list): List of sine frequencies. + seq_len (int): Sequence length. + **kwargs: ForwardOptions. + """ + if (freqs_cos_override := kwargs.get("freqs_cos_override")) is not None: + freqs_cos = freqs_cos_override # pyre-ignore + if (freqs_sin_override := kwargs.get("freqs_sin_override")) is not None: + freqs_sin = freqs_sin_override # pyre-ignore + in_cache_state = kwargs.get("in_cache_state") + out_cache_state = kwargs.get("out_cache_state") + shared_kv = kwargs.get("shared_kv") - return torch.cat(heads, dim=-1), out_cache_state + if shared_kv is not None: + result = self._process_shared_kv( + new_qs, shared_kv, freqs_cos, freqs_sin, seq_len, kwargs["masks"] + ) + return result, out_cache_state, None + else: + return self._process_normal_kv( + new_qs, + new_ks, + new_vs, + freqs_cos, + freqs_sin, + in_cache_state, + out_cache_state, + kwargs["masks"], + seq_len, + ) def _forward_mha( self, @@ -1087,24 +1286,47 @@ def _forward_mha( ): in_cache_state = kwargs.get("in_cache_state") out_cache_state = kwargs.get("out_cache_state") + shared_kv = kwargs.get("shared_kv") - q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) - v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) + if shared_kv is not None: + q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm(q) - k = self.k_norm(k) + if self.use_qk_norm and self.qk_norm_before_rope: + q = self.q_norm(q) - q = self.rope(q, freqs_cos, freqs_sin) - k = self.rope(k, freqs_cos, freqs_sin) + q = self.rope(q, freqs_cos, freqs_sin) - if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm(q) - k = self.k_norm(k) + if self.use_qk_norm and not self.qk_norm_before_rope: + q = self.q_norm(q) - k, out_cache_state = self.k_caches[0].update(k, in_cache_state, out_cache_state) - v, out_cache_state = self.v_caches[0].update(v, in_cache_state, out_cache_state) + k, v = shared_kv + else: + q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) + + if self.use_qk_norm and self.qk_norm_before_rope: + q = self.q_norm(q) + k = self.k_norm(k) + + q = self.rope(q, freqs_cos, freqs_sin) + k = self.rope(k, freqs_cos, freqs_sin) + + if self.use_qk_norm and not self.qk_norm_before_rope: + q = self.q_norm(q) + k = self.k_norm(k) + + k, out_cache_state = self.k_caches[0].update( + k, in_cache_state, out_cache_state + ) + v, out_cache_state = self.v_caches[0].update( + v, in_cache_state, out_cache_state + ) + + # YOCO: Store KV for sharing if this is a non-shared layer and YOCO is enabled + kv_to_share = ( + (k, v) if shared_kv is None and self.num_kv_shared_layers > 0 else None + ) mask = None masks = kwargs.get("masks") @@ -1157,7 +1379,11 @@ def _forward_mha( # Ungroup y y = y_grouped.view(1, self.n_heads, Tq, D) - return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state + return ( + y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), + out_cache_state, + kv_to_share, + ) def load_weights_from_attention_mha( self, other: AttentionMHA, rms_norm_class=torch.nn.RMSNorm @@ -1169,19 +1395,21 @@ def load_weights_from_attention_mha( other.wq.weight[i * self.head_dim : (i + 1) * self.head_dim, :] ) - for i in range(self.n_kv_heads): - self.wks[i].weight.data.copy_( - # pyre-ignore[29] - other.wk.weight[i * self.head_dim : (i + 1) * self.head_dim, :] - ) - self.wvs[i].weight.data.copy_( - # pyre-ignore[29] - other.wv.weight[i * self.head_dim : (i + 1) * self.head_dim, :] - ) + if not self.is_kv_shared_layer: + for i in range(self.n_kv_heads): + self.wks[i].weight.data.copy_( + # pyre-ignore[29] + other.wk.weight[i * self.head_dim : (i + 1) * self.head_dim, :] + ) + self.wvs[i].weight.data.copy_( + # pyre-ignore[29] + other.wv.weight[i * self.head_dim : (i + 1) * self.head_dim, :] + ) else: self.wqs[0].load_state_dict(other.wq.state_dict()) - self.wks[0].load_state_dict(other.wk.state_dict()) - self.wvs[0].load_state_dict(other.wv.state_dict()) + if not self.is_kv_shared_layer: + self.wks[0].load_state_dict(other.wk.state_dict()) + self.wvs[0].load_state_dict(other.wv.state_dict()) self.wo.load_state_dict(other.wo.state_dict()) @@ -1192,10 +1420,15 @@ def load_weights_from_attention_mha( other.q_norm_fn.weight.dtype ) self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) - self.k_norm = rms_norm_class(other.k_norm_fn.dim, other.k_norm_fn.eps).to( - other.k_norm_fn.weight.dtype - ) - self.k_norm.load_state_dict(other.k_norm_fn.state_dict()) + if ( + not self.is_kv_shared_layer + and hasattr(other, "k_norm_fn") + and other.k_norm_fn is not None + ): + self.k_norm = rms_norm_class( + other.k_norm_fn.dim, other.k_norm_fn.eps + ).to(other.k_norm_fn.weight.dtype) + self.k_norm.load_state_dict(other.k_norm_fn.state_dict()) def adopt_hf_rope(self): if self.rope.use_hf_rope: diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 9c5554ab299..b3af06593fa 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -458,3 +458,167 @@ def test_lora_partial_projections(self): ) y, _ = static_attn(x, freqs_cos, freqs_sin, masks={0: mask}) self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) + + # --- YOCO tests --- + + def _make_yoco_args(self, n_layers=4, num_kv_shared_layers=2): + return ModelArgs( + dim=64, + n_heads=4, + n_kv_heads=2, + head_dim=16, + max_batch_size=1, + max_context_len=32, + max_seq_len=8, + enable_dynamic_shape=True, + n_layers=n_layers, + num_kv_shared_layers=num_kv_shared_layers, + ) + + def test_yoco_shared_layer_no_wk_wv(self): + config = self._make_yoco_args(n_layers=4, num_kv_shared_layers=2) + rope = Rope(config) + attn_mha = AttentionMHA(config, layer_id=2, rope=rope) + static_attn = StaticAttention.from_attention_mha(attn_mha, split_mha=False) + + self.assertTrue(static_attn.is_kv_shared_layer) + self.assertEqual(len(static_attn.wks), 0) + self.assertEqual(len(static_attn.wvs), 0) + self.assertEqual(len(static_attn.k_caches), 0) + self.assertEqual(len(static_attn.v_caches), 0) + self.assertEqual(len(static_attn.wqs), 1) + self.assertIsNotNone(static_attn.wo) + + def test_yoco_donor_layer_has_wk_wv(self): + config = self._make_yoco_args(n_layers=4, num_kv_shared_layers=2) + rope = Rope(config) + attn_mha = AttentionMHA(config, layer_id=1, rope=rope) + static_attn = StaticAttention.from_attention_mha(attn_mha, split_mha=False) + + self.assertFalse(static_attn.is_kv_shared_layer) + self.assertEqual(len(static_attn.wks), 1) + self.assertEqual(len(static_attn.wvs), 1) + self.assertEqual(len(static_attn.k_caches), 1) + self.assertEqual(len(static_attn.v_caches), 1) + + def test_yoco_shared_layer_forward_with_shared_kv(self): + config = self._make_yoco_args(n_layers=4, num_kv_shared_layers=2) + rope = Rope(config) + attn_mha = AttentionMHA(config, layer_id=2, rope=rope) + static_attn = StaticAttention.from_attention_mha( + attn_mha, split_mha=False + ).eval() + + x = torch.rand(1, config.max_seq_len, config.dim) + freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) + shared_kv = ( + torch.randn(1, config.n_kv_heads, config.max_seq_len, config.head_dim), + torch.randn(1, config.n_kv_heads, config.max_seq_len, config.head_dim), + ) + mask = torch.triu( + torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")), + diagonal=1, + ) + + y, update = static_attn( + x, freqs_cos, freqs_sin, masks={0: mask}, shared_kv=shared_kv + ) + + self.assertEqual(y.shape, (1, config.max_seq_len, config.dim)) + self.assertIsNone(update["out_cache_state"]) + + def test_yoco_lora_with_shared_layer(self): + config = self._make_yoco_args(n_layers=4, num_kv_shared_layers=2) + config.r = 4 + config.lora_alpha = 8 + config.target_modules = ["q_proj", "o_proj"] + + rope = Rope(config) + attn_mha = AttentionMHA(config, layer_id=2, rope=rope) + + self.assertIsInstance(attn_mha.wq, LoRALinear) + self.assertIsNone(attn_mha.wk) + self.assertIsNone(attn_mha.wv) + + static_attn = StaticAttention.from_attention_mha(attn_mha, split_mha=False) + + self.assertIsInstance(static_attn.wqs[0], LoRALinear) + self.assertIsInstance(static_attn.wo, LoRALinear) + self.assertEqual(len(static_attn.wks), 0) + self.assertEqual(len(static_attn.wvs), 0) + + def test_yoco_static_vs_mha_numerics(self): + torch.manual_seed(42) + config = self._make_yoco_args(n_layers=4, num_kv_shared_layers=2) + rope = Rope(config) + + # Donor layer (layer_id=1, not shared) + attn_mha = AttentionMHA(config, layer_id=1, rope=rope).eval() + static_attn = StaticAttention.from_attention_mha( + attn_mha, split_mha=False + ).eval() + + x = torch.rand(1, config.max_seq_len, config.dim) + freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) + expected, _ = attn_mha(x, freqs_cos, freqs_sin) + + mask = torch.triu( + torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")), + diagonal=1, + ) + y, _ = static_attn(x, freqs_cos, freqs_sin, masks={0: mask}) + + self.assertTrue( + torch.isclose(y, expected, rtol=1e-3).all(), + "YOCO donor layer: StaticAttention vs AttentionMHA mismatch", + ) + + def test_yoco_io_manager_skips_shared_caches(self): + config = self._make_yoco_args(n_layers=4, num_kv_shared_layers=2) + rope = Rope(config) + + layers = [] + for layer_id in range(4): + attn_mha = AttentionMHA(config, layer_id=layer_id, rope=rope) + static_attn = StaticAttention.from_attention_mha(attn_mha, split_mha=False) + layers.append(static_attn) + + model = torch.nn.Sequential(*layers) + io_mgr = StaticAttentionIOManager( + model, + input_len=config.max_seq_len, + cache_lens=[config.max_context_len] * 4, + ) + + # Donor layers (0, 1) should have cache entries + for layer_id in range(2): + cache_key = StaticKVCache.calculate_cache_key(layer_id, 0) + self.assertIn(cache_key, io_mgr.k_caches) + self.assertIn(cache_key, io_mgr.v_caches) + + # Shared layers (2, 3) should NOT have cache entries + for layer_id in range(2, 4): + cache_key = StaticKVCache.calculate_cache_key(layer_id, 0) + self.assertNotIn(cache_key, io_mgr.k_caches) + self.assertNotIn(cache_key, io_mgr.v_caches) + + def test_yoco_from_config_skips_shared_caches(self): + config = self._make_yoco_args(n_layers=4, num_kv_shared_layers=2) + config.attention_type = "static_mha" + io_mgr = StaticAttentionIOManager( + config, + input_len=config.max_seq_len, + cache_lens=[config.max_context_len] * 4, + ) + + # Donor layers (0, 1) should have cache entries + for layer_id in range(2): + cache_key = StaticKVCache.calculate_cache_key(layer_id, 0) + self.assertIn(cache_key, io_mgr.k_caches) + self.assertIn(cache_key, io_mgr.v_caches) + + # Shared layers (2, 3) should NOT have cache entries + for layer_id in range(2, 4): + cache_key = StaticKVCache.calculate_cache_key(layer_id, 0) + self.assertNotIn(cache_key, io_mgr.k_caches) + self.assertNotIn(cache_key, io_mgr.v_caches)