Skip to content

Commit

Permalink
Fix task template reload from dict (huggingface#5106)
Browse files Browse the repository at this point in the history
fix task template reload from dict
  • Loading branch information
lhoestq committed Oct 13, 2022
1 parent dc4c764 commit 99680a7
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/datasets/tasks/audio_classificiation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Audio, ClassLabel, Features
Expand All @@ -8,7 +8,7 @@

@dataclass(frozen=True)
class AudioClassification(TaskTemplate):
task: str = "audio-classification"
task: str = field(default="audio-classification", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"audio": Audio()})
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
audio_column: str = "audio"
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Audio, Features, Value
Expand All @@ -8,7 +8,7 @@

@dataclass(frozen=True)
class AutomaticSpeechRecognition(TaskTemplate):
task: str = "automatic-speech-recognition"
task: str = field(default="automatic-speech-recognition", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"audio": Audio()})
label_schema: ClassVar[Features] = Features({"transcription": Value("string")})
audio_column: str = "audio"
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/image_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import ClassLabel, Features, Image
Expand All @@ -8,7 +8,7 @@

@dataclass(frozen=True)
class ImageClassification(TaskTemplate):
task: str = "image-classification"
task: str = field(default="image-classification", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"image": Image()})
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
image_column: str = "image"
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/language_modeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Features, Value
Expand All @@ -7,7 +7,7 @@

@dataclass(frozen=True)
class LanguageModeling(TaskTemplate):
task: str = "language-modeling"
task: str = field(default="language-modeling", metadata={"include_in_asdict_even_if_is_default": True})

input_schema: ClassVar[Features] = Features({"text": Value("string")})
label_schema: ClassVar[Features] = Features({})
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/question_answering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Features, Sequence, Value
Expand All @@ -8,7 +8,7 @@
@dataclass(frozen=True)
class QuestionAnsweringExtractive(TaskTemplate):
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str = "question-answering-extractive"
task: str = field(default="question-answering-extractive", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")})
label_schema: ClassVar[Features] = Features(
{
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/summarization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Features, Value
Expand All @@ -8,7 +8,7 @@
@dataclass(frozen=True)
class Summarization(TaskTemplate):
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str = "summarization"
task: str = field(default="summarization", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"text": Value("string")})
label_schema: ClassVar[Features] = Features({"summary": Value("string")})
text_column: str = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import ClassLabel, Features, Value
Expand All @@ -9,7 +9,7 @@
@dataclass(frozen=True)
class TextClassification(TaskTemplate):
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str = "text-classification"
task: str = field(default="text-classification", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"text": Value("string")})
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
text_column: str = "text"
Expand Down
23 changes: 23 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from copy import deepcopy
from unittest.case import TestCase

import pytest

from datasets.arrow_dataset import Dataset
from datasets.features import Audio, ClassLabel, Features, Image, Sequence, Value
from datasets.info import DatasetInfo
Expand All @@ -12,7 +14,9 @@
QuestionAnsweringExtractive,
Summarization,
TextClassification,
task_template_from_dict,
)
from datasets.utils.py_utils import asdict


SAMPLE_QUESTION_ANSWERING_EXTRACTIVE = {
Expand All @@ -24,6 +28,25 @@
}


@pytest.mark.parametrize(
"task_cls",
[
AudioClassification,
AutomaticSpeechRecognition,
ImageClassification,
LanguageModeling,
QuestionAnsweringExtractive,
Summarization,
TextClassification,
],
)
def test_reload_task_from_dict(task_cls):
task = task_cls()
task_dict = asdict(task)
reloaded = task_template_from_dict(task_dict)
assert task == reloaded


class TestLanguageModeling:
def test_column_mapping(self):
task = LanguageModeling(text_column="input_text")
Expand Down

0 comments on commit 99680a7

Please sign in to comment.