Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
LoicGrobol committed Sep 18, 2023
1 parent 413e2b6 commit 6244473
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ classifiers = [
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Environment :: Console",
]
keywords = ["nlp", "transformers", "language-model"]
Expand Down
13 changes: 6 additions & 7 deletions zeldarose/datasets/mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
from loguru import logger
from torch.nn.utils.rnn import pad_sequence

# Nouvo plan : un lecteur de jsonlines custom qui prédécoupe en source/target avec attribut src et
# tgt, on charge ça dans dataset pour le système de cache, puis dans le dataloader on sample et
# quelque part dans le trainmodule on ajoute le bruit utiliser
# utiliser
# <https://huggingface.co/docs/datasets/loading#python-generator> comme ça on peut streamer l'entrée


Expand All @@ -44,8 +42,9 @@ def extract_from_jsonline(
source_langs: Collection[str],
target_langs: Collection[str],
) -> Generator[DataRow, None, None]:
# We deal with both top-level tranlatifrdgggggggggggggggggggggggggggggggggggggggggwons and 🤗's
# conventional format for this task
# We deal with both top-level (`{fr: "J'ai chanté", "br": "Me m'eus kanet."}`) and 🤗's
# conventional format (`{"translation": {fr: "J'ai chanté", "br": "Me m'eus kanet."}}`) for this
# task.
example = cast(Mapping[str, str], example.get("translation", example))
for dns_lang in denoise_langs:
if not (dns_str := example.get(dns_lang)):
Expand Down Expand Up @@ -314,8 +313,8 @@ def __init__(
else:
self.val_dataset_path = None

self.train_dataset = None
self.val_dataset = None
self.train_dataset: Optional[datasets.Dataset] = None
self.val_dataset: Optional[datasets.Dataset] = None

def prepare_data(self):
# NOTE (2021-08-12): This should'nt be needed since this method should only be called on
Expand Down
4 changes: 2 additions & 2 deletions zeldarose/datasets/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def __init__(
else:
self.val_dataset_path = None

self.train_dataset = None
self.val_dataset = None
self.train_dataset: Optional[datasets.Dataset] = None
self.val_dataset: Optional[datasets.Dataset] = None

def prepare_data(self):
# NOTE(2021-08-12): This should'nt be needed since this method should only be called on rank
Expand Down
2 changes: 1 addition & 1 deletion zeldarose/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def main(
f" loader batch size({device_batch_size} samples per device × {total_devices} devices)"
" try using fewer devices"
)
elif tuning_config.batch_size % (device_batch_size * total_devices):
elif tuning_config.batch_size % (device_batch_size * total_devices) != 0:
remainder = tuning_config.batch_size % device_batch_size * total_devices
logger.warning(
f"Batch size ({tuning_config.batch_size}) is not a multiple of loader batch size"
Expand Down

0 comments on commit 6244473

Please sign in to comment.