Skip to content

Commit

Permalink
implement onnx export for inception3/4, resnext, mobilenetv2 (#346)
Browse files Browse the repository at this point in the history
* add inceptionv4 backbone/training settings
* add converted backbone, top-1 acc 80.08
  • Loading branch information
lostkevin committed Jul 18, 2024
1 parent 8c90cea commit 3189798
Show file tree
Hide file tree
Showing 14 changed files with 761 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,31 @@
# model settings
model = dict(
type='Classification',
backbone=dict(type='Inception3'),
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=2048,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
backbone=dict(type='Inception3', num_classes=1000),
head=[
dict(
type='ClsHead',
with_fc=False,
in_channels=2048,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[1],
),
num_classes=num_classes))
dict(
type='ClsHead',
with_fc=False,
in_channels=768,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[0],
)
])

class_list = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13',
Expand Down Expand Up @@ -196,3 +211,5 @@
interval=10,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])

export = dict(export_type='raw', export_neck=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
_base_ = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py'

num_classes = 1000
# model settings
model = dict(
type='Classification',
backbone=dict(type='Inception4', num_classes=num_classes),
head=[
dict(
type='ClsHead',
with_fc=False,
in_channels=1536,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[1],
),
dict(
type='ClsHead',
with_fc=False,
in_channels=768,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[0],
)
])

img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# A config with the optimization settings from https://arxiv.org/pdf/1602.07261
# May run with 20 GPUs
_base_ = 'configs/classification/imagenet/inception/inceptionv3_b32x8_100e.py'

num_classes = 1000
# model settings
model = dict(
type='Classification',
backbone=dict(type='Inception4', num_classes=num_classes),
head=[
dict(
type='ClsHead',
with_fc=False,
in_channels=1536,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[1],
),
dict(
type='ClsHead',
with_fc=False,
in_channels=768,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes,
input_feature_index=[0],
)
])

img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

# optimizer
optimizer = dict(
type='RMSprop', lr=0.045, momentum=0.9, weight_decay=0.9, eps=1.0)

# learning policy
lr_config = dict(policy='exp', gamma=0.96954) # gamma**2 ~ 0.94
checkpoint_config = dict(interval=10)

# runtime settings
total_epochs = 200
5 changes: 3 additions & 2 deletions configs/classification/imagenet/mobilenet/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes))
num_classes=num_classes),
pretrained=True)

# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
Expand All @@ -25,4 +26,4 @@
# runtime settings
total_epochs = 100
checkpoint_sync_export = True
export = dict(export_neck=True)
export = dict(export_type='raw', export_neck=True)
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes))
num_classes=num_classes),
pretrained=True)

# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
Expand All @@ -30,3 +31,4 @@

# runtime settings
total_epochs = 100
export = dict(export_type='raw', export_neck=True)
69 changes: 42 additions & 27 deletions easycv/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,34 +158,49 @@ def _get_blade_model():


def _export_onnx_cls(model, model_config, cfg, filename, meta):
support_backbones = {
'ResNet': {
'depth': [50]
},
'MobileNetV2': {},
'Inception3': {},
'Inception4': {},
'ResNeXt': {
'depth': [50]
}
}
if model_config['backbone'].get('type', None) not in support_backbones:
tmp = ' '.join(support_backbones.keys())
info_str = f'Only support export onnx model for {tmp} now!'
raise ValueError(info_str)
configs = support_backbones[model_config['backbone'].get('type')]
for k, v in configs.items():
if v[0].__class__(model_config['backbone'].get(k, None)) not in v:
raise ValueError(
f"Unsupport config for {model_config['backbone'].get('type')}")

# save json config for test_pipline and class
with io.open(
filename +
'.config.json' if filename.endswith('onnx') else filename +
'.onnx.config.json', 'w') as ofile:
json.dump(meta, ofile)

if model_config['backbone'].get(
'type', None) == 'ResNet' and model_config['backbone'].get(
'depth', None) == 50:
# save json config for test_pipline and class
with io.open(
filename +
'.config.json' if filename.endswith('onnx') else filename +
'.onnx.config.json', 'w') as ofile:
json.dump(meta, ofile)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)
img_size = int(cfg.image_size2)
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
torch.onnx.export(
model,
(x_input, 'onnx'),
filename if filename.endswith('onnx') else filename + '.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)
else:
raise ValueError('Only support export onnx model for ResNet now!')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)
img_size = int(cfg.image_size2)
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
torch.onnx.export(
model,
(x_input, 'onnx'),
filename if filename.endswith('onnx') else filename + '.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)


def _export_cls(model, cfg, filename):
Expand Down
1 change: 1 addition & 0 deletions easycv/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .genet import PlainNet
from .hrnet import HRNet
from .inceptionv3 import Inception3
from .inceptionv4 import Inception4
from .lighthrnet import LiteHRNet
from .mae_vit_transformer import *
from .mit import MixVisionTransformer
Expand Down
11 changes: 2 additions & 9 deletions easycv/models/backbones/inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
r""" This model is taken from the official PyTorch model zoo.
- torchvision.models.inception.py on 31th Aug, 2019
"""

from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -16,8 +13,6 @@

__all__ = ['Inception3']

_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])


@BACKBONES.register_module
class Inception3(nn.Module):
Expand Down Expand Up @@ -113,6 +108,7 @@ def forward(self, x):
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux = None
if self.training and self.aux_logits:
aux = self.AuxLogits(x)
# N x 768 x 17 x 17
Expand All @@ -132,10 +128,7 @@ def forward(self, x):
if hasattr(self, 'fc'):
x = self.fc(x)

# N x 1000 (num_classes)
if self.training and self.aux_logits and hasattr(self, 'fc'):
return [_InceptionOutputs(x, aux)]
return [x]
return [aux, x]


class InceptionA(nn.Module):
Expand Down
Loading

0 comments on commit 3189798

Please sign in to comment.