Skip to content

Commit

Permalink
update a lot
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing committed Feb 23, 2022
1 parent 02b0e1f commit 935d05a
Show file tree
Hide file tree
Showing 9 changed files with 600 additions and 149 deletions.
2 changes: 1 addition & 1 deletion nets/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#---------------------------------------------------#
@wraps(Conv2D)
def DarknetConv2D(*args, **kwargs):
darknet_conv_kwargs = {'kernel_initializer' : random_normal(stddev=0.02), 'kernel_regularizer': l2(5e-4)}
darknet_conv_kwargs = {'kernel_initializer' : random_normal(stddev=0.02)}
darknet_conv_kwargs['padding'] = 'valid' if kwargs.get('strides')==(2,2) else 'same'
darknet_conv_kwargs.update(kwargs)
return Conv2D(*args, **darknet_conv_kwargs)
Expand Down
11 changes: 10 additions & 1 deletion nets/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,16 @@ def get_train_model(model_body, input_shape, num_classes, anchors, anchors_mask)
yolo_loss,
output_shape = (1, ),
name = 'yolo_loss',
arguments = {'input_shape' : input_shape, 'anchors' : anchors, 'anchors_mask' : anchors_mask, 'num_classes' : num_classes}
arguments = {
'input_shape' : input_shape,
'anchors' : anchors,
'anchors_mask' : anchors_mask,
'num_classes' : num_classes,
'balance' : [0.4, 1.0, 4],
'box_ratio' : 0.05,
'obj_ratio' : 5 * (input_shape[0] * input_shape[1]) / (416 ** 2),
'cls_ratio' : 1 * (num_classes / 80)
}
)([*model_body.output, *y_true])
model = Model([model_body.input, *y_true], model_loss)
return model
197 changes: 163 additions & 34 deletions nets/yolo_training.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,78 @@
import math
from functools import partial

import tensorflow as tf
from keras import backend as K
from utils.utils_bbox import get_anchors_and_decode


def box_ciou(b1, b2):
"""
输入为:
----------
b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
返回为:
-------
ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
"""
#-----------------------------------------------------------#
# 求出预测框左上角右下角
# b1_mins (batch, feat_w, feat_h, anchor_num, 2)
# b1_maxes (batch, feat_w, feat_h, anchor_num, 2)
#-----------------------------------------------------------#
b1_xy = b1[..., :2]
b1_wh = b1[..., 2:4]
b1_wh_half = b1_wh/2.
b1_mins = b1_xy - b1_wh_half
b1_maxes = b1_xy + b1_wh_half
#-----------------------------------------------------------#
# 求出真实框左上角右下角
# b2_mins (batch, feat_w, feat_h, anchor_num, 2)
# b2_maxes (batch, feat_w, feat_h, anchor_num, 2)
#-----------------------------------------------------------#
b2_xy = b2[..., :2]
b2_wh = b2[..., 2:4]
b2_wh_half = b2_wh/2.
b2_mins = b2_xy - b2_wh_half
b2_maxes = b2_xy + b2_wh_half

#-----------------------------------------------------------#
# 求真实框和预测框所有的iou
# iou (batch, feat_w, feat_h, anchor_num)
#-----------------------------------------------------------#
intersect_mins = K.maximum(b1_mins, b2_mins)
intersect_maxes = K.minimum(b1_maxes, b2_maxes)
intersect_wh = K.maximum(intersect_maxes - intersect_mins, 0.)
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
union_area = b1_area + b2_area - intersect_area
iou = intersect_area / K.maximum(union_area, K.epsilon())

#-----------------------------------------------------------#
# 计算中心的差距
# center_distance (batch, feat_w, feat_h, anchor_num)
#-----------------------------------------------------------#
center_distance = K.sum(K.square(b1_xy - b2_xy), axis=-1)
enclose_mins = K.minimum(b1_mins, b2_mins)
enclose_maxes = K.maximum(b1_maxes, b2_maxes)
enclose_wh = K.maximum(enclose_maxes - enclose_mins, 0.0)
#-----------------------------------------------------------#
# 计算对角线距离
# enclose_diagonal (batch, feat_w, feat_h, anchor_num)
#-----------------------------------------------------------#
enclose_diagonal = K.sum(K.square(enclose_wh), axis=-1)
ciou = iou - 1.0 * (center_distance) / K.maximum(enclose_diagonal ,K.epsilon())

v = 4 * K.square(tf.math.atan2(b1_wh[..., 0], K.maximum(b1_wh[..., 1], K.epsilon())) - tf.math.atan2(b2_wh[..., 0], K.maximum(b2_wh[..., 1],K.epsilon()))) / (math.pi * math.pi)
alpha = v / K.maximum((1.0 - iou + v), K.epsilon())
ciou = ciou - alpha * v

ciou = K.expand_dims(ciou, -1)
return ciou

#---------------------------------------------------#
# 用于计算每个预测框与真实框的iou
#---------------------------------------------------#
Expand Down Expand Up @@ -44,7 +114,20 @@ def box_iou(b1, b2):
#---------------------------------------------------#
# loss值计算
#---------------------------------------------------#
def yolo_loss(args, input_shape, anchors, anchors_mask, num_classes, ignore_thresh=0.5, print_loss=False):
def yolo_loss(
args,
input_shape,
anchors,
anchors_mask,
num_classes,
ignore_thresh = 0.5,
balance = [0.4, 1.0, 4],
box_ratio = 0.05,
obj_ratio = 1,
cls_ratio = 0.5 / 4,
ciou_flag = True,
print_loss = False
):
num_layers = len(anchors_mask)
#---------------------------------------------------------------------------------------------------#
# 将预测结果和实际ground truth分开,args是[*model_body.output, *y_true]
Expand Down Expand Up @@ -76,7 +159,6 @@ def yolo_loss(args, input_shape, anchors, anchors_mask, num_classes, ignore_thre
m = K.shape(yolo_outputs[0])[0]

loss = 0
num_pos = 0
#---------------------------------------------------------------------------------------------------#
# y_true是一个列表,包含三个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85),(m,52,52,3,85)。
# yolo_outputs是一个列表,包含三个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85),(m,52,52,3,85)。
Expand Down Expand Up @@ -159,33 +241,42 @@ def loop_body(b, ignore_mask):
# (m,13,13,3,1)
ignore_mask = K.expand_dims(ignore_mask, -1)

#-----------------------------------------------------------#
# 将真实框进行编码,使其格式与预测的相同,后面用于计算loss
#-----------------------------------------------------------#
raw_true_xy = y_true[l][..., :2] * grid_shapes[l][::-1] - grid
raw_true_wh = K.log(y_true[l][..., 2:4] / anchors[anchors_mask[l]] * input_shape[::-1])

#-----------------------------------------------------------#
# object_mask如果真实存在目标则保存其wh值
# switch接口,就是一个if/else条件判断语句
#-----------------------------------------------------------#
raw_true_wh = K.switch(object_mask, raw_true_wh, K.zeros_like(raw_true_wh))
#-----------------------------------------------------------#
# reshape_y_true[...,2:3]和reshape_y_true[...,3:4]
# 表示真实框的宽高,二者均在0-1之间
# 真实框越大,比重越小,小框的比重更大。
#-----------------------------------------------------------#
box_loss_scale = 2 - y_true[l][...,2:3]*y_true[l][...,3:4]
box_loss_scale = 2 - y_true[l][...,2:3] * y_true[l][...,3:4]
if ciou_flag:
#-----------------------------------------------------------#
# 计算Ciou loss
#-----------------------------------------------------------#
raw_true_box = y_true[l][...,0:4]
ciou = box_ciou(pred_box, raw_true_box)
ciou_loss = object_mask * (1 - ciou) * box_loss_scale
location_loss = K.sum(ciou_loss)
else:
#-----------------------------------------------------------#
# 将真实框进行编码,使其格式与预测的相同,后面用于计算loss
#-----------------------------------------------------------#
raw_true_xy = y_true[l][..., :2] * grid_shapes[l][::-1] - grid
raw_true_wh = K.log(y_true[l][..., 2:4] / anchors[anchors_mask[l]] * input_shape[::-1])

#-----------------------------------------------------------#
# 利用binary_crossentropy计算中心点偏移情况,效果更好
#-----------------------------------------------------------#
xy_loss = object_mask * box_loss_scale * K.binary_crossentropy(raw_true_xy, raw_pred[...,0:2], from_logits=True)
#-----------------------------------------------------------#
# wh_loss用于计算宽高损失
#-----------------------------------------------------------#
wh_loss = object_mask * box_loss_scale * 0.5 * K.square(raw_true_wh - raw_pred[...,2:4])

#-----------------------------------------------------------#
# object_mask如果真实存在目标则保存其wh值
# switch接口,就是一个if/else条件判断语句
#-----------------------------------------------------------#
raw_true_wh = K.switch(object_mask, raw_true_wh, K.zeros_like(raw_true_wh))
#-----------------------------------------------------------#
# 利用binary_crossentropy计算中心点偏移情况,效果更好
#-----------------------------------------------------------#
xy_loss = object_mask * box_loss_scale * K.binary_crossentropy(raw_true_xy, raw_pred[...,0:2], from_logits=True)
#-----------------------------------------------------------#
# wh_loss用于计算宽高损失
#-----------------------------------------------------------#
wh_loss = object_mask * box_loss_scale * 0.5 * K.square(raw_true_wh - raw_pred[...,2:4])
location_loss = K.sum(xy_loss) + K.sum(wh_loss)

#------------------------------------------------------------------------------#
# 如果该位置本来有框,那么计算1与置信度的交叉熵
# 如果该位置本来没有框,那么计算0与置信度的交叉熵
Expand All @@ -200,20 +291,58 @@ def loop_body(b, ignore_mask):
class_loss = object_mask * K.binary_crossentropy(true_class_probs, raw_pred[...,5:], from_logits=True)

#-----------------------------------------------------------#
# 将所有损失求和
# 计算正样本数量
#-----------------------------------------------------------#
xy_loss = K.sum(xy_loss)
wh_loss = K.sum(wh_loss)
confidence_loss = K.sum(confidence_loss)
class_loss = K.sum(class_loss)
num_pos = tf.maximum(K.sum(K.cast(object_mask, tf.float32)), 1)
num_neg = tf.maximum(K.sum(K.cast((1 - object_mask) * ignore_mask, tf.float32)), 1)
#-----------------------------------------------------------#
# 计算正样本数量
# 将所有损失求和
#-----------------------------------------------------------#
num_pos += tf.maximum(K.sum(K.cast(object_mask, tf.float32)), 1)
loss += xy_loss + wh_loss + confidence_loss + class_loss
location_loss = location_loss * box_ratio / num_pos
confidence_loss = K.sum(confidence_loss) * balance[l] * obj_ratio / (num_pos + num_neg)
class_loss = K.sum(class_loss) * cls_ratio / num_pos / num_classes

loss += location_loss + confidence_loss + class_loss
if print_loss:
loss = tf.Print(loss, [loss, xy_loss, wh_loss, confidence_loss, class_loss, tf.shape(ignore_mask)], summarize=100, message='loss: ')

loss = loss / num_pos
loss = tf.Print(loss, [loss, location_loss, confidence_loss, class_loss, tf.shape(ignore_mask)], summarize=100, message='loss: ')
return loss


def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.1, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.3, step_num = 10):
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
if iters <= warmup_total_iters:
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2
) + warmup_lr_start
elif iters >= total_iters - no_aug_iter:
lr = min_lr
else:
lr = min_lr + 0.5 * (lr - min_lr) * (
1.0
+ math.cos(
math.pi
* (iters - warmup_total_iters)
/ (total_iters - warmup_total_iters - no_aug_iter)
)
)
return lr

def step_lr(lr, decay_rate, step_size, iters):
if step_size < 1:
raise ValueError("step_size must above 1.")
n = iters // step_size
out_lr = lr * decay_rate ** n
return out_lr

if lr_decay_type == "cos":
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
else:
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
step_size = total_iters / step_num
func = partial(step_lr, lr, decay_rate, step_size)

return func

16 changes: 11 additions & 5 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
# 'fps'表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
# 'dir_predict'表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
#----------------------------------------------------------------------------------------------------------#
mode = "predict"
mode = "predict"
#-------------------------------------------------------------------------#
# crop指定了是否在单张图片预测后对目标进行截取
# crop仅在mode='predict'时有效
#-------------------------------------------------------------------------#
crop = False
#----------------------------------------------------------------------------------------------------------#
# video_path用于指定视频的路径,当video_path=0时表示检测摄像头
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
Expand Down Expand Up @@ -62,7 +67,7 @@
print('Open Error! Try again!')
continue
else:
r_image = yolo.detect_image(image)
r_image = yolo.detect_image(image, crop = crop)
r_image.show()

elif mode == "video":
Expand Down Expand Up @@ -111,14 +116,15 @@
print("Save processed video to the path :" + video_save_path)
out.release()
cv2.destroyAllWindows()

elif mode == "fps":
img = Image.open('img/street.jpg')
tact_time = yolo.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')

elif mode == "dir_predict":
import os

from tqdm import tqdm

img_names = os.listdir(dir_origin_path)
Expand All @@ -129,7 +135,7 @@
r_image = yolo.detect_image(image)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name))
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)

else:
raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")
Loading

0 comments on commit 935d05a

Please sign in to comment.