From 0ef8a0f6719418c885b329d9ed129d06bda45c97 Mon Sep 17 00:00:00 2001 From: Qianli Scott Zhu Date: Tue, 7 Nov 2023 09:09:50 -0800 Subject: [PATCH] Update image_classification_with_vision_transformer.py (#18740) * Update image_classification_with_vision_transformer.py Not sure if this is a bug in checkpoint logic. I got error like below: ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[10], line 41 37 return history 40 vit_classifier = create_vit_classifier() ---> 41 history = run_experiment(vit_classifier) 44 def plot_history(item): 45 plt.plot(history.history[item], label=item) Cell In[10], line 23, in run_experiment(model) 15 checkpoint_filepath = "/tmp/checkpoint" 16 checkpoint_callback = keras.callbacks.ModelCheckpoint( 17 checkpoint_filepath, 18 monitor="val_accuracy", 19 save_best_only=True, 20 save_weights_only=True, 21 ) ---> 23 history = model.fit( 24 x=x_train, 25 y=y_train, 26 batch_size=batch_size, 27 epochs=num_epochs, 28 validation_split=0.1, 29 callbacks=[checkpoint_callback], 30 ) 32 model.load_weights(checkpoint_filepath) 33 _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test) File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback..error_handler(*args, **kwargs) 120 filtered_tb = _process_traceback_frames(e.__traceback__) 121 # To get the full stack trace, call: 122 # `keras.config.disable_traceback_filtering()` --> 123 raise e.with_traceback(filtered_tb) from None 124 finally: 125 del filtered_tb File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/models/model.py:373, in Model.save_weights(self, filepath, overwrite) 363 """Saves all layer weights to a `.weights.h5` file. 364 365 Args: (...) 370 via an interactive prompt. 371 """ 372 if not str(filepath).endswith(".weights.h5"): --> 373 raise ValueError( 374 "The filename must end in `.weights.h5`. " 375 f"Received: filepath={filepath}" 376 ) 377 try: 378 exists = os.path.exists(filepath) ValueError: The filename must end in `.weights.h5`. Received: filepath=/tmp/checkpoint ``` * Update image_classification_with_vision_transformer.py --- .../vision/image_classification_with_vision_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/keras_io/vision/image_classification_with_vision_transformer.py b/examples/keras_io/vision/image_classification_with_vision_transformer.py index 0804ea1b2c7..4e2928fb0a7 100644 --- a/examples/keras_io/vision/image_classification_with_vision_transformer.py +++ b/examples/keras_io/vision/image_classification_with_vision_transformer.py @@ -276,7 +276,7 @@ def run_experiment(model): ], ) - checkpoint_filepath = "/tmp/checkpoint" + checkpoint_filepath = "/tmp/checkpoint.weights.h5" checkpoint_callback = keras.callbacks.ModelCheckpoint( checkpoint_filepath, monitor="val_accuracy",