Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document file formats in VarStore::save and load methods #829

Merged
merged 1 commit into from
Jan 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Document file formats in VarStore::save and load methods
  • Loading branch information
necrashter committed Dec 19, 2023
commit 2dcdac66d7151bcca2ef25587b71d2f4eba62405
10 changes: 10 additions & 0 deletions src/nn/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ impl VarStore {
///
/// Weight values for all the tensors currently stored in the
/// var-store are saved in the given file.
///
/// If the given path ends with the suffix `.safetensors`, the file will
/// be saved in safetensors format. Otherwise, libtorch C++ module format
/// will be used. Note that saving in pickle format (`.pt` extension) is
/// not supported by the C++ API of Torch.
pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
let variables = self.variables_.lock().unwrap();
let named_tensors = variables.named_variables.iter().collect::<Vec<_>>();
Expand Down Expand Up @@ -216,6 +221,11 @@ impl VarStore {
/// var-store are loaded from the given file. Note that the set of
/// variables stored in the var-store is not changed, only the values
/// for these tensors are modified.
///
/// The format of the file is deduced from the file extension:
/// - `.safetensors`: The file is assumed to be in safetensors format.
/// - `.bin` or `.pt`: The file is assumed to be in pickle format.
/// - Otherwise, the file is assumed to be in libtorch C++ module format.
pub fn load<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
if self.device != Device::Mps {
self.load_internal(path)
Expand Down