Skip to content

Commit

Permalink
update best checkpoint and config show
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing committed May 1, 2022
1 parent 5e62438 commit 2cf3da0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
24 changes: 17 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utils.callbacks import (ExponentDecayScheduler, LossHistory,
ParallelModelCheckpoint)
from utils.dataloader import YoloDatasets
from utils.utils import get_anchors, get_classes
from utils.utils import get_anchors, get_classes, show_config

'''
训练自己的目标检测模型一定需要注意以下几点:
Expand Down Expand Up @@ -230,6 +230,12 @@
num_train = len(train_lines)
num_val = len(val_lines)

show_config(
classes_path = classes_path, anchors_path = anchors_path, anchors_mask = anchors_mask, model_path = model_path, input_shape = input_shape, \
Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \
Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \
save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val
)
#-----------------------------------------------#
# 总训练世代指的是遍历全部数据的总次数
# 总训练步长指的是梯度下降的总次数
Expand All @@ -238,10 +244,10 @@
wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4
total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch
if total_step <= wanted_step:
wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size)
print("\033[1;33;40m\n[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step))
print("\033[1;33;40m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))
print("\033[1;33;40m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m"%(total_step, wanted_step, wanted_epoch))
wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step))
print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))
print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m"%(total_step, wanted_step, wanted_epoch))

for layer in model_body.layers:
if isinstance(layer, DepthwiseConv2D):
Expand Down Expand Up @@ -315,14 +321,18 @@
monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = save_period)
checkpoint_last = ParallelModelCheckpoint(model_body, os.path.join(save_dir, "last_epoch_weights.h5"),
monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = 1)
checkpoint_best = ParallelModelCheckpoint(model_body, os.path.join(save_dir, "best_epoch_weights.h5"),
monitor = 'val_loss', save_weights_only = True, save_best_only = True, period = 1)
else:
checkpoint = ModelCheckpoint(os.path.join(save_dir, "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5"),
monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = save_period)
checkpoint_last = ModelCheckpoint(os.path.join(save_dir, "last_epoch_weights.h5"),
monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = 1)
checkpoint_best = ModelCheckpoint(os.path.join(save_dir, "best_epoch_weights.h5"),
monitor = 'val_loss', save_weights_only = True, save_best_only = True, period = 1)
early_stopping = EarlyStopping(monitor='val_loss', min_delta = 0, patience = 10, verbose = 1)
lr_scheduler = LearningRateScheduler(lr_scheduler_func, verbose = 1)
callbacks = [logging, loss_history, checkpoint, checkpoint_last, lr_scheduler]
callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler]

if start_epoch < end_epoch:
print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
Expand Down Expand Up @@ -359,7 +369,7 @@
#---------------------------------------#
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
lr_scheduler = LearningRateScheduler(lr_scheduler_func, verbose = 1)
callbacks = [logging, loss_history, checkpoint, lr_scheduler]
callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler]

for i in range(len(model_body.layers)):
model_body.layers[i].trainable = True
Expand Down
9 changes: 9 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def preprocess_input(image):
image /= 255.0
return image

def show_config(**kwargs):
print('Configurations:')
print('-' * 70)
print('|%25s | %40s|' % ('keys', 'values'))
print('-' * 70)
for key, value in kwargs.items():
print('|%25s | %40s|' % (str(key), str(value)))
print('-' * 70)

#-------------------------------------------------------------------------------------------------------------------------------#
# From https://github.com/ckyrkou/Keras_FLOP_Estimator
# Fix lots of bugs
Expand Down
3 changes: 2 additions & 1 deletion yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from nets.yolo import yolo_body
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
resize_image)
resize_image, show_config)
from utils.utils_bbox import DecodeBox


Expand Down Expand Up @@ -65,6 +65,7 @@ def get_defaults(cls, n):
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
show_config(**self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)

Expand Down

0 comments on commit 2cf3da0

Please sign in to comment.