Skip to content

Commit eb18286

Browse files
authored
Update fast_rcnn.py
add new loss function
1 parent 0e22bd5 commit eb18286

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

detectron2/modeling/roi_heads/fast_rcnn.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,22 @@ def losses(self, predictions, proposals):
340340
else:
341341
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)
342342

343-
if mode == 1:
344-
from my_fastrcnn_loss_with_focal_loss import fastrcnn_loss
345-
loss_cls = fastrcnn_loss
343+
#書き換えここから
344+
loss_type = self.cfg.MODEL.ROI_HEADS.LOSS_TYPE
345+
if loss_type == "focal":
346+
# Focal Loss
347+
gamma = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA
348+
alpha = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA
349+
loss_cls = focal_loss(pred_class_logits, gt_classes, gamma, alpha)
350+
elif loss_type == "bce":
351+
# BCE Loss
352+
gt_one_hot = F.one_hot(gt_classes, num_classes=pred_class_logits.size(1)).float()
353+
loss_cls = F.binary_cross_entropy_with_logits(pred_class_logits, gt_one_hot, reduction="mean")
346354
elif self.use_sigmoid_ce:
347355
loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes)
348356
else:
349357
loss_cls = cross_entropy(scores, gt_classes, reduction="mean")
358+
#ここまで
350359

351360
losses = {
352361
"loss_cls": loss_cls,

0 commit comments

Comments
 (0)