-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NASNet models #8714
Add NASNet models #8714
Changes from 1 commit
9a62e76
2dc0bda
50a3e29
318bcc7
5462200
77a2096
073c826
15fc1d1
9c7385c
6a6d608
9c3712c
49250b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,6 @@ | |
from ..layers import Input | ||
from ..layers import Activation | ||
from ..layers import Dense | ||
from ..layers import Dropout | ||
from ..layers import BatchNormalization | ||
from ..layers import MaxPooling2D | ||
from ..layers import AveragePooling2D | ||
|
@@ -54,18 +53,17 @@ | |
from ..layers import Cropping2D | ||
from ..layers import concatenate | ||
from ..layers import add | ||
from ..regularizers import l2 | ||
from ..utils.data_utils import get_file | ||
from ..engine.topology import get_source_inputs | ||
from ..applications.imagenet_utils import _obtain_input_shape | ||
from ..applications.inception_v3 import preprocess_input | ||
from ..applications.imagenet_utils import decode_predictions | ||
from .. import backend as K | ||
|
||
NASNET_MOBILE_WEIGHT_PATH = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.0/NASNet-mobile.h5' | ||
NASNET_MOBILE_WEIGHT_PATH_NO_TOP = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.0/NASNet-mobile-no-top.h5' | ||
NASNET_LARGE_WEIGHT_PATH = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.1/NASNet-large.h5' | ||
NASNET_LARGE_WEIGHT_PATH_NO_TOP = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.1/NASNet-large-no-top.h5' | ||
NASNET_MOBILE_WEIGHT_PATH = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.2/NASNet-mobile.h5' | ||
NASNET_MOBILE_WEIGHT_PATH_NO_TOP = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.2/NASNet-mobile-no-top.h5' | ||
NASNET_LARGE_WEIGHT_PATH = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.2/NASNet-large.h5' | ||
NASNET_LARGE_WEIGHT_PATH_NO_TOP = 'https://github.com/titu1994/Keras-NASNet/releases/download/v1.2/NASNet-large-no-top.h5' | ||
|
||
|
||
def NASNet(input_shape=None, | ||
|
@@ -74,14 +72,14 @@ def NASNet(input_shape=None, | |
stem_filters=96, | ||
skip_reduction=True, | ||
filter_multiplier=2, | ||
weight_decay=5e-5, | ||
include_top=True, | ||
weights=None, | ||
input_tensor=None, | ||
pooling=None, | ||
classes=1000, | ||
default_size=None): | ||
'''Instantiates a NASNet model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Introduce a blank line after this |
||
|
||
Note that only TensorFlow is supported for now, | ||
therefore it only works with the data format | ||
`image_data_format='channels_last'` in your Keras config | ||
|
@@ -113,7 +111,6 @@ def NASNet(input_shape=None, | |
of filters in each layer. | ||
- If `filter_multiplier` = 1, default number of filters from the | ||
paper are used at each layer. | ||
weight_decay: l2 regularization weight | ||
include_top: whether to include the fully-connected | ||
layer at the top of the network. | ||
weights: `None` (random initialization) or | ||
|
@@ -208,51 +205,47 @@ def NASNet(input_shape=None, | |
if not skip_reduction: | ||
x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', | ||
use_bias=False, name='stem_conv1', | ||
kernel_initializer='he_normal', | ||
kernel_regularizer=l2(weight_decay))(img_input) | ||
kernel_initializer='he_normal')(img_input) | ||
else: | ||
x = Conv2D(stem_filters, (3, 3), strides=(1, 1), padding='same', | ||
use_bias=False, name='stem_conv1', | ||
kernel_initializer='he_normal', | ||
kernel_regularizer=l2(weight_decay))(img_input) | ||
kernel_initializer='he_normal')(img_input) | ||
|
||
x = BatchNormalization(axis=channel_dim, momentum=0.9997, | ||
epsilon=1e-3, name='stem_bn1')(x) | ||
|
||
p = None | ||
if not skip_reduction: # imagenet / mobile mode | ||
x, p = _reduction_A(x, p, filters // (filter_multiplier ** 2), | ||
weight_decay, id='stem_1') | ||
x, p = _reduction_A(x, p, filters // filter_multiplier, weight_decay, | ||
id='stem_2') | ||
id='stem_1') | ||
x, p = _reduction_A(x, p, filters // filter_multiplier, id='stem_2') | ||
|
||
for i in range(num_blocks): | ||
x, p = _normal_A(x, p, filters, weight_decay, id='%d' % (i)) | ||
x, p = _normal_A(x, p, filters, id='%d' % (i)) | ||
|
||
x, p0 = _reduction_A(x, p, filters * filter_multiplier, weight_decay, | ||
x, p0 = _reduction_A(x, p, filters * filter_multiplier, | ||
id='reduce_%d' % (num_blocks)) | ||
|
||
p = p0 if not skip_reduction else p | ||
|
||
for i in range(num_blocks): | ||
x, p = _normal_A(x, p, filters * filter_multiplier, weight_decay, | ||
x, p = _normal_A(x, p, filters * filter_multiplier, | ||
id='%d' % (num_blocks + i + 1)) | ||
|
||
x, p0 = _reduction_A(x, p, filters * filter_multiplier ** 2, weight_decay, | ||
x, p0 = _reduction_A(x, p, filters * filter_multiplier ** 2, | ||
id='reduce_%d' % (2 * num_blocks)) | ||
|
||
p = p0 if not skip_reduction else p | ||
|
||
for i in range(num_blocks): | ||
x, p = _normal_A(x, p, filters * filter_multiplier ** 2, weight_decay, | ||
x, p = _normal_A(x, p, filters * filter_multiplier ** 2, | ||
id='%d' % (2 * num_blocks + i + 1)) | ||
|
||
x = Activation('relu')(x) | ||
|
||
if include_top: | ||
x = GlobalAveragePooling2D()(x) | ||
x = Dense(classes, activation='softmax', name='predictions', | ||
kernel_regularizer=l2(weight_decay))(x) | ||
x = Dense(classes, activation='softmax', name='predictions')(x) | ||
else: | ||
if pooling == 'avg': | ||
x = GlobalAveragePooling2D()(x) | ||
|
@@ -280,7 +273,7 @@ def NASNet(input_shape=None, | |
|
||
weights_file = get_file(model_name, weight_path, | ||
cache_subdir='models') | ||
model.load_weights(weights_file, by_name=True) | ||
model.load_weights(weights_file) | ||
|
||
elif default_size == 331: # large version | ||
if include_top: | ||
|
@@ -292,7 +285,7 @@ def NASNet(input_shape=None, | |
|
||
weights_file = get_file(model_name, weight_path, | ||
cache_subdir='models') | ||
model.load_weights(weights_file, by_name=True) | ||
model.load_weights(weights_file) | ||
else: | ||
raise ValueError( | ||
'ImageNet weights can only be loaded on NASNetLarge' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "with" would fit better than "on" I think |
||
|
@@ -313,6 +306,7 @@ def NASNetLarge(input_shape=None, | |
pooling=None, | ||
classes=1000): | ||
'''Instantiates a NASNet model in ImageNet mode. | ||
|
||
Note that only TensorFlow is supported for now, | ||
therefore it only works with the data format | ||
`image_data_format='channels_last'` in your Keras config | ||
|
@@ -364,7 +358,6 @@ def NASNetLarge(input_shape=None, | |
stem_filters=96, | ||
skip_reduction=False, | ||
filter_multiplier=2, | ||
weight_decay=5e-5, | ||
include_top=include_top, | ||
weights=weights, | ||
input_tensor=input_tensor, | ||
|
@@ -380,6 +373,7 @@ def NASNetMobile(input_shape=None, | |
pooling=None, | ||
classes=1000): | ||
'''Instantiates a Mobile NASNet model in ImageNet mode. | ||
|
||
Note that only TensorFlow is supported for now, | ||
therefore it only works with the data format | ||
`image_data_format='channels_last'` in your Keras config | ||
|
@@ -429,7 +423,6 @@ def NASNetMobile(input_shape=None, | |
stem_filters=32, | ||
skip_reduction=False, | ||
filter_multiplier=2, | ||
weight_decay=4e-5, | ||
include_top=include_top, | ||
weights=weights, | ||
input_tensor=input_tensor, | ||
|
@@ -439,7 +432,7 @@ def NASNetMobile(input_shape=None, | |
|
||
|
||
def _separable_conv_block(ip, filters, kernel_size=(3, 3), strides=(1, 1), | ||
weight_decay=5e-5, id=None): | ||
id=None): | ||
'''Adds 2 blocks of [relu-separable conv-batchnorm] | ||
|
||
# Arguments: | ||
|
@@ -460,24 +453,22 @@ def _separable_conv_block(ip, filters, kernel_size=(3, 3), strides=(1, 1), | |
x = SeparableConv2D(filters, kernel_size, strides=strides, | ||
name='separable_conv_1_%s' % id, | ||
padding='same', use_bias=False, | ||
kernel_initializer='he_normal', | ||
kernel_regularizer=l2(weight_decay))(x) | ||
kernel_initializer='he_normal')(x) | ||
x = BatchNormalization(axis=channel_dim, momentum=0.9997, | ||
epsilon=1e-3, | ||
name='separable_conv_1_bn_%s' % (id))(x) | ||
x = Activation('relu')(x) | ||
x = SeparableConv2D(filters, kernel_size, | ||
name='separable_conv_2_%s' % id, | ||
padding='same', use_bias=False, | ||
kernel_initializer='he_normal', | ||
kernel_regularizer=l2(weight_decay))(x) | ||
kernel_initializer='he_normal')(x) | ||
x = BatchNormalization(axis=channel_dim, momentum=0.9997, | ||
epsilon=1e-3, | ||
name='separable_conv_2_bn_%s' % (id))(x) | ||
return x | ||
|
||
|
||
def _adjust_block(p, ip, filters, weight_decay=5e-5, id=None): | ||
def _adjust_block(p, ip, filters, id=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use |
||
''' | ||
Adjusts the input `previou path` to match the shape of the `input` | ||
or situations where the output number of filters needs to be changed | ||
|
@@ -506,17 +497,15 @@ def _adjust_block(p, ip, filters, weight_decay=5e-5, id=None): | |
p1 = AveragePooling2D((1, 1), strides=(2, 2), padding='valid', | ||
name='adjust_avg_pool_1_%s' % id)(p) | ||
p1 = Conv2D(filters // 2, (1, 1), padding='same', | ||
use_bias=False, kernel_regularizer=l2(weight_decay), | ||
name='adjust_conv_1_%s' % id, | ||
use_bias=False, name='adjust_conv_1_%s' % id, | ||
kernel_initializer='he_normal')(p1) | ||
|
||
p2 = ZeroPadding2D(padding=((0, 1), (0, 1)))(p) | ||
p2 = Cropping2D(cropping=((1, 0), (1, 0)))(p2) | ||
p2 = AveragePooling2D((1, 1), strides=(2, 2), padding='valid', | ||
name='adjust_avg_pool_2_%s' % id)(p2) | ||
p2 = Conv2D(filters // 2, (1, 1), padding='same', | ||
use_bias=False, kernel_regularizer=l2(weight_decay), | ||
name='adjust_conv_2_%s' % id, | ||
use_bias=False, name='adjust_conv_2_%s' % id, | ||
kernel_initializer='he_normal')(p2) | ||
|
||
p = concatenate([p1, p2], axis=channel_dim) | ||
|
@@ -529,15 +518,14 @@ def _adjust_block(p, ip, filters, weight_decay=5e-5, id=None): | |
p = Activation('relu')(p) | ||
p = Conv2D(filters, (1, 1), strides=(1, 1), padding='same', | ||
name='adjust_conv_projection_%s' % id, | ||
use_bias=False, kernel_regularizer=l2(weight_decay), | ||
kernel_initializer='he_normal')(p) | ||
use_bias=False, kernel_initializer='he_normal')(p) | ||
p = BatchNormalization(axis=channel_dim, momentum=0.9997, | ||
epsilon=1e-3, | ||
name='adjust_bn_%s' % id)(p) | ||
return p | ||
|
||
|
||
def _normal_A(ip, p, filters, weight_decay=5e-5, id=None): | ||
def _normal_A(ip, p, filters, id=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use function names should be snake case (no caps) |
||
'''Adds a Normal cell for NASNet-A (Fig. 4 in the paper) | ||
|
||
# Arguments: | ||
|
@@ -553,31 +541,26 @@ def _normal_A(ip, p, filters, weight_decay=5e-5, id=None): | |
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 | ||
|
||
with K.name_scope('normal_A_block_%s' % id): | ||
p = _adjust_block(p, ip, filters, weight_decay, id) | ||
p = _adjust_block(p, ip, filters, id) | ||
|
||
h = Activation('relu')(ip) | ||
h = Conv2D(filters, (1, 1), strides=(1, 1), padding='same', | ||
name='normal_conv_1_%s' % id, | ||
use_bias=False, kernel_initializer='he_normal', | ||
kernel_regularizer=l2(weight_decay))(h) | ||
use_bias=False, kernel_initializer='he_normal')(h) | ||
h = BatchNormalization(axis=channel_dim, momentum=0.9997, | ||
epsilon=1e-3, | ||
name='normal_bn_1_%s' % id)(h) | ||
|
||
with K.name_scope('block_1'): | ||
x1_1 = _separable_conv_block(h, filters, kernel_size=(5, 5), | ||
weight_decay=weight_decay, | ||
id='normal_left1_%s' % id) | ||
x1_2 = _separable_conv_block(p, filters, weight_decay=weight_decay, | ||
id='normal_right1_%s' % id) | ||
x1_2 = _separable_conv_block(p, filters, id='normal_right1_%s' % id) | ||
x1 = add([x1_1, x1_2], name='normal_add_1_%s' % id) | ||
|
||
with K.name_scope('block_2'): | ||
x2_1 = _separable_conv_block(p, filters, (5, 5), | ||
weight_decay=weight_decay, | ||
id='normal_left2_%s' % id) | ||
x2_2 = _separable_conv_block(p, filters, (3, 3), | ||
weight_decay=weight_decay, | ||
id='normal_right2_%s' % id) | ||
x2 = add([x2_1, x2_2], name='normal_add_2_%s' % id) | ||
|
||
|
@@ -594,16 +577,15 @@ def _normal_A(ip, p, filters, weight_decay=5e-5, id=None): | |
x4 = add([x4_1, x4_2], name='normal_add_4_%s' % id) | ||
|
||
with K.name_scope('block_5'): | ||
x5 = _separable_conv_block(h, filters, weight_decay=weight_decay, | ||
id='normal_left5_%s' % id) | ||
x5 = _separable_conv_block(h, filters, id='normal_left5_%s' % id) | ||
x5 = add([x5, h], name='normal_add_5_%s' % id) | ||
|
||
x = concatenate([p, x1, x2, x3, x4, x5], axis=channel_dim, | ||
name='normal_concat_%s' % id) | ||
return x, ip | ||
|
||
|
||
def _reduction_A(ip, p, filters, weight_decay=5e-5, id=None): | ||
def _reduction_A(ip, p, filters, id=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same remarks as above |
||
'''Adds a Reduction cell for NASNet-A (Fig. 4 in the paper) | ||
|
||
# Arguments: | ||
|
@@ -619,39 +601,34 @@ def _reduction_A(ip, p, filters, weight_decay=5e-5, id=None): | |
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 | ||
|
||
with K.name_scope('reduction_A_block_%s' % id): | ||
p = _adjust_block(p, ip, filters, weight_decay, id) | ||
p = _adjust_block(p, ip, filters, id) | ||
|
||
h = Activation('relu')(ip) | ||
h = Conv2D(filters, (1, 1), strides=(1, 1), padding='same', | ||
name='reduction_conv_1_%s' % id, | ||
use_bias=False, kernel_initializer='he_normal', | ||
kernel_regularizer=l2(weight_decay))(h) | ||
use_bias=False, kernel_initializer='he_normal')(h) | ||
h = BatchNormalization(axis=channel_dim, momentum=0.9997, | ||
epsilon=1e-3, | ||
name='reduction_bn_1_%s' % id)(h) | ||
|
||
with K.name_scope('block_1'): | ||
x1_1 = _separable_conv_block(h, filters, (5, 5), strides=(2, 2), | ||
weight_decay=weight_decay, | ||
id='reduction_left1_%s' % id) | ||
x1_2 = _separable_conv_block(p, filters, (7, 7), strides=(2, 2), | ||
weight_decay=weight_decay, | ||
id='reduction_1_%s' % id) | ||
x1 = add([x1_1, x1_2], name='reduction_add_1_%s' % id) | ||
|
||
with K.name_scope('block_2'): | ||
x2_1 = MaxPooling2D((3, 3), strides=(2, 2), padding='same', | ||
name='reduction_left2_%s' % id)(h) | ||
x2_2 = _separable_conv_block(p, filters, (7, 7), strides=(2, 2), | ||
weight_decay=weight_decay, | ||
id='reduction_right2_%s' % id) | ||
x2 = add([x2_1, x2_2], name='reduction_add_2_%s' % id) | ||
|
||
with K.name_scope('block_3'): | ||
x3_1 = AveragePooling2D((3, 3), strides=(2, 2), padding='same', | ||
name='reduction_left3_%s' % id)(h) | ||
x3_2 = _separable_conv_block(p, filters, (5, 5), strides=(2, 2), | ||
weight_decay=weight_decay, | ||
id='reduction_right3_%s' % id) | ||
x3 = add([x3_1, x3_2], name='reduction_add3_%s' % id) | ||
|
||
|
@@ -662,7 +639,6 @@ def _reduction_A(ip, p, filters, weight_decay=5e-5, id=None): | |
|
||
with K.name_scope('block_5'): | ||
x5_1 = _separable_conv_block(x1, filters, (3, 3), | ||
weight_decay=weight_decay, | ||
id='reduction_left4_%s' % id) | ||
x5_2 = MaxPooling2D((3, 3), strides=(2, 2), padding='same', | ||
name='reduction_right5_%s' % id)(h) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we name this to something more explicit? Most users will not understand what this is without further context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
num_stem_block_filters
?