Skip to content

Commit

Permalink
imagenet_CNN
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengxiawu committed Jan 18, 2020
1 parent db60b6d commit 28dd43d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 14 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ python train.py --dataset cifar10 --data_path /userhome/temp_data/cifar10 --data
python train.py --dataset cifar10 --data_path /userhome/temp_data/cifar10 --data_loader_type torch --auto_augmentation --cutout_length 16 --epochs 600 --model_method proxyless_NAS --model_name ofa_482
python train.py --dataset cifar10 --data_path /userhome/temp_data/cifar10 --data_loader_type torch --auto_augmentation --cutout_length 16 --epochs 600 --model_method proxyless_NAS --model_name ofa_398

# Training with cifar10 torch on different neural networks


# Training with ImageNet torch on different neural networks
python train.py --dataset ImageNet --data_path /gdata/ImageNet2012 --data_loader_type dali --drop_path_prob 0.2 --aux_weight 0.4 --init_channels 48 --layers 14 --epochs 250 --model_method darts_NAS --model_name MDENAS
python train.py --dataset ImageNet --data_path /gdata/ImageNet2012 --data_loader_type dali --model_method proxyless_NAS --model_name proxyless_gpu

python train.py --dataset ImageNet --data_path /userhome/temp_data/cifar10 --model_method proxyless_NAS --model_name proxyless_gpu

Expand Down
85 changes: 85 additions & 0 deletions models/darts/augment_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,88 @@ def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath_):
module.p = p


class AugmentCNN_ImageNet(BaseModel.MyNetwork):
""" Augmented CNN model """
def __init__(self, input_size, C_in, C, n_classes, n_layers, auxiliary, genotype,
stem_multiplier=3, dropout_rate=0.0):
"""
Args:
input_size: size of height and width (assuming height = width)
C_in: # of input channels
C: # of starting model channels
"""
super().__init__()
self.C_in = C_in
self.C = C
self.n_classes = n_classes
self.n_layers = n_layers
self.genotype = genotype
self.dropout_rate = dropout_rate
# aux head position
self.aux_pos = 2*n_layers//3 if auxiliary else -1

self.stem0 = nn.Sequential(
nn.Conv2d(C_in, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)

C_prev_prev, C_prev, C_curr = C, C, C

self.cells = nn.ModuleList()
reduction_prev = True
for i in range(n_layers):
if i in [n_layers//3, 2*n_layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False

cell = AugmentCell(genotype, C_prev_prev, C_prev, C_curr, reduction_prev, reduction)
reduction_prev = reduction
self.cells.append(cell)
C_prev_prev, C_prev = C_prev, len(cell.concat) * C_curr

if i == self.aux_pos:
# [!] this auxiliary head is ignored in computing parameter size
# by the name 'aux_head'
self.aux_head = AuxiliaryHead(input_size//4, C_prev, n_classes)

self.gap = nn.AdaptiveAvgPool2d(1)
# dropout
if self.dropout_rate > 0:
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
else:
self.dropout = None
self.linear = nn.Linear(C_prev, n_classes)

def forward(self, x):
s0 = self.stem0(x)
s1 = self.stem1(s0)
aux_logits = None
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)
if i == self.aux_pos and self.training:
aux_logits = self.aux_head(s1)

out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
if self.dropout is not None:
out = self.dropout(out)
logits = self.linear(out)
return logits, aux_logits

def drop_path_prob(self, p):
""" Set drop path probability """
for module in self.modules():
if isinstance(module, ops.DropPath_):
module.p = p
24 changes: 13 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import models.darts.genotypes as gt
import time
import utils
from models.darts.augment_cnn import AugmentCNN
from models.darts.augment_cnn import AugmentCNN, AugmentCNN_ImageNet
from models import get_model
from data import get_data
import flops_counter
Expand Down Expand Up @@ -62,17 +62,16 @@ def __init__(self):
parser = self.build_parser()
args = parser.parse_args()
super().__init__(**vars(args))
if not self.model_method == 'darts':
if not self.model_method == 'darts_NAS':
if self.aux_weight > 0 or self.drop_path_prob > 0:
print("aux head and drop path only support for daats search space!")
exit()

time_str = time.asctime(time.localtime()).replace(' ', '_')
name_componment = [self.model_method, self.model_name,
self.data_loader_type, 'epoch_' + self.epochs]
if not self.model_method == 'darts':
name_componment += ['channels_' + self.init_channels, 'layers_' + self.layers,
'aux_weight_' + self.aux_weight, 'drop_path_prob_' + self.drop_path_prob]
name_componment = [self.dataset, self.data_loader_type, 'epoch_' + str(self.epochs)]
if not self.model_method == 'darts_NAS':
name_componment += ['channels_' + str(self.init_channels), 'layers_' + str(self.layers),
'aux_weight_' + str(self.aux_weight), 'drop_path_prob_' + str(self.drop_path_prob)]
if self.auto_augmentation or self.cutout_length > 0:
print("DALI do not support Augmentation and Cutout!")
exit()
Expand All @@ -90,9 +89,7 @@ def __init__(self):
name_str += i + '_'
name_str += time_str
self.path = os.path.join('/userhome/project/pytorch_image_classification/expreiments',
self.dataset, name_str)


self.model_method, self.model_name, name_str)
if len(self.genotype) > 1:
self.genotype = gt.from_str(self.genotype)
else:
Expand Down Expand Up @@ -169,15 +166,20 @@ def main():
if config.model_method == 'darts_NAS':
if config.genotype is None:
config.genotype = get_model.get_model(config.model_method, config.model_name)
model = AugmentCNN(input_size, input_channels, config.init_channels, n_classes, config.layers,
if 'imagenet' in config.dataset.lower():
model = AugmentCNN_ImageNet(input_size, input_channels, config.init_channels, n_classes, config.layers,
use_aux, config.genotype)
else:
model = AugmentCNN(input_size, input_channels, config.init_channels, n_classes, config.layers,
use_aux, config.genotype)
else:
model_fun = get_model.get_model(config.model_method, config.model_name)
model = model_fun(num_classes=n_classes, dropout_rate=config.dropout_rate)
# set bn
model.set_bn_param(config.bn_momentum, config.bn_eps)
# model init
model.init_model(model_init=config.model_init)
model.cuda()
# model size
total_ops, total_params = flops_counter.profile(model, [1, input_channels, input_size, input_size])
logger.info("Model size = {:.3f} MB".format(total_params))
Expand Down

0 comments on commit 28dd43d

Please sign in to comment.