Skip to content

Commit

Permalink
Update image_classification_with_vision_transformer.py (#18740)
Browse files Browse the repository at this point in the history
* Update image_classification_with_vision_transformer.py

Not sure if this is a bug in checkpoint logic. I got error like below:

```

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[10], line 41
     37     return history
     40 vit_classifier = create_vit_classifier()
---> 41 history = run_experiment(vit_classifier)
     44 def plot_history(item):
     45     plt.plot(history.history[item], label=item)

Cell In[10], line 23, in run_experiment(model)
     15 checkpoint_filepath = "/tmp/checkpoint"
     16 checkpoint_callback = keras.callbacks.ModelCheckpoint(
     17     checkpoint_filepath,
     18     monitor="val_accuracy",
     19     save_best_only=True,
     20     save_weights_only=True,
     21 )
---> 23 history = model.fit(
     24     x=x_train,
     25     y=y_train,
     26     batch_size=batch_size,
     27     epochs=num_epochs,
     28     validation_split=0.1,
     29     callbacks=[checkpoint_callback],
     30 )
     32 model.load_weights(checkpoint_filepath)
     33 _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)

File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    120     filtered_tb = _process_traceback_frames(e.__traceback__)
    121     # To get the full stack trace, call:
    122     # `keras.config.disable_traceback_filtering()`
--> 123     raise e.with_traceback(filtered_tb) from None
    124 finally:
    125     del filtered_tb

File /opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/models/model.py:373, in Model.save_weights(self, filepath, overwrite)
    363 """Saves all layer weights to a `.weights.h5` file.
    364 
    365 Args:
   (...)
    370         via an interactive prompt.
    371 """
    372 if not str(filepath).endswith(".weights.h5"):
--> 373     raise ValueError(
    374         "The filename must end in `.weights.h5`. "
    375         f"Received: filepath={filepath}"
    376     )
    377 try:
    378     exists = os.path.exists(filepath)

ValueError: The filename must end in `.weights.h5`. Received: filepath=/tmp/checkpoint

```

* Update image_classification_with_vision_transformer.py
  • Loading branch information
qlzh727 committed Nov 7, 2023
1 parent b6df549 commit 0ef8a0f
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def run_experiment(model):
],
)

checkpoint_filepath = "/tmp/checkpoint"
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
Expand Down

0 comments on commit 0ef8a0f

Please sign in to comment.