From 0277ddff4706811db7681895910dd1c488571c8f Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Fri, 20 Jan 2023 22:58:39 -0500 Subject: [PATCH] place test_input on same device as model in jit_compile (#80) --- wsinfer/_modellib/run_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/wsinfer/_modellib/run_inference.py b/wsinfer/_modellib/run_inference.py index 631eecf..7d15cad 100644 --- a/wsinfer/_modellib/run_inference.py +++ b/wsinfer/_modellib/run_inference.py @@ -169,6 +169,9 @@ def jit_compile( ) -> typing.Union[torch.jit.ScriptModule, torch.nn.Module, typing.Callable]: """JIT-compile a model for inference.""" noncompiled = model + device = next(model.parameters()).device + # Attempt to script. If it fails, return the original. + test_input = torch.ones(1, 3, 224, 224).to(device) w = "Warning: could not JIT compile the model. Using non-compiled model instead." # TODO: consider freezing the model as well. # PyTorch 2.x has torch.compile. @@ -189,8 +192,6 @@ def jit_compile( return noncompiled # For pytorch 1.x, use torch.jit.script. else: - # Attempt to script. If it fails, return the original. - test_input = torch.ones(1, 3, 224, 224) try: mjit = torch.jit.script(model) with torch.no_grad():