Skip to content

Commit

Permalink
Always set nullable fields in the writer (huggingface#5835)
Browse files Browse the repository at this point in the history
* always use nullable fields in writer

* doc

* test
  • Loading branch information
lhoestq committed May 19, 2023
1 parent a827967 commit 5ebda17
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,13 @@ def _build_writer(self, inferred_schema: pa.Schema):
schema: pa.Schema = inferred_schema
else:
self._features = inferred_features
schema: pa.Schema = inferred_schema
schema: pa.Schema = inferred_features.arrow_schema
if self.disable_nullable:
schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in schema)
if self.with_metadata:
schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=self._features), self.fingerprint))
else:
schema = schema.with_metadata({})
self._schema = schema
self.pa_writer = self._WRITER_CLASS(self.stream, schema)

Expand Down
1 change: 1 addition & 0 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,7 @@ def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features":
"""
Construct [`Features`] from Arrow Schema.
It also checks the schema metadata for Hugging Face Datasets features.
Non-nullable fields are not supported and set to nullable.
Args:
pa_schema (`pyarrow.Schema`):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,11 @@ def test_writer_embed_local_files(tmp_path, embed_local_files):
else:
assert out["image"][0]["path"] == image_path
assert out["image"][0]["bytes"] is None


def test_always_nullable():
non_nullable_schema = pa.schema([pa.field("col_1", pa.string(), nullable=False)])
output = pa.BufferOutputStream()
with ArrowWriter(stream=output) as writer:
writer._build_writer(inferred_schema=non_nullable_schema)
assert writer._schema == pa.schema([pa.field("col_1", pa.string())])

0 comments on commit 5ebda17

Please sign in to comment.