Skip to content

Commit

Permalink
[TF] Align the optimizer section of the configuration file (openvinot…
Browse files Browse the repository at this point in the history
…oolkit#531)

* align optimizer and scheduler config params
  • Loading branch information
evgeniya-egupova committed Mar 18, 2021
1 parent 83c5947 commit 7b4cbdb
Show file tree
Hide file tree
Showing 52 changed files with 179 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"epochs": 10,

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [5],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"epochs": 15,

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [5, 10],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"batch_size": 128,
"epochs": 4,
"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [2],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"epochs": 40,

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [20, 25, 30],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"epochs": 17,

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [12, 14, 16],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"epochs": 40,

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [15, 20, 30],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"batch_size": 256,
"epochs": 55,
"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [30, 45, 50],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"epochs": 60,

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [35, 50, 55],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"epochs": 40,

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [20, 28],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"batch_size": 256,
"epochs": 55,
"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [30, 45, 50],
Expand Down
4 changes: 1 addition & 3 deletions beta/examples/tensorflow/classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def run(config):

scheduler = build_scheduler(
config=config,
epoch_size=train_builder.num_examples,
batch_size=train_builder.global_batch_size,
steps=train_steps)
steps_per_epoch=train_steps)
optimizer = build_optimizer(
config=config,
scheduler=scheduler)
Expand Down
69 changes: 36 additions & 33 deletions beta/examples/tensorflow/common/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,51 +21,54 @@ def build_optimizer(config, scheduler):
optimizer_config = config.get('optimizer', {})

optimizer_type = optimizer_config.get('type', 'adam').lower()
optimizer_params = optimizer_config.get("optimizer_params", {})
optimizer_params = optimizer_config.get('optimizer_params', {})

logger.info('Building %s optimizer with params %s', optimizer_type, optimizer_params)

if optimizer_type == 'sgd':
logger.info('Using SGD optimizer')
nesterov = optimizer_params.get('nesterov', False)
optimizer = tf.keras.optimizers.SGD(learning_rate=scheduler,
nesterov=nesterov)
elif optimizer_type == 'momentum':
logger.info('Using momentum optimizer')
if optimizer_type in ['sgd', 'momentum']:
printable_names = {'sgd': 'SGD', 'momentum': 'momentum'}
logger.info('Using %s optimizer', printable_names[optimizer_type])

default_momentum_value = 0.9 if optimizer_type == 'momentum' else 0.0
momentum = optimizer_params.get('momentum', default_momentum_value)
nesterov = optimizer_params.get('nesterov', False)
momentum = optimizer_params.get('momentum', 0.9)
optimizer = tf.keras.optimizers.SGD(learning_rate=scheduler,
momentum=momentum,
nesterov=nesterov)
weight_decay = optimizer_config.get('weight_decay', None)
common_params = {'learning_rate': scheduler,
'nesterov': nesterov,
'momentum': momentum}
if weight_decay:
optimizer = tfa.optimizers.SGDW(**common_params,
weight_decay=weight_decay)
else:
optimizer = tf.keras.optimizers.SGD(**common_params)
elif optimizer_type == 'rmsprop':
logger.info('Using RMSProp')
logger.info('Using RMSProp optimizer')
rho = optimizer_params.get('rho', 0.9)
momentum = optimizer_params.get('momentum', 0.9)
epsilon = optimizer_params.get('epsilon', 1e-07)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=scheduler,
rho=rho,
momentum=momentum,
epsilon=epsilon)
elif optimizer_type == 'adam':
logger.info('Using Adam')
beta_1 = optimizer_params.get('beta_1', 0.9)
beta_2 = optimizer_params.get('beta_2', 0.999)
epsilon = optimizer_params.get('epsilon', 1e-07)
optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon)
elif optimizer_type == 'adamw':
logger.info('Using AdamW')
weight_decay = optimizer_params.get('weight_decay', 0.01)
beta_1 = optimizer_params.get('beta_1', 0.9)
beta_2 = optimizer_params.get('beta_2', 0.999)
epsilon = optimizer_params.get('epsilon', 1e-07)
optimizer = tfa.optimizers.AdamW(weight_decay=weight_decay,
learning_rate=scheduler,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon)
elif optimizer_type in ['adam', 'adamw']:
printable_names = {'adam': 'Adam', 'adamw': 'AdamW'}
logger.info('Using %s optimizer', printable_names[optimizer_type])

beta_1, beta_2 = optimizer_params.get('betas', [0.9, 0.999])
epsilon = optimizer_params.get('eps', 1e-07)
amsgrad = optimizer_params.get('amsgrad', False)
w_decay_defaul_value = 0.01 if optimizer_type == 'adamw' else None
weight_decay = optimizer_config.get('weight_decay', w_decay_defaul_value)
common_params = {'learning_rate': scheduler,
'beta_1': beta_1,
'beta_2': beta_2,
'epsilon': epsilon,
'amsgrad': amsgrad}
if weight_decay:
optimizer = tfa.optimizers.AdamW(**common_params,
weight_decay=weight_decay)
else:
optimizer = tf.keras.optimizers.Adam(**common_params)
else:
raise ValueError('Unknown optimizer %s' % optimizer_type)

Expand Down
123 changes: 87 additions & 36 deletions beta/examples/tensorflow/common/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
class StepLearningRateWithLinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor"""

def __init__(self, total_steps, params):
def __init__(self, params):
"""Creates the step learning rate tensor with linear warmup"""
super().__init__()
self._total_steps = total_steps
self._params = params

def __call__(self, global_step):
Expand All @@ -41,60 +40,112 @@ def __call__(self, global_step):
return learning_rate

def get_config(self):
return {'_params': self._params.as_dict()}
return {'params': self._params.as_dict()}


class MultiStepLearningRate(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, init_lr, steps, gamma=0.1):
"""
Creates the multistep learning rate schedule.
Decays learning rate by `gamma` once `global_step` reaches
one of the milestones in the `steps` list.
For init_lr = 0.01, steps = [10, 15] and gamma = 0.1
lr = 0.01 if global_step < 10
lr = 0.001 if 10 <= global_step < 15
lr = 0.0001 if global_step >= 15
Args:
init_lr: Initial learning rate
steps: List of step indices
gamma: Learning rate decay rate
"""
super().__init__()
self._init_lr = init_lr
self._steps = sorted(steps)
self._gamma = gamma
self._lr_values = [init_lr * self._gamma ** (i + 1) for i in range(len(self._steps))]

def __call__(self, global_step):
learning_rate = self._init_lr
for next_learning_rate, start_step in zip(self._lr_values, self._steps):
learning_rate = tf.where(global_step >= start_step, next_learning_rate, learning_rate)
return learning_rate

def get_config(self):
return {'init_lr': self._init_lr,
'steps': self._steps,
'gamma': self._gamma}


def build_scheduler(config, epoch_size, batch_size, steps):
def build_scheduler(config, steps_per_epoch):
optimizer_config = config.get('optimizer', {})
schedule_type = optimizer_config.get('schedule_type', 'exponential').lower()
schedule_params = optimizer_config.get("schedule_params", {})
schedule_type = optimizer_config.get('schedule_type', 'step').lower()
schedule_params = optimizer_config.get('schedule_params', {})
gamma = schedule_params.get('gamma', optimizer_config.get('gamma', 0.1))
base_lr = schedule_params.get('base_lr', optimizer_config.get('base_lr', None))

if schedule_type == 'exponential':
decay_rate = schedule_params.get('decay_rate', None)
if decay_rate is None:
raise ValueError('decay_rate parameter must be specified '
if base_lr is None:
raise ValueError('`base_lr` parameter must be specified '
'for the exponential scheduler')

initial_lr = schedule_params.get('initial_lr', None)
if initial_lr is None:
raise ValueError('initial_lr parameter must be specified '
'for the exponential scheduler')

decay_epochs = schedule_params.get('decay_epochs', None)
decay_steps = decay_epochs * steps if decay_epochs is not None else 0
step = schedule_params.get('step', optimizer_config.get('step', 1))
decay_steps = step * steps_per_epoch

logger.info('Using exponential learning rate with: '
'initial_learning_rate: {initial_lr}, decay_steps: {decay_steps}, '
'decay_rate: {decay_rate}'.format(initial_lr=initial_lr,
decay_steps=decay_steps,
decay_rate=decay_rate))
'initial lr: %f, decay steps: %d, '
'decay rate: %f', base_lr, decay_steps, gamma)
lr = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=initial_lr,
initial_learning_rate=base_lr,
decay_steps=decay_steps,
decay_rate=decay_rate)
decay_rate=gamma)

elif schedule_type == 'piecewise_constant':
boundaries = schedule_params.get('boundaries', None)
boundaries = schedule_params.get('boundaries', optimizer_config.get('boundaries', None))
if boundaries is None:
raise ValueError('boundaries parameter must be specified '
'for the piecewise_constant scheduler')
raise ValueError('`boundaries` parameter must be specified '
'for the `piecewise_constant` scheduler')

values = schedule_params.get('values', None)
values = schedule_params.get('values', optimizer_config.get('values', None))
if values is None:
raise ValueError('values parameter must be specified '
'for the piecewise_constant')
raise ValueError('`values` parameter must be specified '
'for the `piecewise_constant` scheduler')

logger.info('Using Piecewise constant decay with warmup. '
'Parameters: batch_size: {batch_size}, epoch_size: {epoch_size}, '
'boundaries: {boundaries}, values: {values}'.format(batch_size=batch_size,
epoch_size=epoch_size,
boundaries=boundaries,
values=values))
steps_per_epoch = epoch_size // batch_size
'Parameters: boundaries: %s, values: %s', boundaries, values)
boundaries = [steps_per_epoch * x for x in boundaries]
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)

elif schedule_type == 'multistep':
logger.info('Using MultiStep learning rate.')
if base_lr is None:
raise ValueError('`base_lr` parameter must be specified '
'for the `multistep` scheduler')
steps = schedule_params.get('steps', optimizer_config.get('steps', None))
if steps is None:
raise ValueError('`steps` parameter must be specified '
'for the `multistep` scheduler')
steps = [steps_per_epoch * x for x in steps]
lr = MultiStepLearningRate(base_lr, steps, gamma=gamma)

elif schedule_type == 'step':
lr = StepLearningRateWithLinearWarmup(steps, schedule_params)
if base_lr is None:
raise ValueError('`base_lr` parameter must be specified '
'for the `step` scheduler')
step = schedule_params.get('step', optimizer_config.get('step', 1))
decay_steps = step * steps_per_epoch

logger.info('Using Step learning rate with: '
'base_lr: %f, decay steps: %d, '
'gamma: %f', base_lr, decay_steps, gamma)
lr = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=base_lr,
decay_steps=decay_steps,
decay_rate=gamma,
staircase=True
)
elif schedule_type == 'step_warmup':
lr = StepLearningRateWithLinearWarmup(schedule_params)
else:
raise NameError(f"unsupported type of learning rate scheduler: {schedule_type}")
raise KeyError(f'Unknown learning rate scheduler type: {schedule_type}')

return lr
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"momentum": 0.9,
"nesterov": true
},
"schedule_type": "step",
"schedule_type": "step_warmup",
"schedule_params": {
"warmup_learning_rate": 0.0067,
"warmup_steps": 500,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"dataset_type": "tfds",

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [
Expand Down
4 changes: 1 addition & 3 deletions beta/examples/tensorflow/object_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ def run(config):

scheduler = build_scheduler(
config=config,
epoch_size=train_builder.num_examples,
batch_size=train_builder.global_batch_size,
steps=steps_per_epoch)
steps_per_epoch=steps_per_epoch)

optimizer = build_optimizer(
config=config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"momentum": 0.9,
"nesterov": true
},
"schedule_type": "step",
"schedule_type": "step_warmup",
"schedule_params": {
"warmup_learning_rate": 0.0067,
"warmup_steps": 500,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"dataset": "coco/2017",

"optimizer": {
"type": "adam",
"type": "Adam",
"schedule_type": "piecewise_constant",
"schedule_params": {
"boundaries": [1],
Expand Down
Loading

0 comments on commit 7b4cbdb

Please sign in to comment.