Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyExpr to_variant conversions #793

Merged
merged 13 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/datafusion/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,31 @@ def test_relational_expr(test_ctx):
assert df.filter(col("b") != "beta").count() == 2

assert df.filter(col("a") == "beta").count() == 0


def test_expr_to_variant():
# Taken from https://github.com/apache/datafusion-python/issues/781
from datafusion import SessionContext
from datafusion.expr import Filter


def traverse_logical_plan(plan):
cur_node = plan.to_variant()
if isinstance(cur_node, Filter):
return cur_node.predicate().to_variant()
if hasattr(plan, 'inputs'):
for input_plan in plan.inputs():
res = traverse_logical_plan(input_plan)
if res is not None:
return res

ctx = SessionContext()
data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']}
ctx.from_pydict(data, name='table1')
query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')"
logical_plan = ctx.sql(query).optimized_logical_plan()
variant = traverse_logical_plan(logical_plan)
assert variant is not None
assert variant.expr().to_variant().qualified_name() == 'table1.name'
assert str(variant.list()) == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]'
assert not variant.negated()
51 changes: 45 additions & 6 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_expr::{
};

use crate::common::data_type::{DataTypeMap, RexType};
use crate::errors::{py_runtime_err, py_type_err, DataFusionError};
use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, DataFusionError};
use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
Expand Down Expand Up @@ -84,11 +84,13 @@ pub mod scalar_subquery;
pub mod scalar_variable;
pub mod signature;
pub mod sort;
pub mod sort_expr;
pub mod subquery;
pub mod subquery_alias;
pub mod table_scan;
pub mod union;
pub mod unnest;
pub mod unnest_expr;
pub mod window;

/// A PyExpr that can be used on a DataFrame
Expand Down Expand Up @@ -119,8 +121,9 @@ pub fn py_expr_list(expr: &[Expr]) -> PyResult<Vec<PyExpr>> {
impl PyExpr {
/// Return the specific expression
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
Python::with_gil(|_| match &self.expr {
Expr::Alias(alias) => Ok(PyAlias::new(&alias.expr, &alias.name).into_py(py)),
Python::with_gil(|_| {
match &self.expr {
Expr::Alias(alias) => Ok(PyAlias::from(alias.clone()).into_py(py)),
Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)),
Expr::ScalarVariable(data_type, variables) => {
Ok(PyScalarVariable::new(data_type, variables).into_py(py))
Expand All @@ -141,10 +144,44 @@ impl PyExpr {
Expr::AggregateFunction(expr) => {
Ok(PyAggregateFunction::from(expr.clone()).into_py(py))
}
other => Err(py_runtime_err(format!(
"Cannot convert this Expr to a Python object: {:?}",
other
Expr::SimilarTo(value) => Ok(PySimilarTo::from(value.clone()).into_py(py)),
Expr::Between(value) => Ok(between::PyBetween::from(value.clone()).into_py(py)),
Expr::Case(value) => Ok(case::PyCase::from(value.clone()).into_py(py)),
Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_py(py)),
Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_py(py)),
Expr::Sort(value) => Ok(sort_expr::PySortExpr::from(value.clone()).into_py(py)),
Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!(
"Converting Expr::ScalarFunction to a Python object is not implemented: {:?}",
value
))),
Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!(
"Converting Expr::WindowFunction to a Python object is not implemented: {:?}",
value
))),
Expr::InList(value) => Ok(in_list::PyInList::from(value.clone()).into_py(py)),
Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_py(py)),
Expr::InSubquery(value) => {
Ok(in_subquery::PyInSubquery::from(value.clone()).into_py(py))
}
Expr::ScalarSubquery(value) => {
Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_py(py))
}
Expr::Wildcard { qualifier } => Err(py_unsupported_variant_err(format!(
"Converting Expr::Wildcard to a Python object is not implemented : {:?}",
qualifier
))),
Expr::GroupingSet(value) => {
Ok(grouping_set::PyGroupingSet::from(value.clone()).into_py(py))
}
Expr::Placeholder(value) => {
Ok(placeholder::PyPlaceholder::from(value.clone()).into_py(py))
}
Expr::OuterReferenceColumn(data_type, column) => Err(py_unsupported_variant_err(format!(
"Converting Expr::OuterReferenceColumn to a Python object is not implemented: {:?} - {:?}",
data_type, column
))),
Expr::Unnest(value) => Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_py(py)),
}
})
}

Expand Down Expand Up @@ -599,13 +636,15 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<cross_join::PyCrossJoin>()?;
m.add_class::<union::PyUnion>()?;
m.add_class::<unnest::PyUnnest>()?;
m.add_class::<unnest_expr::PyUnnestExpr>()?;
m.add_class::<extension::PyExtension>()?;
m.add_class::<filter::PyFilter>()?;
m.add_class::<projection::PyProjection>()?;
m.add_class::<table_scan::PyTableScan>()?;
m.add_class::<create_memory_table::PyCreateMemoryTable>()?;
m.add_class::<create_view::PyCreateView>()?;
m.add_class::<distinct::PyDistinct>()?;
m.add_class::<sort_expr::PySortExpr>()?;
m.add_class::<subquery_alias::PySubqueryAlias>()?;
m.add_class::<drop_table::PyDropTable>()?;
m.add_class::<repartition::PyPartitioning>()?;
Expand Down
32 changes: 17 additions & 15 deletions src/expr/alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,24 @@ use crate::expr::PyExpr;
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};

use datafusion_expr::Expr;
use datafusion_expr::expr::Alias;

#[pyclass(name = "Alias", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyAlias {
expr: PyExpr,
alias_name: String,
alias: Alias,
}

impl From<Alias> for PyAlias {
fn from(alias: Alias) -> Self {
Self { alias }
}
}

impl From<PyAlias> for Alias {
fn from(py_alias: PyAlias) -> Self {
py_alias.alias
}
}

impl Display for PyAlias {
Expand All @@ -35,29 +46,20 @@ impl Display for PyAlias {
"Alias
\nExpr: `{:?}`
\nAlias Name: `{}`",
&self.expr, &self.alias_name
&self.alias.expr, &self.alias.name
)
}
}

impl PyAlias {
pub fn new(expr: &Expr, alias_name: &String) -> Self {
Self {
expr: expr.clone().into(),
alias_name: alias_name.to_owned(),
}
}
}

#[pymethods]
impl PyAlias {
/// Retrieve the "name" of the alias
fn alias(&self) -> PyResult<String> {
Ok(self.alias_name.clone())
Ok(self.alias.name.clone())
}

fn expr(&self) -> PyResult<PyExpr> {
Ok(self.expr.clone())
Ok((*self.alias.expr.clone()).into())
}

/// Get a String representation of this column
Expand Down
15 changes: 7 additions & 8 deletions src/expr/exists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,30 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_expr::Subquery;
use datafusion_expr::expr::Exists;
use pyo3::prelude::*;

use super::subquery::PySubquery;

#[pyclass(name = "Exists", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyExists {
subquery: Subquery,
negated: bool,
exists: Exists,
}

impl PyExists {
pub fn new(subquery: Subquery, negated: bool) -> Self {
Self { subquery, negated }
impl From<Exists> for PyExists {
fn from(exists: Exists) -> Self {
PyExists { exists }
}
}

#[pymethods]
impl PyExists {
fn subquery(&self) -> PySubquery {
self.subquery.clone().into()
self.exists.subquery.clone().into()
}

fn negated(&self) -> bool {
self.negated
self.exists.negated
}
}
22 changes: 8 additions & 14 deletions src/expr/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,32 @@
// under the License.

use crate::expr::PyExpr;
use datafusion_expr::Expr;
use datafusion_expr::expr::InList;
use pyo3::prelude::*;

#[pyclass(name = "InList", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyInList {
expr: Box<Expr>,
list: Vec<Expr>,
negated: bool,
in_list: InList,
}

impl PyInList {
pub fn new(expr: Box<Expr>, list: Vec<Expr>, negated: bool) -> Self {
Self {
expr,
list,
negated,
}
impl From<InList> for PyInList {
fn from(in_list: InList) -> Self {
PyInList { in_list }
}
}

#[pymethods]
impl PyInList {
fn expr(&self) -> PyExpr {
(*self.expr).clone().into()
(*self.in_list.expr).clone().into()
}

fn list(&self) -> Vec<PyExpr> {
self.list.iter().map(|e| e.clone().into()).collect()
self.in_list.list.iter().map(|e| e.clone().into()).collect()
}

fn negated(&self) -> bool {
self.negated
self.in_list.negated
}
}
22 changes: 8 additions & 14 deletions src/expr/in_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,34 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_expr::{Expr, Subquery};
use datafusion_expr::expr::InSubquery;
use pyo3::prelude::*;

use super::{subquery::PySubquery, PyExpr};

#[pyclass(name = "InSubquery", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyInSubquery {
expr: Box<Expr>,
subquery: Subquery,
negated: bool,
in_subquery: InSubquery,
}

impl PyInSubquery {
pub fn new(expr: Box<Expr>, subquery: Subquery, negated: bool) -> Self {
Self {
expr,
subquery,
negated,
}
impl From<InSubquery> for PyInSubquery {
fn from(in_subquery: InSubquery) -> Self {
PyInSubquery { in_subquery }
}
}

#[pymethods]
impl PyInSubquery {
fn expr(&self) -> PyExpr {
(*self.expr).clone().into()
(*self.in_subquery.expr).clone().into()
}

fn subquery(&self) -> PySubquery {
self.subquery.clone().into()
self.in_subquery.subquery.clone().into()
}

fn negated(&self) -> bool {
self.negated
self.in_subquery.negated
}
}
21 changes: 10 additions & 11 deletions src/expr/placeholder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,33 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::arrow::datatypes::DataType;
use datafusion_expr::expr::Placeholder;
use pyo3::prelude::*;

use crate::common::data_type::PyDataType;

#[pyclass(name = "Placeholder", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyPlaceholder {
id: String,
data_type: Option<DataType>,
placeholder: Placeholder,
}

impl PyPlaceholder {
pub fn new(id: String, data_type: DataType) -> Self {
Self {
id,
data_type: Some(data_type),
}
impl From<Placeholder> for PyPlaceholder {
fn from(placeholder: Placeholder) -> Self {
PyPlaceholder { placeholder }
}
}

#[pymethods]
impl PyPlaceholder {
fn id(&self) -> String {
self.id.clone()
self.placeholder.id.clone()
}

fn data_type(&self) -> Option<PyDataType> {
self.data_type.as_ref().map(|e| e.clone().into())
self.placeholder
.data_type
.as_ref()
.map(|e| e.clone().into())
}
}
Loading
Loading