Skip to content

Commit

Permalink
added use_tpu option
Browse files Browse the repository at this point in the history
  • Loading branch information
rishigami committed Jun 12, 2021
1 parent 8abcfc9 commit c3469d5
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions swintransformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def call(self, x):
return x


def SwinTransformer(model_name='swin_tiny_224', num_classes=1000, include_top=True, pretrained=True, cfgs=CFGS):
def SwinTransformer(model_name='swin_tiny_224', num_classes=1000, include_top=True, pretrained=True, use_tpu=False, cfgs=CFGS):
cfg = cfgs[model_name]
net = SwinTransformerModel(
model_name=model_name, include_top=include_top, num_classes=num_classes, img_size=cfg['input_size'], window_size=cfg[
Expand All @@ -442,6 +442,12 @@ def SwinTransformer(model_name='swin_tiny_224', num_classes=1000, include_top=Tr
if pretrained_ckpt:
if tf.io.gfile.isdir(pretrained_ckpt):
pretrained_ckpt = f'{pretrained_ckpt}/{model_name}.ckpt'
net.load_weights(pretrained_ckpt)

if use_tpu:
load_locally = tf.saved_model.LoadOptions(
experimental_io_device='/job:localhost')
net.load_weights(pretrained_ckpt, options=load_locally)
else:
net.load_weights(pretrained_ckpt)

return net

0 comments on commit c3469d5

Please sign in to comment.