Skip to content

Commit

Permalink
Check if indices values in Dataset.select are within bounds (huggin…
Browse files Browse the repository at this point in the history
…gface#3719)

* Add check

* Add test
  • Loading branch information
mariosasko committed Feb 14, 2022
1 parent 5f94a9f commit 99816c6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
13 changes: 13 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,13 @@ def _check_column_names(column_names: List[str]):
raise ValueError(f"The table can't have duplicated columns but columns {duplicated_columns} are duplicated.")


def _check_valid_indices_value(value, size):
if (value < 0 and value + size < 0) or (value >= size):
raise IndexError(
f"Invalid value {value} in indices iterable. All values must be within range [-{size}, {size - 1}]."
)


def _check_if_features_can_be_aligned(features_list: List[Features]):
"""Check if the dictionaries of features can be aligned.
Expand Down Expand Up @@ -2760,6 +2767,12 @@ def select(
path=tmp_file.name, writer_batch_size=writer_batch_size, fingerprint=new_fingerprint, unit="indices"
)

indices = list(indices)

size = len(self)
_check_valid_indices_value(int(max(indices)), size=size)
_check_valid_indices_value(int(min(indices)), size=size)

indices_array = pa.array(indices, type=pa.uint64())
# Check if we need to convert indices
if self._indices is not None:
Expand Down
16 changes: 15 additions & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,7 +1492,21 @@ def test_select(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
bad_indices = list(range(5))
bad_indices[3] = "foo"
bad_indices[-1] = len(dset) + 10 # out of bounds
tmp_file = os.path.join(tmp_dir, "test.arrow")
self.assertRaises(
Exception,
dset.select,
indices=bad_indices,
indices_cache_file_name=tmp_file,
writer_batch_size=2,
)
self.assertFalse(os.path.exists(tmp_file))

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
bad_indices = list(range(5))
bad_indices[3] = "foo" # wrong type
tmp_file = os.path.join(tmp_dir, "test.arrow")
self.assertRaises(
Exception,
Expand Down

0 comments on commit 99816c6

Please sign in to comment.