-
Notifications
You must be signed in to change notification settings - Fork 225
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] support torchserver for unconditional models (#131)
* support torchserver for unconditional models * support sample_model selection in inference + revise docstring
- Loading branch information
1 parent
6dd321d
commit 6eb7045
Showing
5 changed files
with
234 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from argparse import ArgumentParser, Namespace | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
|
||
import mmcv | ||
|
||
try: | ||
from model_archiver.model_packaging import package_model | ||
from model_archiver.model_packaging_utils import ModelExportUtils | ||
except ImportError: | ||
package_model = None | ||
|
||
|
||
def mmgen2torchserver(config_file: str, | ||
checkpoint_file: str, | ||
output_folder: str, | ||
model_name: str, | ||
model_version: str = '1.0', | ||
model_type: str = 'unconditional', | ||
force: bool = False): | ||
"""Converts MMGeneration model (config + checkpoint) to TorchServe `.mar`. | ||
Args: | ||
config_file (str): Path of config file. The config should in | ||
MMGeneration format. | ||
checkpoint_file (str): Path of checkpoint. The checkpoint should in | ||
MMGeneration checkpoint format. | ||
output_folder (str): Folder where `{model_name}.mar` will be created. | ||
The file created will be in TorchServe archive format. | ||
model_name (str): Name of the generated ``'mar'`` file. If not None, | ||
used for naming the `{model_name}.mar` file that will be created | ||
under `output_folder`. If None, `{Path(checkpoint_file).stem}` | ||
will be used. | ||
model_version (str, optional): Model's version. Defaults to '1.0'. | ||
model_type (str, optional): Type of the model to be convert. Handler | ||
named ``{model_type}_handler`` would be used to generate ``mar`` | ||
file. Defaults to 'unconditional'. | ||
force (bool, optional): If True, existing `{model_name}.mar` will be | ||
overwritten. Default to False. | ||
""" | ||
mmcv.mkdir_or_exist(output_folder) | ||
|
||
config = mmcv.Config.fromfile(config_file) | ||
|
||
with TemporaryDirectory() as tmpdir: | ||
config.dump(f'{tmpdir}/config.py') | ||
|
||
args = Namespace( | ||
**{ | ||
'model_file': f'{tmpdir}/config.py', | ||
'serialized_file': checkpoint_file, | ||
'handler': | ||
f'{Path(__file__).parent}/mmgen_{model_type}_handler.py', | ||
'model_name': model_name or Path(checkpoint_file).stem, | ||
'version': model_version, | ||
'export_path': output_folder, | ||
'force': force, | ||
'requirements_file': None, | ||
'extra_files': None, | ||
'runtime': 'python', | ||
'archive_format': 'default' | ||
}) | ||
manifest = ModelExportUtils.generate_manifest_json(args) | ||
package_model(args, manifest) | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser( | ||
description='Convert MMGeneration models to TorchServe `.mar` format.') | ||
parser.add_argument('config', type=str, help='config file path') | ||
parser.add_argument('checkpoint', type=str, help='checkpoint file path') | ||
parser.add_argument( | ||
'--output-folder', | ||
type=str, | ||
required=True, | ||
help='Folder where `{model_name}.mar` will be created.') | ||
parser.add_argument( | ||
'--model-name', | ||
type=str, | ||
default=None, | ||
help='If not None, used for naming the `{model_name}.mar`' | ||
'file that will be created under `output_folder`.' | ||
'If None, `{Path(checkpoint_file).stem}` will be used.') | ||
parser.add_argument( | ||
'--model-type', | ||
type=str, | ||
default='unconditional', | ||
help='Which model type and handler to be used.') | ||
parser.add_argument( | ||
'--model-version', | ||
type=str, | ||
default='1.0', | ||
help='Number used for versioning.') | ||
parser.add_argument( | ||
'-f', | ||
'--force', | ||
action='store_true', | ||
help='overwrite the existing `{model_name}.mar`') | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
|
||
if package_model is None: | ||
raise ImportError('`torch-model-archiver` is required.' | ||
'Try: pip install torch-model-archiver') | ||
|
||
mmgen2torchserver(args.config, args.checkpoint, args.output_folder, | ||
args.model_name, args.model_version, args.model_type, | ||
args.force) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
from mmgen.apis import init_model | ||
|
||
|
||
class MMGenUnconditionalHandler(BaseHandler): | ||
|
||
def initialize(self, context): | ||
properties = context.system_properties | ||
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
self.device = torch.device(self.map_location + ':' + | ||
str(properties.get('gpu_id')) if torch.cuda. | ||
is_available() else self.map_location) | ||
self.manifest = context.manifest | ||
|
||
model_dir = properties.get('model_dir') | ||
serialized_file = self.manifest['model']['serializedFile'] | ||
checkpoint = os.path.join(model_dir, serialized_file) | ||
self.config_file = os.path.join(model_dir, 'config.py') | ||
|
||
self.model = init_model(self.config_file, checkpoint, self.device) | ||
self.initialized = True | ||
|
||
def preprocess(self, data, *args, **kwargs): | ||
data_decode = dict() | ||
# `data` type is `list[dict]` | ||
for k, v in data[0].items(): | ||
# deocde strings | ||
if isinstance(v, bytearray): | ||
data_decode[k] = v.decode() | ||
return data_decode | ||
|
||
def inference(self, data, *args, **kwargs): | ||
sample_model = data['sample_model'] | ||
print(sample_model) | ||
results = self.model.sample_from_noise( | ||
None, num_batches=1, sample_model=sample_model, **kwargs) | ||
return results | ||
|
||
def postprocess(self, data): | ||
# convert torch tensor to numpy and then covert to bytes | ||
output_list = [] | ||
for data_ in data: | ||
data_ = (data_ + 1) / 2 | ||
data_ = data_[[2, 1, 0], ...] | ||
data_ = data_.clamp_(0, 1) | ||
data_ = (data_ * 255).permute(1, 2, 0) | ||
data_np = data_.detach().cpu().numpy().astype(np.uint8) | ||
data_byte = data_np.tobytes() | ||
output_list.append(data_byte) | ||
|
||
return output_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from argparse import ArgumentParser | ||
|
||
import numpy as np | ||
import requests | ||
from PIL import Image | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument('model_name', help='The model name in the server') | ||
parser.add_argument( | ||
'--inference-addr', | ||
default='127.0.0.1:8080', | ||
help='Address and port of the inference server') | ||
parser.add_argument( | ||
'--img-path', | ||
type=str, | ||
default='demo.png', | ||
help='Path to save generated image.') | ||
parser.add_argument( | ||
'--img-size', type=int, default=128, help='Size of the output image.') | ||
parser.add_argument( | ||
'--sample-model', | ||
type=str, | ||
default='ema/orig', | ||
help='Which model you want to use.') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def save_results(contents, img_path, img_size): | ||
if not isinstance(contents, list): | ||
Image.frombytes('RGB', (img_size, img_size), contents).save(img_path) | ||
return | ||
|
||
imgs = [] | ||
for content in contents: | ||
imgs.append( | ||
np.array(Image.frombytes('RGB', (img_size, img_size), content))) | ||
Image.fromarray(np.concatenate(imgs, axis=1)).save(img_path) | ||
|
||
|
||
def main(args): | ||
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name | ||
|
||
if args.sample_model == 'ema/orig': | ||
cont_ema = requests.post(url, {'sample_model': 'ema'}).content | ||
cont_orig = requests.post(url, {'sample_model': 'orig'}).content | ||
save_results([cont_ema, cont_orig], args.img_path, args.img_size) | ||
return | ||
|
||
response = requests.post(url, {'sample_model': args.sample_model}) | ||
save_results(response.content, args.img_path, args.img_size) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |