diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 3244deab875b..3c7f9c6d799e 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -143,32 +143,16 @@ TVM_REGISTER_OP("tirx.tan") TVM_REGISTER_OP("tirx.cosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; - using tirx::make_zero; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = exp(neg_one * x); - PrimExpr exp_posx = exp(x); - PrimExpr ret = (exp_posx + exp_negx) / two; - return ret; + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.sinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; - using tirx::make_zero; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = exp(neg_one * x); - PrimExpr exp_posx = exp(x); - PrimExpr ret = (exp_posx - exp_negx) / two; - return ret; + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.asin") @@ -232,35 +216,23 @@ TVM_REGISTER_OP("tirx.acos") TVM_REGISTER_OP("tirx.atan") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in atan legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1.0); - PrimExpr denom = sqrt(x * x + one); - return asin(x / denom); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.asinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1.0); - PrimExpr sqrt_val = sqrt(x * x + one); - return log(x + sqrt_val); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.acosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acosh legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1.0); - PrimExpr sqrt_val = sqrt(x * x - one); - return log(x + sqrt_val); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.atanh") @@ -275,21 +247,9 @@ TVM_REGISTER_OP("tirx.atanh") TVM_REGISTER_OP("tirx.erf") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in erf legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr abs_x = tvm::abs(x); - PrimExpr t = make_const(x.dtype(), 1.0) / - (make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911) * abs_x); - PrimExpr a1 = make_const(x.dtype(), 0.254829592); - PrimExpr a2 = make_const(x.dtype(), -0.284496736); - PrimExpr a3 = make_const(x.dtype(), 1.421413741); - PrimExpr a4 = make_const(x.dtype(), -1.453152027); - PrimExpr a5 = make_const(x.dtype(), 1.061405429); - PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t); - PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x); - return tvm::tirx::Select(x < 0, -approx, approx); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.clz") diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 151ec35e897f..53e93b0b4d5a 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -707,7 +707,7 @@ def test_bitwise_shift(direction: str): "Tanh", # "Asin", // TODO @jikechao, fix the precision loss due to the Taylor approximation # "Acos", - # "Atan", + "Atan", "Asinh", "Acosh", "Atanh",