diff --git a/backends/webgpu/runtime/ops/add/BinaryOp.cpp b/backends/webgpu/runtime/ops/add/BinaryOp.cpp index 578799a9c38..8c56ad6c15d 100644 --- a/backends/webgpu/runtime/ops/add/BinaryOp.cpp +++ b/backends/webgpu/runtime/ops/add/BinaryOp.cpp @@ -159,13 +159,48 @@ void add_impl(WebGPUGraph& graph, const std::vector& args) { WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); graph.add_dispatch({pipeline, bind_group, workgroup_count}); + const size_t dispatch_idx = graph.num_dispatches() - 1; + + // Dynamic shapes: recompute numel/dispatch; out follows the larger operand. + WGPUBuffer params_buf = uniform_buffer; + auto add_resize = [in1_id, + in2_id, + out_id, + alpha, + wg_size, + dispatch_idx, + params_buf](WebGPUGraph& g) { + const auto& d1 = g.cur_dims(in1_id); + const auto& d2 = g.cur_dims(in2_id); + const uint64_t n1 = utils::numel_of(d1); + const uint64_t n2 = utils::numel_of(d2); + const uint64_t numel = n2 > n1 ? n2 : n1; + const uint64_t n_min = n2 > n1 ? n1 : n2; + // The flat add follows the larger operand and broadcasts the smaller; valid + // only when the smaller tiles evenly into it (rejects e.g. [4,1] vs [1,3], + // whose true [4,3] result this flat kernel cannot produce). + if (n_min == 0u || numel % n_min != 0u) { + throw std::runtime_error( + "add(resize): operands are not broadcast-compatible by numel"); + } + g.set_cur_dims(out_id, n2 > n1 ? d2 : d1); + AddParams p = {}; + p.num_elements = static_cast(numel); + p.alpha = alpha; + wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); + g.dispatch_at(dispatch_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + g.device(), static_cast(numel), wg_size, "add(resize)"); + }; + graph.add_tensor_resize_hook(in1_id, add_resize); + graph.add_tensor_resize_hook(in2_id, add_resize); // Release intermediate objects (pipeline and bind_group are kept by dispatch) wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - // Drop our ref; the bind group keeps the uniform buffer alive until release. - wgpuBufferRelease(uniform_buffer); + // Graph owns it so a resize hook can rewrite it; freed in the dtor. + graph.own_uniform_buffer(uniform_buffer); } } // namespace diff --git a/backends/webgpu/runtime/ops/mul/BinaryOp.cpp b/backends/webgpu/runtime/ops/mul/BinaryOp.cpp index 007b7b2d8da..2ccb6c0e1bf 100644 --- a/backends/webgpu/runtime/ops/mul/BinaryOp.cpp +++ b/backends/webgpu/runtime/ops/mul/BinaryOp.cpp @@ -14,6 +14,7 @@ #include +#include #include #include @@ -164,15 +165,54 @@ void mul_impl(WebGPUGraph& graph, const std::vector& args) { bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count}); + const size_t dispatch_idx = + graph.add_dispatch({pipeline, bind_group, workgroup_count}); + + // Dynamic shapes: rebuild all 3 broadcast TensorMeta UBOs + dispatch. + WGPUBuffer o_buf = out_meta_buf, a_buf = in1_meta_buf, b_buf = in2_meta_buf; + auto mul_resize = + [in1_id, in2_id, out_id, wg_size, dispatch_idx, o_buf, a_buf, b_buf]( + WebGPUGraph& g) { + const auto& a = g.cur_dims(in1_id); + const auto& b = g.cur_dims(in2_id); + const size_t r = std::max(a.size(), b.size()); + std::vector out_d(r, 1); + for (size_t i = 0; i < r; i++) { + const int64_t av = (i + a.size() < r) ? 1 : a[i - (r - a.size())]; + const int64_t bv = (i + b.size() < r) ? 1 : b[i - (r - b.size())]; + if (av != bv && av != 1 && bv != 1) { + throw std::runtime_error( + "mul(resize): operands are not broadcast-compatible"); + } + out_d[i] = av > bv ? av : bv; + } + g.set_cur_dims(out_id, out_d); + const uint32_t out_ndim = static_cast(r); + WebGPUTensor ta, tb, to; + ta.dims = a; + tb.dims = b; + to.dims = out_d; + TensorMeta om, am, bm; + fill_tensor_meta_broadcast(to, out_ndim, &om); + fill_tensor_meta_broadcast(ta, out_ndim, &am); + fill_tensor_meta_broadcast(tb, out_ndim, &bm); + wgpuQueueWriteBuffer(g.queue(), o_buf, 0, &om, sizeof(om)); + wgpuQueueWriteBuffer(g.queue(), a_buf, 0, &am, sizeof(am)); + wgpuQueueWriteBuffer(g.queue(), b_buf, 0, &bm, sizeof(bm)); + g.dispatch_at(dispatch_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + g.device(), om.numel, wg_size, "mul(resize)"); + }; + graph.add_tensor_resize_hook(in1_id, mul_resize); + graph.add_tensor_resize_hook(in2_id, mul_resize); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - // Drop our refs; the bind group keeps the uniforms alive until release. - wgpuBufferRelease(out_meta_buf); - wgpuBufferRelease(in1_meta_buf); - wgpuBufferRelease(in2_meta_buf); + // Graph owns them so a resize hook can rewrite them; freed in the dtor. + graph.own_uniform_buffer(out_meta_buf); + graph.own_uniform_buffer(in1_meta_buf); + graph.own_uniform_buffer(in2_meta_buf); } } // namespace