Skip to content

Commit

Permalink
Pyo3 Bound<'py, T> api (apache#734)
Browse files Browse the repository at this point in the history
* remove gil-refs feature from pyo3

* migrate module instantiation to Bound api

* migrate utils.rs to Bound api

* migrate config.rs to Bound api

* migrate context.rs to Bound api

* migrate udaf.rs to Bound api

* migrate pyarrow_filter_expression to Bound api

* migrate dataframe.rs to Bound api

* migrade dataset and dataset_exec to Bound api

* migrate substrait.rs to Bound api
  • Loading branch information
Michael-J-Ward authored Jun 18, 2024
1 parent 1f49d46 commit c7ea90d
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 115 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ substrait = ["dep:datafusion-substrait"]
[dependencies]
tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.8"
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38", "gil-refs"] }
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] }
arrow = { version = "52", feature = ["pyarrow"] }
datafusion = { version = "39.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-common = { version = "39.0.0", features = ["pyarrow"] }
Expand Down Expand Up @@ -67,3 +67,4 @@ crate-type = ["cdylib", "rlib"]
[profile.release]
lto = true
codegen-units = 1

2 changes: 1 addition & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub mod function;
pub mod schema;

/// Initializes the `common` module to match the pattern of `datafusion-common` https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<df_schema::PyDFSchema>()?;
m.add_class::<data_type::PyDataType>()?;
m.add_class::<data_type::DataTypeMap>()?;
Expand Down
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl PyConfig {

/// Get all configuration options
pub fn get_all(&mut self, py: Python) -> PyResult<PyObject> {
let dict = PyDict::new(py);
let dict = PyDict::new_bound(py);
let options = self.config.to_owned();
for entry in options.entries() {
dict.set_item(entry.key, entry.value.clone().into_py(py))?;
Expand Down
23 changes: 14 additions & 9 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,11 @@ impl PySessionContext {
pub fn register_object_store(
&mut self,
scheme: &str,
store: &PyAny,
store: &Bound<'_, PyAny>,
host: Option<&str>,
) -> PyResult<()> {
let res: Result<(Arc<dyn ObjectStore>, String), PyErr> =
match StorageContexts::extract(store) {
match StorageContexts::extract_bound(store) {
Ok(store) => match store {
StorageContexts::AmazonS3(s3) => Ok((s3.inner, s3.bucket_name)),
StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner, gcs.bucket_name)),
Expand Down Expand Up @@ -443,8 +443,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pylist", args)?.into();

// Convert Arrow Table to datafusion DataFrame
Expand All @@ -463,8 +463,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pydict", args)?.into();

// Convert Arrow Table to datafusion DataFrame
Expand Down Expand Up @@ -507,8 +507,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pandas", args)?.into();

// Convert Arrow Table to datafusion DataFrame
Expand Down Expand Up @@ -710,7 +710,12 @@ impl PySessionContext {
}

// Registers a PyArrow.Dataset
pub fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
pub fn register_dataset(
&self,
name: &str,
dataset: &Bound<'_, PyAny>,
py: Python,
) -> PyResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);

self.ctx
Expand Down
67 changes: 41 additions & 26 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion::prelude::*;
use datafusion_common::UnnestOptions;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyTuple;
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -56,23 +57,25 @@ impl PyDataFrame {

#[pymethods]
impl PyDataFrame {
fn __getitem__(&self, key: PyObject) -> PyResult<Self> {
Python::with_gil(|py| {
if let Ok(key) = key.extract::<&str>(py) {
self.select_columns(vec![key])
} else if let Ok(tuple) = key.extract::<&PyTuple>(py) {
let keys = tuple
.iter()
.map(|item| item.extract::<&str>())
.collect::<PyResult<Vec<&str>>>()?;
self.select_columns(keys)
} else if let Ok(keys) = key.extract::<Vec<&str>>(py) {
self.select_columns(keys)
} else {
let message = "DataFrame can only be indexed by string index or indices";
Err(PyTypeError::new_err(message))
}
})
/// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(key) = key.extract::<PyBackedStr>() {
// df[col]
self.select_columns(vec![key])
} else if let Ok(tuple) = key.extract::<&PyTuple>() {
// df[col1, col2, col3]
let keys = tuple
.iter()
.map(|item| item.extract::<PyBackedStr>())
.collect::<PyResult<Vec<PyBackedStr>>>()?;
self.select_columns(keys)
} else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
// df[[col1, col2, col3]]
self.select_columns(keys)
} else {
let message = "DataFrame can only be indexed by string index or indices";
Err(PyTypeError::new_err(message))
}
}

fn __repr__(&self, py: Python) -> PyResult<String> {
Expand All @@ -98,7 +101,8 @@ impl PyDataFrame {
}

#[pyo3(signature = (*args))]
fn select_columns(&self, args: Vec<&str>) -> PyResult<Self> {
fn select_columns(&self, args: Vec<PyBackedStr>) -> PyResult<Self> {
let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
let df = self.df.as_ref().clone().select_columns(&args)?;
Ok(Self::new(df))
}
Expand Down Expand Up @@ -194,7 +198,7 @@ impl PyDataFrame {
fn join(
&self,
right: PyDataFrame,
join_keys: (Vec<&str>, Vec<&str>),
join_keys: (Vec<PyBackedStr>, Vec<PyBackedStr>),
how: &str,
) -> PyResult<Self> {
let join_type = match how {
Expand All @@ -212,11 +216,22 @@ impl PyDataFrame {
}
};

let left_keys = join_keys
.0
.iter()
.map(|s| s.as_ref())
.collect::<Vec<&str>>();
let right_keys = join_keys
.1
.iter()
.map(|s| s.as_ref())
.collect::<Vec<&str>>();

let df = self.df.as_ref().clone().join(
right.df.as_ref().clone(),
join_type,
&join_keys.0,
&join_keys.1,
&left_keys,
&right_keys,
None,
)?;
Ok(Self::new(df))
Expand Down Expand Up @@ -414,8 +429,8 @@ impl PyDataFrame {

Python::with_gil(|py| {
// Instantiate pyarrow Table object and use its from_batches method
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[batches, schema]);
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[batches, schema]);
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
Ok(table)
})
Expand Down Expand Up @@ -489,8 +504,8 @@ impl PyDataFrame {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
let dataframe = py.import("polars")?.getattr("DataFrame")?;
let args = PyTuple::new(py, &[table]);
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
let args = PyTuple::new_bound(py, &[table]);
let result: PyObject = dataframe.call1(args)?.into();
Ok(result)
})
Expand All @@ -514,7 +529,7 @@ fn print_dataframe(py: Python, df: DataFrame) -> PyResult<()> {

// Import the Python 'builtins' module to access the print function
// Note that println! does not print to the Python debug console and is not visible in notebooks for instance
let print = py.import("builtins")?.getattr("print")?;
let print = py.import_bound("builtins")?.getattr("print")?;
print.call1((result,))?;
Ok(())
}
13 changes: 7 additions & 6 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ pub(crate) struct Dataset {

impl Dataset {
// Creates a Python PyArrow.Dataset
pub fn new(dataset: &PyAny, py: Python) -> PyResult<Self> {
pub fn new(dataset: &Bound<'_, PyAny>, py: Python) -> PyResult<Self> {
// Ensure that we were passed an instance of pyarrow.dataset.Dataset
let ds = PyModule::import(py, "pyarrow.dataset")?;
let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?;
let ds = PyModule::import_bound(py, "pyarrow.dataset")?;
let ds_attr = ds.getattr("Dataset")?;
let ds_type = ds_attr.downcast::<PyType>()?;
if dataset.is_instance(ds_type)? {
Ok(Dataset {
dataset: dataset.into(),
dataset: dataset.clone().unbind(),
})
} else {
Err(PyValueError::new_err(
Expand All @@ -73,7 +74,7 @@ impl TableProvider for Dataset {
/// Get a reference to the schema for this table
fn schema(&self) -> SchemaRef {
Python::with_gil(|py| {
let dataset = self.dataset.as_ref(py);
let dataset = self.dataset.bind(py);
// This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never
Arc::new(
dataset
Expand Down Expand Up @@ -108,7 +109,7 @@ impl TableProvider for Dataset {
) -> DFResult<Arc<dyn ExecutionPlan>> {
Python::with_gil(|py| {
let plan: Arc<dyn ExecutionPlan> = Arc::new(
DatasetExec::new(py, self.dataset.as_ref(py), projection.cloned(), filters)
DatasetExec::new(py, self.dataset.bind(py), projection.cloned(), filters)
.map_err(|err| DataFusionError::External(Box::new(err)))?,
);
Ok(plan)
Expand Down
36 changes: 17 additions & 19 deletions src/dataset_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Iterator for PyArrowBatchesAdapter {

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
let mut batches: &PyIterator = self.batches.as_ref(py);
let mut batches = self.batches.clone().into_bound(py);
Some(
batches
.next()?
Expand All @@ -79,7 +79,7 @@ pub(crate) struct DatasetExec {
impl DatasetExec {
pub fn new(
py: Python,
dataset: &PyAny,
dataset: &Bound<'_, PyAny>,
projection: Option<Vec<usize>>,
filters: &[Expr],
) -> Result<Self, DataFusionError> {
Expand All @@ -103,15 +103,15 @@ impl DatasetExec {
})
.transpose()?;

let kwargs = PyDict::new(py);
let kwargs = PyDict::new_bound(py);

kwargs.set_item("columns", columns.clone())?;
kwargs.set_item(
"filter",
filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
)?;

let scanner = dataset.call_method("scanner", (), Some(kwargs))?;
let scanner = dataset.call_method("scanner", (), Some(&kwargs))?;

let schema = Arc::new(
scanner
Expand All @@ -120,19 +120,17 @@ impl DatasetExec {
.0,
);

let builtins = Python::import(py, "builtins")?;
let builtins = Python::import_bound(py, "builtins")?;
let pylist = builtins.getattr("list")?;

// Get the fragments or partitions of the dataset
let fragments_iterator: &PyAny = dataset.call_method1(
let fragments_iterator: Bound<'_, PyAny> = dataset.call_method1(
"get_fragments",
(filter_expr.as_ref().map(|expr| expr.clone_ref(py)),),
)?;

let fragments: &PyList = pylist
.call1((fragments_iterator,))?
.downcast()
.map_err(PyErr::from)?;
let fragments_iter = pylist.call1((fragments_iterator,))?;
let fragments = fragments_iter.downcast::<PyList>().map_err(PyErr::from)?;

let projected_statistics = Statistics::new_unknown(&schema);
let plan_properties = datafusion::physical_plan::PlanProperties::new(
Expand All @@ -142,9 +140,9 @@ impl DatasetExec {
);

Ok(DatasetExec {
dataset: dataset.into(),
dataset: dataset.clone().unbind(),
schema,
fragments: fragments.into(),
fragments: fragments.clone().unbind(),
columns,
filter_expr,
projected_statistics,
Expand Down Expand Up @@ -183,8 +181,8 @@ impl ExecutionPlan for DatasetExec {
) -> DFResult<SendableRecordBatchStream> {
let batch_size = context.session_config().batch_size();
Python::with_gil(|py| {
let dataset = self.dataset.as_ref(py);
let fragments = self.fragments.as_ref(py);
let dataset = self.dataset.bind(py);
let fragments = self.fragments.bind(py);
let fragment = fragments
.get_item(partition)
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
Expand All @@ -193,7 +191,7 @@ impl ExecutionPlan for DatasetExec {
let dataset_schema = dataset
.getattr("schema")
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
let kwargs = PyDict::new(py);
let kwargs = PyDict::new_bound(py);
kwargs
.set_item("columns", self.columns.clone())
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
Expand All @@ -207,15 +205,15 @@ impl ExecutionPlan for DatasetExec {
.set_item("batch_size", batch_size)
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
let scanner = fragment
.call_method("scanner", (dataset_schema,), Some(kwargs))
.call_method("scanner", (dataset_schema,), Some(&kwargs))
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
let schema: SchemaRef = Arc::new(
scanner
.getattr("projected_schema")
.and_then(|schema| Ok(schema.extract::<PyArrowType<_>>()?.0))
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?,
);
let record_batches: &PyIterator = scanner
let record_batches: Bound<'_, PyIterator> = scanner
.call_method0("to_batches")
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?
.iter()
Expand Down Expand Up @@ -264,7 +262,7 @@ impl ExecutionPlanProperties for DatasetExec {
impl DisplayAs for DatasetExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Python::with_gil(|py| {
let number_of_fragments = self.fragments.as_ref(py).len();
let number_of_fragments = self.fragments.bind(py).len();
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let projected_columns: Vec<String> = self
Expand All @@ -274,7 +272,7 @@ impl DisplayAs for DatasetExec {
.map(|x| x.name().to_owned())
.collect();
if let Some(filter_expr) = &self.filter_expr {
let filter_expr = filter_expr.as_ref(py).str().or(Err(std::fmt::Error))?;
let filter_expr = filter_expr.bind(py).str().or(Err(std::fmt::Error))?;
write!(
f,
"DatasetExec: number_of_fragments={}, filter_expr={}, projection=[{}]",
Expand Down
2 changes: 1 addition & 1 deletion src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ impl PyExpr {
}

/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyExpr>()?;
m.add_class::<PyColumn>()?;
m.add_class::<PyLiteral>()?;
Expand Down
Loading

0 comments on commit c7ea90d

Please sign in to comment.