Skip to content

Commit

Permalink
fix: allow for date/datetime subclasses (e.g. pd.Timestamp, FreezeGun…
Browse files Browse the repository at this point in the history
…) in pl.lit
  • Loading branch information
MarcoGorelli committed Aug 31, 2024
1 parent 4dc90a9 commit f260a44
Show file tree
Hide file tree
Showing 89 changed files with 335 additions and 309 deletions.
24 changes: 18 additions & 6 deletions crates/polars-python/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,24 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult<PyExpr> {
});
Ok(dsl::lit(s).into())
} else {
Err(PyTypeError::new_err(format!(
"cannot create expression literal for value of type {}: {}\
\n\nHint: Pass `allow_object=True` to accept any value and create a literal of type Object.",
value.get_type().qualname()?,
value.repr()?
)))
Python::with_gil(|py| {
// One final attempt before erroring. Do we have a date/datetime subclass?
// E.g. pd.Timestamp, or Freezegun.
let datetime_module = PyModule::import_bound(py, "datetime")?;
let datetime_class = datetime_module.getattr("datetime")?;
let date_class = datetime_module.getattr("date")?;
if value.is_instance(&datetime_class)? || value.is_instance(&date_class)? {
let av = py_object_to_any_value(value, true)?;
Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into())
} else {
Err(PyTypeError::new_err(format!(
"cannot create expression literal for value of type {}: {}\
\n\nHint: Pass `allow_object=True` to accept any value and create a literal of type Object.",
value.get_type().name()?,
value.repr()?
)))
}
})
}
}

Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/benchmark/interop/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ def floats_array() -> np.ndarray[Any, Any]:
return np.random.randn(n_rows)


@pytest.fixture
@pytest.fixture()
def floats(floats_array: np.ndarray[Any, Any]) -> pl.Series:
return pl.Series(floats_array)


@pytest.fixture
@pytest.fixture()
def floats_with_nulls(floats: pl.Series) -> pl.Series:
null_probability = 0.1
validity = pl.Series(np.random.uniform(size=floats.len())) > null_probability
return pl.select(pl.when(validity).then(floats)).to_series()


@pytest.fixture
@pytest.fixture()
def floats_chunked(floats_array: np.ndarray[Any, Any]) -> pl.Series:
n_chunks = 5
chunk_len = len(floats_array) // n_chunks
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/docs/test_user_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _change_test_dir() -> Iterator[None]:
os.chdir(current_path)


@pytest.mark.docs
@pytest.mark.docs()
@pytest.mark.parametrize("path", snippet_paths)
@pytest.mark.usefixtures("_change_test_dir")
def test_run_python_snippets(path: Path) -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/cloud/test_prepare_cloud_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_prepare_cloud_plan_fail_on_local_data_source(lf: pl.LazyFrame) -> None:
prepare_cloud_plan(lf)


@pytest.mark.write_disk
@pytest.mark.write_disk()
def test_prepare_cloud_plan_fail_on_python_scan(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
data_path = tmp_path / "data.parquet"
Expand Down
12 changes: 6 additions & 6 deletions py-polars/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
NESTED_DTYPES = [pl.List, pl.Struct, pl.Array]


@pytest.fixture
@pytest.fixture()
def partition_limit() -> int:
"""The limit at which Polars will start partitioning in debug builds."""
return 15


@pytest.fixture
@pytest.fixture()
def df() -> pl.DataFrame:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -68,14 +68,14 @@ def df() -> pl.DataFrame:
)


@pytest.fixture
@pytest.fixture()
def df_no_lists(df: pl.DataFrame) -> pl.DataFrame:
return df.select(
pl.all().exclude(["list_str", "list_int", "list_bool", "list_int", "list_flt"])
)


@pytest.fixture
@pytest.fixture()
def fruits_cars() -> pl.DataFrame:
return pl.DataFrame(
{
Expand All @@ -88,7 +88,7 @@ def fruits_cars() -> pl.DataFrame:
)


@pytest.fixture
@pytest.fixture()
def str_ints_df() -> pl.DataFrame:
n = 1000

Expand Down Expand Up @@ -199,7 +199,7 @@ def get_peak(self) -> int:
return tracemalloc.get_traced_memory()[1]


@pytest.fixture
@pytest.fixture()
def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]:
"""
Provide an API for measuring peak memory usage.
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ def test_reproducible_hash_with_seeds() -> None:
assert_series_equal(expected, result, check_names=False, check_exact=True)


@pytest.mark.slow
@pytest.mark.slow()
@pytest.mark.parametrize(
"e",
[
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/dataframe/test_from_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_from_dict_with_scalars() -> None:
assert df9.rows() == [(0, 2, 0, "x"), (1, 1, 0, "x"), (2, 0, 0, "x")]


@pytest.mark.slow
@pytest.mark.slow()
def test_from_dict_with_values_mixed() -> None:
# a bit of everything
mixed_dtype_data: dict[str, Any] = {
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/dataframe/test_partition_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import polars.selectors as cs


@pytest.fixture
@pytest.fixture()
def df() -> pl.DataFrame:
return pl.DataFrame(
{
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_partition_by_as_dict_include_keys_false_maintain_order_false() -> None:
df.partition_by(["a"], maintain_order=False, include_key=False, as_dict=True)


@pytest.mark.slow
@pytest.mark.slow()
def test_partition_by_as_dict_include_keys_false_large() -> None:
# test with both as_dict and include_key=False
df = pl.DataFrame(
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/dataframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_df_serde_to_from_buffer(
assert_frame_equal(df, read_df, categorical_as_str=True)


@pytest.mark.write_disk
@pytest.mark.write_disk()
def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/dataframe/test_vstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from polars.testing import assert_frame_equal


@pytest.fixture
@pytest.fixture()
def df1() -> pl.DataFrame:
return pl.DataFrame({"foo": [1, 2], "bar": [6, 7], "ham": ["a", "b"]})


@pytest.fixture
@pytest.fixture()
def df2() -> pl.DataFrame:
return pl.DataFrame({"foo": [3, 4], "bar": [8, 9], "ham": ["c", "d"]})

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_cast_list_to_array(data: Any, inner_type: pl.DataType) -> None:
assert s.to_list() == data


@pytest.fixture
@pytest.fixture()
def data_dispersion() -> pl.DataFrame:
return pl.DataFrame(
{
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import polars as pl


@pytest.mark.slow
@pytest.mark.slow()
def test_bool_arg_min_max() -> None:
# masks that ensures we take more than u64 chunks
# and slicing and dicing to ensure the offsets work
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_cast_inner_categorical() -> None:
)


@pytest.mark.slow
@pytest.mark.slow()
def test_stringcache() -> None:
N = 1_500
with pl.StringCache():
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def permutations_int_dec_none() -> list[tuple[D | int | None, ...]]:
)


@pytest.mark.slow
@pytest.mark.slow()
def test_series_from_pydecimal_and_ints(
permutations_int_dec_none: list[tuple[D | int | None, ...]],
) -> None:
Expand All @@ -45,7 +45,7 @@ def test_series_from_pydecimal_and_ints(
assert s.to_list() == [D(x) if x is not None else None for x in data]


@pytest.mark.slow
@pytest.mark.slow()
def test_frame_from_pydecimal_and_ints(
permutations_int_dec_none: list[tuple[D | int | None, ...]], monkeypatch: Any
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def test_as_list_logical_type() -> None:
).to_dict(as_series=False) == {"literal": [True], "timestamp": [[date(2000, 1, 1)]]}


@pytest.fixture
@pytest.fixture()
def data_dispersion() -> pl.DataFrame:
return pl.DataFrame(
{
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/functions/range/test_date_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,17 @@ def test_date_ranges_datetime_input() -> None:
"literal", [[date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3)]]
)
assert_series_equal(result, expected)


def test_date_range_with_subclass_18470_18447() -> None:
class MyAmazingDate(date):
pass

class MyAmazingDatetime(datetime):
pass

result = pl.datetime_range(
MyAmazingDate(2020, 1, 1), MyAmazingDatetime(2020, 1, 2), eager=True
)
expected = pl.Series("literal", [datetime(2020, 1, 1), datetime(2020, 1, 2)])
assert_series_equal(result, expected)
4 changes: 2 additions & 2 deletions py-polars/tests/unit/functions/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import polars as pl


@pytest.mark.slow
@pytest.mark.slow()
def test_concat_expressions_stack_overflow() -> None:
n = 10000
e = pl.concat([pl.lit(x) for x in range(n)])
Expand All @@ -12,7 +12,7 @@ def test_concat_expressions_stack_overflow() -> None:
assert df.shape == (n, 1)


@pytest.mark.slow
@pytest.mark.slow()
def test_concat_lf_stack_overflow() -> None:
n = 1000
bar = pl.DataFrame({"a": 0}).lazy()
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/functions/test_when_then.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def test_when_then_null_broadcast() -> None:
)


@pytest.mark.slow
@pytest.mark.slow()
@pytest.mark.parametrize("len", [1, 10, 100, 500])
@pytest.mark.parametrize(
("dtype", "vals"),
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/interchange/test_from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,13 @@ def test_construct_offsets_buffer_copy() -> None:
assert_series_equal(result, expected)


@pytest.fixture
@pytest.fixture()
def bitmask() -> PolarsBuffer:
data = pl.Series([False, True, True, False])
return PolarsBuffer(data)


@pytest.fixture
@pytest.fixture()
def bytemask() -> PolarsBuffer:
data = pl.Series([0, 1, 1, 0], dtype=pl.UInt8)
return PolarsBuffer(data)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/interop/test_from_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_untrusted_categorical_input() -> None:
assert_frame_equal(result, expected, categorical_as_str=True)


@pytest.fixture
@pytest.fixture()
def _set_pyarrow_unavailable(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"polars._utils.construction.dataframe._PYARROW_AVAILABLE", False
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/io/cloud/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def s3_base(monkeypatch_module: Any) -> Iterator[str]:
p.kill()


@pytest.fixture
@pytest.fixture()
def s3(s3_base: str, io_files_path: Path) -> str:
region = "us-east-1"
client = boto3.client("s3", region_name=region, endpoint_url=s3_base)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/io/cloud/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from polars.exceptions import ComputeError


@pytest.mark.slow
@pytest.mark.slow()
@pytest.mark.parametrize("format", ["parquet", "csv", "ndjson", "ipc"])
def test_scan_nonexistent_cloud_path_17444(format: str) -> None:
# https://github.com/pola-rs/polars/issues/17444
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
import pytest


@pytest.fixture
@pytest.fixture()
def io_files_path() -> Path:
return Path(__file__).parent / "files"
4 changes: 2 additions & 2 deletions py-polars/tests/unit/io/database/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path


@pytest.fixture
@pytest.fixture()
def tmp_sqlite_db(tmp_path: Path) -> Path:
test_db = tmp_path / "test.db"
test_db.unlink(missing_ok=True)
Expand Down Expand Up @@ -51,7 +51,7 @@ def convert_date(val: bytes) -> date:
return test_db


@pytest.fixture
@pytest.fixture()
def tmp_sqlite_inference_db(tmp_path: Path) -> Path:
test_db = tmp_path / "test_inference.db"
test_db.unlink(missing_ok=True)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/io/database/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class ExceptionTestParams(NamedTuple):
kwargs: dict[str, Any] | None = None


@pytest.mark.write_disk
@pytest.mark.write_disk()
@pytest.mark.parametrize(
(
"read_method",
Expand Down Expand Up @@ -698,7 +698,7 @@ def test_read_database_cx_credentials(uri: str) -> None:
pl.read_database_uri("SELECT * FROM data", uri=uri, engine="connectorx")


@pytest.mark.write_disk
@pytest.mark.write_disk()
def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None:
import kuzu

Expand Down
Loading

0 comments on commit f260a44

Please sign in to comment.