Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,15 @@ trait DivModLike extends BinaryArithmetic {
case _ => x => x == 0
}

// Whether the divisor is a non-null and non-zero literal (foldable divisors are already folded
// to a literal before codegen). When true, the divide-by-zero check is dead code (it can never
// trigger), so codegen skips emitting both the check and its otherwise-unreachable error-context
// reference.
private lazy val divisorIsNonZero: Boolean = right match {
case Literal(divisor, _) => divisor != null && !isZero(divisor)
case _ => false
}

final override def eval(input: InternalRow): Any = {
// evaluate right first as we have a chance to skip left if right is 0
val input2 = right.eval(input)
Expand Down Expand Up @@ -705,7 +714,9 @@ trait DivModLike extends BinaryArithmetic {
s"${eval2.value} == 0"
}
val javaType = CodeGenerator.javaType(dataType)
val errorContextCode = getContextOrNullCode(ctx, failOnError)
// Lazy so the error-context reference is only registered when actually emitted; a statically
// non-zero divisor (see divisorIsNonZero) skips the divide-by-zero check that would use it.
lazy val errorContextCode = getContextOrNullCode(ctx, failOnError)
val operation = super.dataType match {
case DecimalType.Fixed(precision, scale) =>
val castUtils = classOf[CastUtils].getName
Expand Down Expand Up @@ -741,25 +752,34 @@ trait DivModLike extends BinaryArithmetic {

// evaluate right first as we have a chance to skip left if right is 0
if (!left.nullable && !right.nullable) {
val divByZero = if (failOnError) {
s"throw ${divideByZeroErrorCode(ctx)};"
val divisionBody =
s"""
|${eval1.code}
|$checkIntegralDivideOverflow
|$operation""".stripMargin
// A statically non-zero divisor makes the zero check dead code, so emit only the division.
val guardedBody = if (divisorIsNonZero) {
divisionBody
} else {
s"${ev.isNull} = true;"
val divByZero = if (failOnError) {
s"throw ${divideByZeroErrorCode(ctx)};"
} else {
s"${ev.isNull} = true;"
}
s"""
|if ($isZero) {
| $divByZero
|} else {$divisionBody
|}""".stripMargin
}
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
$divByZero
} else {
${eval1.code}
$checkIntegralDivideOverflow
$operation
}""")
$guardedBody""")
} else {
val nullOnErrorCondition = if (failOnError) "" else s" || $isZero"
val failOnErrorBranch = if (failOnError) {
val nullOnErrorCondition = if (failOnError || divisorIsNonZero) "" else s" || $isZero"
val failOnErrorBranch = if (failOnError && !divisorIsNonZero) {
s"if ($isZero) throw ${divideByZeroErrorCode(ctx)};"
} else {
""
Expand Down Expand Up @@ -1079,6 +1099,14 @@ case class Pmod(
case _ => x => x == 0
}

// Whether the divisor is a non-null and non-zero literal (foldable divisors are already folded
// to a literal before codegen). When true, the remainder-by-zero check is dead code, so codegen
// skips emitting it.
private lazy val divisorIsNonZero: Boolean = right match {
case Literal(divisor, _) => divisor != null && !isZero(divisor)
case _ => false
}

private lazy val pmodFunc: (Any, Any) => Any = dataType match {
case _: IntegerType => (l, r) => MathUtils.pmod(l.asInstanceOf[Int], r.asInstanceOf[Int])
case _: LongType => (l, r) => MathUtils.pmod(l.asInstanceOf[Long], r.asInstanceOf[Long])
Expand Down Expand Up @@ -1119,7 +1147,9 @@ case class Pmod(
}
val remainder = ctx.freshName("remainder")
val javaType = CodeGenerator.javaType(dataType)
val errorContext = getContextOrNullCode(ctx)
// Lazy so the error-context reference is only registered when actually emitted; a statically
// non-zero divisor (see divisorIsNonZero) skips the remainder-by-zero check that would use it.
lazy val errorContext = getContextOrNullCode(ctx)
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
val result = dataType match {
case DecimalType.Fixed(precision, scale) =>
Expand All @@ -1146,24 +1176,33 @@ case class Pmod(

// evaluate right first as we have a chance to skip left if right is 0
if (!left.nullable && !right.nullable) {
val divByZero = if (failOnError) {
s"throw QueryExecutionErrors.remainderByZeroError($errorContext);"
val remainderBody =
s"""
|${eval1.code}
|$result""".stripMargin
// A statically non-zero divisor makes the zero check dead code, so emit only the remainder.
val guardedBody = if (divisorIsNonZero) {
remainderBody
} else {
s"${ev.isNull} = true;"
val divByZero = if (failOnError) {
s"throw QueryExecutionErrors.remainderByZeroError($errorContext);"
} else {
s"${ev.isNull} = true;"
}
s"""
|if ($isZero) {
| $divByZero
|} else {$remainderBody
|}""".stripMargin
}
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
$divByZero
} else {
${eval1.code}
$result
}""")
$guardedBody""")
} else {
val nullOnErrorCondition = if (failOnError) "" else s" || $isZero"
val failOnErrorBranch = if (failOnError) {
val nullOnErrorCondition = if (failOnError || divisorIsNonZero) "" else s" || $isZero"
val failOnErrorBranch = if (failOnError && !divisorIsNonZero) {
s"if ($isZero) throw QueryExecutionErrors.remainderByZeroError($errorContext);"
} else {
""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,41 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
}
}

test("SPARK-57198: skip the divide-by-zero check when the divisor is a non-zero literal") {
def codeOf(e: Expression): String = e.genCode(new CodegenContext).code.toString

Seq(true, false).foreach { ansi =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) {
// A non-zero literal divisor makes the zero check dead code; codegen must not emit it.
Seq(
Divide(Literal(1.0), Literal(2.0)),
Remainder(Literal(7), Literal(3)),
IntegralDivide(Literal(7L), Literal(3L)),
Pmod(Literal(7), Literal(3))).foreach { e =>
val code = codeOf(e)
assert(!code.contains("ByZeroError"),
s"expected no by-zero check for a non-zero literal divisor in $e:\n$code")
}

// A variable (non-foldable) divisor keeps the check: in ANSI mode the by-zero error must
// still be generated.
Seq(
Divide(
BoundReference(0, DoubleType, nullable = false),
BoundReference(1, DoubleType, nullable = false)),
Pmod(
BoundReference(0, IntegerType, nullable = false),
BoundReference(1, IntegerType, nullable = false))).foreach { e =>
val code = codeOf(e)
if (ansi) {
assert(code.contains("ByZeroError"),
s"expected a by-zero check for a variable divisor in $e:\n$code")
}
}
}
}
}

test("SPARK-34677: exact add and subtract of day-time and year-month intervals") {
Seq(EvalMode.ANSI, EvalMode.LEGACY).foreach { evalMode =>
checkExceptionInExpression[ArithmeticException](
Expand Down