Skip to content

Commit

Permalink
coprocessor: add mod(Int/Real/Decimal) RPN functions (tikv#4727)
Browse files Browse the repository at this point in the history
* coprocessor: add mod(Int/Real/Decimal) RPN functions

Signed-off-by: Lonng <heng@lonng.org>

* remove empty line

Signed-off-by: Lonng <heng@lonng.org>

* address comment

Signed-off-by: Lonng <heng@lonng.org>
  • Loading branch information
lonng authored and jswh committed May 27, 2019
1 parent c085a34 commit 5434936
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 7 deletions.
261 changes: 254 additions & 7 deletions src/coprocessor/dag/rpn_expr/impl_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use cop_codegen::RpnFunction;

use super::types::RpnFnCallPayload;
use crate::coprocessor::codec::data_type::*;
use crate::coprocessor::codec::mysql::{Decimal, Res};
use crate::coprocessor::codec::{self, Error};
use crate::coprocessor::dag::expr::EvalContext;
use crate::coprocessor::Result;
Expand Down Expand Up @@ -138,25 +139,122 @@ impl ArithmeticOp for DecimalPlus {
}
}

#[derive(Debug)]
pub struct IntIntMod;

impl ArithmeticOp for IntIntMod {
type T = Int;

fn calc(lhs: &Int, rhs: &Int) -> Result<Option<Int>> {
if *rhs == 0i64 {
return Ok(None);
}
Ok(Some(lhs % rhs))
}
}

#[derive(Debug)]
pub struct IntUintMod;

impl ArithmeticOp for IntUintMod {
type T = Int;

fn calc(lhs: &Int, rhs: &Int) -> Result<Option<Int>> {
if *rhs == 0i64 {
return Ok(None);
}
Ok(Some(
((lhs.overflowing_abs().0 as u64) % (*rhs as u64)) as i64,
))
}
}

#[derive(Debug)]
pub struct UintIntMod;

impl ArithmeticOp for UintIntMod {
type T = Int;

fn calc(lhs: &Int, rhs: &Int) -> Result<Option<Int>> {
if *rhs == 0i64 {
return Ok(None);
}
Ok(Some(
((*lhs as u64) % (rhs.overflowing_abs().0 as u64)) as i64,
))
}
}

#[derive(Debug)]
pub struct UintUintMod;
impl ArithmeticOp for UintUintMod {
type T = Int;

fn calc(lhs: &Int, rhs: &Int) -> Result<Option<Int>> {
if *rhs == 0i64 {
return Ok(None);
}
Ok(Some(((*lhs as u64) % (*rhs as u64)) as i64))
}
}

#[derive(Debug)]
pub struct RealMod;

impl ArithmeticOp for RealMod {
type T = Real;

fn calc(lhs: &Real, rhs: &Real) -> Result<Option<Real>> {
if (*rhs).into_inner() == 0f64 {
return Ok(None);
}
Ok(Some(*lhs % *rhs))
}
}

#[derive(Debug)]
pub struct DecimalMod;

impl ArithmeticOp for DecimalMod {
type T = Decimal;

fn calc(lhs: &Decimal, rhs: &Decimal) -> Result<Option<Decimal>> {
if rhs.is_zero() {
return Ok(None);
}
match lhs % rhs {
Some(v) => match v {
Res::Ok(v) => Ok(Some(v)),
Res::Truncated(_) => Err(Error::truncated())?,
Res::Overflow(_) => {
Err(Error::overflow("DECIMAL", &format!("({} % {})", lhs, rhs)))?
}
},
None => Ok(None),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::coprocessor::dag::rpn_expr::types::test_util::RpnFnScalarEvaluator;

use cop_datatype::builder::FieldTypeBuilder;
use cop_datatype::{FieldTypeFlag, FieldTypeTp};
use std::i64;
use tipb::expression::ScalarFuncSig;

use crate::coprocessor::dag::rpn_expr::types::test_util::RpnFnScalarEvaluator;

#[test]
fn test_arithmetic_int() {
let test_cases = vec![
(None, false, Some(1), false, None),
(Some(1), false, None, false, None),
(Some(17), false, Some(25), false, Some(42)),
(
Some(i64::MIN),
Some(std::i64::MIN),
false,
Some((i64::MAX as u64 + 1) as i64),
Some((std::i64::MAX as u64 + 1) as i64),
true,
Some(0),
),
Expand All @@ -183,7 +281,7 @@ mod tests {
.push_param_with_field_type(rhs, rhs_field_type)
.evaluate(ScalarFuncSig::PlusInt)
.unwrap();
assert_eq!(output, expected, "{:?}, {:?}", output, expected);
assert_eq!(output, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}
}

Expand All @@ -207,7 +305,7 @@ mod tests {
assert!(output.is_err())
} else {
let output = output.unwrap();
assert_eq!(output, expected, "{:?}, {:?}", output, expected);
assert_eq!(output, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}
}
}
Expand All @@ -222,7 +320,156 @@ mod tests {
.push_param(rhs.parse::<Decimal>().ok())
.evaluate(ScalarFuncSig::PlusDecimal)
.unwrap();
assert_eq!(output, expected, "{:?}, {:?}", output, expected);
assert_eq!(output, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}
}

#[test]
fn test_mod_int() {
let tests = vec![
(Some(13), Some(11), Some(2)),
(Some(-13), Some(11), Some(-2)),
(Some(13), Some(-11), Some(2)),
(Some(-13), Some(-11), Some(-2)),
(Some(33), Some(11), Some(0)),
(Some(33), Some(-11), Some(0)),
(Some(-33), Some(-11), Some(0)),
(Some(-11), None, None),
(None, Some(-11), None),
(Some(11), Some(0), None),
(Some(-11), Some(0), None),
(
Some(std::i64::MAX),
Some(std::i64::MIN),
Some(std::i64::MAX),
),
(Some(std::i64::MIN), Some(std::i64::MAX), Some(-1)),
];

for (lhs, rhs, expected) in tests {
let output = RpnFnScalarEvaluator::new()
.push_param(lhs)
.push_param(rhs)
.evaluate(ScalarFuncSig::ModInt)
.unwrap();
assert_eq!(output, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}
}
#[test]
fn test_mod_int_unsigned() {
let tests = vec![
(
Some(std::u64::MAX as i64),
true,
Some(std::i64::MIN),
false,
Some(std::i64::MAX),
),
(
Some(std::i64::MIN),
false,
Some(std::u64::MAX as i64),
true,
Some(std::i64::MIN),
),
];

for (lhs, lhs_is_unsigned, rhs, rhs_is_unsigned, expected) in tests {
let lhs_field_type = FieldTypeBuilder::new()
.tp(FieldTypeTp::LongLong)
.flag(if lhs_is_unsigned {
FieldTypeFlag::UNSIGNED
} else {
FieldTypeFlag::empty()
})
.build();
let rhs_field_type = FieldTypeBuilder::new()
.tp(FieldTypeTp::LongLong)
.flag(if rhs_is_unsigned {
FieldTypeFlag::UNSIGNED
} else {
FieldTypeFlag::empty()
})
.build();
let output = RpnFnScalarEvaluator::new()
.push_param_with_field_type(lhs, lhs_field_type)
.push_param_with_field_type(rhs, rhs_field_type)
.evaluate(ScalarFuncSig::ModInt)
.unwrap();
assert_eq!(output, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}
}

#[test]
fn test_mod_real() {
let tests = vec![
(Real::new(1.0).ok(), None, None),
(None, Real::new(1.0).ok(), None),
(
Real::new(1.0).ok(),
Real::new(1.1).ok(),
Real::new(1.0).ok(),
),
(
Real::new(-1.0).ok(),
Real::new(1.1).ok(),
Real::new(-1.0).ok(),
),
(
Real::new(1.0).ok(),
Real::new(-1.1).ok(),
Real::new(1.0).ok(),
),
(
Real::new(-1.0).ok(),
Real::new(-1.1).ok(),
Real::new(-1.0).ok(),
),
(Real::new(1.0).ok(), Real::new(0.0).ok(), None),
];

for (lhs, rhs, expected) in tests {
let output = RpnFnScalarEvaluator::new()
.push_param(lhs)
.push_param(rhs)
.evaluate(ScalarFuncSig::ModReal)
.unwrap();
assert_eq!(output, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}
}

#[test]
fn test_mod_decimal() {
let tests = vec![
("13", "11", "2"),
("-13", "11", "-2"),
("13", "-11", "2"),
("-13", "-11", "-2"),
("33", "11", "0"),
("-33", "11", "0"),
("33", "-11", "0"),
("-33", "-11", "0"),
("0.0000000001", "1.0", "0.0000000001"),
("1", "1.1", "1"),
("-1", "1.1", "-1"),
("1", "-1.1", "1"),
("-1", "-1.1", "-1"),
("3", "0", ""),
("-3", "0", ""),
("0", "0", ""),
("-3", "", ""),
("", ("-3"), ""),
("", "", ""),
];

for (lhs, rhs, expected) in tests {
let expected = expected.parse::<Decimal>().ok();
let output = RpnFnScalarEvaluator::new()
.push_param(lhs.parse::<Decimal>().ok())
.push_param(rhs.parse::<Decimal>().ok())
.evaluate(ScalarFuncSig::ModDecimal)
.unwrap();
assert_eq!(output, expected, "lhs={:?}, rhs={:?}", lhs, rhs);
}
}
}
12 changes: 12 additions & 0 deletions src/coprocessor/dag/rpn_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ fn plus_mapper(lhs_is_unsigned: bool, rhs_is_unsigned: bool) -> Box<dyn RpnFunct
}
}

fn mod_mapper(lhs_is_unsigned: bool, rhs_is_unsigned: bool) -> Box<dyn RpnFunction> {
match (lhs_is_unsigned, rhs_is_unsigned) {
(false, false) => Box::new(RpnFnArithmetic::<IntIntMod>::new()),
(false, true) => Box::new(RpnFnArithmetic::<IntUintMod>::new()),
(true, false) => Box::new(RpnFnArithmetic::<UintIntMod>::new()),
(true, true) => Box::new(RpnFnArithmetic::<UintUintMod>::new()),
}
}

#[rustfmt::skip]
fn map_pb_sig_to_rpn_func(value: ScalarFuncSig, children: &[Expr]) -> Result<Box<dyn RpnFunction>> {
Ok(match value {
Expand Down Expand Up @@ -135,6 +144,9 @@ fn map_pb_sig_to_rpn_func(value: ScalarFuncSig, children: &[Expr]) -> Result<Box
ScalarFuncSig::PlusInt => map_int_sig(value, children, plus_mapper)?,
ScalarFuncSig::PlusReal => Box::new(RpnFnArithmetic::<RealPlus>::new()),
ScalarFuncSig::PlusDecimal => Box::new(RpnFnArithmetic::<DecimalPlus>::new()),
ScalarFuncSig::ModReal => Box::new(RpnFnArithmetic::<RealMod>::new()),
ScalarFuncSig::ModDecimal => Box::new(RpnFnArithmetic::<DecimalMod>::new()),
ScalarFuncSig::ModInt => map_int_sig(value, children, mod_mapper)?,
_ => return Err(box_err!(
"ScalarFunction {:?} is not supported in batch mode",
value
Expand Down

0 comments on commit 5434936

Please sign in to comment.