diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index beb3f49e5b2..adbf4301413 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -151,19 +151,9 @@ function(add_webgpu_native_test test_name test_src) endfunction() if(EXECUTORCH_BUILD_WEBGPU_TEST) - add_webgpu_native_test(webgpu_native_test test/test_webgpu_native.cpp) - add_webgpu_native_test( - webgpu_dispatch_order_test test/native/test_dispatch_order.cpp - ) - add_webgpu_native_test( - webgpu_scratch_buffer_test test/native/test_scratch_buffer.cpp - ) - add_webgpu_native_test( - webgpu_update_cache_test test/native/test_update_cache.cpp - ) - - # Manifest-driven op-test framework: a generic gtest driver (webgpu_op_test) + - # its device-free util unit test. GTest needs -DEXECUTORCH_BUILD_TESTS=ON. + # All WebGPU native tests use GTest (device-dependent ones bring up the device + # in their own main(); the fold unit test is device-free via gtest_main). + # GTest needs -DEXECUTORCH_BUILD_TESTS=ON. if(NOT TARGET GTest::gtest) find_package(GTest QUIET) endif() @@ -195,12 +185,28 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST) target_compile_options(webgpu_op_test_util_test PRIVATE -fexceptions) set_property(TARGET webgpu_op_test_util_test PROPERTY CXX_STANDARD 17) - # Dynamic-shape integration test: a gtest binary with its own main() that - # brings up the device once (like webgpu_op_test). + # Device-dependent native tests: each has its own main() that brings up the + # device once, then RUN_ALL_TESTS(); link GTest::gtest (not gtest_main). + add_webgpu_native_test(webgpu_native_test test/test_webgpu_native.cpp) + target_link_libraries(webgpu_native_test PRIVATE GTest::gtest) + add_webgpu_native_test( + webgpu_dispatch_order_test test/native/test_dispatch_order.cpp + ) + target_link_libraries(webgpu_dispatch_order_test PRIVATE GTest::gtest) + add_webgpu_native_test( + webgpu_scratch_buffer_test test/native/test_scratch_buffer.cpp + ) + target_link_libraries(webgpu_scratch_buffer_test PRIVATE GTest::gtest) + add_webgpu_native_test( + webgpu_update_cache_test test/native/test_update_cache.cpp + ) + target_link_libraries(webgpu_update_cache_test PRIVATE GTest::gtest) add_webgpu_native_test( webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp ) target_link_libraries(webgpu_dynamic_shape_test PRIVATE GTest::gtest) + add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp) + target_link_libraries(webgpu_index_test PRIVATE GTest::gtest) # Device-free fold unit test (gtest_main provides main; no device needed). add_webgpu_native_test( @@ -210,5 +216,4 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST) webgpu_dispatch_2d_test PRIVATE GTest::gtest GTest::gtest_main ) endif() - add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp) endif() diff --git a/backends/webgpu/scripts/test_webgpu_native_ci.sh b/backends/webgpu/scripts/test_webgpu_native_ci.sh index 0ed8c88e3b2..810d8165303 100644 --- a/backends/webgpu/scripts/test_webgpu_native_ci.sh +++ b/backends/webgpu/scripts/test_webgpu_native_ci.sh @@ -47,6 +47,8 @@ UPDATE_CACHE_DIR="/tmp/update_cache" UPDATE_CACHE_OK=1 INDEX_DIR="/tmp/index" INDEX_OK=1 +DYNAMIC_SHAPE_DIR="/tmp/dynamic_shape" +DYNAMIC_SHAPE_OK=1 EMBEDDING_MODEL="/tmp/webgpu_embedding_q4gsw.pte" EMBEDDING_INDICES="/tmp/webgpu_embedding_q4gsw_indices.bin" EMBEDDING_GOLDEN="/tmp/webgpu_embedding_q4gsw_golden.bin" @@ -111,6 +113,11 @@ from executorch.backends.webgpu.test.ops.index.test_index import export_all_inde export_all_index_models('${INDEX_DIR}') " || { echo "WARN: index export failed; skipping index native test"; INDEX_OK=0; } +$PYTHON_EXECUTABLE -c " +from executorch.backends.webgpu.test.ops.dynamic_shape.test_dynamic_shape_export import export_dynamic_shape_cases +export_dynamic_shape_cases('${DYNAMIC_SHAPE_DIR}') +" || { echo "WARN: dynamic_shape export failed; skipping dynamic_shape native test"; DYNAMIC_SHAPE_OK=0; } + # Non-fatal: a failed sdpa export makes the required 4k/8k configs hard-fail in # webgpu_native_test below (precise per-config error), so don't exit/mask here. $PYTHON_EXECUTABLE -c " @@ -132,6 +139,7 @@ rm -rf "${BUILD_DIR}" cmake \ -DEXECUTORCH_BUILD_WEBGPU=ON \ -DEXECUTORCH_BUILD_WEBGPU_TEST=ON \ + -DEXECUTORCH_BUILD_TESTS=ON \ -DDawn_DIR="${Dawn_DIR}" \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ @@ -143,7 +151,7 @@ cmake \ "${EXECUTORCH_ROOT}" # ── Build + run every native test target that exists in this tree ──────────── -TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test webgpu_dispatch_2d_test) +TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test webgpu_dynamic_shape_test webgpu_dispatch_2d_test) BIN_DIR="${BUILD_DIR}/backends/webgpu" # Which targets are defined depends on which diffs are landed (native_test + @@ -211,6 +219,9 @@ fi if [[ "${INDEX_OK}" == "1" && -x "${BIN_DIR}/webgpu_index_test" ]]; then "${BIN_DIR}/webgpu_index_test" "${INDEX_DIR}" fi +if [[ "${DYNAMIC_SHAPE_OK}" == "1" && -x "${BIN_DIR}/webgpu_dynamic_shape_test" ]]; then + "${BIN_DIR}/webgpu_dynamic_shape_test" "${DYNAMIC_SHAPE_DIR}" +fi [[ -x "${BIN_DIR}/webgpu_scratch_buffer_test" ]] && "${BIN_DIR}/webgpu_scratch_buffer_test" # Device-free: pure 2D workgroup-count fold unit test (no .pte, no GPU). [[ -x "${BIN_DIR}/webgpu_dispatch_2d_test" ]] && "${BIN_DIR}/webgpu_dispatch_2d_test" diff --git a/backends/webgpu/test/native/test_dispatch_order.cpp b/backends/webgpu/test/native/test_dispatch_order.cpp index 0f3eb5dea8e..d8aa627eff3 100644 --- a/backends/webgpu/test/native/test_dispatch_order.cpp +++ b/backends/webgpu/test/native/test_dispatch_order.cpp @@ -10,10 +10,13 @@ #include #include +#include + #include #include #include #include +#include #include #include #include @@ -24,23 +27,8 @@ using namespace executorch::runtime; namespace { -struct Case { - const char* name; - std::vector sizes; -}; - -// Mirrors _CASES in test_dispatch_order.py (add-chain or rms_norm+add chain). -const std::vector kCases = { - {"single", {16, 16}}, - {"chain3", {64, 64}}, - {"chain5_tiny", {1, 1}}, - {"chain5_wide", {7, 896}}, - {"chain8", {256, 256}}, - {"deep32", {128, 128}}, - {"large_chain", {1024, 1024}}, - {"het_small", {1, 1, 7, 896}}, - {"het_deep", {1, 1, 5, 256}}, -}; +// Artifacts directory; set from env/argv in main() before RUN_ALL_TESTS(). +std::string g_dir; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) std::vector read_f32_bin(const std::string& path) { std::ifstream f(path, std::ios::binary | std::ios::ate); @@ -59,53 +47,35 @@ std::vector read_f32_bin(const std::string& path) { return data; } -bool run_case(const std::string& dir, const Case& tc) { - printf("\n--- dispatch_order[%s] ---\n", tc.name); - const std::string base = dir + "/" + tc.name; +// Mirrors _CASES in test_dispatch_order.py (add-chain or rms_norm+add chain). +void run_case(const char* name, const std::vector& sizes) { + const std::string base = g_dir + "/" + name; std::vector input = read_f32_bin(base + ".input.bin"); std::vector golden = read_f32_bin(base + ".golden.bin"); - if (input.empty() || golden.empty()) { - printf("FAIL: could not read input/golden for %s\n", tc.name); - return false; - } + ASSERT_FALSE(input.empty() || golden.empty()) + << "could not read input/golden for " << name; Module module(base + ".pte"); - if (module.load_forward() != Error::Ok) { - printf("FAIL: could not load %s.pte\n", tc.name); - return false; - } + ASSERT_EQ(module.load_forward(), Error::Ok) + << "could not load " << name << ".pte"; size_t expected = 1; - for (int32_t d : tc.sizes) { + for (int32_t d : sizes) { expected *= static_cast(d); } - if (input.size() != expected) { - printf( - "FAIL: input numel %zu != expected %zu for %s\n", - input.size(), - expected, - tc.name); - return false; - } - auto x = make_tensor_ptr(tc.sizes, std::vector(input)); + ASSERT_EQ(input.size(), expected) + << "input numel " << input.size() << " != expected " << expected + << " for " << name; + auto x = make_tensor_ptr(sizes, std::vector(input)); auto result = module.forward({EValue(x)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); - if (outputs.empty() || !outputs[0].isTensor()) { - printf("FAIL: no tensor output\n"); - return false; - } + ASSERT_TRUE(!outputs.empty() && outputs[0].isTensor()) << "no tensor output"; const auto& out_tensor = outputs[0].toTensor(); - if (static_cast(out_tensor.numel()) != golden.size()) { - printf( - "FAIL: output numel %zu != golden %zu\n", - (size_t)out_tensor.numel(), - golden.size()); - return false; - } + ASSERT_EQ(static_cast(out_tensor.numel()), golden.size()) + << "output numel " << (size_t)out_tensor.numel() << " != golden " + << golden.size(); const float* out_data = out_tensor.const_data_ptr(); float max_abs_err = 0.0f; @@ -116,52 +86,76 @@ bool run_case(const std::string& dir, const Case& tc) { const float denom = std::max(std::abs(golden[i]), 1e-6f); max_rel_err = std::max(max_rel_err, abs_err / denom); } - printf( - "Max abs error: %e Max rel error: %e (%zu elements)\n", - max_abs_err, - max_rel_err, - golden.size()); // Lenient gate: pass iff abs<=tol OR rel<=tol (near-zero goldens). - if (max_abs_err > 1e-3f && max_rel_err > 1e-3f) { - printf("FAIL: dispatch_order[%s] exceeds tolerance 1e-3\n", tc.name); - return false; - } - printf("PASS: dispatch_order[%s]\n", tc.name); - return true; + EXPECT_FALSE(max_abs_err > 1e-3f && max_rel_err > 1e-3f) + << "dispatch_order[" << name + << "] exceeds tolerance 1e-3 (max_abs_err=" << max_abs_err + << " max_rel_err=" << max_rel_err << ", " << golden.size() + << " elements)"; } } // namespace +TEST(DispatchOrder, single) { + run_case("single", {16, 16}); +} + +TEST(DispatchOrder, chain3) { + run_case("chain3", {64, 64}); +} + +TEST(DispatchOrder, chain5_tiny) { + run_case("chain5_tiny", {1, 1}); +} + +TEST(DispatchOrder, chain5_wide) { + run_case("chain5_wide", {7, 896}); +} + +TEST(DispatchOrder, chain8) { + run_case("chain8", {256, 256}); +} + +TEST(DispatchOrder, deep32) { + run_case("deep32", {128, 128}); +} + +TEST(DispatchOrder, large_chain) { + run_case("large_chain", {1024, 1024}); +} + +TEST(DispatchOrder, het_small) { + run_case("het_small", {1, 1, 7, 896}); +} + +TEST(DispatchOrder, het_deep) { + run_case("het_deep", {1, 1, 5, 256}); +} + int main(int argc, char** argv) { - std::string dir = "/tmp/dispatch_order"; + ::testing::InitGoogleTest(&argc, argv); + + // Artifacts dir: env wins, else first positional arg, else default (gtest + // flags were already stripped by InitGoogleTest above). + g_dir = "/tmp/dispatch_order"; if (argc > 1) { - dir = argv[1]; + g_dir = argv[1]; } if (const char* env = std::getenv("WEBGPU_DISPATCH_ORDER_DIR")) { - dir = env; + g_dir = env; } WebGPUContext ctx; try { ctx = create_webgpu_context(); } catch (const std::exception& e) { - printf("SKIP: %s\n", e.what()); + std::printf("SKIP: %s\n", e.what()); return 0; } set_default_webgpu_context(&ctx); - printf("WebGPU device acquired (native); case dir: %s\n", dir.c_str()); - - bool ok = true; - for (const auto& tc : kCases) { - ok = run_case(dir, tc) && ok; - } + const int rc = RUN_ALL_TESTS(); set_default_webgpu_context(nullptr); destroy_webgpu_context(ctx); - - if (!ok) { - return 1; - } - printf("\nAll dispatch_order tests passed\n"); - return 0; + return rc; } diff --git a/backends/webgpu/test/native/test_index.cpp b/backends/webgpu/test/native/test_index.cpp index aed24c0a796..91f4ec9ea01 100644 --- a/backends/webgpu/test/native/test_index.cpp +++ b/backends/webgpu/test/native/test_index.cpp @@ -10,10 +10,14 @@ #include #include +#include + #include #include +#include #include #include +#include #include #include #include @@ -24,13 +28,8 @@ using namespace executorch::runtime; namespace { -// Names mirror test_index.py CONFIGS (self/idx/golden bins written per case). -constexpr const char* kIndexCases[] = { - "index_n16_m5", - "index_n8_rev", - "index_n32_m3", - "index_n4_rep", -}; +// Artifacts directory; set from env/argv in main() before RUN_ALL_TESTS(). +std::string g_dir; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) std::vector read_f32_bin(const std::string& path) { std::ifstream f(path, std::ios::binary | std::ios::ate); @@ -62,22 +61,19 @@ std::vector read_i32_bin(const std::string& path) { return data; } -bool run_case(const std::string& dir, const char* name) { - printf("\n--- Test: %s ---\n", name); - const std::string base = dir + "/" + name; +// index.Tensor: self [n] float, idx [m] int64 -> output [m]. Names mirror +// test_index.py CONFIGS (self/idx/golden bins written per case). +void run_case(const char* name) { + const std::string base = g_dir + "/" + name; std::vector self_data = read_f32_bin(base + ".self.bin"); std::vector idx32 = read_i32_bin(base + ".idx.bin"); std::vector golden = read_f32_bin(base + ".golden.bin"); - if (self_data.empty() || idx32.empty() || golden.empty()) { - printf("FAIL: could not read self/idx/golden for %s\n", name); - return false; - } + ASSERT_FALSE(self_data.empty() || idx32.empty() || golden.empty()) + << "could not read self/idx/golden for " << name; Module module(base + ".pte"); - if (module.load_forward() != Error::Ok) { - printf("FAIL: could not load %s.pte\n", name); - return false; - } + ASSERT_EQ(module.load_forward(), Error::Ok) + << "could not load " << name << ".pte"; const int32_t n = static_cast(self_data.size()); const int32_t m = static_cast(idx32.size()); @@ -87,33 +83,21 @@ bool run_case(const std::string& dir, const char* name) { auto idx = make_tensor_ptr({m}, std::vector(idx64)); auto result = module.forward({EValue(x), EValue(idx)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); // index.Tensor has exactly one output of shape [num_indices]; fail loud else. - if (outputs.size() != 1 || !outputs[0].isTensor()) { - printf("FAIL: expected exactly one tensor output\n"); - return false; - } + ASSERT_TRUE(outputs.size() == 1 && outputs[0].isTensor()) + << "expected exactly one tensor output"; const auto& out_tensor = outputs[0].toTensor(); - if (out_tensor.dim() != 1 || out_tensor.size(0) != m) { - printf( - "FAIL: output shape mismatch (dim %d size0 %d, expected [%d])\n", - (int)out_tensor.dim(), - (int)(out_tensor.dim() == 1 ? out_tensor.size(0) : -1), - m); - return false; - } - if (static_cast(out_tensor.numel()) != golden.size()) { - printf( - "FAIL: output numel %zu != golden %zu\n", - (size_t)out_tensor.numel(), - golden.size()); - return false; - } + ASSERT_TRUE(out_tensor.dim() == 1 && out_tensor.size(0) == m) + << "output shape mismatch (dim " << (int)out_tensor.dim() << " size0 " + << (int)(out_tensor.dim() == 1 ? out_tensor.size(0) : -1) + << ", expected [" << m << "])"; + ASSERT_EQ(static_cast(out_tensor.numel()), golden.size()) + << "output numel " << (size_t)out_tensor.numel() << " != golden " + << golden.size(); const float* out_data = out_tensor.const_data_ptr(); float max_abs_err = 0.0f; @@ -124,51 +108,54 @@ bool run_case(const std::string& dir, const char* name) { const float denom = std::max(std::abs(golden[i]), 1e-6f); max_rel_err = std::max(max_rel_err, abs_err / denom); } - printf( - "Max abs error: %e Max rel error: %e (%zu elements)\n", - max_abs_err, - max_rel_err, - golden.size()); - if (max_abs_err > 1e-3f || max_rel_err > 1e-3f) { - printf("FAIL: %s exceeds tolerance 1e-3\n", name); - return false; - } - printf("PASS: %s\n", name); - return true; + EXPECT_LE(max_abs_err, 1e-3f) << name << " max_abs_err=" << max_abs_err + << " (" << golden.size() << " elements)"; + EXPECT_LE(max_rel_err, 1e-3f) << name << " max_rel_err=" << max_rel_err + << " (" << golden.size() << " elements)"; } } // namespace +TEST(Index, N16M5) { + run_case("index_n16_m5"); +} + +TEST(Index, N8Rev) { + run_case("index_n8_rev"); +} + +TEST(Index, N32M3) { + run_case("index_n32_m3"); +} + +TEST(Index, N4Rep) { + run_case("index_n4_rep"); +} + int main(int argc, char** argv) { - std::string dir = "/tmp/index"; + ::testing::InitGoogleTest(&argc, argv); + + // Artifacts dir: env wins, else first positional arg, else default (gtest + // flags were already stripped by InitGoogleTest above). + g_dir = "/tmp/index"; if (argc > 1) { - dir = argv[1]; + g_dir = argv[1]; } if (const char* env = std::getenv("WEBGPU_INDEX_DIR")) { - dir = env; + g_dir = env; } WebGPUContext ctx; try { ctx = create_webgpu_context(); } catch (const std::exception& e) { - printf("SKIP: %s\n", e.what()); + std::printf("SKIP: %s\n", e.what()); return 0; } set_default_webgpu_context(&ctx); - printf("WebGPU device acquired (native); case dir: %s\n", dir.c_str()); - - bool ok = true; - for (const char* name : kIndexCases) { - ok = run_case(dir, name) && ok; - } + const int rc = RUN_ALL_TESTS(); set_default_webgpu_context(nullptr); destroy_webgpu_context(ctx); - - if (!ok) { - return 1; - } - printf("\nAll index tests passed\n"); - return 0; + return rc; } diff --git a/backends/webgpu/test/native/test_scratch_buffer.cpp b/backends/webgpu/test/native/test_scratch_buffer.cpp index 7a4df6e9d00..98cf3648c6b 100644 --- a/backends/webgpu/test/native/test_scratch_buffer.cpp +++ b/backends/webgpu/test/native/test_scratch_buffer.cpp @@ -14,17 +14,27 @@ #include +#include + #include #include #include -#include #include +#include #include using namespace executorch::backends::webgpu; namespace { +// WebGPU context; set from create_webgpu_context() in main() before +// RUN_ALL_TESTS(). +// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) +WGPUInstance g_instance = nullptr; +WGPUDevice g_device = nullptr; +WGPUQueue g_queue = nullptr; +// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) + struct MapCb { std::atomic status{WGPUMapAsyncStatus_Error}; }; @@ -79,56 +89,46 @@ std::vector readback( return out; } +} // namespace + // Tier 1: allocation, zero-size guard, distinct non-null handles. -bool tier1_alloc(WGPUDevice device) { - printf("\n--- scratch[tier1: allocation] ---\n"); +TEST(ScratchBuffer, Tier1Alloc) { WebGPUGraph g; - g.set_device(device); + g.set_device(g_device); WGPUBuffer a = g.create_scratch_buffer(64 * sizeof(float)); WGPUBuffer z = g.create_scratch_buffer(0); // guarded to 4 bytes WGPUBuffer b = g.create_scratch_buffer(64 * sizeof(float)); - const bool ok = a && z && b && a != b && a != z && b != z; - printf(ok ? "PASS: allocation\n" : "FAIL: allocation\n"); - return ok; // graph dtor releases all three here + EXPECT_TRUE(a && z && b && a != b && a != z && b != z); + // graph dtor releases all three here } // Tier 2: host->scratch write, scratch->staging copy, read-back round-trip. -bool tier2_roundtrip( - WGPUInstance instance, - WGPUDevice device, - WGPUQueue queue) { - printf("\n--- scratch[tier2: copy round-trip] ---\n"); - bool ok = true; +TEST(ScratchBuffer, Tier2Roundtrip) { for (int n : {1, 7, 1024}) { WebGPUGraph g; - g.set_device(device); + g.set_device(g_device); WGPUBuffer s = g.create_scratch_buffer(n * sizeof(float)); std::vector in(n); for (int i = 0; i < n; i++) { in[i] = static_cast(i) * 0.5f + 1.0f; } - wgpuQueueWriteBuffer(queue, s, 0, in.data(), n * sizeof(float)); + wgpuQueueWriteBuffer(g_queue, s, 0, in.data(), n * sizeof(float)); std::vector back = - readback(instance, device, queue, s, n * sizeof(float)); + readback(g_instance, g_device, g_queue, s, n * sizeof(float)); float max_err = 0.0f; for (int i = 0; i < n; i++) { max_err = std::max(max_err, std::abs(back[i] - in[i])); } - printf(" n=%d max abs error %e\n", n, max_err); - if (max_err != 0.0f) { // pure copy: must be bit-exact - ok = false; - } + // pure copy: must be bit-exact + EXPECT_EQ(max_err, 0.0f) << "n=" << n << " max abs error " << max_err; } - printf(ok ? "PASS: copy round-trip\n" : "FAIL: copy round-trip\n"); - return ok; } // Tier 3a: bind scratch as a Storage buffer in a compute pass (its real use). -bool tier3_compute(WGPUInstance instance, WGPUDevice device, WGPUQueue queue) { - printf("\n--- scratch[tier3: compute Storage round-trip] ---\n"); +TEST(ScratchBuffer, Tier3Compute) { const int n = 256; WebGPUGraph g; - g.set_device(device); + g.set_device(g_device); WGPUBuffer s = g.create_scratch_buffer(n * sizeof(float)); const char* kWgsl = @@ -144,7 +144,7 @@ bool tier3_compute(WGPUInstance instance, WGPUDevice device, WGPUQueue queue) { wgsl.code = {kWgsl, WGPU_STRLEN}; WGPUShaderModuleDescriptor smd = {}; smd.nextInChain = &wgsl.chain; - WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &smd); + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(g_device, &smd); WGPUBindGroupLayoutEntry ble = {}; ble.binding = 0; @@ -153,18 +153,18 @@ bool tier3_compute(WGPUInstance instance, WGPUDevice device, WGPUQueue queue) { WGPUBindGroupLayoutDescriptor bld = {}; bld.entryCount = 1; bld.entries = &ble; - WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bld); + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(g_device, &bld); WGPUPipelineLayoutDescriptor pld = {}; pld.bindGroupLayoutCount = 1; pld.bindGroupLayouts = &bgl; - WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pld); + WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(g_device, &pld); WGPUComputePipelineDescriptor cpd = {}; cpd.layout = pl; cpd.compute.module = shader; cpd.compute.entryPoint = {"main", WGPU_STRLEN}; - WGPUComputePipeline pipe = wgpuDeviceCreateComputePipeline(device, &cpd); + WGPUComputePipeline pipe = wgpuDeviceCreateComputePipeline(g_device, &cpd); WGPUBindGroupEntry bge = {}; bge.binding = 0; @@ -174,10 +174,10 @@ bool tier3_compute(WGPUInstance instance, WGPUDevice device, WGPUQueue queue) { bgd.layout = bgl; bgd.entryCount = 1; bgd.entries = &bge; - WGPUBindGroup bg = wgpuDeviceCreateBindGroup(device, &bgd); + WGPUBindGroup bg = wgpuDeviceCreateBindGroup(g_device, &bgd); WGPUCommandEncoderDescriptor ed = {}; - WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(device, &ed); + WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(g_device, &ed); WGPUComputePassDescriptor pd = {}; WGPUComputePassEncoder pass = wgpuCommandEncoderBeginComputePass(enc, &pd); wgpuComputePassEncoderSetPipeline(pass, pipe); @@ -187,18 +187,17 @@ bool tier3_compute(WGPUInstance instance, WGPUDevice device, WGPUQueue queue) { wgpuComputePassEncoderRelease(pass); WGPUCommandBufferDescriptor cd = {}; WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(enc, &cd); - wgpuQueueSubmit(queue, 1, &cmd); + wgpuQueueSubmit(g_queue, 1, &cmd); wgpuCommandBufferRelease(cmd); wgpuCommandEncoderRelease(enc); std::vector back = - readback(instance, device, queue, s, n * sizeof(float)); + readback(g_instance, g_device, g_queue, s, n * sizeof(float)); float max_err = 0.0f; for (int i = 0; i < n; i++) { const float expected = static_cast(i) * 2.0f + 1.0f; max_err = std::max(max_err, std::abs(back[i] - expected)); } - printf(" max abs error %e (%d elements)\n", max_err, n); wgpuBindGroupRelease(bg); wgpuComputePipelineRelease(pipe); @@ -206,56 +205,40 @@ bool tier3_compute(WGPUInstance instance, WGPUDevice device, WGPUQueue queue) { wgpuBindGroupLayoutRelease(bgl); wgpuShaderModuleRelease(shader); - const bool ok = max_err == 0.0f; - printf( - ok ? "PASS: compute Storage round-trip\n" : "FAIL: compute round-trip\n"); - return ok; + EXPECT_EQ(max_err, 0.0f) << "max abs error " << max_err << " (" << n + << " elements)"; } // Tier 3b: many scratch buffers across repeated graphs; dtor must release all. -bool tier3_lifecycle(WGPUDevice device) { - printf("\n--- scratch[tier3: lifecycle/stress] ---\n"); - bool ok = true; +TEST(ScratchBuffer, Tier3Lifecycle) { for (int iter = 0; iter < 50; iter++) { WebGPUGraph g; - g.set_device(device); + g.set_device(g_device); for (int k = 0; k < 256; k++) { WGPUBuffer b = g.create_scratch_buffer(static_cast(k % 17) * sizeof(float)); - ok = ok && b != nullptr; + EXPECT_NE(b, nullptr); } } // each graph's dtor releases its 256 buffers here - printf( - ok ? "PASS: lifecycle/stress (50 graphs x 256 buffers)\n" - : "FAIL: lifecycle/stress (null buffer)\n"); - return ok; } -} // namespace +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); -int main() { WebGPUContext ctx; try { ctx = create_webgpu_context(); } catch (const std::exception& e) { - printf("SKIP: %s\n", e.what()); + std::printf("SKIP: %s\n", e.what()); return 0; } set_default_webgpu_context(&ctx); - printf("WebGPU device acquired (native)\n"); - - bool ok = true; - ok = tier1_alloc(ctx.device) && ok; - ok = tier2_roundtrip(ctx.instance, ctx.device, ctx.queue) && ok; - ok = tier3_compute(ctx.instance, ctx.device, ctx.queue) && ok; - ok = tier3_lifecycle(ctx.device) && ok; + g_instance = ctx.instance; + g_device = ctx.device; + g_queue = ctx.queue; + const int rc = RUN_ALL_TESTS(); set_default_webgpu_context(nullptr); destroy_webgpu_context(ctx); - - if (!ok) { - return 1; - } - printf("\nAll scratch_buffer tests passed\n"); - return 0; + return rc; } diff --git a/backends/webgpu/test/native/test_update_cache.cpp b/backends/webgpu/test/native/test_update_cache.cpp index 3f932ea7f03..dad859af669 100644 --- a/backends/webgpu/test/native/test_update_cache.cpp +++ b/backends/webgpu/test/native/test_update_cache.cpp @@ -10,10 +10,13 @@ #include #include +#include + #include #include #include #include +#include #include #include @@ -23,6 +26,9 @@ using namespace executorch::runtime; namespace { +// Artifacts directory; set from env/argv in main() before RUN_ALL_TESTS(). +std::string g_dir; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + struct UpdateCacheCase { const char* name; int s; @@ -40,20 +46,10 @@ constexpr UpdateCacheCase kCases[] = { {"shape_b_offset", 3, 4, 8, 16, 10}, }; -bool run_case(const std::string& dir, const UpdateCacheCase& tc) { - printf( - "\n--- Test: update_cache[%s] (S=%d,H=%d,D=%d,Cmax=%d,pos=%d) ---\n", - tc.name, - tc.s, - tc.h, - tc.d, - tc.cmax, - tc.input_pos); - Module module(dir + "/" + tc.name + ".pte"); - if (module.load_forward() != Error::Ok) { - printf("FAIL: could not load %s.pte\n", tc.name); - return false; - } +void run_case(const UpdateCacheCase& tc) { + Module module(g_dir + "/" + tc.name + ".pte"); + ASSERT_EQ(module.load_forward(), Error::Ok) + << "could not load " << tc.name << ".pte"; const int vnumel = tc.s * tc.h * tc.d; const int cnumel = tc.cmax * tc.h * tc.d; @@ -79,37 +75,24 @@ bool run_case(const std::string& dir, const UpdateCacheCase& tc) { auto v = make_tensor_ptr({1, tc.s, tc.h, tc.d}, std::vector(value)); auto c = make_tensor_ptr({1, tc.cmax, tc.h, tc.d}, std::vector(cache)); auto result = module.forward({EValue(v), EValue(c)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); - if (outputs.empty() || !outputs[0].isTensor()) { - printf("FAIL: no tensor output\n"); - return false; - } + ASSERT_TRUE(!outputs.empty() && outputs[0].isTensor()) << "no tensor output"; const auto& out_tensor = outputs[0].toTensor(); - if (static_cast(out_tensor.numel()) != cnumel) { - printf( - "FAIL: output numel %zu != expected %d\n", - (size_t)out_tensor.numel(), - cnumel); - return false; - } + ASSERT_EQ(static_cast(out_tensor.numel()), cnumel) + << "output numel " << (size_t)out_tensor.numel() << " != expected " + << cnumel; const float* out_data = out_tensor.const_data_ptr(); float max_abs_err = 0.0f; for (int i = 0; i < cnumel; i++) { max_abs_err = std::max(max_abs_err, std::abs(out_data[i] - ref[i])); } - printf("Max abs error: %e (checked %d elements)\n", max_abs_err, cnumel); // update_cache is a pure scatter copy: the output must be bit-exact. - if (max_abs_err > 0.0f) { - printf("FAIL: update_cache[%s] not bit-exact\n", tc.name); - return false; - } - printf("PASS: update_cache[%s]\n", tc.name); - return true; + EXPECT_EQ(max_abs_err, 0.0f) + << "update_cache[" << tc.name << "] not bit-exact (max abs error " + << max_abs_err << ", checked " << cnumel << " elements)"; } struct ReplayCase { @@ -120,18 +103,11 @@ struct ReplayCase { }; // Multi-step advancing-input_pos cache accumulation, mirroring VulkanSDPATest. -bool run_replay(const std::string& dir, const ReplayCase& rc) { +void run_replay(const ReplayCase& rc) { int cmax = 0; for (int s : rc.seq_lens) { cmax += s; } - printf( - "\n--- Replay: update_cache[%s] (H=%d,D=%d,Cmax=%d,%zu steps) ---\n", - rc.name, - rc.h, - rc.d, - cmax, - rc.seq_lens.size()); const int cnumel = cmax * rc.h * rc.d; std::vector cache(cnumel); @@ -141,7 +117,6 @@ bool run_replay(const std::string& dir, const ReplayCase& rc) { std::vector ref(cache); int input_pos = 0; - bool ok = true; for (size_t step = 0; step < rc.seq_lens.size(); step++) { const int s = rc.seq_lens[step]; const int vnumel = s * rc.h * rc.d; @@ -151,31 +126,22 @@ bool run_replay(const std::string& dir, const ReplayCase& rc) { value[i] = (base + static_cast(i)) * 0.25f; } - const std::string fname = dir + "/" + rc.name + "_step" + + const std::string fname = g_dir + "/" + rc.name + "_step" + std::to_string(step) + "_S" + std::to_string(s) + "_pos" + std::to_string(input_pos) + ".pte"; Module module(fname); - if (module.load_forward() != Error::Ok) { - printf("FAIL: could not load %s\n", fname.c_str()); - return false; - } + ASSERT_EQ(module.load_forward(), Error::Ok) << "could not load " << fname; auto v = make_tensor_ptr({1, s, rc.h, rc.d}, std::vector(value)); auto c = make_tensor_ptr({1, cmax, rc.h, rc.d}, std::vector(cache)); auto result = module.forward({EValue(v), EValue(c)}); - if (!result.ok()) { - printf( - "FAIL: forward failed step %zu (error %d)\n", - step, - (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed step " << step << " (error " + << (int)result.error() << ")"; const auto& outputs = result.get(); - if (outputs.empty() || !outputs[0].isTensor() || - static_cast(outputs[0].toTensor().numel()) != cnumel) { - printf("FAIL: bad cache output at step %zu\n", step); - return false; - } + ASSERT_TRUE( + !outputs.empty() && outputs[0].isTensor() && + static_cast(outputs[0].toTensor().numel()) == cnumel) + << "bad cache output at step " << step; const float* out_data = outputs[0].toTensor().const_data_ptr(); const int dst_offset = input_pos * rc.h * rc.d; @@ -190,24 +156,12 @@ bool run_replay(const std::string& dir, const ReplayCase& rc) { max_abs_err = std::max(max_abs_err, std::abs(out_data[i] - ref[i])); cache[i] = out_data[i]; // thread the accumulated cache into the next step } - printf( - " step %zu (S=%d,pos=%d): max abs error %e\n", - step, - s, - input_pos, - max_abs_err); - if (max_abs_err > 0.0f) { // pure scatter copy: must be bit-exact - ok = false; - } + // pure scatter copy: must be bit-exact + EXPECT_EQ(max_abs_err, 0.0f) + << "step " << step << " (S=" << s << ",pos=" << input_pos + << "): max abs error " << max_abs_err; input_pos += s; } - - if (ok) { - printf("PASS: update_cache[%s] replay\n", rc.name); - } else { - printf("FAIL: update_cache[%s] replay\n", rc.name); - } - return ok; } struct NegativeCase { @@ -216,76 +170,75 @@ struct NegativeCase { }; // Single-op, single-guard-violation cases: rejection maps to the named guard. -bool run_negative_case(const std::string& dir, const NegativeCase& nc) { - printf( - "\n--- Negative: update_cache[%s] (expect rejection: %s) ---\n", - nc.name, - nc.guard); - Module module(dir + "/" + nc.name + ".pte"); +void run_negative_case(const NegativeCase& nc) { + Module module(g_dir + "/" + nc.name + ".pte"); const Error err = module.load_forward(); // init catches the guard throw -> this code; other errors = setup failure. - if (err != Error::DelegateInvalidCompatibility) { - printf( - "FAIL: %s.pte -> error %d; expected DelegateInvalidCompatibility " - "from the '%s' guard\n", - nc.name, - (int)err, - nc.guard); - return false; - } - printf("PASS: rejected with DelegateInvalidCompatibility (%s)\n", nc.guard); - return true; + EXPECT_EQ(err, Error::DelegateInvalidCompatibility) + << nc.name << ".pte -> error " << (int)err + << "; expected DelegateInvalidCompatibility from the '" << nc.guard + << "' guard"; } } // namespace -int main(int argc, char** argv) { - std::string dir = "/tmp/update_cache"; - if (argc > 1) { - dir = argv[1]; - } - if (const char* env = std::getenv("WEBGPU_UPDATE_CACHE_DIR")) { - dir = env; - } - - WebGPUContext ctx; - try { - ctx = create_webgpu_context(); - } catch (const std::exception& e) { - printf("SKIP: %s\n", e.what()); - return 0; - } - set_default_webgpu_context(&ctx); - printf("WebGPU device acquired (native); case dir: %s\n", dir.c_str()); - - bool ok = true; +// Single-step scatter cases (prefill / offset / shape variants): the op output +// must equal the inline integer-exact scatter reference. +TEST(UpdateCache, ScatterCases) { for (const auto& tc : kCases) { - ok = run_case(dir, tc) && ok; + run_case(tc); } +} +// Multi-step advancing-input_pos cache accumulation, mirroring VulkanSDPATest. +TEST(UpdateCache, Replay) { const std::vector kReplays = { {"seqA", 4, 4, {3, 1, 1, 5, 1, 1, 2}}, {"seqB", 2, 8, {3, 1, 1, 5, 1, 1}}, {"llama3", 8, 128, {111, 1, 1, 1, 57, 1, 1}}, }; for (const auto& rc : kReplays) { - ok = run_replay(dir, rc) && ok; + run_replay(rc); } +} +// Guard-violation cases: each must be rejected with +// DelegateInvalidCompatibility. +TEST(UpdateCache, Negative) { const NegativeCase kNegatives[] = { {"neg_batch", "batch must be 1"}, {"neg_fp16", "fp32-only"}, }; for (const auto& nc : kNegatives) { - ok = run_negative_case(dir, nc) && ok; + run_negative_case(nc); } +} - set_default_webgpu_context(nullptr); - destroy_webgpu_context(ctx); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); - if (!ok) { - return 1; + // Artifacts dir: env wins, else first positional arg, else default (gtest + // flags were already stripped by InitGoogleTest above). + std::string dir = "/tmp/update_cache"; + if (argc > 1) { + dir = argv[1]; } - printf("\nAll update_cache tests passed\n"); - return 0; + if (const char* env = std::getenv("WEBGPU_UPDATE_CACHE_DIR")) { + dir = env; + } + g_dir = dir; + + WebGPUContext ctx; + try { + ctx = create_webgpu_context(); + } catch (const std::exception& e) { + std::printf("SKIP: %s\n", e.what()); + return 0; + } + set_default_webgpu_context(&ctx); + + const int rc = RUN_ALL_TESTS(); + set_default_webgpu_context(nullptr); + destroy_webgpu_context(ctx); + return rc; } diff --git a/backends/webgpu/test/test_build_webgpu.sh b/backends/webgpu/test/test_build_webgpu.sh index 2fd1dea1a52..1c79d17cf06 100755 --- a/backends/webgpu/test/test_build_webgpu.sh +++ b/backends/webgpu/test/test_build_webgpu.sh @@ -85,6 +85,7 @@ cmake \ -DEXECUTORCH_BUILD_WEBGPU=ON \ -DDawn_DIR="${Dawn_DIR}" \ -DEXECUTORCH_BUILD_WEBGPU_TEST=ON \ + -DEXECUTORCH_BUILD_TESTS=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index cb6d491f4ba..556eb0127b4 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -12,10 +12,14 @@ #include #include +#include + #include #include +#include #include #include +#include #include #include #include @@ -24,27 +28,31 @@ using namespace executorch::backends::webgpu; using namespace executorch::extension; using namespace executorch::runtime; +namespace { + +// Environment-derived config; captured in main() before RUN_ALL_TESTS(). +// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) +std::string g_update_cache_model_path; +std::string g_qlinear_dir; +std::string g_prepack_model_path, g_prepack_golden_path; +std::string g_prepack2_model_path, g_prepack2_golden_path; +std::string g_prepack_tied_model_path, g_prepack_tied_golden_path; +std::string g_sdpa_dir; +std::string g_symint_blob; +// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) + #ifdef WGPU_BACKEND_ENABLE_PROFILING // Capacity-overrun must throw; runs without a device or TimestampQuery. -static bool test_query_pool_overrun_throws() { - printf("\n--- Test: WebGPUQueryPool capacity-overrun guard ---\n"); +void test_query_pool_overrun_throws() { WebGPUQueryPool qp; - try { - qp.reset(1); - } catch (const std::exception&) { - printf("PASS: reset beyond capacity throws\n"); - return true; - } - printf("FAIL: reset beyond capacity did not throw\n"); - return false; + EXPECT_THROW(qp.reset(1), std::exception) + << "reset beyond capacity did not throw"; } // WebGPUQueryPool roundtrip: time a probe pass; assert non-zero GPU duration. -static bool test_query_pool_roundtrip(const WebGPUContext& ctx) { - printf("\n--- Test: WebGPUQueryPool roundtrip ---\n"); +void test_query_pool_roundtrip(const WebGPUContext& ctx) { if (!ctx.timestamp_supported) { - printf("SKIP: adapter lacks TimestampQuery feature\n"); - return true; + GTEST_SKIP() << "adapter lacks TimestampQuery feature"; } WGPUDevice device = ctx.device; @@ -134,33 +142,21 @@ static bool test_query_pool_roundtrip(const WebGPUContext& ctx) { wgpuBindGroupRelease(bg); wgpuShaderModuleRelease(shader); - if (qp.results().size() != 1) { - printf("FAIL: expected 1 duration, got %zu\n", qp.results().size()); - return false; - } + ASSERT_EQ(qp.results().size(), 1u) + << "expected 1 duration, got " << qp.results().size(); const uint64_t dur = qp.results()[0].execution_duration_ns; printf(" probe duration: %llu ns\n", (unsigned long long)dur); - if (dur == 0) { - printf("FAIL: probe duration is zero (expected monotonic non-zero)\n"); - return false; - } - printf("PASS: WebGPUQueryPool roundtrip -- non-zero GPU kernel duration\n"); - return true; + EXPECT_NE(dur, 0u) << "probe duration is zero (expected monotonic non-zero)"; } #endif // WGPU_BACKEND_ENABLE_PROFILING -static bool test_update_cache(const std::string& model_path) { +void test_update_cache(const std::string& model_path) { // update_cache: value [1,2,2,4] scattered into cache [1,8,2,4] at // input_pos=0. - printf( - "\n--- Test: update_cache (value[1,2,2,4] -> cache[1,8,2,4], pos=0) ---\n"); - Module module(model_path); auto err = module.load_forward(); - if (err != Error::Ok) { - printf("FAIL: could not load forward method (error %d)\n", (int)err); - return false; - } + ASSERT_EQ(err, Error::Ok) + << "could not load forward method (error " << (int)err << ")"; printf("Model loaded: %s\n", model_path.c_str()); constexpr int S = 2, H = 2, D = 4, Cmax = 8; @@ -187,24 +183,15 @@ static bool test_update_cache(const std::string& model_path) { auto v = make_tensor_ptr({1, S, H, D}, std::vector(value)); auto c = make_tensor_ptr({1, Cmax, H, D}, std::vector(cache)); auto result = module.forward({EValue(v), EValue(c)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); - if (outputs.empty() || !outputs[0].isTensor()) { - printf("FAIL: no tensor output\n"); - return false; - } + ASSERT_TRUE(!outputs.empty() && outputs[0].isTensor()) << "no tensor output"; const auto& out_tensor = outputs[0].toTensor(); - if (out_tensor.numel() != cnumel) { - printf( - "FAIL: output numel %zu != expected %d\n", - (size_t)out_tensor.numel(), - cnumel); - return false; - } + ASSERT_EQ((int)out_tensor.numel(), cnumel) + << "output numel " << (size_t)out_tensor.numel() << " != expected " + << cnumel; const float* out_data = out_tensor.const_data_ptr(); float max_abs_err = 0.0f; @@ -212,15 +199,10 @@ static bool test_update_cache(const std::string& model_path) { max_abs_err = std::max(max_abs_err, std::abs(out_data[i] - ref[i])); } printf("Max abs error: %e (checked %d elements)\n", max_abs_err, cnumel); - if (max_abs_err > 1e-3f) { - printf("FAIL: max error exceeds tolerance 1e-3\n"); - return false; - } - printf("PASS: update_cache test\n"); - return true; + EXPECT_LE(max_abs_err, 1e-3f) << "max error exceeds tolerance 1e-3"; } -static std::vector load_golden(const std::string& path, size_t numel) { +std::vector load_golden(const std::string& path, size_t numel) { // Load a raw little-endian fp32 golden written by the export .py (the native // binary has no ATen/torch, so the reference is computed offline). std::vector g(numel); @@ -241,7 +223,7 @@ static std::vector load_golden(const std::string& path, size_t numel) { // value can't blow up the rel metric (the kernel's ~1e-8 abs error is the real // signal at llama3 scale). Sets the reported maxima; true iff all elements // pass. -static bool sdpa_within_tol( +bool sdpa_within_tol( const float* out, const float* golden, int n, @@ -277,7 +259,7 @@ struct Q4gswConfig { // Llama-3.2-1B linear shapes (q/o/k/v/gate/up/down + lm_head) + 4k/8k prefill. // tol scales with K (fp32 accum depth), not M; down_proj (K=8192) is looser. -static const Q4gswConfig kQ4gswConfigs[] = { +const Q4gswConfig kQ4gswConfigs[] = { // name M K N tol_abs tol_rel req heavy {"q_proj", 1, 2048, 2048, 1e-4f, 1e-3f, true, false}, {"kv_proj", 1, 2048, 512, 1e-4f, 1e-3f, true, false}, @@ -299,23 +281,36 @@ static const Q4gswConfig kQ4gswConfigs[] = { }; // /16 ramp over the flat index; mirrors test_quantized_linear.py _ramp_input. -static float q4gsw_ramp(int i) { +float q4gsw_ramp(int i) { return static_cast((i % 17) - 8) / 16.0f; } -// Fwd decl of the per-element abs-OR-rel tolerance helper (defined below). -static bool quant_within_tol( +// Per-element abs-OR-rel tolerance helper. +bool quant_within_tol( const float* out, const float* golden, int n, float atol, float rtol, float* ma, - float* mr); + float* mr) { + float max_abs = 0.0f, max_rel = 0.0f; + bool ok = true; + for (int i = 0; i < n; i++) { + const float ae = std::abs(out[i] - golden[i]); + const float re = ae / std::max(std::abs(golden[i]), 1e-6f); + max_abs = std::max(max_abs, ae); + max_rel = std::max(max_rel, re); + if (ae > atol && re > rtol) { + ok = false; + } + } + *ma = max_abs; + *mr = max_rel; + return ok; +} -static std::vector load_indices( - const std::string& path, - size_t numel) { +std::vector load_indices(const std::string& path, size_t numel) { // Load raw little-endian int32 indices written by the export .py. std::vector g(numel); FILE* f = std::fopen(path.c_str(), "rb"); @@ -330,7 +325,7 @@ static std::vector load_indices( return g; } -static bool test_embedding_q4gsw( +void test_embedding_q4gsw( const std::string& model_path, const std::string& indices_path, const std::string& golden_path, @@ -347,44 +342,29 @@ static bool test_embedding_q4gsw( Module module(model_path); auto err = module.load_forward(); - if (err != Error::Ok) { - printf("FAIL: could not load forward method (error %d)\n", (int)err); - return false; - } + ASSERT_EQ(err, Error::Ok) + << "could not load forward method (error " << (int)err << ")"; printf("Model loaded: %s\n", model_path.c_str()); std::vector idx32 = load_indices(indices_path, num_indices); std::vector golden = load_golden(golden_path, out_numel); - if (idx32.empty() || golden.empty()) { - printf( - "FAIL: could not load indices %s / golden %s\n", - indices_path.c_str(), - golden_path.c_str()); - return false; - } + ASSERT_FALSE(idx32.empty() || golden.empty()) + << "could not load indices " << indices_path << " / golden " + << golden_path; // int64 at the program boundary; copy_inputs narrows to the int32 buffer. std::vector idx64(idx32.begin(), idx32.end()); auto idx = make_tensor_ptr({num_indices}, std::move(idx64)); auto result = module.forward({EValue(idx)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); - if (outputs.empty() || !outputs[0].isTensor()) { - printf("FAIL: no tensor output\n"); - return false; - } + ASSERT_TRUE(!outputs.empty() && outputs[0].isTensor()) << "no tensor output"; const auto& out_tensor = outputs[0].toTensor(); - if (out_tensor.numel() != out_numel) { - printf( - "FAIL: output numel %zu != expected %d\n", - (size_t)out_tensor.numel(), - out_numel); - return false; - } + ASSERT_EQ((int)out_tensor.numel(), out_numel) + << "output numel " << (size_t)out_tensor.numel() << " != expected " + << out_numel; const float* out_data = out_tensor.const_data_ptr(); float max_abs_err = 0.0f, max_rel_err = 0.0f; @@ -401,39 +381,10 @@ static bool test_embedding_q4gsw( max_abs_err, max_rel_err, out_numel); - if (!pass) { - printf("FAIL: embedding_q4gsw exceeds tolerance 1e-3 (abs AND rel)\n"); - return false; - } - printf("PASS: embedding_q4gsw test\n"); - return true; -} - -static bool quant_within_tol( - const float* out, - const float* golden, - int n, - float atol, - float rtol, - float* ma, - float* mr) { - float max_abs = 0.0f, max_rel = 0.0f; - bool ok = true; - for (int i = 0; i < n; i++) { - const float ae = std::abs(out[i] - golden[i]); - const float re = ae / std::max(std::abs(golden[i]), 1e-6f); - max_abs = std::max(max_abs, ae); - max_rel = std::max(max_rel, re); - if (ae > atol && re > rtol) { - ok = false; - } - } - *ma = max_abs; - *mr = max_rel; - return ok; + EXPECT_TRUE(pass) << "embedding_q4gsw exceeds tolerance 1e-3 (abs AND rel)"; } -static bool test_rope( +void test_rope( const std::string& model_path, const std::string& xq_golden_path, const std::string& xk_golden_path, @@ -456,10 +407,8 @@ static bool test_rope( Module module(model_path); auto err = module.load_forward(); - if (err != Error::Ok) { - printf("FAIL: could not load forward method (error %d)\n", (int)err); - return false; - } + ASSERT_EQ(err, Error::Ok) + << "could not load forward method (error " << (int)err << ")"; printf("Model loaded: %s\n", model_path.c_str()); // ((i % mod) - off) / 16: exact in fp32, matches test_rope.py::_ramp. @@ -486,43 +435,30 @@ static bool test_rope( auto result = module.forward({EValue(xqt), EValue(xkt), EValue(fct), EValue(fst)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); // Outputs in graph order [0]=xq_out, [1]=xk_out (positional; the numel check // below guards a swap, since NH != NKV under GQA). - if (outputs.size() < 2 || !outputs[0].isTensor() || !outputs[1].isTensor()) { - printf("FAIL: expected 2 tensor outputs, got %zu\n", outputs.size()); - return false; - } + ASSERT_TRUE( + outputs.size() >= 2 && outputs[0].isTensor() && outputs[1].isTensor()) + << "expected 2 tensor outputs, got " << outputs.size(); const auto& xq_t = outputs[0].toTensor(); const auto& xk_t = outputs[1].toTensor(); - if (xq_t.numel() != xq_numel || xk_t.numel() != xk_numel) { - printf( - "FAIL: output shapes [%zu,%zu] != expected [%d,%d]\n", - (size_t)xq_t.numel(), - (size_t)xk_t.numel(), - xq_numel, - xk_numel); - return false; - } + ASSERT_TRUE(xq_t.numel() == xq_numel && xk_t.numel() == xk_numel) + << "output shapes [" << (size_t)xq_t.numel() << "," + << (size_t)xk_t.numel() << "] != expected [" << xq_numel << "," + << xk_numel << "]"; const float* xq_out = xq_t.const_data_ptr(); const float* xk_out = xk_t.const_data_ptr(); std::vector gq = load_golden(xq_golden_path, xq_numel); std::vector gk = load_golden(xk_golden_path, xk_numel); - if (gq.empty() || gk.empty()) { - printf( - "FAIL: could not load goldens %s / %s\n", - xq_golden_path.c_str(), - xk_golden_path.c_str()); - return false; - } + ASSERT_FALSE(gq.empty() || gk.empty()) + << "could not load goldens " << xq_golden_path << " / " << xk_golden_path; - // Per-element abs-OR-rel on xq and xk (shared helper, defined above). + // Per-element abs-OR-rel on xq and xk (shared helper). float maq = 0.0f, mrq = 0.0f, mak = 0.0f, mrk = 0.0f; const bool pass_q = quant_within_tol(xq_out, gq.data(), xq_numel, 1e-3f, 1e-3f, &maq, &mrq); @@ -536,15 +472,11 @@ static bool test_rope( max_abs_err, max_rel_err, xq_numel + xk_numel); - if (!(pass_q && pass_k)) { - printf("FAIL: apply_rotary_emb exceeds tolerance 1e-3 (abs AND rel)\n"); - return false; - } - printf("PASS: apply_rotary_emb test\n"); - return true; + EXPECT_TRUE(pass_q && pass_k) + << "apply_rotary_emb exceeds tolerance 1e-3 (abs AND rel)"; } -static bool test_prepack( +void test_prepack( const std::string& model_path, const std::string& golden_path, const std::string& label = "x + const w") { @@ -555,17 +487,12 @@ static bool test_prepack( Module module(model_path); auto err = module.load_forward(); - if (err != Error::Ok) { - printf("FAIL: could not load forward method (error %d)\n", (int)err); - return false; - } + ASSERT_EQ(err, Error::Ok) + << "could not load forward method (error " << (int)err << ")"; printf("Model loaded: %s\n", model_path.c_str()); std::vector golden = load_golden(golden_path, numel); - if (golden.empty()) { - printf("FAIL: could not load golden %s\n", golden_path.c_str()); - return false; - } + ASSERT_FALSE(golden.empty()) << "could not load golden " << golden_path; // ((i % 13) - 6) / 16: exact in fp32, matches test_prepack.py::_inputs. std::vector x_data(numel); @@ -575,23 +502,14 @@ static bool test_prepack( auto x = make_tensor_ptr({n, n}, std::vector(x_data)); auto result = module.forward({EValue(x)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); - if (outputs.empty() || !outputs[0].isTensor()) { - printf("FAIL: no tensor output\n"); - return false; - } + ASSERT_TRUE(!outputs.empty() && outputs[0].isTensor()) << "no tensor output"; const auto& out_tensor = outputs[0].toTensor(); - if (out_tensor.numel() != numel) { - printf( - "FAIL: output numel %zu != expected %d\n", - (size_t)out_tensor.numel(), - numel); - return false; - } + ASSERT_EQ((int)out_tensor.numel(), numel) + << "output numel " << (size_t)out_tensor.numel() << " != expected " + << numel; const float* out_data = out_tensor.const_data_ptr(); float max_abs_err = 0.0f, max_rel_err = 0.0f; @@ -604,16 +522,11 @@ static bool test_prepack( max_abs_err, max_rel_err, numel); - if (!within) { - printf("FAIL: prepack exceeds tolerance 1e-3\n"); - return false; - } - printf("PASS: prepack test\n"); - return true; + EXPECT_TRUE(within) << "prepack exceeds tolerance 1e-3"; } // Reconstruct _ramp_input bit-for-bit, run the op, compare to the fp64 golden. -static bool test_q4gsw_config( +void test_q4gsw_config( const Q4gswConfig& cfg, const std::string& pte, const std::string& golden_path) { @@ -625,10 +538,7 @@ static bool test_q4gsw_config( cfg.n); Module module(pte); - if (module.load_forward() != Error::Ok) { - printf("FAIL: could not load %s\n", pte.c_str()); - return false; - } + ASSERT_EQ(module.load_forward(), Error::Ok) << "could not load " << pte; const int in_numel = cfg.m * cfg.k; const int out_numel = cfg.m * cfg.n; @@ -639,30 +549,18 @@ static bool test_q4gsw_config( auto x = make_tensor_ptr({cfg.m, cfg.k}, std::vector(input)); auto result = module.forward({EValue(x)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); - if (outputs.empty() || !outputs[0].isTensor()) { - printf("FAIL: no tensor output\n"); - return false; - } + ASSERT_TRUE(!outputs.empty() && outputs[0].isTensor()) << "no tensor output"; const auto& out_tensor = outputs[0].toTensor(); - if (out_tensor.numel() != out_numel) { - printf( - "FAIL: output numel %zu != expected %d\n", - (size_t)out_tensor.numel(), - out_numel); - return false; - } + ASSERT_EQ((int)out_tensor.numel(), out_numel) + << "output numel " << (size_t)out_tensor.numel() << " != expected " + << out_numel; const float* out_data = out_tensor.const_data_ptr(); std::vector golden = load_golden(golden_path, out_numel); - if (golden.empty()) { - printf("FAIL: could not load golden %s\n", golden_path.c_str()); - return false; - } + ASSERT_FALSE(golden.empty()) << "could not load golden " << golden_path; float ma = 0.0f, mr = 0.0f; const bool pass = quant_within_tol( @@ -672,47 +570,8 @@ static bool test_q4gsw_config( ma, mr, out_numel); - if (!pass) { - printf( - "FAIL: linear_q4gsw %s exceeds tolerance (abs %g OR rel %g)\n", - cfg.name, - cfg.tol_abs, - cfg.tol_rel); - return false; - } - printf("PASS: linear_q4gsw %s\n", cfg.name); - return true; -} - -// q4gsw sweep: self-discover q4gsw_.pte; required=FAIL, heavy=gate, *ran. -static bool test_q4gsw_sweep(const std::string& dir, bool* ran) { - bool ok = true; - const bool heavy_run = std::getenv("WEBGPU_TEST_HEAVY") != nullptr; - for (const auto& cfg : kQ4gswConfigs) { - const std::string pte = dir + "q4gsw_" + cfg.name + ".pte"; - FILE* f = std::fopen(pte.c_str(), "rb"); - if (!f) { - if (cfg.required && !dir.empty()) { - printf( - "FAIL: required q4gsw config %s has no .pte in %s\n", - cfg.name, - dir.c_str()); - ok = false; - } - continue; - } - std::fclose(f); - if (cfg.heavy && !heavy_run) { - printf( - "SKIP: heavy q4gsw config %s (set WEBGPU_TEST_HEAVY=1 on a real GPU)\n", - cfg.name); - continue; - } - const std::string golden = dir + "q4gsw_" + cfg.name + ".golden.bin"; - *ran = true; - ok = test_q4gsw_config(cfg, pte, golden) && ok; - } - return ok; + EXPECT_TRUE(pass) << "linear_q4gsw " << cfg.name << " exceeds tolerance (abs " + << cfg.tol_abs << " OR rel " << cfg.tol_rel << ")"; } // Fused sdpa_with_kv_cache sweep config. Mirrors the Python CONFIGS table in @@ -730,7 +589,7 @@ struct SdpaConfig { bool expect_reject = false; // load MUST fail (e.g. D%4 guard), no golden }; -static const SdpaConfig kSdpaConfigs[] = { +const SdpaConfig kSdpaConfigs[] = { // name Hq Hkv D S Cmax pos denom {"gqa31_prefill", 6, 2, 8, 4, 16, 0, 16.0f}, // GQA 3:1 (original case) {"mha_ctxodd", 4, 4, 16, 3, 8, 0, 16.0f}, // MHA; context_len=3 (odd) @@ -777,14 +636,18 @@ constexpr float kSdpaRampDenom = 16.0f; // /denom ramp: ((i % mod) - off) / denom, exact in fp32 (power-of-two denom). // Mirrors test_sdpa.py::_ramp. -static float sdpa_ramp(int i, int mod, int off, float denom = kSdpaRampDenom) { +float sdpa_ramp(int i, int mod, int off, float denom = kSdpaRampDenom) { return static_cast((i % mod) - off) / denom; } // Step-indexed ramp; mirrors test_sdpa.py::_ramp_t bit-for-bit. denom defaults // to kSdpaRampDenom and must match the Python denom for bit-identity. -static float -sdpa_ramp_t(int i, int mod, int off, int t, float denom = kSdpaRampDenom) { +float sdpa_ramp_t( + int i, + int mod, + int off, + int t, + float denom = kSdpaRampDenom) { return static_cast(((i + 31 * t) % mod) - off) / denom; } @@ -800,13 +663,13 @@ struct SdpaSequence { std::vector seq_lens; }; -static const SdpaSequence kSdpaSequences[] = { +const SdpaSequence kSdpaSequences[] = { {"small", 8, 4, 4, 16, {3, 1, 1, 5, 1, 1, 2}}, {"small_d", 6, 2, 8, 16, {3, 1, 1, 5, 1, 1}}, {"llama3", 24, 8, 128, 256, {111, 1, 1, 1, 57, 1, 1}}, }; -static bool test_sdpa_config( +void test_sdpa_config( const SdpaConfig& cfg, const std::string& model_path, const std::string& golden_path) { @@ -825,17 +688,13 @@ static bool test_sdpa_config( auto err = module.load_forward(); if (cfg.expect_reject) { // D not a multiple of 4 must be rejected at load by the head_dim guard. - if (err != Error::Ok) { - printf("PASS: %s rejected at load (error %d)\n", cfg.name, (int)err); - return true; - } - printf("FAIL: %s loaded OK; head_dim%%4 guard did not fire\n", cfg.name); - return false; - } - if (err != Error::Ok) { - printf("FAIL: could not load forward method (error %d)\n", (int)err); - return false; + ASSERT_NE(err, Error::Ok) + << cfg.name << " loaded OK; head_dim%4 guard did not fire"; + printf("PASS: %s rejected at load (error %d)\n", cfg.name, (int)err); + return; } + ASSERT_EQ(err, Error::Ok) + << "could not load forward method (error " << (int)err << ")"; printf("Model loaded: %s\n", model_path.c_str()); const int qn = cfg.s * cfg.hq * cfg.d; @@ -869,10 +728,8 @@ static bool test_sdpa_config( auto result = module.forward( {EValue(qt), EValue(kt), EValue(vt), EValue(kct), EValue(vct)}); - if (!result.ok()) { - printf("FAIL: forward failed (error %d)\n", (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) << "forward failed (error " << (int)result.error() + << ")"; const auto& outputs = result.get(); // Select the attention output [1,S,Hq,D] by shape; the op returns @@ -891,32 +748,17 @@ static bool test_sdpa_config( attn_matches++; } } - if (attn_idx < 0) { - printf( - "FAIL: no attention output [1,%d,%d,%d] among %zu outputs\n", - cfg.s, - cfg.hq, - cfg.d, - outputs.size()); - return false; - } - if (attn_matches > 1) { - printf( - "FAIL: ambiguous attention output: %d tensors match shape [1,%d,%d,%d]\n", - attn_matches, - cfg.s, - cfg.hq, - cfg.d); - return false; - } + ASSERT_GE(attn_idx, 0) << "no attention output [1," << cfg.s << "," << cfg.hq + << "," << cfg.d << "] among " << outputs.size() + << " outputs"; + ASSERT_LE(attn_matches, 1) << "ambiguous attention output: " << attn_matches + << " tensors match shape [1," << cfg.s << "," + << cfg.hq << "," << cfg.d << "]"; const auto& out_tensor = outputs[attn_idx].toTensor(); const float* out_data = out_tensor.const_data_ptr(); std::vector golden = load_golden(golden_path, on); - if (golden.empty()) { - printf("FAIL: could not load golden %s\n", golden_path.c_str()); - return false; - } + ASSERT_FALSE(golden.empty()) << "could not load golden " << golden_path; float max_abs_err = 0.0f, max_rel_err = 0.0f; const bool pass = @@ -926,48 +768,14 @@ static bool test_sdpa_config( max_abs_err, max_rel_err, on); - if (!pass) { - printf( - "FAIL: %s exceeds tolerance (per-element abs 1e-4 OR rel 1e-3)\n", - cfg.name); - return false; - } - printf("PASS: sdpa test (%s)\n", cfg.name); - return true; -} - -// Run the full SDPA sweep. Each config self-discovers its embedded/on-disk -// sdpa_.pte; a config is skipped silently when its .pte is absent, so the -// same binary works whether one or all configs are embedded. Returns false only -// if a discovered config actually fails. Sets *ran true if any config ran. -static bool test_sdpa_sweep(const std::string& dir, bool* ran) { - bool ok = true; - for (const auto& cfg : kSdpaConfigs) { - const std::string pte = dir + "sdpa_" + cfg.name + ".pte"; - FILE* f = std::fopen(pte.c_str(), "rb"); - if (!f) { - // required config absent (dir set) = FAIL; otherwise skip silently. - if (cfg.required && !dir.empty()) { - printf( - "FAIL: required sdpa config %s has no .pte in %s\n", - cfg.name, - dir.c_str()); - ok = false; - } - continue; // not embedded in this binary - } - std::fclose(f); - const std::string golden = dir + "sdpa_" + cfg.name + ".golden.bin"; - *ran = true; - ok = test_sdpa_config(cfg, pte, golden) && ok; - } - return ok; + EXPECT_TRUE(pass) << cfg.name + << " exceeds tolerance (per-element abs 1e-4 OR rel 1e-3)"; } // Replay one sequence: thread the op's returned (mutated) KV cache across // steps, comparing each step's attention output to its accumulated-context // golden. -static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { +void test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { printf( "\n--- Test: sdpa replay (%s: Hq=%d,Hkv=%d,D=%d,Cmax=%d, %zu steps) ---\n", seq.name, @@ -982,7 +790,6 @@ static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { int input_pos = 0; int k_idx = -1, v_idx = -1; // pinned at step 0 by content (caches share numel) - bool ok = true; for (size_t t = 0; t < seq.seq_lens.size(); t++) { const int s = seq.seq_lens[t]; @@ -990,10 +797,8 @@ static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { std::to_string(t) + "_S" + std::to_string(s) + "_pos" + std::to_string(input_pos); Module module(base + ".pte"); - if (module.load_forward() != Error::Ok) { - printf("FAIL: could not load %s.pte\n", base.c_str()); - return false; - } + ASSERT_EQ(module.load_forward(), Error::Ok) + << "could not load " << base << ".pte"; const int qn = s * seq.hq * seq.d; const int kvn = s * seq.hkv * seq.d; @@ -1016,13 +821,8 @@ static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { auto result = module.forward( {EValue(qt), EValue(kt), EValue(vt), EValue(kct), EValue(vct)}); - if (!result.ok()) { - printf( - "FAIL: forward %s.pte (error %d)\n", - base.c_str(), - (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) + << "forward " << base << ".pte (error " << (int)result.error() << ")"; const auto& outs = result.get(); // The op returns [k_cache, v_cache, attn_output]: attn has a unique numel; @@ -1040,10 +840,8 @@ static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { cache_idxs.push_back(static_cast(i)); } } - if (attn_idx < 0 || cache_idxs.size() != 2) { - printf("FAIL: %s step%zu: expected 1 attn + 2 caches\n", seq.name, t); - return false; - } + ASSERT_TRUE(attn_idx >= 0 && cache_idxs.size() == 2) + << seq.name << " step" << t << ": expected 1 attn + 2 caches"; if (t == 0) { const float* c0 = outs[cache_idxs[0]].toTensor().const_data_ptr(); @@ -1063,18 +861,13 @@ static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { k_idx = cache_idxs[1]; v_idx = cache_idxs[0]; } else { - printf( - "FAIL: %s step0 cannot identify k/v cache by content\n", seq.name); - return false; + FAIL() << seq.name << " step0 cannot identify k/v cache by content"; } printf(" k/v cache outputs: k_idx=%d v_idx=%d\n", k_idx, v_idx); } std::vector golden = load_golden(base + ".golden.bin", qn); - if (golden.empty()) { - printf("FAIL: could not load %s.golden.bin\n", base.c_str()); - return false; - } + ASSERT_FALSE(golden.empty()) << "could not load " << base << ".golden.bin"; const float* ad = outs[attn_idx].toTensor().const_data_ptr(); float ma = 0.0f, mr = 0.0f; const bool step_ok = sdpa_within_tol(ad, golden.data(), qn, &ma, &mr); @@ -1086,13 +879,9 @@ static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { input_pos + s, ma, mr); - if (!step_ok) { - printf( - "FAIL: %s step%zu exceeds tolerance (per-element abs 1e-4 OR rel 1e-3)\n", - seq.name, - t); - ok = false; - } + EXPECT_TRUE(step_ok) + << seq.name << " step" << t + << " exceeds tolerance (per-element abs 1e-4 OR rel 1e-3)"; // Thread the device-written caches into the next step (K->K, V->V). const float* kd = outs[k_idx].toTensor().const_data_ptr(); @@ -1101,28 +890,6 @@ static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { vc.assign(vd, vd + cn); input_pos += s; } - - if (ok) { - printf("PASS: sdpa replay (%s)\n", seq.name); - } - return ok; -} - -// Run all replay sequences whose step0 .pte is present (self-skip otherwise). -static bool test_sdpa_replay_sweep(const std::string& dir, bool* ran) { - bool ok = true; - for (const auto& seq : kSdpaSequences) { - const std::string step0 = dir + "sdpa_" + seq.name + "_step0_S" + - std::to_string(seq.seq_lens[0]) + "_pos0.pte"; - FILE* f = std::fopen(step0.c_str(), "rb"); - if (!f) { - continue; // sequence not embedded in this binary - } - std::fclose(f); - *ran = true; - ok = test_sdpa_replay(seq, dir) && ok; - } - return ok; } // Dynamic input_pos decode: ONE .pte (S=1, runtime SymInt input_pos) reused @@ -1133,7 +900,7 @@ static bool test_sdpa_replay_sweep(const std::string& dir, bool* ran) { // per-step input_pos actually being read + applied. negative=true pins // input_pos at 0 every step (stale context_len) and asserts the run DIVERGES, // proving the runtime input_pos + resize hook are load-bearing (no false-pass). -static bool test_sdpa_dynamic_decode( +void test_sdpa_dynamic_decode( const SdpaSequence& seq, const std::string& dir, bool negative) { @@ -1150,16 +917,12 @@ static bool test_sdpa_dynamic_decode( const std::string pte = dir + "sdpa_dyn_" + seq.name + ".pte"; Module module(pte); - if (module.load_forward() != Error::Ok) { - printf("FAIL: could not load %s\n", pte.c_str()); - return false; - } + ASSERT_EQ(module.load_forward(), Error::Ok) << "could not load " << pte; const int cn = seq.cmax * seq.hkv * seq.d; std::vector kc(cn, 0.0f), vc(cn, 0.0f); int k_idx = -1, v_idx = -1; // pinned at step 0 by content (caches share numel) - bool ok = true; bool any_mismatch = false; for (int t = 0; t < kSteps; t++) { @@ -1190,10 +953,8 @@ static bool test_sdpa_dynamic_decode( EValue(kct), EValue(vct), EValue(ipt)}); - if (!result.ok()) { - printf("FAIL: forward step%d (error %d)\n", t, (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) + << "forward step" << t << " (error " << (int)result.error() << ")"; const auto& outs = result.get(); int attn_idx = -1; @@ -1209,10 +970,8 @@ static bool test_sdpa_dynamic_decode( cache_idxs.push_back(static_cast(i)); } } - if (attn_idx < 0 || cache_idxs.size() != 2) { - printf("FAIL: %s step%d: expected 1 attn + 2 caches\n", seq.name, t); - return false; - } + ASSERT_TRUE(attn_idx >= 0 && cache_idxs.size() == 2) + << seq.name << " step" << t << ": expected 1 attn + 2 caches"; if (t == 0) { const float* c0 = outs[cache_idxs[0]].toTensor().const_data_ptr(); const float* c1 = outs[cache_idxs[1]].toTensor().const_data_ptr(); @@ -1231,18 +990,14 @@ static bool test_sdpa_dynamic_decode( k_idx = cache_idxs[1]; v_idx = cache_idxs[0]; } else { - printf("FAIL: %s step0 cannot identify k/v cache\n", seq.name); - return false; + FAIL() << seq.name << " step0 cannot identify k/v cache"; } } const std::string gpath = dir + "sdpa_dyn_" + seq.name + "_step" + std::to_string(t) + ".golden.bin"; std::vector golden = load_golden(gpath, qn); - if (golden.empty()) { - printf("FAIL: could not load %s\n", gpath.c_str()); - return false; - } + ASSERT_FALSE(golden.empty()) << "could not load " << gpath; const float* ad = outs[attn_idx].toTensor().const_data_ptr(); float ma = 0.0f, mr = 0.0f; const bool step_ok = sdpa_within_tol(ad, golden.data(), qn, &ma, &mr); @@ -1265,46 +1020,24 @@ static bool test_sdpa_dynamic_decode( } if (negative) { + // The negative control must DIVERGE: a stale input_pos=0 every step cannot + // match the accumulating golden -- otherwise the oracle has no teeth. + EXPECT_TRUE(any_mismatch) + << seq.name + << " negative control matched the golden (oracle has no teeth)"; if (any_mismatch) { printf( "PASS: sdpa dynamic decode NEGATIVE (%s): stale input_pos diverges " "as expected\n", seq.name); - return true; } - printf( - "FAIL: %s negative control matched the golden (oracle has no teeth)\n", - seq.name); - return false; + return; } - if (any_mismatch) { - printf( - "FAIL: %s exceeds tolerance (per-element abs 1e-4 OR rel 1e-3)\n", - seq.name); - ok = false; - } - if (ok) { + EXPECT_FALSE(any_mismatch) + << seq.name << " exceeds tolerance (per-element abs 1e-4 OR rel 1e-3)"; + if (!any_mismatch) { printf("PASS: sdpa dynamic decode (%s)\n", seq.name); } - return ok; -} - -// Run dynamic decode (positive + negative control) for each param set whose -// sdpa_dyn_.pte is embedded (self-skip otherwise). -static bool test_sdpa_dynamic_decode_sweep(const std::string& dir, bool* ran) { - bool ok = true; - for (const auto& seq : kSdpaSequences) { - const std::string pte = dir + "sdpa_dyn_" + seq.name + ".pte"; - FILE* f = std::fopen(pte.c_str(), "rb"); - if (!f) { - continue; - } - std::fclose(f); - *ran = true; - ok = test_sdpa_dynamic_decode(seq, dir, /*negative=*/false) && ok; - ok = test_sdpa_dynamic_decode(seq, dir, /*negative=*/true) && ok; - } - return ok; } // In-graph mutable KV cache: ONE .pte whose k_cache/v_cache are mutable buffers @@ -1314,7 +1047,7 @@ static bool test_sdpa_dynamic_decode_sweep(const std::string& dir, bool* ran) { // Module each step re-seeds the cache to zeros, so it MUST diverge from the // accumulating golden at step>=1. Persistent-matches + fresh-diverges = proof // the pass comes from real accumulation, not a static artifact. -static bool test_sdpa_incache_decode( +void test_sdpa_incache_decode( const SdpaSequence& seq, const std::string& dir, bool fresh_per_step) { @@ -1333,10 +1066,8 @@ static bool test_sdpa_incache_decode( std::unique_ptr persistent; if (!fresh_per_step) { persistent = std::make_unique(pte); - if (persistent->load_forward() != Error::Ok) { - printf("FAIL: could not load %s\n", pte.c_str()); - return false; - } + ASSERT_EQ(persistent->load_forward(), Error::Ok) + << "could not load " << pte; } bool any_mismatch = false; @@ -1363,10 +1094,7 @@ static bool test_sdpa_incache_decode( Module* mod = persistent.get(); if (fresh_per_step) { fresh = std::make_unique(pte); - if (fresh->load_forward() != Error::Ok) { - printf("FAIL: could not load %s\n", pte.c_str()); - return false; - } + ASSERT_EQ(fresh->load_forward(), Error::Ok) << "could not load " << pte; mod = fresh.get(); } @@ -1374,10 +1102,8 @@ static bool test_sdpa_incache_decode( // buffers). auto result = mod->forward({EValue(qt), EValue(kt), EValue(vt), EValue(ipt)}); - if (!result.ok()) { - printf("FAIL: forward step%d (error %d)\n", t, (int)result.error()); - return false; - } + ASSERT_TRUE(result.ok()) + << "forward step" << t << " (error " << (int)result.error() << ")"; const auto& outs = result.get(); int attn_idx = -1; for (size_t i = 0; i < outs.size(); i++) { @@ -1387,18 +1113,13 @@ static bool test_sdpa_incache_decode( break; } } - if (attn_idx < 0) { - printf("FAIL: %s step%d: no attn output (numel %d)\n", seq.name, t, qn); - return false; - } + ASSERT_GE(attn_idx, 0) << seq.name << " step" << t + << ": no attn output (numel " << qn << ")"; const std::string gpath = dir + "sdpa_incache_" + seq.name + "_step" + std::to_string(t) + ".golden.bin"; std::vector golden = load_golden(gpath, qn); - if (golden.empty()) { - printf("FAIL: could not load %s\n", gpath.c_str()); - return false; - } + ASSERT_FALSE(golden.empty()) << "could not load " << gpath; const float* ad = outs[attn_idx].toTensor().const_data_ptr(); float ma = 0.0f, mr = 0.0f; const bool step_ok = sdpa_within_tol(ad, golden.data(), qn, &ma, &mr); @@ -1418,55 +1139,36 @@ static bool test_sdpa_incache_decode( if (fresh_per_step) { // The control must DIVERGE: a fresh Module per step has no accumulated // history, so it cannot match the accumulating golden at step>=1. + EXPECT_TRUE(any_mismatch) + << seq.name + << " static control matched the accumulating golden -- " + "accumulation was not actually exercised (false-pass risk)"; if (any_mismatch) { printf( "PASS: in-graph-cache STATIC CONTROL (%s) diverges as expected -- " "persistence is load-bearing; the positive pass is real accumulation\n", seq.name); - return true; } - printf( - "FAIL: %s static control matched the accumulating golden -- " - "accumulation was not actually exercised (false-pass risk)\n", - seq.name); - return false; + return; } + EXPECT_FALSE(any_mismatch) + << seq.name << " in-graph-cache decode exceeds tolerance"; if (!any_mismatch) { printf( "PASS: sdpa in-graph-cache decode (%s) -- cache accumulated in-graph " "with NO host threading\n", seq.name); - return true; - } - printf("FAIL: %s in-graph-cache decode exceeds tolerance\n", seq.name); - return false; -} - -static bool test_sdpa_incache_decode_sweep(const std::string& dir, bool* ran) { - bool ok = true; - for (const auto& seq : kSdpaSequences) { - const std::string pte = dir + "sdpa_incache_" + seq.name + ".pte"; - FILE* f = std::fopen(pte.c_str(), "rb"); - if (!f) { - continue; - } - std::fclose(f); - *ran = true; - ok = test_sdpa_incache_decode(seq, dir, /*fresh_per_step=*/false) && ok; - ok = test_sdpa_incache_decode(seq, dir, /*fresh_per_step=*/true) && ok; } - return ok; } // S1 SymInt round-trip: build a graph directly from a dynamic-input_pos SDPA // blob; confirm input_pos deserializes as a live SymInt and set/read // round-trips. -static bool test_symint_roundtrip(const std::string& blob_path) { +void test_symint_roundtrip(const std::string& blob_path) { printf("\n--- Test: symint round-trip (%s) ---\n", blob_path.c_str()); FILE* f = std::fopen(blob_path.c_str(), "rb"); if (!f) { - printf("SKIP: %s not present\n", blob_path.c_str()); - return true; + GTEST_SKIP() << blob_path << " not present"; } std::fseek(f, 0, SEEK_END); long n = std::ftell(f); @@ -1474,24 +1176,17 @@ static bool test_symint_roundtrip(const std::string& blob_path) { std::vector blob(static_cast(n)); size_t rd = std::fread(blob.data(), 1, blob.size(), f); std::fclose(f); - if (rd != blob.size()) { - printf("FAIL: short read of %s\n", blob_path.c_str()); - return false; - } + ASSERT_EQ(rd, blob.size()) << "short read of " << blob_path; auto header = WebGPUDelegateHeader::parse(blob.data()); - if (!header.ok()) { - printf("FAIL: delegate header parse\n"); - return false; - } + ASSERT_TRUE(header.ok()) << "delegate header parse"; const uint8_t* base = blob.data(); WebGPUGraph graph; try { graph.build( base + header->flatbuffer_offset, base + header->bytes_offset, nullptr); } catch (const std::exception& e) { - printf("FAIL: graph build: %s\n", e.what()); - return false; + FAIL() << "graph build: " << e.what(); } int sid = -1; @@ -1501,35 +1196,21 @@ static bool test_symint_roundtrip(const std::string& blob_path) { break; } } - if (sid < 0) { - printf( - "FAIL: no SymInt value deserialized (input_pos should be a SymInt)\n"); - return false; - } - if (graph.symint_buffer(sid) == nullptr) { - printf("FAIL: SymInt %d has no live uniform buffer\n", sid); - return false; - } - if (graph.read_symint(sid) != 0) { - printf( - "FAIL: SymInt %d placeholder != 0 (got %d)\n", - sid, - graph.read_symint(sid)); - return false; - } + ASSERT_GE(sid, 0) + << "no SymInt value deserialized (input_pos should be a SymInt)"; + ASSERT_NE(graph.symint_buffer(sid), nullptr) + << "SymInt " << sid << " has no live uniform buffer"; + ASSERT_EQ(graph.read_symint(sid), 0) + << "SymInt " << sid << " placeholder != 0 (got " << graph.read_symint(sid) + << ")"; graph.set_symint(sid, 7); - if (graph.read_symint(sid) != 7) { - printf("FAIL: set/read round-trip (got %d)\n", graph.read_symint(sid)); - return false; - } + ASSERT_EQ(graph.read_symint(sid), 7) + << "set/read round-trip (got " << graph.read_symint(sid) << ")"; // Execute-read: feed a fake input_pos=5 via the recorded select_as_symint // source and confirm update_symints_from_inputs populates the SymInt. const auto& srcs = graph.symint_sources(); - if (srcs.empty()) { - printf("FAIL: no select_as_symint source recorded\n"); - return false; - } + ASSERT_FALSE(srcs.empty()) << "no select_as_symint source recorded"; const auto& in_ids = graph.input_ids(); std::vector fake_inputs(in_ids.size()); int64_t fake_pos = 5; @@ -1539,29 +1220,24 @@ static bool test_symint_roundtrip(const std::string& blob_path) { } } graph.update_symints_from_inputs(fake_inputs); - if (graph.read_symint(srcs[0].symint_id) != 5) { - printf( - "FAIL: execute-read (got %d, want 5)\n", - graph.read_symint(srcs[0].symint_id)); - return false; - } + ASSERT_EQ(graph.read_symint(srcs[0].symint_id), 5) + << "execute-read (got " << graph.read_symint(srcs[0].symint_id) + << ", want 5)"; printf( "PASS: symint round-trip (SymInt %d: deserialize, live buffer, " "set 0->7, execute-read input_pos->5)\n", sid); - return true; } // Group 1: the resize-hook dirty-gating mechanism (no SDPA dependency). // A hook keyed to a SymInt must run via propagate_resize() iff that SymInt // changed since the last propagate_resize, and exactly once per change. -static bool test_resize_hook(const std::string& blob_path) { +void test_resize_hook(const std::string& blob_path) { printf("\n--- Test: resize-hook dirty-gating (%s) ---\n", blob_path.c_str()); FILE* f = std::fopen(blob_path.c_str(), "rb"); if (!f) { - printf("SKIP: %s not present\n", blob_path.c_str()); - return true; + GTEST_SKIP() << blob_path << " not present"; } std::fseek(f, 0, SEEK_END); long n = std::ftell(f); @@ -1569,23 +1245,16 @@ static bool test_resize_hook(const std::string& blob_path) { std::vector blob(static_cast(n)); size_t rd = std::fread(blob.data(), 1, blob.size(), f); std::fclose(f); - if (rd != blob.size()) { - printf("FAIL: short read of %s\n", blob_path.c_str()); - return false; - } + ASSERT_EQ(rd, blob.size()) << "short read of " << blob_path; auto header = WebGPUDelegateHeader::parse(blob.data()); - if (!header.ok()) { - printf("FAIL: delegate header parse\n"); - return false; - } + ASSERT_TRUE(header.ok()) << "delegate header parse"; const uint8_t* base = blob.data(); WebGPUGraph graph; try { graph.build( base + header->flatbuffer_offset, base + header->bytes_offset, nullptr); } catch (const std::exception& e) { - printf("FAIL: graph build: %s\n", e.what()); - return false; + FAIL() << "graph build: " << e.what(); } int sid = -1; @@ -1595,10 +1264,7 @@ static bool test_resize_hook(const std::string& blob_path) { break; } } - if (sid < 0) { - printf("FAIL: no SymInt value deserialized\n"); - return false; - } + ASSERT_GE(sid, 0) << "no SymInt value deserialized"; int run_count = 0; int last_seen = -1; @@ -1610,279 +1276,405 @@ static bool test_resize_hook(const std::string& blob_path) { // 1: change 0->3 then propagate -> hook runs once, sees 3. graph.set_symint(sid, 3); graph.propagate_resize(); - if (run_count != 1 || last_seen != 3) { - printf( - "FAIL: after set(3)+propagate run_count=%d last_seen=%d (want 1,3)\n", - run_count, - last_seen); - return false; - } + ASSERT_TRUE(run_count == 1 && last_seen == 3) + << "after set(3)+propagate run_count=" << run_count + << " last_seen=" << last_seen << " (want 1,3)"; // 2: propagate again with no change -> hook does NOT run. graph.propagate_resize(); - if (run_count != 1) { - printf( - "FAIL: propagate with clean dirty-set ran the hook (run_count=%d)\n", - run_count); - return false; - } + ASSERT_EQ(run_count, 1) + << "propagate with clean dirty-set ran the hook (run_count=" << run_count + << ")"; // 3: set to the SAME value -> not dirty -> hook does NOT run. graph.set_symint(sid, 3); graph.propagate_resize(); - if (run_count != 1) { - printf( - "FAIL: set(same)+propagate ran the hook (run_count=%d)\n", run_count); - return false; - } + ASSERT_EQ(run_count, 1) << "set(same)+propagate ran the hook (run_count=" + << run_count << ")"; // 4: change 3->8 then propagate -> hook runs again, sees 8. graph.set_symint(sid, 8); graph.propagate_resize(); - if (run_count != 2 || last_seen != 8) { - printf( - "FAIL: after set(8)+propagate run_count=%d last_seen=%d (want 2,8)\n", - run_count, - last_seen); - return false; - } + ASSERT_TRUE(run_count == 2 && last_seen == 8) + << "after set(8)+propagate run_count=" << run_count + << " last_seen=" << last_seen << " (want 2,8)"; printf( "PASS: resize-hook dirty-gating (SymInt %d: runs only on change, " "once per change; saw 3 then 8)\n", sid); - return true; } -int main(int argc, char** argv) { - std::string update_cache_model_path; - if (const char* env = std::getenv("WEBGPU_TEST_UPDATE_CACHE_MODEL")) { - update_cache_model_path = env; - } +// q4gsw embedding_q4gsw on-GPU configs: small + llama1b (env-gated, +// run-if-present). +struct EmbConfig { + const char* name; + const char* model_env; + const char* indices_env; + const char* golden_env; + int num_indices; + int embed; +}; +const EmbConfig kEmbConfigs[] = { + {"small", + "WEBGPU_TEST_EMBEDDING_Q4GSW_MODEL", + "WEBGPU_TEST_EMBEDDING_Q4GSW_INDICES", + "WEBGPU_TEST_EMBEDDING_Q4GSW_GOLDEN", + 4, + 64}, + {"llama1b", + "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_MODEL", + "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_INDICES", + "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_GOLDEN", + 4, + 2048}, +}; - // Quantized-linear sweep dir (mirrors WEBGPU_TEST_SDPA_DIR). - std::string qlinear_dir; - if (const char* env = std::getenv("WEBGPU_TEST_QUANTIZED_LINEAR_DIR")) { - qlinear_dir = env; - if (!qlinear_dir.empty() && qlinear_dir.back() != '/') { - qlinear_dir += '/'; - } - } +// apply_rotary_emb on-GPU configs: multi + decode (env-gated, run-if-present). +struct RopeConfig { + const char* name; + const char* model_env; + const char* xq_env; + const char* xk_env; + int S; + int NH; + int NKV; + int HD; +}; +const RopeConfig kRopeConfigs[] = { + {"multi", + "WEBGPU_TEST_ROPE_MODEL", + "WEBGPU_TEST_ROPE_XQ_GOLDEN", + "WEBGPU_TEST_ROPE_XK_GOLDEN", + 5, + 8, + 2, + 64}, + {"decode", + "WEBGPU_TEST_ROPE_DECODE_MODEL", + "WEBGPU_TEST_ROPE_DECODE_XQ_GOLDEN", + "WEBGPU_TEST_ROPE_DECODE_XK_GOLDEN", + 1, + 32, + 8, + 64}, +}; - // embedding_q4gsw on-GPU configs: small + llama1b (env-gated, - // run-if-present). - struct EmbConfig { - const char* name; - const char* model_env; - const char* indices_env; - const char* golden_env; - int num_indices; - int embed; - }; - const EmbConfig emb_configs[] = { - {"small", - "WEBGPU_TEST_EMBEDDING_Q4GSW_MODEL", - "WEBGPU_TEST_EMBEDDING_Q4GSW_INDICES", - "WEBGPU_TEST_EMBEDDING_Q4GSW_GOLDEN", - 4, - 64}, - {"llama1b", - "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_MODEL", - "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_INDICES", - "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_GOLDEN", - 4, - 2048}, - }; +} // namespace - // apply_rotary_emb on-GPU configs: multi + decode (env-gated, - // run-if-present). - struct RopeConfig { - const char* name; - const char* model_env; - const char* xq_env; - const char* xk_env; - int S; - int NH; - int NKV; - int HD; - }; - const RopeConfig rope_configs[] = { - {"multi", - "WEBGPU_TEST_ROPE_MODEL", - "WEBGPU_TEST_ROPE_XQ_GOLDEN", - "WEBGPU_TEST_ROPE_XK_GOLDEN", - 5, - 8, - 2, - 64}, - {"decode", - "WEBGPU_TEST_ROPE_DECODE_MODEL", - "WEBGPU_TEST_ROPE_DECODE_XQ_GOLDEN", - "WEBGPU_TEST_ROPE_DECODE_XK_GOLDEN", - 1, - 32, - 8, - 64}, - }; +#ifdef WGPU_BACKEND_ENABLE_PROFILING +TEST(WebGPUNative, QueryPoolOverrunThrows) { + test_query_pool_overrun_throws(); +} - std::string prepack_model_path, prepack_golden_path; - if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_MODEL")) { - prepack_model_path = env; - } - if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_GOLDEN")) { - prepack_golden_path = env; +TEST(WebGPUNative, QueryPoolRoundtrip) { + test_query_pool_roundtrip(*get_default_webgpu_context()); +} +#endif // WGPU_BACKEND_ENABLE_PROFILING + +TEST(WebGPUNative, UpdateCache) { + if (g_update_cache_model_path.empty()) { + GTEST_SKIP() << "WEBGPU_TEST_UPDATE_CACHE_MODEL not set"; } + test_update_cache(g_update_cache_model_path); +} - std::string prepack2_model_path, prepack2_golden_path; - if (const char* env = std::getenv("WEBGPU_TEST_PREPACK2_MODEL")) { - prepack2_model_path = env; +// Guard python<->C++ ramp bit-identity: q4gsw_ramp(0) = -0.5 exactly. +TEST(WebGPUNative, Q4gswRampBitIdentity) { + EXPECT_LT(std::abs(q4gsw_ramp(0) - (-0.5f)), 1e-12f) + << "q4gsw_ramp bit-identity check"; +} + +// q4gsw sweep: self-discover q4gsw_.pte; required=FAIL, heavy=gate. +TEST(WebGPUNative, QuantizedLinearSweep) { + const std::string& dir = g_qlinear_dir; + const bool heavy_run = std::getenv("WEBGPU_TEST_HEAVY") != nullptr; + bool ran = false; + for (const auto& cfg : kQ4gswConfigs) { + const std::string pte = dir + "q4gsw_" + cfg.name + ".pte"; + FILE* f = std::fopen(pte.c_str(), "rb"); + if (!f) { + if (cfg.required && !dir.empty()) { + ADD_FAILURE() << "required q4gsw config " << cfg.name + << " has no .pte in " << dir; + } + continue; + } + std::fclose(f); + if (cfg.heavy && !heavy_run) { + printf( + "SKIP: heavy q4gsw config %s (set WEBGPU_TEST_HEAVY=1 on a real GPU)\n", + cfg.name); + continue; + } + const std::string golden = dir + "q4gsw_" + cfg.name + ".golden.bin"; + ran = true; + test_q4gsw_config(cfg, pte, golden); } - if (const char* env = std::getenv("WEBGPU_TEST_PREPACK2_GOLDEN")) { - prepack2_golden_path = env; + if (!dir.empty() && !ran) { + ADD_FAILURE() + << "WEBGPU_TEST_QUANTIZED_LINEAR_DIR set but no q4gsw config ran"; } +} - std::string prepack_tied_model_path, prepack_tied_golden_path; - if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_TIED_MODEL")) { - prepack_tied_model_path = env; +TEST(WebGPUNative, EmbeddingQ4gsw) { + bool any = false; + for (const auto& c : kEmbConfigs) { + const char* m = std::getenv(c.model_env); + const char* ip = std::getenv(c.indices_env); + const char* g = std::getenv(c.golden_env); + if (m && ip && g && *m && *ip && *g) { + any = true; + test_embedding_q4gsw(m, ip, g, c.num_indices, c.embed, c.name); + } } - if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_TIED_GOLDEN")) { - prepack_tied_golden_path = env; + if (!any) { + GTEST_SKIP() << "no embedding_q4gsw config env set"; } +} - // SDPA sweep: configs self-discover their sdpa_.pte/.golden.bin under - // this directory (default "" = the embedded-file root / cwd). Set - // WEBGPU_TEST_SDPA_DIR to point at the exported .pte directory (e.g. /tmp/). - std::string sdpa_dir; - if (const char* env = std::getenv("WEBGPU_TEST_SDPA_DIR")) { - sdpa_dir = env; - if (!sdpa_dir.empty() && sdpa_dir.back() != '/') { - sdpa_dir += '/'; +TEST(WebGPUNative, Rope) { + bool any = false; + for (const auto& c : kRopeConfigs) { + const char* m = std::getenv(c.model_env); + const char* xq = std::getenv(c.xq_env); + const char* xk = std::getenv(c.xk_env); + if (m && xq && xk && *m && *xq && *xk) { + any = true; + test_rope(m, xq, xk, c.S, c.NH, c.NKV, c.HD, c.name); } } - - WebGPUContext ctx; - try { - ctx = create_webgpu_context(); - } catch (const std::exception& e) { - printf("SKIP: %s\n", e.what()); - return 0; + if (!any) { + GTEST_SKIP() << "no apply_rotary_emb config env set"; } +} - set_default_webgpu_context(&ctx); - printf("WebGPU device acquired (native)\n"); +TEST(WebGPUNative, Prepack) { + if (g_prepack_model_path.empty() || g_prepack_golden_path.empty()) { + GTEST_SKIP() << "WEBGPU_TEST_PREPACK_MODEL/GOLDEN not set"; + } + test_prepack(g_prepack_model_path, g_prepack_golden_path); +} - bool ok = true; -#ifdef WGPU_BACKEND_ENABLE_PROFILING - ok = test_query_pool_overrun_throws() && ok; - ok = test_query_pool_roundtrip(ctx) && ok; -#endif // WGPU_BACKEND_ENABLE_PROFILING - if (!update_cache_model_path.empty()) { - ok = test_update_cache(update_cache_model_path) && ok; +TEST(WebGPUNative, Prepack2) { + if (g_prepack2_model_path.empty() || g_prepack2_golden_path.empty()) { + GTEST_SKIP() << "WEBGPU_TEST_PREPACK2_MODEL/GOLDEN not set"; } + test_prepack(g_prepack2_model_path, g_prepack2_golden_path, "x + w1 + w2"); +} - bool q4gsw_ran = false; - bool q4gsw_ok = test_q4gsw_sweep(qlinear_dir, &q4gsw_ran); - if (q4gsw_ran) { - ok = q4gsw_ok && ok; +TEST(WebGPUNative, PrepackTied) { + if (g_prepack_tied_model_path.empty() || g_prepack_tied_golden_path.empty()) { + GTEST_SKIP() << "WEBGPU_TEST_PREPACK_TIED_MODEL/GOLDEN not set"; } - // Guard python<->C++ ramp bit-identity: q4gsw_ramp(0) = -0.5 exactly. - if (std::abs(q4gsw_ramp(0) - (-0.5f)) > 1e-12f) { - printf("FAIL: q4gsw_ramp bit-identity check\n"); - ok = false; + test_prepack( + g_prepack_tied_model_path, + g_prepack_tied_golden_path, + "x + w + w (tied weights, shared key)"); +} + +// SDPA sweep: configs self-discover sdpa_.pte; required=FAIL else skip. +TEST(WebGPUNative, SdpaSweep) { + const std::string& dir = g_sdpa_dir; + bool ran = false; + for (const auto& cfg : kSdpaConfigs) { + const std::string pte = dir + "sdpa_" + cfg.name + ".pte"; + FILE* f = std::fopen(pte.c_str(), "rb"); + if (!f) { + // required config absent (dir set) = FAIL; otherwise skip silently. + if (cfg.required && !dir.empty()) { + ADD_FAILURE() << "required sdpa config " << cfg.name + << " has no .pte in " << dir; + } + continue; // not embedded in this binary + } + std::fclose(f); + const std::string golden = dir + "sdpa_" + cfg.name + ".golden.bin"; + ran = true; + test_sdpa_config(cfg, pte, golden); } - if (!qlinear_dir.empty() && !q4gsw_ran) { - printf( - "FAIL: WEBGPU_TEST_QUANTIZED_LINEAR_DIR set but no q4gsw config ran\n"); - ok = false; + if (!dir.empty() && !ran) { + ADD_FAILURE() << "WEBGPU_TEST_SDPA_DIR set but no sdpa config found a .pte"; } +} - for (const auto& c : emb_configs) { - const char* m = std::getenv(c.model_env); - const char* ip = std::getenv(c.indices_env); - const char* g = std::getenv(c.golden_env); - if (m && ip && g && *m && *ip && *g) { - ok = test_embedding_q4gsw(m, ip, g, c.num_indices, c.embed, c.name) && ok; +// Guard python<->C++ ramp bit-identity (recorded: _ramp_t(0,17,8,2)=0.1875). +TEST(WebGPUNative, SdpaRampTBitIdentity) { + EXPECT_LT(std::abs(sdpa_ramp_t(0, 17, 8, 2) - 0.1875f), 1e-12f) + << "sdpa_ramp_t bit-identity check"; +} + +// Guard the adversarial denom path: sdpa_ramp(0,17,8,0.5)= -16.0 exactly. +TEST(WebGPUNative, SdpaRampDenomBitIdentity) { + EXPECT_LT(std::abs(sdpa_ramp(0, 17, 8, 0.5f) - (-16.0f)), 1e-12f) + << "sdpa_ramp denom bit-identity check"; +} + +// Replay sweep: run every sequence whose step0 .pte is present. +TEST(WebGPUNative, SdpaReplaySweep) { + const std::string& dir = g_sdpa_dir; + for (const auto& seq : kSdpaSequences) { + const std::string step0 = dir + "sdpa_" + seq.name + "_step0_S" + + std::to_string(seq.seq_lens[0]) + "_pos0.pte"; + FILE* f = std::fopen(step0.c_str(), "rb"); + if (!f) { + continue; // sequence not embedded in this binary } + std::fclose(f); + test_sdpa_replay(seq, dir); } +} - for (const auto& c : rope_configs) { - const char* m = std::getenv(c.model_env); - const char* xq = std::getenv(c.xq_env); - const char* xk = std::getenv(c.xk_env); - if (m && xq && xk && *m && *xq && *xk) { - ok = test_rope(m, xq, xk, c.S, c.NH, c.NKV, c.HD, c.name) && ok; +// Dynamic decode sweep: positive + negative control per embedded param set. +TEST(WebGPUNative, SdpaDynamicDecodeSweep) { + const std::string& dir = g_sdpa_dir; + for (const auto& seq : kSdpaSequences) { + const std::string pte = dir + "sdpa_dyn_" + seq.name + ".pte"; + FILE* f = std::fopen(pte.c_str(), "rb"); + if (!f) { + continue; } + std::fclose(f); + test_sdpa_dynamic_decode(seq, dir, /*negative=*/false); + test_sdpa_dynamic_decode(seq, dir, /*negative=*/true); } +} - if (!prepack_model_path.empty() && !prepack_golden_path.empty()) { - ok = test_prepack(prepack_model_path, prepack_golden_path) && ok; +// In-graph-cache decode sweep: persistent + fresh (static control) per set. +TEST(WebGPUNative, SdpaIncacheDecodeSweep) { + const std::string& dir = g_sdpa_dir; + for (const auto& seq : kSdpaSequences) { + const std::string pte = dir + "sdpa_incache_" + seq.name + ".pte"; + FILE* f = std::fopen(pte.c_str(), "rb"); + if (!f) { + continue; + } + std::fclose(f); + test_sdpa_incache_decode(seq, dir, /*fresh_per_step=*/false); + test_sdpa_incache_decode(seq, dir, /*fresh_per_step=*/true); } +} - if (!prepack2_model_path.empty() && !prepack2_golden_path.empty()) { - ok = test_prepack( - prepack2_model_path, prepack2_golden_path, "x + w1 + w2") && - ok; +// If an SDPA dir was given, the exports must have produced .ptes for every +// family; a self-skip there means a silent export failure, not a pass. +TEST(WebGPUNative, SdpaAllFamiliesRanWhenDirSet) { + const std::string& dir = g_sdpa_dir; + if (dir.empty()) { + GTEST_SKIP() << "WEBGPU_TEST_SDPA_DIR not set"; + } + auto has_glob = [&](const std::string& prefix, const std::string& suffix) { + for (const auto& seq : kSdpaSequences) { + const std::string p = dir + prefix + seq.name + suffix; + FILE* f = std::fopen(p.c_str(), "rb"); + if (f) { + std::fclose(f); + return true; + } + } + return false; + }; + bool sdpa_ran = false; + for (const auto& cfg : kSdpaConfigs) { + const std::string pte = dir + "sdpa_" + cfg.name + ".pte"; + FILE* f = std::fopen(pte.c_str(), "rb"); + if (f) { + std::fclose(f); + sdpa_ran = true; + break; + } } + const bool replay_ran = [&] { + for (const auto& seq : kSdpaSequences) { + const std::string step0 = dir + "sdpa_" + seq.name + "_step0_S" + + std::to_string(seq.seq_lens[0]) + "_pos0.pte"; + FILE* f = std::fopen(step0.c_str(), "rb"); + if (f) { + std::fclose(f); + return true; + } + } + return false; + }(); + const bool dyn_ran = has_glob("sdpa_dyn_", ".pte"); + const bool incache_ran = has_glob("sdpa_incache_", ".pte"); + EXPECT_TRUE(sdpa_ran && replay_ran && dyn_ran && incache_ran) + << "WEBGPU_TEST_SDPA_DIR set but an SDPA family found no .pte"; +} - if (!prepack_tied_model_path.empty() && !prepack_tied_golden_path.empty()) { - ok = test_prepack( - prepack_tied_model_path, - prepack_tied_golden_path, - "x + w + w (tied weights, shared key)") && - ok; +TEST(WebGPUNative, SymintRoundtrip) { + if (g_symint_blob.empty()) { + GTEST_SKIP() << "WEBGPU_TEST_SYMINT_BLOB not set"; } + test_symint_roundtrip(g_symint_blob); +} - bool sdpa_ran = false; - bool sdpa_ok = test_sdpa_sweep(sdpa_dir, &sdpa_ran); - if (sdpa_ran) { - ok = sdpa_ok && ok; +TEST(WebGPUNative, ResizeHook) { + if (g_symint_blob.empty()) { + GTEST_SKIP() << "WEBGPU_TEST_SYMINT_BLOB not set"; } + test_resize_hook(g_symint_blob); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); - // Guard python<->C++ ramp bit-identity (recorded: _ramp_t(0,17,8,2)=0.1875). - if (std::abs(sdpa_ramp_t(0, 17, 8, 2) - 0.1875f) > 1e-12f) { - printf("FAIL: sdpa_ramp_t bit-identity check\n"); - ok = false; + if (const char* env = std::getenv("WEBGPU_TEST_UPDATE_CACHE_MODEL")) { + g_update_cache_model_path = env; } - // Guard the adversarial denom path: sdpa_ramp(0,17,8,0.5)= -16.0 exactly. - if (std::abs(sdpa_ramp(0, 17, 8, 0.5f) - (-16.0f)) > 1e-12f) { - printf("FAIL: sdpa_ramp denom bit-identity check\n"); - ok = false; + + // Quantized-linear sweep dir (mirrors WEBGPU_TEST_SDPA_DIR). + if (const char* env = std::getenv("WEBGPU_TEST_QUANTIZED_LINEAR_DIR")) { + g_qlinear_dir = env; + if (!g_qlinear_dir.empty() && g_qlinear_dir.back() != '/') { + g_qlinear_dir += '/'; + } } - bool replay_ran = false; - bool replay_ok = test_sdpa_replay_sweep(sdpa_dir, &replay_ran); - if (replay_ran) { - ok = replay_ok && ok; + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_MODEL")) { + g_prepack_model_path = env; + } + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_GOLDEN")) { + g_prepack_golden_path = env; } - bool dyn_ran = false; - bool dyn_ok = test_sdpa_dynamic_decode_sweep(sdpa_dir, &dyn_ran); - if (dyn_ran) { - ok = dyn_ok && ok; + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK2_MODEL")) { + g_prepack2_model_path = env; + } + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK2_GOLDEN")) { + g_prepack2_golden_path = env; } - bool incache_ran = false; - bool incache_ok = test_sdpa_incache_decode_sweep(sdpa_dir, &incache_ran); - if (incache_ran) { - ok = incache_ok && ok; + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_TIED_MODEL")) { + g_prepack_tied_model_path = env; + } + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_TIED_GOLDEN")) { + g_prepack_tied_golden_path = env; } - // If an SDPA dir was given, the exports must have produced .ptes for every - // family; a self-skip there means a silent export failure, not a pass. - if (!sdpa_dir.empty() && - !(sdpa_ran && replay_ran && dyn_ran && incache_ran)) { - printf("FAIL: WEBGPU_TEST_SDPA_DIR set but an SDPA family found no .pte\n"); - ok = false; + // SDPA sweep: configs self-discover their sdpa_.pte/.golden.bin under + // this directory (default "" = the embedded-file root / cwd). Set + // WEBGPU_TEST_SDPA_DIR to point at the exported .pte directory (e.g. /tmp/). + if (const char* env = std::getenv("WEBGPU_TEST_SDPA_DIR")) { + g_sdpa_dir = env; + if (!g_sdpa_dir.empty() && g_sdpa_dir.back() != '/') { + g_sdpa_dir += '/'; + } } if (const char* env = std::getenv("WEBGPU_TEST_SYMINT_BLOB")) { - ok = test_symint_roundtrip(env) && ok; - ok = test_resize_hook(env) && ok; + g_symint_blob = env; + } + + WebGPUContext ctx; + try { + ctx = create_webgpu_context(); + } catch (const std::exception& e) { + printf("SKIP: %s\n", e.what()); + return 0; } + set_default_webgpu_context(&ctx); + printf("WebGPU device acquired (native)\n"); + + const int rc = RUN_ALL_TESTS(); set_default_webgpu_context(nullptr); destroy_webgpu_context(ctx); - - if (!ok) { - return 1; - } - printf("\nAll tests passed\n"); - return 0; + return rc; }