Skip to content

Commit

Permalink
place test_input on same device as model in jit_compile (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaczmarj authored Jan 21, 2023
1 parent 1158030 commit 0277ddf
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions wsinfer/_modellib/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def jit_compile(
) -> typing.Union[torch.jit.ScriptModule, torch.nn.Module, typing.Callable]:
"""JIT-compile a model for inference."""
noncompiled = model
device = next(model.parameters()).device
# Attempt to script. If it fails, return the original.
test_input = torch.ones(1, 3, 224, 224).to(device)
w = "Warning: could not JIT compile the model. Using non-compiled model instead."
# TODO: consider freezing the model as well.
# PyTorch 2.x has torch.compile.
Expand All @@ -189,8 +192,6 @@ def jit_compile(
return noncompiled
# For pytorch 1.x, use torch.jit.script.
else:
# Attempt to script. If it fails, return the original.
test_input = torch.ones(1, 3, 224, 224)
try:
mjit = torch.jit.script(model)
with torch.no_grad():
Expand Down

0 comments on commit 0277ddf

Please sign in to comment.