Skip to content

Commit

Permalink
focal
Browse files Browse the repository at this point in the history
  • Loading branch information
yhlleo committed Jan 4, 2020
1 parent 015bb2c commit c21c913
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
7 changes: 5 additions & 2 deletions models/deepcrack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import itertools
from .base_model import BaseModel
from .deepcrack_networks import define_deepcrack
from .deepcrack_networks import define_deepcrack, BinaryFocalLoss

class DeepCrackModel(BaseModel):
"""
Expand Down Expand Up @@ -50,7 +50,10 @@ def __init__(self, opt):
# define loss functions
#self.weight = torch.from_numpy(np.array([0.0300, 1.0000], dtype='float32')).float().to(self.device)
#self.criterionSeg = torch.nn.CrossEntropyLoss(weight=self.weight)
self.criterionSeg = nn.BCEWithLogitsLoss(size_average=True, reduce=True,
if self.opt.loss_mode == 'focal':
self.criterionSeg = BinaryFocalLoss()
else:
self.criterionSeg = nn.BCEWithLogitsLoss(size_average=True, reduce=True,
pos_weight=torch.tensor(1.0/3e-2).to(self.device))
self.weight_side = [0.5, 0.75, 1.0, 0.75, 0.5]

Expand Down
20 changes: 20 additions & 0 deletions models/deepcrack_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,23 @@ def define_deepcrack(in_nc,
gpu_ids=[]):
net = DeepCrackNet(in_nc, num_classes, ngf, norm)
return init_net(net, init_type, init_gain, gpu_ids)


class BinaryFocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, size_average=True):
super(BinaryFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.size_average = size_average
self.criterion = nn.BCEWithLogitsLoss(reduction='none')

def forward(self, inputs, targets):
BCE_loss = self.criterion(inputs, targets)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

if self.size_average:
return F_loss.mean()
else:
return F_loss.sum()
1 change: 1 addition & 0 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def initialize(self, parser):
#parser.add_argument('--use_l1', type=int, default=1, help='using l1 loss')

parser.add_argument('--use_selu', type=int, default=1, help='using selu active function')
parser.add_argument('--loss_mode', type=str, default='focal', help='[bce | focal]')
self.initialized = True
return parser

Expand Down
4 changes: 3 additions & 1 deletion scripts/train_deepcrack.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ BATCH_SIZE=1
NUM_CLASSES=1
LOAD_WIDTH=256
LOAD_HEIGHT=256
LOSS_MODE=focal

NORM=batch
NITER=400
Expand All @@ -30,4 +31,5 @@ python3 train.py \
--load_width ${LOAD_WIDTH} \
--load_height ${LOAD_HEIGHT} \
--no_flip 0 \
--display_id 0
--display_id 0 \
--loss_mode ${LOSS_MODE}

0 comments on commit c21c913

Please sign in to comment.