Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add a low-friction sql method for DataFrame and LazyFrame #15783

Merged
merged 5 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,24 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
| SQLDataType::Uuid
| SQLDataType::Varchar(_) => DataType::String,
SQLDataType::Date => DataType::Date,
SQLDataType::Double | SQLDataType::DoublePrecision => DataType::Float64,
SQLDataType::Float(_) => DataType::Float32,
SQLDataType::Double
| SQLDataType::DoublePrecision
| SQLDataType::Float8
| SQLDataType::Float64 => DataType::Float64,
SQLDataType::Float(n_bytes) => match n_bytes {
Some(n) if (1u64..=24u64).contains(n) => DataType::Float32,
Some(n) if (25u64..=53u64).contains(n) => DataType::Float64,
Some(n) => {
polars_bail!(ComputeError: "unsupported `float` size; expected a value between 1 and 53, found {}", n)
},
None => DataType::Float64,
},
SQLDataType::Float4 | SQLDataType::Real => DataType::Float32,
SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32,
SQLDataType::Int2(_) => DataType::Int16,
SQLDataType::Int4(_) => DataType::Int32,
SQLDataType::Int8(_) => DataType::Int64,
SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds),
SQLDataType::Real => DataType::Float32,
SQLDataType::SmallInt(_) => DataType::Int16,
SQLDataType::Time(_, tz) => match tz {
TimezoneInfo::None => DataType::Time,
Expand All @@ -72,11 +82,11 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
SQLDataType::Timestamp(prec, tz) => {
let tu = match prec {
None => TimeUnit::Microseconds,
Some(3) => TimeUnit::Milliseconds,
Some(6) => TimeUnit::Microseconds,
Some(9) => TimeUnit::Nanoseconds,
Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds,
Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds,
Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds,
Some(n) => {
polars_bail!(ComputeError: "unsupported `timestamp` precision; expected 3, 6 or 9, found prec={}", n)
polars_bail!(ComputeError: "unsupported `timestamp` precision; expected a value between 1 and 9, found {}", n)
},
};
match tz {
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-sql/tests/simple_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ fn test_cast_exprs() {
context.register("df", df.clone().lazy());
let sql = r#"
SELECT
cast(a as FLOAT) as floats,
cast(a as FLOAT) as f64,
cast(a as FLOAT(24)) as f32,
cast(a as INT) as ints,
cast(a as BIGINT) as bigints,
cast(a as STRING) as strings,
Expand All @@ -103,7 +104,8 @@ fn test_cast_exprs() {
let df_pl = df
.lazy()
.select(&[
col("a").cast(DataType::Float32).alias("floats"),
col("a").cast(DataType::Float64).alias("f64"),
col("a").cast(DataType::Float32).alias("f32"),
col("a").cast(DataType::Int32).alias("ints"),
col("a").cast(DataType::Int64).alias("bigints"),
col("a").cast(DataType::String).alias("strings"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Manipulation/selection
DataFrame.shrink_to_fit
DataFrame.slice
DataFrame.sort
DataFrame.sql
DataFrame.tail
DataFrame.take_every
DataFrame.to_dummies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Manipulation/selection
LazyFrame.shift_and_fill
LazyFrame.slice
LazyFrame.sort
LazyFrame.sql
LazyFrame.tail
LazyFrame.take_every
LazyFrame.top_k
Expand Down
7 changes: 6 additions & 1 deletion py-polars/docs/source/reference/sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ SQL Interface
=============
.. currentmodule:: polars

Polars provides a SQL interface to query frame data; this is available
through the :class:`SQLContext` object, detailed below, and the DataFrame
:meth:`~polars.DataFrame.sql` and LazyFrame :meth:`~polars.LazyFrame.sql`
methods (which make use of SQLContext internally).

.. py:class:: SQLContext
:canonical: polars.sql.SQLContext

Run SQL queries against DataFrame/LazyFrame data.

.. automethod:: __init__

Note: can be used as a context manager.
**Note:** can be used as a context manager.

.. automethod:: __enter__
.. automethod:: __exit__
Expand Down
114 changes: 114 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4122,6 +4122,120 @@ def sort(
.collect(_eager=True)
)

def sql(self, query: str, *, table_name: str | None = None) -> Self:
"""
Execute a SQL query against the DataFrame.

.. warning::
This functionality is considered **unstable**, although it is close to
being considered stable. It may be changed at any point without it being
considered a breaking change.

Parameters
----------
query
SQL query to execute.
table_name
Optionally provide an explicit name for the table that represents the
calling frame (the alias "self" will always be registered/available).

Notes
-----
* The calling frame is automatically registered as a table in the SQL context
under the name "self". All DataFrames and LazyFrames found in the current
set of global variables are also registered, using their variable name.
* More control over registration and execution behaviour is available by
using the :class:`SQLContext` object.
* The SQL query executes entirely in lazy mode before being collected and
returned as a DataFrame.

See Also
--------
SQLContext

Examples
--------
>>> from datetime import date
>>> df1 = pl.DataFrame(
... {
... "a": [1, 2, 3],
... "b": ["zz", "yy", "xx"],
... "c": [date(1999, 12, 31), date(2010, 10, 10), date(2077, 8, 8)],
... }
... )

Query the DataFrame using SQL:

>>> df1.sql("SELECT c, b FROM self WHERE a > 1")
shape: (2, 2)
┌────────────┬─────┐
│ c ┆ b │
│ --- ┆ --- │
│ date ┆ str │
╞════════════╪═════╡
│ 2010-10-10 ┆ yy │
│ 2077-08-08 ┆ xx │
└────────────┴─────┘

Join two DataFrames using SQL.

>>> df2 = pl.DataFrame({"a": [3, 2, 1], "d": [125, -654, 888]})
>>> df1.sql(
... '''
... SELECT self.*, d
... FROM self
... INNER JOIN df2 USING (a)
... WHERE a > 1 AND EXTRACT(year FROM c) < 2050
... '''
... )
shape: (1, 4)
┌─────┬─────┬────────────┬──────┐
│ a ┆ b ┆ c ┆ d │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ date ┆ i64 │
╞═════╪═════╪════════════╪══════╡
│ 2 ┆ yy ┆ 2010-10-10 ┆ -654 │
└─────┴─────┴────────────┴──────┘

Apply transformations to a DataFrame using SQL, aliasing "self" to "frame".

>>> df1.sql(
... query='''
... SELECT
... a,
... (a % 2 == 0) AS a_is_even,
... CONCAT_WS(':', b, b) AS b_b,
... EXTRACT(year FROM c) AS year,
... 0::float4 AS "zero",
... FROM frame
... ''',
... table_name="frame",
... )
shape: (3, 5)
┌─────┬───────────┬───────┬──────┬──────┐
│ a ┆ a_is_even ┆ b_b ┆ year ┆ zero │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ bool ┆ str ┆ i32 ┆ f32 │
╞═════╪═══════════╪═══════╪══════╪══════╡
│ 1 ┆ false ┆ zz:zz ┆ 1999 ┆ 0.0 │
│ 2 ┆ true ┆ yy:yy ┆ 2010 ┆ 0.0 │
│ 3 ┆ false ┆ xx:xx ┆ 2077 ┆ 0.0 │
└─────┴───────────┴───────┴──────┴──────┘
"""
from polars.sql import SQLContext

issue_unstable_warning(
"`sql` is considered **unstable** (although it is close to being considered stable)."
)
with SQLContext(
register_globals=True,
eager_execution=True,
) as ctx:
frames = {table_name: self} if table_name else {}
frames["self"] = self
ctx.register_many(frames)
return ctx.execute(query) # type: ignore[return-value]

def top_k(
self,
k: int,
Expand Down
105 changes: 105 additions & 0 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,111 @@ def sort(
)
)

def sql(self, query: str, *, table_name: str | None = None) -> Self:
"""
Execute a SQL query against the LazyFrame.
.. warning::
This functionality is considered **unstable**, although it is close to
being considered stable. It may be changed at any point without it being
considered a breaking change.
Parameters
----------
query
SQL query to execute.
table_name
Optionally provide an explicit name for the table that represents the
calling frame (the alias "self" will always be registered/available).
Notes
-----
* The calling frame is automatically registered as a table in the SQL context
under the name "self". All DataFrames and LazyFrames found in the current
set of global variables are also registered, using their variable name.
* More control over registration and execution behaviour is available by
using the :class:`SQLContext` object.
See Also
--------
SQLContext
Examples
--------
>>> lf1 = pl.LazyFrame({"a": [1, 2, 3], "b": [6, 7, 8], "c": ["z", "y", "x"]})
>>> lf2 = pl.LazyFrame({"a": [3, 2, 1], "d": [125, -654, 888]})
Query the LazyFrame using SQL:
>>> lf1.sql("SELECT c, b FROM self WHERE a > 1").collect()
shape: (2, 2)
┌─────┬─────┐
│ c ┆ b │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ y ┆ 7 │
│ x ┆ 8 │
└─────┴─────┘
Join two LazyFrames:
>>> lf1.sql(
... '''
... SELECT self.*, d
... FROM self
... INNER JOIN lf2 USING (a)
... WHERE a > 1 AND b < 8
... '''
... ).collect()
shape: (1, 4)
┌─────┬─────┬─────┬──────┐
│ a ┆ b ┆ c ┆ d │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str ┆ i64 │
╞═════╪═════╪═════╪══════╡
│ 2 ┆ 7 ┆ y ┆ -654 │
└─────┴─────┴─────┴──────┘
Apply SQL transforms (aliasing "self" to "frame") and subsequently
filter natively (you can freely mix SQL and native operations):
>>> lf1.sql(
... query='''
... SELECT
... a,
... (a % 2 == 0) AS a_is_even,
... (b::float4 / 2) AS "b/2",
... CONCAT_WS(':', c, c, c) AS c_c_c
... FROM frame
... ORDER BY a
... ''',
... table_name="frame",
... ).filter(~pl.col("c_c_c").str.starts_with("x")).collect()
shape: (2, 4)
┌─────┬───────────┬─────┬───────┐
│ a ┆ a_is_even ┆ b/2 ┆ c_c_c │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ bool ┆ f32 ┆ str │
╞═════╪═══════════╪═════╪═══════╡
│ 1 ┆ false ┆ 3.0 ┆ z:z:z │
│ 2 ┆ true ┆ 3.5 ┆ y:y:y │
└─────┴───────────┴─────┴───────┘
"""
from polars.sql import SQLContext

issue_unstable_warning(
"`sql` is considered **unstable** (although it is close to being considered stable)."
)
with SQLContext(
register_globals=True,
eager_execution=False,
) as ctx:
frames = {table_name: self} if table_name else {}
frames["self"] = self
ctx.register_many(frames)
return ctx.execute(query) # type: ignore[return-value]

def top_k(
self,
k: int,
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def __init__(
named_frames[name] = obj

if frames or named_frames:
self.register_many(frames, **named_frames)
frames.update(named_frames)
self.register_many(frames)

def __enter__(self) -> SQLContext[FrameType]:
"""Track currently registered tables on scope entry; supports nested scopes."""
Expand Down
22 changes: 10 additions & 12 deletions py-polars/tests/unit/sql/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@


def test_array_to_string() -> None:
df = pl.DataFrame({"values": [["aa", "bb"], [None, "cc"], ["dd", None]]})

with pl.SQLContext(df=df, eager_execution=True) as ctx:
res = ctx.execute(
"""
SELECT
ARRAY_TO_STRING(values, '') AS v1,
ARRAY_TO_STRING(values, ':') AS v2,
ARRAY_TO_STRING(values, ':', 'NA') AS v3
FROM df
"""
)
data = {"values": [["aa", "bb"], [None, "cc"], ["dd", None]]}
res = pl.DataFrame(data).sql(
"""
SELECT
ARRAY_TO_STRING(values, '') AS v1,
ARRAY_TO_STRING(values, ':') AS v2,
ARRAY_TO_STRING(values, ':', 'NA') AS v3
FROM self
"""
)
assert_frame_equal(
res,
pl.DataFrame(
Expand Down
10 changes: 5 additions & 5 deletions py-polars/tests/unit/sql/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def test_case_when() -> None:
FROM test_data
"""
)
assert out.to_dict(as_series=False) == {
"v1": [None, 2, None, 4],
"v2": [101, 202, 303, 404],
"v3": ["odd", "even", "odd", "even"],
}
assert out.to_dict(as_series=False) == {
"v1": [None, 2, None, 4],
"v2": [101, 202, 303, 404],
"v3": ["odd", "even", "odd", "even"],
}


def test_control_flow(foods_ipc_path: Path) -> None:
Expand Down
Loading
Loading