Skip to content

Commit

Permalink
Fix ClassLabel to/from dict when passed names_file (huggingface#3695)
Browse files Browse the repository at this point in the history
* Test ClassLabel to/from dict with names_file

* Fix ClassLabel to/from dict with names_file

* Add explanatory comment
  • Loading branch information
albertvillanova committed Feb 11, 2022
1 parent 3634e16 commit b150e58
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
import sys
from collections.abc import Iterable
from dataclasses import _asdict_inner, dataclass, field, fields
from dataclasses import InitVar, _asdict_inner, dataclass, field, fields
from functools import reduce
from operator import mul
from typing import Any, ClassVar, Dict, List, Optional
Expand Down Expand Up @@ -761,7 +761,7 @@ class ClassLabel:

num_classes: int = None
names: List[str] = None
names_file: Optional[str] = None
names_file: InitVar[Optional[str]] = None # Pseudo-field: ignored by asdict and fields when converting to/from dict
id: Optional[str] = None
# Automatically constructed
dtype: ClassVar[str] = "int64"
Expand All @@ -770,7 +770,8 @@ class ClassLabel:
_int2str: ClassVar[Dict[int, int]] = None
_type: str = field(default="ClassLabel", init=False, repr=False)

def __post_init__(self):
def __post_init__(self, names_file):
self.names_file = names_file
if self.names_file is not None and self.names is not None:
raise ValueError("Please provide either names or names_file but not both.")
# Set self.names
Expand Down
15 changes: 15 additions & 0 deletions tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_cast_to_python_objects,
cast_to_python_objects,
encode_nested_example,
generate_from_dict,
string_to_arrow,
)
from datasets.info import DatasetInfo
Expand Down Expand Up @@ -269,6 +270,20 @@ def test_classlabel_int2str():
classlabel.int2str(len(names))


@pytest.mark.parametrize("class_label_arg", ["names", "names_file"])
def test_class_label_to_and_from_dict(class_label_arg, tmp_path_factory):
names = ["negative", "positive"]
names_file = str(tmp_path_factory.mktemp("features") / "labels.txt")
with open(names_file, "w", encoding="utf-8") as f:
f.write("\n".join(names))
if class_label_arg == "names":
class_label = ClassLabel(names=names)
elif class_label_arg == "names_file":
class_label = ClassLabel(names_file=names_file)
generated_class_label = generate_from_dict(asdict(class_label))
assert generated_class_label == class_label


def test_encode_nested_example_sequence_with_none():
schema = Sequence(Value("int32"))
obj = None
Expand Down

0 comments on commit b150e58

Please sign in to comment.