Skip to content

Commit

Permalink
feat(python): Add a low-friction sql method for DataFrame and LazyF…
Browse files Browse the repository at this point in the history
…rame (#15783)
  • Loading branch information
alexander-beedie authored Apr 22, 2024
1 parent fe190b3 commit a078d0c
Show file tree
Hide file tree
Showing 15 changed files with 307 additions and 73 deletions.
24 changes: 17 additions & 7 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,24 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
| SQLDataType::Uuid
| SQLDataType::Varchar(_) => DataType::String,
SQLDataType::Date => DataType::Date,
SQLDataType::Double | SQLDataType::DoublePrecision => DataType::Float64,
SQLDataType::Float(_) => DataType::Float32,
SQLDataType::Double
| SQLDataType::DoublePrecision
| SQLDataType::Float8
| SQLDataType::Float64 => DataType::Float64,
SQLDataType::Float(n_bytes) => match n_bytes {
Some(n) if (1u64..=24u64).contains(n) => DataType::Float32,
Some(n) if (25u64..=53u64).contains(n) => DataType::Float64,
Some(n) => {
polars_bail!(ComputeError: "unsupported `float` size; expected a value between 1 and 53, found {}", n)
},
None => DataType::Float64,
},
SQLDataType::Float4 | SQLDataType::Real => DataType::Float32,
SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32,
SQLDataType::Int2(_) => DataType::Int16,
SQLDataType::Int4(_) => DataType::Int32,
SQLDataType::Int8(_) => DataType::Int64,
SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds),
SQLDataType::Real => DataType::Float32,
SQLDataType::SmallInt(_) => DataType::Int16,
SQLDataType::Time(_, tz) => match tz {
TimezoneInfo::None => DataType::Time,
Expand All @@ -72,11 +82,11 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
SQLDataType::Timestamp(prec, tz) => {
let tu = match prec {
None => TimeUnit::Microseconds,
Some(3) => TimeUnit::Milliseconds,
Some(6) => TimeUnit::Microseconds,
Some(9) => TimeUnit::Nanoseconds,
Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds,
Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds,
Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds,
Some(n) => {
polars_bail!(ComputeError: "unsupported `timestamp` precision; expected 3, 6 or 9, found prec={}", n)
polars_bail!(ComputeError: "unsupported `timestamp` precision; expected a value between 1 and 9, found {}", n)
},
};
match tz {
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-sql/tests/simple_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ fn test_cast_exprs() {
context.register("df", df.clone().lazy());
let sql = r#"
SELECT
cast(a as FLOAT) as floats,
cast(a as FLOAT) as f64,
cast(a as FLOAT(24)) as f32,
cast(a as INT) as ints,
cast(a as BIGINT) as bigints,
cast(a as STRING) as strings,
Expand All @@ -103,7 +104,8 @@ fn test_cast_exprs() {
let df_pl = df
.lazy()
.select(&[
col("a").cast(DataType::Float32).alias("floats"),
col("a").cast(DataType::Float64).alias("f64"),
col("a").cast(DataType::Float32).alias("f32"),
col("a").cast(DataType::Int32).alias("ints"),
col("a").cast(DataType::Int64).alias("bigints"),
col("a").cast(DataType::String).alias("strings"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Manipulation/selection
DataFrame.shrink_to_fit
DataFrame.slice
DataFrame.sort
DataFrame.sql
DataFrame.tail
DataFrame.take_every
DataFrame.to_dummies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Manipulation/selection
LazyFrame.shift_and_fill
LazyFrame.slice
LazyFrame.sort
LazyFrame.sql
LazyFrame.tail
LazyFrame.take_every
LazyFrame.top_k
Expand Down
7 changes: 6 additions & 1 deletion py-polars/docs/source/reference/sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ SQL Interface
=============
.. currentmodule:: polars

Polars provides a SQL interface to query frame data; this is available
through the :class:`SQLContext` object, detailed below, and the DataFrame
:meth:`~polars.DataFrame.sql` and LazyFrame :meth:`~polars.LazyFrame.sql`
methods (which make use of SQLContext internally).

.. py:class:: SQLContext
:canonical: polars.sql.SQLContext

Run SQL queries against DataFrame/LazyFrame data.

.. automethod:: __init__

Note: can be used as a context manager.
**Note:** can be used as a context manager.

.. automethod:: __enter__
.. automethod:: __exit__
Expand Down
114 changes: 114 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4122,6 +4122,120 @@ def sort(
.collect(_eager=True)
)

def sql(self, query: str, *, table_name: str | None = None) -> Self:
"""
Execute a SQL query against the DataFrame.
.. warning::
This functionality is considered **unstable**, although it is close to
being considered stable. It may be changed at any point without it being
considered a breaking change.
Parameters
----------
query
SQL query to execute.
table_name
Optionally provide an explicit name for the table that represents the
calling frame (the alias "self" will always be registered/available).
Notes
-----
* The calling frame is automatically registered as a table in the SQL context
under the name "self". All DataFrames and LazyFrames found in the current
set of global variables are also registered, using their variable name.
* More control over registration and execution behaviour is available by
using the :class:`SQLContext` object.
* The SQL query executes entirely in lazy mode before being collected and
returned as a DataFrame.
See Also
--------
SQLContext
Examples
--------
>>> from datetime import date
>>> df1 = pl.DataFrame(
... {
... "a": [1, 2, 3],
... "b": ["zz", "yy", "xx"],
... "c": [date(1999, 12, 31), date(2010, 10, 10), date(2077, 8, 8)],
... }
... )
Query the DataFrame using SQL:
>>> df1.sql("SELECT c, b FROM self WHERE a > 1")
shape: (2, 2)
┌────────────┬─────┐
│ c ┆ b │
│ --- ┆ --- │
│ date ┆ str │
╞════════════╪═════╡
│ 2010-10-10 ┆ yy │
│ 2077-08-08 ┆ xx │
└────────────┴─────┘
Join two DataFrames using SQL.
>>> df2 = pl.DataFrame({"a": [3, 2, 1], "d": [125, -654, 888]})
>>> df1.sql(
... '''
... SELECT self.*, d
... FROM self
... INNER JOIN df2 USING (a)
... WHERE a > 1 AND EXTRACT(year FROM c) < 2050
... '''
... )
shape: (1, 4)
┌─────┬─────┬────────────┬──────┐
│ a ┆ b ┆ c ┆ d │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ date ┆ i64 │
╞═════╪═════╪════════════╪══════╡
│ 2 ┆ yy ┆ 2010-10-10 ┆ -654 │
└─────┴─────┴────────────┴──────┘
Apply transformations to a DataFrame using SQL, aliasing "self" to "frame".
>>> df1.sql(
... query='''
... SELECT
... a,
... (a % 2 == 0) AS a_is_even,
... CONCAT_WS(':', b, b) AS b_b,
... EXTRACT(year FROM c) AS year,
... 0::float4 AS "zero",
... FROM frame
... ''',
... table_name="frame",
... )
shape: (3, 5)
┌─────┬───────────┬───────┬──────┬──────┐
│ a ┆ a_is_even ┆ b_b ┆ year ┆ zero │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ bool ┆ str ┆ i32 ┆ f32 │
╞═════╪═══════════╪═══════╪══════╪══════╡
│ 1 ┆ false ┆ zz:zz ┆ 1999 ┆ 0.0 │
│ 2 ┆ true ┆ yy:yy ┆ 2010 ┆ 0.0 │
│ 3 ┆ false ┆ xx:xx ┆ 2077 ┆ 0.0 │
└─────┴───────────┴───────┴──────┴──────┘
"""
from polars.sql import SQLContext

issue_unstable_warning(
"`sql` is considered **unstable** (although it is close to being considered stable)."
)
with SQLContext(
register_globals=True,
eager_execution=True,
) as ctx:
frames = {table_name: self} if table_name else {}
frames["self"] = self
ctx.register_many(frames)
return ctx.execute(query) # type: ignore[return-value]

def top_k(
self,
k: int,
Expand Down
105 changes: 105 additions & 0 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,111 @@ def sort(
)
)

def sql(self, query: str, *, table_name: str | None = None) -> Self:
"""
Execute a SQL query against the LazyFrame.
.. warning::
This functionality is considered **unstable**, although it is close to
being considered stable. It may be changed at any point without it being
considered a breaking change.
Parameters
----------
query
SQL query to execute.
table_name
Optionally provide an explicit name for the table that represents the
calling frame (the alias "self" will always be registered/available).
Notes
-----
* The calling frame is automatically registered as a table in the SQL context
under the name "self". All DataFrames and LazyFrames found in the current
set of global variables are also registered, using their variable name.
* More control over registration and execution behaviour is available by
using the :class:`SQLContext` object.
See Also
--------
SQLContext
Examples
--------
>>> lf1 = pl.LazyFrame({"a": [1, 2, 3], "b": [6, 7, 8], "c": ["z", "y", "x"]})
>>> lf2 = pl.LazyFrame({"a": [3, 2, 1], "d": [125, -654, 888]})
Query the LazyFrame using SQL:
>>> lf1.sql("SELECT c, b FROM self WHERE a > 1").collect()
shape: (2, 2)
┌─────┬─────┐
│ c ┆ b │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ y ┆ 7 │
│ x ┆ 8 │
└─────┴─────┘
Join two LazyFrames:
>>> lf1.sql(
... '''
... SELECT self.*, d
... FROM self
... INNER JOIN lf2 USING (a)
... WHERE a > 1 AND b < 8
... '''
... ).collect()
shape: (1, 4)
┌─────┬─────┬─────┬──────┐
│ a ┆ b ┆ c ┆ d │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str ┆ i64 │
╞═════╪═════╪═════╪══════╡
│ 2 ┆ 7 ┆ y ┆ -654 │
└─────┴─────┴─────┴──────┘
Apply SQL transforms (aliasing "self" to "frame") and subsequently
filter natively (you can freely mix SQL and native operations):
>>> lf1.sql(
... query='''
... SELECT
... a,
... (a % 2 == 0) AS a_is_even,
... (b::float4 / 2) AS "b/2",
... CONCAT_WS(':', c, c, c) AS c_c_c
... FROM frame
... ORDER BY a
... ''',
... table_name="frame",
... ).filter(~pl.col("c_c_c").str.starts_with("x")).collect()
shape: (2, 4)
┌─────┬───────────┬─────┬───────┐
│ a ┆ a_is_even ┆ b/2 ┆ c_c_c │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ bool ┆ f32 ┆ str │
╞═════╪═══════════╪═════╪═══════╡
│ 1 ┆ false ┆ 3.0 ┆ z:z:z │
│ 2 ┆ true ┆ 3.5 ┆ y:y:y │
└─────┴───────────┴─────┴───────┘
"""
from polars.sql import SQLContext

issue_unstable_warning(
"`sql` is considered **unstable** (although it is close to being considered stable)."
)
with SQLContext(
register_globals=True,
eager_execution=False,
) as ctx:
frames = {table_name: self} if table_name else {}
frames["self"] = self
ctx.register_many(frames)
return ctx.execute(query) # type: ignore[return-value]

def top_k(
self,
k: int,
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def __init__(
named_frames[name] = obj

if frames or named_frames:
self.register_many(frames, **named_frames)
frames.update(named_frames)
self.register_many(frames)

def __enter__(self) -> SQLContext[FrameType]:
"""Track currently registered tables on scope entry; supports nested scopes."""
Expand Down
22 changes: 10 additions & 12 deletions py-polars/tests/unit/sql/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@


def test_array_to_string() -> None:
df = pl.DataFrame({"values": [["aa", "bb"], [None, "cc"], ["dd", None]]})

with pl.SQLContext(df=df, eager_execution=True) as ctx:
res = ctx.execute(
"""
SELECT
ARRAY_TO_STRING(values, '') AS v1,
ARRAY_TO_STRING(values, ':') AS v2,
ARRAY_TO_STRING(values, ':', 'NA') AS v3
FROM df
"""
)
data = {"values": [["aa", "bb"], [None, "cc"], ["dd", None]]}
res = pl.DataFrame(data).sql(
"""
SELECT
ARRAY_TO_STRING(values, '') AS v1,
ARRAY_TO_STRING(values, ':') AS v2,
ARRAY_TO_STRING(values, ':', 'NA') AS v3
FROM self
"""
)
assert_frame_equal(
res,
pl.DataFrame(
Expand Down
10 changes: 5 additions & 5 deletions py-polars/tests/unit/sql/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def test_case_when() -> None:
FROM test_data
"""
)
assert out.to_dict(as_series=False) == {
"v1": [None, 2, None, 4],
"v2": [101, 202, 303, 404],
"v3": ["odd", "even", "odd", "even"],
}
assert out.to_dict(as_series=False) == {
"v1": [None, 2, None, 4],
"v2": [101, 202, 303, 404],
"v3": ["odd", "even", "odd", "even"],
}


def test_control_flow(foods_ipc_path: Path) -> None:
Expand Down
Loading

0 comments on commit a078d0c

Please sign in to comment.