From 912f789e7f73877d2c262fbe2d4d41e34812294f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 10 Feb 2024 11:09:45 -0700 Subject: [PATCH] [Python] Add `read_csv` and `read_parquet` methods (#976) --- python/Cargo.toml | 9 +- python/README.md | 27 +++-- python/pyballista/__init__.py | 2 - python/pyballista/tests/test_context.py | 27 ++++- python/requirements.txt | 1 + python/src/context.rs | 153 ++++++++++++++++++++++++ python/src/lib.rs | 98 +-------------- python/src/utils.rs | 24 ++++ python/testdata/test.csv | 2 + python/testdata/test.parquet | Bin 0 -> 1851 bytes 10 files changed, 237 insertions(+), 106 deletions(-) create mode 100644 python/src/context.rs create mode 100644 python/src/utils.rs create mode 100644 python/testdata/test.csv create mode 100755 python/testdata/test.parquet diff --git a/python/Cargo.toml b/python/Cargo.toml index 6a63b6f73..833f6c7c2 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -27,11 +27,18 @@ license = "Apache-2.0" edition = "2021" rust-version = "1.64" include = ["/src", "/pyballista", "/LICENSE.txt", "pyproject.toml", "Cargo.toml", "Cargo.lock"] +publish = false [dependencies] +async-trait = "0.1.77" ballista = { path = "../ballista/client", version = "0.12.0" } +ballista-core = { path = "../ballista/core", version = "0.12.0" } datafusion = "35.0.0" -datafusion-python = "35.0.0" +datafusion-proto = "35.0.0" + +# we need to use a recent build of ADP that has a public PyDataFrame +datafusion-python = { git = "https://github.com/apache/arrow-datafusion-python", rev = "5296c0cfcf8e6fcb654d5935252469bf04f929e9" } + pyo3 = { version = "0.20", features = ["extension-module", "abi3", "abi3-py38"] } tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] } diff --git a/python/README.md b/python/README.md index 1819a7201..2898cb165 100644 --- a/python/README.md +++ b/python/README.md @@ -19,26 +19,33 @@ # PyBallista -Minimal Python client for Ballista. - -The goal of this project is to provide a way to run SQL against a Ballista cluster from Python and collect -results as PyArrow record batches. - -Note that this client currently only provides a SQL API and not a DataFrame API. A future release will support -using the DataFrame API from DataFusion's Python bindings to create a logical plan and then execute that logical plan -from the Ballista context ([tracking issue](https://github.com/apache/arrow-ballista/issues/971)). +Python client for Ballista. This project is versioned and released independently from the main Ballista project and is intentionally not part of the default Cargo workspace so that it doesn't cause overhead for maintainers of the main Ballista codebase. -## Example Usage +## Creating a SessionContext + +Creates a new context and connects to a Ballista scheduler process. ```python from pyballista import SessionContext >>> ctx = SessionContext("localhost", 50050) +``` + +## Example SQL Usage + +```python >>> ctx.sql("create external table t stored as parquet location '/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet'") >>> df = ctx.sql("select * from t limit 5") ->>> df.collect() +>>> pyarrow_batches = df.collect() +``` + +## Example DataFrame Usage + +```python +>>> df = ctx.read_parquet('/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet').limit(5) +>>> pyarrow_batches = df.collect() ``` ## Creating Virtual Environment diff --git a/python/pyballista/__init__.py b/python/pyballista/__init__.py index 480e9edab..62a6bc790 100644 --- a/python/pyballista/__init__.py +++ b/python/pyballista/__init__.py @@ -27,12 +27,10 @@ from .pyballista_internal import ( SessionContext, - DataFrame ) __version__ = importlib_metadata.version(__name__) __all__ = [ "SessionContext", - "DataFrame", ] diff --git a/python/pyballista/tests/test_context.py b/python/pyballista/tests/test_context.py index 46e67e10b..9c264e98c 100644 --- a/python/pyballista/tests/test_context.py +++ b/python/pyballista/tests/test_context.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from pyballista import SessionContext import pytest @@ -23,4 +24,28 @@ def test_create_context(): def test_select_one(): ctx = SessionContext("localhost", 50050) df = ctx.sql("SELECT 1") - df.collect() \ No newline at end of file + batches = df.collect() + assert len(batches) == 1 + +def test_read_csv(): + ctx = SessionContext("localhost", 50050) + df = ctx.read_csv("testdata/test.csv", has_header=True) + batches = df.collect() + assert len(batches) == 1 + assert len(batches[0]) == 1 + +def test_read_parquet(): + ctx = SessionContext("localhost", 50050) + df = ctx.read_parquet("testdata/test.parquet") + batches = df.collect() + assert len(batches) == 1 + assert len(batches[0]) == 8 + +def test_read_dataframe_api(): + ctx = SessionContext("localhost", 50050) + df = ctx.read_csv("testdata/test.csv", has_header=True) \ + .select_columns('a', 'b') \ + .limit(1) + batches = df.collect() + assert len(batches) == 1 + assert len(batches[0]) == 1 diff --git a/python/requirements.txt b/python/requirements.txt index f6acb1761..a03a8f8d2 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,2 +1,3 @@ +datafusion==35.0.0 pyarrow pytest \ No newline at end of file diff --git a/python/src/context.rs b/python/src/context.rs new file mode 100644 index 000000000..d9e7feeee --- /dev/null +++ b/python/src/context.rs @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::prelude::*; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use std::path::PathBuf; + +use crate::utils::to_pyerr; +use ballista::prelude::*; +use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::pyarrow::PyArrowType; +use datafusion_python::context::{ + convert_table_partition_cols, parse_file_compression_type, +}; +use datafusion_python::dataframe::PyDataFrame; +use datafusion_python::errors::DataFusionError; +use datafusion_python::expr::PyExpr; +use datafusion_python::utils::wait_for_future; + +/// PyBallista session context. This is largely a duplicate of +/// DataFusion's PySessionContext, with the main difference being +/// that this operates on a BallistaContext instead of DataFusion's +/// SessionContext. We could probably add extra extension points to +/// DataFusion to allow for a pluggable context and remove much of +/// this code. +#[pyclass(name = "SessionContext", module = "pyballista", subclass)] +pub struct PySessionContext { + ctx: BallistaContext, +} + +#[pymethods] +impl PySessionContext { + /// Create a new SessionContext by connecting to a Ballista scheduler process. + #[new] + pub fn new(host: &str, port: u16, py: Python) -> PyResult { + let config = BallistaConfig::new().unwrap(); + let ballista_context = BallistaContext::remote(host, port, &config); + let ctx = wait_for_future(py, ballista_context).map_err(to_pyerr)?; + Ok(Self { ctx }) + } + + pub fn sql(&mut self, query: &str, py: Python) -> PyResult { + let result = self.ctx.sql(query); + let df = wait_for_future(py, result)?; + Ok(PyDataFrame::new(df)) + } + + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = ( + path, + schema=None, + has_header=true, + delimiter=",", + schema_infer_max_records=1000, + file_extension=".csv", + table_partition_cols=vec![], + file_compression_type=None))] + pub fn read_csv( + &self, + path: PathBuf, + schema: Option>, + has_header: bool, + delimiter: &str, + schema_infer_max_records: usize, + file_extension: &str, + table_partition_cols: Vec<(String, String)>, + file_compression_type: Option, + py: Python, + ) -> PyResult { + let path = path + .to_str() + .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; + + let delimiter = delimiter.as_bytes(); + if delimiter.len() != 1 { + return Err(PyValueError::new_err( + "Delimiter must be a single character", + )); + }; + + let mut options = CsvReadOptions::new() + .has_header(has_header) + .delimiter(delimiter[0]) + .schema_infer_max_records(schema_infer_max_records) + .file_extension(file_extension) + .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .file_compression_type(parse_file_compression_type(file_compression_type)?); + + if let Some(py_schema) = schema { + options.schema = Some(&py_schema.0); + let result = self.ctx.read_csv(path, options); + let df = PyDataFrame::new(wait_for_future(py, result)?); + Ok(df) + } else { + let result = self.ctx.read_csv(path, options); + let df = PyDataFrame::new(wait_for_future(py, result)?); + Ok(df) + } + } + + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = ( + path, + table_partition_cols=vec![], + parquet_pruning=true, + file_extension=".parquet", + skip_metadata=true, + schema=None, + file_sort_order=None))] + pub fn read_parquet( + &self, + path: &str, + table_partition_cols: Vec<(String, String)>, + parquet_pruning: bool, + file_extension: &str, + skip_metadata: bool, + schema: Option>, + file_sort_order: Option>>, + py: Python, + ) -> PyResult { + let mut options = ParquetReadOptions::default() + .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .parquet_pruning(parquet_pruning) + .skip_metadata(skip_metadata); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + options.file_sort_order = file_sort_order + .unwrap_or_default() + .into_iter() + .map(|e| e.into_iter().map(|f| f.into()).collect()) + .collect(); + + let result = self.ctx.read_parquet(path, options); + let df = + PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?); + Ok(df) + } +} diff --git a/python/src/lib.rs b/python/src/lib.rs index 186a570e0..04cf232a2 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -15,103 +15,17 @@ // specific language governing permissions and limitations // under the License. -use datafusion::arrow::pyarrow::ToPyArrow; -use datafusion::prelude::DataFrame; -use pyo3::exceptions::PyException; use pyo3::prelude::*; -use std::future::Future; -use std::sync::Arc; -use tokio::runtime::Runtime; +pub mod context; +mod utils; -use ballista::prelude::*; - -/// PyBallista SessionContext -#[pyclass(name = "SessionContext", module = "pyballista", subclass)] -pub struct PySessionContext { - ctx: BallistaContext, -} - -#[pymethods] -impl PySessionContext { - #[new] - pub fn new(host: &str, port: u16, py: Python) -> PyResult { - let config = BallistaConfig::new().unwrap(); - let ballista_context = BallistaContext::remote(host, port, &config); - let ctx = wait_for_future(py, ballista_context).map_err(to_pyerr)?; - Ok(Self { ctx }) - } - - pub fn sql(&mut self, query: &str, py: Python) -> PyResult { - let result = self.ctx.sql(query); - let df = wait_for_future(py, result)?; - Ok(PyDataFrame::new(df)) - } -} - -#[pyclass(name = "DataFrame", module = "pyballista", subclass)] -#[derive(Clone)] -pub struct PyDataFrame { - /// DataFusion DataFrame - df: Arc, -} - -impl PyDataFrame { - /// creates a new PyDataFrame - pub fn new(df: DataFrame) -> Self { - Self { df: Arc::new(df) } - } -} - -#[pymethods] -impl PyDataFrame { - /// Executes the plan, returning a list of `RecordBatch`es. - /// Unless some order is specified in the plan, there is no - /// guarantee of the order of the result. - fn collect(&self, py: Python) -> PyResult> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect())?; - // cannot use PyResult> return type due to - // https://github.com/PyO3/pyo3/issues/1813 - batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() - } -} - -fn wait_for_future(py: Python, f: F) -> F::Output -where - F: Send, - F::Output: Send, -{ - let runtime: &Runtime = &get_tokio_runtime(py).0; - py.allow_threads(|| runtime.block_on(f)) -} - -fn get_tokio_runtime(py: Python) -> PyRef { - let ballista = py.import("pyballista._internal").unwrap(); - let tmp = ballista.getattr("runtime").unwrap(); - match tmp.extract::>() { - Ok(runtime) => runtime, - Err(_e) => { - let rt = TokioRuntime(tokio::runtime::Runtime::new().unwrap()); - let obj: &PyAny = Py::new(py, rt).unwrap().into_ref(py); - obj.extract().unwrap() - } - } -} - -fn to_pyerr(err: BallistaError) -> PyErr { - PyException::new_err(err.to_string()) -} - -#[pyclass] -pub(crate) struct TokioRuntime(tokio::runtime::Runtime); +pub use crate::context::PySessionContext; #[pymodule] fn pyballista_internal(_py: Python, m: &PyModule) -> PyResult<()> { - // Register the Tokio Runtime as a module attribute so we can reuse it - m.add( - "runtime", - TokioRuntime(tokio::runtime::Runtime::new().unwrap()), - )?; + // Ballista structs m.add_class::()?; - m.add_class::()?; + // DataFusion structs + m.add_class::()?; Ok(()) } diff --git a/python/src/utils.rs b/python/src/utils.rs new file mode 100644 index 000000000..10278537e --- /dev/null +++ b/python/src/utils.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::error::BallistaError; +use pyo3::exceptions::PyException; +use pyo3::PyErr; + +pub(crate) fn to_pyerr(err: BallistaError) -> PyErr { + PyException::new_err(err.to_string()) +} diff --git a/python/testdata/test.csv b/python/testdata/test.csv new file mode 100644 index 000000000..00910b0fe --- /dev/null +++ b/python/testdata/test.csv @@ -0,0 +1,2 @@ +a,b +1,2 \ No newline at end of file diff --git a/python/testdata/test.parquet b/python/testdata/test.parquet new file mode 100755 index 0000000000000000000000000000000000000000..a63f5dca7c3821909748f34752966a0d7e08d47f GIT binary patch literal 1851 zcmb7F&ui0g6#pje+AM2l7<*qVv_zM;YQD;X!1G_`Ye@W}TbqmweOL$NPTX=lg!8G{0y<66Rp8 z2pS|Aqlfj;PSH-&mT4zwizU$p1|0Y`VGJoyS_YbwNId;`jBg|z+DN8!b`S=|Sr$Ee51+|8u<)E>*X#arx$Xz24Q}8O`6X`}Xhr%F1Zjm_hG6In z7b&rg?-Cs*15K~?$g4Hmpi6uSk7fKqck31Rd$NO@X;dxW?*`sZ;$!02EAVEj1Dx*0 z{MLuNloZ0uLp~A&5eQYhXi-eh3&y9k4#_aQs_m^t;cL8xuhMu#`94GW^Wlpd7r_2h zbWlRre%G&Crz3oz;A@fee~~T(>&n~(=)0;8>IvyeeZ%&hb^-l_Dsj zFvuM<3gd=3Zp;SqWJI2b$YdaF$o()3pD7?YQ98!mj1HO5|D}r6be0>LFTr05hN163Gw zN`^s(6x}&&X(N%RnM7u%??nSFr{~`PZ?MG}V6i6>#vU;kXJ%l`=Er#5j4|7?$M(UP z-GIH6Am3BD#zq&s>YC+S`G?MW!>iZw=2&6OxPE8h?#;!8`C@+5-thcNe#V-dsZ?y! oaow58so4p;;FgVPyFBeqnNHc9FdWl$-SX^Jc0`|z5`8xR0vs|-V*mgE literal 0 HcmV?d00001