Skip to content

Commit

Permalink
Add Arrow type casting to struct for Image and Audio + Support nested…
Browse files Browse the repository at this point in the history
… casting (huggingface#3575)

* add storage cast

* implement dict cast for image

* factorize extension type creation for audio and image + implement type cast for thos custom types

* fix tests

* style

* [big] allow extension array in nested arrays

* docs

* style

* fix Features pickling

* fix some tests

* fix more tests

* fix more tests

* add base extensionarray for pyarrow<6

* add extensionarray for pyarrow<6

* add soundfile to tests requirements

* minor

* remove not implemented error for complex casting in pyarrow 3

* style

* style again

* add casting for fixed size lists

* add libsndfile1 in the linux CI

* style

* typo

* start adding new tests just to notice the concatenation issue...

* [big] remove extension types + move cast_storage to the Image and Audio classes

* minor

* fix test

* style

* add more tests to image

* add audio tests

* support casting from null array

* fix field names verifications when casting

* docs + tests

* use the new table_cast on pyarrow tables

* whoops forgot one line

* remove unused string handling in Image.decode_example

* update tests accordingly
  • Loading branch information
lhoestq committed Jan 21, 2022
1 parent e58ce4b commit 6ca96c7
Show file tree
Hide file tree
Showing 20 changed files with 975 additions and 411 deletions.
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ jobs:
resource_class: medium
steps:
- checkout
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: python -m venv venv
- run: source venv/bin/activate
Expand All @@ -26,6 +27,7 @@ jobs:
resource_class: medium
steps:
- checkout
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: python -m venv venv
- run: source venv/bin/activate
Expand Down
Binary file modified datasets/wmt19/dummy/cs-en/1.0.0/dummy_data.zip
Binary file not shown.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
"tensorflow>=2.3,!=2.6.0,!=2.6.1",
"torch",
"torchaudio",
"soundfile",
"transformers",
# datasets dependencies
"bs4",
Expand Down Expand Up @@ -172,6 +173,7 @@
]

TESTS_REQUIRE.extend(VISION_REQURE)
TESTS_REQUIRE.extend(AUDIO_REQUIRE)

if os.name != "nt":
# dependencies of unbabel-comet
Expand Down
46 changes: 19 additions & 27 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,7 @@
from .info import DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit, Split, SplitInfo
from .table import (
InMemoryTable,
MemoryMappedTable,
Table,
cast_with_sliced_list_support,
concat_tables,
list_table_cache_files,
)
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files, table_cast
from .tasks import TaskTemplate
from .utils import logging
from .utils.deprecation_utils import deprecated
Expand Down Expand Up @@ -636,8 +629,6 @@ def __init__(

if self._data.schema.metadata is not None and "huggingface".encode("utf-8") in self._data.schema.metadata:
metadata = json.loads(self._data.schema.metadata["huggingface".encode("utf-8")].decode())
if "info" in metadata and self.info.features is None: # try to load features from the arrow file metadata
self._info.features = DatasetInfo.from_dict(metadata["info"]).features
if (
"fingerprint" in metadata and self._fingerprint is None
): # try to load fingerprint from the arrow file metadata
Expand Down Expand Up @@ -780,7 +771,7 @@ def from_pandas(
info = DatasetInfo()
info.features = features
table = InMemoryTable.from_pandas(
df=df, preserve_index=preserve_index, schema=pa.schema(features.type) if features is not None else None
df=df, preserve_index=preserve_index, schema=features.arrow_schema if features is not None else None
)
return cls(table, info=info, split=split)

Expand Down Expand Up @@ -815,10 +806,12 @@ def from_dict(
if features is not None:
mapping = features.encode_batch(mapping)
mapping = {
col: OptimizedTypedSequence(data, type=features.type[col].type if features is not None else None, col=col)
col: OptimizedTypedSequence(data, type=features[col] if features is not None else None, col=col)
for col, data in mapping.items()
}
pa_table = InMemoryTable.from_pydict(mapping=mapping)
if info.features is None:
info.features = Features({col: ts.get_inferred_type() for col, ts in mapping.items()})
return cls(pa_table, info=info, split=split)

@staticmethod
Expand Down Expand Up @@ -1221,12 +1214,13 @@ def stringify_column(batch):
dset = self

# Create the new feature
class_names = sorted(sample for sample in dset.unique(column) if include_nulls or sample is not None)
class_names = sorted(str(sample) for sample in dset.unique(column) if include_nulls or sample is not None)
dst_feat = ClassLabel(names=class_names)

def cast_to_class_labels(batch):
batch[column] = [
dst_feat.str2int(sample) if include_nulls or sample is not None else None for sample in batch[column]
dst_feat.str2int(str(sample)) if include_nulls or sample is not None else None
for sample in batch[column]
]
return batch

Expand Down Expand Up @@ -1280,7 +1274,7 @@ def flatten_(self, max_depth=16):
self._data = self._data.flatten()
else:
break
self.info.features = Features.from_arrow_schema(self._data.schema)
self.info.features = self.features.flatten(max_depth=max_depth)
self._data = update_metadata_with_features(self._data, self.features)
logger.info(f'Flattened dataset from depth {depth} to depth { 1 if depth + 1 < max_depth else "unknown"}.')

Expand All @@ -1299,7 +1293,7 @@ def flatten(self, new_fingerprint, max_depth=16) -> "Dataset":
dataset._data = dataset._data.flatten()
else:
break
dataset.info.features = Features.from_arrow_schema(dataset._data.schema)
dataset.info.features = self.features.flatten(max_depth=max_depth)
dataset._data = update_metadata_with_features(dataset._data, dataset.features)
logger.info(f'Flattened dataset from depth {depth} to depth {1 if depth + 1 < max_depth else "unknown"}.')
dataset._fingerprint = new_fingerprint
Expand Down Expand Up @@ -1348,10 +1342,8 @@ def cast_(
type = features.type
schema = pa.schema({col_name: type[col_name].type for col_name in self._data.column_names})
dataset = self.with_format("arrow")
# capture the PyArrow version here to make the lambda serializable on Windows
is_pyarrow_at_least_4 = config.PYARROW_VERSION.major >= 4
dataset = dataset.map(
lambda t: t.cast(schema) if is_pyarrow_at_least_4 else cast_with_sliced_list_support(t, schema),
partial(table_cast, schema=schema),
batched=True,
batch_size=batch_size,
keep_in_memory=keep_in_memory,
Expand Down Expand Up @@ -1406,14 +1398,12 @@ def cast(
f"as the columns in the dataset: {self._data.column_names}"
)

type = features.type
schema = pa.schema({col_name: type[col_name].type for col_name in self._data.column_names})
schema = features.arrow_schema
format = self.format
dataset = self.with_format("arrow")
# capture the PyArrow version here to make the lambda serializable on Windows
is_pyarrow_at_least_4 = config.PYARROW_VERSION.major >= 4
dataset = dataset.map(
lambda t: t.cast(schema) if is_pyarrow_at_least_4 else cast_with_sliced_list_support(t, schema),
partial(table_cast, schema=schema),
batched=True,
batch_size=batch_size,
keep_in_memory=keep_in_memory,
Expand All @@ -1433,15 +1423,17 @@ def cast_column(self, column: str, feature: FeatureType, new_fingerprint: str) -
Args:
column (:obj:`str`): Column name.
feature (:class:`Feature`): Target feature.
feature (:class:`FeatureType`): Target feature.
Returns:
:class:`Dataset`
"""
if hasattr(feature, "decode_example"):
if hasattr(feature, "cast_storage"):
dataset = copy.deepcopy(self)
dataset.features[column] = feature
dataset._fingerprint = new_fingerprint
dataset._data = dataset._data.cast(dataset.features.arrow_schema)
dataset._data = update_metadata_with_features(dataset._data, dataset.features)
return dataset
else:
features = self.features.copy()
Expand Down Expand Up @@ -3860,8 +3852,8 @@ def add_item(self, item: dict, new_fingerprint: str):
# Cast to align the schemas of the tables and concatenate the tables
table = concat_tables(
[
self._data.cast(pa.schema(dset_features.type)) if self.features != dset_features else self._data,
item_table.cast(pa.schema(item_features.type)),
self._data.cast(dset_features.arrow_schema) if self.features != dset_features else self._data,
item_table.cast(item_features.arrow_schema),
]
)
if self._indices is None:
Expand Down
Loading

0 comments on commit 6ca96c7

Please sign in to comment.