diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index cb7ba16aeb81..88b409c4e867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -653,6 +653,15 @@ trait DivModLike extends BinaryArithmetic { case _ => x => x == 0 } + // Whether the divisor is a non-null and non-zero literal (foldable divisors are already folded + // to a literal before codegen). When true, the divide-by-zero check is dead code (it can never + // trigger), so codegen skips emitting both the check and its otherwise-unreachable error-context + // reference. + private lazy val divisorIsNonZero: Boolean = right match { + case Literal(divisor, _) => divisor != null && !isZero(divisor) + case _ => false + } + final override def eval(input: InternalRow): Any = { // evaluate right first as we have a chance to skip left if right is 0 val input2 = right.eval(input) @@ -705,7 +714,9 @@ trait DivModLike extends BinaryArithmetic { s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val errorContextCode = getContextOrNullCode(ctx, failOnError) + // Lazy so the error-context reference is only registered when actually emitted; a statically + // non-zero divisor (see divisorIsNonZero) skips the divide-by-zero check that would use it. + lazy val errorContextCode = getContextOrNullCode(ctx, failOnError) val operation = super.dataType match { case DecimalType.Fixed(precision, scale) => val castUtils = classOf[CastUtils].getName @@ -741,25 +752,34 @@ trait DivModLike extends BinaryArithmetic { // evaluate right first as we have a chance to skip left if right is 0 if (!left.nullable && !right.nullable) { - val divByZero = if (failOnError) { - s"throw ${divideByZeroErrorCode(ctx)};" + val divisionBody = + s""" + |${eval1.code} + |$checkIntegralDivideOverflow + |$operation""".stripMargin + // A statically non-zero divisor makes the zero check dead code, so emit only the division. + val guardedBody = if (divisorIsNonZero) { + divisionBody } else { - s"${ev.isNull} = true;" + val divByZero = if (failOnError) { + s"throw ${divideByZeroErrorCode(ctx)};" + } else { + s"${ev.isNull} = true;" + } + s""" + |if ($isZero) { + | $divByZero + |} else {$divisionBody + |}""".stripMargin } ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - $divByZero - } else { - ${eval1.code} - $checkIntegralDivideOverflow - $operation - }""") + $guardedBody""") } else { - val nullOnErrorCondition = if (failOnError) "" else s" || $isZero" - val failOnErrorBranch = if (failOnError) { + val nullOnErrorCondition = if (failOnError || divisorIsNonZero) "" else s" || $isZero" + val failOnErrorBranch = if (failOnError && !divisorIsNonZero) { s"if ($isZero) throw ${divideByZeroErrorCode(ctx)};" } else { "" @@ -1079,6 +1099,14 @@ case class Pmod( case _ => x => x == 0 } + // Whether the divisor is a non-null and non-zero literal (foldable divisors are already folded + // to a literal before codegen). When true, the remainder-by-zero check is dead code, so codegen + // skips emitting it. + private lazy val divisorIsNonZero: Boolean = right match { + case Literal(divisor, _) => divisor != null && !isZero(divisor) + case _ => false + } + private lazy val pmodFunc: (Any, Any) => Any = dataType match { case _: IntegerType => (l, r) => MathUtils.pmod(l.asInstanceOf[Int], r.asInstanceOf[Int]) case _: LongType => (l, r) => MathUtils.pmod(l.asInstanceOf[Long], r.asInstanceOf[Long]) @@ -1119,7 +1147,9 @@ case class Pmod( } val remainder = ctx.freshName("remainder") val javaType = CodeGenerator.javaType(dataType) - val errorContext = getContextOrNullCode(ctx) + // Lazy so the error-context reference is only registered when actually emitted; a statically + // non-zero divisor (see divisorIsNonZero) skips the remainder-by-zero check that would use it. + lazy val errorContext = getContextOrNullCode(ctx) val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$") val result = dataType match { case DecimalType.Fixed(precision, scale) => @@ -1146,24 +1176,33 @@ case class Pmod( // evaluate right first as we have a chance to skip left if right is 0 if (!left.nullable && !right.nullable) { - val divByZero = if (failOnError) { - s"throw QueryExecutionErrors.remainderByZeroError($errorContext);" + val remainderBody = + s""" + |${eval1.code} + |$result""".stripMargin + // A statically non-zero divisor makes the zero check dead code, so emit only the remainder. + val guardedBody = if (divisorIsNonZero) { + remainderBody } else { - s"${ev.isNull} = true;" + val divByZero = if (failOnError) { + s"throw QueryExecutionErrors.remainderByZeroError($errorContext);" + } else { + s"${ev.isNull} = true;" + } + s""" + |if ($isZero) { + | $divByZero + |} else {$remainderBody + |}""".stripMargin } ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - $divByZero - } else { - ${eval1.code} - $result - }""") + $guardedBody""") } else { - val nullOnErrorCondition = if (failOnError) "" else s" || $isZero" - val failOnErrorBranch = if (failOnError) { + val nullOnErrorCondition = if (failOnError || divisorIsNonZero) "" else s" || $isZero" + val failOnErrorBranch = if (failOnError && !divisorIsNonZero) { s"if ($isZero) throw QueryExecutionErrors.remainderByZeroError($errorContext);" } else { "" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 7d659fca5df2..f8cdf825bc4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -912,6 +912,41 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-57198: skip the divide-by-zero check when the divisor is a non-zero literal") { + def codeOf(e: Expression): String = e.genCode(new CodegenContext).code.toString + + Seq(true, false).foreach { ansi => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) { + // A non-zero literal divisor makes the zero check dead code; codegen must not emit it. + Seq( + Divide(Literal(1.0), Literal(2.0)), + Remainder(Literal(7), Literal(3)), + IntegralDivide(Literal(7L), Literal(3L)), + Pmod(Literal(7), Literal(3))).foreach { e => + val code = codeOf(e) + assert(!code.contains("ByZeroError"), + s"expected no by-zero check for a non-zero literal divisor in $e:\n$code") + } + + // A variable (non-foldable) divisor keeps the check: in ANSI mode the by-zero error must + // still be generated. + Seq( + Divide( + BoundReference(0, DoubleType, nullable = false), + BoundReference(1, DoubleType, nullable = false)), + Pmod( + BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = false))).foreach { e => + val code = codeOf(e) + if (ansi) { + assert(code.contains("ByZeroError"), + s"expected a by-zero check for a variable divisor in $e:\n$code") + } + } + } + } + } + test("SPARK-34677: exact add and subtract of day-time and year-month intervals") { Seq(EvalMode.ANSI, EvalMode.LEGACY).foreach { evalMode => checkExceptionInExpression[ArithmeticException](