Skip to content

Commit

Permalink
ENH: add a --speedup option and tests (#75)
Browse files Browse the repository at this point in the history
* add --speedup option to jit-compile model

* fix typo in import alias

* fix mypy complaint

* cache output of load_model

* do not cache model because weights it not hashable

* use eval mode for model in laod_model

* rename _jit_compile to jit_compile

* fix control flow in jit_compile

Test that the model works as we apply optimizations. If a certain step
fails, return the model from the previous (successful) step.

* add test_jit_compile

* add tests on pytorch nightly

* add ModelType type + use control flow for torch.compile

* cast model to Module if jit_compile

* using typing.cast instead of type.cast

* run regtests with and without --speedup
  • Loading branch information
kaczmarj authored Jan 20, 2023
1 parent be9c991 commit ee4e669
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 4 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ jobs:
run: python -m mypy wsinfer/
- name: Run tests
run: python -m pytest --verbose tests/
test-pytorch-nightly:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install the package
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python -m pip install --editable .[dev] --find-links https://girder.github.io/large_image_wheels
- name: Check style
run: python -m flake8 wsinfer/
- name: Check types
run: python -m mypy wsinfer/
- name: Run tests
run: python -m pytest --verbose tests/
test-docker:
runs-on: ubuntu-latest
steps:
Expand Down
50 changes: 46 additions & 4 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
from typing import List

from click.testing import CliRunner
import geojson as geojsoblib
import geojson as geojsonlib
import h5py
import numpy as np
import pandas as pd
import pytest
import tifffile
import torch
import yaml

from wsinfer import get_model_weights
from wsinfer import list_all_models_and_weights


@pytest.fixture
def tiff_image(tmp_path: Path) -> Path:
Expand Down Expand Up @@ -254,13 +258,15 @@ def test_cli_run_args(tmp_path: Path):
),
],
)
@pytest.mark.parametrize("speedup", [False, True])
def test_cli_run_regression(
model: str,
weights: str,
class_names: List[str],
expected_probs: List[float],
expected_patch_size: int,
expected_num_patches: int,
speedup: bool,
tiff_image: Path,
tmp_path: Path,
):
Expand All @@ -281,6 +287,7 @@ def test_cli_run_regression(
weights,
"--results-dir",
str(results_dir),
"--speedup" if speedup else "--no-speedup",
],
)
assert result.exit_code == 0
Expand Down Expand Up @@ -324,7 +331,7 @@ def test_cli_run_regression(
result = runner.invoke(cli, ["togeojson", str(results_dir), str(geojson_dir)])
assert result.exit_code == 0
with open(geojson_dir / "purple.json") as f:
d: geojsoblib.GeoJSON = geojsoblib.load(f)
d: geojsonlib.GeoJSON = geojsonlib.load(f)
assert d.is_valid, "geojson not valid!"
assert len(d["features"]) == expected_num_patches

Expand Down Expand Up @@ -888,10 +895,45 @@ def test_patch_cli(
for x in range(0, orig_slide_width, expected_patch_size):
for y in range(0, orig_slide_height, expected_patch_size):
expected_coords.append([x, y])
expected_coords = np.array(expected_coords)
expected_coords_arr = np.array(expected_coords)

with h5py.File(savedir / "patches" / f"{stem}.h5") as f:
assert f["/coords"].attrs["patch_size"] == expected_patch_size
coords = f["/coords"][()]
assert coords.shape == (expected_num_patches, 2)
assert np.array_equal(expected_coords, coords)
assert np.array_equal(expected_coords_arr, coords)


@pytest.mark.parametrize(["model_name", "weights_name"], list_all_models_and_weights())
def test_jit_compile(model_name: str, weights_name: str):
import time
from wsinfer._modellib.run_inference import jit_compile

w = get_model_weights(model_name, weights_name)
size = w.transform.resize_size
x = torch.ones(20, 3, size, size, dtype=torch.float32)
model = w.load_model()
model.eval()
NUM_SAMPLES = 1
with torch.no_grad():
t0 = time.perf_counter()
for _ in range(NUM_SAMPLES):
out_nojit = model(x).detach().cpu()
time_nojit = time.perf_counter() - t0
model_nojit = model
model = jit_compile(model)
if model is model_nojit:
pytest.skip("Failed to compile model (would use original model)")
with torch.no_grad():
model(x).detach().cpu() # run it once to compile
t0 = time.perf_counter()
for _ in range(NUM_SAMPLES):
out_jit = model(x).detach().cpu()
time_yesjit = time.perf_counter() - t0

assert torch.allclose(out_nojit, out_jit)
if time_nojit < time_yesjit:
pytest.skip(
"JIT-compiled model was SLOWER than original: "
f"jit={time_yesjit:0.3f} vs nojit={time_nojit:0.3f}"
)
1 change: 1 addition & 0 deletions wsinfer/_modellib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def load_model(self) -> torch.nn.Module:
raise RuntimeError("cannot find weights")

model.load_state_dict(state_dict, strict=True)
model.eval()
return model

def get_sha256_of_weights(self) -> str:
Expand Down
56 changes: 56 additions & 0 deletions wsinfer/_modellib/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,59 @@ def __getitem__(
return patch_im, torch.as_tensor([minx, miny, width, height])


def jit_compile(
model: torch.nn.Module,
) -> typing.Union[torch.jit.ScriptModule, torch.nn.Module, typing.Callable]:
"""JIT-compile a model for inference."""
noncompiled = model
w = "Warning: could not JIT compile the model. Using non-compiled model instead."
# TODO: consider freezing the model as well.
# PyTorch 2.x has torch.compile.
if hasattr(torch, "compile"):
# Try to get the most optimized model.
try:
return torch.compile(model, fullgraph=True, mode="max-autotune")
except Exception:
pass
try:
return torch.compile(model, mode="max-autotune")
except Exception:
pass
try:
return torch.compile(model)
except Exception:
warnings.warn(w)
return noncompiled
# For pytorch 1.x, use torch.jit.script.
else:
# Attempt to script. If it fails, return the original.
test_input = torch.ones(1, 3, 224, 224)
try:
mjit = torch.jit.script(model)
with torch.no_grad():
mjit(test_input)
except Exception:
warnings.warn(w)
return noncompiled
# Now that we have scripted the model, try to optimize it further. If that
# fails, return the scripted model.
try:
mjit_frozen = torch.jit.freeze(mjit)
mjit_opt = torch.jit.optimize_for_inference(mjit_frozen)
with torch.no_grad():
mjit_opt(test_input)
return mjit_opt
except Exception:
return mjit


def run_inference(
wsi_dir: PathType,
results_dir: PathType,
weights: Weights,
batch_size: int = 32,
num_workers: int = 0,
speedup: bool = False,
) -> None:
"""Run model inference on a directory of whole slide images and save results to CSV.
Expand All @@ -192,6 +239,9 @@ def run_inference(
The batch size during the forward pass (default is 32).
num_workers : int
Number of workers for data loading (default is 0, meaning use a single thread).
speedup : bool
If True, JIT-compile the model. This has a startup cost but model inference
should be faster (default False).
Returns
-------
Expand Down Expand Up @@ -231,6 +281,12 @@ def run_inference(
model.eval()
model.to(device)

if speedup:
if typing.TYPE_CHECKING:
model = typing.cast(torch.nn.Module, jit_compile(model))
else:
model = jit_compile(model)

# results_for_all_slides: typing.List[pd.DataFrame] = []
for i, (wsi_path, patch_path) in enumerate(zip(wsi_paths, patch_paths)):
print(f"Slide {i+1} of {len(wsi_paths)}")
Expand Down
8 changes: 8 additions & 0 deletions wsinfer/cli/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ def get_stdout(args) -> str:
help="Number of workers to use for data loading during model inference (default=0"
" for single thread). A reasonable value is 8.",
)
@click.option(
"--speedup/--no-speedup",
default=False,
show_default=True,
help="JIT-compile the model for potential speedups.",
)
@click.option(
"--dense-grid/--no-dense-grid",
default=False,
Expand All @@ -242,6 +248,7 @@ def cli(
config: typing.Optional[Path],
batch_size: int,
num_workers: int = 0,
speedup: bool = False,
dense_grid: bool = False,
):
"""Run model inference on a directory of whole slide images.
Expand Down Expand Up @@ -321,6 +328,7 @@ def cli(
weights=weights_obj,
batch_size=batch_size,
num_workers=num_workers,
speedup=speedup,
)

run_metadata_outpath = results_dir / "run_metadata.json"
Expand Down

0 comments on commit ee4e669

Please sign in to comment.