From a078d0c7d3f02dedde3f9ab58d7a7c4bfd7d885b Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Mon, 22 Apr 2024 10:56:33 +0400 Subject: [PATCH] feat(python): Add a low-friction `sql` method for DataFrame and LazyFrame (#15783) --- crates/polars-sql/src/sql_expr.rs | 24 ++-- crates/polars-sql/tests/simple_exprs.rs | 6 +- .../reference/dataframe/modify_select.rst | 1 + .../reference/lazyframe/modify_select.rst | 1 + py-polars/docs/source/reference/sql.rst | 7 +- py-polars/polars/dataframe/frame.py | 114 ++++++++++++++++++ py-polars/polars/lazyframe/frame.py | 105 ++++++++++++++++ py-polars/polars/sql/context.py | 3 +- py-polars/tests/unit/sql/test_array.py | 22 ++-- py-polars/tests/unit/sql/test_conditional.py | 10 +- py-polars/tests/unit/sql/test_joins.py | 14 +-- .../tests/unit/sql/test_miscellaneous.py | 12 +- py-polars/tests/unit/sql/test_numeric.py | 43 ++++--- py-polars/tests/unit/sql/test_strings.py | 8 +- py-polars/tests/unit/sql/test_temporal.py | 10 +- 15 files changed, 307 insertions(+), 73 deletions(-) diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index e5bc71f7858f..a1568ded9f03 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -54,14 +54,24 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult DataType::String, SQLDataType::Date => DataType::Date, - SQLDataType::Double | SQLDataType::DoublePrecision => DataType::Float64, - SQLDataType::Float(_) => DataType::Float32, + SQLDataType::Double + | SQLDataType::DoublePrecision + | SQLDataType::Float8 + | SQLDataType::Float64 => DataType::Float64, + SQLDataType::Float(n_bytes) => match n_bytes { + Some(n) if (1u64..=24u64).contains(n) => DataType::Float32, + Some(n) if (25u64..=53u64).contains(n) => DataType::Float64, + Some(n) => { + polars_bail!(ComputeError: "unsupported `float` size; expected a value between 1 and 53, found {}", n) + }, + None => DataType::Float64, + }, + SQLDataType::Float4 | SQLDataType::Real => DataType::Float32, SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32, SQLDataType::Int2(_) => DataType::Int16, SQLDataType::Int4(_) => DataType::Int32, SQLDataType::Int8(_) => DataType::Int64, SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds), - SQLDataType::Real => DataType::Float32, SQLDataType::SmallInt(_) => DataType::Int16, SQLDataType::Time(_, tz) => match tz { TimezoneInfo::None => DataType::Time, @@ -72,11 +82,11 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { let tu = match prec { None => TimeUnit::Microseconds, - Some(3) => TimeUnit::Milliseconds, - Some(6) => TimeUnit::Microseconds, - Some(9) => TimeUnit::Nanoseconds, + Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, + Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, + Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, Some(n) => { - polars_bail!(ComputeError: "unsupported `timestamp` precision; expected 3, 6 or 9, found prec={}", n) + polars_bail!(ComputeError: "unsupported `timestamp` precision; expected a value between 1 and 9, found {}", n) }, }; match tz { diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index a2d0a0a7a9db..92a69a03ea0c 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -93,7 +93,8 @@ fn test_cast_exprs() { context.register("df", df.clone().lazy()); let sql = r#" SELECT - cast(a as FLOAT) as floats, + cast(a as FLOAT) as f64, + cast(a as FLOAT(24)) as f32, cast(a as INT) as ints, cast(a as BIGINT) as bigints, cast(a as STRING) as strings, @@ -103,7 +104,8 @@ fn test_cast_exprs() { let df_pl = df .lazy() .select(&[ - col("a").cast(DataType::Float32).alias("floats"), + col("a").cast(DataType::Float64).alias("f64"), + col("a").cast(DataType::Float32).alias("f32"), col("a").cast(DataType::Int32).alias("ints"), col("a").cast(DataType::Int64).alias("bigints"), col("a").cast(DataType::String).alias("strings"), diff --git a/py-polars/docs/source/reference/dataframe/modify_select.rst b/py-polars/docs/source/reference/dataframe/modify_select.rst index 1a82e58027b0..17a1b60cbde4 100644 --- a/py-polars/docs/source/reference/dataframe/modify_select.rst +++ b/py-polars/docs/source/reference/dataframe/modify_select.rst @@ -65,6 +65,7 @@ Manipulation/selection DataFrame.shrink_to_fit DataFrame.slice DataFrame.sort + DataFrame.sql DataFrame.tail DataFrame.take_every DataFrame.to_dummies diff --git a/py-polars/docs/source/reference/lazyframe/modify_select.rst b/py-polars/docs/source/reference/lazyframe/modify_select.rst index c71126c7093a..e2c3f065ad82 100644 --- a/py-polars/docs/source/reference/lazyframe/modify_select.rst +++ b/py-polars/docs/source/reference/lazyframe/modify_select.rst @@ -44,6 +44,7 @@ Manipulation/selection LazyFrame.shift_and_fill LazyFrame.slice LazyFrame.sort + LazyFrame.sql LazyFrame.tail LazyFrame.take_every LazyFrame.top_k diff --git a/py-polars/docs/source/reference/sql.rst b/py-polars/docs/source/reference/sql.rst index bf28f9cc6e20..5010ed54d611 100644 --- a/py-polars/docs/source/reference/sql.rst +++ b/py-polars/docs/source/reference/sql.rst @@ -3,6 +3,11 @@ SQL Interface ============= .. currentmodule:: polars +Polars provides a SQL interface to query frame data; this is available +through the :class:`SQLContext` object, detailed below, and the DataFrame +:meth:`~polars.DataFrame.sql` and LazyFrame :meth:`~polars.LazyFrame.sql` +methods (which make use of SQLContext internally). + .. py:class:: SQLContext :canonical: polars.sql.SQLContext @@ -10,7 +15,7 @@ SQL Interface .. automethod:: __init__ - Note: can be used as a context manager. + **Note:** can be used as a context manager. .. automethod:: __enter__ .. automethod:: __exit__ diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 4f35506ce121..eee0d8b56c73 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4122,6 +4122,120 @@ def sort( .collect(_eager=True) ) + def sql(self, query: str, *, table_name: str | None = None) -> Self: + """ + Execute a SQL query against the DataFrame. + + .. warning:: + This functionality is considered **unstable**, although it is close to + being considered stable. It may be changed at any point without it being + considered a breaking change. + + Parameters + ---------- + query + SQL query to execute. + table_name + Optionally provide an explicit name for the table that represents the + calling frame (the alias "self" will always be registered/available). + + Notes + ----- + * The calling frame is automatically registered as a table in the SQL context + under the name "self". All DataFrames and LazyFrames found in the current + set of global variables are also registered, using their variable name. + * More control over registration and execution behaviour is available by + using the :class:`SQLContext` object. + * The SQL query executes entirely in lazy mode before being collected and + returned as a DataFrame. + + See Also + -------- + SQLContext + + Examples + -------- + >>> from datetime import date + >>> df1 = pl.DataFrame( + ... { + ... "a": [1, 2, 3], + ... "b": ["zz", "yy", "xx"], + ... "c": [date(1999, 12, 31), date(2010, 10, 10), date(2077, 8, 8)], + ... } + ... ) + + Query the DataFrame using SQL: + + >>> df1.sql("SELECT c, b FROM self WHERE a > 1") + shape: (2, 2) + ┌────────────┬─────┐ + │ c ┆ b │ + │ --- ┆ --- │ + │ date ┆ str │ + ╞════════════╪═════╡ + │ 2010-10-10 ┆ yy │ + │ 2077-08-08 ┆ xx │ + └────────────┴─────┘ + + Join two DataFrames using SQL. + + >>> df2 = pl.DataFrame({"a": [3, 2, 1], "d": [125, -654, 888]}) + >>> df1.sql( + ... ''' + ... SELECT self.*, d + ... FROM self + ... INNER JOIN df2 USING (a) + ... WHERE a > 1 AND EXTRACT(year FROM c) < 2050 + ... ''' + ... ) + shape: (1, 4) + ┌─────┬─────┬────────────┬──────┐ + │ a ┆ b ┆ c ┆ d │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ date ┆ i64 │ + ╞═════╪═════╪════════════╪══════╡ + │ 2 ┆ yy ┆ 2010-10-10 ┆ -654 │ + └─────┴─────┴────────────┴──────┘ + + Apply transformations to a DataFrame using SQL, aliasing "self" to "frame". + + >>> df1.sql( + ... query=''' + ... SELECT + ... a, + ... (a % 2 == 0) AS a_is_even, + ... CONCAT_WS(':', b, b) AS b_b, + ... EXTRACT(year FROM c) AS year, + ... 0::float4 AS "zero", + ... FROM frame + ... ''', + ... table_name="frame", + ... ) + shape: (3, 5) + ┌─────┬───────────┬───────┬──────┬──────┐ + │ a ┆ a_is_even ┆ b_b ┆ year ┆ zero │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ bool ┆ str ┆ i32 ┆ f32 │ + ╞═════╪═══════════╪═══════╪══════╪══════╡ + │ 1 ┆ false ┆ zz:zz ┆ 1999 ┆ 0.0 │ + │ 2 ┆ true ┆ yy:yy ┆ 2010 ┆ 0.0 │ + │ 3 ┆ false ┆ xx:xx ┆ 2077 ┆ 0.0 │ + └─────┴───────────┴───────┴──────┴──────┘ + """ + from polars.sql import SQLContext + + issue_unstable_warning( + "`sql` is considered **unstable** (although it is close to being considered stable)." + ) + with SQLContext( + register_globals=True, + eager_execution=True, + ) as ctx: + frames = {table_name: self} if table_name else {} + frames["self"] = self + ctx.register_many(frames) + return ctx.execute(query) # type: ignore[return-value] + def top_k( self, k: int, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 39f5ce12dc16..90bb5cd94fe8 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1211,6 +1211,111 @@ def sort( ) ) + def sql(self, query: str, *, table_name: str | None = None) -> Self: + """ + Execute a SQL query against the LazyFrame. + + .. warning:: + This functionality is considered **unstable**, although it is close to + being considered stable. It may be changed at any point without it being + considered a breaking change. + + Parameters + ---------- + query + SQL query to execute. + table_name + Optionally provide an explicit name for the table that represents the + calling frame (the alias "self" will always be registered/available). + + Notes + ----- + * The calling frame is automatically registered as a table in the SQL context + under the name "self". All DataFrames and LazyFrames found in the current + set of global variables are also registered, using their variable name. + * More control over registration and execution behaviour is available by + using the :class:`SQLContext` object. + + See Also + -------- + SQLContext + + Examples + -------- + >>> lf1 = pl.LazyFrame({"a": [1, 2, 3], "b": [6, 7, 8], "c": ["z", "y", "x"]}) + >>> lf2 = pl.LazyFrame({"a": [3, 2, 1], "d": [125, -654, 888]}) + + Query the LazyFrame using SQL: + + >>> lf1.sql("SELECT c, b FROM self WHERE a > 1").collect() + shape: (2, 2) + ┌─────┬─────┐ + │ c ┆ b │ + │ --- ┆ --- │ + │ str ┆ i64 │ + ╞═════╪═════╡ + │ y ┆ 7 │ + │ x ┆ 8 │ + └─────┴─────┘ + + Join two LazyFrames: + + >>> lf1.sql( + ... ''' + ... SELECT self.*, d + ... FROM self + ... INNER JOIN lf2 USING (a) + ... WHERE a > 1 AND b < 8 + ... ''' + ... ).collect() + shape: (1, 4) + ┌─────┬─────┬─────┬──────┐ + │ a ┆ b ┆ c ┆ d │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ str ┆ i64 │ + ╞═════╪═════╪═════╪══════╡ + │ 2 ┆ 7 ┆ y ┆ -654 │ + └─────┴─────┴─────┴──────┘ + + Apply SQL transforms (aliasing "self" to "frame") and subsequently + filter natively (you can freely mix SQL and native operations): + + >>> lf1.sql( + ... query=''' + ... SELECT + ... a, + ... (a % 2 == 0) AS a_is_even, + ... (b::float4 / 2) AS "b/2", + ... CONCAT_WS(':', c, c, c) AS c_c_c + ... FROM frame + ... ORDER BY a + ... ''', + ... table_name="frame", + ... ).filter(~pl.col("c_c_c").str.starts_with("x")).collect() + shape: (2, 4) + ┌─────┬───────────┬─────┬───────┐ + │ a ┆ a_is_even ┆ b/2 ┆ c_c_c │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ bool ┆ f32 ┆ str │ + ╞═════╪═══════════╪═════╪═══════╡ + │ 1 ┆ false ┆ 3.0 ┆ z:z:z │ + │ 2 ┆ true ┆ 3.5 ┆ y:y:y │ + └─────┴───────────┴─────┴───────┘ + """ + from polars.sql import SQLContext + + issue_unstable_warning( + "`sql` is considered **unstable** (although it is close to being considered stable)." + ) + with SQLContext( + register_globals=True, + eager_execution=False, + ) as ctx: + frames = {table_name: self} if table_name else {} + frames["self"] = self + ctx.register_many(frames) + return ctx.execute(query) # type: ignore[return-value] + def top_k( self, k: int, diff --git a/py-polars/polars/sql/context.py b/py-polars/polars/sql/context.py index db02e51a1dd8..70a8a78e7e1a 100644 --- a/py-polars/polars/sql/context.py +++ b/py-polars/polars/sql/context.py @@ -125,7 +125,8 @@ def __init__( named_frames[name] = obj if frames or named_frames: - self.register_many(frames, **named_frames) + frames.update(named_frames) + self.register_many(frames) def __enter__(self) -> SQLContext[FrameType]: """Track currently registered tables on scope entry; supports nested scopes.""" diff --git a/py-polars/tests/unit/sql/test_array.py b/py-polars/tests/unit/sql/test_array.py index a62cd6ffd984..7bb4bc53c302 100644 --- a/py-polars/tests/unit/sql/test_array.py +++ b/py-polars/tests/unit/sql/test_array.py @@ -5,18 +5,16 @@ def test_array_to_string() -> None: - df = pl.DataFrame({"values": [["aa", "bb"], [None, "cc"], ["dd", None]]}) - - with pl.SQLContext(df=df, eager_execution=True) as ctx: - res = ctx.execute( - """ - SELECT - ARRAY_TO_STRING(values, '') AS v1, - ARRAY_TO_STRING(values, ':') AS v2, - ARRAY_TO_STRING(values, ':', 'NA') AS v3 - FROM df - """ - ) + data = {"values": [["aa", "bb"], [None, "cc"], ["dd", None]]} + res = pl.DataFrame(data).sql( + """ + SELECT + ARRAY_TO_STRING(values, '') AS v1, + ARRAY_TO_STRING(values, ':') AS v2, + ARRAY_TO_STRING(values, ':', 'NA') AS v3 + FROM self + """ + ) assert_frame_equal( res, pl.DataFrame( diff --git a/py-polars/tests/unit/sql/test_conditional.py b/py-polars/tests/unit/sql/test_conditional.py index 4d14765fcbb6..1a475581a6b6 100644 --- a/py-polars/tests/unit/sql/test_conditional.py +++ b/py-polars/tests/unit/sql/test_conditional.py @@ -29,11 +29,11 @@ def test_case_when() -> None: FROM test_data """ ) - assert out.to_dict(as_series=False) == { - "v1": [None, 2, None, 4], - "v2": [101, 202, 303, 404], - "v3": ["odd", "even", "odd", "even"], - } + assert out.to_dict(as_series=False) == { + "v1": [None, 2, None, 4], + "v2": [101, 202, 303, 404], + "v3": ["odd", "even", "odd", "even"], + } def test_control_flow(foods_ipc_path: Path) -> None: diff --git a/py-polars/tests/unit/sql/test_joins.py b/py-polars/tests/unit/sql/test_joins.py index 10534076ec72..a498cbb6629e 100644 --- a/py-polars/tests/unit/sql/test_joins.py +++ b/py-polars/tests/unit/sql/test_joins.py @@ -70,20 +70,19 @@ def test_join_anti_semi(sql: str, expected: pl.DataFrame) -> None: ], ) def test_join_inner(foods_ipc_path: Path, join_clause: str) -> None: - lf = pl.scan_ipc(foods_ipc_path) + foods1 = pl.scan_ipc(foods_ipc_path) + foods2 = foods1 # noqa: F841 - ctx = pl.SQLContext() - ctx.register_many(foods1=lf, foods2=lf) - - out = ctx.execute( + out = foods1.sql( f""" SELECT * FROM foods1 INNER JOIN foods2 {join_clause} LIMIT 2 """ - ) - assert out.collect().to_dict(as_series=False) == { + ).collect() + + assert out.to_dict(as_series=False) == { "category": ["vegetables", "vegetables"], "calories": [45, 20], "fats_g": [0.5, 0.0], @@ -171,6 +170,7 @@ def test_join_left_multi_nested() -> None: ORDER BY tbl_x.a ASC """ ).collect() + assert out.rows() == [ (1, 4, "z", 25.5), (2, None, None, None), diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index ba86a0b434ca..d32674db78ff 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -21,7 +21,7 @@ def test_any_all() -> None: "y": [1, 0, 0, 1, 2, 3], } ) - res = pl.SQLContext(df=df).execute( + res = df.sql( """ SELECT x >= ALL(df.y) as 'All Geq', @@ -36,9 +36,7 @@ def test_any_all() -> None: x != ANY(df.y) as 'Any Neq', FROM df """, - eager=True, ) - assert res.to_dict(as_series=False) == { "All Geq": [0, 0, 0, 0, 1, 1], "All G": [0, 0, 0, 0, 0, 1], @@ -88,16 +86,16 @@ def test_distinct() -> None: def test_in_no_ops_11946() -> None: - df = pl.LazyFrame( + lf = pl.LazyFrame( [ {"i1": 1}, {"i1": 2}, {"i1": 3}, ] ) - ctx = pl.SQLContext(frame_data=df, eager_execution=False) - out = ctx.execute( - "SELECT * FROM frame_data WHERE i1 in (1, 3)", eager=False + out = lf.sql( + query="SELECT * FROM frame_data WHERE i1 in (1, 3)", + table_name="frame_data", ).collect() assert out.to_dict(as_series=False) == {"i1": [1, 3]} diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index dc72bc30d0af..ca54df80be98 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -22,29 +22,28 @@ def test_modulo() -> None: "d": [16.5, 17.0, 18.5, None, 20.0], } ) - with pl.SQLContext(df=df) as ctx: - out = ctx.execute( - """ - SELECT - a % 2 AS a2, - b % 3 AS b3, - MOD(c, 4) AS c4, - MOD(d, 5.5) AS d55 - FROM df - """ - ).collect() + out = df.sql( + """ + SELECT + a % 2 AS a2, + b % 3 AS b3, + MOD(c, 4) AS c4, + MOD(d, 5.5) AS d55 + FROM df + """ + ) - assert_frame_equal( - out, - pl.DataFrame( - { - "a2": [1.5, None, 1.0, 1 / 3, 1.0], - "b3": [0, 1, 2, 0, 1], - "c4": [3, 0, 1, 2, 3], - "d55": [0.0, 0.5, 2.0, None, 3.5], - } - ), - ) + assert_frame_equal( + out, + pl.DataFrame( + { + "a2": [1.5, None, 1.0, 1 / 3, 1.0], + "b3": [0, 1, 2, 0, 1], + "c4": [3, 0, 1, 2, 3], + "d55": [0.0, 0.5, 2.0, None, 3.5], + } + ), + ) @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/sql/test_strings.py b/py-polars/tests/unit/sql/test_strings.py index ee7685847750..019b102a4d05 100644 --- a/py-polars/tests/unit/sql/test_strings.py +++ b/py-polars/tests/unit/sql/test_strings.py @@ -46,7 +46,7 @@ def test_string_concat() -> None: "z": [1, 2, 3], } ) - res = pl.SQLContext(data=lf).execute( + res = lf.sql( """ SELECT ("x" || "x" || "y") AS c0, @@ -56,10 +56,10 @@ def test_string_concat() -> None: CONCAT("x", "y", ("z" * 2)) AS c4, CONCAT_WS(':', "x", "y", "z") AS c5, CONCAT_WS('', "y", "z", '!') AS c6 - FROM data + FROM self """, - eager=True, - ) + ).collect() + assert res.to_dict(as_series=False) == { "c0": ["aad", None, "ccf"], "c1": ["ad1", None, "cf3"], diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index 9659c720ce84..4babd435374f 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -45,10 +45,9 @@ def test_datetime_to_time(time_unit: Literal["ns", "us", "ms"]) -> None: }, schema={"dtm": pl.Datetime(time_unit)}, ) - with pl.SQLContext(df=df, eager_execution=True) as ctx: - result = ctx.execute("SELECT dtm::time as tm from df")["tm"].to_list() - assert result == [ + res = df.sql("SELECT dtm::time as tm from df")["tm"].to_list() + assert res == [ time(23, 59, 59), time(12, 30, 30), time(1, 1, 1), @@ -180,8 +179,9 @@ def test_timestamp_time_unit_errors() -> None: df = pl.DataFrame({"ts": [datetime(2024, 1, 7, 1, 2, 3, 123456)]}) with pl.SQLContext(frame_data=df, eager_execution=True) as ctx: - for prec in (0, 4, 15): + for prec in (0, 15): with pytest.raises( - ComputeError, match=f"unsupported `timestamp` precision; .* prec={prec}" + ComputeError, + match=f"unsupported `timestamp` precision; expected a value between 1 and 9, found {prec}", ): ctx.execute(f"SELECT ts::timestamp({prec}) FROM frame_data")