Skip to content

Commit

Permalink
Update device
Browse files Browse the repository at this point in the history
  • Loading branch information
xhlulu committed Nov 6, 2021
1 parent b8e3a27 commit 449279b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
from model import Generator

model = Generator()
device = torch.device(device)

model = Generator().to(device)

if type(pretrained) == str:
ckpt_url = pretrained
Expand All @@ -14,7 +16,7 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
if pretrained is True:
state_dict = torch.hub.load_state_dict_from_url(
ckpt_url,
map_location=torch.device(device),
map_location=device,
progress=progress,
check_hash=check_hash,
)
Expand Down

0 comments on commit 449279b

Please sign in to comment.