Skip to content

Commit

Permalink
Fix add_column on datasets with indices mapping (huggingface#3647)
Browse files Browse the repository at this point in the history
* Flatten indices in add_column if indices table exists

* Add test

* Address review comment
  • Loading branch information
mariosasko committed Jan 28, 2022
1 parent 3adc314 commit e8cd145
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3649,13 +3649,14 @@ def add_column(self, name: str, column: Union[list, np.array], new_fingerprint:
"""
column_table = InMemoryTable.from_pydict({name: column})
_check_column_names(self._data.column_names + column_table.column_names)
dataset = self.flatten_indices() if self._indices is not None else self
# Concatenate tables horizontally
table = concat_tables([self._data, column_table], axis=1)
table = concat_tables([dataset._data, column_table], axis=1)
# Update features
info = self.info.copy()
info = dataset.info.copy()
info.features.update(Features.from_arrow_schema(column_table.schema))
table = update_metadata_with_features(table, info.features)
return Dataset(table, info=info, split=self.split, indices_table=self._indices, fingerprint=new_fingerprint)
return Dataset(table, info=info, split=self.split, indices_table=None, fingerprint=new_fingerprint)

def add_faiss_index(
self,
Expand Down
11 changes: 9 additions & 2 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2491,7 +2491,13 @@ def test_interleave_datasets_probabilities():
@pytest.mark.parametrize("in_memory", [False, True])
@pytest.mark.parametrize(
"transform",
[None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})],
[
None,
("shuffle", (42,), {}),
("with_format", ("pandas",), {}),
("class_encode_column", ("col_2",), {}),
("select", (range(3),), {}),
],
)
def test_dataset_add_column(column, expected_dtype, in_memory, transform, dataset_dict, arrow_path):
column_name = "col_4"
Expand All @@ -2503,8 +2509,9 @@ def test_dataset_add_column(column, expected_dtype, in_memory, transform, datase
if transform is not None:
transform_name, args, kwargs = transform
original_dataset: Dataset = getattr(original_dataset, transform_name)(*args, **kwargs)
column = column[:3] if transform is not None and transform_name == "select" else column
dataset = original_dataset.add_column(column_name, column)
assert dataset.data.shape == (4, 4)
assert dataset.data.shape == (3, 4) if transform is not None and transform_name == "select" else (4, 4)
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
# Sort expected features as in the original dataset
expected_features = {feature: expected_features[feature] for feature in original_dataset.features}
Expand Down

0 comments on commit e8cd145

Please sign in to comment.