Skip to content

Commit

Permalink
fix(rust): Indicative error in list.gather when wrong indices type …
Browse files Browse the repository at this point in the history
…is supplied (#18611)
  • Loading branch information
barak1412 authored Sep 9, 2024
1 parent 18d3073 commit d3a14de
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ pub trait ListNameSpaceImpl: AsList {

use DataType::*;
match idx.dtype() {
List(_) => {
List(boxed_dt) if boxed_dt.is_integer() => {
let idx_ca = idx.list().unwrap();
let mut out = {
list_ca
Expand Down
71 changes: 71 additions & 0 deletions py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
from datetime import date, datetime

import numpy as np
Expand Down Expand Up @@ -159,6 +160,76 @@ def test_list_categorical_get() -> None:
)


def test_list_gather_wrong_indices_list_type() -> None:
a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]])
expected = pl.Series("a", [[1, 2], [4], [6, 9]])

# int8
indices_series = pl.Series("indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int8))
result = a.list.gather(indices=indices_series)
assert_series_equal(result, expected)

# int16
indices_series = pl.Series(
"indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int16)
)
result = a.list.gather(indices=indices_series)
assert_series_equal(result, expected)

# int32
indices_series = pl.Series(
"indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int32)
)
result = a.list.gather(indices=indices_series)
assert_series_equal(result, expected)

# int64
indices_series = pl.Series(
"indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int64)
)
result = a.list.gather(indices=indices_series)
assert_series_equal(result, expected)

# uint8
indices_series = pl.Series(
"indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt8)
)
result = a.list.gather(indices=indices_series)
assert_series_equal(result, expected)

# uint16
indices_series = pl.Series(
"indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt16)
)
result = a.list.gather(indices=indices_series)
assert_series_equal(result, expected)

# uint32
indices_series = pl.Series(
"indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt32)
)
result = a.list.gather(indices=indices_series)
assert_series_equal(result, expected)

# uint64
indices_series = pl.Series(
"indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt64)
)
result = a.list.gather(indices=indices_series)
assert_series_equal(result, expected)

df = pl.DataFrame(
{
"index": [["2"], ["2"], ["2"]],
"lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]],
}
)
with pytest.raises(
ComputeError, match=re.escape("cannot use dtype `list[str]` as an index")
):
df.select(pl.col("lists").list.gather(pl.col("index")))


def test_contains() -> None:
a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]])
out = a.list.contains(2)
Expand Down

0 comments on commit d3a14de

Please sign in to comment.