Skip to content

Commit

Permalink
Add test that diagonal concat with mismatching schemas raises (rapids…
Browse files Browse the repository at this point in the history
…ai#16006)

Arguably this should be determined during query optimization by polars, but for now it is raised late during compute, so we must validate on our side.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Thomas Li (https://github.com/lithomas1)

URL: rapidsai#16006
  • Loading branch information
wence- authored Jun 12, 2024
1 parent 97518ac commit b35991c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,10 +933,10 @@ class Union(IR):
"""Optional slice to apply after concatenation."""

def __post_init__(self) -> None:
"""Validated preconditions."""
"""Validate preconditions."""
schema = self.dfs[0].schema
if not all(s.schema == schema for s in self.dfs[1:]):
raise ValueError("Schema mismatch")
raise NotImplementedError("Schema mismatch")

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
Expand Down
16 changes: 16 additions & 0 deletions python/cudf_polars/tests/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars import translate_ir
from cudf_polars.testing.asserts import assert_gpu_result_equal


Expand All @@ -19,6 +22,19 @@ def test_union():
assert_gpu_result_equal(query)


def test_union_schema_mismatch_raises():
ldf = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7],
"b": [1, 1, 1, 1, 1, 1, 1],
}
).lazy()
ldf2 = ldf.select(pl.col("a").cast(pl.Float32))
query = pl.concat([ldf, ldf2], how="diagonal")
with pytest.raises(NotImplementedError):
_ = translate_ir(query._ldf.visit())


def test_concat_vertical():
ldf = pl.LazyFrame(
{
Expand Down

0 comments on commit b35991c

Please sign in to comment.