Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 98 additions & 25 deletions backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <string>
#include <vector>

namespace executorch::backends::webgpu {

Expand All @@ -39,6 +41,42 @@ static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");
constexpr int64_t kQ4gswTileM = 4;
constexpr int64_t kQ4gswTileN = 4;

// Workgroup count for a linear_q4gsw dispatch (GEMV coop4 or tiled GEMM), with
// the range/zero guards shared by the build-time path and the resize hook.
uint32_t compute_q4gsw_workgroup_count(
WGPUDevice device,
bool use_gemv,
uint32_t m,
uint32_t n,
uint32_t wg_size,
const char* op_name) {
if (use_gemv) {
// coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N.
const uint64_t outputs =
static_cast<uint64_t>(m) * static_cast<uint64_t>(n);
if (outputs == 0u || outputs > UINT32_MAX) {
throw std::runtime_error(
std::string("WebGPU ") + op_name + ": M*N out of range");
}
const uint32_t wgc =
utils::clamp_workgroup_count(device, static_cast<uint32_t>(outputs));
if (wgc == 0u) {
throw std::runtime_error(
std::string("WebGPU ") + op_name + ": zero GEMV dispatch");
}
return wgc;
}
const int64_t total_tiles = utils::div_up<int64_t>(m, kQ4gswTileM) *
utils::div_up<int64_t>(n, kQ4gswTileN);
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
throw std::runtime_error(
std::string("WebGPU ") + op_name +
": tile count exceeds the 1D dispatch limit");
}
return utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(total_tiles), wg_size, op_name);
}

// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int in_id = args.at(0);
Expand Down Expand Up @@ -122,29 +160,8 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
const bool use_gemv = (M == 1u && K % 8u == 0u && gs % 8u == 0u);
const char* shader_src = use_gemv ? kQ4gswLinearCoop4WGSL : kQ4gswLinearWGSL;
uint32_t workgroup_count;
if (use_gemv) {
// coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N.
const uint64_t outputs =
static_cast<uint64_t>(M) * static_cast<uint64_t>(N);
if (outputs == 0u || outputs > UINT32_MAX) {
throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range");
}
workgroup_count =
utils::clamp_workgroup_count(device, static_cast<uint32_t>(outputs));
if (workgroup_count == 0u) {
throw std::runtime_error("WebGPU linear_q4gsw: zero GEMV dispatch");
}
} else {
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
utils::div_up<int64_t>(N, kQ4gswTileN);
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
throw std::runtime_error(
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
}
workgroup_count = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");
}
const uint32_t workgroup_count = compute_q4gsw_workgroup_count(
device, use_gemv, M, N, wg_size, "linear_q4gsw");

// Optional bias: real buffer if present, else a dummy for the fixed layout.
uint32_t has_bias = 0;
Expand Down Expand Up @@ -256,12 +273,68 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count, "linear_q4gsw"});
const size_t dispatch_idx = graph.add_dispatch(
{pipeline, bind_group, workgroup_count, "linear_q4gsw"});

// Dynamic shapes: recompute dispatch + params.M for the live M.
WGPUBuffer params_buf = uniform_buffer;
graph.add_tensor_resize_hook(
in_id,
[in_id,
out_id,
M,
K,
N,
K_packed,
gs,
padded_N,
has_bias,
wg_size,
use_gemv,
dispatch_idx,
params_buf](WebGPUGraph& g) {
const auto& d = g.cur_dims(in_id);
if (d.empty()) {
throw std::runtime_error("WebGPU linear_q4gsw: empty input dims");
}
const uint64_t numel = utils::numel_of(d);
if (numel % static_cast<uint64_t>(K) != 0u) {
throw std::runtime_error(
"WebGPU linear_q4gsw: live input numel not a multiple of K");
}
const uint32_t m =
static_cast<uint32_t>(numel / static_cast<uint64_t>(K));
if (m == 0u) {
throw std::runtime_error("WebGPU linear_q4gsw: live M == 0");
}
// Buffers/bind-groups were sized for the build-time max M; a larger
// live M would write out of bounds.
if (m > M) {
throw std::runtime_error(
"WebGPU linear_q4gsw: live M exceeds the build-time max");
}
const uint32_t wgc = compute_q4gsw_workgroup_count(
g.device(), use_gemv, m, N, wg_size, "linear_q4gsw(resize)");
Q4gswParams p = {};
p.M = m;
p.N = N;
p.K = K;
p.K_packed = K_packed;
p.group_size = gs;
p.padded_N = padded_N;
p.has_bias = has_bias;
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x = wgc;
std::vector<int64_t> od(d.begin(), d.end());
od.back() = static_cast<int64_t>(N);
g.set_cur_dims(out_id, od);
});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
wgpuBufferRelease(uniform_buffer);
// Graph owns it so the resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
}

} // namespace
Expand Down
Loading