Skip to content

CUBLAS_STATUS_NOT_SUPPORTED when using CUBLASLT_EPILOGUE_BIAS with NVFP4 #303

@huye566

Description

@huye566

I am encountering a CUBLAS_STATUS_NOT_SUPPORTED error when attempting to use CUBLASLT_EPILOGUE_BIAS for NVFP4 (E2M1) inference via cublasLtMatmul.

Interestingly, while the standalone BIAS epilogue fails, the fused CUBLASLT_EPILOGUE_RELU_BIAS works perfectly under the exact same configuration and environment.

Environment:
GPU: thor-U
Data Type: NVFP4 (Input), FP16 (Bias & Output)
CUDA Version: 12.8

Core Code Snippet:

bool cublaslt_gemm_nvfp4_impl(
    cublasLtHandle_t handle,
    int m, int n, int k,
    const void* A,
    const void* B,
    void* C,
    void* D,
    const void* bias,
    const nv_fp8_e4m3* a_scale,
    const nv_fp8_e4m3* b_scale,
    const nv_fp8_e4m3* c_scale,
    const nv_fp8_e4m3* d_scale,
    const nv_fp8_e4m3* d_out_scale,
    const NvFp4GemmParams& params,
    cudaStream_t stream) {

    cublasLtMatmulDesc_t operation_desc = nullptr;
    cudaDataType_t nvfp4DataType = CUDA_R_4F_E2M1;
    cudaDataType_t d_type = CUDA_R_32F;
    if (params.compute_type == ComputeType::HALF) {
        d_type = CUDA_R_16F;
    }
    CUBLASLT_CHECK(cublasLtMatmulDescCreate(&operation_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
    CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc, CUBLASLT_MATMUL_DESC_TRANSA,
                                           &params.trans_a, sizeof(params.trans_a))); // T
    CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc, CUBLASLT_MATMUL_DESC_TRANSB,
                                           &params.trans_b, sizeof(params.trans_b))); // N

    if (bias && (params.epilogue_mode == EpilogueMode::BIAS ||
                params.epilogue_mode == EpilogueMode::BIAS_RELU ||
                params.epilogue_mode == EpilogueMode::BIAS_GELU)) {
        cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; // error
        // cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; // right
        CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc,
            CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
        CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc,
            CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
        CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc,
            CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &d_type, sizeof(d_type)));
    }

    void *scale_A_ptr = (void *)a_scale;
    void *scale_B_ptr = (void *)b_scale;
    CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
                                           &params.a_scale_mode, sizeof(params.a_scale_mode)));
    CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
                                        &params.b_scale_mode, sizeof(params.b_scale_mode)));
    CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                            &scale_B_ptr, sizeof(scale_B_ptr)));
    CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(operation_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                            &scale_A_ptr, sizeof(scale_A_ptr)));

    // 创建矩阵布局
    cublasLtMatrixLayout_t A_desc = nullptr, B_desc = nullptr, C_desc = nullptr, D_desc = nullptr;

    int rows_A, cols_A, lda;
    int rows_B, cols_B, ldb;
    int rows_D, cols_D, ldd;
    rows_A = k;
    cols_A = m;
    lda = k;
    rows_B = k;
    cols_B = n;
    ldb = k;
    rows_D = n;
    cols_D = m;
    ldd = n;

    CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&A_desc, nvfp4DataType,
                                       rows_A, cols_A, lda));
    CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&B_desc, nvfp4DataType,
                                       rows_B, cols_B, ldb));
    CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&C_desc, d_type, rows_D, cols_D, ldd));
    CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&D_desc, d_type, rows_D, cols_D, ldd));

    auto out_order = CUBLASLT_ORDER_COL;
    CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(
        D_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &out_order, sizeof(out_order)));
    cublasLtMatmulPreference_t preference = nullptr;
    size_t workspace_size = 32 * 1024 * 1024;
    CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&preference));
    CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(preference,
                                                 CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
                                                 &workspace_size, sizeof(workspace_size)));

    // 获取启发式算法
    int returned_results = 1;
    cublasLtMatmulHeuristicResult_t heuristic_result = {};
    CublasLtNVFP4Wrapper& wrapper = CublasLtNVFP4Wrapper::instance();

    CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(
        handle, operation_desc, B_desc, A_desc, D_desc, D_desc,
        preference, 1, &heuristic_result, &returned_results));
    if (returned_results == 0) {
        std::cerr << "No valid algorithm found for NVFP4 GEMM" << std::endl;
    } else {
        CUBLASLT_CHECK(cublasLtMatmul(
            handle, operation_desc,
            &params.alpha,
            B, B_desc,
            A, A_desc,
            &params.beta,
            D, D_desc,
            D, D_desc,
            &heuristic_result.algo,
            wrapper.workspace(),
            wrapper.max_workspace_size(),
            stream));
    }

    // 清理资源
    if (preference) cublasLtMatmulPreferenceDestroy(preference);
    if (D_desc) cublasLtMatrixLayoutDestroy(D_desc);
    if (C_desc) cublasLtMatrixLayoutDestroy(C_desc);
    if (B_desc) cublasLtMatrixLayoutDestroy(B_desc);
    if (A_desc) cublasLtMatrixLayoutDestroy(A_desc);
    if (operation_desc) cublasLtMatmulDescDestroy(operation_desc);

    return (returned_results != 0);
}

Is this a known limitation for NVFP4 kernels in the current version of cuBLASLt, or are there specific alignment/attribute requirements for CUBLASLT_EPILOGUE_BIAS that differ from the fused ReLU version?

Metadata

Metadata

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions