@@ -566,6 +566,8 @@ def mla_decode_stage1_asm_fwd(
566566 work_indptr : Optional [torch .Tensor ],
567567 work_info_set : Optional [torch .Tensor ],
568568 max_seqlen_q : int ,
569+ page_size : int ,
570+ nhead_kv : int ,
569571 softmax_scale : float ,
570572 # [batch_size, num_kv_splits, num_heads, v_head_dim]
571573 splitData : torch .Tensor ,
@@ -854,6 +856,7 @@ def get_mla_metadata_info_v1(
854856def get_mla_metadata_v1 (
855857 seqlens_qo_indptr : torch .Tensor ,
856858 seqlens_kv_indptr : torch .Tensor ,
859+ kv_last_page_lens : torch .Tensor ,
857860 num_heads_per_head_k : int ,
858861 num_heads_k : int ,
859862 is_causal : bool ,
@@ -863,6 +866,7 @@ def get_mla_metadata_v1(
863866 reduce_indptr : torch .Tensor ,
864867 reduce_final_map : torch .Tensor ,
865868 reduce_partial_map : torch .Tensor ,
869+ page_size : int = 1 ,
866870 kv_granularity : int = 16 ,
867871 max_seqlen_qo : int = - 1 ,
868872 uni_seqlen_qo : int = - 1 ,
@@ -876,12 +880,14 @@ def get_mla_metadata_v1(
876880 """
877881 Inputs:
878882 cumulated seqlens of q/o: (batch_size + 1), dtype torch.int32.
879- cumulated seqlens of k/v: (batch_size + 1), dtype torch.int32.
883+ cumulated page indices of k/v: (batch_size + 1), dtype torch.int32.
884+ Length of last page of k/v: (batch_size), dtype torch.int32.
880885 num_heads_per_head_k: Equals to num_heads_q // num_heads_k.
881886 num_heads_k: num_heads_k.
882887 is_causal: Whether causal mask is enabled.
883888 Options: Detailed settings for spliting. All of them are optional.
884- kv_granularity: default=16. The granularity on kv sequence length when cutting batch.
889+ page_size: default=1. The size of a page.
890+ kv_granularity: default=16. The granularity on kv page nums when cutting batch.
885891 max_seqlen_qo: default=-1. Used to check lds usage and save time. value less than 1 means unknown.
886892 uni_seqlen_qo: default=-1. Sequence length of qo is uniform across batches. value less than 1 means the
887893 length is not fixed.
@@ -899,11 +905,11 @@ def get_mla_metadata_v1(
899905 [2.2] q_start: (#work), The global index in seq where q/o starts. Use global index here can
900906 reduce memory access count in kernel.
901907 [2.3] q_end: (#work), The global index in seq where q/o ends (not included).
902- [2.4] kv_start: (#work), The global index in seq where k/v starts.
903- [2.5] kv_end: (#work), The global index in seq where k/v ends (not included). Note that
908+ [2.4] kv_start: (#work), The global index in page where k/v starts.
909+ [2.5] kv_end: (#work), The global index in page where k/v ends (not included). Note that
904910 this value indicates the end of last qo sequence if there are
905911 multiple qo sequences included in the current work and causal mask
906- is enabled.
912+ is enabled when page_size is 1 .
907913 [2.6] kv_offset: (#work), Remaining length in seq from kv_end to the end of current batch.
908914 [2.7] pad (#work, 1), Pad to 8 DWs.
909915 [3] reduce_indptr: (sum(qo_seqlen_blk_count) + 1),
0 commit comments