From ee4e669d32c59872c4f02d1941d48579817f7eab Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Fri, 20 Jan 2023 15:18:03 -0500 Subject: [PATCH] ENH: add a --speedup option and tests (#75) * add --speedup option to jit-compile model * fix typo in import alias * fix mypy complaint * cache output of load_model * do not cache model because weights it not hashable * use eval mode for model in laod_model * rename _jit_compile to jit_compile * fix control flow in jit_compile Test that the model works as we apply optimizations. If a certain step fails, return the model from the previous (successful) step. * add test_jit_compile * add tests on pytorch nightly * add ModelType type + use control flow for torch.compile * cast model to Module if jit_compile * using typing.cast instead of type.cast * run regtests with and without --speedup --- .github/workflows/ci.yml | 19 ++++++++++ tests/test_all.py | 50 +++++++++++++++++++++++--- wsinfer/_modellib/models.py | 1 + wsinfer/_modellib/run_inference.py | 56 ++++++++++++++++++++++++++++++ wsinfer/cli/infer.py | 8 +++++ 5 files changed, 130 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e8891b9..37bb587 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,6 +29,25 @@ jobs: run: python -m mypy wsinfer/ - name: Run tests run: python -m pytest --verbose tests/ + test-pytorch-nightly: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install the package + run: | + python -m pip install --upgrade pip setuptools wheel + python -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu + python -m pip install --editable .[dev] --find-links https://girder.github.io/large_image_wheels + - name: Check style + run: python -m flake8 wsinfer/ + - name: Check types + run: python -m mypy wsinfer/ + - name: Run tests + run: python -m pytest --verbose tests/ test-docker: runs-on: ubuntu-latest steps: diff --git a/tests/test_all.py b/tests/test_all.py index 1975cfd..deab1f0 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -7,14 +7,18 @@ from typing import List from click.testing import CliRunner -import geojson as geojsoblib +import geojson as geojsonlib import h5py import numpy as np import pandas as pd import pytest import tifffile +import torch import yaml +from wsinfer import get_model_weights +from wsinfer import list_all_models_and_weights + @pytest.fixture def tiff_image(tmp_path: Path) -> Path: @@ -254,6 +258,7 @@ def test_cli_run_args(tmp_path: Path): ), ], ) +@pytest.mark.parametrize("speedup", [False, True]) def test_cli_run_regression( model: str, weights: str, @@ -261,6 +266,7 @@ def test_cli_run_regression( expected_probs: List[float], expected_patch_size: int, expected_num_patches: int, + speedup: bool, tiff_image: Path, tmp_path: Path, ): @@ -281,6 +287,7 @@ def test_cli_run_regression( weights, "--results-dir", str(results_dir), + "--speedup" if speedup else "--no-speedup", ], ) assert result.exit_code == 0 @@ -324,7 +331,7 @@ def test_cli_run_regression( result = runner.invoke(cli, ["togeojson", str(results_dir), str(geojson_dir)]) assert result.exit_code == 0 with open(geojson_dir / "purple.json") as f: - d: geojsoblib.GeoJSON = geojsoblib.load(f) + d: geojsonlib.GeoJSON = geojsonlib.load(f) assert d.is_valid, "geojson not valid!" assert len(d["features"]) == expected_num_patches @@ -888,10 +895,45 @@ def test_patch_cli( for x in range(0, orig_slide_width, expected_patch_size): for y in range(0, orig_slide_height, expected_patch_size): expected_coords.append([x, y]) - expected_coords = np.array(expected_coords) + expected_coords_arr = np.array(expected_coords) with h5py.File(savedir / "patches" / f"{stem}.h5") as f: assert f["/coords"].attrs["patch_size"] == expected_patch_size coords = f["/coords"][()] assert coords.shape == (expected_num_patches, 2) - assert np.array_equal(expected_coords, coords) + assert np.array_equal(expected_coords_arr, coords) + + +@pytest.mark.parametrize(["model_name", "weights_name"], list_all_models_and_weights()) +def test_jit_compile(model_name: str, weights_name: str): + import time + from wsinfer._modellib.run_inference import jit_compile + + w = get_model_weights(model_name, weights_name) + size = w.transform.resize_size + x = torch.ones(20, 3, size, size, dtype=torch.float32) + model = w.load_model() + model.eval() + NUM_SAMPLES = 1 + with torch.no_grad(): + t0 = time.perf_counter() + for _ in range(NUM_SAMPLES): + out_nojit = model(x).detach().cpu() + time_nojit = time.perf_counter() - t0 + model_nojit = model + model = jit_compile(model) + if model is model_nojit: + pytest.skip("Failed to compile model (would use original model)") + with torch.no_grad(): + model(x).detach().cpu() # run it once to compile + t0 = time.perf_counter() + for _ in range(NUM_SAMPLES): + out_jit = model(x).detach().cpu() + time_yesjit = time.perf_counter() - t0 + + assert torch.allclose(out_nojit, out_jit) + if time_nojit < time_yesjit: + pytest.skip( + "JIT-compiled model was SLOWER than original: " + f"jit={time_yesjit:0.3f} vs nojit={time_nojit:0.3f}" + ) diff --git a/wsinfer/_modellib/models.py b/wsinfer/_modellib/models.py index f2692bc..3209d6f 100644 --- a/wsinfer/_modellib/models.py +++ b/wsinfer/_modellib/models.py @@ -217,6 +217,7 @@ def load_model(self) -> torch.nn.Module: raise RuntimeError("cannot find weights") model.load_state_dict(state_dict, strict=True) + model.eval() return model def get_sha256_of_weights(self) -> str: diff --git a/wsinfer/_modellib/run_inference.py b/wsinfer/_modellib/run_inference.py index 3cec0e5..631eecf 100644 --- a/wsinfer/_modellib/run_inference.py +++ b/wsinfer/_modellib/run_inference.py @@ -164,12 +164,59 @@ def __getitem__( return patch_im, torch.as_tensor([minx, miny, width, height]) +def jit_compile( + model: torch.nn.Module, +) -> typing.Union[torch.jit.ScriptModule, torch.nn.Module, typing.Callable]: + """JIT-compile a model for inference.""" + noncompiled = model + w = "Warning: could not JIT compile the model. Using non-compiled model instead." + # TODO: consider freezing the model as well. + # PyTorch 2.x has torch.compile. + if hasattr(torch, "compile"): + # Try to get the most optimized model. + try: + return torch.compile(model, fullgraph=True, mode="max-autotune") + except Exception: + pass + try: + return torch.compile(model, mode="max-autotune") + except Exception: + pass + try: + return torch.compile(model) + except Exception: + warnings.warn(w) + return noncompiled + # For pytorch 1.x, use torch.jit.script. + else: + # Attempt to script. If it fails, return the original. + test_input = torch.ones(1, 3, 224, 224) + try: + mjit = torch.jit.script(model) + with torch.no_grad(): + mjit(test_input) + except Exception: + warnings.warn(w) + return noncompiled + # Now that we have scripted the model, try to optimize it further. If that + # fails, return the scripted model. + try: + mjit_frozen = torch.jit.freeze(mjit) + mjit_opt = torch.jit.optimize_for_inference(mjit_frozen) + with torch.no_grad(): + mjit_opt(test_input) + return mjit_opt + except Exception: + return mjit + + def run_inference( wsi_dir: PathType, results_dir: PathType, weights: Weights, batch_size: int = 32, num_workers: int = 0, + speedup: bool = False, ) -> None: """Run model inference on a directory of whole slide images and save results to CSV. @@ -192,6 +239,9 @@ def run_inference( The batch size during the forward pass (default is 32). num_workers : int Number of workers for data loading (default is 0, meaning use a single thread). + speedup : bool + If True, JIT-compile the model. This has a startup cost but model inference + should be faster (default False). Returns ------- @@ -231,6 +281,12 @@ def run_inference( model.eval() model.to(device) + if speedup: + if typing.TYPE_CHECKING: + model = typing.cast(torch.nn.Module, jit_compile(model)) + else: + model = jit_compile(model) + # results_for_all_slides: typing.List[pd.DataFrame] = [] for i, (wsi_path, patch_path) in enumerate(zip(wsi_paths, patch_paths)): print(f"Slide {i+1} of {len(wsi_paths)}") diff --git a/wsinfer/cli/infer.py b/wsinfer/cli/infer.py index ee3df1a..a069429 100644 --- a/wsinfer/cli/infer.py +++ b/wsinfer/cli/infer.py @@ -225,6 +225,12 @@ def get_stdout(args) -> str: help="Number of workers to use for data loading during model inference (default=0" " for single thread). A reasonable value is 8.", ) +@click.option( + "--speedup/--no-speedup", + default=False, + show_default=True, + help="JIT-compile the model for potential speedups.", +) @click.option( "--dense-grid/--no-dense-grid", default=False, @@ -242,6 +248,7 @@ def cli( config: typing.Optional[Path], batch_size: int, num_workers: int = 0, + speedup: bool = False, dense_grid: bool = False, ): """Run model inference on a directory of whole slide images. @@ -321,6 +328,7 @@ def cli( weights=weights_obj, batch_size=batch_size, num_workers=num_workers, + speedup=speedup, ) run_metadata_outpath = results_dir / "run_metadata.json"