Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 117 additions & 73 deletions backends/webgpu/runtime/ops/sdpa/Sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ static WGPUBuffer record_update_cache_dispatch(
uint32_t kv_dst_offset,
uint64_t cache_numel,
uint32_t uc_wg,
bool dynamic_pos,
bool retain_uniform,
const char* label) {
const uint32_t wgc = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(kv_numel), uc_wg, label);
Expand All @@ -258,7 +258,7 @@ static WGPUBuffer record_update_cache_dispatch(
sizeof(uc),
wgc,
uc_wg,
dynamic_pos,
retain_uniform,
"update_cache");
return ubuf;
}
Expand Down Expand Up @@ -417,7 +417,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// Dynamic input_pos: the resize hook rewrites these per step.
WGPUBuffer uc_k_buf = nullptr, uc_v_buf = nullptr, qk_buf = nullptr,
softmax_buf = nullptr, av_buf = nullptr;
size_t qk_idx = 0;
size_t qk_idx = 0, uc_k_idx = 0, uc_v_idx = 0, softmax_idx = 0, av_idx = 0;

const WGPUDevice device = graph.device();
const uint32_t uc_wg =
Expand All @@ -442,7 +442,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
kv_dst_offset,
numel(k_cache),
uc_wg,
dynamic_pos,
true,
"update_cache(K)");
uc_v_buf = record_update_cache_dispatch(
graph,
Expand All @@ -453,8 +453,10 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
kv_dst_offset,
numel(v_cache),
uc_wg,
dynamic_pos,
true,
"update_cache(V)");
uc_k_idx = graph.num_dispatches() - 2;
uc_v_idx = graph.num_dispatches() - 1;

// FlashDecoding decode (S==1, static pos). Shapes FD can't handle (head dim
// > kSdpaFdMaxHeadDim) fall through to the materialized path below.
Expand Down Expand Up @@ -494,7 +496,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
qk_wg,
dynamic_pos,
true,
"sdpa_compute_attn_weights");
qk_buf = ubuf;
qk_idx = graph.num_dispatches() - 1;
Expand All @@ -518,9 +520,10 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
0,
dynamic_pos,
true,
"sdpa_softmax");
softmax_buf = ubuf;
softmax_idx = graph.num_dispatches() - 1;
}

// --- Dispatch 5: AV -> out. One thread per TM x TN tile.
Expand All @@ -544,77 +547,118 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
av_wg,
dynamic_pos,
true,
"sdpa_compute_out");
av_buf = ubuf;
av_idx = graph.num_dispatches() - 1;
}

// Per-step recompute hook; mirrors Vulkan DynamicDispatchNode.
// Per-step recompute: live S (q resize) or input_pos (SymInt); inert if
// static.
const int64_t pos_const = input_pos;
auto sdpa_resize = [q_id,
qn,
S,
out_id,
dynamic_pos,
input_pos_id,
pos_const,
Hq,
Hkv,
D,
Cmax,
g,
scale,
qk_idx,
uc_k_idx,
uc_v_idx,
softmax_idx,
av_idx,
uc_wg,
qk_wg,
av_wg,
uc_k_buf,
uc_v_buf,
qk_buf,
softmax_buf,
av_buf](WebGPUGraph& gr) {
const int64_t s = gr.cur_dims(q_id)[qn - 3];
const int64_t pos = dynamic_pos
? static_cast<int64_t>(gr.read_symint(input_pos_id))
: pos_const;
if (s <= 0 || pos < 0) {
throw std::runtime_error("WebGPU sdpa: invalid live S or input_pos");
}
// Scratch (attn_weights/softmax) is sized at build for S=max; a larger live
// S would overrun it. Make that invariant load-bearing.
if (s > S) {
throw std::runtime_error(
"WebGPU sdpa: live S exceeds the build-time max (scratch capacity)");
}
const int64_t ctx = s + pos;
if (ctx <= 0 || ctx > Cmax) {
throw std::runtime_error(
"WebGPU sdpa: context_len exceeds cache capacity");
}
const uint32_t kv_off = static_cast<uint32_t>(
static_cast<uint64_t>(pos) * static_cast<uint64_t>(Hkv) *
static_cast<uint64_t>(D));
const uint64_t aw_floats = static_cast<uint64_t>(Hq) *
static_cast<uint64_t>(s) * static_cast<uint64_t>(ctx);
if (aw_floats > UINT32_MAX) {
throw std::runtime_error("WebGPU sdpa: Hq*S*context_len exceeds uint32");
}
const uint64_t kv_numel = static_cast<uint64_t>(s) *
static_cast<uint64_t>(Hkv) * static_cast<uint64_t>(D);
if (kv_numel > UINT32_MAX) {
throw std::runtime_error("WebGPU sdpa: S*Hkv*D exceeds uint32");
}
const uint64_t k_cache_numel = static_cast<uint64_t>(Cmax) *
static_cast<uint64_t>(Hkv) * static_cast<uint64_t>(D);

// update_cache K/V: dispatch (kv_numel) + dst offset scale with live S/pos.
UpdateCacheParams uc =
make_update_cache_params(kv_numel, kv_off, k_cache_numel);
wgpuQueueWriteBuffer(gr.queue(), uc_k_buf, 0, &uc, sizeof(uc));
wgpuQueueWriteBuffer(gr.queue(), uc_v_buf, 0, &uc, sizeof(uc));
const uint32_t uc_wgc = utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(kv_numel), uc_wg, "uc(resize)");
gr.dispatch_at(uc_k_idx).workgroup_count_x = uc_wgc;
gr.dispatch_at(uc_v_idx).workgroup_count_x = uc_wgc;

// QK: one thread per TM x TN tile; grid = Hq*ceil(S/TM)*ceil(ctx/TN).
AttnWeightsParams qp =
make_attn_weights_params(s, Hq, Hkv, D, ctx, pos, g, scale);
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
const int64_t qk_tiles =
Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(ctx, kSdpaTileN);
gr.dispatch_at(qk_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(qk_tiles), qk_wg, "QK(resize)");

// softmax: one workgroup per (h,s) row.
SoftmaxParams sp = make_softmax_params(Hq, s, ctx);
wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp));
gr.dispatch_at(softmax_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(Hq * s), 1, "softmax(resize)");

// AV: one thread per TM x TN tile; grid = Hq*ceil(S/TM)*ceil(D/TN).
ComputeOutParams op = make_compute_out_params(s, Hq, Hkv, D, ctx, g);
wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op));
const int64_t av_tiles =
Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(D, kSdpaTileN);
gr.dispatch_at(av_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(av_tiles), av_wg, "AV(resize)");

// Output attn has the same shape as q: [.., S, Hq, D].
gr.set_cur_dims(out_id, gr.cur_dims(q_id));
};
// q and input_pos share one idempotent recompute; a double-fire is harmless.
graph.add_tensor_resize_hook(q_id, sdpa_resize);
if (dynamic_pos) {
graph.add_resize_hook(
input_pos_id,
[input_pos_id,
S,
Hq,
Hkv,
D,
Cmax,
g,
scale,
qk_idx,
qk_wg,
uc_k_buf,
uc_v_buf,
qk_buf,
softmax_buf,
av_buf](WebGPUGraph& gr) {
const int32_t pos = gr.read_symint(input_pos_id);
if (pos < 0) {
throw std::runtime_error(
"WebGPU sdpa: input_pos must be non-negative");
}
const int64_t ctx = S + pos;
if (ctx <= 0 || ctx > Cmax) {
throw std::runtime_error(
"WebGPU sdpa: context_len exceeds cache capacity");
}
const uint32_t kv_off = static_cast<uint32_t>(
static_cast<uint64_t>(pos) * static_cast<uint64_t>(Hkv) *
static_cast<uint64_t>(D));
const uint64_t aw_floats = static_cast<uint64_t>(Hq) *
static_cast<uint64_t>(S) * static_cast<uint64_t>(ctx);
if (aw_floats > UINT32_MAX) {
throw std::runtime_error(
"WebGPU sdpa: Hq*S*context_len exceeds uint32 max");
}
const uint64_t kv_numel = static_cast<uint64_t>(S) *
static_cast<uint64_t>(Hkv) * static_cast<uint64_t>(D);
const uint64_t k_cache_numel = static_cast<uint64_t>(Cmax) *
static_cast<uint64_t>(Hkv) * static_cast<uint64_t>(D);

UpdateCacheParams uc =
make_update_cache_params(kv_numel, kv_off, k_cache_numel);
wgpuQueueWriteBuffer(gr.queue(), uc_k_buf, 0, &uc, sizeof(uc));
wgpuQueueWriteBuffer(gr.queue(), uc_v_buf, 0, &uc, sizeof(uc));

AttnWeightsParams qp =
make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale);
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
utils::div_up(ctx, kSdpaTileN);
const uint32_t qk_wgc = utils::compute_1d_workgroup_count(
gr.device(),
static_cast<uint32_t>(qk_tiles),
qk_wg,
"QK(resize)");
gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc;

SoftmaxParams sp = make_softmax_params(Hq, S, ctx);
wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp));

ComputeOutParams op = make_compute_out_params(S, Hq, Hkv, D, ctx, g);
wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op));
});
graph.add_resize_hook(input_pos_id, sdpa_resize);
}
}

Expand Down
Loading