-
Notifications
You must be signed in to change notification settings - Fork 541
/
utils.py
31 lines (24 loc) · 924 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from omegaconf.dictconfig import DictConfig
def recursively_cast_dictconfigs(cfg):
if isinstance(cfg, DictConfig):
return {k2: recursively_cast_dictconfigs(v2) for k2, v2 in cfg.items()}
else:
return cfg
def torch_load_cpu(path):
state = torch.load(path, map_location=torch.device("cpu"))
# If model was trained with fp16, model from loaded state_dict can be moved to fp16
if not isinstance(state, dict):
return state
if "cfg" in state:
state["cfg"] = recursively_cast_dictconfigs(state["cfg"])
if (
state["cfg"]["common"]["fp16"]
or state["cfg"]["common"]["memory_efficient_fp16"]
):
state["model"] = {k: v.half() for k, v in state["model"].items()}
return state
def load_and_pop_last_optimizer_state(pth):
st = torch_load_cpu(pth)
st.pop("last_optimizer_state", None)
return st