forked from open-mmlab/OpenPCDet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from . import box_utils | ||
|
||
|
||
class SigmoidFocalClassificationLoss(nn.Module): | ||
""" | ||
Sigmoid focal cross entropy loss. | ||
""" | ||
|
||
def __init__(self, gamma: float = 2.0, alpha: float = 0.25): | ||
""" | ||
Args: | ||
gamma: Weighting parameter to balance loss for hard and easy examples. | ||
alpha: Weighting parameter to balance loss for positive and negative examples. | ||
""" | ||
super(SigmoidFocalClassificationLoss, self).__init__() | ||
self.alpha = alpha | ||
self.gamma = gamma | ||
|
||
@staticmethod | ||
def sigmoid_cross_entropy_with_logits(input: torch.Tensor, target: torch.Tensor): | ||
""" PyTorch Implementation for tf.nn.sigmoid_cross_entropy_with_logits: | ||
max(x, 0) - x * z + log(1 + exp(-abs(x))) in | ||
https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits | ||
Args: | ||
input: (B, #anchors, #classes) float tensor. | ||
Predicted logits for each class | ||
target: (B, #anchors, #classes) float tensor. | ||
One-hot encoded classification targets | ||
Returns: | ||
loss: (B, #anchors, #classes) float tensor. | ||
Sigmoid cross entropy loss without reduction | ||
""" | ||
loss = torch.clamp(input, min=0) - input * target + \ | ||
torch.log1p(torch.exp(-torch.abs(input))) | ||
return loss | ||
|
||
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor): | ||
""" | ||
Args: | ||
input: (B, #anchors, #classes) float tensor. | ||
Predicted logits for each class | ||
target: (B, #anchors, #classes) float tensor. | ||
One-hot encoded classification targets | ||
weights: (B, #anchors) float tensor. | ||
Anchor-wise weights. | ||
Returns: | ||
weighted_loss: (B, #anchors, #classes) float tensor after weighting. | ||
""" | ||
pred_sigmoid = torch.sigmoid(input) | ||
alpha_weight = target * self.alpha + (1 - target) * (1 - self.alpha) | ||
pt = target * (1.0 - pred_sigmoid) + (1.0 - target) * pred_sigmoid | ||
focal_weight = alpha_weight * torch.pow(pt, self.gamma) | ||
|
||
bce_loss = self.sigmoid_cross_entropy_with_logits(input, target) | ||
|
||
loss = focal_weight * bce_loss | ||
|
||
if weights.shape.__len__() == 2 or \ | ||
(weights.shape.__len__() == 1 and target.shape.__len__() == 2): | ||
weights = weights.unsqueeze(-1) | ||
|
||
assert weights.shape.__len__() == loss.shape.__len__() | ||
|
||
return loss * weights | ||
|
||
|
||
class WeightedSmoothL1Loss(nn.Module): | ||
""" | ||
Code-wise Weighted Smooth L1 Loss modified based on fvcore.nn.smooth_l1_loss | ||
https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/smooth_l1_loss.py | ||
| 0.5 * x ** 2 / beta if abs(x) < beta | ||
smoothl1(x) = | | ||
| abs(x) - 0.5 * beta otherwise, | ||
where x = input - target. | ||
""" | ||
def __init__(self, beta: float = 1.0 / 9.0, code_weights: list = None): | ||
""" | ||
Args: | ||
beta: Scalar float. | ||
L1 to L2 change point. | ||
For beta values < 1e-5, L1 loss is computed. | ||
code_weights: (#codes) float list if not None. | ||
Code-wise weights. | ||
""" | ||
super(WeightedSmoothL1Loss, self).__init__() | ||
self.beta = beta | ||
if code_weights is not None: | ||
self.code_weights = np.array(code_weights, dtype=np.float32) | ||
self.code_weights = torch.from_numpy(self.code_weights).cuda() | ||
|
||
@staticmethod | ||
def smooth_l1_loss(diff, beta): | ||
if beta < 1e-5: | ||
loss = torch.abs(diff) | ||
else: | ||
n = torch.abs(diff) | ||
loss = torch.where(n < beta, 0.5 * n ** 2 / beta, n - 0.5 * beta) | ||
|
||
return loss | ||
|
||
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None): | ||
""" | ||
Args: | ||
input: (B, #anchors, #codes) float tensor. | ||
Ecoded predicted locations of objects. | ||
target: (B, #anchors, #codes) float tensor. | ||
Regression targets. | ||
weights: (B, #anchors) float tensor if not None. | ||
Returns: | ||
loss: (B, #anchors) float tensor. | ||
Weighted smooth l1 loss without reduction. | ||
""" | ||
diff = input - target | ||
# code-wise weighting | ||
if self.code_weights is not None: | ||
diff = diff * self.code_weights.view(1, 1, -1) | ||
|
||
loss = self.smooth_l1_loss(diff, self.beta) | ||
|
||
# anchor-wise weighting | ||
if weights is not None: | ||
assert weights.shape[0] == loss.shape[0] and weights.shape[1] == loss.shape[1] | ||
loss = loss * weights.unsqueeze(-1) | ||
|
||
return loss | ||
|
||
|
||
class WeightedCrossEntropyLoss(nn.Module): | ||
""" | ||
Transform input to fit the fomation of PyTorch offical cross entropy loss | ||
with anchor-wise weighting. | ||
""" | ||
def __init__(self): | ||
super(WeightedCrossEntropyLoss, self).__init__() | ||
|
||
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor): | ||
""" | ||
Args: | ||
input: (B, #anchors, #classes) float tensor. | ||
Predited logits for each class. | ||
target: (B, #anchors, #classes) float tensor. | ||
One-hot classification targets. | ||
weights: (B, #anchors) float tensor. | ||
Anchor-wise weights. | ||
Returns: | ||
loss: (B, #anchors) float tensor. | ||
Weighted cross entropy loss without reduction | ||
""" | ||
input = input.permute(0, 2, 1) | ||
target = target.argmax(dim=-1) | ||
loss = F.cross_entropy(input, target, reduction='none') * weights | ||
return loss | ||
|
||
|
||
def get_corner_loss_lidar(pred_bbox3d: torch.Tensor, gt_bbox3d: torch.Tensor): | ||
""" | ||
Args: | ||
pred_bbox3d: (N, 7) float Tensor. | ||
gt_bbox3d: (N, 7) float Tensor. | ||
Returns: | ||
corner_loss: (N) float Tensor. | ||
""" | ||
assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0] | ||
|
||
pred_box_corners = box_utils.boxes_to_corners_3d(pred_bbox3d) | ||
gt_box_corners = box_utils.boxes_to_corners_3d(gt_bbox3d) | ||
|
||
gt_bbox3d_flip = gt_bbox3d.clone() | ||
gt_bbox3d_flip[:, 6] += np.pi | ||
gt_box_corners_flip = box_utils.boxes_to_corners_3d(gt_bbox3d_flip) | ||
# (N, 8) | ||
corner_dist = torch.min(torch.norm(pred_box_corners - gt_box_corners, dim=2), | ||
torch.norm(pred_box_corners - gt_box_corners_flip, dim=2)) | ||
# (N, 8) | ||
corner_loss = WeightedSmoothL1Loss.smooth_l1_loss(corner_dist, beta=1.0) | ||
|
||
return corner_loss.mean(dim=1) |