Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal. #1134

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions scenic/projects/baselines/centernet/modeling/centernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,17 @@ def heatmap_focal_loss(self, heatmaps, gt_heatmaps):
pos_w = (gt_heatmaps == 1.).astype(jnp.float32) # B x m x C
pos_loss = jnp.log(pred) * jnp.power(1 - pred, self.focal_gamma) * pos_w
neg_loss = jnp.log(1. - pred) * jnp.power(pred, self.focal_gamma) * neg_w
norm = jnp.maximum(pos_w.sum(), 1.) # scalar
bs = pos_w.shape[0]
norm = jnp.maximum(pos_w.reshape((bs, -1)).sum(1), 1.0)
norm = jnp.mean(norm) # scalar
if self.sync_device_norm: # sync across GPUs. Helpful for small batch size.
norm = jax.lax.pmean(norm, axis_name='batch')
pos_loss = pos_loss.sum() / norm # scalar
neg_loss = neg_loss.sum() / norm # scalar
norm = jax.lax.pmean(norm, axis_name='batch') # scalar
pos_loss = jnp.mean(pos_loss.reshape((bs, -1)).sum(1)) / norm # scalar
neg_loss = jnp.mean(neg_loss.reshape((bs, -1)).sum(1)) / norm # scalar
if self.focal_alpha >= 0:
pos_loss = self.focal_alpha * pos_loss
neg_loss = (1. - self.focal_alpha) * neg_loss
return - pos_loss, - neg_loss, norm / heatmaps.shape[0]
return - pos_loss, - neg_loss, norm

def reg_loss(self, box_regs, gt_regs):
"""Compute regression loss.
Expand All @@ -242,11 +244,12 @@ def reg_loss(self, box_regs, gt_regs):
"""
reg_inds = gt_regs.max(axis=2) >= 0 # B x m: find valid pixels.
gious = centernet_utils.giou_loss(box_regs, gt_regs) # B x m
norm = jnp.maximum(reg_inds.sum(), 1.) # scalar
norm = jnp.maximum(reg_inds.sum(1), 1.)
norm = jnp.mean(norm) # scalar
if self.sync_device_norm:
norm = jax.lax.pmean(norm, axis_name='batch')
reg_loss = (gious * reg_inds).sum() / norm # scalar
return reg_loss, gious, norm / reg_inds.shape[0]
reg_loss = jnp.mean((gious * reg_inds).sum(1)) / norm # scalar
return reg_loss, gious, norm

def _get_bbox_ltrb(self, grids, boxes, m, n):
"""generate FCOS style regression targets.
Expand Down