Skip to content

Commit

Permalink
Merge pull request #4036 from kaushaladiti-2802:patch-1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650732054
  • Loading branch information
Flax Authors committed Jul 9, 2024
2 parents 0da4f31 + 4910005 commit 36612a3
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ Setup

# Create some dummy variables for this example.
MAX_STEPS = 5
CKPT_PYTREE = [12, {'foo': 'str', 'bar': np.array((2, 3))}, [1, 4, 10]]
TARGET_PYTREE = [0, {'foo': '', 'bar': np.array((0))}, [0, 0, 0]]
CKPT_PYTREE = [12, {'bar': np.array((2, 3))}, [1, 4, 10]]
TARGET_PYTREE = [0, {'bar': np.array((0))}, [0, 0, 0]]

Most common use case: Saving/loading and managing checkpoints
*************************************************************
Expand Down Expand Up @@ -179,6 +179,33 @@ Then, you can call ``orbax.checkpoint.AsyncCheckpointer.wait_until_finished()``

For more details, read the `checkpoint guide <https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#asynchronized-checkpointing>`_.

You can also use Orbax AsyncCheckpointer with Flax APIs through async manager. Async manager internally calls wait_until_finished(). This solution is not actively maintained and the recommedation is to use Orbax async checkpointing.

For example:

.. codediff::
:title: flax.checkpoints, orbax.checkpoint
:skip_test: flax.checkpoints
:sync:

ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'
flax.config.update('flax_use_orbax_checkpointing', True)
async_manager = checkpoints.AsyncManager()

checkpoints.save_checkpoint(ASYNC_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True, async_manager=async_manager)
checkpoints.restore_checkpoint(ASYNC_CKPT_DIR, target=TARGET_PYTREE)
---

ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async'

import orbax.checkpoint as ocp
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(ASYNC_CKPT_DIR, args=ocp.args.StandardSave(CKPT_PYTREE))
# ... Continue with your work...
# ... Until a time when you want to wait until the save completes:
ckptr.wait_until_finished() # Blocks until the checkpoint saving is completed.
ckptr.restore(ASYNC_CKPT_DIR, args=ocp.args.StandardRestore(TARGET_PYTREE))


Saving/loading a single JAX or NumPy Array
******************************************
Expand Down

0 comments on commit 36612a3

Please sign in to comment.