diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py index 81accfed1287..f7d79ba4348c 100644 --- a/python/tvm/target/detect_target.py +++ b/python/tvm/target/detect_target.py @@ -41,6 +41,7 @@ def _detect_cuda(dev: Device) -> Target: "max_threads_per_block": dev.max_threads_per_block, "thread_warp_size": dev.warp_size, "arch": "sm_" + dev.compute_version.replace(".", ""), + "enable_fast_math": False, } ) diff --git a/python/tvm/target/tag_registry/cuda.py b/python/tvm/target/tag_registry/cuda.py index 6b1bd9e8a8bd..d3740cb5151a 100644 --- a/python/tvm/target/tag_registry/cuda.py +++ b/python/tvm/target/tag_registry/cuda.py @@ -28,12 +28,14 @@ def _register_cuda_tag(name, arch, shared_mem=49152, regs=65536, **extra): "max_threads_per_block": 1024, "thread_warp_size": 32, "registers_per_block": regs, + # Default to disable fast math + "enable_fast_math": False, } config.update(extra) register_tag(name, config) -def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536): +def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536, enable_fast_math=False): register_tag( name, { @@ -49,6 +51,7 @@ def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536): "mcpu": mcpu, "num-cores": num_cores, }, + "enable_fast_math": enable_fast_math, }, ) diff --git a/src/target/cuda/intrin_rule_cuda.cc b/src/target/cuda/intrin_rule_cuda.cc index d38db9fe8372..5e3c5214cc7f 100644 --- a/src/target/cuda/intrin_rule_cuda.cc +++ b/src/target/cuda/intrin_rule_cuda.cc @@ -174,37 +174,46 @@ TVM_REGISTER_OP("tirx.nearbyint") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.exp") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.exp2") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.exp10") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.erf") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.log") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.log2") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.log10") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.tan") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + // Now the fast math version of tan and the default version of tan are same. + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.cos") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.cosh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.sin") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.sinh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); @@ -213,12 +222,17 @@ TVM_REGISTER_OP("tirx.atan") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.tanh") + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.sqrt") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tirx.rsqrt") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + TVM_REGISTER_OP("tirx.pow") + .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.popcount") diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index b19c41056deb..588a4f271498 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -183,6 +183,14 @@ ffi::Map UpdateCUDAAttrs(ffi::Map } target.Set("arch", ffi::String("sm_") + std::to_string(archInt)); } + // Update enable_fast_math + if (target.count("enable_fast_math")) { + // If enable_fast_math has been specified, validate that enable_fast_math is a bool + Downcast(target.at("enable_fast_math")); + } else { + // If enable_fast_math has not been specified, default to false + target.Set("enable_fast_math", false); + } return target; } @@ -367,6 +375,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("l2_cache_size_bytes") .add_attr_option("max_num_threads", refl::DefaultValue(1024)) // TODO(@zxybazh): deprecate it + .add_attr_option("enable_fast_math") .set_default_keys({"cuda", "gpu"}) .set_target_canonicalizer(UpdateCUDAAttrs); diff --git a/src/tirx/transform/lower_intrin.cc b/src/tirx/transform/lower_intrin.cc index 981615b0d1d5..7f4b1aa30b4a 100644 --- a/src/tirx/transform/lower_intrin.cc +++ b/src/tirx/transform/lower_intrin.cc @@ -46,11 +46,21 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt_; using FLowerGeneral = ffi::TypedFunction; - IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") - : IRMutatorWithAnalyzer(analyzer) { + IntrinInjecter(arith::Analyzer* analyzer, const Target& tgt) : IRMutatorWithAnalyzer(analyzer) { + std::string target = tgt->kind->name; + ffi::String mtriple = tgt->GetAttr("mtriple").value_or(""); + std::vector patterns; + // For CUDA targets, we need to add the fast math patterns if enable_fast_math is true. + // The priority of the fast math patterns is higher than the normal patterns. + bool is_fast_math = tgt->GetAttr("enable_fast_math").value_or(false); + if (is_fast_math) { + patterns.push_back(target + ".fastmath.FLowerIntrinsic"); + patterns.push_back(target + ".fastmath.FLegalize"); + } patterns.push_back(target + ".FLowerIntrinsic"); patterns.push_back(target + ".FLegalize"); + bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); if (is_llvm_aarch64) { patterns.push_back(target + ".aarch64.FLowerIntrinsic"); @@ -354,7 +364,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { arith::Analyzer analyzer; - return IntrinInjecter(&analyzer, target)(std::move(stmt)); + return IntrinInjecter(&analyzer, Target(ffi::String(target)))(std::move(stmt)); } namespace transform { @@ -365,9 +375,7 @@ Pass LowerIntrin() { auto target = f->GetAttr(tvm::attr::kTarget); TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - auto mtriple = target.value()->GetAttr("mtriple", ""); - n->body = - IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body)); + n->body = IntrinInjecter(&analyzer, target.value())(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {}); diff --git a/tests/python/codegen/test_target_codegen_cuda_fastmath.py b/tests/python/codegen/test_target_codegen_cuda_fastmath.py new file mode 100644 index 000000000000..84cac4361e61 --- /dev/null +++ b/tests/python/codegen/test_target_codegen_cuda_fastmath.py @@ -0,0 +1,298 @@ +# Licensed to the Apache Software Foundation (ASF) under one + +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +from collections.abc import Callable +from dataclasses import dataclass + +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.tirx as tirx +from tvm.contrib.nvcc import have_fp16 +from tvm.ir.module import IRModule +from tvm.runtime.executable import Executable +from tvm.script import tirx as T + +VECTOR_N_INPUTS = 8 + + +def make_prim_func( + name: str, + dtype: str, + num_inputs: int, + op: Callable[[tirx.PrimExpr, ...], tirx.PrimExpr], +) -> tirx.PrimFunc: + """Make a primitive function that applies the given operation to the input buffer.""" + if num_inputs == 1: + + @T.prim_func + def kernel( + A: T.Buffer((VECTOR_N_INPUTS,), dtype), + B: T.Buffer((VECTOR_N_INPUTS,), dtype), + ): + T.func_attr({"global_symbol": name + "_kernel", "tirx.noalias": True}) + for i in T.thread_binding(VECTOR_N_INPUTS, thread="threadIdx.x"): + B[i] = op(A[i]) + + return kernel + elif num_inputs == 2: + + @T.prim_func + def kernel( + A: T.Buffer((VECTOR_N_INPUTS,), dtype), + E: T.Buffer((VECTOR_N_INPUTS,), dtype), + B: T.Buffer((VECTOR_N_INPUTS,), dtype), + ): + T.func_attr({"global_symbol": name + "_kernel", "tirx.noalias": True}) + for i in T.thread_binding(VECTOR_N_INPUTS, thread="threadIdx.x"): + B[i] = op(A[i], E[i]) + + return kernel + else: + raise ValueError(f"Unsupported number of inputs: {num_inputs}") + + +@dataclass(frozen=True) +class MathCase: + name: str + op: Callable[[tirx.PrimExpr, ...], tirx.PrimExpr] + num_inputs: int + default_intrinsic_f16: str + default_intrinsic_bf16: str + default_intrinsic_f32: str + default_intrinsic_f64: str + fast_math_intrinsic_f32: str + np_ref: object + rtol: float = 1e-5 + atol: float = 1e-6 + + +MATH_CASES = [ + MathCase( + "exp_case", + T.exp, + 1, + "hexp", + "hexp", + "expf", + "exp", + "__expf", + lambda x: np.exp(x), + ), + MathCase( + "exp10_case", + T.exp10, + 1, + "hexp10", + "hexp10", + "exp10f", + "exp10", + "__exp10f", + lambda x: np.power(10.0, x), + ), + MathCase( + "log_case", + T.log, + 1, + "hlog", + "hlog", + "logf", + "log", + "__logf", + lambda x: np.log(x), + ), + MathCase( + "log2_case", + T.log2, + 1, + "hlog2", + "hlog2", + "log2f", + "log2", + "__log2f", + lambda x: np.log2(x), + ), + MathCase( + "log10_case", + T.log10, + 1, + "hlog10", + "hlog10", + "log10f", + "log10", + "__log10f", + lambda x: np.log10(x), + ), + MathCase( + "tan_case", + T.tan, + 1, + "htan", + "htan", + "tanf", + "tan", + "tanf", + lambda x: np.tan(x), + ), + MathCase( + "cos_case", + T.cos, + 1, + "hcos", + "hcos", + "cosf", + "cos", + "__cosf", + lambda x: np.cos(x), + ), + MathCase( + "sin_case", + T.sin, + 1, + "hsin", + "hsin", + "sinf", + "sin", + "__sinf", + lambda x: np.sin(x), + ), + MathCase( + "tanh_case", + T.tanh, + 1, + "htanh", + "htanh", + "tanhf", + "tanh", + "__tanhf", + lambda x: np.tanh(x), + ), + MathCase( + "pow_case", + T.pow, + 2, + "hpow", + "hpow", + "powf", + "pow", + "__powf", + lambda x, y: np.power(x, y), + ), +] + + +def make_mod( + dtype: str, case: MathCase, enable_fast_math: bool +) -> tuple[tvm.target.Target, tvm.IRModule]: + """Make a module for the given dtype and case.""" + target = tvm.target.Target({"kind": "cuda", "enable_fast_math": enable_fast_math}) + prim_func = make_prim_func(case.name, dtype, case.num_inputs, case.op) + return target, tvm.IRModule.from_expr(prim_func.with_attr("target", target)) + + +def expected_intrinsic(dtype: str, case: MathCase, enable_fast_math: bool) -> str: + """Get the expected intrinsic for the given dtype and case.""" + if dtype == "float16": + return case.default_intrinsic_f16 + elif dtype == "bfloat16": + return case.default_intrinsic_bf16 + elif dtype == "float32": + return case.fast_math_intrinsic_f32 if enable_fast_math else case.default_intrinsic_f32 + elif dtype == "float64": + return case.default_intrinsic_f64 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def check_lowered_ir( + dtype: str, case: MathCase, enable_fast_math: bool +) -> tuple[tvm.target.Target, IRModule]: + """Check the lowered IR for the given dtype and case.""" + target, mod = make_mod(dtype, case, enable_fast_math) + lowered_mod = tvm.tirx.transform.LowerIntrin()(mod) + script = lowered_mod.script(show_meta=False) + expected = expected_intrinsic(dtype, case, enable_fast_math) + assert re.search(rf"""["']{re.escape(expected)}["']""", script) + return target, lowered_mod + + +def check_cuda_source( + target: tvm.target.Target, + mod: IRModule, + dtype: str, + case: MathCase, + enable_fast_math: bool, +) -> Executable: + """Check the CUDA source for the given dtype and case.""" + executable = tvm.compile(mod, target=target) + source = executable.mod.imports[0].inspect_source() + expected = expected_intrinsic(dtype, case, enable_fast_math) + assert re.search(rf"(? bool: @@ -133,32 +129,77 @@ def supports_device(cls, device: str) -> bool: # validated against the ONNX Backend Test Suite. They can be added # incrementally as the importer improves. _INCLUDE_OPS = [ - "abs", "acos", "acosh", "add", "and", "argmax", "argmin", - "averagepool", "bitshift", - "bitwise_and", "bitwise_not", "bitwise_or", "bitwise_xor", - "ceil", "clip", "compress", "concat", - "conv", "cos", "cosh", - "depthtospace", "div", - "einsum", "erf", "exp", - "flatten", "floor", - "gathernd", "gemm", - "globalaveragepool", "globalmaxpool", "greater", "greater_equal", - "hardmax", "hardswish", + "abs", + "acos", + "acosh", + "add", + "and", + "argmax", + "argmin", + "averagepool", + "bitshift", + "bitwise_and", + "bitwise_not", + "bitwise_or", + "bitwise_xor", + "ceil", + "clip", + "compress", + "concat", + "conv", + "cos", + "cosh", + "depthtospace", + "div", + "einsum", + "erf", + "exp", + "flatten", + "floor", + "gathernd", + "gemm", + "globalaveragepool", + "globalmaxpool", + "greater", + "greater_equal", + "hardmax", + "hardswish", "isnan", - "less", "less_equal", "lrn", - "matmul", "matmulinteger", "mean", "min", "mod", "mul", "neg", - "nonzero", "not", + "less", + "less_equal", + "lrn", + "matmul", + "matmulinteger", + "mean", + "min", + "mod", + "mul", + "neg", + "nonzero", + "not", "or", "reciprocal", "round", "scatternd", - "sigmoid", "sign", - "sin", "sinh", "size", "slice", + "sigmoid", + "sign", + "sin", + "sinh", + "size", + "slice", "spacetodepth", - "sqrt", "squeeze", "sub", "sum", - "tan", "tanh", "tile", "transpose", - "unique", "unsqueeze", - "where", "xor", + "sqrt", + "squeeze", + "sub", + "sum", + "tan", + "tanh", + "tile", + "transpose", + "unique", + "unsqueeze", + "where", + "xor", ] for _op in _INCLUDE_OPS: diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index 1b2246adb09c..6c8d1d9939c1 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -148,6 +148,7 @@ def test_target_tag_0(): assert tgt.attrs["max_threads_per_block"] == 1024 assert tgt.attrs["thread_warp_size"] == 32 assert tgt.attrs["registers_per_block"] == 65536 + assert not tgt.attrs["enable_fast_math"] def test_target_tag_1(): @@ -158,15 +159,19 @@ def test_target_tag_1(): assert tgt.attrs["max_threads_per_block"] == 1024 assert tgt.attrs["thread_warp_size"] == 32 assert tgt.attrs["registers_per_block"] == 32768 + assert not tgt.attrs["enable_fast_math"] def test_target_tag_override(): """Test creating a target from a tag with attribute overrides.""" - tgt = tvm.target.Target({"tag": "nvidia/nvidia-a100", "l2_cache_size_bytes": 12345}) + tgt = tvm.target.Target( + {"tag": "nvidia/nvidia-a100", "l2_cache_size_bytes": 12345, "enable_fast_math": True} + ) assert tgt.kind.name == "cuda" assert tgt.attrs["arch"] == "sm_80" # Override should take effect assert int(tgt.attrs["l2_cache_size_bytes"]) == 12345 + assert tgt.attrs["enable_fast_math"] # Base tag fields should be preserved assert tgt.attrs["max_shared_memory_per_block"] == 49152 assert tgt.attrs["thread_warp_size"] == 32 @@ -189,12 +194,14 @@ def test_target_host_tags(): assert tgt.attrs["max_threads_per_block"] == 1024 assert tgt.attrs["thread_warp_size"] == 32 assert tgt.attrs["registers_per_block"] == 32768 + assert not tgt.attrs["enable_fast_math"] assert tgt.host.kind.name == "cuda" assert tgt.host.attrs["arch"] == "sm_75" assert tgt.host.attrs["max_shared_memory_per_block"] == 49152 assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 65536 + assert not tgt.host.attrs["enable_fast_math"] def test_target_host_tag_dict(): @@ -205,6 +212,7 @@ def test_target_host_tag_dict(): assert tgt.attrs["max_threads_per_block"] == 1024 assert tgt.attrs["thread_warp_size"] == 32 assert tgt.attrs["registers_per_block"] == 32768 + assert not tgt.attrs["enable_fast_math"] assert tgt.host.kind.name == "llvm" @@ -217,6 +225,7 @@ def test_target_host_single_dict(): assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 32768 + assert not tgt.host.attrs["enable_fast_math"] def test_target_host_single_string(): @@ -234,6 +243,7 @@ def test_target_host_single_string_with_tag(): assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 32768 + assert not tgt.host.attrs["enable_fast_math"] def test_target_host_merge_0(): @@ -245,6 +255,7 @@ def test_target_host_merge_0(): assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 32768 + assert not tgt.host.attrs["enable_fast_math"] def test_target_host_merge_1(): @@ -295,6 +306,7 @@ def test_target_with_host(): assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 32768 + assert not tgt.host.attrs["enable_fast_math"] def test_target_attr_bool_value():