Commit b973a823 authored by Yahor Yuzefovich's avatar Yahor Yuzefovich

sem: unify division by zero check and fix it in a few places

Release note (bug fix): Previously, in some cases, CockroachDB didn't
check whether the right argument of `Div` (`/`), `FloorDiv` (`//`),
or `Mod` (`%`) operations was zero, so instead of correctly returning
a "division by zero" error, we were returning `NaN`, and this is now
fixed. Additionally, the error message of "modulus by zero" has been
changed to "division by zero" to be inline with Postgres.
parent c9b9c01c
......@@ -241,7 +241,29 @@ func (decimalCustomizer) getBinOpAssignFunc() assignFunc {
func (c floatCustomizer) getBinOpAssignFunc() assignFunc {
return func(op *lastArgWidthOverload, targetElem, leftElem, rightElem, targetCol, leftCol, rightCol string) string {
return fmt.Sprintf("%s = float64(%s) %s float64(%s)", targetElem, leftElem, op.overloadBase.OpStr, rightElem)
binOp := op.overloadBase.BinOp
computeBinOp := fmt.Sprintf("float64(%s) %s float64(%s)", leftElem, binOp, rightElem)
args := map[string]interface{}{
"CheckRightIsZero": binOp == tree.Div,
"Target": targetElem,
"Right": rightElem,
"ComputeBinOp": computeBinOp,
}
buf := strings.Builder{}
t := template.Must(template.New("").Parse(`
{
{{if .CheckRightIsZero}}
if {{.Right}} == 0.0 {
colexecerror.ExpectedError(tree.ErrDivByZero)
}
{{end}}
{{.Target}} = {{.ComputeBinOp}}
}
`))
if err := t.Execute(&buf, args); err != nil {
colexecerror.InternalError(err)
}
return buf.String()
}
}
......
......@@ -877,10 +877,10 @@ SELECT mod(5.0::float, 2.0), mod(1.0::float, 0.0), mod(5, 2), mod(19.3::decimal,
# mod returns the same results as PostgreSQL 9.4.4
# in tests below (except for the error message).
query error mod\(\): zero modulus
query error mod\(\): division by zero
SELECT mod(5, 0)
query error mod\(\): zero modulus
query error mod\(\): division by zero
SELECT mod(5::decimal, 0::decimal)
query II
......
......@@ -334,7 +334,7 @@ var mathBuiltins = map[string]builtinDefinition{
}, "Calculates `x`%`y`.", tree.VolatilityImmutable),
decimalOverload2("x", "y", func(x, y *apd.Decimal) (tree.Datum, error) {
if y.Sign() == 0 {
return nil, tree.ErrZeroModulus
return nil, tree.ErrDivByZero
}
dd := &tree.DDecimal{}
_, err := tree.HighPrecisionCtx.Rem(&dd.Decimal, x, y)
......@@ -346,7 +346,7 @@ var mathBuiltins = map[string]builtinDefinition{
Fn: func(_ *tree.EvalContext, args tree.Datums) (tree.Datum, error) {
y := tree.MustBeDInt(args[1])
if y == 0 {
return nil, tree.ErrZeroModulus
return nil, tree.ErrDivByZero
}
x := tree.MustBeDInt(args[0])
return tree.NewDInt(x % y), nil
......
......@@ -55,8 +55,6 @@ var (
// ErrDivByZero is reported on a division by zero.
ErrDivByZero = pgerror.New(pgcode.DivisionByZero, "division by zero")
errSqrtOfNegNumber = pgerror.New(pgcode.InvalidArgumentForPowerFunction, "cannot take square root of a negative number")
// ErrZeroModulus is reported when computing the rest of a division by zero.
ErrZeroModulus = pgerror.New(pgcode.DivisionByZero, "zero modulus")
big10E6 = big.NewInt(1e6)
big10E10 = big.NewInt(1e10)
......@@ -1310,13 +1308,13 @@ var BinOps = map[BinaryOperator]binOpOverload{
ReturnType: types.Decimal,
Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) {
rInt := MustBeDInt(right)
if rInt == 0 {
return nil, ErrDivByZero
}
div := ctx.getTmpDec().SetFinite(int64(rInt), 0)
dd := &DDecimal{}
dd.SetFinite(int64(MustBeDInt(left)), 0)
cond, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, div)
if cond.DivisionByZero() {
return dd, ErrDivByZero
}
_, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, div)
return dd, err
},
Volatility: VolatilityImmutable,
......@@ -1341,11 +1339,11 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := &right.(*DDecimal).Decimal
dd := &DDecimal{}
cond, err := DecimalCtx.Quo(&dd.Decimal, l, r)
if cond.DivisionByZero() {
return dd, ErrDivByZero
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
_, err := DecimalCtx.Quo(&dd.Decimal, l, r)
return dd, err
},
Volatility: VolatilityImmutable,
......@@ -1357,12 +1355,12 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := MustBeDInt(right)
if r == 0 {
return nil, ErrDivByZero
}
dd := &DDecimal{}
dd.SetFinite(int64(r), 0)
cond, err := DecimalCtx.Quo(&dd.Decimal, l, &dd.Decimal)
if cond.DivisionByZero() {
return dd, ErrDivByZero
}
_, err := DecimalCtx.Quo(&dd.Decimal, l, &dd.Decimal)
return dd, err
},
Volatility: VolatilityImmutable,
......@@ -1374,12 +1372,12 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := MustBeDInt(left)
r := &right.(*DDecimal).Decimal
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
dd.SetFinite(int64(l), 0)
cond, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, r)
if cond.DivisionByZero() {
return dd, ErrDivByZero
}
_, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, r)
return dd, err
},
Volatility: VolatilityImmutable,
......@@ -1433,6 +1431,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := float64(*left.(*DFloat))
r := float64(*right.(*DFloat))
if r == 0.0 {
return nil, ErrDivByZero
}
return NewDFloat(DFloat(math.Trunc(l / r))), nil
},
Volatility: VolatilityImmutable,
......@@ -1444,6 +1445,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := &right.(*DDecimal).Decimal
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
_, err := HighPrecisionCtx.QuoInteger(&dd.Decimal, l, r)
return dd, err
......@@ -1474,7 +1478,7 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := MustBeDInt(left)
r := &right.(*DDecimal).Decimal
if r.Sign() == 0 {
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
......@@ -1494,7 +1498,7 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
r := MustBeDInt(right)
if r == 0 {
return nil, ErrZeroModulus
return nil, ErrDivByZero
}
return NewDInt(MustBeDInt(left) % r), nil
},
......@@ -1505,7 +1509,12 @@ var BinOps = map[BinaryOperator]binOpOverload{
RightType: types.Float,
ReturnType: types.Float,
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
return NewDFloat(DFloat(math.Mod(float64(*left.(*DFloat)), float64(*right.(*DFloat))))), nil
l := float64(*left.(*DFloat))
r := float64(*right.(*DFloat))
if r == 0.0 {
return nil, ErrDivByZero
}
return NewDFloat(DFloat(math.Mod(l, r))), nil
},
Volatility: VolatilityImmutable,
},
......@@ -1516,6 +1525,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := &right.(*DDecimal).Decimal
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
_, err := HighPrecisionCtx.Rem(&dd.Decimal, l, r)
return dd, err
......@@ -1529,6 +1541,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := &left.(*DDecimal).Decimal
r := MustBeDInt(right)
if r == 0 {
return nil, ErrDivByZero
}
dd := &DDecimal{}
dd.SetFinite(int64(r), 0)
_, err := HighPrecisionCtx.Rem(&dd.Decimal, l, &dd.Decimal)
......@@ -1543,6 +1558,9 @@ var BinOps = map[BinaryOperator]binOpOverload{
Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) {
l := MustBeDInt(left)
r := &right.(*DDecimal).Decimal
if r.IsZero() {
return nil, ErrDivByZero
}
dd := &DDecimal{}
dd.SetFinite(int64(l), 0)
_, err := HighPrecisionCtx.Rem(&dd.Decimal, &dd.Decimal, r)
......
......@@ -258,7 +258,7 @@ func TestEvalError(t *testing.T) {
expr string
expected string
}{
{`1 % 0`, `zero modulus`},
{`1 % 0`, `division by zero`},
{`1 / 0`, `division by zero`},
{`1::float / 0::float`, `division by zero`},
{`1 // 0`, `division by zero`},
......
......@@ -208,7 +208,7 @@ func TestEval(t *testing.T) {
result.Op,
typs,
nil, /* output */
nil, /* metadataSourcesQueue */
result.MetadataSources,
nil, /* toClose */
nil, /* outputStatsToTrace */
nil, /* cancelFlow */
......@@ -228,10 +228,6 @@ func TestEval(t *testing.T) {
t.Fatalf("unexpected metadata: %+v", meta)
}
if row == nil {
// Might be some metadata.
if meta := mat.DrainHelper(); meta.Err != nil {
t.Fatalf("unexpected error: %s", meta.Err)
}
t.Fatal("unexpected end of input")
}
return row[0].Datum.String()
......
......@@ -88,11 +88,31 @@ eval
----
2
eval
1 // 0
----
division by zero
eval
-4.5 // 1.2
----
-3
eval
1.0 // 0.0
----
division by zero
eval
1.0 // 0
----
division by zero
eval
1 // 0.0
----
division by zero
eval
3.1 % 2.0
----
......@@ -118,6 +138,11 @@ eval
----
2
eval
1 % 0
----
division by zero
eval
1 + NULL
----
......@@ -148,11 +173,36 @@ eval
----
1
eval
1.0 % 0.0
----
division by zero
eval
1.0 % 0
----
division by zero
eval
1 % 0.0
----
division by zero
eval
-4.5:::float // 1.2:::float
----
-3.0
eval
1:::float // 0:::float
----
division by zero
eval
1:::float % 0:::float
----
division by zero
eval
2 ^ 3
----
......
......@@ -419,7 +419,6 @@ var ignoredErrorPatterns = []string{
"overflow",
"requested length too large",
"division by zero",
"zero modulus",
"is out of range",
// Type checking
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment