Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Support shortcut eval of common boolean filters in SQL interface "WHERE" clause #18571

Merged
merged 2 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,28 @@ impl SQLContext {
expr: &Option<SQLExpr>,
) -> PolarsResult<LazyFrame> {
if let Some(expr) = expr {
let schema = Some(self.get_frame_schema(&mut lf)?);
let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?;
let schema = self.get_frame_schema(&mut lf)?;

// shortcut filter evaluation if given expression is just TRUE or FALSE
let (all_true, all_false) = match expr {
SQLExpr::Value(SQLValue::Boolean(b)) => (*b, !*b),
SQLExpr::BinaryOp { left, op, right } => match (&**left, &**right, op) {
(SQLExpr::Value(a), SQLExpr::Value(b), BinaryOperator::Eq) => (a == b, a != b),
(SQLExpr::Value(a), SQLExpr::Value(b), BinaryOperator::NotEq) => {
(a != b, a == b)
},
_ => (false, false),
},
_ => (false, false),
};
if all_true {
return Ok(lf);
} else if all_false {
return Ok(DataFrame::empty_with_schema(schema.as_ref()).lazy());
}

// ...otherwise parse and apply the filter as normal
let mut filter_expression = parse_sql_expr(expr, self, Some(schema).as_deref())?;
if filter_expression.clone().meta().has_multiple_outputs() {
filter_expression = all_horizontal([filter_expression])?;
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/src/function_registry.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! This module defines the function registry and user defined functions.
//! This module defines a FunctionRegistry for supported SQL functions and UDFs.

use polars_error::{polars_bail, PolarsResult};
use polars_plan::prelude::udf::UserDefinedFunction;
Expand Down
8 changes: 3 additions & 5 deletions crates/polars-sql/src/keywords.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
//! Keywords that are supported by Polars SQL
//!
//! This is useful for syntax highlighting
//! Keywords that are supported by the Polars SQL interface.
//!
//! This module defines:
//! - all Polars SQL keywords [`all_keywords`]
//! - all of polars SQL functions [`all_functions`]
//! - all recognised Polars SQL keywords [`all_keywords`]
//! - all recognised Polars SQL functions [`all_functions`]
use crate::functions::PolarsSQLFunctions;
use crate::table_functions::PolarsTableFunctions;

Expand Down
1 change: 1 addition & 0 deletions crates/polars-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod functions;
pub mod keywords;
mod sql_expr;
mod table_functions;
mod types;

pub use context::SQLContext;
pub use sql_expr::sql_expr;
223 changes: 18 additions & 205 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
//! Expressions that are supported by the Polars SQL interface.
//!
//! This is useful for syntax highlighting
//!
//! This module defines:
//! - all Polars SQL keywords [`all_keywords`]
//! - all of polars SQL functions [`all_functions`]

use std::fmt::Display;
use std::ops::Div;

Expand All @@ -9,216 +17,39 @@ use polars_plan::prelude::LiteralValue::Null;
use polars_time::Duration;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use regex::{Regex, RegexBuilder};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "dtype-decimal")]
use sqlparser::ast::ExactNumberInfo;
use sqlparser::ast::{
ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind,
BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind,
DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident,
Interval, ObjectName, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField,
Interval, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField,
UnaryOperator, Value as SQLValue,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};

use crate::functions::SQLFunctionVisitor;
use crate::types::{
bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
};
use crate::SQLContext;

static DATETIME_LITERAL_RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
static DATE_LITERAL_RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();
static TIME_LITERAL_RE: std::sync::OnceLock<Regex> = std::sync::OnceLock::new();

fn is_iso_datetime(value: &str) -> bool {
let dtm_regex = DATETIME_LITERAL_RE.get_or_init(|| {
RegexBuilder::new(
r"^\d{4}-[01]\d-[0-3]\d[ T](?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$",
)
.build()
.unwrap()
});
dtm_regex.is_match(value)
}

fn is_iso_date(value: &str) -> bool {
let dt_regex = DATE_LITERAL_RE.get_or_init(|| {
RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d$")
.build()
.unwrap()
});
dt_regex.is_match(value)
}

fn is_iso_time(value: &str) -> bool {
let tm_regex = TIME_LITERAL_RE.get_or_init(|| {
RegexBuilder::new(r"^(?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$")
.build()
.unwrap()
});
tm_regex.is_match(value)
}

#[inline]
#[cold]
#[must_use]
/// Convert a Display-able error to PolarsError::SQLInterface
pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
PolarsError::SQLInterface(err.to_string().into())
}

fn timeunit_from_precision(prec: &Option<u64>) -> PolarsResult<TimeUnit> {
Ok(match prec {
None => TimeUnit::Microseconds,
Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds,
Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds,
Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds,
Some(n) => {
polars_bail!(SQLSyntax: "invalid temporal type precision (expected 1-9, found {})", n)
},
})
}

pub(crate) fn map_sql_polars_datatype(dtype: &SQLDataType) -> PolarsResult<DataType> {
Ok(match dtype {
// ---------------------------------
// array/list
// ---------------------------------
SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type))
| SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type, _)) => {
DataType::List(Box::new(map_sql_polars_datatype(inner_type)?))
},

// ---------------------------------
// binary
// ---------------------------------
SQLDataType::Bytea
| SQLDataType::Bytes(_)
| SQLDataType::Binary(_)
| SQLDataType::Blob(_)
| SQLDataType::Varbinary(_) => DataType::Binary,

// ---------------------------------
// boolean
// ---------------------------------
SQLDataType::Boolean | SQLDataType::Bool => DataType::Boolean,

// ---------------------------------
// signed integer
// ---------------------------------
SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32,
SQLDataType::Int2(_) | SQLDataType::SmallInt(_) => DataType::Int16,
SQLDataType::Int4(_) | SQLDataType::MediumInt(_) => DataType::Int32,
SQLDataType::Int8(_) | SQLDataType::BigInt(_) => DataType::Int64,
SQLDataType::TinyInt(_) => DataType::Int8,

// ---------------------------------
// unsigned integer: the following do not map to PostgreSQL types/syntax, but
// are enabled for wider compatibility (eg: "CAST(col AS BIGINT UNSIGNED)").
// ---------------------------------
SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below
SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32,
SQLDataType::UnsignedInt2(_) | SQLDataType::UnsignedSmallInt(_) => DataType::UInt16,
SQLDataType::UnsignedInt4(_) | SQLDataType::UnsignedMediumInt(_) => DataType::UInt32,
SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) | SQLDataType::UInt8 => {
DataType::UInt64
},

// ---------------------------------
// float
// ---------------------------------
SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => {
DataType::Float64
},
SQLDataType::Float(n_bytes) => match n_bytes {
Some(n) if (1u64..=24u64).contains(n) => DataType::Float32,
Some(n) if (25u64..=53u64).contains(n) => DataType::Float64,
Some(n) => {
polars_bail!(SQLSyntax: "unsupported `float` size (expected a value between 1 and 53, found {})", n)
},
None => DataType::Float64,
},
SQLDataType::Float4 | SQLDataType::Real => DataType::Float32,

// ---------------------------------
// decimal
// ---------------------------------
#[cfg(feature = "dtype-decimal")]
SQLDataType::Dec(info) | SQLDataType::Decimal(info) | SQLDataType::Numeric(info) => {
match *info {
ExactNumberInfo::PrecisionAndScale(p, s) => {
DataType::Decimal(Some(p as usize), Some(s as usize))
},
ExactNumberInfo::Precision(p) => DataType::Decimal(Some(p as usize), Some(0)),
ExactNumberInfo::None => DataType::Decimal(Some(38), Some(9)),
}
},

// ---------------------------------
// temporal
// ---------------------------------
SQLDataType::Date => DataType::Date,
SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds),
SQLDataType::Time(_, tz) => match tz {
TimezoneInfo::None => DataType::Time,
_ => {
polars_bail!(SQLInterface: "`time` with timezone is not supported; found tz={}", tz)
},
},
SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None),
SQLDataType::Timestamp(prec, tz) => match tz {
TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None),
_ => {
polars_bail!(SQLInterface: "`timestamp` with timezone is not (yet) supported")
},
},

// ---------------------------------
// string
// ---------------------------------
SQLDataType::Char(_)
| SQLDataType::CharVarying(_)
| SQLDataType::Character(_)
| SQLDataType::CharacterVarying(_)
| SQLDataType::Clob(_)
| SQLDataType::String(_)
| SQLDataType::Text
| SQLDataType::Uuid
| SQLDataType::Varchar(_) => DataType::String,

// ---------------------------------
// custom
// ---------------------------------
SQLDataType::Custom(ObjectName(idents), _) => match idents.as_slice() {
[Ident { value, .. }] => match value.to_lowercase().as_str() {
// these integer types are not supported by the PostgreSQL core distribution,
// but they ARE available via `pguint` (https://github.com/petere/pguint), an
// extension maintained by one of the PostgreSQL core developers.
"uint1" => DataType::UInt8,
"uint2" => DataType::UInt16,
"uint4" | "uint" => DataType::UInt32,
"uint8" => DataType::UInt64,
// `pguint` also provides a 1 byte (8bit) integer type alias
"int1" => DataType::Int8,
_ => {
polars_bail!(SQLInterface: "datatype {:?} is not currently supported", value)
},
},
_ => {
polars_bail!(SQLInterface: "datatype {:?} is not currently supported", idents)
},
},
_ => {
polars_bail!(SQLInterface: "datatype {:?} is not currently supported", dtype)
},
})
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
/// Categorises the type of (allowed) subquery constraint
pub enum SubqueryRestriction {
// SingleValue,
/// Subquery must return a single column
SingleColumn,
// SingleRow,
// SingleValue,
// Any
}

Expand Down Expand Up @@ -889,7 +720,7 @@ impl SQLExprVisitor<'_> {
if dtype == &SQLDataType::JSON {
return Ok(expr.str().json_decode(None, None));
}
let polars_type = map_sql_polars_datatype(dtype)?;
let polars_type = map_sql_dtype_to_polars(dtype)?;
Ok(match cast_kind {
CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
Expand Down Expand Up @@ -1319,24 +1150,6 @@ pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
}
}

fn bitstring_to_bytes_literal(b: &String) -> PolarsResult<Expr> {
let n_bits = b.len();
if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 {
polars_bail!(
SQLSyntax:
"bit string literal should contain only 0s and 1s and have length <= 64; found '{}' with length {}", b, n_bits
)
}
let s = b.as_str();
Ok(lit(match n_bits {
0 => b"".to_vec(),
1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
_ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(),
}))
}

pub(crate) fn resolve_compound_identifier(
ctx: &mut SQLContext,
idents: &[Ident],
Expand Down
Loading
Loading