Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions backends/webgpu/runtime/ops/add/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,48 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});
const size_t dispatch_idx = graph.num_dispatches() - 1;

// Dynamic shapes: recompute numel/dispatch; out follows the larger operand.
WGPUBuffer params_buf = uniform_buffer;
auto add_resize = [in1_id,
in2_id,
out_id,
alpha,
wg_size,
dispatch_idx,
params_buf](WebGPUGraph& g) {
const auto& d1 = g.cur_dims(in1_id);
const auto& d2 = g.cur_dims(in2_id);
const uint64_t n1 = utils::numel_of(d1);
const uint64_t n2 = utils::numel_of(d2);
const uint64_t numel = n2 > n1 ? n2 : n1;
const uint64_t n_min = n2 > n1 ? n1 : n2;
// The flat add follows the larger operand and broadcasts the smaller; valid
// only when the smaller tiles evenly into it (rejects e.g. [4,1] vs [1,3],
// whose true [4,3] result this flat kernel cannot produce).
if (n_min == 0u || numel % n_min != 0u) {
throw std::runtime_error(
"add(resize): operands are not broadcast-compatible by numel");
}
g.set_cur_dims(out_id, n2 > n1 ? d2 : d1);
AddParams p = {};
p.num_elements = static_cast<uint32_t>(numel);
p.alpha = alpha;
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(), static_cast<uint32_t>(numel), wg_size, "add(resize)");
};
graph.add_tensor_resize_hook(in1_id, add_resize);
graph.add_tensor_resize_hook(in2_id, add_resize);

// Release intermediate objects (pipeline and bind_group are kept by dispatch)
wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our ref; the bind group keeps the uniform buffer alive until release.
wgpuBufferRelease(uniform_buffer);
// Graph owns it so a resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
}

} // namespace
Expand Down
50 changes: 45 additions & 5 deletions backends/webgpu/runtime/ops/mul/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <webgpu/webgpu.h>

#include <algorithm>
#include <stdexcept>
#include <vector>

Expand Down Expand Up @@ -164,15 +165,54 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});
const size_t dispatch_idx =
graph.add_dispatch({pipeline, bind_group, workgroup_count});

// Dynamic shapes: rebuild all 3 broadcast TensorMeta UBOs + dispatch.
WGPUBuffer o_buf = out_meta_buf, a_buf = in1_meta_buf, b_buf = in2_meta_buf;
auto mul_resize =
[in1_id, in2_id, out_id, wg_size, dispatch_idx, o_buf, a_buf, b_buf](
WebGPUGraph& g) {
const auto& a = g.cur_dims(in1_id);
const auto& b = g.cur_dims(in2_id);
const size_t r = std::max(a.size(), b.size());
std::vector<int64_t> out_d(r, 1);
for (size_t i = 0; i < r; i++) {
const int64_t av = (i + a.size() < r) ? 1 : a[i - (r - a.size())];
const int64_t bv = (i + b.size() < r) ? 1 : b[i - (r - b.size())];
if (av != bv && av != 1 && bv != 1) {
throw std::runtime_error(
"mul(resize): operands are not broadcast-compatible");
}
out_d[i] = av > bv ? av : bv;
}
g.set_cur_dims(out_id, out_d);
const uint32_t out_ndim = static_cast<uint32_t>(r);
WebGPUTensor ta, tb, to;
ta.dims = a;
tb.dims = b;
to.dims = out_d;
TensorMeta om, am, bm;
fill_tensor_meta_broadcast(to, out_ndim, &om);
fill_tensor_meta_broadcast(ta, out_ndim, &am);
fill_tensor_meta_broadcast(tb, out_ndim, &bm);
wgpuQueueWriteBuffer(g.queue(), o_buf, 0, &om, sizeof(om));
wgpuQueueWriteBuffer(g.queue(), a_buf, 0, &am, sizeof(am));
wgpuQueueWriteBuffer(g.queue(), b_buf, 0, &bm, sizeof(bm));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(), om.numel, wg_size, "mul(resize)");
};
graph.add_tensor_resize_hook(in1_id, mul_resize);
graph.add_tensor_resize_hook(in2_id, mul_resize);

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our refs; the bind group keeps the uniforms alive until release.
wgpuBufferRelease(out_meta_buf);
wgpuBufferRelease(in1_meta_buf);
wgpuBufferRelease(in2_meta_buf);
// Graph owns them so a resize hook can rewrite them; freed in the dtor.
graph.own_uniform_buffer(out_meta_buf);
graph.own_uniform_buffer(in1_meta_buf);
graph.own_uniform_buffer(in2_meta_buf);
}

} // namespace
Expand Down
Loading