Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,15 +816,15 @@ void WebGPUGraph::execute() {
wgpuComputePassEncoderSetBindGroup(
pass, 0, dispatch.bind_group, 0, nullptr);
wgpuComputePassEncoderDispatchWorkgroups(
pass, dispatch.workgroup_count_x, 1, 1);
pass, dispatch.workgroup_count_x, dispatch.workgroup_count_y, 1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
#ifdef WGPU_BACKEND_ENABLE_PROFILING
if (qp) {
qp->record(
static_cast<uint32_t>(i),
dispatch.kernel_name,
{dispatch.workgroup_count_x, 1, 1},
{dispatch.workgroup_count_x, dispatch.workgroup_count_y, 1},
{1, 1, 1});
}
#endif // WGPU_BACKEND_ENABLE_PROFILING
Expand Down Expand Up @@ -896,7 +896,10 @@ void WebGPUGraph::execute() {
wgpuComputePassEncoderSetBindGroup(
pass, 0, dispatches_[i].bind_group, 0, nullptr);
wgpuComputePassEncoderDispatchWorkgroups(
pass, dispatches_[i].workgroup_count_x, 1, 1);
pass,
dispatches_[i].workgroup_count_x,
dispatches_[i].workgroup_count_y,
1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
}
Expand Down
1 change: 1 addition & 0 deletions backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct WebGPUDispatch {
WGPUBindGroup bind_group = nullptr;
uint32_t workgroup_count_x = 1;
std::string kernel_name; // bench label
uint32_t workgroup_count_y = 1; // 2D fold (>65535); 1 = unchanged 1D path
// DMA copy command; default Compute keeps existing positional inits valid.
enum class Kind { Compute, Copy };
Kind kind = Kind::Compute;
Expand Down
55 changes: 48 additions & 7 deletions backends/webgpu/runtime/WebGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,68 @@ inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) {
return desired;
}

struct WgCount {
uint32_t x;
uint32_t y;
};

// Device's max workgroups per dispatch dimension; the WebGPU spec-default floor
// (65535) if the query fails — never under-reports a real device's capacity.
inline uint32_t queried_max_workgroups(WGPUDevice device) {
WGPULimits limits = {};
return wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
limits.maxComputeWorkgroupsPerDimension > 0
? limits.maxComputeWorkgroupsPerDimension
: 65535u;
}

// Pure 2D fold of a 1D workgroup count (device-free, unit-testable): {count,1}
// when count <= max, else {max, div_up(count, max)} so a >max workload fits the
// per-dimension cap; throws if a 3rd dimension would be needed (out of scope).
// The shader reconstructs the linear index from @builtin(num_workgroups).
inline WgCount fold_workgroup_count_2d(
uint32_t count,
uint32_t max_count,
const char* op_name) {
if (count <= max_count) {
return {count, 1u};
}
uint32_t y = (count + max_count - 1) / max_count;
if (y > max_count) {
throw std::runtime_error(
std::string("WebGPU ") + op_name +
": workgroup count needs a 3rd dispatch dimension (unsupported)");
}
return {max_count, y};
}

// 1D dispatch count (mirrors Vulkan div_up); throws if > device limit.
inline uint32_t compute_1d_workgroup_count(
WGPUDevice device,
uint32_t num_threads,
uint32_t workgroup_size,
const char* op_name) {
uint32_t count = div_up(num_threads, workgroup_size);
WGPULimits limits = {};
uint32_t max_count =
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
limits.maxComputeWorkgroupsPerDimension > 0
? limits.maxComputeWorkgroupsPerDimension
: 65535u; // WebGPU spec-default floor
if (count > max_count) {
if (count > queried_max_workgroups(device)) {
throw std::runtime_error(
std::string("WebGPU ") + op_name +
": workgroup count exceeds the 1D dispatch limit");
}
return count;
}

// 2D dispatch count: fold the 1D count across x/y when it exceeds the per-dim
// limit (lifts the cap, e.g. for SDPA prefill). Same fast path as compute_1d.
inline WgCount compute_2d_workgroup_count(
WGPUDevice device,
uint32_t num_threads,
uint32_t workgroup_size,
const char* op_name) {
uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size;
return fold_workgroup_count_2d(
count, queried_max_workgroups(device), op_name);
}

// Create a uniform buffer mapped-at-creation, copy `size` bytes in, and unmap.
inline WGPUBuffer
make_uniform(WGPUDevice device, const void* data, size_t size) {
Expand Down
60 changes: 29 additions & 31 deletions backends/webgpu/runtime/ops/add/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {

uint32_t wg_size =
utils::clamp_workgroup_size(device, kBinaryAddWorkgroupSizeX);
uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, num_elements, wg_size, "add");
utils::WgCount workgroup_count =
utils::compute_2d_workgroup_count(device, num_elements, wg_size, "add");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
Expand Down Expand Up @@ -158,40 +158,38 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});
graph.add_dispatch(
{pipeline, bind_group, workgroup_count.x, "", workgroup_count.y});
const size_t dispatch_idx = graph.num_dispatches() - 1;

// Dynamic shapes: recompute numel/dispatch; out follows the larger operand.
WGPUBuffer params_buf = uniform_buffer;
auto add_resize = [in1_id,
in2_id,
out_id,
alpha,
wg_size,
dispatch_idx,
params_buf](WebGPUGraph& g) {
const auto& d1 = g.cur_dims(in1_id);
const auto& d2 = g.cur_dims(in2_id);
const uint64_t n1 = utils::numel_of(d1);
const uint64_t n2 = utils::numel_of(d2);
const uint64_t numel = n2 > n1 ? n2 : n1;
const uint64_t n_min = n2 > n1 ? n1 : n2;
// The flat add follows the larger operand and broadcasts the smaller; valid
// only when the smaller tiles evenly into it (rejects e.g. [4,1] vs [1,3],
// whose true [4,3] result this flat kernel cannot produce).
if (n_min == 0u || numel % n_min != 0u) {
throw std::runtime_error(
"add(resize): operands are not broadcast-compatible by numel");
}
g.set_cur_dims(out_id, n2 > n1 ? d2 : d1);
AddParams p = {};
p.num_elements = static_cast<uint32_t>(numel);
p.alpha = alpha;
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
auto add_resize =
[in1_id, in2_id, out_id, alpha, wg_size, dispatch_idx, params_buf](
WebGPUGraph& g) {
const auto& d1 = g.cur_dims(in1_id);
const auto& d2 = g.cur_dims(in2_id);
const uint64_t n1 = utils::numel_of(d1);
const uint64_t n2 = utils::numel_of(d2);
const uint64_t numel = n2 > n1 ? n2 : n1;
const uint64_t n_min = n2 > n1 ? n1 : n2;
// The flat add follows the larger operand and broadcasts the smaller;
// valid only when the smaller tiles evenly into it (rejects e.g. [4,1]
// vs [1,3], whose true [4,3] result this flat kernel cannot produce).
if (n_min == 0u || numel % n_min != 0u) {
throw std::runtime_error(
"add(resize): operands are not broadcast-compatible by numel");
}
g.set_cur_dims(out_id, n2 > n1 ? d2 : d1);
AddParams p = {};
p.num_elements = static_cast<uint32_t>(numel);
p.alpha = alpha;
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
const utils::WgCount wgc = utils::compute_2d_workgroup_count(
g.device(), static_cast<uint32_t>(numel), wg_size, "add(resize)");
};
g.dispatch_at(dispatch_idx).workgroup_count_x = wgc.x;
g.dispatch_at(dispatch_idx).workgroup_count_y = wgc.y;
};
graph.add_tensor_resize_hook(in1_id, add_resize);
graph.add_tensor_resize_hook(in2_id, add_resize);

Expand Down
6 changes: 4 additions & 2 deletions backends/webgpu/runtime/ops/add/binary_add.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ struct Params {
override wg_size: u32 = 256;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= params.num_elements) {
return;
}
Expand Down
8 changes: 5 additions & 3 deletions backends/webgpu/runtime/ops/add/binary_add_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from binary_add.wgsl - DO NOT EDIT.
// wgsl-sha256: c1ceec80c8d4d3d56986ad91ce0d7f9a57cd8467b8c3aa07a28da70e51d141d9
// wgsl-sha256: e66bd67465c2a0296e09668df54f87605a4c91015a615f3734cdd0f140a74477
inline constexpr const char* kBinaryAddWGSL = R"(
@group(0) @binding(0) var<storage, read> input1: array<f32>;
@group(0) @binding(1) var<storage, read> input2: array<f32>;
Expand All @@ -28,8 +28,10 @@ struct Params {
override wg_size: u32 = 256;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= params.num_elements) {
return;
}
Expand Down
45 changes: 29 additions & 16 deletions backends/webgpu/runtime/ops/sdpa/Sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ void build_dispatch(
WGPUBuffer uniform_buffer,
uint64_t uniform_size,
uint32_t workgroup_count_x,
uint32_t workgroup_count_y,
uint32_t wg_size,
bool retain_uniform = false,
const char* kernel_name = "") {
Expand Down Expand Up @@ -216,7 +217,12 @@ void build_dispatch(
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count_x, kernel_name});
graph.add_dispatch(
{pipeline,
bind_group,
workgroup_count_x,
kernel_name,
workgroup_count_y});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
Expand Down Expand Up @@ -257,6 +263,7 @@ static WGPUBuffer record_update_cache_dispatch(
ubuf,
sizeof(uc),
wgc,
1,
uc_wg,
retain_uniform,
"update_cache");
Expand Down Expand Up @@ -478,7 +485,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
}
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
utils::div_up(context_len, kSdpaTileN);
const uint32_t wgc = utils::compute_1d_workgroup_count(
const utils::WgCount wgc = utils::compute_2d_workgroup_count(
device, static_cast<uint32_t>(qk_tiles), qk_wg, "QK");
AttnWeightsParams p = make_attn_weights_params(
S, Hq, Hkv, D, context_len, input_pos, g, scale);
Expand All @@ -494,7 +501,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
3,
ubuf,
sizeof(p),
wgc,
wgc.x,
wgc.y,
qk_wg,
true,
"sdpa_compute_attn_weights");
Expand All @@ -505,7 +513,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// Dispatch 4: softmax, one workgroup per (h,s) row of width context_len.
{
// One workgroup per (h,s) row; wg_size 1 keeps the device dispatch check.
const uint32_t wgc = utils::compute_1d_workgroup_count(
const utils::WgCount wgc = utils::compute_2d_workgroup_count(
device, static_cast<uint32_t>(Hq * S), 1, "softmax");
SoftmaxParams p = make_softmax_params(Hq, S, context_len);
WGPUBuffer ubuf = graph.make_uniform_buffer(&p, sizeof(p));
Expand All @@ -518,7 +526,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
2,
ubuf,
sizeof(p),
wgc,
wgc.x,
wgc.y,
0,
true,
"sdpa_softmax");
Expand All @@ -530,7 +539,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
{
const int64_t av_tiles =
Hq * utils::div_up(S, kSdpaTileM) * utils::div_up(D, kSdpaTileN);
const uint32_t wgc = utils::compute_1d_workgroup_count(
const utils::WgCount wgc = utils::compute_2d_workgroup_count(
device, static_cast<uint32_t>(av_tiles), av_wg, "AV");
ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g);
WGPUBuffer ubuf = graph.make_uniform_buffer(&p, sizeof(p));
Expand All @@ -545,7 +554,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
3,
ubuf,
sizeof(p),
wgc,
wgc.x,
wgc.y,
av_wg,
true,
"sdpa_compute_out");
Expand Down Expand Up @@ -632,25 +642,28 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
const int64_t qk_tiles =
Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(ctx, kSdpaTileN);
gr.dispatch_at(qk_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(qk_tiles), qk_wg, "QK(resize)");
const utils::WgCount qk_wgc = utils::compute_2d_workgroup_count(
gr.device(), static_cast<uint32_t>(qk_tiles), qk_wg, "QK(resize)");
gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc.x;
gr.dispatch_at(qk_idx).workgroup_count_y = qk_wgc.y;

// softmax: one workgroup per (h,s) row.
SoftmaxParams sp = make_softmax_params(Hq, s, ctx);
wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp));
gr.dispatch_at(softmax_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(Hq * s), 1, "softmax(resize)");
const utils::WgCount sm_wgc = utils::compute_2d_workgroup_count(
gr.device(), static_cast<uint32_t>(Hq * s), 1, "softmax(resize)");
gr.dispatch_at(softmax_idx).workgroup_count_x = sm_wgc.x;
gr.dispatch_at(softmax_idx).workgroup_count_y = sm_wgc.y;

// AV: one thread per TM x TN tile; grid = Hq*ceil(S/TM)*ceil(D/TN).
ComputeOutParams op = make_compute_out_params(s, Hq, Hkv, D, ctx, g);
wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op));
const int64_t av_tiles =
Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(D, kSdpaTileN);
gr.dispatch_at(av_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(av_tiles), av_wg, "AV(resize)");
const utils::WgCount av_wgc = utils::compute_2d_workgroup_count(
gr.device(), static_cast<uint32_t>(av_tiles), av_wg, "AV(resize)");
gr.dispatch_at(av_idx).workgroup_count_x = av_wgc.x;
gr.dispatch_at(av_idx).workgroup_count_y = av_wgc.y;

// Output attn has the same shape as q: [.., S, Hq, D].
gr.set_cur_dims(out_id, gr.cur_dims(q_id));
Expand Down
12 changes: 8 additions & 4 deletions backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,21 @@ fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
}

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
let nrt = (params.S + TM - 1u) / TM;
let nct = (params.context_len + TN - 1u) / TN;
let tiles = nrt * nct;
let total = tiles * params.Hq;
if (gid.x >= total) {
// 2D dispatch fold: recover the linear tile index across x/y.
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= total) {
return;
}

let h = gid.x / tiles;
let rem = gid.x % tiles;
let h = idx / tiles;
let rem = idx % tiles;
let row_tile = rem / nct;
let col_tile = rem % nct;
let kvh = h / params.g;
Expand Down
Loading
Loading