Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions backends/webgpu/runtime/ops/view_copy/ViewCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/view_copy/view_copy.h>

Expand Down Expand Up @@ -38,11 +39,49 @@ void add_flat_copy(WebGPUGraph& graph, int in_id, int out_id) {
}

// Aliased in/out already in place; CopyBufferToBuffer rejects src == dst.
if (in_tensor.buffer == out_tensor.buffer) {
return;
}
const bool aliased = in_tensor.buffer == out_tensor.buffer;
const size_t dispatch_idx = aliased
? 0
: graph.add_buffer_copy(
in_tensor.buffer, out_tensor.buffer, out_tensor.nbytes);

graph.add_buffer_copy(in_tensor.buffer, out_tensor.buffer, out_tensor.nbytes);
// Dynamic shapes: view preserves numel; copy_nbytes + out dims track live in.
std::vector<int64_t> out_max = out_tensor.dims;
graph.add_tensor_resize_hook(
in_id, [in_id, out_id, out_max, dispatch_idx, aliased](WebGPUGraph& g) {
const uint64_t target = utils::numel_of(g.cur_dims(in_id));
std::vector<int64_t> od = out_max;
const uint64_t maxnumel = utils::numel_of(out_max);
if (maxnumel != target) {
bool resolved = false;
// Assumes one dynamic dim; picks the leftmost numel-divisible.
for (size_t d = 0; d < od.size(); d++) {
if (out_max[d] <= 0) {
continue;
}
const uint64_t rest = maxnumel / static_cast<uint64_t>(out_max[d]);
if (rest != 0 && target % rest == 0) {
const uint64_t nd = target / rest;
if (nd <= static_cast<uint64_t>(out_max[d])) {
od[d] = static_cast<int64_t>(nd);
resolved = true;
break;
}
}
}
// Fail loud: a silent miss would leave od at max while copy_nbytes
// shrinks to the live size, desyncing consumers from the real copy.
if (!resolved) {
throw std::runtime_error(
"view_copy(resize): could not resolve live output shape");
}
}
g.set_cur_dims(out_id, od);
if (!aliased) {
g.dispatch_at(dispatch_idx).copy_nbytes =
static_cast<size_t>(target) * sizeof(float);
}
});
}

namespace {
Expand Down
Loading