Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python wrapper classes for all user interfaces #750

Merged
merged 55 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d00c00a
Expose missing functions to python
timsaucer Jul 9, 2024
27e4f30
Initial commit for creating wrapper classes and functions for all use…
timsaucer Jul 9, 2024
a3429ab
Remove extra level of python path that is no longer required
timsaucer Jul 9, 2024
7937963
Move import to only happen for type checking for hints
timsaucer Jul 9, 2024
1f4c829
Comment out classes from __all__ in the top level that are not curren…
timsaucer Jul 9, 2024
d7f5f68
Add license comments
timsaucer Jul 9, 2024
79bb196
Add missing import
timsaucer Jul 9, 2024
685a257
Functions now only has one level of depth
timsaucer Jul 9, 2024
45ee5ab
Applying google docstring formatting
timsaucer Jul 9, 2024
b8239e7
Addressing PR request to add google formatted docstrings
timsaucer Jul 10, 2024
4c8073e
Small docstring for ruff
timsaucer Jul 10, 2024
411c91c
Linting
timsaucer Jul 10, 2024
610adda
Add docstring format checking to pre-commit stage
timsaucer Jul 10, 2024
265aeb7
Set explicit return types on UDFs
timsaucer Jul 11, 2024
02564de
Add options of passing either a path or a string
timsaucer Jul 12, 2024
e0e55a8
Switch to google docstring style
timsaucer Jul 12, 2024
dcd5211
Update unit tests to include registering via path or string
timsaucer Jul 12, 2024
1063cff
Add py.typed file
timsaucer Jul 12, 2024
5ba2017
Resolve deprecation warnings in unit tests
timsaucer Jul 13, 2024
438afa0
Add path to unit test
timsaucer Jul 13, 2024
837e3b2
Expose an option in write_csv to include header and add unit test
timsaucer Jul 13, 2024
6e75eee
Update write_parquet unit test to include paths or strings
timsaucer Jul 13, 2024
2ebe2e5
Add unit test for write_json
timsaucer Jul 13, 2024
dad0d26
Add unit test for substrait serialization to a file
timsaucer Jul 13, 2024
ae569ff
Add unit tests for runtime config
timsaucer Jul 13, 2024
4f973af
Setting return type to typing_extensions.Self per PR recommendation
timsaucer Jul 13, 2024
f2ed822
Correcting __next__ to not return None since it will raise an excepti…
timsaucer Jul 13, 2024
c2ee65d
Add optiona parameter of decimal places to round and add unit test
timsaucer Jul 13, 2024
835e374
Improve docstrings
timsaucer Jul 13, 2024
08b83ac
Set default to None instead of empty dict
timsaucer Jul 13, 2024
2ccd5ad
User request to allow passing multiple arguments to filter()
timsaucer Jul 13, 2024
13be857
Enhance Expr comparison operators to accept any python value and atte…
timsaucer Jul 13, 2024
8f1bb65
Expose overlay and add unit test
timsaucer Jul 13, 2024
75e129a
Allow select() to take either str for column names or a full expr
timsaucer Jul 13, 2024
f2b15e0
Update comments on regexp and add unit tests
timsaucer Jul 13, 2024
b76d105
Remove TODO markings no longer applicable
timsaucer Jul 13, 2024
6e87d73
Update udf documentation
timsaucer Jul 14, 2024
39f18cb
Docstring formatting
timsaucer Jul 14, 2024
94650b5
Updating docstring formatting
timsaucer Jul 14, 2024
95a4688
Updating docstring formatting
timsaucer Jul 14, 2024
39d9c00
Updating docstring formatting
timsaucer Jul 14, 2024
671d508
Updating docstring formatting
timsaucer Jul 15, 2024
49efdd0
Updating docstring formatting
timsaucer Jul 15, 2024
3c7a811
Cleaning up docstring line lengths
timsaucer Jul 15, 2024
fbf3f46
Add pre-commit check of docstring line length
timsaucer Jul 15, 2024
d6c6598
Do not emit doc entry for __init__ of some classes
timsaucer Jul 16, 2024
cccf305
Correct errors on code blocks generating in sphinx
timsaucer Jul 16, 2024
6579ac5
Resolve conflict with
timsaucer Jul 16, 2024
62197bc
Add license info to py.typed
timsaucer Jul 16, 2024
2821183
Clean up some docstring too long errors in CI
timsaucer Jul 16, 2024
c1df7db
Correct ruff complain in unit tests
timsaucer Jul 16, 2024
461e7b5
Temporarily install google test to get clippy to pass
timsaucer Jul 16, 2024
4af541e
Adding gmock to build step due to upstream error
timsaucer Jul 16, 2024
5588f28
Add type_extensions to conda meta file
timsaucer Jul 16, 2024
39f01fb
Small comment suggestions from PR
timsaucer Jul 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ Functions
.. autosummary::
:toctree: ../generated/

functions.functions
functions
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

"""Documenation generation."""

# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
Expand Down
15 changes: 6 additions & 9 deletions examples/substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
from datafusion import SessionContext
from datafusion import substrait as ss

# TODO add user changing interface note to PR that datafusion.substrait.substrait is simplified to datafusion.substrait
timsaucer marked this conversation as resolved.
Show resolved Hide resolved

# Create a DataFusion context
ctx = SessionContext()

# Register table with context
ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv")

substrait_plan = ss.substrait.serde.serialize_to_plan(
"SELECT * FROM aggregate_test_data", ctx
)
substrait_plan = ss.serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx)
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>

# Encode it to bytes
Expand All @@ -38,17 +37,15 @@
# Alternative serialization approaches
# type(substrait_bytes) -> <class 'bytes'>, at this point the bytes can be distributed to file, network, etc safely
# where they could subsequently be deserialized on the receiving end.
substrait_bytes = ss.substrait.serde.serialize_bytes(
"SELECT * FROM aggregate_test_data", ctx
)
substrait_bytes = ss.serde.serialize_bytes("SELECT * FROM aggregate_test_data", ctx)

# Imagine here bytes would be read from network, file, etc ... for example brevity this is omitted and variable is simply reused
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
substrait_plan = ss.substrait.serde.deserialize_bytes(substrait_bytes)
substrait_plan = ss.serde.deserialize_bytes(substrait_bytes)

# type(df_logical_plan) -> <class 'substrait.LogicalPlan'>
df_logical_plan = ss.substrait.consumer.from_substrait_plan(ctx, substrait_plan)
df_logical_plan = ss.consumer.from_substrait_plan(ctx, substrait_plan)

# Back to Substrait Plan just for demonstration purposes
# type(substrait_plan) -> <class 'datafusion.substrait.plan'>
substrait_plan = ss.substrait.producer.to_substrait_plan(df_logical_plan)
substrait_plan = ss.producer.to_substrait_plan(df_logical_plan)
30 changes: 22 additions & 8 deletions examples/tpch/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from datafusion import col, lit, functions as F
from util import get_answer_file


def df_selection(col_name, col_type):
if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
return F.round(col(col_name), lit(2)).alias(col_name)
Expand All @@ -29,14 +30,16 @@ def df_selection(col_name, col_type):
else:
return col(col_name)


def load_schema(col_name, col_type):
if col_type == pa.int64() or col_type == pa.int32():
return col_name, pa.string()
elif isinstance(col_type, pa.Decimal128Type):
return col_name, pa.float64()
else:
return col_name, col_type



def expected_selection(col_name, col_type):
if col_type == pa.int64() or col_type == pa.int32():
return F.trim(col(col_name)).cast(col_type).alias(col_name)
Expand All @@ -45,20 +48,23 @@ def expected_selection(col_name, col_type):
else:
return col(col_name)


def selections_and_schema(original_schema):
columns = [ (c, original_schema.field(c).type) for c in original_schema.names ]
columns = [(c, original_schema.field(c).type) for c in original_schema.names]

df_selections = [ df_selection(c, t) for (c, t) in columns]
expected_schema = [ load_schema(c, t) for (c, t) in columns]
expected_selections = [ expected_selection(c, t) for (c, t) in columns]
df_selections = [df_selection(c, t) for (c, t) in columns]
expected_schema = [load_schema(c, t) for (c, t) in columns]
expected_selections = [expected_selection(c, t) for (c, t) in columns]

return (df_selections, expected_schema, expected_selections)


def check_q17(df):
raw_value = float(df.collect()[0]["avg_yearly"][0].as_py())
value = round(raw_value, 2)
assert abs(value - 348406.05) < 0.001


@pytest.mark.parametrize(
("query_code", "answer_file"),
[
Expand All @@ -73,7 +79,8 @@ def check_q17(df):
("q09_product_type_profit_measure", "q9"),
("q10_returned_item_reporting", "q10"),
pytest.param(
"q11_important_stock_identification", "q11",
"q11_important_stock_identification",
"q11",
),
("q12_ship_mode_order_priority", "q12"),
("q13_customer_distribution", "q13"),
Expand All @@ -97,13 +104,20 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
if answer_file == "q17":
return check_q17(df)

(df_selections, expected_schema, expected_selections) = selections_and_schema(df.schema())
(df_selections, expected_schema, expected_selections) = selections_and_schema(
df.schema()
)

df = df.select(*df_selections)

read_schema = pa.schema(expected_schema)

df_expected = module.ctx.read_csv(get_answer_file(answer_file), schema=read_schema, delimiter="|", file_extension=".out")
df_expected = module.ctx.read_csv(
get_answer_file(answer_file),
schema=read_schema,
delimiter="|",
file_extension=".out",
)

df_expected = df_expected.select(*expected_selections)

Expand Down
3 changes: 1 addition & 2 deletions examples/tpch/convert_data_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,14 @@

curr_dir = os.path.dirname(os.path.abspath(__file__))
for filename, curr_schema in all_schemas.items():

# For convenience, go ahead and convert the schema column names to lowercase
curr_schema = [(s[0].lower(), s[1]) for s in curr_schema]

# Pre-collect the output columns so we can ignore the null field we add
# in to handle the trailing | in the file
output_cols = [r[0] for r in curr_schema]

curr_schema = [ pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema]
curr_schema = [pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema]

# Trailing | requires extra field for in processing
curr_schema.append(("some_null", pyarrow.null()))
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q08_market_share.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@

ctx = SessionContext()

df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type")
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_type"
)
df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q09_product_type_profit_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@

ctx = SessionContext()

df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name")
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_name"
)
df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q13_customer_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_custkey", "o_comment"
)
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns("c_custkey")
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_custkey"
)

# Use a regex to remove special cases
df_orders = df_orders.filter(
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q14_promotion_effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_partkey", "l_shipdate", "l_extendedprice", "l_discount"
)
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type")
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_type"
)


# Check part type begins with PROMO
Expand Down
3 changes: 2 additions & 1 deletion examples/tpch/q16_part_supplier_relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
# Select the parts we are interested in
df_part = df_part.filter(col("p_brand") != lit(BRAND))
df_part = df_part.filter(
F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) != lit(TYPE_TO_IGNORE)
F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1))
!= lit(TYPE_TO_IGNORE)
)

# Python conversion of integer to literal casts it to int64 but the data for
Expand Down
8 changes: 7 additions & 1 deletion examples/tpch/q17_small_quantity_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@
# Find the average quantity
window_frame = WindowFrame("rows", None, None)
df = df.with_column(
"avg_quantity", F.window("avg", [col("l_quantity")], window_frame=window_frame, partition_by=[col("l_partkey")])
"avg_quantity",
F.window(
"avg",
[col("l_quantity")],
window_frame=window_frame,
partition_by=[col("l_partkey")],
),
)

df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity"))
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q20_potential_part_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@

ctx = SessionContext()

df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name")
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_name"
)
df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_shipdate", "l_partkey", "l_suppkey", "l_quantity"
)
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q22_global_sales_opportunity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_phone", "c_acctbal", "c_custkey"
)
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns("o_custkey")
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_custkey"
)

# The nation code is a two digit number, but we need to convert it to a string literal
nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES])
Expand Down
7 changes: 5 additions & 2 deletions examples/tpch/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
"""

import os
from pathlib import Path


def get_data_path(filename: str) -> str:
path = os.path.dirname(os.path.abspath(__file__))

return os.path.join(path, "data", filename)


def get_answer_file(answer_file: str) -> str:
path = os.path.dirname(os.path.abspath(__file__))

return os.path.join(path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out")
return os.path.join(
path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out"
)
15 changes: 15 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,18 @@ exclude = [".github/**", "ci/**", ".asf.yaml"]
# Require Cargo.lock is up to date
locked = true
features = ["substrait"]

# Enable docstring linting using the google style guide
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "D"]
ignore = ["D417"]

[tool.ruff.lint.pydocstyle]
convention = "google"

# Disable docstring checking for these directories
[tool.ruff.lint.per-file-ignores]
"python/datafusion/tests/*" = ["D"]
"examples/*" = ["D"]
"dev/*" = ["D"]
"benchmarks/*" = ["D", "F"]
Loading