Skip to content

Commit

Permalink
model-hub tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Nov 16, 2023
1 parent 776670b commit d88c57d
Show file tree
Hide file tree
Showing 13 changed files with 1,162 additions and 1 deletion.
21 changes: 21 additions & 0 deletions .github/workflows/model_hub.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Model Hub

on:
workflow_dispatch:

jobs:
torch:
runs-on: ubuntu-20.04-16-cores
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.8.10
- name: Install NNCF and test requirements
run: make install-models-hub-torch

- name: Run models-hub-torch test scope
run: make test-models-hub-torch
14 changes: 13 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,17 @@ install-torch-dev: install-torch-test install-pre-commit
pip install -r examples/post_training_quantization/torch/mobilenet_v2/requirements.txt
pip install -r examples/post_training_quantization/torch/ssd300_vgg16/requirements.txt

install-models-hub-torch:
pip install -U pip
pip install -e .
pip install -r tests/torch/models_hub_test/requirements.txt
# Install wheel to run pip with --no-build-isolation
pip install wheel
pip install --no-build-isolation -r tests/torch/models_hub_test/requirements_secondary.txt


test-torch:
pytest ${COVERAGE_ARGS} tests/torch -m "not weekly and not nightly" --junitxml ${JUNITXML_PATH} $(DATA_ARG)
pytest ${COVERAGE_ARGS} tests/torch -m "not weekly and not nightly and not models_hub_test" --junitxml ${JUNITXML_PATH} $(DATA_ARG)

test-torch-nightly:
pytest ${COVERAGE_ARGS} tests/torch -m nightly --junitxml ${JUNITXML_PATH} $(DATA_ARG)
Expand All @@ -138,6 +147,9 @@ test-examples-torch:
--backend torch \
--junitxml ${JUNITXML_PATH}

test-models-hub-torch:
pytest tests/torch/models_hub_test --junitxml ${JUNITXML_PATH}

###############################################################################
# Common part
install-common-test:
Expand Down
10 changes: 10 additions & 0 deletions tests/torch/models_hub_test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
110 changes: 110 additions & 0 deletions tests/torch/models_hub_test/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union

import networkx as nx
import numpy as np
import pytest
import torch
from _pytest.mark import ParameterSet

from nncf.common.graph import NNCFGraph
from nncf.torch.model_creation import wrap_model


class BaseTestModel(ABC):
@abstractmethod
def load_model(self, model_name: str):
pass

@staticmethod
def check_graph(graph: NNCFGraph):
nx_graph = graph._get_graph_for_visualization()
nx_graph = nx_graph.to_undirected()
num_connected_components = len(list(nx.connected_components(nx_graph)))
assert num_connected_components == 1, f"Disconnected graph, {num_connected_components} connected components"

def nncf_wrap(self, model_name):
torch.manual_seed(0)

fw_model, example = self.load_model(model_name)

example_input = None
if isinstance(example, (list, tuple)):
example_input = tuple([torch.tensor(x) for x in example])
elif isinstance(example, dict):
example_input = {k: torch.tensor(v) for k, v in example.items()}
assert example_input is not None

nncf_model = wrap_model(fw_model, example_input)

self.check_graph(nncf_model.nncf.get_original_graph())


@dataclass
class ModelInfo:
model_name: Optional[str]
model_link: Optional[str]
mark: Optional[str]
reason: Optional[str]


def idfn(val):
if isinstance(val, ModelInfo):
return val.model_name
return None


def get_models_list(file_name: str) -> List[ModelInfo]:
models = []
with open(file_name) as f:
for model_info in f:
model_info = model_info.rstrip()
# skip comment in model scope file
if model_info.startswith("#"):
continue
mark = None
reason = None
model_link = None

splitted = model_info.split(",")
if len(splitted) == 1:
model_name = splitted[0]
elif len(splitted) == 2:
model_name, model_link = splitted
elif len(splitted) == 4:
model_name, model_link, mark, reason = splitted
if model_link == "none":
model_link = None
assert mark in ["skip", "xfail"], "Incorrect failure mark for model info {}".format(model_info)
else:
raise RuntimeError(f"Incorrect model info `{model_info}`. It must contain either 1, 2 or 3 fields.")
models.append(ModelInfo(model_name, model_link, mark, reason))

return models


def get_model_params(file_name: Path) -> List[Union[ModelInfo, ParameterSet]]:
model_list = get_models_list(file_name)
params = []
for mi in model_list:
if mi.mark == "skip":
params.append(pytest.param(mi, marks=pytest.mark.skip(reason=mi.reason)))
elif mi.mark == "xfail":
params.append(pytest.param(mi, marks=pytest.mark.xfail(reason=mi.reason)))
else:
params.append(mi)
return params
Loading

0 comments on commit d88c57d

Please sign in to comment.