Skip to content

Commit

Permalink
Add known networks
Browse files Browse the repository at this point in the history
  • Loading branch information
xhlulu committed Nov 6, 2021
1 parent 449279b commit 4a9c882
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
import torch

def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
release_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights"
known = {
name: f"{release_url}/{name}.pt"
for name in [
'face_paint_512_v0', 'face_paint_512_v2'
]
}

from model import Generator

device = torch.device(device)

model = Generator().to(device)

if type(pretrained) == str:
ckpt_url = pretrained
# Look if a known name is passed, otherwise assume it's a URL
ckpt_url = known.get(pretrained, pretrained)
pretrained = True
else:
ckpt_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights/face_paint_512_v2_0.pt"
ckpt_url = known.get('face_paint_512_v2')

if pretrained is True:
state_dict = torch.hub.load_state_dict_from_url(
Expand Down

0 comments on commit 4a9c882

Please sign in to comment.