diff --git a/src/maxtext/inference/kvcache.py b/src/maxtext/inference/kvcache.py index ba2266060f..5b7a3e4871 100644 --- a/src/maxtext/inference/kvcache.py +++ b/src/maxtext/inference/kvcache.py @@ -175,6 +175,9 @@ def kv_cache_as_linen( use_chunked_prefill: bool = False, model_mode: str = MODEL_MODE_PREFILL, is_gdn: bool = False, + is_deepseek_v4: bool = False, + compress_rate: int = 1, + is_indexer: bool = False, conv_kernel_size: int = 0, conv_dim: int = 0, name: str | None = None, @@ -228,6 +231,9 @@ def kv_cache_as_linen( use_chunked_prefill=use_chunked_prefill, model_mode=model_mode, is_gdn=is_gdn, + is_deepseek_v4=is_deepseek_v4, + compress_rate=compress_rate, + is_indexer=is_indexer, conv_kernel_size=conv_kernel_size, conv_dim=conv_dim, metadata_fn=variable_to_logically_partitioned, @@ -274,6 +280,9 @@ def __init__( is_gdn: bool = False, conv_kernel_size: int = 0, conv_dim: int = 0, + is_deepseek_v4: bool = False, + compress_rate: int = 1, + is_indexer: bool = False, *, # Not used in KVCache but passed in by nnx_wrappers.to_linen. # TODO: Remove when bridge no longer needed @@ -326,11 +335,42 @@ def __init__( self.is_gdn = is_gdn self.conv_kernel_size = conv_kernel_size self.conv_dim = conv_dim + self.is_deepseek_v4 = is_deepseek_v4 + self.compress_rate = compress_rate + self.is_indexer = is_indexer if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): self._initialize_prefill_caches(model_mode) self._initialize_ar_cache_vars(model_mode) + if self.is_deepseek_v4 and self.compress_rate > 1: + cache_batch_axis_name = CACHE_BATCH_PREFILL if model_mode == MODEL_MODE_PREFILL else CACHE_BATCH + + self.entry_count = nnx.Cache( + jnp.zeros((self.batch, 1), dtype=jnp.int32), + out_sharding=(cache_batch_axis_name, None) + ) + self.accumulator_index = nnx.Cache( + jnp.zeros((self.batch, 1), dtype=jnp.int32), + out_sharding=(cache_batch_axis_name, None) + ) + self.leftover_buffer_kv = nnx.Cache( + jnp.zeros((self.batch, self.compress_rate, self.key_heads, self.key_head_size), dtype=dtype), + out_sharding=(cache_batch_axis_name, None, None, None) + ) + self.leftover_buffer_gate = nnx.Cache( + jnp.zeros((self.batch, self.compress_rate, self.key_heads, self.key_head_size), dtype=dtype), + out_sharding=(cache_batch_axis_name, None, None, None) + ) + self.overlap_kv = nnx.Cache( + jnp.zeros((self.batch, self.compress_rate, self.key_heads, self.key_head_size), dtype=dtype), + out_sharding=(cache_batch_axis_name, None, None, None) + ) + self.overlap_gate = nnx.Cache( + jnp.zeros((self.batch, self.compress_rate, self.key_heads, self.key_head_size), dtype=dtype), + out_sharding=(cache_batch_axis_name, None, None, None) + ) + @property def prefill_key_vars(self): return (self.cached_prefill_key, self.cached_prefill_key_scale) @@ -923,6 +963,95 @@ def kv_cache_autoregressive( cache_ar_lengths_var.get_value(), ) return cached_prefill, cached_ar + + + def kv_cache_autoregressive_v4( + self, + key: Array, + value: Array, + gate: Optional[Array] = None, + use_ragged_attention: bool = False, + ): + """DeepSeek-V4 aware token-by-token caching matrix.""" + if self.compress_rate == 1: + return self.kv_cache_autoregressive(key, value, use_ragged_attention) + + # 1. Capture dynamic execution indexes using the dedicated accumulator + current_index = jnp.squeeze(self.accumulator_index.get_value()) + + buffer_kv = jax.lax.dynamic_update_index_in_dim( + self.leftover_buffer_kv.get_value(), jnp.transpose(key, self.ar_cache_axis_order), current_index, 1 + ) + buffer_gate = jax.lax.dynamic_update_index_in_dim( + self.leftover_buffer_gate.get_value(), jnp.transpose(gate, self.ar_cache_axis_order), current_index, 1 + ) + + self.leftover_buffer_kv.set_value(buffer_kv) + self.leftover_buffer_gate.set_value(buffer_gate) + + next_index = current_index + 1 + window_complete = (next_index == self.compress_rate) + + def flush_window_block(carry_state): + kv_chunk = self.leftover_buffer_kv.get_value() + gate_chunk = self.leftover_buffer_gate.get_value() + + gate_weights = jax.nn.softmax(gate_chunk, axis=1).astype(kv_chunk.dtype) + compressed_block = jnp.sum(kv_chunk * gate_weights, axis=1, keepdims=True) + + update_key = jnp.transpose(compressed_block, self.key_axis_order) + + # --- USE AR INDEX FOR THE CACHE UPDATE --- + ar_index = self.cache_ar_index.get_value() + self.update_ar_key_value( + update_key, update_key, # Value is identical to key in V4 compressed blocks + self._get_ar_cache_vars()[0], self._get_ar_cache_vars()[1], + ar_index, None, False + ) + + self.entry_count.set_value(self.entry_count.get_value() + 1) + + # --- UPDATE AR METADATA SO ATTENTION MASK RECOGNIZES THE BLOCK --- + active_indicator = jnp.zeros((self.batch, 1), dtype=jnp.int32) + DECODING_ACTIVE_SEQUENCE_INDICATOR + cached_ar_segment_id_var = self._get_ar_cache_vars()[2] + cached_ar_segment_id_var.set_value( + jax.lax.dynamic_update_index_in_dim( + cached_ar_segment_id_var.get_value(), active_indicator, jnp.squeeze(ar_index), 1 + ) + ) + + self.cache_ar_index.set_value( + jnp.mod(ar_index + 1, self.max_target_length - self.max_prefill_length) + ) + cache_ar_lengths_var = self._get_ar_cache_vars()[4] + cache_ar_lengths_var.set_value(cache_ar_lengths_var.get_value().at[:].add(1)) + + return jnp.int32(0) # Reset accumulator + + def hold_window_block(carry_state): + return next_index + + updated_index = jax.lax.cond(window_complete, flush_window_block, hold_window_block, None) + self.accumulator_index.set_value(jnp.expand_dims(updated_index, 0)) + + # --- UNPACK JAX ARRAYS TO MATCH STANDARD ATTENTION PIPELINE --- + cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars() + cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, _, cache_ar_lengths_var = self._get_ar_cache_vars() + + cached_prefill = ( + self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order), + self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order), + cached_prefill_segment_id_var.get_value(), + ) + + cached_ar = ( + self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order), + self.get_cached_values(cached_ar_value_vars, value.dtype, self.ar_cache_axis_order), + cached_ar_segment_id_var.get_value(), + cache_ar_lengths_var.get_value(), + ) + return cached_prefill, cached_ar + def __call__( self, @@ -932,6 +1061,7 @@ def __call__( model_mode: str, use_ragged_attention: bool = False, previous_chunk: Any = None, + gate: Optional[Array] = None, ) -> tuple: """KV cache takes the current state and updates the state accordingly. @@ -956,6 +1086,8 @@ def __call__( else: return self.kv_cache_prefill(key, value, decoder_segment_ids), None elif model_mode == MODEL_MODE_AUTOREGRESSIVE: + if self.is_deepseek_v4 and self.compress_rate > 1: + return self.kv_cache_autoregressive_v4(key, value, gate, use_ragged_attention) return self.kv_cache_autoregressive(key, value, use_ragged_attention) else: raise ValueError(f"Model Mode isn't supported! {model_mode=}") @@ -1128,6 +1260,7 @@ def __call__( model_mode: str, use_ragged_attention: bool = False, previous_chunk: Any = None, + gate: Optional[Array] = None, ) -> tuple[ None | tuple[Array, Array, Array], None | tuple[Array, Array, Array, Array], diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index d9b686b182..4c5e7a276b 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -1442,6 +1442,17 @@ def copy(path, partial_cache, full_cache, annotations): if batch_idx < 0: raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}") + if path_key in [ + "entry_count", + "accumulator_index", + "leftover_buffer_kv", + "leftover_buffer_gate", + "overlap_kv", + "overlap_gate" + ]: + # Copy these states by explicitly overwriting the target slot matching current request id + return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) + for slot in slots: if path_key == "cache_ar_segment_id": ### goal: zero this out in case there is existing data @@ -1556,6 +1567,17 @@ def copy(path, partial_cache, full_cache, annotations): if batch_idx < 0: raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}") + + if path_key in [ + "entry_count", + "accumulator_index", + "leftover_buffer_kv", + "leftover_buffer_gate", + "overlap_kv", + "overlap_gate" + ]: + # Copy these states by explicitly overwriting the target slot matching current request id + return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) if path_key == "cache_ar_segment_id": s = list(full_cache.shape) @@ -1690,6 +1712,17 @@ def copy(path, partial_cache, full_cache, annotations): if batch_idx < 0: raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}") + + if path_key in [ + "entry_count", + "accumulator_index", + "leftover_buffer_kv", + "leftover_buffer_gate", + "overlap_kv", + "overlap_gate" + ]: + # Direct batch slot index overwrite for fixed-size metadata trackers + return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) if path_key == "cache_ar_segment_id": ### goal: zero this out in case there is existing data diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py index e9a25f46b5..3e8b30748a 100644 --- a/src/maxtext/layers/attention_compressed.py +++ b/src/maxtext/layers/attention_compressed.py @@ -40,90 +40,128 @@ from maxtext.layers.normalizations import RMSNorm from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.inference.kvcache import KVQuant - +from maxtext.inference import kvcache + + +# def csa_overlap_pooling( +# hidden_states: Array, +# kv_proj: Any, +# gate_proj: Any, +# position_bias: Array, +# kv_norm: Any, +# compress_rate: int, +# head_dim: int, +# ) -> Array: +# """Shared utility for Compressed Sparse Attention (CSA) overlap pooling. + +# Implements the overlapping Ca/Cb pooling logic shared by both the CSA Compressor +# and the CSA Indexer. It splits the projected states into two halves (Ca and Cb), +# shifts the first half forward by one window, and concatenates them to form +# overlapping windows over which softmax gating is applied. + +# Args: +# hidden_states: Input token embeddings. Shape: `[batch, seq_len, emb_dim]`. +# kv_proj: Dense layer projecting to `2 * head_dim`. +# gate_proj: Dense layer projecting to `2 * head_dim`. +# position_bias: Bias tensor. Shape: `[compress_rate, 2 * head_dim]`. +# kv_norm: RMSNorm instance. +# compress_rate: Compression rate for CSA. +# head_dim: Standard head dimension. + +# Returns: +# compressed: The pooled overlapping states. Shape: `[batch, n_windows, head_dim]`. + +# Shape Transformations: +# 1. Projections: `[batch, seq_len, emb_dim]` -> `[batch, seq_len, 2 * head_dim]` +# 2. Reshape: -> `[batch, n_windows, compress_rate, 2 * head_dim]` +# 3. Split: -> 2x `[batch, n_windows, compress_rate, head_dim]` +# 4. Shift: Ca shifted forward by one window. +# 5. Concat (Ca + Cb): -> `[batch, n_windows, 2 * compress_rate, head_dim]` +# 6. Gating & Sum: -> `[batch, n_windows, head_dim]` +# """ +# batch_size, seq_len, _ = hidden_states.shape + +# # [batch, seq_len, emb_dim] -> [batch, seq_len, 2 * head_dim] +# kv = kv_proj(hidden_states) +# # [batch, seq_len, emb_dim] -> [batch, seq_len, 2 * head_dim] +# gate = gate_proj(hidden_states) + +# usable = (seq_len // compress_rate) * compress_rate +# chunk_kv = kv[:, :usable] +# chunk_gate = gate[:, :usable] + +# # Return zero tensor if there are no full windows available for pooling +# if chunk_kv.shape[1] == 0: +# return jnp.zeros((batch_size, 0, head_dim), dtype=hidden_states.dtype) + +# n_windows = chunk_kv.shape[1] // compress_rate + +# # Reshape flat sequence into discrete compression windows +# # -> [batch, n_windows, compress_rate, 2 * head_dim] +# chunk_kv = chunk_kv.reshape((batch_size, n_windows, compress_rate, 2 * head_dim)) +# chunk_gate = chunk_gate.reshape((batch_size, n_windows, compress_rate, 2 * head_dim)) + position_bias + +# # Split the projections into Ca and Cb components for overlapping +# # 2x [batch, n_windows, compress_rate, head_dim] +# a_kv, b_kv = jnp.split(chunk_kv, 2, axis=-1) +# a_gate, b_gate = jnp.split(chunk_gate, 2, axis=-1) + +# # Shift Ca forward by one window to align with the next Cb +# a_kv_shifted = jnp.concatenate( +# [jnp.zeros((batch_size, 1, compress_rate, head_dim), dtype=a_kv.dtype), a_kv[:, :-1]], axis=1 +# ) +# a_gate_shifted = jnp.concatenate( +# [jnp.full((batch_size, 1, compress_rate, head_dim), -jnp.inf, dtype=a_gate.dtype), a_gate[:, :-1]], axis=1 +# ) + +# # Concatenate shifted Ca and unshifted Cb to form the final overlapping window +# # -> [batch, n_windows, 2 * compress_rate, head_dim] +# new_kv = jnp.concatenate([a_kv_shifted, b_kv], axis=2) +# new_gate = jnp.concatenate([a_gate_shifted, b_gate], axis=2) + +# # Apply softmax gating and sum across the overlapping window dimension +# gate_weights = jax.nn.softmax(new_gate, axis=2).astype(new_kv.dtype) +# # -> [batch, n_windows, head_dim] +# compressed = kv_norm(jnp.sum(new_kv * gate_weights, axis=2)) + +# return compressed def csa_overlap_pooling( - hidden_states: Array, - kv_proj: Any, - gate_proj: Any, - position_bias: Array, + chunk_kv_reshaped: Array, # Shape: [batch, n_windows, compress_rate, 2 * head_dim] + chunk_gate_reshaped: Array, # Shape: [batch, n_windows, compress_rate, 2 * head_dim] kv_norm: Any, - compress_rate: int, head_dim: int, -) -> Array: - """Shared utility for Compressed Sparse Attention (CSA) overlap pooling. - - Implements the overlapping Ca/Cb pooling logic shared by both the CSA Compressor - and the CSA Indexer. It splits the projected states into two halves (Ca and Cb), - shifts the first half forward by one window, and concatenates them to form - overlapping windows over which softmax gating is applied. - - Args: - hidden_states: Input token embeddings. Shape: `[batch, seq_len, emb_dim]`. - kv_proj: Dense layer projecting to `2 * head_dim`. - gate_proj: Dense layer projecting to `2 * head_dim`. - position_bias: Bias tensor. Shape: `[compress_rate, 2 * head_dim]`. - kv_norm: RMSNorm instance. - compress_rate: Compression rate for CSA. - head_dim: Standard head dimension. - - Returns: - compressed: The pooled overlapping states. Shape: `[batch, n_windows, head_dim]`. - - Shape Transformations: - 1. Projections: `[batch, seq_len, emb_dim]` -> `[batch, seq_len, 2 * head_dim]` - 2. Reshape: -> `[batch, n_windows, compress_rate, 2 * head_dim]` - 3. Split: -> 2x `[batch, n_windows, compress_rate, head_dim]` - 4. Shift: Ca shifted forward by one window. - 5. Concat (Ca + Cb): -> `[batch, n_windows, 2 * compress_rate, head_dim]` - 6. Gating & Sum: -> `[batch, n_windows, head_dim]` - """ - batch_size, seq_len, _ = hidden_states.shape - - # [batch, seq_len, emb_dim] -> [batch, seq_len, 2 * head_dim] - kv = kv_proj(hidden_states) - # [batch, seq_len, emb_dim] -> [batch, seq_len, 2 * head_dim] - gate = gate_proj(hidden_states) - - usable = (seq_len // compress_rate) * compress_rate - chunk_kv = kv[:, :usable] - chunk_gate = gate[:, :usable] - - # Return zero tensor if there are no full windows available for pooling - if chunk_kv.shape[1] == 0: - return jnp.zeros((batch_size, 0, head_dim), dtype=hidden_states.dtype) - - n_windows = chunk_kv.shape[1] // compress_rate + prior_kv: Optional[Array] = None, # Shape: [batch, 1, compress_rate, head_dim] + prior_gate: Optional[Array] = None, # Shape: [batch, 1, compress_rate, head_dim] +) -> Tuple[Array, Array, Array]: + """Executes staggered Ca/Cb overlapping pooling and returns the states for the next window.""" + batch_size, n_windows, compress_rate, _ = chunk_kv_reshaped.shape + + # Split the projections into Ca (next window's past) and Cb (current window's present) + a_kv, b_kv = jnp.split(chunk_kv_reshaped, 2, axis=-1) + a_gate, b_gate = jnp.split(chunk_gate_reshaped, 2, axis=-1) + + # If no prior state exists (e.g. first prefill step), initialize empty/masked priors + if prior_kv is None: + prior_kv = jnp.zeros((batch_size, 1, compress_rate, head_dim), dtype=a_kv.dtype) + if prior_gate is None: + prior_gate = jnp.full((batch_size, 1, compress_rate, head_dim), -jnp.inf, dtype=a_gate.dtype) + + # Shift Ca forward by prepending the prior window's Ca slice + a_kv_shifted = jnp.concatenate([prior_kv, a_kv[:, :-1]], axis=1) + a_gate_shifted = jnp.concatenate([prior_gate, a_gate[:, :-1]], axis=1) - # Reshape flat sequence into discrete compression windows - # -> [batch, n_windows, compress_rate, 2 * head_dim] - chunk_kv = chunk_kv.reshape((batch_size, n_windows, compress_rate, 2 * head_dim)) - chunk_gate = chunk_gate.reshape((batch_size, n_windows, compress_rate, 2 * head_dim)) + position_bias - - # Split the projections into Ca and Cb components for overlapping - # 2x [batch, n_windows, compress_rate, head_dim] - a_kv, b_kv = jnp.split(chunk_kv, 2, axis=-1) - a_gate, b_gate = jnp.split(chunk_gate, 2, axis=-1) - - # Shift Ca forward by one window to align with the next Cb - a_kv_shifted = jnp.concatenate( - [jnp.zeros((batch_size, 1, compress_rate, head_dim), dtype=a_kv.dtype), a_kv[:, :-1]], axis=1 - ) - a_gate_shifted = jnp.concatenate( - [jnp.full((batch_size, 1, compress_rate, head_dim), -jnp.inf, dtype=a_gate.dtype), a_gate[:, :-1]], axis=1 - ) - - # Concatenate shifted Ca and unshifted Cb to form the final overlapping window - # -> [batch, n_windows, 2 * compress_rate, head_dim] new_kv = jnp.concatenate([a_kv_shifted, b_kv], axis=2) new_gate = jnp.concatenate([a_gate_shifted, b_gate], axis=2) - # Apply softmax gating and sum across the overlapping window dimension gate_weights = jax.nn.softmax(new_gate, axis=2).astype(new_kv.dtype) - # -> [batch, n_windows, head_dim] compressed = kv_norm(jnp.sum(new_kv * gate_weights, axis=2)) - return compressed + # The next forward pass will need the Ca slice from the very last window processed here + next_prior_kv = a_kv[:, -1:] + next_prior_gate = a_gate[:, -1:] + + return compressed, next_prior_kv, next_prior_gate class BaseDeepseekCompressor(nnx.Module): @@ -246,6 +284,8 @@ def __call__( hidden_states: Array, q_normed: Array, position_ids: Array, + model_mode: str, + cache: Optional[Any] = None, ) -> Tuple[Array, Array]: """Forward pass for the HCA compressor. @@ -261,55 +301,71 @@ def __call__( """ batch_size, seq_len, _ = hidden_states.shape - # Project hidden states to KV and Gate components - # [batch, seq_len, emb_dim] -> [batch, seq_len, head_dim] kv = self.kv_proj(hidden_states) - # [batch, seq_len, emb_dim] -> [batch, seq_len, head_dim] gate = self.gate_proj(hidden_states) - # Truncate sequence to the nearest multiple of the compression rate + # --- AUTOREGRESSIVE DELEGATION --- + if model_mode == MODEL_MODE_AUTOREGRESSIVE and cache is not None: + # Expand dims to match [B, S, H, D] format for the cache + kv_exp = jnp.expand_dims(kv, 2) + gate_exp = jnp.expand_dims(gate, 2) + + cached_prefill, cached_ar = cache( + key=kv_exp, value=kv_exp, gate=gate_exp, decoder_segment_ids=None, model_mode=model_mode + ) + # Recombine history and strip head dimension + compressed_kv = jnp.concatenate([cached_prefill[0], cached_ar[0]], axis=1)[:, :, 0, :] + compressed_kv = jnp.expand_dims(compressed_kv, 2) # [B, N, 1, D] + return compressed_kv, None + + # --- PREFILL CHUNKING & PRIMING --- usable = (seq_len // self.compress_rate) * self.compress_rate chunk_kv = kv[:, :usable] chunk_gate = gate[:, :usable] first_window_position = position_ids[:, 0:1] - # Process overlapping windows if there is enough sequence length if chunk_kv.shape[1] > 0: n_windows = chunk_kv.shape[1] // self.compress_rate - - # Reshape into blocks of size `compress_rate` - # -> [batch, n_windows, compress_rate, head_dim] chunk_kv = chunk_kv.reshape((batch_size, n_windows, self.compress_rate, -1)) chunk_gate = chunk_gate.reshape((batch_size, n_windows, self.compress_rate, -1)) + self.position_bias.value - # Apply gating mechanism over each compression window gate_weights = jax.nn.softmax(chunk_gate, axis=2).astype(chunk_kv.dtype) - # -> [batch, n_windows, head_dim] compressed = self.kv_norm(jnp.sum(chunk_kv * gate_weights, axis=2)) - - # Calculate positions for the compressed blocks positions = jnp.arange(n_windows) * self.compress_rate + first_window_position - - # Apply Rotary Positional Embeddings to the pooled representations - # compressed is [batch, n_windows, head_dim] compressed = self.rotary_emb(compressed, positions, unsqueeze_dim=None) else: - # Provide an empty tensor when the sequence is shorter than the compression rate compressed = jnp.zeros((batch_size, 0, self.head_dim), dtype=self.dtype) - # Expand the feature dimension to match the standard KV projection shape - # -> [batch, n_windows, 1, head_dim] compressed_kv = jnp.expand_dims(compressed, axis=2) compressed_len = compressed_kv.shape[1] - # Skip causal mask generation during decoding (seq_len == 1) or if no blocks were pooled + # --- PREFILL CACHE PRIMING --- + if cache is not None: + remainder = seq_len % self.compress_rate + if remainder > 0: + leftover_kv = kv[:, usable:] + leftover_gate = gate[:, usable:] + pad_len = self.compress_rate - remainder + padded_kv = jnp.expand_dims(jnp.pad(leftover_kv, ((0, 0), (0, pad_len), (0, 0))), 2) + padded_gate = jnp.expand_dims(jnp.pad(leftover_gate, ((0, 0), (0, pad_len), (0, 0))), 2) + cache.leftover_buffer_kv.set_value(padded_kv) + cache.leftover_buffer_gate.set_value(padded_gate) + cache.accumulator_index.set_value(jnp.full((batch_size, 1), remainder, dtype=jnp.int32)) + + if compressed_len > 0: + cache_key_var = cache.cached_prefill_key + # Update the prefill array with the generated blocks [B, N, H, D] + update_blocks = jnp.transpose(compressed_kv, (0, 1, 3, 2)) + cache_key_var.set_value( + jax.lax.dynamic_update_slice_in_dim(cache_key_var.get_value(), update_blocks, 0, axis=1) + ) + cache.entry_count.set_value(jnp.full((batch_size, 1), compressed_len, dtype=jnp.int32)) + if seq_len == 1 or compressed_len == 0: return compressed_kv, None - # Construct a causal mask preventing early queries from attending to future compressed blocks entry_indices = jnp.arange(compressed_len) causal_threshold = (position_ids + 1) // self.compress_rate - future_mask = entry_indices[None, None, None, :] >= jnp.expand_dims(causal_threshold, axis=(1, 3)) compressed_causal_mask = jnp.where(future_mask, DEFAULT_MASK_VALUE, 0.0).astype(self.dtype) @@ -435,84 +491,109 @@ def __call__( q_latent: Array, position_ids: Array, attention_mask: Optional[Array] = None, + model_mode: str = MODEL_MODE_TRAIN, + cache: Optional[Any] = None, ) -> Array: batch_size, seq_len, _ = hidden_states.shape - # Process overlapping pooling independently for the Indexer using its own head dimension - # -> [batch, n_windows, index_head_dim] - compressed = csa_overlap_pooling( - hidden_states, - self.kv_proj, - self.gate_proj, - self.position_bias.value, - self.kv_norm, - self.compress_rate, - self.index_head_dim, - ) - compressed_len = compressed.shape[1] - - # Apply rotary positional embeddings to the compressed blocks if valid windows exist - if compressed_len > 0: - first_window_position = position_ids[:, 0:1] - positions = jnp.arange(compressed_len) * self.compress_rate + first_window_position + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) - compressed = self.rotary_emb(compressed, positions, unsqueeze_dim=None) + # --- AUTOREGRESSIVE DELEGATION --- + if model_mode == MODEL_MODE_AUTOREGRESSIVE and cache is not None: + kv_exp = jnp.expand_dims(kv, 2) + gate_exp = jnp.expand_dims(gate, 2) + cached_prefill, cached_ar = cache( + key=kv_exp, value=kv_exp, gate=gate_exp, decoder_segment_ids=None, model_mode=model_mode + ) + compressed = jnp.concatenate([cached_prefill[0], cached_ar[0]], axis=1)[:, :, 0, :] + compressed_len = compressed.shape[1] + + # --- PREFILL CHUNKING & PRIMING --- else: - # Return empty top-k selections when sequence is too short to form any windows + usable = (seq_len // self.compress_rate) * self.compress_rate + chunk_kv = kv[:, :usable] + chunk_gate = gate[:, :usable] + + # Extract staggered overlap states if cache is available + prior_kv = cache.overlap_kv.get_value()[:, :, 0, :] if cache is not None else None + prior_gate = cache.overlap_gate.get_value()[:, :, 0, :] if cache is not None else None + + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv_reshaped = chunk_kv.reshape((batch_size, n_windows, self.compress_rate, -1)) + chunk_gate_reshaped = chunk_gate.reshape((batch_size, n_windows, self.compress_rate, -1)) + self.position_bias.value + + compressed, next_prior_kv, next_prior_gate = csa_overlap_pooling( + chunk_kv_reshaped, chunk_gate_reshaped, self.kv_norm, self.index_head_dim, prior_kv, prior_gate + ) + compressed_len = compressed.shape[1] + + positions = jnp.arange(compressed_len) * self.compress_rate + position_ids[:, 0:1] + compressed = self.rotary_emb(compressed, positions, unsqueeze_dim=None) + else: + compressed = jnp.zeros((batch_size, 0, self.index_head_dim), dtype=self.dtype) + compressed_len = 0 + next_prior_kv = prior_kv + next_prior_gate = prior_gate + + # Prefill Cache Insertion + if cache is not None: + remainder = seq_len % self.compress_rate + if remainder > 0: + leftover_kv = kv[:, usable:] + leftover_gate = gate[:, usable:] + pad_len = self.compress_rate - remainder + padded_kv = jnp.expand_dims(jnp.pad(leftover_kv, ((0, 0), (0, pad_len), (0, 0))), 2) + padded_gate = jnp.expand_dims(jnp.pad(leftover_gate, ((0, 0), (0, pad_len), (0, 0))), 2) + cache.leftover_buffer_kv.set_value(padded_kv) + cache.leftover_buffer_gate.set_value(padded_gate) + cache.accumulator_index.set_value(jnp.full((batch_size, 1), remainder, dtype=jnp.int32)) + + if compressed_len > 0: + cache_key_var = cache.cached_prefill_key + update_blocks = jnp.transpose(jnp.expand_dims(compressed, 2), (0, 1, 3, 2)) + cache_key_var.set_value( + jax.lax.dynamic_update_slice_in_dim(cache_key_var.get_value(), update_blocks, 0, axis=1) + ) + cache.entry_count.set_value(jnp.full((batch_size, 1), compressed_len, dtype=jnp.int32)) + + # Save the new trailing Ca slices to the overlap registers! + cache.overlap_kv.set_value(jnp.expand_dims(next_prior_kv, 2)) + cache.overlap_gate.set_value(jnp.expand_dims(next_prior_gate, 2)) + + if compressed_len == 0: return jnp.zeros((batch_size, seq_len, min(self.index_topk, compressed_len)), dtype=jnp.int32) - # Broadcast the compressed KV representations across all indexer heads - # -> [batch, 1, n_windows, index_head_dim] + # --- TOP-K ROUTING MATH (Executes in both Prefill and AR) --- compressed_kv = jnp.expand_dims(compressed, axis=1) - # -> [batch, index_n_heads, n_windows, index_head_dim] compressed_kv = jnp.broadcast_to(compressed_kv, (batch_size, self.index_n_heads, compressed_len, self.index_head_dim)) - # Project the latent query to match the Indexer's dimensions - # [batch, seq_len, index_n_heads * index_head_dim] -> [batch, seq_len, index_n_heads, index_head_dim] q = self.q_proj(q_latent).reshape((batch_size, seq_len, self.index_n_heads, self.index_head_dim)) - # -> [batch, index_n_heads, seq_len, index_head_dim] q = jnp.transpose(q, (0, 2, 1, 3)) - - # Apply standard Rotary Positional Embeddings to queries q = self.rotary_emb(q, position_ids, unsqueeze_dim=1) q = q.astype(jnp.float32) compressed_kv = compressed_kv.astype(jnp.float32) - # Compute dot product between Queries and Compressed KV Blocks - # -> [batch, index_n_heads, seq_len, n_windows] scores = jnp.einsum("bhsd,bhwd->bhsw", q, compressed_kv) scores = jax.nn.relu(scores) * self.softmax_scale - - # Compute routing weights to combine scores across indexer heads - # [batch, seq_len, emb_dim] -> [batch, seq_len, index_n_heads] weights = self.weights_proj(hidden_states).astype(jnp.float32) * self.weights_scaling - - # Combine individual head scores according to routing weights - # -> [batch, seq_len, n_windows] index_scores = jnp.einsum("bhsw,bsh->bsw", scores, weights) k = min(self.index_topk, compressed_len) - - # Mask out future compressed blocks to ensure causal routing causal_threshold = (position_ids + 1) // self.compress_rate entry_indices = jnp.arange(compressed_len) future_mask = entry_indices[None, None, :] >= jnp.expand_dims(causal_threshold, axis=-1) index_scores = jnp.where(future_mask, jnp.full_like(index_scores, -jnp.inf), index_scores) - # Apply standard segment attention mask (additive 0 and -inf) if attention_mask is not None: index_scores += attention_mask[:, :, :compressed_len] - # Retrieve the top-k highest scoring block indices for each token top_k_indices = jax.lax.top_k(index_scores, k)[1] - - # Invalidate any top-k selections that point to future blocks (edge case safety) invalid = top_k_indices >= jnp.expand_dims(causal_threshold, axis=-1) - top_k_indices = jnp.where(invalid, jnp.full_like(top_k_indices, -1), top_k_indices) - - return top_k_indices + return jnp.where(invalid, jnp.full_like(top_k_indices, -1), top_k_indices) class DeepseekV4CSACompressor(BaseDeepseekCompressor): @@ -568,58 +649,87 @@ def __call__( q_latent: Array, position_ids: Array, attention_mask: Optional[Array] = None, + model_mode: str = MODEL_MODE_TRAIN, + cache: Optional[Any] = None, + indexer_cache: Optional[Any] = None, ) -> Tuple[Array, Array]: - """Forward pass for the CSA compressor. - - Args: - hidden_states: Input token embeddings. Shape: `[batch, seq_len, emb_dim]`. - q_latent: Latent query representation. Shape: `[batch, seq_len, emb_dim]`. - position_ids: Absolute token positions. Shape: `[batch, seq_len]`. - - Returns: - compressed_kv: The pooled KV tensors. Shape: `[batch, n_windows, 1, head_dim]`. - compressed_mask: Causal and routing mask dynamically selected by the Indexer. - Shape: `[batch, 1, seq_len, n_windows]`. - """ batch_size, seq_len, _ = hidden_states.shape - # Retrieve top-k blocks dynamically chosen for each query - # -> [batch, seq_len, index_topk] - top_k_indices = self.indexer(hidden_states, q_latent, position_ids, attention_mask) - - # Perform overlapping pooling over the sequence - # -> [batch, n_windows, head_dim] - compressed = csa_overlap_pooling( - hidden_states, - self.kv_proj, - self.gate_proj, - self.position_bias.value, - self.kv_norm, - self.compress_rate, - self.head_dim, + # 1. ALWAYS Run Indexer (It fetches its own history inside AR) + top_k_indices = self.indexer( + hidden_states, q_latent, position_ids, attention_mask, model_mode, indexer_cache ) - compressed_len = compressed.shape[1] - - # Apply rotary positional embeddings to the pooled blocks if there are any full windows - if compressed_len > 0: - first_window_position = position_ids[:, 0:1] - positions = jnp.arange(compressed_len) * self.compress_rate + first_window_position - compressed = self.rotary_emb(compressed, positions, unsqueeze_dim=None) + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) - # Expand to standard KV format - # -> [batch, n_windows, 1, head_dim] - compressed_kv = jnp.expand_dims(compressed, axis=2) + # --- AUTOREGRESSIVE DELEGATION --- + if model_mode == MODEL_MODE_AUTOREGRESSIVE and cache is not None: + kv_exp = jnp.expand_dims(kv, 2) + gate_exp = jnp.expand_dims(gate, 2) + cached_prefill, cached_ar = cache( + key=kv_exp, value=kv_exp, gate=gate_exp, decoder_segment_ids=None, model_mode=model_mode + ) + compressed = jnp.concatenate([cached_prefill[0], cached_ar[0]], axis=1)[:, :, 0, :] + compressed_len = compressed.shape[1] + compressed_kv = jnp.expand_dims(compressed, 2) + + # --- PREFILL CHUNKING & PRIMING --- + else: + usable = (seq_len // self.compress_rate) * self.compress_rate + chunk_kv = kv[:, :usable] + chunk_gate = gate[:, :usable] + + prior_kv = cache.overlap_kv.get_value()[:, :, 0, :] if cache is not None else None + prior_gate = cache.overlap_gate.get_value()[:, :, 0, :] if cache is not None else None + + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv_reshaped = chunk_kv.reshape((batch_size, n_windows, self.compress_rate, -1)) + chunk_gate_reshaped = chunk_gate.reshape((batch_size, n_windows, self.compress_rate, -1)) + self.position_bias.value + + compressed, next_prior_kv, next_prior_gate = csa_overlap_pooling( + chunk_kv_reshaped, chunk_gate_reshaped, self.kv_norm, self.head_dim, prior_kv, prior_gate + ) + compressed_len = compressed.shape[1] + + positions = jnp.arange(compressed_len) * self.compress_rate + position_ids[:, 0:1] + compressed = self.rotary_emb(compressed, positions, unsqueeze_dim=None) + else: + compressed = jnp.zeros((batch_size, 0, self.head_dim), dtype=self.dtype) + compressed_len = 0 + next_prior_kv = prior_kv + next_prior_gate = prior_gate + + compressed_kv = jnp.expand_dims(compressed, 2) + + if cache is not None: + remainder = seq_len % self.compress_rate + if remainder > 0: + leftover_kv = kv[:, usable:] + leftover_gate = gate[:, usable:] + pad_len = self.compress_rate - remainder + padded_kv = jnp.expand_dims(jnp.pad(leftover_kv, ((0, 0), (0, pad_len), (0, 0))), 2) + padded_gate = jnp.expand_dims(jnp.pad(leftover_gate, ((0, 0), (0, pad_len), (0, 0))), 2) + cache.leftover_buffer_kv.set_value(padded_kv) + cache.leftover_buffer_gate.set_value(padded_gate) + cache.accumulator_index.set_value(jnp.full((batch_size, 1), remainder, dtype=jnp.int32)) + + if compressed_len > 0: + cache_key_var = cache.cached_prefill_key + update_blocks = jnp.transpose(compressed_kv, (0, 1, 3, 2)) + cache_key_var.set_value( + jax.lax.dynamic_update_slice_in_dim(cache_key_var.get_value(), update_blocks, 0, axis=1) + ) + cache.entry_count.set_value(jnp.full((batch_size, 1), compressed_len, dtype=jnp.int32)) + cache.overlap_kv.set_value(jnp.expand_dims(next_prior_kv, 2)) + cache.overlap_gate.set_value(jnp.expand_dims(next_prior_gate, 2)) - # Return early if no compressed blocks could be formed (e.g. sequence too short) if compressed_len == 0: return compressed_kv, jnp.zeros((batch_size, 1, seq_len, 0), dtype=self.dtype) - # Construct the final dynamic mask applying the Indexer's selections - # -> [batch, 1, seq_len, n_windows] + # 3. Apply Dynamic Masking Logic k = top_k_indices.shape[-1] - - # Only compute and apply the complex block mask if top-k selections exist if k > 0: valid = top_k_indices >= 0 entry_indices = jnp.arange(compressed_len)[None, None, :] @@ -872,6 +982,51 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No rngs=self.rngs, ) + self.compressor_cache = None + self.indexer_cache = None + + if self.model_mode != MODEL_MODE_TRAIN: + batch_size = inputs_q_shape[0] + max_prefill_comp = max_prefill_predict_length // self.compress_ratio if self.compress_ratio > 0 else 0 + max_target_comp = max_target_length // self.compress_ratio if self.compress_ratio > 0 else 0 + + if self.compress_ratio > 0: + self.compressor_cache = kvcache.KVCache( + max_prefill_length=max_prefill_comp, + max_target_length=max_target_comp, + batch=batch_size, + key_seq_len=1, + value_seq_len=1, + key_heads=1, + value_heads=1, + key_head_size=self.head_dim, + value_head_size=self.head_dim, + dtype=self.dtype, + model_mode=self.model_mode, + is_deepseek_v4=True, + compress_rate=self.compress_ratio, + rngs=rngs, + ) + + if self.compress_ratio == 4: + self.indexer_cache = kvcache.KVCache( + max_prefill_length=max_prefill_comp, + max_target_length=max_target_comp, + batch=batch_size, + key_seq_len=1, + value_seq_len=1, + key_heads=1, + value_heads=1, + key_head_size=config.indexer_head_dim, + value_head_size=config.indexer_head_dim, + dtype=self.dtype, + model_mode=self.model_mode, + is_deepseek_v4=True, + compress_rate=self.compress_ratio, + is_indexer=True, + rngs=rngs, + ) + @property def out_head_dim(self) -> int: """Returns the head dimension used prior to the output projection.""" @@ -992,9 +1147,13 @@ def __call__( # Route to the appropriate compressor depending on the layer's role in the architecture if self.compress_ratio > 4: - compressed_kv, compressed_mask = self.hca_compressor(inputs_kv, q_normed, inputs_positions) + compressed_kv, compressed_mask = self.hca_compressor( + inputs_kv, q_normed, inputs_positions, model_mode, self.compressor_cache + ) elif self.compress_ratio == 4: - compressed_kv, compressed_mask = self.csa_compressor(inputs_kv, q_normed, inputs_positions, compressed_segment_mask) + compressed_kv, compressed_mask = self.csa_compressor( + inputs_kv, q_normed, inputs_positions, compressed_segment_mask, model_mode, self.compressor_cache, self.indexer_cache + ) # Apply segment masking to the compressed blocks if compressed_segment_mask is not None and compressed_mask is not None: