Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add dynamic literals to ensure schema correctness #15832

Merged
merged 11 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/polars-core/src/datatypes/_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ enum SerializableDataType {
#[cfg(feature = "dtype-struct")]
Struct(Vec<Field>),
// some logical types we cannot know statically, e.g. Datetime
Unknown,
Unknown(UnknownKind),
#[cfg(feature = "dtype-categorical")]
Categorical(Option<Wrap<Utf8ViewArray>>, CategoricalOrdering),
#[cfg(feature = "dtype-decimal")]
Expand Down Expand Up @@ -141,7 +141,7 @@ impl From<&DataType> for SerializableDataType {
#[cfg(feature = "dtype-array")]
Array(dt, width) => Self::Array(Box::new(dt.as_ref().into()), *width),
Null => Self::Null,
Unknown => Self::Unknown,
Unknown(kind) => Self::Unknown(*kind),
#[cfg(feature = "dtype-struct")]
Struct(flds) => Self::Struct(flds.clone()),
#[cfg(feature = "dtype-categorical")]
Expand Down Expand Up @@ -185,7 +185,7 @@ impl From<SerializableDataType> for DataType {
#[cfg(feature = "dtype-array")]
Array(dt, width) => Self::Array(Box::new((*dt).into()), width),
Null => Self::Null,
Unknown => Self::Unknown,
Unknown(kind) => Self::Unknown(kind),
#[cfg(feature = "dtype-struct")]
Struct(flds) => Self::Struct(flds),
#[cfg(feature = "dtype-categorical")]
Expand Down
72 changes: 65 additions & 7 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,22 @@ pub type TimeZone = String;
pub static DTYPE_ENUM_KEY: &str = "POLARS.CATEGORICAL_TYPE";
pub static DTYPE_ENUM_VALUE: &str = "ENUM";

#[derive(Clone, Debug, Default)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
#[cfg_attr(
any(feature = "serde", feature = "serde-lazy"),
derive(Serialize, Deserialize)
)]
pub enum UnknownKind {
// Hold the value to determine the concrete size.
Int(i128),
Float,
// Can be Categorical or String
Str,
#[default]
Any,
}

#[derive(Clone, Debug)]
pub enum DataType {
Boolean,
UInt8,
Expand Down Expand Up @@ -59,8 +74,13 @@ pub enum DataType {
#[cfg(feature = "dtype-struct")]
Struct(Vec<Field>),
// some logical types we cannot know statically, e.g. Datetime
#[default]
Unknown,
Unknown(UnknownKind),
}

impl Default for DataType {
fn default() -> Self {
DataType::Unknown(UnknownKind::Any)
}
}

pub trait AsRefDataType {
Expand Down Expand Up @@ -144,7 +164,7 @@ impl DataType {
DataType::List(inner) => inner.is_known(),
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => fields.iter().all(|fld| fld.dtype.is_known()),
DataType::Unknown => false,
DataType::Unknown(_) => false,
_ => true,
}
}
Expand Down Expand Up @@ -208,7 +228,14 @@ impl DataType {

/// Check if this [`DataType`] is a basic numeric type (excludes Decimal).
pub fn is_numeric(&self) -> bool {
self.is_float() || self.is_integer()
self.is_float() || self.is_integer() || self.is_dynamic()
}

pub fn is_dynamic(&self) -> bool {
matches!(
self,
DataType::Unknown(UnknownKind::Int(_) | UnknownKind::Float | UnknownKind::Str)
)
}

/// Check if this [`DataType`] is a boolean
Expand Down Expand Up @@ -382,6 +409,32 @@ impl DataType {
}
}

pub fn is_string(&self) -> bool {
matches!(self, DataType::String | DataType::Unknown(UnknownKind::Str))
}

pub fn is_categorical(&self) -> bool {
#[cfg(feature = "dtype-categorical")]
{
matches!(self, DataType::Categorical(_, _))
}
#[cfg(not(feature = "dtype-categorical"))]
{
false
}
}

pub fn is_enum(&self) -> bool {
#[cfg(feature = "dtype-categorical")]
{
matches!(self, DataType::Enum(_, _))
}
#[cfg(not(feature = "dtype-categorical"))]
{
false
}
}

/// Convert to an Arrow Field
pub fn to_arrow_field(&self, name: &str, pl_flavor: bool) -> ArrowField {
let metadata = match self {
Expand Down Expand Up @@ -490,7 +543,7 @@ impl DataType {
Ok(ArrowDataType::Struct(fields))
},
BinaryOffset => Ok(ArrowDataType::LargeBinary),
Unknown => Ok(ArrowDataType::Unknown),
Unknown(_) => Ok(ArrowDataType::Unknown),
}
}

Expand Down Expand Up @@ -591,7 +644,12 @@ impl Display for DataType {
DataType::Enum(_, _) => "enum",
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => return write!(f, "struct[{}]", fields.len()),
DataType::Unknown => "unknown",
DataType::Unknown(kind) => match kind {
UnknownKind::Any => "unknown",
UnknownKind::Int(_) => "dyn int",
UnknownKind::Float => "dyn float",
UnknownKind::Str => "dyn str",
},
DataType::BinaryOffset => "binary[offset]",
};
f.write_str(s)
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ impl Field {
}
}

impl AsRef<DataType> for Field {
fn as_ref(&self) -> &DataType {
&self.dtype
}
}

impl AsRef<DataType> for DataType {
fn as_ref(&self) -> &DataType {
self
}
}

impl DataType {
pub fn boxed(self) -> Box<DataType> {
Box::new(self)
Expand Down
23 changes: 15 additions & 8 deletions crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ macro_rules! impl_polars_num_datatype {
};
}

macro_rules! impl_polars_datatype {
($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
macro_rules! impl_polars_datatype2 {
($ca:ident, $dtype:expr, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
#[derive(Clone, Copy)]
pub struct $ca {}

Expand All @@ -128,12 +128,18 @@ macro_rules! impl_polars_datatype {

#[inline]
fn get_dtype() -> DataType {
DataType::$variant
$dtype
}
}
};
}

macro_rules! impl_polars_datatype {
($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
impl_polars_datatype2!($ca, DataType::$variant, $arr, $lt, $phys, $zerophys);
};
}

impl_polars_num_datatype!(PolarsIntegerType, UInt8Type, UInt8, u8);
impl_polars_num_datatype!(PolarsIntegerType, UInt16Type, UInt16, u16);
impl_polars_num_datatype!(PolarsIntegerType, UInt32Type, UInt32, u32);
Expand All @@ -145,17 +151,18 @@ impl_polars_num_datatype!(PolarsIntegerType, Int64Type, Int64, i64);
impl_polars_num_datatype!(PolarsFloatType, Float32Type, Float32, f32);
impl_polars_num_datatype!(PolarsFloatType, Float64Type, Float64, f64);
impl_polars_datatype!(DateType, Date, PrimitiveArray<i32>, 'a, i32, i32);
#[cfg(feature = "dtype-decimal")]
impl_polars_datatype!(DecimalType, Unknown, PrimitiveArray<i128>, 'a, i128, i128);
impl_polars_datatype!(DatetimeType, Unknown, PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype!(DurationType, Unknown, PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype!(CategoricalType, Unknown, PrimitiveArray<u32>, 'a, u32, u32);
impl_polars_datatype!(TimeType, Time, PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype!(StringType, String, Utf8ViewArray, 'a, &'a str, Option<&'a str>);
impl_polars_datatype!(BinaryType, Binary, BinaryViewArray, 'a, &'a [u8], Option<&'a [u8]>);
impl_polars_datatype!(BinaryOffsetType, BinaryOffset, BinaryArray<i64>, 'a, &'a [u8], Option<&'a [u8]>);
impl_polars_datatype!(BooleanType, Boolean, BooleanArray, 'a, bool, bool);

#[cfg(feature = "dtype-decimal")]
impl_polars_datatype2!(DecimalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i128>, 'a, i128, i128);
impl_polars_datatype2!(DatetimeType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype2!(DurationType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype2!(CategoricalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<u32>, 'a, u32, u32);

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ListType {}
unsafe impl PolarsDataType for ListType {
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ impl Series {
}
},
Null => new_null(name, &chunks),
Unknown => panic!("uh oh, somehow we don't know the dtype?"),
Unknown(_) => {
panic!("dtype is unknown; consider supplying data-types for all operations")
},
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
Expand Down
87 changes: 84 additions & 3 deletions crates/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use num_traits::Signed;

use super::*;

/// Given two data types, determine the data type that both types can safely be cast to.
Expand Down Expand Up @@ -195,9 +197,9 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
(Time, Float64) => Some(Float64),

// every known type can be casted to a string except binary
(dt, String) if dt != &DataType::Unknown && dt != &DataType::Binary => Some(String),
(dt, String) if !matches!(dt, DataType::Unknown(UnknownKind::Any)) && dt != &DataType::Binary => Some(String),

(dt, String) if dt != &DataType::Unknown => Some(String),
(dt, String) if !matches!(dt, DataType::Unknown(UnknownKind::Any)) => Some(String),

(dt, Null) => Some(dt.clone()),

Expand Down Expand Up @@ -253,7 +255,35 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
let st = get_supertype(inner_left, inner_right)?;
Some(DataType::List(Box::new(st)))
}
(_, Unknown) => Some(Unknown),
#[cfg(feature = "dtype-struct")]
(Struct(inner), right @ Unknown(UnknownKind::Float | UnknownKind::Int(_))) => {
match inner.first() {
Some(inner) => get_supertype(&inner.dtype, right),
None => None
}
},
(dt, Unknown(kind)) => {
match kind {
UnknownKind::Float | UnknownKind::Int(_) if dt.is_float() | dt.is_string() => Some(dt.clone()),
UnknownKind::Float if dt.is_numeric() => Some(Unknown(UnknownKind::Float)),
UnknownKind::Str if dt.is_string() | dt.is_enum() => Some(dt.clone()),
#[cfg(feature = "dtype-categorical")]
UnknownKind::Str if dt.is_categorical() => {
let Categorical(_, ord) = dt else { unreachable!()};
Some(Categorical(None, *ord))
},
dynam if dt.is_null() => Some(Unknown(*dynam)),
UnknownKind::Int(v) if dt.is_numeric() => {
let smallest_fitting_dtype = if dt.is_unsigned_integer() && v.is_positive() {
materialize_dyn_int_pos(*v).dtype()
} else {
materialize_smallest_dyn_int(*v).dtype()
};
get_supertype(dt, &smallest_fitting_dtype)
}
_ => Some(Unknown(UnknownKind::Any))
}
},
#[cfg(feature = "dtype-struct")]
(Struct(fields_a), Struct(fields_b)) => {
super_type_structs(fields_a, fields_b)
Expand Down Expand Up @@ -341,3 +371,54 @@ fn super_type_structs(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType
Some(DataType::Struct(new_fields))
}
}

pub fn materialize_dyn_int(v: i128) -> AnyValue<'static> {
// Try to get the "smallest" fitting value.
// TODO! next breaking go to true smallest.
match i32::try_from(v).ok() {
Some(v) => AnyValue::Int32(v),
None => match i64::try_from(v).ok() {
Some(v) => AnyValue::Int64(v),
None => match u64::try_from(v).ok() {
Some(v) => AnyValue::UInt64(v),
None => AnyValue::Null,
},
},
}
}
fn materialize_dyn_int_pos(v: i128) -> AnyValue<'static> {
// Try to get the "smallest" fitting value.
// TODO! next breaking go to true smallest.
match u8::try_from(v).ok() {
Some(v) => AnyValue::UInt8(v),
None => match u16::try_from(v).ok() {
Some(v) => AnyValue::UInt16(v),
None => match u32::try_from(v).ok() {
Some(v) => AnyValue::UInt32(v),
None => match u64::try_from(v).ok() {
Some(v) => AnyValue::UInt64(v),
None => AnyValue::Null,
},
},
},
}
}

fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> {
match i8::try_from(v).ok() {
Some(v) => AnyValue::Int8(v),
None => match i16::try_from(v).ok() {
Some(v) => AnyValue::Int16(v),
None => match i32::try_from(v).ok() {
Some(v) => AnyValue::Int32(v),
None => match i64::try_from(v).ok() {
Some(v) => AnyValue::Int64(v),
None => match u64::try_from(v).ok() {
Some(v) => AnyValue::UInt64(v),
None => AnyValue::Null,
},
},
},
},
}
}
6 changes: 6 additions & 0 deletions crates/polars-lazy/src/physical_plan/expressions/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ impl PhysicalExpr for LiteralExpr {
.into_time()
.into_series(),
Series(series) => series.deref().clone(),
lv @ (Int(_) | Float(_) | StrCat(_)) => polars_core::prelude::Series::from_any_values(
LITERAL_NAME,
&[lv.to_any_value().unwrap()],
false,
)
.unwrap(),
};
Ok(s)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ pub(crate) fn insert_streaming_nodes(
.iter()
.all(|fld| allowed_dtype(fld.data_type(), string_cache)),
// We need to be able to sink to disk or produce the aggregate return dtype.
DataType::Unknown => false,
DataType::Unknown(_) => false,
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(_, _) => false,
_ => true,
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/tests/optimization_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,12 @@ pub fn test_predicate_block_cast() -> PolarsResult<()> {
let lf1 = df
.clone()
.lazy()
.with_column(col("value").cast(DataType::Int16) * lit(0.1f32))
.with_column(col("value").cast(DataType::Int16) * lit(0.1).cast(DataType::Float32))
.filter(col("value").lt(lit(2.5f32)));

let lf2 = df
.lazy()
.select([col("value").cast(DataType::Int16) * lit(0.1f32)])
.select([col("value").cast(DataType::Int16) * lit(0.1).cast(DataType::Float32)])
.filter(col("value").lt(lit(2.5f32)));

for lf in [lf1, lf2] {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ fn test_simplify_expr() {

let plan = df
.lazy()
.select(&[lit(1.0f32) + lit(1.0f32) + col("sepal_width")])
.select(&[lit(1.0) + lit(1.0) + col("sepal_width")])
.logical_plan;

let mut expr_arena = Arena::new();
Expand All @@ -564,7 +564,7 @@ fn test_simplify_expr() {
.unwrap();
let plan = node_to_lp(lp_top, &expr_arena, &mut lp_arena);
assert!(
matches!(plan, DslPlan::Select{ expr, ..} if matches!(&expr[0], Expr::BinaryExpr{left, ..} if **left == Expr::Literal(LiteralValue::Float32(2.0))))
matches!(plan, DslPlan::Select{ expr, ..} if matches!(&expr[0], Expr::BinaryExpr{left, ..} if **left == Expr::Literal(LiteralValue::Float(2.0))))
);
}

Expand Down
Loading
Loading