Skip to content

Commit

Permalink
[Feature] support torchserver for unconditional models (#131)
Browse files Browse the repository at this point in the history
* support torchserver for unconditional models

* support sample_model selection in inference + revise docstring
  • Loading branch information
LeoXing1996 committed Oct 13, 2021
1 parent 6dd321d commit 6eb7045
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ data
*.log.json
work_dirs/
*.DS_Store

# PyTorch
*.pth
mmgen/configs/
mmgen/tools/
runs/

# Pytorch Server
*.mar
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ line_length=79
multi_line_output=0
known_standard_library=argparse,inspect,contextlib,hashlib,subprocess,unittest,tempfile,copy,pkg_resources,logging,pickle,platform,setuptools,abc,collections,functools,os,math,time,warnings,random,shutil,sys
known_first_party=mmgen
known_third_party=PIL,click,cv2,m2r,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm
known_third_party=PIL,click,cv2,m2r,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm,ts
no_lines_before=STDLIB,LOCALFOLDER
default_section=THIRDPARTY
114 changes: 114 additions & 0 deletions tools/deployment/mmgen2torchserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory

import mmcv

try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None


def mmgen2torchserver(config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
model_type: str = 'unconditional',
force: bool = False):
"""Converts MMGeneration model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file (str): Path of config file. The config should in
MMGeneration format.
checkpoint_file (str): Path of checkpoint. The checkpoint should in
MMGeneration checkpoint format.
output_folder (str): Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name (str): Name of the generated ``'mar'`` file. If not None,
used for naming the `{model_name}.mar` file that will be created
under `output_folder`. If None, `{Path(checkpoint_file).stem}`
will be used.
model_version (str, optional): Model's version. Defaults to '1.0'.
model_type (str, optional): Type of the model to be convert. Handler
named ``{model_type}_handler`` would be used to generate ``mar``
file. Defaults to 'unconditional'.
force (bool, optional): If True, existing `{model_name}.mar` will be
overwritten. Default to False.
"""
mmcv.mkdir_or_exist(output_folder)

config = mmcv.Config.fromfile(config_file)

with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')

args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler':
f'{Path(__file__).parent}/mmgen_{model_type}_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)


def parse_args():
parser = ArgumentParser(
description='Convert MMGeneration models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-type',
type=str,
default='unconditional',
help='Which model type and handler to be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()

return args


if __name__ == '__main__':
args = parse_args()

if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')

mmgen2torchserver(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.model_type,
args.force)
57 changes: 57 additions & 0 deletions tools/deployment/mmgen_unconditional_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

import numpy as np
import torch
from ts.torch_handler.base_handler import BaseHandler

from mmgen.apis import init_model


class MMGenUnconditionalHandler(BaseHandler):

def initialize(self, context):
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest

model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')

self.model = init_model(self.config_file, checkpoint, self.device)
self.initialized = True

def preprocess(self, data, *args, **kwargs):
data_decode = dict()
# `data` type is `list[dict]`
for k, v in data[0].items():
# deocde strings
if isinstance(v, bytearray):
data_decode[k] = v.decode()
return data_decode

def inference(self, data, *args, **kwargs):
sample_model = data['sample_model']
print(sample_model)
results = self.model.sample_from_noise(
None, num_batches=1, sample_model=sample_model, **kwargs)
return results

def postprocess(self, data):
# convert torch tensor to numpy and then covert to bytes
output_list = []
for data_ in data:
data_ = (data_ + 1) / 2
data_ = data_[[2, 1, 0], ...]
data_ = data_.clamp_(0, 1)
data_ = (data_ * 255).permute(1, 2, 0)
data_np = data_.detach().cpu().numpy().astype(np.uint8)
data_byte = data_np.tobytes()
output_list.append(data_byte)

return output_list
58 changes: 58 additions & 0 deletions tools/deployment/test_torchserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from argparse import ArgumentParser

import numpy as np
import requests
from PIL import Image


def parse_args():
parser = ArgumentParser()
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--img-path',
type=str,
default='demo.png',
help='Path to save generated image.')
parser.add_argument(
'--img-size', type=int, default=128, help='Size of the output image.')
parser.add_argument(
'--sample-model',
type=str,
default='ema/orig',
help='Which model you want to use.')
args = parser.parse_args()
return args


def save_results(contents, img_path, img_size):
if not isinstance(contents, list):
Image.frombytes('RGB', (img_size, img_size), contents).save(img_path)
return

imgs = []
for content in contents:
imgs.append(
np.array(Image.frombytes('RGB', (img_size, img_size), content)))
Image.fromarray(np.concatenate(imgs, axis=1)).save(img_path)


def main(args):
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name

if args.sample_model == 'ema/orig':
cont_ema = requests.post(url, {'sample_model': 'ema'}).content
cont_orig = requests.post(url, {'sample_model': 'orig'}).content
save_results([cont_ema, cont_orig], args.img_path, args.img_size)
return

response = requests.post(url, {'sample_model': args.sample_model})
save_results(response.content, args.img_path, args.img_size)


if __name__ == '__main__':
args = parse_args()
main(args)

0 comments on commit 6eb7045

Please sign in to comment.