Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions src/maxtext/inference/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def kv_cache_as_linen(
use_chunked_prefill: bool = False,
model_mode: str = MODEL_MODE_PREFILL,
is_gdn: bool = False,
is_deepseek_v4: bool = False,
compress_rate: int = 1,
is_indexer: bool = False,
conv_kernel_size: int = 0,
conv_dim: int = 0,
name: str | None = None,
Expand Down Expand Up @@ -228,6 +231,9 @@ def kv_cache_as_linen(
use_chunked_prefill=use_chunked_prefill,
model_mode=model_mode,
is_gdn=is_gdn,
is_deepseek_v4=is_deepseek_v4,
compress_rate=compress_rate,
is_indexer=is_indexer,
conv_kernel_size=conv_kernel_size,
conv_dim=conv_dim,
metadata_fn=variable_to_logically_partitioned,
Expand Down Expand Up @@ -274,6 +280,9 @@ def __init__(
is_gdn: bool = False,
conv_kernel_size: int = 0,
conv_dim: int = 0,
is_deepseek_v4: bool = False,
compress_rate: int = 1,
is_indexer: bool = False,
*,
# Not used in KVCache but passed in by nnx_wrappers.to_linen.
# TODO: Remove when bridge no longer needed
Expand Down Expand Up @@ -326,11 +335,42 @@ def __init__(
self.is_gdn = is_gdn
self.conv_kernel_size = conv_kernel_size
self.conv_dim = conv_dim
self.is_deepseek_v4 = is_deepseek_v4
self.compress_rate = compress_rate
self.is_indexer = is_indexer

if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
self._initialize_prefill_caches(model_mode)
self._initialize_ar_cache_vars(model_mode)

if self.is_deepseek_v4 and self.compress_rate > 1:
cache_batch_axis_name = CACHE_BATCH_PREFILL if model_mode == MODEL_MODE_PREFILL else CACHE_BATCH

self.entry_count = nnx.Cache(
jnp.zeros((self.batch, 1), dtype=jnp.int32),
out_sharding=(cache_batch_axis_name, None)
)
self.accumulator_index = nnx.Cache(
jnp.zeros((self.batch, 1), dtype=jnp.int32),
out_sharding=(cache_batch_axis_name, None)
)
self.leftover_buffer_kv = nnx.Cache(
jnp.zeros((self.batch, self.compress_rate, self.key_heads, self.key_head_size), dtype=dtype),
out_sharding=(cache_batch_axis_name, None, None, None)
)
self.leftover_buffer_gate = nnx.Cache(
jnp.zeros((self.batch, self.compress_rate, self.key_heads, self.key_head_size), dtype=dtype),
out_sharding=(cache_batch_axis_name, None, None, None)
)
self.overlap_kv = nnx.Cache(
jnp.zeros((self.batch, self.compress_rate, self.key_heads, self.key_head_size), dtype=dtype),
out_sharding=(cache_batch_axis_name, None, None, None)
)
self.overlap_gate = nnx.Cache(
jnp.zeros((self.batch, self.compress_rate, self.key_heads, self.key_head_size), dtype=dtype),
out_sharding=(cache_batch_axis_name, None, None, None)
)

@property
def prefill_key_vars(self):
return (self.cached_prefill_key, self.cached_prefill_key_scale)
Expand Down Expand Up @@ -923,6 +963,95 @@ def kv_cache_autoregressive(
cache_ar_lengths_var.get_value(),
)
return cached_prefill, cached_ar


def kv_cache_autoregressive_v4(
self,
key: Array,
value: Array,
gate: Optional[Array] = None,
use_ragged_attention: bool = False,
):
"""DeepSeek-V4 aware token-by-token caching matrix."""
if self.compress_rate == 1:
return self.kv_cache_autoregressive(key, value, use_ragged_attention)

# 1. Capture dynamic execution indexes using the dedicated accumulator
current_index = jnp.squeeze(self.accumulator_index.get_value())

buffer_kv = jax.lax.dynamic_update_index_in_dim(
self.leftover_buffer_kv.get_value(), jnp.transpose(key, self.ar_cache_axis_order), current_index, 1
)
buffer_gate = jax.lax.dynamic_update_index_in_dim(
self.leftover_buffer_gate.get_value(), jnp.transpose(gate, self.ar_cache_axis_order), current_index, 1
)

self.leftover_buffer_kv.set_value(buffer_kv)
self.leftover_buffer_gate.set_value(buffer_gate)

next_index = current_index + 1
window_complete = (next_index == self.compress_rate)

def flush_window_block(carry_state):
kv_chunk = self.leftover_buffer_kv.get_value()
gate_chunk = self.leftover_buffer_gate.get_value()

gate_weights = jax.nn.softmax(gate_chunk, axis=1).astype(kv_chunk.dtype)
compressed_block = jnp.sum(kv_chunk * gate_weights, axis=1, keepdims=True)

update_key = jnp.transpose(compressed_block, self.key_axis_order)

# --- USE AR INDEX FOR THE CACHE UPDATE ---
ar_index = self.cache_ar_index.get_value()
self.update_ar_key_value(
update_key, update_key, # Value is identical to key in V4 compressed blocks
self._get_ar_cache_vars()[0], self._get_ar_cache_vars()[1],
ar_index, None, False
)

self.entry_count.set_value(self.entry_count.get_value() + 1)

# --- UPDATE AR METADATA SO ATTENTION MASK RECOGNIZES THE BLOCK ---
active_indicator = jnp.zeros((self.batch, 1), dtype=jnp.int32) + DECODING_ACTIVE_SEQUENCE_INDICATOR
cached_ar_segment_id_var = self._get_ar_cache_vars()[2]
cached_ar_segment_id_var.set_value(
jax.lax.dynamic_update_index_in_dim(
cached_ar_segment_id_var.get_value(), active_indicator, jnp.squeeze(ar_index), 1
)
)

self.cache_ar_index.set_value(
jnp.mod(ar_index + 1, self.max_target_length - self.max_prefill_length)
)
cache_ar_lengths_var = self._get_ar_cache_vars()[4]
cache_ar_lengths_var.set_value(cache_ar_lengths_var.get_value().at[:].add(1))

return jnp.int32(0) # Reset accumulator

def hold_window_block(carry_state):
return next_index

updated_index = jax.lax.cond(window_complete, flush_window_block, hold_window_block, None)
self.accumulator_index.set_value(jnp.expand_dims(updated_index, 0))

# --- UNPACK JAX ARRAYS TO MATCH STANDARD ATTENTION PIPELINE ---
cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars()
cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, _, cache_ar_lengths_var = self._get_ar_cache_vars()

cached_prefill = (
self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order),
self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order),
cached_prefill_segment_id_var.get_value(),
)

cached_ar = (
self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order),
self.get_cached_values(cached_ar_value_vars, value.dtype, self.ar_cache_axis_order),
cached_ar_segment_id_var.get_value(),
cache_ar_lengths_var.get_value(),
)
return cached_prefill, cached_ar


def __call__(
self,
Expand All @@ -932,6 +1061,7 @@ def __call__(
model_mode: str,
use_ragged_attention: bool = False,
previous_chunk: Any = None,
gate: Optional[Array] = None,
) -> tuple:
"""KV cache takes the current state and updates the state accordingly.

Expand All @@ -956,6 +1086,8 @@ def __call__(
else:
return self.kv_cache_prefill(key, value, decoder_segment_ids), None
elif model_mode == MODEL_MODE_AUTOREGRESSIVE:
if self.is_deepseek_v4 and self.compress_rate > 1:
return self.kv_cache_autoregressive_v4(key, value, gate, use_ragged_attention)
return self.kv_cache_autoregressive(key, value, use_ragged_attention)
else:
raise ValueError(f"Model Mode isn't supported! {model_mode=}")
Expand Down Expand Up @@ -1128,6 +1260,7 @@ def __call__(
model_mode: str,
use_ragged_attention: bool = False,
previous_chunk: Any = None,
gate: Optional[Array] = None,
) -> tuple[
None | tuple[Array, Array, Array],
None | tuple[Array, Array, Array, Array],
Expand Down
33 changes: 33 additions & 0 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,17 @@ def copy(path, partial_cache, full_cache, annotations):
if batch_idx < 0:
raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}")

if path_key in [
"entry_count",
"accumulator_index",
"leftover_buffer_kv",
"leftover_buffer_gate",
"overlap_kv",
"overlap_gate"
]:
# Copy these states by explicitly overwriting the target slot matching current request id
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)

for slot in slots:
if path_key == "cache_ar_segment_id":
### goal: zero this out in case there is existing data
Expand Down Expand Up @@ -1556,6 +1567,17 @@ def copy(path, partial_cache, full_cache, annotations):

if batch_idx < 0:
raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}")

if path_key in [
"entry_count",
"accumulator_index",
"leftover_buffer_kv",
"leftover_buffer_gate",
"overlap_kv",
"overlap_gate"
]:
# Copy these states by explicitly overwriting the target slot matching current request id
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)

if path_key == "cache_ar_segment_id":
s = list(full_cache.shape)
Expand Down Expand Up @@ -1690,6 +1712,17 @@ def copy(path, partial_cache, full_cache, annotations):

if batch_idx < 0:
raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}")

if path_key in [
"entry_count",
"accumulator_index",
"leftover_buffer_kv",
"leftover_buffer_gate",
"overlap_kv",
"overlap_gate"
]:
# Direct batch slot index overwrite for fixed-size metadata trackers
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)

if path_key == "cache_ar_segment_id":
### goal: zero this out in case there is existing data
Expand Down
Loading
Loading