diff --git a/scenic/projects/baselines/centernet/modeling/centernet.py b/scenic/projects/baselines/centernet/modeling/centernet.py index caa4710ad..f506bbc7a 100644 --- a/scenic/projects/baselines/centernet/modeling/centernet.py +++ b/scenic/projects/baselines/centernet/modeling/centernet.py @@ -214,15 +214,17 @@ def heatmap_focal_loss(self, heatmaps, gt_heatmaps): pos_w = (gt_heatmaps == 1.).astype(jnp.float32) # B x m x C pos_loss = jnp.log(pred) * jnp.power(1 - pred, self.focal_gamma) * pos_w neg_loss = jnp.log(1. - pred) * jnp.power(pred, self.focal_gamma) * neg_w - norm = jnp.maximum(pos_w.sum(), 1.) # scalar + bs = pos_w.shape[0] + norm = jnp.maximum(pos_w.reshape((bs, -1)).sum(1), 1.0) + norm = jnp.mean(norm) # scalar if self.sync_device_norm: # sync across GPUs. Helpful for small batch size. - norm = jax.lax.pmean(norm, axis_name='batch') - pos_loss = pos_loss.sum() / norm # scalar - neg_loss = neg_loss.sum() / norm # scalar + norm = jax.lax.pmean(norm, axis_name='batch') # scalar + pos_loss = jnp.mean(pos_loss.reshape((bs, -1)).sum(1)) / norm # scalar + neg_loss = jnp.mean(neg_loss.reshape((bs, -1)).sum(1)) / norm # scalar if self.focal_alpha >= 0: pos_loss = self.focal_alpha * pos_loss neg_loss = (1. - self.focal_alpha) * neg_loss - return - pos_loss, - neg_loss, norm / heatmaps.shape[0] + return - pos_loss, - neg_loss, norm def reg_loss(self, box_regs, gt_regs): """Compute regression loss. @@ -242,11 +244,12 @@ def reg_loss(self, box_regs, gt_regs): """ reg_inds = gt_regs.max(axis=2) >= 0 # B x m: find valid pixels. gious = centernet_utils.giou_loss(box_regs, gt_regs) # B x m - norm = jnp.maximum(reg_inds.sum(), 1.) # scalar + norm = jnp.maximum(reg_inds.sum(1), 1.) + norm = jnp.mean(norm) # scalar if self.sync_device_norm: norm = jax.lax.pmean(norm, axis_name='batch') - reg_loss = (gious * reg_inds).sum() / norm # scalar - return reg_loss, gious, norm / reg_inds.shape[0] + reg_loss = jnp.mean((gious * reg_inds).sum(1)) / norm # scalar + return reg_loss, gious, norm def _get_bbox_ltrb(self, grids, boxes, m, n): """generate FCOS style regression targets.