From 3d8d9337380d4a0a2f91370aa32e97558b90179a Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 28 Jun 2026 09:22:24 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/runtime/ops/select/Select.cpp | 60 ++++++++++++++++--- .../webgpu/runtime/ops/sigmoid/UnaryOp.cpp | 26 +++++++- 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/backends/webgpu/runtime/ops/select/Select.cpp b/backends/webgpu/runtime/ops/select/Select.cpp index 5686bbc79c0..d674cf2a4e5 100644 --- a/backends/webgpu/runtime/ops/select/Select.cpp +++ b/backends/webgpu/runtime/ops/select/Select.cpp @@ -58,10 +58,9 @@ void select_impl(WebGPUGraph& graph, const std::vector& args) { throw std::runtime_error("select: dim out of range"); } const int64_t in_size = in_tensor.dims[dim]; - int64_t index = read_scalar(graph, args.at(2), "index"); - if (index < 0) { - index += in_size; - } + // Keep the RAW index: -1 normalizes against the LIVE dim (the resize hook). + const int64_t raw_index = read_scalar(graph, args.at(2), "index"); + int64_t index = raw_index < 0 ? raw_index + in_size : raw_index; if (index < 0 || index >= in_size) { throw std::runtime_error("select: index out of range"); } @@ -164,15 +163,58 @@ void select_impl(WebGPUGraph& graph, const std::vector& args) { bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count}); + const size_t dispatch_idx = + graph.add_dispatch({pipeline, bind_group, workgroup_count}); + + // Dynamic shapes: out = in minus `dim`; re-resolve index, meta, dispatch. + graph.add_tensor_resize_hook( + in_id, + [in_id, + out_id, + dim, + raw_index, + out_meta_buf, + in_meta_buf, + params_buf, + wg_size, + dispatch_idx](WebGPUGraph& g) { + const auto& ind = g.cur_dims(in_id); + const int64_t live_in_size = ind[dim]; + int64_t idx = raw_index < 0 ? raw_index + live_in_size : raw_index; + if (idx < 0 || idx >= live_in_size) { + throw std::runtime_error("select(resize): index out of range"); + } + std::vector od; + for (size_t k = 0; k < ind.size(); k++) { + if (static_cast(k) != dim) { + od.push_back(ind[k]); + } + } + g.set_cur_dims(out_id, od); + WebGPUTensor to, ti; + to.dims = od; + ti.dims = ind; + TensorMeta om, im; + fill_tensor_meta(to, &om); + fill_tensor_meta(ti, &im); + wgpuQueueWriteBuffer(g.queue(), out_meta_buf, 0, &om, sizeof(om)); + wgpuQueueWriteBuffer(g.queue(), in_meta_buf, 0, &im, sizeof(im)); + SelectParams p = {}; + p.dim = static_cast(dim); + p.index = static_cast(idx); + wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); + g.dispatch_at(dispatch_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + g.device(), om.numel, wg_size, "select(resize)"); + }); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - // Drop our refs; the bind group keeps the uniforms alive until release. - wgpuBufferRelease(out_meta_buf); - wgpuBufferRelease(in_meta_buf); - wgpuBufferRelease(params_buf); + // Graph owns them so the resize hook can rewrite them; freed in the dtor. + graph.own_uniform_buffer(out_meta_buf); + graph.own_uniform_buffer(in_meta_buf); + graph.own_uniform_buffer(params_buf); } } // namespace diff --git a/backends/webgpu/runtime/ops/sigmoid/UnaryOp.cpp b/backends/webgpu/runtime/ops/sigmoid/UnaryOp.cpp index 4d1a087cae5..7c8b8c0e9ce 100644 --- a/backends/webgpu/runtime/ops/sigmoid/UnaryOp.cpp +++ b/backends/webgpu/runtime/ops/sigmoid/UnaryOp.cpp @@ -135,14 +135,34 @@ void add_unary_op( bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count}); + const size_t dispatch_idx = + graph.add_dispatch({pipeline, bind_group, workgroup_count}); + + // Dynamic shapes: recompute num_elements/dispatch for the live shape. + WGPUBuffer params_buf = uniform_buffer; + graph.add_tensor_resize_hook( + in_id, + [in_id, out_id, wg_size, dispatch_idx, params_buf](WebGPUGraph& g) { + const auto& d = g.cur_dims(in_id); + const uint64_t numel = utils::numel_of(d); + g.set_cur_dims(out_id, d); + UnaryParams p = {}; + p.num_elements = static_cast(numel); + wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); + g.dispatch_at(dispatch_idx).workgroup_count_x = + utils::compute_1d_workgroup_count( + g.device(), + static_cast(numel), + wg_size, + "unary(resize)"); + }); // Release intermediates (pipeline + bind_group are kept by dispatch). wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - // Drop our ref; the bind group keeps the uniform buffer alive until release. - wgpuBufferRelease(uniform_buffer); + // Graph owns it so the resize hook can rewrite it; freed in the dtor. + graph.own_uniform_buffer(uniform_buffer); } void sigmoid_impl(WebGPUGraph& graph, const std::vector& args) {