diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py index 2d26e42604da03..28000015ab173a 100644 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -43,6 +43,7 @@ HfArgumentParser, Trainer, TrainingArguments, + set_seed, ) from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version, send_example_telemetry @@ -214,6 +215,9 @@ def main(): "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) + # Set seed before initializing model. + set_seed(training_args.seed) + # Initialize our dataset and prepare it for the 'image-classification' task. if data_args.dataset_name is not None: dataset = load_dataset(