From e119e5bc29045db893475eb50fba7e12b5db50c4 Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Fri, 15 May 2026 16:34:08 +0900 Subject: [PATCH 1/3] [BugFix][Target][LLVM] Route sinh/cosh/atan/asinh/erf through libm extern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Five LLVM legalize rules used inline mathematical identities that fail on representable inputs because the intermediate computation overflows or cancels, even though the true result is in range: - `sinh`/`cosh`: `(exp(x) ± exp(-x)) / 2` — `exp(89) > FLT_MAX` so `exp(x)` itself overflows. True `sinh(89) ≈ 2.24e38` fits in float32. (#19559) - `atan`: `asin(x / sqrt(x*x + 1))` — `x*x` overflows for `|x| > sqrt(FLT_MAX) ≈ 1.84e19`, then `sqrt(inf) = inf`, `x / inf = 0`, `asin(0) = 0`. (#19560) - `asinh`: `log(x + sqrt(x*x + 1))` — same `x*x` overflow. True `asinh(3e22) ≈ 52.45`. (#19561) - `erf`: Abramowitz–Stegun `1 - poly(t) * exp(-x*x)` — for small `|x|`, `poly * exp(-x*x) ≈ 1` and the subtraction cancels to 0 in float32, flushing `erf(3e-12)` to 0 instead of the true `~3.4e-12`. (#19562) All four route through the existing `DispatchPureExtern` helper (i.e. `sinhf`, `coshf`, `atanf`, `asinhf`, `erff`), the same pattern already used by `asin`/`acos`. ULP-grade accuracy across representative ranges; `Atan` is re-enabled in `test_unary` since the overflow that previously broke it is fixed. Note for reviewers: if the inline identities were a deliberate fast-path (e.g. for autovectorization or to avoid extern call overhead in tight loops), please flag it and I'll switch to stable inline forms (`exp(x − ln 2) ± exp(−x − ln 2)` for sinh/cosh, range-reduced asinh, small-x Taylor for erf, etc.). I could not find evidence of such intent in the git history. Acosh shows the same `sqrt(x*x − 1)` overflow pattern but is not covered by any of the listed issues; happy to include it as a follow-up if maintainers want. Fixes #19559. Fixes #19560. Fixes #19561. Fixes #19562. --- src/target/llvm/intrin_rule_llvm.cc | 51 +++++------------------- tests/python/relax/test_frontend_onnx.py | 2 +- 2 files changed, 11 insertions(+), 42 deletions(-) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 3244deab875b..ac8b1a7890fd 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -143,32 +143,18 @@ TVM_REGISTER_OP("tirx.tan") TVM_REGISTER_OP("tirx.cosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; - using tirx::make_zero; + using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = exp(neg_one * x); - PrimExpr exp_posx = exp(x); - PrimExpr ret = (exp_posx + exp_negx) / two; - return ret; + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.sinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; - using tirx::make_zero; + using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = exp(neg_one * x); - PrimExpr exp_posx = exp(x); - PrimExpr ret = (exp_posx - exp_negx) / two; - return ret; + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.asin") @@ -232,24 +218,18 @@ TVM_REGISTER_OP("tirx.acos") TVM_REGISTER_OP("tirx.atan") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; + using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in atan legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1.0); - PrimExpr denom = sqrt(x * x + one); - return asin(x / denom); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.asinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; + using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1.0); - PrimExpr sqrt_val = sqrt(x * x + one); - return log(x + sqrt_val); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.acosh") @@ -275,21 +255,10 @@ TVM_REGISTER_OP("tirx.atanh") TVM_REGISTER_OP("tirx.erf") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; + using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in erf legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr abs_x = tvm::abs(x); - PrimExpr t = make_const(x.dtype(), 1.0) / - (make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911) * abs_x); - PrimExpr a1 = make_const(x.dtype(), 0.254829592); - PrimExpr a2 = make_const(x.dtype(), -0.284496736); - PrimExpr a3 = make_const(x.dtype(), 1.421413741); - PrimExpr a4 = make_const(x.dtype(), -1.453152027); - PrimExpr a5 = make_const(x.dtype(), 1.061405429); - PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t); - PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x); - return tvm::tirx::Select(x < 0, -approx, approx); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.clz") diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 151ec35e897f..53e93b0b4d5a 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -707,7 +707,7 @@ def test_bitwise_shift(direction: str): "Tanh", # "Asin", // TODO @jikechao, fix the precision loss due to the Taylor approximation # "Acos", - # "Atan", + "Atan", "Asinh", "Acosh", "Atanh", From 21f42efb31926ca4dea70f536d7c7a93b6517ed2 Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Fri, 15 May 2026 16:41:02 +0900 Subject: [PATCH 2/3] [BugFix][Target][LLVM] Also route acosh through libm extern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `acosh` has the same `sqrt(x*x - 1)` overflow pattern as `asinh`: intermediate `x*x` overflows float32 for `|x| > sqrt(FLT_MAX) ≈ 1.84e19`, so `sqrt(inf) = inf`, `log(x + inf) = inf`, while the true result `acosh(3e22) ≈ 52.45` is well within range. No issue was filed for this op but the bug is identical to #19561 and the fix is the same. --- src/target/llvm/intrin_rule_llvm.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index ac8b1a7890fd..2aff9e23e4cf 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -234,13 +234,10 @@ TVM_REGISTER_OP("tirx.asinh") TVM_REGISTER_OP("tirx.acosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tirx::make_const; + using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acosh legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1.0); - PrimExpr sqrt_val = sqrt(x * x - one); - return log(x + sqrt_val); + return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); }); TVM_REGISTER_OP("tirx.atanh") From ba9d6e413cc67b685899e405e32c03cd5d07512b Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Fri, 15 May 2026 16:45:55 +0900 Subject: [PATCH 3/3] [BugFix][Target][LLVM] Drop unused `using namespace intrin` in libm wrappers Per review feedback: with the inline math identities gone, the `DispatchPureExtern` call is fully qualified (`::tvm::codegen::intrin::...`) and the `using namespace intrin;` line inside each lambda no longer brings anything into scope. Drop it from the six ops touched in this PR (sinh, cosh, atan, asinh, acosh, erf). The `CallNode* call` ICHECK is kept for parity with the rest of the file (every legalize lambda in this translation unit performs that check). --- src/target/llvm/intrin_rule_llvm.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 2aff9e23e4cf..3c7f9c6d799e 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -143,7 +143,6 @@ TVM_REGISTER_OP("tirx.tan") TVM_REGISTER_OP("tirx.cosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); @@ -151,7 +150,6 @@ TVM_REGISTER_OP("tirx.cosh") TVM_REGISTER_OP("tirx.sinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); @@ -218,7 +216,6 @@ TVM_REGISTER_OP("tirx.acos") TVM_REGISTER_OP("tirx.atan") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in atan legalization"; return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); @@ -226,7 +223,6 @@ TVM_REGISTER_OP("tirx.atan") TVM_REGISTER_OP("tirx.asinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); @@ -234,7 +230,6 @@ TVM_REGISTER_OP("tirx.asinh") TVM_REGISTER_OP("tirx.acosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acosh legalization"; return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); @@ -252,7 +247,6 @@ TVM_REGISTER_OP("tirx.atanh") TVM_REGISTER_OP("tirx.erf") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using namespace intrin; const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in erf legalization"; return ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e);