Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions backends/webgpu/runtime/ops/mul/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {

uint32_t wg_size =
utils::clamp_workgroup_size(device, kBinaryMulWorkgroupSizeX);
uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, out_meta.numel, wg_size, "mul");
utils::WgCount workgroup_count =
utils::compute_2d_workgroup_count(device, out_meta.numel, wg_size, "mul");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
Expand Down Expand Up @@ -165,8 +165,8 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

const size_t dispatch_idx =
graph.add_dispatch({pipeline, bind_group, workgroup_count});
const size_t dispatch_idx = graph.add_dispatch(
{pipeline, bind_group, workgroup_count.x, "mul", workgroup_count.y});

// Dynamic shapes: rebuild all 3 broadcast TensorMeta UBOs + dispatch.
WGPUBuffer o_buf = out_meta_buf, a_buf = in1_meta_buf, b_buf = in2_meta_buf;
Expand Down Expand Up @@ -199,9 +199,10 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
wgpuQueueWriteBuffer(g.queue(), o_buf, 0, &om, sizeof(om));
wgpuQueueWriteBuffer(g.queue(), a_buf, 0, &am, sizeof(am));
wgpuQueueWriteBuffer(g.queue(), b_buf, 0, &bm, sizeof(bm));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(), om.numel, wg_size, "mul(resize)");
const utils::WgCount wgc = utils::compute_2d_workgroup_count(
g.device(), om.numel, wg_size, "mul(resize)");
g.dispatch_at(dispatch_idx).workgroup_count_x = wgc.x;
g.dispatch_at(dispatch_idx).workgroup_count_y = wgc.y;
};
graph.add_tensor_resize_hook(in1_id, mul_resize);
graph.add_tensor_resize_hook(in2_id, mul_resize);
Expand Down
7 changes: 5 additions & 2 deletions backends/webgpu/runtime/ops/mul/binary_mul.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ struct TensorMeta {
override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= out_meta.numel) {
return;
}
Expand Down
9 changes: 6 additions & 3 deletions backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from binary_mul.wgsl - DO NOT EDIT.
// wgsl-sha256: e7f77426cbaf48e6085e0d882522c027302ec97ef017b86a2275eed9820f7891
// wgsl-sha256: cca69c3428f37f293942637e23f664225dec81a56f184bcb63185b6629dd155e
inline constexpr const char* kBinaryMulWGSL = R"(
@group(0) @binding(0) var<storage, read> input1: array<f32>;
@group(0) @binding(1) var<storage, read> input2: array<f32>;
Expand All @@ -32,8 +32,11 @@ struct TensorMeta {
override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= out_meta.numel) {
return;
}
Expand Down
5 changes: 3 additions & 2 deletions backends/webgpu/runtime/ops/permute/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void permute_impl(WebGPUGraph& graph, const std::vector<int>& args) {

uint32_t wg_size =
utils::clamp_workgroup_size(device, kPermuteWorkgroupSizeX);
uint32_t workgroup_count = utils::compute_1d_workgroup_count(
utils::WgCount workgroup_count = utils::compute_2d_workgroup_count(
device, out_meta.numel, wg_size, "permute");

WGPUConstantEntry wg_size_constant = {};
Expand Down Expand Up @@ -176,7 +176,8 @@ void permute_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});
graph.add_dispatch(
{pipeline, bind_group, workgroup_count.x, "permute", workgroup_count.y});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
Expand Down
7 changes: 5 additions & 2 deletions backends/webgpu/runtime/ops/permute/permute.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ struct Params {
override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
let out_bufi = gid.x + gid.y * (num_workgroups.x * wg_size);
if (out_bufi >= out_meta.numel) {
return;
}
Expand Down
9 changes: 6 additions & 3 deletions backends/webgpu/runtime/ops/permute/permute_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from permute.wgsl - DO NOT EDIT.
// wgsl-sha256: d34f59730cda7317589b6ed5691a1ccab8666b9c94e17ac2cb3658b036300197
// wgsl-sha256: 05884aeb14426c979ea037b066266d8cab11f4fed76ee21ee8778e7fc13ad84e
inline constexpr const char* kPermuteWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
Expand All @@ -35,8 +35,11 @@ struct Params {
override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
let out_bufi = gid.x + gid.y * (num_workgroups.x * wg_size);
if (out_bufi >= out_meta.numel) {
return;
}
Expand Down
Loading