From 41f49842fb78c7eafcf82eb73fabfff83ab8fc8b Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Thu, 16 Apr 2026 07:09:47 +0000 Subject: [PATCH 1/4] [MetaX] Implement cooperative kernel launch for CINN plugin --- backends/metax_gpu/cinn/cinn_interface.cc | 15 +++++++++++ .../metax_gpu/cinn/runtime/cinn_runtime.cc | 27 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/backends/metax_gpu/cinn/cinn_interface.cc b/backends/metax_gpu/cinn/cinn_interface.cc index a01bd0e67e..332cc99674 100644 --- a/backends/metax_gpu/cinn/cinn_interface.cc +++ b/backends/metax_gpu/cinn/cinn_interface.cc @@ -67,6 +67,20 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr, int shm, void* stream); +// Launches a cooperative kernel function (grid-level sync) +extern C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, + int shm, + void* stream); + // --- From passes/pass_manager.cc --- // Applies custom graph optimization passes extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module); @@ -99,6 +113,7 @@ void InitCinnInterface(C_DeviceInterface* device_interface) { metax_cinn_impl.module_unload = MetaxModuleUnload; metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress; metax_cinn_impl.launch_kernel = MetaxLaunchKernel; + metax_cinn_impl.launch_cooperative_kernel = MetaxLaunchCooperativeKernel; // 6. Register Compilation Strategy interface metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc index 7f19db35e4..abe7c3474e 100644 --- a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -82,6 +82,33 @@ C_Status MetaxLaunchKernel(void* dev_ptr, return C_Status::C_SUCCESS; } +// Launch cooperative kernel: equivalent to cuLaunchCooperativeKernel +C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, + int shm, + void* stream) { + CUresult err = cuLaunchCooperativeKernel((CUfunction)func_ptr, + gx, + gy, + gz, + bx, + by, + bz, + shm, + (CUstream)stream, + args); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + return C_Status::C_SUCCESS; +} + } // namespace metax } // namespace custom_device } // namespace paddle From 8939eda34e8e198704c75b0dab0418030eef3828 Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Wed, 22 Apr 2026 08:28:38 +0000 Subject: [PATCH 2/4] Metax: Add __cinn_grid_sync() into CINN_GRID_REDUCE_IMPL in compiler.cc Add stdcout in cinn_runtime.cc. --- backends/metax_gpu/cinn/compiler/compiler.cc | 41 ++++++++++++++++--- .../metax_gpu/cinn/runtime/cinn_runtime.cc | 37 +++++++++++++---- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index b65f73e6e4..e9ee346c32 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -780,12 +780,43 @@ __device__ inline argidx_fp32_i64 cinn_discrete_reduce_min_argidx_fp32_i64( CINN_DISCRETE_REDUCE_IMPL(min_argidx_fp32_i64, value); } +// =============================================================== +// Grid-wide Barrier (emulates cooperative_groups::this_grid().sync()) +// Uses a sense-reversing barrier so it works correctly when called +// multiple times within the same kernel. +// REQUIREMENT: all thread blocks must be co-resident on the GPU. +// =============================================================== +__device__ unsigned int __cinn_grid_barrier_count[8192]; +__device__ unsigned int __cinn_grid_barrier_flag[8192]; + +__device__ inline void __cinn_grid_sync() { + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + unsigned int expected = + atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u); + unsigned int arrived = + atomicAdd(&__cinn_grid_barrier_count[blockIdx.x], 1u) + 1u; + if (arrived == (unsigned int)gridDim.y) { + atomicExch(&__cinn_grid_barrier_count[blockIdx.x], 0u); + __threadfence(); + atomicExch(&__cinn_grid_barrier_flag[blockIdx.x], 1u - expected); + __threadfence(); + } else { + while (atomicAdd(&__cinn_grid_barrier_flag[blockIdx.x], 0u) == + expected) { + } + } + } + __syncthreads(); +} + #define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \ - DTYPE tmp_val = init_value; \ - for (int y = 0; y < gridDim.y; y++) { \ - tmp_val = \ - cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \ - } \ + __cinn_grid_sync(); \ + DTYPE tmp_val = init_value; \ + for (int y = 0; y < gridDim.y; y++) { \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \ + } \ return tmp_val; #define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \ diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc index abe7c3474e..41d493e3df 100644 --- a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -28,9 +28,14 @@ namespace metax { C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { CUmodule module; CUresult err = cuModuleLoad(&module, path); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxModuleLoad] FAILED to load module from: " << path + << ", error=" << err << std::endl; + return C_Status::C_FAILED; + } *mod_out = reinterpret_cast(module); + std::cerr << "[MetaxModuleLoad] OK path=" << path << " module=" << module + << std::endl; return C_Status::C_SUCCESS; } @@ -47,9 +52,14 @@ C_Status MetaxGetKernelAddress(void* dev_ptr, void** func_out) { CUfunction func; CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxGetKernelAddress] FAILED func_name=" << func_name + << " module=" << module_handle << " error=" << err << std::endl; + return C_Status::C_FAILED; + } *func_out = reinterpret_cast(func); + std::cerr << "[MetaxGetKernelAddress] OK func_name=" << func_name + << " func_ptr=" << func << std::endl; return C_Status::C_SUCCESS; } @@ -82,7 +92,10 @@ C_Status MetaxLaunchKernel(void* dev_ptr, return C_Status::C_SUCCESS; } -// Launch cooperative kernel: equivalent to cuLaunchCooperativeKernel +// Launch cooperative kernel: uses cuLaunchCooperativeKernel (mapped to +// wcudaLaunchCooperativeKernel -> mcLaunchCooperativeKernel via cu-bridge) +// to guarantee all thread blocks are co-resident on the GPU, which is +// required by cross-block grid_reduce barriers (__cinn_grid_sync). C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, void* func_ptr, void** args, @@ -95,7 +108,11 @@ C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, int bz, int shm, void* stream) { - CUresult err = cuLaunchCooperativeKernel((CUfunction)func_ptr, + std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr + << " grid=(" << gx << "," << gy << "," << gz << ")" + << " block=(" << bx << "," << by << "," << bz << ")" + << " shm=" << shm << std::endl; + CUresult err = cuLaunchCooperativeKernel(static_cast(func_ptr), gx, gy, gz, @@ -103,9 +120,13 @@ C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, by, bz, shm, - (CUstream)stream, + static_cast(stream), args); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + if (err != CUDA_SUCCESS) { + std::cerr << "[MetaxLaunchCooperativeKernel] FAILED error=" << err + << std::endl; + return C_Status::C_FAILED; + } return C_Status::C_SUCCESS; } From e222d76c93343282cc15c2f0aeec43f4ec4f6e6a Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Wed, 29 Apr 2026 18:01:23 +0800 Subject: [PATCH 3/4] Metax MetaxLaunchCooperativeKernel, print more debug info --- backends/metax_gpu/cinn/compiler/compiler.cc | 4 +++ .../metax_gpu/cinn/runtime/cinn_runtime.cc | 25 ++++++++++--------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index e9ee346c32..99a4589254 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -1269,6 +1269,10 @@ C_Status MetaxCompile(void* dev_ptr, src_file << code; src_file.close(); } + // std::cout << "[MetaX] src_file content written to: " << src_path + // << "\n--- BEGIN src_file ---\n" + // << kMacaRuntimeSource << "\n" << code + // << "\n--- END src_file ---" << std::endl; // 2. Resolve compiler binary path const char* maca_path_env = std::getenv("MACA_PATH"); diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc index 41d493e3df..90093775b7 100644 --- a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -42,6 +42,7 @@ C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { // Unload module C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle) { cuModuleUnload((CUmodule)module_handle); + std::cout << "YUHAN!!! [MetaxModuleUnload] module_handle=" << module_handle << std::endl; return C_Status::C_SUCCESS; } @@ -58,8 +59,8 @@ C_Status MetaxGetKernelAddress(void* dev_ptr, return C_Status::C_FAILED; } *func_out = reinterpret_cast(func); - std::cerr << "[MetaxGetKernelAddress] OK func_name=" << func_name - << " func_ptr=" << func << std::endl; + std::cout << "YUHAN!!! [MetaxGetKernelAddress] OK func_name=" << func_name + << " func_ptr=" << func << " module_handle=" << module_handle << std::endl; return C_Status::C_SUCCESS; } @@ -108,17 +109,17 @@ C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, int bz, int shm, void* stream) { - std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr - << " grid=(" << gx << "," << gy << "," << gz << ")" - << " block=(" << bx << "," << by << "," << bz << ")" - << " shm=" << shm << std::endl; + std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr; + CUmodule module; + CUresult errModule = cuFuncGetModule(&module ,static_cast(func_ptr)); + if (errModule != CUDA_SUCCESS) { + std::cerr << "[MetaxLaunchCooperativeKernel] FAILED Module error=" << errModule + << std::endl; + return C_Status::C_FAILED; + } CUresult err = cuLaunchCooperativeKernel(static_cast(func_ptr), - gx, - gy, - gz, - bx, - by, - bz, + gx, gy, gz, + bx, by, bz, shm, static_cast(stream), args); From 4ecbb10ca5a810cfd04bbd31ccd4a07d0962e4ff Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Thu, 30 Apr 2026 18:16:58 +0800 Subject: [PATCH 4/4] Use cooperative_groups::this_grid().sync() in CINN_GRID_REDUCE_IMPL. Add CINN_GRID_REDUCE_FP16_MACRO. --- backends/metax_gpu/cinn/compiler/compiler.cc | 26 +++++++++++++++++-- .../metax_gpu/cinn/runtime/cinn_runtime.cc | 6 ----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 99a4589254..6245f537c4 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -39,6 +39,7 @@ namespace metax { // ============================================================ static const char* kMacaRuntimeSource = R"MACA_SOURCE( #pragma once +#include #include #include @@ -812,7 +813,7 @@ __device__ inline void __cinn_grid_sync() { } #define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \ - __cinn_grid_sync(); \ + cooperative_groups::this_grid().sync(); \ DTYPE tmp_val = init_value; \ for (int y = 0; y < gridDim.y; y++) { \ tmp_val = cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \ @@ -830,7 +831,28 @@ EXPAND_REDUCE_INT64_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_FP32_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_FP64_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_BOOL_MACRO(CINN_GRID_REDUCE_MACRO) -EXPAND_REDUCE_FP16_MACRO(CINN_GRID_REDUCE_MACRO) + +// FP16 grid reduce: accumulate in FP32 to avoid precision loss when summing +// multiple FP16 block-level partial sums. Each partial sum can have magnitude +// O(block_size * input_scale), and accumulating N such values in FP16 incurs +// error proportional to N * magnitude * eps_fp16. Using FP32 for the inter- +// block accumulation step keeps the error at FP16 quantization level only. +#define CINN_GRID_REDUCE_FP16_MACRO(FP16_TYPE, FP32_FUNC, INIT_VAL) \ + __device__ inline float16 cinn_grid_reduce_##FP16_TYPE( \ + const float16 *mem, int spatial_size, int spatial_index) { \ + cooperative_groups::this_grid().sync(); \ + float tmp_val = (float)(INIT_VAL); \ + for (int y = 0; y < gridDim.y; y++) { \ + tmp_val = FP32_FUNC( \ + tmp_val, __half2float(mem[y * spatial_size + spatial_index])); \ + } \ + return __float2half(tmp_val); \ + } + +CINN_GRID_REDUCE_FP16_MACRO(sum_fp16, cinn_sum_fp32, 0.0f) +CINN_GRID_REDUCE_FP16_MACRO(prod_fp16, cinn_prod_fp32, 1.0f) +CINN_GRID_REDUCE_FP16_MACRO(max_fp16, cinn_max_fp32, -65504.0f) +CINN_GRID_REDUCE_FP16_MACRO(min_fp16, cinn_min_fp32, 65504.0f) __device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) { __shared__ bool done; diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc index 90093775b7..5aaa051eeb 100644 --- a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -34,15 +34,12 @@ C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { return C_Status::C_FAILED; } *mod_out = reinterpret_cast(module); - std::cerr << "[MetaxModuleLoad] OK path=" << path << " module=" << module - << std::endl; return C_Status::C_SUCCESS; } // Unload module C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle) { cuModuleUnload((CUmodule)module_handle); - std::cout << "YUHAN!!! [MetaxModuleUnload] module_handle=" << module_handle << std::endl; return C_Status::C_SUCCESS; } @@ -59,8 +56,6 @@ C_Status MetaxGetKernelAddress(void* dev_ptr, return C_Status::C_FAILED; } *func_out = reinterpret_cast(func); - std::cout << "YUHAN!!! [MetaxGetKernelAddress] OK func_name=" << func_name - << " func_ptr=" << func << " module_handle=" << module_handle << std::endl; return C_Status::C_SUCCESS; } @@ -109,7 +104,6 @@ C_Status MetaxLaunchCooperativeKernel(void* dev_ptr, int bz, int shm, void* stream) { - std::cout << "YUHAN!!! [MetaxLaunchCooperativeKernel] func_ptr=" << func_ptr; CUmodule module; CUresult errModule = cuFuncGetModule(&module ,static_cast(func_ptr)); if (errModule != CUDA_SUCCESS) {