Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add a --speedup option and tests #75

Merged
merged 15 commits into from
Jan 20, 2023
Merged
12 changes: 8 additions & 4 deletions wsinfer/_modellib/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from .models import Weights

PathType = typing.Union[str, Path]
ModelType = typing.Union[torch.jit.ScriptModule, torch.nn.Module, typing.Callable]


class WholeSlideImageDirectoryNotFound(FileNotFoundError):
Expand Down Expand Up @@ -165,7 +164,9 @@ def __getitem__(
return patch_im, torch.as_tensor([minx, miny, width, height])


def jit_compile(model: torch.nn.Module) -> ModelType:
def jit_compile(
model: torch.nn.Module,
) -> typing.Union[torch.jit.ScriptModule, torch.nn.Module, typing.Callable]:
"""JIT-compile a model for inference."""
noncompiled = model
w = "Warning: could not JIT compile the model. Using non-compiled model instead."
Expand Down Expand Up @@ -276,12 +277,15 @@ def run_inference(
model_output_dir.mkdir(exist_ok=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model: ModelType = weights.load_model()
model = weights.load_model()
model.eval()
model.to(device)

if speedup:
model = jit_compile(model)
if typing.TYPE_CHECKING:
model = type.cast(torch.nn.Module, jit_compile(model))
kaczmarj marked this conversation as resolved.
Show resolved Hide resolved
else:
model = jit_compile(model)

# results_for_all_slides: typing.List[pd.DataFrame] = []
for i, (wsi_path, patch_path) in enumerate(zip(wsi_paths, patch_paths)):
Expand Down