diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java index fa1741cb08f7..eba771c37e16 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java @@ -346,6 +346,27 @@ public static UTF8String quote(UTF8String str) { return UTF8String.fromString(qtChar + sp + qtChar); } + /** + * Inverse hyperbolic cosine for the {@code acosh} expression, using the + * fdlibm {@code e_acosh.c} algorithm (returns {@code NaN} for {@code x < 1}). + * Shared by the eval and codegen paths so the generated Java is a single call + * rather than an inline five-branch if/else. + */ + public static double acosh(double x) { + if (x < 1.0) { + return Double.NaN; + } else if (x >= (1 << 28)) { + return StrictMath.log(x) + StrictMath.log(2.0); + } else if (x == 1.0) { + return 0.0; + } else if (x > 2.0) { + return StrictMath.log(2.0 * x - 1.0 / (x + Math.sqrt(x * x - 1.0))); + } else { + double t = x - 1.0; + return StrictMath.log1p(t + Math.sqrt(2.0 * t + t * t)); + } + } + /** * Returns the single-character string for the {@code chr} expression: the * ASCII/Latin-1 character for {@code longVal & 0xFF}. A negative argument diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index d816043e710d..2f62e1427d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -418,40 +418,11 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" since = "3.0.0", group = "math_funcs") case class Acosh(child: Expression) - extends UnaryMathExpression((x: Double) => { - // fdlibm e_acosh.c algorithm - if (x < 1.0) { - Double.NaN - } else if (x >= (1 << 28)) { - StrictMath.log(x) + StrictMath.log(2.0) - } else if (x == 1.0) { - 0.0 - } else if (x > 2.0) { - StrictMath.log(2.0 * x - 1.0 / (x + math.sqrt(x * x - 1.0))) - } else { - val t = x - 1.0 - StrictMath.log1p(t + math.sqrt(2.0 * t + t * t)) - } - }, "ACOSH") { + // fdlibm e_acosh.c algorithm, shared with codegen via ExpressionImplUtils.acosh. + extends UnaryMathExpression((x: Double) => ExpressionImplUtils.acosh(x), "ACOSH") { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => { - val sm = "java.lang.StrictMath" - val t = ctx.freshName("t") - s""" - |if ($c < 1.0) { - | ${ev.value} = java.lang.Double.NaN; - |} else if ($c >= ${1 << 28}.0) { - | ${ev.value} = $sm.log($c) + $sm.log(2.0); - |} else if ($c == 1.0) { - | ${ev.value} = 0.0; - |} else if ($c > 2.0) { - | ${ev.value} = $sm.log(2.0 * $c - 1.0 / ($c + java.lang.Math.sqrt($c * $c - 1.0))); - |} else { - | double $t = $c - 1.0; - | ${ev.value} = $sm.log1p($t + java.lang.Math.sqrt(2.0 * $t + $t * $t)); - |} - |""".stripMargin - }) + val utils = classOf[ExpressionImplUtils].getName + defineCodeGen(ctx, ev, c => s"$utils.acosh($c)") } override protected def withNewChildInternal(newChild: Expression): Acosh = copy(child = newChild) }