Skip to content

Commit

Permalink
TF 2.5.* support (openvinotoolkit#1100)
Browse files Browse the repository at this point in the history
### Changes

TF2.5.* supported

### Related tickets

62338

### Tests

Graph tests are updated. Now graphs for 2.4.* and for 2.5.* versions are stored in separate folders.
  • Loading branch information
negvet committed Feb 16, 2022
1 parent 5f162c9 commit bcf5818
Show file tree
Hide file tree
Showing 211 changed files with 100,512 additions and 21 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ See [third_party_integration](./third_party_integration) for examples of code mo
- Python\* 3.6.2 or later
- Supported frameworks:
- PyTorch\* >=1.5.0, <=1.9.1 (1.8.0 not supported)
- TensorFlow\* 2.4.3
- TensorFlow\* >=2.4.0, <=2.5.3

This repository is tested on Python* 3.6.2+, PyTorch* 1.9.1 (NVidia CUDA\* Toolkit 10.2) and TensorFlow* 2.4.3 (NVidia CUDA\* Toolkit 11.0).
This repository is tested on Python* 3.6.2+, PyTorch* 1.9.1 (NVidia CUDA\* Toolkit 10.2) and TensorFlow* 2.5.3 (NVidia CUDA\* Toolkit 11.2).

## Installation
We suggest to install or use the package in the [Python virtual environment](https://docs.python.org/3/tutorial/venv.html).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""

import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.keras.backend as K # pylint: disable=no-name-in-module
from examples.tensorflow.common.object_detection.architecture import nn_ops


Expand Down
9 changes: 9 additions & 0 deletions examples/tensorflow/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import resource
from os import path as osp
from pathlib import Path
import atexit

import tensorflow as tf
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy

from examples.tensorflow.common.logger import logger as default_logger
from examples.tensorflow.common.sample_config import CustomArgumentParser
Expand Down Expand Up @@ -201,3 +203,10 @@ def reset(self):
self.start_time = 0.
self.diff = 0.
self.average_time = 0.


def close_strategy_threadpool(strategy):
"""Due to https://github.com/tensorflow/tensorflow/issues/50487"""
# pylint: disable=protected-access
if isinstance(strategy, MirroredStrategy):
atexit.register(strategy._extended._collective_ops._pool.close)
25 changes: 18 additions & 7 deletions examples/tensorflow/object_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import functools
import os
import sys
from pathlib import Path

import tensorflow as tf
import numpy as np

from examples.tensorflow.common.utils import close_strategy_threadpool
from nncf.common.accuracy_aware_training import create_accuracy_aware_training_loop
from nncf.tensorflow import create_compressed_model
from nncf.tensorflow.helpers.model_manager import TFOriginalModelManager
Expand Down Expand Up @@ -266,6 +267,13 @@ def evaluate(test_step, metric, test_dist_dataset, num_batches, print_freq):
return result


def model_eval_fn(model, strategy, model_builder, test_dist_dataset, num_test_batches, config):
test_step = create_test_step_fn(strategy, model, model_builder.post_processing)
metric_result = evaluate(test_step, model_builder.eval_metrics(), test_dist_dataset,
num_test_batches, config.print_freq)
return metric_result['AP']


def run(config):
strategy = get_distribution_strategy(config)
if config.metrics_dump is not None:
Expand All @@ -286,11 +294,6 @@ def run(config):
# Create model builder
model_builder = get_model_builder(config)

def model_eval_fn(model):
test_step = create_test_step_fn(strategy, model, model_builder.post_processing)
metric_result = evaluate(test_step, model_builder.eval_metrics(), test_dist_dataset,
num_test_batches, config.print_freq)
return metric_result['AP']
# Register additional parameters in the NNCFConfig for initialization
# the compressed model during building
nncf_config = config.nncf_config
Expand All @@ -307,7 +310,13 @@ def model_eval_fn(model):
with TFOriginalModelManager(model_builder.build_model,
weights=config.get('weights', None)) as model:
with strategy.scope():
config.nncf_config.register_extra_structs([ModelEvaluationArgs(eval_fn=model_eval_fn)])
config.nncf_config.register_extra_structs(
[ModelEvaluationArgs(eval_fn=functools.partial(model_eval_fn,
strategy=strategy,
model_builder=model_builder,
test_dist_dataset=test_dist_dataset,
num_test_batches=num_test_batches,
config=config))])
compression_ctrl, compress_model = create_compressed_model(model, nncf_config, compression_state)
scheduler = build_scheduler(
config=config,
Expand Down Expand Up @@ -378,6 +387,8 @@ def validate_fn(model, **kwargs):
compression_ctrl.export_model(save_path, save_format)
logger.info("Saved to {}".format(save_path))

close_strategy_threadpool(strategy)


def export(config):
model_builder = get_model_builder(config)
Expand Down
2 changes: 1 addition & 1 deletion examples/tensorflow/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
absl-py==0.10
tensorflow_datasets==4.2.0
tensorflow_hub
tensorflow_addons==0.12.1
tensorflow_addons==0.14.0
opencv-python
protobuf==3.17.3
git+https://github.com/alexsu52/cocoapi.git#egg=pycocotools&subdirectory=PythonAPI
3 changes: 3 additions & 0 deletions examples/tensorflow/segmentation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import tensorflow as tf

from examples.tensorflow.common.utils import close_strategy_threadpool
from nncf.tensorflow import create_compressed_model
from nncf.tensorflow import register_default_init_args
from nncf.tensorflow.helpers.model_manager import TFOriginalModelManager
Expand Down Expand Up @@ -271,6 +272,8 @@ def run_evaluation(config, eval_timeout=None):
if config.metrics_dump is not None:
write_metrics(metric_result['AP'], config.metrics_dump)

close_strategy_threadpool(strategy)


def export(config):
model_builder = get_model_builder(config)
Expand Down
3 changes: 3 additions & 0 deletions examples/tensorflow/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import tensorflow as tf

from examples.tensorflow.common.utils import close_strategy_threadpool
from nncf.tensorflow import create_compressed_model
from nncf.tensorflow.helpers.model_manager import TFOriginalModelManager
from nncf.tensorflow.initialization import register_default_init_args
Expand Down Expand Up @@ -294,6 +295,8 @@ def run_train(config):
statistics = compression_ctrl.statistics()
logger.info(statistics.to_str())

close_strategy_threadpool(strategy)


def main(argv):
parser = get_argument_parser()
Expand Down
10 changes: 8 additions & 2 deletions nncf/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@

tensorflow_version = parse_version(tensorflow.__version__).base_version
if not tensorflow_version.startswith(BKC_TF_VERSION[:-2]):
raise RuntimeError(
'NNCF only supports tensorflow=={bkc}, while current tensorflow version is {curr}'.format(
import warnings
warnings.warn("NNCF provides best results with tensorflow=={bkc}, "
"while current tensorflow version is {curr} - consider switching to tensorflow=={bkc}".format(
bkc=BKC_TF_VERSION,
curr=tensorflow.__version__
))
elif not ('2.4' <= tensorflow_version[:3] <= '2.5'):
raise RuntimeError(
'NNCF only supports tensorflow >=2.4.0, ==2.5.*, while current tensorflow version is {curr}'.format(
curr=tensorflow.__version__
))


from nncf.tensorflow.helpers import create_compressed_model
Expand Down
2 changes: 1 addition & 1 deletion nncf/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '2.1.0'
BKC_TORCH_VERSION = '1.9.1'
BKC_TORCHVISION_VERSION = '0.10.1'
BKC_TF_VERSION = '2.4.*'
BKC_TF_VERSION = '2.5.*'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def find_version(*file_paths):
version_string = "{}{}".format(sys.version_info[0], sys.version_info[1])

_extra_deps = [
"tensorflow~=2.4.3",
"tensorflow~=2.5.0",
"torch>=1.5.0, <=1.9.1, !=1.8.0",
]

Expand Down
Loading

0 comments on commit bcf5818

Please sign in to comment.