From d6326c88df1e36a9bb158bb96f870daf7962714b Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Mon, 8 Aug 2022 14:08:11 +0200 Subject: [PATCH] Add seed setting to image classification example (#18519) --- .../pytorch/image-classification/run_image_classification.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py index 2d26e42604da03..28000015ab173a 100644 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -43,6 +43,7 @@ HfArgumentParser, Trainer, TrainingArguments, + set_seed, ) from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version, send_example_telemetry @@ -214,6 +215,9 @@ def main(): "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) + # Set seed before initializing model. + set_seed(training_args.seed) + # Initialize our dataset and prepare it for the 'image-classification' task. if data_args.dataset_name is not None: dataset = load_dataset(