Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update format, fingerprint and indices after add_item #2254

Merged
merged 6 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
update format, fingerprint and indices after add_item
  • Loading branch information
lhoestq committed Apr 23, 2021
commit a877fffd9f1e90cea45dac788a3649312fab64af
18 changes: 16 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2858,7 +2858,9 @@ def add_elasticsearch_index(
)
return self

def add_item(self, item: dict):
@transmit_format
@fingerprint_transform(inplace=False)
def add_item(self, item: dict, new_fingerprint: str):
"""Add item to Dataset.

.. versionadded:: 1.6
Expand All @@ -2875,7 +2877,19 @@ def add_item(self, item: dict):
item_table = item_table.cast(schema)
# Concatenate tables
table = concat_tables([self._data, item_table])
return Dataset(table)
if self._indices is None:
indices_table = None
else:
new_indices_array = pa.array([len(self._data)], type=pa.uint64())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe calling it item_indices_table to be consistent with item_table above.

new_indices_table = InMemoryTable.from_arrays([new_indices_array], names=["indices"])
indices_table = concat_tables([self._indices, new_indices_table])
return Dataset(
table,
info=copy.deepcopy(self.info),
split=self.split,
indices_table=indices_table,
fingerprint=new_fingerprint,
)


def concatenate_datasets(
Expand Down
26 changes: 18 additions & 8 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,10 @@ def test_concatenate_datasets_duplicate_columns(dataset):
assert "duplicated" in str(excinfo.value)


@pytest.mark.parametrize(
"transform",
[None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})],
)
@pytest.mark.parametrize("in_memory", [False, True])
@pytest.mark.parametrize(
"item",
Expand All @@ -1958,22 +1962,28 @@ def test_concatenate_datasets_duplicate_columns(dataset):
{"col_1": 4.0, "col_2": 4.0, "col_3": 4.0},
],
)
def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path):
dataset = (
def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path, transform):
dataset_to_test = (
Dataset(InMemoryTable.from_pydict(dataset_dict))
if in_memory
else Dataset(MemoryMappedTable.from_file(arrow_path))
)
dataset = dataset.add_item(item)
if transform is not None:
transform_name, args, kwargs = transform
dataset_to_test: Dataset = getattr(dataset_to_test, transform_name)(*args, **kwargs)
dataset = dataset_to_test.add_item(item)
assert dataset.data.shape == (5, 3)
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
expected_features = dataset_to_test.features
assert dataset.data.column_names == list(expected_features.keys())
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one
dataset = dataset.add_item(item)
assert dataset.data.shape == (6, 3)
assert dataset.features[feature] == expected_dtype
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one
assert dataset.format["type"] == dataset_to_test.format["type"]
assert dataset._fingerprint != dataset_to_test._fingerprint
dataset.reset_format()
dataset_to_test.reset_format()
assert dataset[:-1] == dataset_to_test[:]
assert {k: int(v) for k, v in dataset[-1].items()} == {k: int(v) for k, v in item.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to introduce an explicit test for the _indices?

For example, this test passes even if I wrongly set in add_item:

new_indices_array = pa.array([9], type=pa.uint64())



@pytest.mark.parametrize("keep_in_memory", [False, True])
Expand Down