Skip to content

Commit

Permalink
Drop torch 1.8.2 support (#1736)
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor authored and KodiaqQ committed May 2, 2023
1 parent 0d9ad7b commit 97b1c42
Show file tree
Hide file tree
Showing 9 changed files with 6 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.d
def benchmark_performance(model_path, config) -> float:
command = f"benchmark_app -m {model_path} -d CPU -api async -t 30"
command += f' -shape "[1,3,{config.imgsz},{config.imgsz}]"'
cmd_output = subprocess.check_output(command, shell=True) # nosec
cmd_output = subprocess.check_output(command, shell=True) # nosec

match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output))
return float(match.group(1))
Expand Down
2 changes: 1 addition & 1 deletion examples/torch/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,6 @@ def download_checkpoint(url):
if not download_path.exists():
print("Downloading checkpoint ...")
checkpoint = requests.get(url)
with open(download_path, 'wb') as f:
with open(download_path, "wb") as f:
f.write(checkpoint.content)
return str(download_path)
2 changes: 1 addition & 1 deletion examples/torch/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ defusedxml>=0.7.0rc1
mlflow>=1.12.1
returns>0.14
opencv-python>=4.4.0.46
torchvision>=0.9.2,<0.15 # the minor version should always match the torch minor version that is installed via NNCF's `pip install nncf[torch]`; TV minor version is torch minor version +1
torchvision>=0.10.0,<0.15 # the minor version should always match the torch minor version that is installed via NNCF's `pip install nncf[torch]`; TV minor version is torch minor version +1
efficientnet_pytorch
2 changes: 1 addition & 1 deletion nncf/experimental/common/graph/netron.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
limitations under the License.
"""
# Since we are not reading XML, but creating it, the package security message is irrelevant
import xml.etree.ElementTree as ET # nosec
import xml.etree.ElementTree as ET # nosec
from typing import Callable, Dict, List, Optional, Tuple

from nncf.common.graph import NNCFGraph
Expand Down
15 changes: 1 addition & 14 deletions nncf/torch/binarization/binarize_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,13 @@
from typing import Any

import torch
from torch import _C # pylint:disable=protected-access
from torch.onnx.symbolic_helper import _is_constant # pylint:disable=protected-access

from nncf.common.logging import nncf_logger
from nncf.torch.binarization.extensions import BinarizedFunctionsCUDA
from nncf.torch.utils import add_domain


def _is_value(x: Any) -> bool:
return isinstance(x, _C.Value)


# Implementation is copy-pasted from torch.onnx.symbolic_helper.
# It's need to support torch < 1.9, since there's no such function in such versions of torch.
def _is_constant(value: Any) -> bool:
return not _is_value(value) or value.node().kind() in {
"onnx::Constant",
"prim::Constant",
}


def _unsqueeze_helper(g, input_, axes_i):
# Unsqueeze handling for different opsets inspired by torch.onnx.symbolic_helper._unsqueeze_helper
# The original unsqueeze_helper cannot be used in 1.13 since it references
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def find_version(*file_paths):
]

TORCH_EXTRAS = [
"torch>=1.8.2,<1.14",
"torch>=1.9.1,<1.14",
]

ONNX_EXTRAS = ["onnx~=1.13.1", "onnxruntime~=1.14.1"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
import pytest
import torch
from pkg_resources import parse_version
from torch import nn
from transformers import AutoModelForAudioClassification
from transformers import AutoModelForImageClassification
Expand Down Expand Up @@ -123,10 +122,6 @@ def fixture_transformer_search_params_desc(request):
return request.param


@pytest.mark.skipif(
parse_version(torch.__version__) < parse_version("1.9"),
reason=f"torch {torch.__version__} is not compatible with installed transformers package",
)
def test_transformer_building_blocks(desc: TransformerSearchBBlockParamsCase):
model = desc.model_creator()
move_model_to_cuda_if_available(model)
Expand Down
4 changes: 0 additions & 4 deletions tests/torch/nas/test_elastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,10 +485,6 @@ def forward(self, x):
return x * self.dummy + x


@pytest.mark.skipif(
parse_version(torch.__version__) < parse_version("1.9"),
reason="Test uses torch.permute attribute, which is not presented in the current torch version",
)
@pytest.mark.parametrize("model_creator", (TwoPermute, ChunkConcat, TwoBranchesBeforeInput, TwoBranchesAfterInput))
def test_can_skip_trivial_block(model_creator):
model = model_creator()
Expand Down
9 changes: 0 additions & 9 deletions tests/torch/nas/test_sanity_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from typing import Dict

import pytest
import torch
from pkg_resources import parse_version

from nncf.common.initialization.batchnorm_adaptation import BatchnormAdaptationAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elastic_depth import ElasticDepthHandler
Expand Down Expand Up @@ -101,13 +99,6 @@ def fixture_nas_desc(request, dataset_dir):


def test_e2e_supernet_training(nas_desc: NASSampleTestDescriptor, tmp_path, mocker):
if parse_version(torch.__version__) < parse_version("1.9") and (
"efficient_net" in nas_desc.config_name_ or "mobilenet_v3" in nas_desc.config_name_
):
pytest.skip(
f"Test exports model with hardsigmoid operator to ONNX opset version 13.\n"
f"It is not supported in the current torch version: {torch.__version__}"
)
validator = nas_desc.get_validator()
args = validator.get_default_args(tmp_path)
validator.validate_sample(args, mocker)

0 comments on commit 97b1c42

Please sign in to comment.