-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update image_classification_with_vision_transformer.py #18740
Conversation
Not sure if this is a bug in checkpoint logic. I got error like below: ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[10], line 41 37 return history 40 vit_classifier = create_vit_classifier() ---> 41 history = run_experiment(vit_classifier) 44 def plot_history(item): 45 plt.plot(history.history[item], label=item) Cell In[10], line 23, in run_experiment(model) 15 checkpoint_filepath = "/tmp/checkpoint" 16 checkpoint_callback = keras.callbacks.ModelCheckpoint( 17 checkpoint_filepath, 18 monitor="val_accuracy", 19 save_best_only=True, 20 save_weights_only=True, 21 ) ---> 23 history = model.fit( 24 x=x_train, 25 y=y_train, 26 batch_size=batch_size, 27 epochs=num_epochs, 28 validation_split=0.1, 29 callbacks=[checkpoint_callback], 30 ) 32 model.load_weights(checkpoint_filepath) 33 _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test) File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs) 120 filtered_tb = _process_traceback_frames(e.__traceback__) 121 # To get the full stack trace, call: 122 # `keras.config.disable_traceback_filtering()` --> 123 raise e.with_traceback(filtered_tb) from None 124 finally: 125 del filtered_tb File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/models/model.py:373, in Model.save_weights(self, filepath, overwrite) 363 """Saves all layer weights to a `.weights.h5` file. 364 365 Args: (...) 370 via an interactive prompt. 371 """ 372 if not str(filepath).endswith(".weights.h5"): --> 373 raise ValueError( 374 "The filename must end in `.weights.h5`. " 375 f"Received: filepath={filepath}" 376 ) 377 try: 378 exists = os.path.exists(filepath) ValueError: The filename must end in `.weights.h5`. Received: filepath=/tmp/checkpoint ```
Neel, I think |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #18740 +/- ##
===========================================
- Coverage 78.89% 65.82% -13.08%
===========================================
Files 336 336
Lines 33937 33937
Branches 6651 6651
===========================================
- Hits 26776 22338 -4438
- Misses 5584 10153 +4569
+ Partials 1577 1446 -131
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Hi Scott, |
Not sure if this is a bug in checkpoint logic. I got error like below: