Skip to content

Commit

Permalink
Add PyExpr to_variant conversions (apache#793)
Browse files Browse the repository at this point in the history
* make PyExpr::to_variant arms explicit

* update PyInList to wrap expr::InList

* update PyExists to wrap expr::Exists

* update PyInSubquery to wrap expr::InSubquery

* update Placeholder to wrap expr::Placeholder

* make PyLogicalPlan::to_variant match arms explicit

* add PySortExpr wrapper

* add PyUnnestExpr wrapper

* update PyAlias to wrap upstream Alias

* return not implemented error for unimplemnted variants in PyExpr::to_variant

* added to_variant python test from the GH issue

* remove unused import

* return unsupported_variants for unimplemented variants in  PyLogicalPlan::to_variant
  • Loading branch information
Michael-J-Ward authored Aug 4, 2024
1 parent 9a6805e commit 3eb198b
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 71 deletions.
28 changes: 28 additions & 0 deletions python/datafusion/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,31 @@ def test_relational_expr(test_ctx):
assert df.filter(col("b") != "beta").count() == 2

assert df.filter(col("a") == "beta").count() == 0


def test_expr_to_variant():
# Taken from https://github.com/apache/datafusion-python/issues/781
from datafusion import SessionContext
from datafusion.expr import Filter


def traverse_logical_plan(plan):
cur_node = plan.to_variant()
if isinstance(cur_node, Filter):
return cur_node.predicate().to_variant()
if hasattr(plan, 'inputs'):
for input_plan in plan.inputs():
res = traverse_logical_plan(input_plan)
if res is not None:
return res

ctx = SessionContext()
data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']}
ctx.from_pydict(data, name='table1')
query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')"
logical_plan = ctx.sql(query).optimized_logical_plan()
variant = traverse_logical_plan(logical_plan)
assert variant is not None
assert variant.expr().to_variant().qualified_name() == 'table1.name'
assert str(variant.list()) == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]'
assert not variant.negated()
51 changes: 45 additions & 6 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_expr::{
};

use crate::common::data_type::{DataTypeMap, RexType};
use crate::errors::{py_runtime_err, py_type_err, DataFusionError};
use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, DataFusionError};
use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
Expand Down Expand Up @@ -84,11 +84,13 @@ pub mod scalar_subquery;
pub mod scalar_variable;
pub mod signature;
pub mod sort;
pub mod sort_expr;
pub mod subquery;
pub mod subquery_alias;
pub mod table_scan;
pub mod union;
pub mod unnest;
pub mod unnest_expr;
pub mod window;

/// A PyExpr that can be used on a DataFrame
Expand Down Expand Up @@ -119,8 +121,9 @@ pub fn py_expr_list(expr: &[Expr]) -> PyResult<Vec<PyExpr>> {
impl PyExpr {
/// Return the specific expression
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
Python::with_gil(|_| match &self.expr {
Expr::Alias(alias) => Ok(PyAlias::new(&alias.expr, &alias.name).into_py(py)),
Python::with_gil(|_| {
match &self.expr {
Expr::Alias(alias) => Ok(PyAlias::from(alias.clone()).into_py(py)),
Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)),
Expr::ScalarVariable(data_type, variables) => {
Ok(PyScalarVariable::new(data_type, variables).into_py(py))
Expand All @@ -141,10 +144,44 @@ impl PyExpr {
Expr::AggregateFunction(expr) => {
Ok(PyAggregateFunction::from(expr.clone()).into_py(py))
}
other => Err(py_runtime_err(format!(
"Cannot convert this Expr to a Python object: {:?}",
other
Expr::SimilarTo(value) => Ok(PySimilarTo::from(value.clone()).into_py(py)),
Expr::Between(value) => Ok(between::PyBetween::from(value.clone()).into_py(py)),
Expr::Case(value) => Ok(case::PyCase::from(value.clone()).into_py(py)),
Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_py(py)),
Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_py(py)),
Expr::Sort(value) => Ok(sort_expr::PySortExpr::from(value.clone()).into_py(py)),
Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!(
"Converting Expr::ScalarFunction to a Python object is not implemented: {:?}",
value
))),
Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!(
"Converting Expr::WindowFunction to a Python object is not implemented: {:?}",
value
))),
Expr::InList(value) => Ok(in_list::PyInList::from(value.clone()).into_py(py)),
Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_py(py)),
Expr::InSubquery(value) => {
Ok(in_subquery::PyInSubquery::from(value.clone()).into_py(py))
}
Expr::ScalarSubquery(value) => {
Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_py(py))
}
Expr::Wildcard { qualifier } => Err(py_unsupported_variant_err(format!(
"Converting Expr::Wildcard to a Python object is not implemented : {:?}",
qualifier
))),
Expr::GroupingSet(value) => {
Ok(grouping_set::PyGroupingSet::from(value.clone()).into_py(py))
}
Expr::Placeholder(value) => {
Ok(placeholder::PyPlaceholder::from(value.clone()).into_py(py))
}
Expr::OuterReferenceColumn(data_type, column) => Err(py_unsupported_variant_err(format!(
"Converting Expr::OuterReferenceColumn to a Python object is not implemented: {:?} - {:?}",
data_type, column
))),
Expr::Unnest(value) => Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_py(py)),
}
})
}

Expand Down Expand Up @@ -599,13 +636,15 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<cross_join::PyCrossJoin>()?;
m.add_class::<union::PyUnion>()?;
m.add_class::<unnest::PyUnnest>()?;
m.add_class::<unnest_expr::PyUnnestExpr>()?;
m.add_class::<extension::PyExtension>()?;
m.add_class::<filter::PyFilter>()?;
m.add_class::<projection::PyProjection>()?;
m.add_class::<table_scan::PyTableScan>()?;
m.add_class::<create_memory_table::PyCreateMemoryTable>()?;
m.add_class::<create_view::PyCreateView>()?;
m.add_class::<distinct::PyDistinct>()?;
m.add_class::<sort_expr::PySortExpr>()?;
m.add_class::<subquery_alias::PySubqueryAlias>()?;
m.add_class::<drop_table::PyDropTable>()?;
m.add_class::<repartition::PyPartitioning>()?;
Expand Down
32 changes: 17 additions & 15 deletions src/expr/alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,24 @@ use crate::expr::PyExpr;
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};

use datafusion_expr::Expr;
use datafusion_expr::expr::Alias;

#[pyclass(name = "Alias", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyAlias {
expr: PyExpr,
alias_name: String,
alias: Alias,
}

impl From<Alias> for PyAlias {
fn from(alias: Alias) -> Self {
Self { alias }
}
}

impl From<PyAlias> for Alias {
fn from(py_alias: PyAlias) -> Self {
py_alias.alias
}
}

impl Display for PyAlias {
Expand All @@ -35,29 +46,20 @@ impl Display for PyAlias {
"Alias
\nExpr: `{:?}`
\nAlias Name: `{}`",
&self.expr, &self.alias_name
&self.alias.expr, &self.alias.name
)
}
}

impl PyAlias {
pub fn new(expr: &Expr, alias_name: &String) -> Self {
Self {
expr: expr.clone().into(),
alias_name: alias_name.to_owned(),
}
}
}

#[pymethods]
impl PyAlias {
/// Retrieve the "name" of the alias
fn alias(&self) -> PyResult<String> {
Ok(self.alias_name.clone())
Ok(self.alias.name.clone())
}

fn expr(&self) -> PyResult<PyExpr> {
Ok(self.expr.clone())
Ok((*self.alias.expr.clone()).into())
}

/// Get a String representation of this column
Expand Down
15 changes: 7 additions & 8 deletions src/expr/exists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,30 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_expr::Subquery;
use datafusion_expr::expr::Exists;
use pyo3::prelude::*;

use super::subquery::PySubquery;

#[pyclass(name = "Exists", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyExists {
subquery: Subquery,
negated: bool,
exists: Exists,
}

impl PyExists {
pub fn new(subquery: Subquery, negated: bool) -> Self {
Self { subquery, negated }
impl From<Exists> for PyExists {
fn from(exists: Exists) -> Self {
PyExists { exists }
}
}

#[pymethods]
impl PyExists {
fn subquery(&self) -> PySubquery {
self.subquery.clone().into()
self.exists.subquery.clone().into()
}

fn negated(&self) -> bool {
self.negated
self.exists.negated
}
}
22 changes: 8 additions & 14 deletions src/expr/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,32 @@
// under the License.

use crate::expr::PyExpr;
use datafusion_expr::Expr;
use datafusion_expr::expr::InList;
use pyo3::prelude::*;

#[pyclass(name = "InList", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyInList {
expr: Box<Expr>,
list: Vec<Expr>,
negated: bool,
in_list: InList,
}

impl PyInList {
pub fn new(expr: Box<Expr>, list: Vec<Expr>, negated: bool) -> Self {
Self {
expr,
list,
negated,
}
impl From<InList> for PyInList {
fn from(in_list: InList) -> Self {
PyInList { in_list }
}
}

#[pymethods]
impl PyInList {
fn expr(&self) -> PyExpr {
(*self.expr).clone().into()
(*self.in_list.expr).clone().into()
}

fn list(&self) -> Vec<PyExpr> {
self.list.iter().map(|e| e.clone().into()).collect()
self.in_list.list.iter().map(|e| e.clone().into()).collect()
}

fn negated(&self) -> bool {
self.negated
self.in_list.negated
}
}
22 changes: 8 additions & 14 deletions src/expr/in_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,34 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_expr::{Expr, Subquery};
use datafusion_expr::expr::InSubquery;
use pyo3::prelude::*;

use super::{subquery::PySubquery, PyExpr};

#[pyclass(name = "InSubquery", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyInSubquery {
expr: Box<Expr>,
subquery: Subquery,
negated: bool,
in_subquery: InSubquery,
}

impl PyInSubquery {
pub fn new(expr: Box<Expr>, subquery: Subquery, negated: bool) -> Self {
Self {
expr,
subquery,
negated,
}
impl From<InSubquery> for PyInSubquery {
fn from(in_subquery: InSubquery) -> Self {
PyInSubquery { in_subquery }
}
}

#[pymethods]
impl PyInSubquery {
fn expr(&self) -> PyExpr {
(*self.expr).clone().into()
(*self.in_subquery.expr).clone().into()
}

fn subquery(&self) -> PySubquery {
self.subquery.clone().into()
self.in_subquery.subquery.clone().into()
}

fn negated(&self) -> bool {
self.negated
self.in_subquery.negated
}
}
21 changes: 10 additions & 11 deletions src/expr/placeholder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,33 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::arrow::datatypes::DataType;
use datafusion_expr::expr::Placeholder;
use pyo3::prelude::*;

use crate::common::data_type::PyDataType;

#[pyclass(name = "Placeholder", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyPlaceholder {
id: String,
data_type: Option<DataType>,
placeholder: Placeholder,
}

impl PyPlaceholder {
pub fn new(id: String, data_type: DataType) -> Self {
Self {
id,
data_type: Some(data_type),
}
impl From<Placeholder> for PyPlaceholder {
fn from(placeholder: Placeholder) -> Self {
PyPlaceholder { placeholder }
}
}

#[pymethods]
impl PyPlaceholder {
fn id(&self) -> String {
self.id.clone()
self.placeholder.id.clone()
}

fn data_type(&self) -> Option<PyDataType> {
self.data_type.as_ref().map(|e| e.clone().into())
self.placeholder
.data_type
.as_ref()
.map(|e| e.clone().into())
}
}
Loading

0 comments on commit 3eb198b

Please sign in to comment.