diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 3244deab875b..ae57e8d9a607 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -173,61 +173,18 @@ TVM_REGISTER_OP("tirx.sinh") TVM_REGISTER_OP("tirx.asin") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - - PrimExpr threshold = make_const(x.dtype(), 0.5); - PrimExpr abs_x = tvm::abs(x); - PrimExpr use_lib = abs_x >= threshold; - - PrimExpr x2 = x * x; - PrimExpr term1 = x; - PrimExpr term3 = term1 * x2 / make_const(x.dtype(), 6); - PrimExpr term5 = term3 * x2 * make_const(x.dtype(), 9) / make_const(x.dtype(), 40); - PrimExpr term7 = term5 * x2 * make_const(x.dtype(), 25) / make_const(x.dtype(), 112); - PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456); - PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160); - PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11; - - PrimExpr lib_result = - ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); - - PrimExpr lower = make_const(x.dtype(), -1.0); - PrimExpr upper = make_const(x.dtype(), 1.0); - PrimExpr out_range = tirx::Or(x upper); - PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); - - return tirx::Select(out_range, nan_const, tirx::Select(use_lib, lib_result, series)); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.acos") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acos legalization"; - const PrimExpr& x = call->args[0]; - - PrimExpr threshold = make_const(x.dtype(), 0.5); - PrimExpr abs_x = tvm::abs(x); - PrimExpr use_lib = abs_x >= threshold; - - PrimExpr half_pi = make_const(x.dtype(), M_PI / 2); - PrimExpr asin_x = asin(x); - PrimExpr formula_result = half_pi - asin_x; - - PrimExpr lib_result = - ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); - - PrimExpr lower = make_const(x.dtype(), -1.0); - PrimExpr upper = make_const(x.dtype(), 1.0); - PrimExpr out_range = tirx::Or(x upper); - PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); - - return tirx::Select(out_range, nan_const, tirx::Select(use_lib, lib_result, formula_result)); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.atan") diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 151ec35e897f..464c06c3b6a9 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -705,9 +705,9 @@ def test_bitwise_shift(direction: str): "Sinh", "Cosh", "Tanh", - # "Asin", // TODO @jikechao, fix the precision loss due to the Taylor approximation - # "Acos", - # "Atan", + "Asin", + "Acos", + # "Atan", // TODO: fix x²+1 overflow in llvm legalize for huge inputs (issue #19560) "Asinh", "Acosh", "Atanh",