diff --git a/src/array/ops.rs b/src/array/ops.rs index 17966419..55eb23c2 100644 --- a/src/array/ops.rs +++ b/src/array/ops.rs @@ -13,8 +13,8 @@ use super::*; use crate::for_all_variants; use crate::parser::{BinaryOperator, UnaryOperator}; use crate::types::{ - Blob, ConvertError, DataTypeKind, DataValue, Date, DateTimeField, Interval, Timestamp, - TimestampTz, F64, + Blob, ConvertError, DataTypeKind, DataValue, Date, DateTimeField, Interval, NativeType, + Timestamp, TimestampTz, F64, }; type A = ArrayImpl; @@ -147,7 +147,7 @@ impl ArrayImpl { arith!(add, +); arith!(sub, -); arith!(mul, *); - arith!(div, /); + arith!(unchecked_div, /); arith!(rem, %); cmp!(eq, ==); cmp!(ne, !=); @@ -156,6 +156,17 @@ impl ArrayImpl { cmp!(ge, >=); cmp!(le, <=); + pub fn div(&self, other: &Self) -> Result { + let valid_rhs = other.get_valid_bitmap(); + let other = safen_dividend(other, valid_rhs).ok_or(ConvertError::NoBinaryOp( + "div".into(), + self.type_string(), + other.type_string(), + ))?; + + self.unchecked_div(&other) + } + pub fn and(&self, other: &Self) -> Result { let (A::Bool(a), A::Bool(b)) = (self, other) else { return Err(ConvertError::NoBinaryOp("and".into(), self.type_string(), other.type_string())); @@ -625,6 +636,55 @@ macro_rules! impl_agg { for_all_variants! { impl_agg } +fn safen_dividend(array: &ArrayImpl, valid: &BitVec) -> Option { + fn f(array: &PrimitiveArray, valid: &BitVec, value: N) -> T + where + T: ArrayFromDataExt, + N: NativeType + num_traits::Zero + Borrow<::Item>, + { + let mut valid = valid.to_owned(); + + // 1. set valid as false if item is zero + for (idx, item) in array.raw_iter().enumerate() { + if item.is_zero() { + valid.set(idx, false); + } + } + + // 2. replace item with safe dividend if valid is false + let data = array + .raw_iter() + .map(|item| if item.is_zero() { value } else { *item }); + + T::from_data(data, valid) + } + + // all valid dividend case + Some(match array { + ArrayImpl::Int16(array) => { + let array = f(array, valid, 1); + ArrayImpl::Int16(Arc::new(array)) + } + ArrayImpl::Int32(array) => { + let array = f(array, valid, 1); + ArrayImpl::Int32(Arc::new(array)) + } + ArrayImpl::Int64(array) => { + let array = f(array, valid, 1); + ArrayImpl::Int64(Arc::new(array)) + } + ArrayImpl::Float64(array) => { + let array = f(array, valid, 1.0.into()); + ArrayImpl::Float64(Arc::new(array)) + } + ArrayImpl::Decimal(array) => { + let array = f(array, valid, Decimal::new(1, 0)); + ArrayImpl::Decimal(Arc::new(array)) + } + _ => return None, + }) +} + fn binary_op(a: &A, b: &B, f: F) -> O where A: ArrayValidExt, diff --git a/tests/sql/nullable_operator.slt b/tests/sql/nullable_operator.slt new file mode 100644 index 00000000..8aab6df7 --- /dev/null +++ b/tests/sql/nullable_operator.slt @@ -0,0 +1,20 @@ +statement ok +create table t(x int, y int) + +statement ok +insert into t values (1, 2), (2, NULL) + +query I +select x / y from t +---- +0 +NULL + +query I +select x / 0 from t +---- +NULL +NULL + +statement ok +drop table t