Skip to content

Commit

Permalink
[Python] Add read_csv and read_parquet methods (#976)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Feb 10, 2024
1 parent c481363 commit 912f789
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 106 deletions.
9 changes: 8 additions & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ license = "Apache-2.0"
edition = "2021"
rust-version = "1.64"
include = ["/src", "/pyballista", "/LICENSE.txt", "pyproject.toml", "Cargo.toml", "Cargo.lock"]
publish = false

[dependencies]
async-trait = "0.1.77"
ballista = { path = "../ballista/client", version = "0.12.0" }
ballista-core = { path = "../ballista/core", version = "0.12.0" }
datafusion = "35.0.0"
datafusion-python = "35.0.0"
datafusion-proto = "35.0.0"

# we need to use a recent build of ADP that has a public PyDataFrame
datafusion-python = { git = "https://github.com/apache/arrow-datafusion-python", rev = "5296c0cfcf8e6fcb654d5935252469bf04f929e9" }

pyo3 = { version = "0.20", features = ["extension-module", "abi3", "abi3-py38"] }
tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] }

Expand Down
27 changes: 17 additions & 10 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,33 @@

# PyBallista

Minimal Python client for Ballista.

The goal of this project is to provide a way to run SQL against a Ballista cluster from Python and collect
results as PyArrow record batches.

Note that this client currently only provides a SQL API and not a DataFrame API. A future release will support
using the DataFrame API from DataFusion's Python bindings to create a logical plan and then execute that logical plan
from the Ballista context ([tracking issue](https://github.com/apache/arrow-ballista/issues/971)).
Python client for Ballista.

This project is versioned and released independently from the main Ballista project and is intentionally not
part of the default Cargo workspace so that it doesn't cause overhead for maintainers of the main Ballista codebase.

## Example Usage
## Creating a SessionContext

Creates a new context and connects to a Ballista scheduler process.

```python
from pyballista import SessionContext
>>> ctx = SessionContext("localhost", 50050)
```

## Example SQL Usage

```python
>>> ctx.sql("create external table t stored as parquet location '/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet'")
>>> df = ctx.sql("select * from t limit 5")
>>> df.collect()
>>> pyarrow_batches = df.collect()
```

## Example DataFrame Usage

```python
>>> df = ctx.read_parquet('/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet').limit(5)
>>> pyarrow_batches = df.collect()
```

## Creating Virtual Environment
Expand Down
2 changes: 0 additions & 2 deletions python/pyballista/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@

from .pyballista_internal import (
SessionContext,
DataFrame
)

__version__ = importlib_metadata.version(__name__)

__all__ = [
"SessionContext",
"DataFrame",
]
27 changes: 26 additions & 1 deletion python/pyballista/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from pyballista import SessionContext
import pytest

Expand All @@ -23,4 +24,28 @@ def test_create_context():
def test_select_one():
ctx = SessionContext("localhost", 50050)
df = ctx.sql("SELECT 1")
df.collect()
batches = df.collect()
assert len(batches) == 1

def test_read_csv():
ctx = SessionContext("localhost", 50050)
df = ctx.read_csv("testdata/test.csv", has_header=True)
batches = df.collect()
assert len(batches) == 1
assert len(batches[0]) == 1

def test_read_parquet():
ctx = SessionContext("localhost", 50050)
df = ctx.read_parquet("testdata/test.parquet")
batches = df.collect()
assert len(batches) == 1
assert len(batches[0]) == 8

def test_read_dataframe_api():
ctx = SessionContext("localhost", 50050)
df = ctx.read_csv("testdata/test.csv", has_header=True) \
.select_columns('a', 'b') \
.limit(1)
batches = df.collect()
assert len(batches) == 1
assert len(batches[0]) == 1
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
datafusion==35.0.0
pyarrow
pytest
153 changes: 153 additions & 0 deletions python/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion::prelude::*;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::path::PathBuf;

use crate::utils::to_pyerr;
use ballista::prelude::*;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion_python::context::{
convert_table_partition_cols, parse_file_compression_type,
};
use datafusion_python::dataframe::PyDataFrame;
use datafusion_python::errors::DataFusionError;
use datafusion_python::expr::PyExpr;
use datafusion_python::utils::wait_for_future;

/// PyBallista session context. This is largely a duplicate of
/// DataFusion's PySessionContext, with the main difference being
/// that this operates on a BallistaContext instead of DataFusion's
/// SessionContext. We could probably add extra extension points to
/// DataFusion to allow for a pluggable context and remove much of
/// this code.
#[pyclass(name = "SessionContext", module = "pyballista", subclass)]
pub struct PySessionContext {
ctx: BallistaContext,
}

#[pymethods]
impl PySessionContext {
/// Create a new SessionContext by connecting to a Ballista scheduler process.
#[new]
pub fn new(host: &str, port: u16, py: Python) -> PyResult<Self> {
let config = BallistaConfig::new().unwrap();
let ballista_context = BallistaContext::remote(host, port, &config);
let ctx = wait_for_future(py, ballista_context).map_err(to_pyerr)?;
Ok(Self { ctx })
}

pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result)?;
Ok(PyDataFrame::new(df))
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (
path,
schema=None,
has_header=true,
delimiter=",",
schema_infer_max_records=1000,
file_extension=".csv",
table_partition_cols=vec![],
file_compression_type=None))]
pub fn read_csv(
&self,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
has_header: bool,
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
file_compression_type: Option<String>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;

let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
return Err(PyValueError::new_err(
"Delimiter must be a single character",
));
};

let mut options = CsvReadOptions::new()
.has_header(has_header)
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension)
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.file_compression_type(parse_file_compression_type(file_compression_type)?);

if let Some(py_schema) = schema {
options.schema = Some(&py_schema.0);
let result = self.ctx.read_csv(path, options);
let df = PyDataFrame::new(wait_for_future(py, result)?);
Ok(df)
} else {
let result = self.ctx.read_csv(path, options);
let df = PyDataFrame::new(wait_for_future(py, result)?);
Ok(df)
}
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (
path,
table_partition_cols=vec![],
parquet_pruning=true,
file_extension=".parquet",
skip_metadata=true,
schema=None,
file_sort_order=None))]
pub fn read_parquet(
&self,
path: &str,
table_partition_cols: Vec<(String, String)>,
parquet_pruning: bool,
file_extension: &str,
skip_metadata: bool,
schema: Option<PyArrowType<Schema>>,
file_sort_order: Option<Vec<Vec<PyExpr>>>,
py: Python,
) -> PyResult<PyDataFrame> {
let mut options = ParquetReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.parquet_pruning(parquet_pruning)
.skip_metadata(skip_metadata);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);
options.file_sort_order = file_sort_order
.unwrap_or_default()
.into_iter()
.map(|e| e.into_iter().map(|f| f.into()).collect())
.collect();

let result = self.ctx.read_parquet(path, options);
let df =
PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
}
}
98 changes: 6 additions & 92 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,103 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::prelude::DataFrame;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use std::future::Future;
use std::sync::Arc;
use tokio::runtime::Runtime;
pub mod context;
mod utils;

use ballista::prelude::*;

/// PyBallista SessionContext
#[pyclass(name = "SessionContext", module = "pyballista", subclass)]
pub struct PySessionContext {
ctx: BallistaContext,
}

#[pymethods]
impl PySessionContext {
#[new]
pub fn new(host: &str, port: u16, py: Python) -> PyResult<Self> {
let config = BallistaConfig::new().unwrap();
let ballista_context = BallistaContext::remote(host, port, &config);
let ctx = wait_for_future(py, ballista_context).map_err(to_pyerr)?;
Ok(Self { ctx })
}

pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result)?;
Ok(PyDataFrame::new(df))
}
}

#[pyclass(name = "DataFrame", module = "pyballista", subclass)]
#[derive(Clone)]
pub struct PyDataFrame {
/// DataFusion DataFrame
df: Arc<DataFrame>,
}

impl PyDataFrame {
/// creates a new PyDataFrame
pub fn new(df: DataFrame) -> Self {
Self { df: Arc::new(df) }
}
}

#[pymethods]
impl PyDataFrame {
/// Executes the plan, returning a list of `RecordBatch`es.
/// Unless some order is specified in the plan, there is no
/// guarantee of the order of the result.
fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
let batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
// cannot use PyResult<Vec<RecordBatch>> return type due to
// https://github.com/PyO3/pyo3/issues/1813
batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
}
}

fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
where
F: Send,
F::Output: Send,
{
let runtime: &Runtime = &get_tokio_runtime(py).0;
py.allow_threads(|| runtime.block_on(f))
}

fn get_tokio_runtime(py: Python) -> PyRef<TokioRuntime> {
let ballista = py.import("pyballista._internal").unwrap();
let tmp = ballista.getattr("runtime").unwrap();
match tmp.extract::<PyRef<TokioRuntime>>() {
Ok(runtime) => runtime,
Err(_e) => {
let rt = TokioRuntime(tokio::runtime::Runtime::new().unwrap());
let obj: &PyAny = Py::new(py, rt).unwrap().into_ref(py);
obj.extract().unwrap()
}
}
}

fn to_pyerr(err: BallistaError) -> PyErr {
PyException::new_err(err.to_string())
}

#[pyclass]
pub(crate) struct TokioRuntime(tokio::runtime::Runtime);
pub use crate::context::PySessionContext;

#[pymodule]
fn pyballista_internal(_py: Python, m: &PyModule) -> PyResult<()> {
// Register the Tokio Runtime as a module attribute so we can reuse it
m.add(
"runtime",
TokioRuntime(tokio::runtime::Runtime::new().unwrap()),
)?;
// Ballista structs
m.add_class::<PySessionContext>()?;
m.add_class::<PyDataFrame>()?;
// DataFusion structs
m.add_class::<datafusion_python::dataframe::PyDataFrame>()?;
Ok(())
}
Loading

0 comments on commit 912f789

Please sign in to comment.