Skip to content

Commit

Permalink
[traner] fix --lr_scheduler_type choices (huggingface#9800)
Browse files Browse the repository at this point in the history
* fix --lr_scheduler_type choices

* rewrite to fix for all enum-based cl args

* cleanup

* adjust test

* style

* Proposal that should work

* Remove needless code

* Fix test

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
  • Loading branch information
stas00 and sgugger authored Jan 27, 2021
1 parent 893120f commit 7c6d632
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
8 changes: 4 additions & 4 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
field.type = prim_type

if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = list(field.type)
kwargs["type"] = field.type
kwargs["choices"] = [x.value for x in field.type]
kwargs["type"] = type(kwargs["choices"][0])
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.type is bool or field.type is Optional[bool]:
Expand Down Expand Up @@ -198,7 +198,7 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
data = json.loads(Path(json_file).read_text())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
Expand All @@ -211,7 +211,7 @@ def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in args.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
Expand Down
15 changes: 11 additions & 4 deletions tests/test_hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ class BasicEnum(Enum):

@dataclass
class EnumExample:
foo: BasicEnum = BasicEnum.toto
foo: BasicEnum = "toto"

def __post_init__(self):
self.foo = BasicEnum(self.foo)


@dataclass
Expand Down Expand Up @@ -133,14 +136,18 @@ def test_with_enum(self):
parser = HfArgumentParser(EnumExample)

expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=BasicEnum.toto, choices=list(BasicEnum), type=BasicEnum)
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
self.argparsersEqual(parser, expected)

args = parser.parse_args([])
self.assertEqual(args.foo, BasicEnum.toto)
self.assertEqual(args.foo, "toto")
enum_ex = parser.parse_args_into_dataclasses([])[0]
self.assertEqual(enum_ex.foo, BasicEnum.toto)

args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, BasicEnum.titi)
self.assertEqual(args.foo, "titi")
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
self.assertEqual(enum_ex.foo, BasicEnum.titi)

def test_with_list(self):
parser = HfArgumentParser(ListExample)
Expand Down

0 comments on commit 7c6d632

Please sign in to comment.