diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 7e3cb32de05..072bcf43b8b 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -194,6 +194,13 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST) ) target_compile_options(webgpu_op_test_util_test PRIVATE -fexceptions) set_property(TARGET webgpu_op_test_util_test PROPERTY CXX_STANDARD 17) + + # Dynamic-shape integration test: a gtest binary with its own main() that + # brings up the device once (like webgpu_op_test). + add_webgpu_native_test( + webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp + ) + target_link_libraries(webgpu_dynamic_shape_test PRIVATE GTest::gtest) endif() add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp) endif() diff --git a/backends/webgpu/test/native/test_dynamic_shape.cpp b/backends/webgpu/test/native/test_dynamic_shape.cpp new file mode 100644 index 00000000000..167ce52483a --- /dev/null +++ b/backends/webgpu/test/native/test_dynamic_shape.cpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Native test for dynamic tensor shapes (Option 2). One graph is built at the +// upper-bound seq-len MAXS and run at several live S; the output must match the +// torch golden at each S (allocate-at-max + per-op resize hooks + output-EValue +// resize). Cases: +// A dyn_rms at S=MAXS -> golden (static-equivalent) +// B dyn_rms at S < MAXS (64, 8, 1) -> golden (resize shrinks dispatch) +// C ONE loaded graph reused across S -> all golden (buffers never moved => +// bind groups stayed valid) +// D static_rms (no dynamic dim) -> golden (static path unchanged) +// F dyn_rms_chain (rms(rms(x))) at 3 S -> golden (resize CASCADE, DD-4) +// G rms+residual H rms*x I dyn_linear J sdpa_dyn K emb_dyn L rope_dyn +// M dyn_sigmoid N dyn_select (select_copy(0,-1), dynamic S) +// .pte + goldens from test/ops/dynamic_shape/test_dynamic_shape_export.py. +// +// Artifacts dir: $WEBGPU_DYNAMIC_SHAPE_DIR, else argv[1], else +// /tmp/dynamic_shape. + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::webgpu; +using namespace executorch::extension; +using namespace executorch::runtime; + +namespace { + +constexpr int kHidden = 64; + +// Artifacts directory; set from env/argv in main() before RUN_ALL_TESTS(). +std::string g_dir; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + +std::vector read_bin(const std::string& path) { + std::ifstream f(path, std::ios::binary | std::ios::ate); + if (!f) { + return {}; + } + const std::streamsize n = f.tellg(); + if (n < 0) { + return {}; + } + f.seekg(0); + std::vector v(static_cast(n) / sizeof(float)); + f.read(reinterpret_cast(v.data()), n); + return v; +} + +float max_err(const std::vector& a, const std::vector& b) { + if (a.size() != b.size() || a.empty()) { + return 1e30f; + } + float m = 0.0f; + for (size_t i = 0; i < a.size(); i++) { + m = std::fmax(m, std::fabs(a[i] - b[i])); + } + return m; +} + +// Run a [1,1,S,kHidden] input through `m` and compare to the golden. Shared by +// every single-output rms-shaped case (A-H, M). +void check_s(Module& m, const std::string& prefix, int s) { + const std::string base = g_dir + "/" + prefix + ".S" + std::to_string(s); + auto input = read_bin(base + ".input.bin"); + ASSERT_FALSE(input.empty()) << "missing input: " << prefix << ".S" << s; + ASSERT_EQ(input.size(), static_cast(s) * kHidden) + << "wrong input size: " << prefix << ".S" << s; + auto t = make_tensor_ptr({1, 1, s, kHidden}, std::move(input)); + auto r = m.forward({EValue(t)}); + ASSERT_TRUE(r.ok() && !r.get().empty() && r.get()[0].isTensor()) + << prefix << " S=" << s + << " forward failed (err=" << (r.ok() ? 0 : (int)r.error()) << ")"; + const auto& out = r.get()[0].toTensor(); + const size_t numel = static_cast(s) * kHidden; + // Output EValue must have been resized to the live shape. + ASSERT_EQ(static_cast(out.numel()), numel) + << prefix << " S=" << s << " output numel mismatch"; + const float* d = out.const_data_ptr(); + std::vector got(d, d + numel); + auto golden = read_bin(base + ".golden.bin"); + const float e = max_err(got, golden); + EXPECT_LT(e, 1e-3f) << prefix << " S=" << s << " max_err=" << e + << " (got.size=" << got.size() + << " golden.size=" << golden.size() << ")"; +} + +// Dynamic quantized linear: input [M, kLinK] -> output [M, kLinN]. +constexpr int kLinK = 64; +constexpr int kLinN = 128; +void check_linear(int m_rows) { + Module m(g_dir + "/dyn_linear.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load dyn_linear.pte"; + const std::string base = g_dir + "/dyn_linear.S" + std::to_string(m_rows); + auto input = read_bin(base + ".input.bin"); + auto golden = read_bin(base + ".golden.bin"); + ASSERT_FALSE(input.empty()) << "missing dyn_linear.S" << m_rows; + auto t = make_tensor_ptr({m_rows, kLinK}, std::move(input)); + auto r = m.forward({EValue(t)}); + ASSERT_TRUE(r.ok() && !r.get().empty() && r.get()[0].isTensor()) + << "dyn_linear M=" << m_rows << " forward failed"; + const auto& out = r.get()[0].toTensor(); + const size_t numel = static_cast(m_rows) * kLinN; + ASSERT_EQ(static_cast(out.numel()), numel) + << "dyn_linear M=" << m_rows << " output numel mismatch"; + std::vector got( + out.const_data_ptr(), out.const_data_ptr() + numel); + const float e = max_err(got, golden); + // 4-bit quant: looser tol (the kernel mirrors the dequant-matmul reference). + EXPECT_LT(e, 5e-3f) << "dyn_linear M=" << m_rows << " max_err=" << e; +} + +// Dynamic SDPA (GQA prefill, input_pos=0): q[1,s,hq,d] k/v[1,s,hkv,d] +// caches[1,cmax,hkv,d]; attn output [1,s,hq,d] selected by shape (3 outputs). +constexpr int kSdHq = 8, kSdHkv = 2, kSdD = 16, kSdCmax = 64; +void check_sdpa(int s) { + Module m(g_dir + "/sdpa_dyn.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "sdpa_dyn S=" << s << " load"; + const std::string b = g_dir + "/sdpa_dyn.S" + std::to_string(s) + "."; + auto q = read_bin(b + "q.bin"); + auto k = read_bin(b + "k.bin"); + auto v = read_bin(b + "v.bin"); + auto kc = read_bin(b + "kc.bin"); + auto vc = read_bin(b + "vc.bin"); + auto golden = read_bin(b + "golden.bin"); + ASSERT_FALSE( + q.empty() || k.empty() || v.empty() || kc.empty() || vc.empty() || + golden.empty()) + << "missing sdpa_dyn.S" << s; + auto tq = make_tensor_ptr({1, s, kSdHq, kSdD}, std::move(q)); + auto tk = make_tensor_ptr({1, s, kSdHkv, kSdD}, std::move(k)); + auto tv = make_tensor_ptr({1, s, kSdHkv, kSdD}, std::move(v)); + auto tkc = make_tensor_ptr({1, kSdCmax, kSdHkv, kSdD}, std::move(kc)); + auto tvc = make_tensor_ptr({1, kSdCmax, kSdHkv, kSdD}, std::move(vc)); + auto r = + m.forward({EValue(tq), EValue(tk), EValue(tv), EValue(tkc), EValue(tvc)}); + ASSERT_TRUE(r.ok()) << "sdpa S=" << s + << " forward failed (err=" << (int)r.error() << ")"; + // Select the attn output by full shape [1,s,hq,d] (never numel). + const float* attn = nullptr; + const size_t numel = static_cast(s) * kSdHq * kSdD; + for (size_t i = 0; i < r.get().size(); i++) { + if (!r.get()[i].isTensor()) { + continue; + } + const auto& t = r.get()[i].toTensor(); + if (t.dim() == 4 && t.size(1) == s && t.size(2) == kSdHq && + t.size(3) == kSdD) { + attn = t.const_data_ptr(); + break; + } + } + ASSERT_NE(attn, nullptr) << "sdpa S=" << s << ": no attn output of shape [1," + << s << "," << kSdHq << "," << kSdD << "]"; + std::vector got(attn, attn + numel); + const float e = max_err(got, golden); + EXPECT_LT(e, 2e-3f) << "sdpa_dyn S=" << s << " max_err=" << e; +} + +// Dynamic embedding: int64 token ids [N] -> [N, kEmbDim] fp32. The int64 host +// input exercises copy_inputs' int64->int32 narrow path under dynamic shapes. +constexpr int kEmbDim = 64; +void check_embedding(int n) { + Module m(g_dir + "/emb_dyn.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load emb_dyn.pte"; + const std::string b = g_dir + "/emb_dyn.S" + std::to_string(n) + "."; + std::ifstream f(b + "idx.bin", std::ios::binary | std::ios::ate); + ASSERT_TRUE(f.good()) << "missing emb_dyn.S" << n; + const std::streamsize nb = f.tellg(); + ASSERT_GE(nb, 0) << "missing emb_dyn.S" << n; + f.seekg(0); + std::vector idx(static_cast(nb) / sizeof(int64_t)); + f.read(reinterpret_cast(idx.data()), nb); + ASSERT_EQ(idx.size(), static_cast(n)) + << "wrong emb_dyn idx size S" << n; + auto golden = read_bin(b + "golden.bin"); + auto t = make_tensor_ptr({n}, std::move(idx)); // int64 (Long) host input + auto r = m.forward({EValue(t)}); + ASSERT_TRUE(r.ok() && !r.get().empty() && r.get()[0].isTensor()) + << "emb N=" << n + << " forward failed (err=" << (r.ok() ? 0 : (int)r.error()) << ")"; + const auto& out = r.get()[0].toTensor(); + const size_t numel = static_cast(n) * kEmbDim; + ASSERT_EQ(static_cast(out.numel()), numel) + << "emb N=" << n << " output numel mismatch"; + std::vector got( + out.const_data_ptr(), out.const_data_ptr() + numel); + const float e = max_err(got, golden); + EXPECT_LT(e, 5e-3f) << "emb_dyn N=" << n << " max_err=" << e; +} + +// Dynamic RoPE: xq[1,s,nh,hd] xk[1,s,nkv,hd] freqs[s,hd/2] -> xq_out/xk_out +// (2 outputs, selected by head count nh != nkv). +constexpr int kRopeNH = 8, kRopeNKV = 2, kRopeHD = 64; +void check_rope(int s) { + Module m(g_dir + "/rope_dyn.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load rope_dyn.pte"; + const std::string b = g_dir + "/rope_dyn.S" + std::to_string(s) + "."; + auto xq = read_bin(b + "xq.bin"); + auto xk = read_bin(b + "xk.bin"); + auto fc = read_bin(b + "fc.bin"); + auto fs = read_bin(b + "fs.bin"); + auto gq = read_bin(b + "gq.bin"); + auto gk = read_bin(b + "gk.bin"); + ASSERT_FALSE( + xq.empty() || xk.empty() || fc.empty() || fs.empty() || gq.empty() || + gk.empty()) + << "missing rope_dyn.S" << s; + auto txq = make_tensor_ptr({1, s, kRopeNH, kRopeHD}, std::move(xq)); + auto txk = make_tensor_ptr({1, s, kRopeNKV, kRopeHD}, std::move(xk)); + auto tfc = make_tensor_ptr({s, kRopeHD / 2}, std::move(fc)); + auto tfs = make_tensor_ptr({s, kRopeHD / 2}, std::move(fs)); + auto r = m.forward({EValue(txq), EValue(txk), EValue(tfc), EValue(tfs)}); + ASSERT_TRUE(r.ok()) << "rope S=" << s + << " forward failed (err=" << (int)r.error() << ")"; + // Select xq_out (nh heads) and xk_out (nkv heads) by shape. + const float *oq = nullptr, *okp = nullptr; + for (size_t i = 0; i < r.get().size(); i++) { + if (!r.get()[i].isTensor()) { + continue; + } + const auto& t = r.get()[i].toTensor(); + if (t.dim() == 4 && t.size(1) == s && t.size(3) == kRopeHD) { + if (t.size(2) == kRopeNH) { + oq = t.const_data_ptr(); + } else if (t.size(2) == kRopeNKV) { + okp = t.const_data_ptr(); + } + } + } + ASSERT_TRUE(oq != nullptr && okp != nullptr) + << "rope S=" << s << ": missing xq_out/xk_out by shape"; + std::vector gotq(oq, oq + static_cast(s) * kRopeNH * kRopeHD); + std::vector gotk( + okp, okp + static_cast(s) * kRopeNKV * kRopeHD); + const float e = std::fmax(max_err(gotq, gq), max_err(gotk, gk)); + EXPECT_LT(e, 1e-3f) << "rope_dyn S=" << s << " max_err=" << e; +} + +// Dynamic select_copy(0,-1): input [2,1,S,kHidden] -> output [1,S,kHidden]. The +// negative index resolves against the (static) leading dim live; the dynamic S +// flows to the output, so the resize hook recomputes its dispatch each S. +constexpr int kSelLead = 2; +void check_select(int s) { + Module m(g_dir + "/dyn_select.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load dyn_select.pte"; + const std::string base = g_dir + "/dyn_select.S" + std::to_string(s); + auto input = read_bin(base + ".input.bin"); + auto golden = read_bin(base + ".golden.bin"); + ASSERT_FALSE(input.empty() || golden.empty()) << "missing dyn_select.S" << s; + auto t = make_tensor_ptr({kSelLead, 1, s, kHidden}, std::move(input)); + auto r = m.forward({EValue(t)}); + ASSERT_TRUE(r.ok() && !r.get().empty() && r.get()[0].isTensor()) + << "select S=" << s + << " forward failed (err=" << (r.ok() ? 0 : (int)r.error()) << ")"; + const auto& out = r.get()[0].toTensor(); + const size_t numel = static_cast(s) * kHidden; + ASSERT_EQ(static_cast(out.numel()), numel) + << "select S=" << s << " output numel mismatch"; + std::vector got( + out.const_data_ptr(), out.const_data_ptr() + numel); + const float e = max_err(got, golden); + EXPECT_LT(e, 1e-3f) << "dyn_select S=" << s << " max_err=" << e; +} + +} // namespace + +// A + B: single dynamic rms_norm at S = MAXS .. 1 (fresh module load each S). +TEST(DynamicShape, RmsNormFreshLoad) { + for (int s : {128, 64, 8, 1}) { + Module m(g_dir + "/dyn_rms.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load dyn_rms.pte"; + check_s(m, "dyn_rms", s); + } +} + +// C: ONE loaded graph reused across S (buffers must not move => bind groups +// stay valid). +TEST(DynamicShape, RmsNormReusedGraph) { + Module m(g_dir + "/dyn_rms.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load dyn_rms.pte"; + for (int s : {128, 1, 64, 8, 128}) { + check_s(m, "dyn_rms", s); + } +} + +// D: static rms_norm (no dynamic dim) — regression that the static path is +// unchanged. +TEST(DynamicShape, StaticRmsNorm) { + Module m(g_dir + "/static_rms.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load static_rms.pte"; + check_s(m, "static_rms", 8); +} + +// F: 2-op chain rms(rms(x)) — resize cascade. +TEST(DynamicShape, RmsChainCascade) { + for (int s : {128, 16, 1}) { + Module m(g_dir + "/dyn_rms_chain.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load dyn_rms_chain.pte"; + check_s(m, "dyn_rms_chain", s); + } +} + +// G: rms(x)+x residual — cross-op (rms -> add) cascade. +TEST(DynamicShape, RmsResidualCascade) { + for (int s : {128, 32, 1}) { + Module m(g_dir + "/dyn_residual.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load dyn_residual.pte"; + check_s(m, "dyn_residual", s); + } +} + +// H: rms(x)*x — exercises the mul op resize. +TEST(DynamicShape, RmsMul) { + for (int s : {128, 32, 1}) { + Module m(g_dir + "/dyn_rmsmul.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load dyn_rmsmul.pte"; + check_s(m, "dyn_rmsmul", s); + } +} + +// I: dynamic 4-bit quantized linear (prefill GEMM) at several M. +TEST(DynamicShape, QuantizedLinear) { + for (int m_rows : {128, 32, 1}) { + check_linear(m_rows); + } +} + +// J: dynamic SDPA (GQA prefill) at several seq-len S. The whole case skips +// while op coverage is pending (the dynamic-S build throws err 48 until +// registered). +TEST(DynamicShape, Sdpa) { + { + Module probe(g_dir + "/sdpa_dyn.pte"); + if (probe.load_forward() == Error::DelegateInvalidCompatibility) { + GTEST_SKIP() << "sdpa_dyn pending op coverage (err " + << (int)Error::DelegateInvalidCompatibility << ")"; + } + } + for (int s : {64, 16, 1}) { + check_sdpa(s); + } +} + +// K: dynamic embedding (int64 token ids) at several token counts. +TEST(DynamicShape, Embedding) { + for (int n : {16, 8, 1}) { + check_embedding(n); + } +} + +// L: dynamic RoPE (two outputs) at several seq-len S. +TEST(DynamicShape, Rope) { + for (int s : {16, 8, 1}) { + check_rope(s); + } +} + +// M: dynamic sigmoid (elementwise) at several S. +TEST(DynamicShape, Sigmoid) { + for (int s : {128, 32, 1}) { + Module m(g_dir + "/dyn_sigmoid.pte"); + ASSERT_EQ(m.load_forward(), Error::Ok) << "load dyn_sigmoid.pte"; + check_s(m, "dyn_sigmoid", s); + } +} + +// N: dynamic select_copy(0,-1) at several S. +TEST(DynamicShape, Select) { + for (int s : {128, 32, 1}) { + check_select(s); + } +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + // Artifacts dir: env wins, else first positional arg, else default (gtest + // flags were already stripped by InitGoogleTest above). + g_dir = "/tmp/dynamic_shape"; + if (argc > 1) { + g_dir = argv[1]; + } + if (const char* env = std::getenv("WEBGPU_DYNAMIC_SHAPE_DIR")) { + g_dir = env; + } + + WebGPUContext ctx; + try { + ctx = create_webgpu_context(); + } catch (const std::exception& e) { + std::printf("SKIP: no WebGPU device (%s)\n", e.what()); + return 0; + } + set_default_webgpu_context(&ctx); + + const int rc = RUN_ALL_TESTS(); + set_default_webgpu_context(nullptr); + destroy_webgpu_context(ctx); + return rc; +} diff --git a/backends/webgpu/test/ops/dynamic_shape/test_dynamic_shape_export.py b/backends/webgpu/test/ops/dynamic_shape/test_dynamic_shape_export.py new file mode 100644 index 00000000000..6652d073805 --- /dev/null +++ b/backends/webgpu/test/ops/dynamic_shape/test_dynamic_shape_export.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Dynamic tensor-shape (Option 2) export tests via VulkanPartitioner. + +Exports ONE graph built at the upper-bound seq-len MAXS that the native runtime +(`test/native/test_dynamic_shape.cpp`) then runs at several live S, asserting the +output matches the torch golden and that the static path is unchanged. Numerics +are checked in the native test; this verifies the dynamic export side + writes +goldens. +""" + +import os +import unittest + +import torch +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.exir import to_edge_transform_and_lower + +MAXS = 128 # upper bound for the dynamic seq-len dim (within the 1D dispatch cap) +HIDDEN = 64 + + +def _rms(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + x_f32 = x.to(torch.float32) + var = x_f32.pow(2).mean(dim=-1, keepdim=True) + return (x_f32 * torch.rsqrt(var + eps)) * weight + + +class RmsNormModule(torch.nn.Module): + def __init__(self, hidden: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _rms(x, self.weight, self.eps) + + +class RmsChainModule(torch.nn.Module): + """rms(rms(x)) — two ops; exercises the resize-cascade (DD-4).""" + + def __init__(self, hidden: int, eps: float = 1e-6) -> None: + super().__init__() + self.w1 = torch.nn.Parameter( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + self.w2 = torch.nn.Parameter( + torch.linspace(1.5, 0.5, hidden, dtype=torch.float32) + ) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _rms(_rms(x, self.w1, self.eps), self.w2, self.eps) + + +class RmsResidualModule(torch.nn.Module): + """rms(x) + x — rms op feeding an add op; proves the cross-op resize cascade.""" + + def __init__(self, hidden: int, eps: float = 1e-6) -> None: + super().__init__() + self.w = torch.nn.Parameter( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _rms(x, self.w, self.eps) + x + + +class RmsMulModule(torch.nn.Module): + """rms(x) * x — exercises the mul op (two same-shape dynamic operands).""" + + def __init__(self, hidden: int, eps: float = 1e-6) -> None: + super().__init__() + self.w = torch.nn.Parameter( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _rms(x, self.w, self.eps) * x + + +class SigmoidModule(torch.nn.Module): + """sigmoid(x) — elementwise; resize hook recomputes dispatch from live numel.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sigmoid(x) + + +class SelectModule(torch.nn.Module): + """x.select(0, -1) — negative index resolved live + dynamic output dispatch.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.select(0, -1) + + +def _ramp(shape) -> torch.Tensor: + n = 1 + for d in shape: + n *= d + return torch.linspace(-1.0, 1.0, n, dtype=torch.float32).reshape(shape) + + +def _export(model, example_inputs, dynamic_shapes, path: str) -> None: + ep = torch.export.export(model, example_inputs, dynamic_shapes=dynamic_shapes) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + found = any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ) + assert found, f"Expected VulkanBackend delegate in {path}" + with open(path, "wb") as f: + f.write(et.buffer) + print(f"Exported {path}") + + +def _write_goldens(model, prefix: str, out_dir: str, s_values) -> None: + for s in s_values: + x = _ramp((1, 1, s, HIDDEN)) + with torch.no_grad(): + g = model(x) + x.detach().numpy().astype(" None: + """Write the dynamic + static .pte's and per-S goldens for the native test.""" + os.makedirs(out_dir, exist_ok=True) + s_dim = torch.export.Dim("s", min=1, max=MAXS) + + # 1) Single dynamic rms_norm, graph built at S=MAXS (upper bound). + rms = RmsNormModule(HIDDEN) + _export( + rms, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_rms.pte"), + ) + _write_goldens(rms, "dyn_rms", out_dir, [MAXS, 64, 8, 1]) + + # 2) Two-op chain (cascade): rms(rms(x)). + chain = RmsChainModule(HIDDEN) + _export( + chain, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_rms_chain.pte"), + ) + _write_goldens(chain, "dyn_rms_chain", out_dir, [MAXS, 16, 1]) + + # 2b) rms(x)+x residual — cross-op (rms->add) cascade. + resid = RmsResidualModule(HIDDEN) + _export( + resid, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_residual.pte"), + ) + _write_goldens(resid, "dyn_residual", out_dir, [MAXS, 32, 1]) + + # 2c) rms(x)*x — exercises the mul op resize. + rmsmul = RmsMulModule(HIDDEN) + _export( + rmsmul, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_rmsmul.pte"), + ) + _write_goldens(rmsmul, "dyn_rmsmul", out_dir, [MAXS, 32, 1]) + + # 2d) 4-bit quantized linear with a DYNAMIC rows (M) dim — prefill GEMM. + _export_dynamic_linear(out_dir) + + # 2e) Fused SDPA with a DYNAMIC seq-len S (prefill, input_pos=0). + _export_dynamic_sdpa(out_dir) + + # 2f) 4-bit embedding with a DYNAMIC token count (int64 indices). + _export_dynamic_embedding(out_dir) + + # 2g) Interleaved RoPE with a DYNAMIC seq-len S (two outputs xq/xk). + _export_dynamic_rope(out_dir) + + # 2h) Elementwise sigmoid with a DYNAMIC seq-len S. + sig = SigmoidModule() + _export( + sig, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_sigmoid.pte"), + ) + _write_goldens(sig, "dyn_sigmoid", out_dir, [MAXS, 32, 1]) + + # 2i) select_copy(0, -1) over a DYNAMIC seq-len S (negative live index). + _export_dynamic_select(out_dir) + + # 3) Static rms_norm (no dynamic dim) — regression: must stay byte-identical. + static = RmsNormModule(HIDDEN) + _export( + static, + (_ramp((1, 1, 8, HIDDEN)),), + None, + os.path.join(out_dir, "static_rms.pte"), + ) + _write_goldens(static, "static_rms", out_dir, [8]) + + +# Quantized linear: K x N weight, dynamic rows M; input [M, K], output [M, N]. +LIN_K = 64 +LIN_N = 128 +LIN_GROUP = 32 +LIN_MAXM = 128 + + +def _export_dynamic_linear(out_dir: str) -> None: + from executorch.backends.webgpu.test.ops.quantized_linear.test_quantized_linear import ( + _fp64_golden, + _make_quantized_model, + ) + + model = _make_quantized_model(LIN_K, LIN_N, LIN_GROUP) + x = _ramp((LIN_MAXM, LIN_K)) + m_dim = torch.export.Dim("m", min=1, max=LIN_MAXM) + ep = torch.export.export(model, (x,), dynamic_shapes=({0: m_dim},)) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "linear_q4gsw not delegated" + with open(os.path.join(out_dir, "dyn_linear.pte"), "wb") as f: + f.write(et.buffer) + print("Exported dyn_linear.pte") + for m in [LIN_MAXM, 32, 1]: + xm = _ramp((m, LIN_K)) + g = _fp64_golden(model, xm).astype(" None: + from executorch.backends.webgpu.test.ops.sdpa.test_sdpa import ( + _det_inputs, + _golden, + SdpaConfig, + SdpaModule, + ) + + def cfg(s: int) -> "SdpaConfig": + return SdpaConfig("dyn", SD_HQ, SD_HKV, SD_D, s, SD_CMAX, 0) + + model = SdpaModule(0) + q, k, v, kc, vc = _det_inputs(cfg(SD_MAXS)) + s_dim = torch.export.Dim("s", min=1, max=SD_MAXS) + ds = ({1: s_dim}, {1: s_dim}, {1: s_dim}, None, None) + ep = torch.export.export(model, (q, k, v, kc, vc), dynamic_shapes=ds) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "sdpa not delegated" + with open(os.path.join(out_dir, "sdpa_dyn.pte"), "wb") as f: + f.write(et.buffer) + print("Exported sdpa_dyn.pte") + for s in [SD_MAXS, 16, 1]: + c = cfg(s) + q, k, v, kc, vc = _det_inputs(c) + g = _golden(c, q, k, v, kc, vc) + for name, t in [ + ("q", q), + ("k", k), + ("v", v), + ("kc", kc), + ("vc", vc), + ("golden", g), + ]: + t.detach().cpu().numpy().astype(" [N, EMBED] fp32. +EMB_VOCAB = 64 +EMB_DIM = 64 +EMB_GROUP = 32 +EMB_MAXN = 16 + + +def _export_dynamic_embedding(out_dir: str) -> None: + from executorch.backends.webgpu.test.ops.embedding_q4gsw.test_embedding_q4gsw import ( + _make_quantized_model, + _quant_params, + Shape, + ) + + shape = Shape("dyn", EMB_VOCAB, EMB_DIM, EMB_GROUP, list(range(EMB_MAXN))) + qm = _make_quantized_model(shape) + idx_max = torch.arange(EMB_MAXN, dtype=torch.long) + n_dim = torch.export.Dim("n", min=1, max=EMB_MAXN) + ep = torch.export.export(qm, (idx_max,), dynamic_shapes=({0: n_dim},)) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "embedding_q4gsw not delegated" + with open(os.path.join(out_dir, "emb_dyn.pte"), "wb") as f: + f.write(et.buffer) + print("Exported emb_dyn.pte") + weight, scales, group_size = _quant_params(qm) + for n in [EMB_MAXN, 8, 1]: + idx = (torch.arange(n, dtype=torch.long) * 7) % EMB_VOCAB + g = torch.ops.et_vk.embedding_q4gsw.default( + weight, scales, group_size, idx, False + ) + idx.detach().numpy().astype(" None: + from executorch.backends.webgpu.test.ops.rope.test_rope import ( + _golden, + _inputs, + Shape, + ) + from executorch.examples.models.llama.rope import RotaryEmbedding + + xq, xk, fc, fs = _inputs(Shape("dyn", 1, ROPE_MAXS, ROPE_NH, ROPE_NKV, ROPE_HD)) + s_dim = torch.export.Dim("s", min=1, max=ROPE_MAXS) + ds = ({1: s_dim}, {1: s_dim}, {0: s_dim}, {0: s_dim}) + ep = torch.export.export( + RotaryEmbedding().eval(), (xq, xk, fc, fs), dynamic_shapes=ds + ) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "apply_rotary_emb not delegated" + with open(os.path.join(out_dir, "rope_dyn.pte"), "wb") as f: + f.write(et.buffer) + print("Exported rope_dyn.pte") + for s in [ROPE_MAXS, 8, 1]: + xq, xk, fc, fs = _inputs(Shape("dyn", 1, s, ROPE_NH, ROPE_NKV, ROPE_HD)) + gq, gk = _golden(xq, xk, fc, fs) + for name, t in [ + ("xq", xq), + ("xk", xk), + ("fc", fc), + ("fs", fs), + ("gq", gq), + ("gk", gk), + ]: + t.detach().cpu().numpy().astype(" [1, S, HIDDEN]. +SEL_LEAD = 2 + + +def _export_dynamic_select(out_dir: str) -> None: + model = SelectModule() + s_dim = torch.export.Dim("s", min=1, max=MAXS) + ep = torch.export.export( + model, + (_ramp((SEL_LEAD, 1, MAXS, HIDDEN)),), + dynamic_shapes=({2: s_dim},), + ) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "select_copy not delegated" + with open(os.path.join(out_dir, "dyn_select.pte"), "wb") as f: + f.write(et.buffer) + print("Exported dyn_select.pte") + for s in [MAXS, 32, 1]: + x = _ramp((SEL_LEAD, 1, s, HIDDEN)) + with torch.no_grad(): + g = model(x) + x.detach().numpy().astype(" None: + import tempfile + + with tempfile.TemporaryDirectory() as d: + export_dynamic_shape_cases(d) + self.assertTrue(os.path.exists(os.path.join(d, "dyn_rms.pte"))) + self.assertTrue(os.path.exists(os.path.join(d, "dyn_rms.S1.golden.bin"))) + + +if __name__ == "__main__": + unittest.main()