Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update image_classification_with_vision_transformer.py #18740

Merged
merged 2 commits into from
Nov 7, 2023
Merged

Conversation

qlzh727
Copy link
Member

@qlzh727 qlzh727 commented Nov 7, 2023

Not sure if this is a bug in checkpoint logic. I got error like below:


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[10], line 41
     37     return history
     40 vit_classifier = create_vit_classifier()
---> 41 history = run_experiment(vit_classifier)
     44 def plot_history(item):
     45     plt.plot(history.history[item], label=item)

Cell In[10], line 23, in run_experiment(model)
     15 checkpoint_filepath = "/tmp/checkpoint"
     16 checkpoint_callback = keras.callbacks.ModelCheckpoint(
     17     checkpoint_filepath,
     18     monitor="val_accuracy",
     19     save_best_only=True,
     20     save_weights_only=True,
     21 )
---> 23 history = model.fit(
     24     x=x_train,
     25     y=y_train,
     26     batch_size=batch_size,
     27     epochs=num_epochs,
     28     validation_split=0.1,
     29     callbacks=[checkpoint_callback],
     30 )
     32 model.load_weights(checkpoint_filepath)
     33 _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)

File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    120     filtered_tb = _process_traceback_frames(e.__traceback__)
    121     # To get the full stack trace, call:
    122     # `keras.config.disable_traceback_filtering()`
--> 123     raise e.with_traceback(filtered_tb) from None
    124 finally:
    125     del filtered_tb

File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/models/model.py:373, in Model.save_weights(self, filepath, overwrite)
    363 """Saves all layer weights to a `.weights.h5` file.
    364 
    365 Args:
   (...)
    370         via an interactive prompt.
    371 """
    372 if not str(filepath).endswith(".weights.h5"):
--> 373     raise ValueError(
    374         "The filename must end in `.weights.h5`. "
    375         f"Received: filepath={filepath}"
    376     )
    377 try:
    378     exists = os.path.exists(filepath)

ValueError: The filename must end in `.weights.h5`. Received: filepath=/tmp/checkpoint

Not sure if this is a bug in checkpoint logic. I got error like below:

```

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[10], line 41
     37     return history
     40 vit_classifier = create_vit_classifier()
---> 41 history = run_experiment(vit_classifier)
     44 def plot_history(item):
     45     plt.plot(history.history[item], label=item)

Cell In[10], line 23, in run_experiment(model)
     15 checkpoint_filepath = "/tmp/checkpoint"
     16 checkpoint_callback = keras.callbacks.ModelCheckpoint(
     17     checkpoint_filepath,
     18     monitor="val_accuracy",
     19     save_best_only=True,
     20     save_weights_only=True,
     21 )
---> 23 history = model.fit(
     24     x=x_train,
     25     y=y_train,
     26     batch_size=batch_size,
     27     epochs=num_epochs,
     28     validation_split=0.1,
     29     callbacks=[checkpoint_callback],
     30 )
     32 model.load_weights(checkpoint_filepath)
     33 _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)

File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    120     filtered_tb = _process_traceback_frames(e.__traceback__)
    121     # To get the full stack trace, call:
    122     # `keras.config.disable_traceback_filtering()`
--> 123     raise e.with_traceback(filtered_tb) from None
    124 finally:
    125     del filtered_tb

File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/models/model.py:373, in Model.save_weights(self, filepath, overwrite)
    363 """Saves all layer weights to a `.weights.h5` file.
    364 
    365 Args:
   (...)
    370         via an interactive prompt.
    371 """
    372 if not str(filepath).endswith(".weights.h5"):
--> 373     raise ValueError(
    374         "The filename must end in `.weights.h5`. "
    375         f"Received: filepath={filepath}"
    376     )
    377 try:
    378     exists = os.path.exists(filepath)

ValueError: The filename must end in `.weights.h5`. Received: filepath=/tmp/checkpoint

```
@qlzh727
Copy link
Member Author

qlzh727 commented Nov 7, 2023

Neel, I think /tmp/checkpoint is a valid path for saving checkpoint. Maybe this is a corner case for ModelCheckpoint callback with save_weights_only=True

@codecov-commenter
Copy link

codecov-commenter commented Nov 7, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (b6df549) 78.89% compared to head (8a19814) 65.82%.

Additional details and impacted files
@@             Coverage Diff             @@
##           master   #18740       +/-   ##
===========================================
- Coverage   78.89%   65.82%   -13.08%     
===========================================
  Files         336      336               
  Lines       33937    33937               
  Branches     6651     6651               
===========================================
- Hits        26776    22338     -4438     
- Misses       5584    10153     +4569     
+ Partials     1577     1446      -131     
Flag Coverage Δ
keras 65.79% <ø> (-13.01%) ⬇️
keras-jax 61.76% <ø> (ø)
keras-numpy 56.26% <ø> (ø)
keras-tensorflow ?
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

see 78 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@nkovela1
Copy link
Contributor

nkovela1 commented Nov 7, 2023

Hi Scott,
This is actually intended behavior, since we want all weights files saved in Keras 3 to have the extension ".weights.h5".
The /tmp/checkpoint path was valid for checkpointing in legacy TF-Keras, but we are moving everyone to using this new format. Your latest commit LGTM!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Nov 7, 2023
@qlzh727 qlzh727 merged commit 0ef8a0f into master Nov 7, 2023
11 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Nov 7, 2023
@sachinprasadhs sachinprasadhs deleted the qlzh727-patch-3 branch February 1, 2024 22:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants