Skip to content

Commit 0e0a378

Browse files
authored
mla ps support paged 64 and 3buffer layout for ds3.2 (ROCm#1917)
* mla ps support paged 64 and 3buffer layout for ds3.2
1 parent 7229d74 commit 0e0a378

14 files changed

Lines changed: 726 additions & 270 deletions

File tree

aiter/mla.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def mla_decode_fwd(
150150
kv_indices,
151151
kv_last_page_lens,
152152
max_seqlen_q,
153+
page_size=1,
154+
nhead_kv=1,
153155
sm_scale=None, # 1.0 / (qk_head_dim**0.5)
154156
logit_cap=0.0,
155157
num_kv_splits=None, # for experts only!!!
@@ -168,7 +170,11 @@ def mla_decode_fwd(
168170
):
169171
device = q.device
170172
assert logit_cap <= 0, f"{logit_cap=} is not support yet"
171-
num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape
173+
if kv_buffer.dtype != torch.uint8:
174+
_, _, _, qk_head_dim = kv_buffer.shape
175+
else:
176+
_, _, qk_head_dim = q.shape
177+
172178
if sm_scale is None:
173179
sm_scale = 1.0 / (qk_head_dim**0.5)
174180

@@ -227,6 +233,8 @@ def mla_decode_fwd(
227233
None,
228234
None,
229235
max_seqlen_q,
236+
page_size,
237+
nhead_kv,
230238
sm_scale,
231239
logits,
232240
attn_lse,
@@ -319,6 +327,8 @@ def mla_decode_fwd(
319327
work_indptr,
320328
work_info_set,
321329
max_seqlen_q,
330+
page_size,
331+
nhead_kv,
322332
sm_scale,
323333
logits,
324334
attn_lse,

aiter/ops/attention.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ def mla_decode_stage1_asm_fwd(
566566
work_indptr: Optional[torch.Tensor],
567567
work_info_set: Optional[torch.Tensor],
568568
max_seqlen_q: int,
569+
page_size: int,
570+
nhead_kv: int,
569571
softmax_scale: float,
570572
# [batch_size, num_kv_splits, num_heads, v_head_dim]
571573
splitData: torch.Tensor,
@@ -854,6 +856,7 @@ def get_mla_metadata_info_v1(
854856
def get_mla_metadata_v1(
855857
seqlens_qo_indptr: torch.Tensor,
856858
seqlens_kv_indptr: torch.Tensor,
859+
kv_last_page_lens: torch.Tensor,
857860
num_heads_per_head_k: int,
858861
num_heads_k: int,
859862
is_causal: bool,
@@ -863,6 +866,7 @@ def get_mla_metadata_v1(
863866
reduce_indptr: torch.Tensor,
864867
reduce_final_map: torch.Tensor,
865868
reduce_partial_map: torch.Tensor,
869+
page_size: int = 1,
866870
kv_granularity: int = 16,
867871
max_seqlen_qo: int = -1,
868872
uni_seqlen_qo: int = -1,
@@ -876,12 +880,14 @@ def get_mla_metadata_v1(
876880
"""
877881
Inputs:
878882
cumulated seqlens of q/o: (batch_size + 1), dtype torch.int32.
879-
cumulated seqlens of k/v: (batch_size + 1), dtype torch.int32.
883+
cumulated page indices of k/v: (batch_size + 1), dtype torch.int32.
884+
Length of last page of k/v: (batch_size), dtype torch.int32.
880885
num_heads_per_head_k: Equals to num_heads_q // num_heads_k.
881886
num_heads_k: num_heads_k.
882887
is_causal: Whether causal mask is enabled.
883888
Options: Detailed settings for spliting. All of them are optional.
884-
kv_granularity: default=16. The granularity on kv sequence length when cutting batch.
889+
page_size: default=1. The size of a page.
890+
kv_granularity: default=16. The granularity on kv page nums when cutting batch.
885891
max_seqlen_qo: default=-1. Used to check lds usage and save time. value less than 1 means unknown.
886892
uni_seqlen_qo: default=-1. Sequence length of qo is uniform across batches. value less than 1 means the
887893
length is not fixed.
@@ -899,11 +905,11 @@ def get_mla_metadata_v1(
899905
[2.2] q_start: (#work), The global index in seq where q/o starts. Use global index here can
900906
reduce memory access count in kernel.
901907
[2.3] q_end: (#work), The global index in seq where q/o ends (not included).
902-
[2.4] kv_start: (#work), The global index in seq where k/v starts.
903-
[2.5] kv_end: (#work), The global index in seq where k/v ends (not included). Note that
908+
[2.4] kv_start: (#work), The global index in page where k/v starts.
909+
[2.5] kv_end: (#work), The global index in page where k/v ends (not included). Note that
904910
this value indicates the end of last qo sequence if there are
905911
multiple qo sequences included in the current work and causal mask
906-
is enabled.
912+
is enabled when page_size is 1.
907913
[2.6] kv_offset: (#work), Remaining length in seq from kv_end to the end of current batch.
908914
[2.7] pad (#work, 1), Pad to 8 DWs.
909915
[3] reduce_indptr: (sum(qo_seqlen_blk_count) + 1),

csrc/include/attention_asm_mla.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ void mla_decode_stage1_asm_fwd(
1515
std::optional<torch::Tensor>& work_indptr, // metadata
1616
std::optional<torch::Tensor>& work_info_set, // [batch_size+1]
1717
int max_seqlen_q,
18+
int page_size,
19+
int nhead_kv,
1820
float softmax_scale,
1921
// following are output
2022
torch::Tensor& splitData, //[batch_size, num_kv_splits, num_heads, v_head_dim]

csrc/include/mla.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
33

44
#pragma once
55

@@ -37,6 +37,7 @@ static_assert(kSizeMlaPartialTileInfoInDw == 2);
3737

3838
void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]
3939
const torch::Tensor& seqlens_kv_indptr, // [batch size + 1]
40+
const torch::Tensor& kv_last_page_lens, // [batch size]
4041
const int32_t num_heads_per_head_k,
4142
const int32_t num_heads_k,
4243
const bool is_causal,
@@ -46,13 +47,14 @@ void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size
4647
torch::Tensor& reduce_indptr,
4748
torch::Tensor& reduce_final_map,
4849
torch::Tensor& reduce_partial_map,
50+
const int32_t page_size,
4951
const int32_t kv_granularity,
5052
const int32_t max_seqlen_qo,
5153
const int32_t uni_seqlen_qo,
5254
const bool fast_mode,
5355
const int32_t topk,
5456
const int32_t max_split_per_batch,
55-
const bool intra_batch_mode,
57+
const bool intra_batch_mode,
5658
const std::optional<at::ScalarType> dtype_q,
5759
const std::optional<at::ScalarType> dtype_kv);
5860

csrc/include/rocm_ops.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ namespace py = pybind11;
5757
py::arg("work_indptr"), \
5858
py::arg("work_info_set"), \
5959
py::arg("max_seqlen_q"), \
60+
py::arg("page_size"), \
61+
py::arg("nhead_kv"), \
6062
py::arg("softmax_scale"), \
6163
py::arg("splitData"), \
6264
py::arg("splitLse"), \
@@ -1654,6 +1656,7 @@ namespace py = pybind11;
16541656
"get_mla_metadata_v1", \
16551657
py::arg("seqlens_qo_indptr"), \
16561658
py::arg("seqlens_kv_indptr"), \
1659+
py::arg("kv_last_page_lens"), \
16571660
py::arg("num_heads_per_head_k"), \
16581661
py::arg("num_heads_k"), \
16591662
py::arg("is_causal"), \
@@ -1663,6 +1666,7 @@ namespace py = pybind11;
16631666
py::arg("reduce_indptr"), \
16641667
py::arg("reduce_final_map"), \
16651668
py::arg("reduce_partial_map"), \
1669+
py::arg("page_size") = 1, \
16661670
py::arg("kv_granularity") = 16, \
16671671
py::arg("max_seqlen_qo") = -1, \
16681672
py::arg("uni_seqlen_qo") = -1, \

csrc/kernels/mla/metadata.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
55
#include "metadata/v1_0_device.cuh"
@@ -40,6 +40,7 @@
4040
void get_mla_metadata_v1(
4141
const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]
4242
const torch::Tensor& seqlens_kv_indptr, // [batch size + 1]
43+
const torch::Tensor& kv_last_page_lens, // [batch size]
4344
const int32_t num_heads_per_head_k,
4445
const int32_t num_heads_k,
4546
const bool is_causal,
@@ -49,6 +50,7 @@ void get_mla_metadata_v1(
4950
torch::Tensor& reduce_indptr,
5051
torch::Tensor& reduce_final_map,
5152
torch::Tensor& reduce_partial_map,
53+
const int32_t page_size,
5254
const int32_t kv_granularity,
5355
const int32_t max_seqlen_qo,
5456
const int32_t uni_seqlen_qo,
@@ -63,6 +65,8 @@ void get_mla_metadata_v1(
6365

6466
TORCH_CHECK((kv_granularity & (kv_granularity - 1)) == 0,
6567
__func__, ": kv_granularity Must be power of 2!");
68+
TORCH_CHECK((page_size & (page_size - 1)) == 0,
69+
__func__, ": page_size Must be power of 2!");
6670
TORCH_CHECK(seqlens_qo_indptr.stride(0) == 1,
6771
__func__, ": seqlens_qo_indptr should be continuous!");
6872
TORCH_CHECK(seqlens_qo_indptr.scalar_type() == at::ScalarType::Int,
@@ -71,6 +75,10 @@ void get_mla_metadata_v1(
7175
__func__, ": seqlens_kv_indptr should be continuous!");
7276
TORCH_CHECK(seqlens_kv_indptr.scalar_type() == at::ScalarType::Int,
7377
__func__, ": seqlens_kv_indptr's element type should be int!");
78+
TORCH_CHECK(kv_last_page_lens.stride(0) == 1,
79+
__func__, ": kv_last_page_lens should be continuous!");
80+
TORCH_CHECK(kv_last_page_lens.scalar_type() == at::ScalarType::Int,
81+
__func__, ": kv_last_page_lens's element type should be int!");
7482

7583
at::ScalarType q_dtype = dtype_q.has_value() ? dtype_q.value() : at::ScalarType::BFloat16;
7684
at::ScalarType kv_dtype = dtype_kv.has_value() ? dtype_kv.value() : at::ScalarType::BFloat16;
@@ -80,9 +88,11 @@ void get_mla_metadata_v1(
8088
get_mla_metadata_v1_2_device(
8189
seqlens_qo_indptr,
8290
seqlens_kv_indptr,
91+
kv_last_page_lens,
8392
num_heads_per_head_k,
8493
num_heads_k,
8594
is_causal,
95+
page_size,
8696
kv_granularity,
8797
max_seqlen_qo,
8898
uni_seqlen_qo,

0 commit comments

Comments
 (0)