diff --git a/ideadet/modeling/criterion/dn_components.py b/ideadet/modeling/criterion/dn_components.py new file mode 100644 index 0000000000..07bacaf29e --- /dev/null +++ b/ideadet/modeling/criterion/dn_components.py @@ -0,0 +1,272 @@ +# ------------------------------------------------------------------------ +# DN-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + + +# import torch +# from util.misc import (NestedTensor, nested_tensor_from_tensor_list, +# accuracy, get_world_size, interpolate, +# is_dist_avail_and_initialized, inverse_sigmoid) +# # from .DABDETR import sigmoid_focal_loss +# from util import box_ops +# import torch.nn.functional as F +# from ..losses import dice_loss, sigmoid_focal_loss + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ideadet.layers import box_ops +from ideadet.utils import ( + accuracy, + get_world_size, + interpolate, + is_dist_avail_and_initialized, + nested_tensor_from_tensor_list, +) +from ideadet.utils.misc import inverse_sigmoid, nested_tensor_from_tensor_list +from ..losses import dice_loss, sigmoid_focal_loss + + +def prepare_for_dn(targets, dn_args, embedweight, batch_size, training, num_queries, num_classes, hidden_dim, label_enc): + """ + prepare for dn components in forward function + Args: + dn_args: (targets, args.scalar, args.label_noise_scale, + args.box_noise_scale, args.num_patterns) from engine input + embedweight: positional queries as anchor + training: whether it is training or inference + num_queries: number of queries + num_classes: number of classes + hidden_dim: transformer hidden dimenstion + label_enc: label encoding embedding + + Returns: input_query_label, input_query_bbox, attn_mask, mask_dict + """ + scalar, label_noise_scale, box_noise_scale = dn_args + + num_patterns = 1 + indicator0 = torch.zeros([num_queries * num_patterns, 1]).cuda() + tgt = label_enc(torch.tensor(num_classes).cuda()).repeat(num_queries * num_patterns, 1) + tgt = torch.cat([tgt, indicator0], dim=1) + refpoint_emb = embedweight.repeat(num_patterns, 1) + if training: + known = [(torch.ones_like(t['labels'])).cuda() for t in targets] + know_idx = [torch.nonzero(t) for t in known] + known_num = [sum(k) for k in known] + # you can uncomment this to use fix number of dn queries + # if int(max(known_num))>0: + # scalar=scalar//int(max(known_num)) + + # can be modified to selectively denosie some label or boxes; also known label prediction + unmask_bbox = unmask_label = torch.cat(known) + labels = torch.cat([t['labels'] for t in targets]) + boxes = torch.cat([t['boxes'] for t in targets]) + batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) + + known_indice = torch.nonzero(unmask_label + unmask_bbox) + known_indice = known_indice.view(-1) + + # add noise + known_indice = known_indice.repeat(scalar, 1).view(-1) + known_labels = labels.repeat(scalar, 1).view(-1) + known_bid = batch_idx.repeat(scalar, 1).view(-1) + known_bboxs = boxes.repeat(scalar, 1) + known_labels_expaned = known_labels.clone() + known_bbox_expand = known_bboxs.clone() + + # noise on the label + if label_noise_scale > 0: + p = torch.rand_like(known_labels_expaned.float()) + chosen_indice = torch.nonzero(p < (label_noise_scale)).view(-1) # usually half of bbox noise + new_label = torch.randint_like(chosen_indice, 0, num_classes) # randomly put a new one here + known_labels_expaned.scatter_(0, chosen_indice, new_label) + # noise on the box + if box_noise_scale > 0: + diff = torch.zeros_like(known_bbox_expand) + diff[:, :2] = known_bbox_expand[:, 2:] / 2 + diff[:, 2:] = known_bbox_expand[:, 2:] + known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), + diff).cuda() * box_noise_scale + known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) + + m = known_labels_expaned.long().to('cuda') + input_label_embed = label_enc(m) + # add dn part indicator + indicator1 = torch.ones([input_label_embed.shape[0], 1]).cuda() + input_label_embed = torch.cat([input_label_embed, indicator1], dim=1) + input_bbox_embed = inverse_sigmoid(known_bbox_expand) + single_pad = int(max(known_num)) + pad_size = int(single_pad * scalar) + padding_label = torch.zeros(pad_size, hidden_dim).cuda() + padding_bbox = torch.zeros(pad_size, 4).cuda() + input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) + input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) + + # map in order + map_known_indice = torch.tensor([]).to('cuda') + if len(known_num): + map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] + map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() + if len(known_bid): + input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed + input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed + + tgt_size = pad_size + num_queries * num_patterns + attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 + # match query cannot see the reconstruct + attn_mask[pad_size:, :pad_size] = True + # reconstruct cannot see each other + for i in range(scalar): + if i == 0: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + if i == scalar - 1: + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + else: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + mask_dict = { + 'known_indice': torch.as_tensor(known_indice).long(), + 'batch_idx': torch.as_tensor(batch_idx).long(), + 'map_known_indice': torch.as_tensor(map_known_indice).long(), + 'known_lbs_bboxes': (known_labels, known_bboxs), + 'know_idx': know_idx, + 'pad_size': pad_size + } + else: # no dn for inference + input_query_label = tgt.repeat(batch_size, 1, 1) + input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) + attn_mask = None + mask_dict = None + + input_query_label = input_query_label.transpose(0, 1) + input_query_bbox = input_query_bbox.transpose(0, 1) + + return input_query_label, input_query_bbox, attn_mask, mask_dict + + +def dn_post_process(outputs_class, outputs_coord, mask_dict): + """ + post process of dn after output from the transformer + put the dn part in the mask_dict + """ + if mask_dict and mask_dict['pad_size'] > 0: + output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] + output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] + outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] + outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] + mask_dict['output_known_lbs_bboxes']=(output_known_class,output_known_coord) + return outputs_class, outputs_coord + +def prepare_for_loss(mask_dict): + """ + prepare dn components to calculate loss + Args: + mask_dict: a dict that contains dn information + """ + output_known_class, output_known_coord = mask_dict['output_known_lbs_bboxes'] + known_labels, known_bboxs = mask_dict['known_lbs_bboxes'] + map_known_indice = mask_dict['map_known_indice'] + + known_indice = mask_dict['known_indice'] + + batch_idx = mask_dict['batch_idx'] + bid = batch_idx[known_indice] + if len(output_known_class) > 0: + output_known_class = output_known_class.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) + output_known_coord = output_known_coord.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) + num_tgt = known_indice.numel() + return known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt + + +def tgt_loss_boxes(src_boxes, tgt_boxes, num_tgt,): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + if len(tgt_boxes) == 0: + return { + 'tgt_loss_bbox': torch.as_tensor(0.).to('cuda'), + 'tgt_loss_giou': torch.as_tensor(0.).to('cuda'), + } + + loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none') + + losses = {} + losses['tgt_loss_bbox'] = loss_bbox.sum() / num_tgt + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(tgt_boxes))) + losses['tgt_loss_giou'] = loss_giou.sum() / num_tgt + return losses + + +def tgt_loss_labels(src_logits_, tgt_labels_, num_tgt, focal_alpha, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + if len(tgt_labels_) == 0: + return { + 'tgt_loss_ce': torch.as_tensor(0.).to('cuda'), + 'tgt_class_error': torch.as_tensor(0.).to('cuda'), + } + + src_logits, tgt_labels= src_logits_.unsqueeze(0), tgt_labels_.unsqueeze(0) + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, tgt_labels.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:, :, :-1] + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_tgt, alpha=focal_alpha, gamma=2) * src_logits.shape[1] + + losses = {'tgt_loss_ce': loss_ce} + + losses['tgt_class_error'] = 100 - accuracy(src_logits_, tgt_labels_)[0] + return losses + + +def compute_dn_loss(mask_dict, training, aux_num, focal_alpha): + """ + compute dn loss in criterion + Args: + mask_dict: a dict for dn information + training: training or inference flag + aux_num: aux loss number + focal_alpha: for focal loss + """ + losses = {} + if training and 'output_known_lbs_bboxes' in mask_dict: + known_labels, known_bboxs, output_known_class, output_known_coord, \ + num_tgt = prepare_for_loss(mask_dict) + losses.update(tgt_loss_labels(output_known_class[-1], known_labels, num_tgt, focal_alpha)) + losses.update(tgt_loss_boxes(output_known_coord[-1], known_bboxs, num_tgt)) + else: + losses['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda') + losses['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda') + losses['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda') + losses['tgt_class_error'] = torch.as_tensor(0.).to('cuda') + + if aux_num: + for i in range(aux_num): + # dn aux loss + if training and 'output_known_lbs_bboxes' in mask_dict: + l_dict = tgt_loss_labels(output_known_class[i], known_labels, num_tgt, focal_alpha) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + l_dict = tgt_loss_boxes(output_known_coord[i], known_bboxs, num_tgt) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + else: + l_dict = dict() + l_dict['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_class_error'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda') + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + return losses + diff --git a/ideadet/modeling/criterion/dn_criterion.py b/ideadet/modeling/criterion/dn_criterion.py new file mode 100644 index 0000000000..1ace1a2143 --- /dev/null +++ b/ideadet/modeling/criterion/dn_criterion.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ideadet.layers import box_ops +from ideadet.utils import ( + accuracy, + get_world_size, + interpolate, + is_dist_avail_and_initialized, + nested_tensor_from_tensor_list, +) + +from ..losses import dice_loss, sigmoid_focal_loss +from .dn_components import compute_dn_loss + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_alpha = focal_alpha + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:,:,:-1] + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets, mask_dict=None): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + kwargs = {} + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'enc_outputs' in outputs: + enc_outputs = outputs['enc_outputs'] + for bt in targets: + bt['labels'] = torch.zeros_like(bt['labels']) + indices = self.matcher(enc_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_enc': v for k, v in l_dict.items()} + losses.update(l_dict) + + # dn loss computation + aux_num = 0 + if 'aux_outputs' in outputs: + aux_num = len(outputs['aux_outputs']) + dn_losses = compute_dn_loss(mask_dict, self.training, aux_num, self.focal_alpha) + losses.update(dn_losses) + + return losses \ No newline at end of file diff --git a/projects/dn_detr/README.md b/projects/dn_detr/README.md new file mode 100644 index 0000000000..bfe245cffe --- /dev/null +++ b/projects/dn_detr/README.md @@ -0,0 +1,8 @@ +## DAB-DETR + +- Run test + +```python +cd projects/dab_detr +python train_net.py --config-file configs/dab_detr_training.py --eval-only +``` \ No newline at end of file diff --git a/projects/dn_detr/configs/common/coco_loader.py b/projects/dn_detr/configs/common/coco_loader.py new file mode 100644 index 0000000000..81f6b890bf --- /dev/null +++ b/projects/dn_detr/configs/common/coco_loader.py @@ -0,0 +1,70 @@ +from omegaconf import OmegaConf + +import detectron2.data.transforms as T +from detectron2.config import LazyCall as L +from detectron2.data import ( + build_detection_test_loader, + build_detection_train_loader, + get_detection_dataset_dicts, +) +from detectron2.evaluation import COCOEvaluator + +from ideadet.data import DetrDatasetMapper + +dataloader = OmegaConf.create() + +dataloader.train = L(build_detection_train_loader)( + dataset=L(get_detection_dataset_dicts)(names="coco_2017_train"), + mapper=L(DetrDatasetMapper)( + augmentation=[ + L(T.RandomFlip)(), + L(T.ResizeShortestEdge)( + short_edge_length=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), + max_size=1333, + sample_style="choice", + ), + ], + augmentation_with_crop=[ + L(T.RandomFlip)(), + L(T.ResizeShortestEdge)( + short_edge_length=(400, 500, 600), + sample_style="choice", + ), + L(T.RandomCrop)( + crop_type="absolute_range", + crop_size=(384, 600), + ), + L(T.ResizeShortestEdge)( + short_edge_length=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), + max_size=1333, + sample_style="choice", + ), + ], + is_train=True, + mask_on=False, + img_format="RGB", + ), + total_batch_size=16, + num_workers=4, +) + +dataloader.test = L(build_detection_test_loader)( + dataset=L(get_detection_dataset_dicts)(names="coco_2017_val", filter_empty=False), + mapper=L(DetrDatasetMapper)( + augmentation=[ + L(T.ResizeShortestEdge)( + short_edge_length=800, + max_size=1333, + ), + ], + augmentation_with_crop=None, + is_train=False, + mask_on=False, + img_format="RGB", + ), + num_workers=4, +) + +dataloader.evaluator = L(COCOEvaluator)( + dataset_name="${..test.dataset.names}", +) diff --git a/projects/dn_detr/configs/common/schedule.py b/projects/dn_detr/configs/common/schedule.py new file mode 100644 index 0000000000..ae8767eb87 --- /dev/null +++ b/projects/dn_detr/configs/common/schedule.py @@ -0,0 +1,34 @@ +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler + + +def dab_coco_scheduler(epochs=50, decay_epochs=40, warmup_epochs=0.0): + """ + Returns the config for a default multi-step LR scheduler such as "1x", "3x", + commonly referred to in papers, where every 1x has the total length of 1440k + training images (~12 COCO epochs). LR is decayed twice at the end of training + following the strategy defined in "Rethinking ImageNet Pretraining", Sec 4. + Args: + num_X: a positive real number + Returns: + DictConfig: configs that define the multiplier for LR during training + """ + # total number of iterations assuming 16 batch size, using 1440000/16=90000 + total_steps_16bs = epochs * 7500 + decay_steps = decay_epochs * 7500 + warmup_steps = warmup_epochs * 7500 + scheduler = L(MultiStepParamScheduler)( + values=[1.0, 0.1], + milestones=[decay_steps, total_steps_16bs], + ) + return L(WarmupParamScheduler)( + scheduler=scheduler, + warmup_length=warmup_steps / total_steps_16bs, + warmup_method="linear", + warmup_factor=0.001, + ) + + +lr_multiplier = dab_coco_scheduler() diff --git a/projects/dn_detr/configs/dn_detr_r50_50epoch.py b/projects/dn_detr/configs/dn_detr_r50_50epoch.py new file mode 100644 index 0000000000..1a28c485c2 --- /dev/null +++ b/projects/dn_detr/configs/dn_detr_r50_50epoch.py @@ -0,0 +1,24 @@ +from ideadet.config import get_config + +from .models.dab_detr_r50 import model +from .common.coco_loader import dataloader +from .common.schedule import lr_multiplier + + +optimizer = get_config("common/optim.py").AdamW +train = get_config("common/train.py").train + +# modify training config +train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl" +# train.init_checkpoint = "/student/lifeng/model/dn_detr_r50_official/checkpoint_46ep_44.6ap.pth" +# train.init_checkpoint = "/student/lifeng/model//output_dab_r50_freeze_1_no_decay_norm_dn-idea01//converted_model.pth" +train.output_dir = "./output/dn_detr_r50_50epoch" +train.max_iter = 375000 + + +# modify optimizer config +optimizer.weight_decay = 1e-4 +optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1 + +# modify dataloader config +dataloader.train.num_workers = 16 diff --git a/projects/dn_detr/configs/models/dab_detr_r50.py b/projects/dn_detr/configs/models/dab_detr_r50.py new file mode 100644 index 0000000000..ab9ac4c432 --- /dev/null +++ b/projects/dn_detr/configs/models/dab_detr_r50.py @@ -0,0 +1,145 @@ +import torch.nn as nn + +from ideadet.modeling.utils import Joiner, MaskedBackbone +from ideadet.modeling.matcher import DabMatcher +from ideadet.modeling.criterion import DabCriterion +from ideadet.layers import ( + MultiheadAttention, + ConditionalSelfAttention, + ConditionalCrossAttention, + PositionEmbeddingSine, + FFN, + BaseTransformerLayer, +) + +from detectron2.modeling.backbone import ResNet, BasicStem +from detectron2.config import LazyCall as L + +from modeling import ( + DABDETR, + DabDetrTransformer, + DabDetrTransformerDecoder, + DabDetrTransformerEncoder, +) + + +model = L(DABDETR)( + backbone=L(Joiner)( + backbone=L(MaskedBackbone)( + backbone=L(ResNet)( + stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), + stages=L(ResNet.make_default_stages)( + depth=50, + stride_in_1x1=False, + norm="FrozenBN", + ), + out_features=["res2", "res3", "res4", "res5"], + freeze_at=1, + ) + ), + position_embedding=L(PositionEmbeddingSine)( + num_pos_feats=128, temperature=20, normalize=True + ), + ), + transformer=L(DabDetrTransformer)( + encoder=L(DabDetrTransformerEncoder)( + transformer_layers=L(BaseTransformerLayer)( + attn=L(MultiheadAttention)( + embed_dim=256, + num_heads=8, + attn_drop=0.0, + batch_first=False, + ), + ffn=L(FFN)( + embed_dim=256, + feedforward_dim=2048, + ffn_drop=0.0, + activation=L(nn.PReLU)(), + ), + norm=L(nn.LayerNorm)(normalized_shape=256), + operation_order=("self_attn", "norm", "ffn", "norm"), + ), + num_layers=6, + post_norm=False, + ), + decoder=L(DabDetrTransformerDecoder)( + num_layers=6, + return_intermediate=True, + query_dim=4, + modulate_hw_attn=True, + post_norm=True, + transformer_layers=L(BaseTransformerLayer)( + attn=[ + L(ConditionalSelfAttention)( + embed_dim=256, + num_heads=8, + attn_drop=0.0, + batch_first=False, + ), + L(ConditionalCrossAttention)( + embed_dim=256, + num_heads=8, + attn_drop=0.0, + batch_first=False, + ), + ], + ffn=L(FFN)( + embed_dim=256, + feedforward_dim=2048, + ffn_drop=0.0, + activation=L(nn.PReLU)(), + ), + norm=L(nn.LayerNorm)( + normalized_shape=256, + ), + operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), + ), + ), + ), + num_classes=80, + num_queries=300, + aux_loss=True, + query_dim=4, + iter_update=True, + random_refpoints_xy=True, + criterion=L(DabCriterion)( + num_classes=80, + matcher=L(DabMatcher)( + cost_class=1, + cost_bbox=5.0, + cost_giou=2.0, + ), + weight_dict={ + "loss_ce": 1, + "loss_bbox": 5.0, + "loss_giou": 2.0, + }, + focal_alpha=0.25, + losses=[ + "labels", + "boxes", + ], + ), + pixel_mean=[123.675, 116.280, 103.530], + pixel_std=[58.395, 57.120, 57.375], + device="cuda", + use_dn=True, + scalar=5, + label_noise_scale=0.2, + box_noise_scale=0.4, +) + +# set aux loss weight dict +if model.aux_loss: + weight_dict = model.criterion.weight_dict + if model.use_dn: + weight_dict['tgt_loss_ce'] = 1.0 + weight_dict['tgt_loss_bbox'] = 5.0 + weight_dict['tgt_loss_giou'] = 2.0 + aux_weight_dict = {} + for i in range(model.transformer.decoder.num_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + model.criterion.weight_dict = weight_dict + + diff --git a/projects/dn_detr/modeling/__init__.py b/projects/dn_detr/modeling/__init__.py new file mode 100644 index 0000000000..5035f37679 --- /dev/null +++ b/projects/dn_detr/modeling/__init__.py @@ -0,0 +1,6 @@ +from .dab_detr import DABDETR +from .dab_transformer import ( + DabDetrTransformerEncoder, + DabDetrTransformerDecoder, + DabDetrTransformer, +) diff --git a/projects/dn_detr/modeling/dab_detr.py b/projects/dn_detr/modeling/dab_detr.py new file mode 100644 index 0000000000..6737c3a790 --- /dev/null +++ b/projects/dn_detr/modeling/dab_detr.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------------------------------ +# Modified from: +# https://github.com/facebookresearch/detr/blob/main/d2/detr/detr.py +# ------------------------------------------------------------------------------------------------ + +import math +import torch +import torch.nn as nn + +from ideadet.layers.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh +from ideadet.layers.mlp import MLP +from ideadet.utils.misc import inverse_sigmoid, nested_tensor_from_tensor_list + +from detectron2.modeling import detector_postprocess +from detectron2.structures import Boxes, ImageList, Instances +from ideadet.modeling.criterion.dn_components import prepare_for_dn, dn_post_process + + +class DABDETR(nn.Module): + def __init__( + self, + backbone, + transformer, + num_classes, + num_queries, + criterion, + pixel_mean, + pixel_std, + aux_loss=True, + iter_update=True, + query_dim=4, + random_refpoints_xy=True, + device="cuda", + use_dn=True, + scalar=5, + label_noise_scale=0., + box_noise_scale=0., + ): + super(DABDETR, self).__init__() + self.backbone = backbone + self.transformer = transformer + hidden_dim = 256 + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.query_dim = query_dim + self.aux_loss = aux_loss + self.iter_update = iter_update + #################### + self.hidden_dim = hidden_dim + self.num_queries = num_queries + self.num_classes = num_classes + self.use_dn = use_dn + self.dn_args = (scalar, label_noise_scale, box_noise_scale) + # leave one dim for indicator + self.label_enc = nn.Embedding(num_classes + 1, hidden_dim - 1) + + assert self.query_dim in [2, 4] + + self.refpoint_embed = nn.Embedding(num_queries, query_dim) + self.random_refpoints_xy = random_refpoints_xy + if random_refpoints_xy: + # import ipdb; ipdb.set_trace() + self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) + self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid( + self.refpoint_embed.weight.data[:, :2] + ) + self.refpoint_embed.weight.data[:, :2].requires_grad = False + + self.input_proj = nn.Conv2d(2048, hidden_dim, kernel_size=1) + if self.iter_update: + self.transformer.decoder.bbox_embed = self.bbox_embed + + self.criterion = criterion + self.device = device + pixel_mean = torch.Tensor(pixel_mean).to(self.device).view(3, 1, 1) + pixel_std = torch.Tensor(pixel_std).to(self.device).view(3, 1, 1) + self.normalizer = lambda x: (x - pixel_mean) / pixel_std + + # init prior_prob setting for focal loss + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + + def forward(self, batched_inputs): + images = self.preprocess_image(batched_inputs) + + if isinstance(images, (list, torch.Tensor)): + images = nested_tensor_from_tensor_list(images) + features, pos = self.backbone(images) + + src, mask = features[-1].decompose() + assert mask is not None + embedweight = self.refpoint_embed.weight # TODO this should be moved to the Transformer + + ####### prepare for dn + if self.training: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + targets = self.prepare_targets(gt_instances) + else: + targets = None + input_query_label, input_query_bbox, attn_mask, mask_dict = \ + prepare_for_dn(targets, self.dn_args, embedweight, src.size(0), self.training, self.num_queries, self.num_classes, + self.hidden_dim, self.label_enc) + + # hs, reference = self.transformer(self.input_proj(src), mask, embedweight, pos[-1]) + hs, reference = self.transformer(self.input_proj(src), mask, input_query_bbox, pos[-1], target=input_query_label, + attn_mask=attn_mask) + + reference_before_sigmoid = inverse_sigmoid(reference) + tmp = self.bbox_embed(hs) + tmp[..., : self.query_dim] += reference_before_sigmoid + outputs_coord = tmp.sigmoid() + outputs_class = self.class_embed(hs) + + ###### dn post process + outputs_class, outputs_coord = dn_post_process(outputs_class, outputs_coord, mask_dict) + + output = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} + if self.aux_loss: + output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) + + if self.training: + loss_dict = self.criterion(output, targets, mask_dict) + weight_dict = self.criterion.weight_dict + for k in loss_dict.keys(): + if k in weight_dict: + loss_dict[k] *= weight_dict[k] + return loss_dict + else: + box_cls = output["pred_logits"] + box_pred = output["pred_boxes"] + results = self.inference(box_cls, box_pred, images.image_sizes) + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"instances": r}) + return processed_results + + def inference(self, box_cls, box_pred, image_sizes): + """ + Arguments: + box_cls (Tensor): tensor of shape (batch_size, num_queries, K). + The tensor predicts the classification probability for each query. + box_pred (Tensor): tensors of shape (batch_size, num_queries, 4). + The tensor predicts 4-vector (x,y,w,h) box + regression values for every queryx + image_sizes (List[torch.Size]): the input image sizes + + Returns: + results (List[Instances]): a list of #images elements. + """ + assert len(box_cls) == len(image_sizes) + results = [] + + # box_cls.shape: 1, 300, 80 + # box_pred.shape: 1, 300, 4 + prob = box_cls.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(box_cls.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = torch.div(topk_indexes, box_cls.shape[2], rounding_mode="floor") + labels = topk_indexes % box_cls.shape[2] + + boxes = torch.gather(box_pred, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # For each box we assign the best class or the second best if the best on is `no_object`. + # scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1) + + for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate( + zip(scores, labels, boxes, image_sizes) + ): + result = Instances(image_size) + result.pred_boxes = Boxes(box_cxcywh_to_xyxy(box_pred_per_image)) + + result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) + result.scores = scores_per_image + result.pred_classes = labels_per_image + results.append(result) + return results + + def prepare_targets(self, targets): + new_targets = [] + for targets_per_image in targets: + h, w = targets_per_image.image_size + image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) + gt_classes = targets_per_image.gt_classes + gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy + gt_boxes = box_xyxy_to_cxcywh(gt_boxes) + new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) + return new_targets + + def preprocess_image(self, batched_inputs): + images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] + images = ImageList.from_tensors(images) + return images + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + {"pred_logits": a, "pred_boxes": b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + ] diff --git a/projects/dn_detr/modeling/dab_transformer.py b/projects/dn_detr/modeling/dab_transformer.py new file mode 100644 index 0000000000..ebbe48acbf --- /dev/null +++ b/projects/dn_detr/modeling/dab_transformer.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from ideadet.layers import MLP, BaseTransformerLayer, TransformerLayerSequence, get_sine_pos_embed +from ideadet.utils.misc import inverse_sigmoid + + +class DabDetrTransformerEncoder(TransformerLayerSequence): + def __init__( + self, + transformer_layers: BaseTransformerLayer = None, + post_norm: bool = True, + num_layers: int = None, + ): + super(DabDetrTransformerEncoder, self).__init__( + transformer_layers=transformer_layers, num_layers=num_layers + ) + self.embed_dim = self.layers[0].embed_dim + self.pre_norm = self.layers[0].pre_norm + self.query_scale = MLP(self.embed_dim, self.embed_dim, self.embed_dim, 2) + + if post_norm: + self.post_norm_layer = nn.LayerNorm(self.embed_dim) + else: + self.post_norm_layer = None + + def forward( + self, + query, + key, + value, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + **kwargs, + ): + + for layer in self.layers: + position_scales = self.query_scale(query) + query = layer( + query, + key, + value, + query_pos=query_pos * position_scales, + attn_masks=attn_masks, + query_key_padding_mask=query_key_padding_mask, + key_padding_mask=key_padding_mask, + **kwargs, + ) + + if self.post_norm_layer is not None: + query = self.post_norm_layer(query) + return query + + +class DabDetrTransformerDecoder(TransformerLayerSequence): + def __init__( + self, + transformer_layers: BaseTransformerLayer = None, + num_layers: int = None, + query_dim: int = 4, + modulate_hw_attn: bool = True, + post_norm: bool = True, + return_intermediate: bool = True, + ): + super().__init__(transformer_layers, num_layers) + self.return_intermediate = return_intermediate + self.embed_dim = self.layers[0].embed_dim + + self.query_scale = MLP(self.embed_dim, self.embed_dim, self.embed_dim, 2) + self.ref_point_head = MLP( + query_dim // 2 * self.embed_dim, self.embed_dim, self.embed_dim, 2 + ) + + self.bbox_embed = None + if modulate_hw_attn: + self.ref_anchor_head = MLP(self.embed_dim, self.embed_dim, 2, 2) + self.modulate_hw_attn = modulate_hw_attn + + if post_norm: + self.post_norm_layer = nn.LayerNorm(self.embed_dim) + else: + self.post_norm_layer = None + + for idx in range(num_layers - 1): + self.layers[idx + 1].attentions[1].query_pos_proj = None + + def forward( + self, + query, + key, + value, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + refpoints_embed=None, + **kwargs, + ): + intermediate = [] + + reference_points = refpoints_embed.sigmoid() + refpoints = [reference_points] + + for idx, layer in enumerate(self.layers): + obj_center = reference_points[..., : self.embed_dim] + query_sine_embed = get_sine_pos_embed(obj_center) + query_pos = self.ref_point_head(query_sine_embed) + + # do not apply transform in position in the first decoder layer + if idx == 0: + position_transform = 1 + else: + position_transform = self.query_scale(query) + + # apply position transform + query_sine_embed = query_sine_embed[..., : self.embed_dim] * position_transform + + if self.modulate_hw_attn: + ref_hw_cond = self.ref_anchor_head(query).sigmoid() + query_sine_embed[..., self.embed_dim // 2 :] *= ( + ref_hw_cond[..., 0] / obj_center[..., 2] + ).unsqueeze(-1) + query_sine_embed[..., : self.embed_dim // 2] *= ( + ref_hw_cond[..., 1] / obj_center[..., 3] + ).unsqueeze(-1) + + query = layer( + query, + key, + value, + query_pos=query_pos, + key_pos=key_pos, + query_sine_embed=query_sine_embed, + attn_masks=attn_masks, + query_key_padding_mask=query_key_padding_mask, + key_padding_mask=key_padding_mask, + is_first_layer=(idx == 0), + **kwargs, + ) + + # iter update + if self.bbox_embed is not None: + temp = self.bbox_embed(query) + temp[..., : self.embed_dim] += inverse_sigmoid(reference_points) + new_reference_points = temp[..., : self.embed_dim].sigmoid() + + if idx != self.num_layers - 1: + refpoints.append(new_reference_points) + reference_points = new_reference_points.detach() + + if self.return_intermediate: + if self.post_norm_layer is not None: + intermediate.append(self.post_norm_layer(query)) + else: + intermediate.append(query) + + if self.post_norm_layer is not None: + query = self.post_norm_layer(query) + if self.return_intermediate: + intermediate.pop() + intermediate.append(query) + + if self.return_intermediate: + if self.bbox_embed is not None: + return [ + torch.stack(intermediate).transpose(1, 2), + torch.stack(refpoints).transpose(1, 2), + ] + else: + return [ + torch.stack(intermediate).transpose(1, 2), + reference_points.unsqueeze(0).transpose(1, 2), + ] + + return query.unsqueeze(0) + + +class DabDetrTransformer(nn.Module): + def __init__(self, encoder=None, decoder=None): + super(DabDetrTransformer, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.embed_dim = self.encoder.embed_dim + + self.init_weights() + + def init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, x, mask, refpoints_embed, pos_embed, target=None, attn_mask=None): + bs, c, h, w = x.shape + x = x.view(bs, c, -1).permute(2, 0, 1) + pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) + + # refpoints_embed = refpoints_embed.unsqueeze(1).repeat(1, bs, 1) + mask = mask.view(bs, -1) + memory = self.encoder( + query=x, + key=None, + value=None, + query_pos=pos_embed, + query_key_padding_mask=mask, + ) + # num_queries = refpoints_embed.shape[0] + # target = torch.zeros(num_queries, bs, self.embed_dim, device=refpoints_embed.device) + + hidden_state, references = self.decoder( + query=target, + key=memory, + value=memory, + key_pos=pos_embed, + attn_masks=attn_mask, + refpoints_embed=refpoints_embed, + ) + + return hidden_state, references diff --git a/projects/dn_detr/modeling/dn_components.py b/projects/dn_detr/modeling/dn_components.py new file mode 100644 index 0000000000..44e879402e --- /dev/null +++ b/projects/dn_detr/modeling/dn_components.py @@ -0,0 +1,160 @@ +# # ------------------------------------------------------------------------ +# # DN-DETR +# # Copyright (c) 2022 IDEA. All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# # ------------------------------------------------------------------------ +# +# +# import torch +# from util.misc import (NestedTensor, nested_tensor_from_tensor_list, +# accuracy, get_world_size, interpolate, +# is_dist_avail_and_initialized, inverse_sigmoid) +# # from .DABDETR import sigmoid_focal_loss +# from util import box_ops +# import torch.nn.functional as F +# +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# +# from ideadet.layers import box_ops +# from ideadet.utils import ( +# accuracy, +# get_world_size, +# interpolate, +# is_dist_avail_and_initialized, +# nested_tensor_from_tensor_list, +# ) +# +# +# def prepare_for_dn(targets, dn_args, embedweight, batch_size, training, num_queries, num_classes, hidden_dim, label_enc): +# """ +# prepare for dn components in forward function +# Args: +# dn_args: (targets, args.scalar, args.label_noise_scale, +# args.box_noise_scale, args.num_patterns) from engine input +# embedweight: positional queries as anchor +# training: whether it is training or inference +# num_queries: number of queries +# num_classes: number of classes +# hidden_dim: transformer hidden dimenstion +# label_enc: label encoding embedding +# +# Returns: input_query_label, input_query_bbox, attn_mask, mask_dict +# """ +# scalar, label_noise_scale, box_noise_scale = dn_args +# +# num_patterns = 1 +# indicator0 = torch.zeros([num_queries * num_patterns, 1]).cuda() +# tgt = label_enc(torch.tensor(num_classes).cuda()).repeat(num_queries * num_patterns, 1) +# tgt = torch.cat([tgt, indicator0], dim=1) +# refpoint_emb = embedweight.repeat(num_patterns, 1) +# if training: +# known = [(torch.ones_like(t['labels'])).cuda() for t in targets] +# know_idx = [torch.nonzero(t) for t in known] +# known_num = [sum(k) for k in known] +# # you can uncomment this to use fix number of dn queries +# # if int(max(known_num))>0: +# # scalar=scalar//int(max(known_num)) +# +# # can be modified to selectively denosie some label or boxes; also known label prediction +# unmask_bbox = unmask_label = torch.cat(known) +# labels = torch.cat([t['labels'] for t in targets]) +# boxes = torch.cat([t['boxes'] for t in targets]) +# batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) +# +# known_indice = torch.nonzero(unmask_label + unmask_bbox) +# known_indice = known_indice.view(-1) +# +# # add noise +# known_indice = known_indice.repeat(scalar, 1).view(-1) +# known_labels = labels.repeat(scalar, 1).view(-1) +# known_bid = batch_idx.repeat(scalar, 1).view(-1) +# known_bboxs = boxes.repeat(scalar, 1) +# known_labels_expaned = known_labels.clone() +# known_bbox_expand = known_bboxs.clone() +# +# # noise on the label +# if label_noise_scale > 0: +# p = torch.rand_like(known_labels_expaned.float()) +# chosen_indice = torch.nonzero(p < (label_noise_scale)).view(-1) # usually half of bbox noise +# new_label = torch.randint_like(chosen_indice, 0, num_classes) # randomly put a new one here +# known_labels_expaned.scatter_(0, chosen_indice, new_label) +# # noise on the box +# if box_noise_scale > 0: +# diff = torch.zeros_like(known_bbox_expand) +# diff[:, :2] = known_bbox_expand[:, 2:] / 2 +# diff[:, 2:] = known_bbox_expand[:, 2:] +# known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), +# diff).cuda() * box_noise_scale +# known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) +# +# m = known_labels_expaned.long().to('cuda') +# input_label_embed = label_enc(m) +# # add dn part indicator +# indicator1 = torch.ones([input_label_embed.shape[0], 1]).cuda() +# input_label_embed = torch.cat([input_label_embed, indicator1], dim=1) +# input_bbox_embed = inverse_sigmoid(known_bbox_expand) +# single_pad = int(max(known_num)) +# pad_size = int(single_pad * scalar) +# padding_label = torch.zeros(pad_size, hidden_dim).cuda() +# padding_bbox = torch.zeros(pad_size, 4).cuda() +# input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) +# input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) +# +# # map in order +# map_known_indice = torch.tensor([]).to('cuda') +# if len(known_num): +# map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] +# map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() +# if len(known_bid): +# input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed +# input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed +# +# tgt_size = pad_size + num_queries * num_patterns +# attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 +# # match query cannot see the reconstruct +# attn_mask[pad_size:, :pad_size] = True +# # reconstruct cannot see each other +# for i in range(scalar): +# if i == 0: +# attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True +# if i == scalar - 1: +# attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True +# else: +# attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True +# attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True +# mask_dict = { +# 'known_indice': torch.as_tensor(known_indice).long(), +# 'batch_idx': torch.as_tensor(batch_idx).long(), +# 'map_known_indice': torch.as_tensor(map_known_indice).long(), +# 'known_lbs_bboxes': (known_labels, known_bboxs), +# 'know_idx': know_idx, +# 'pad_size': pad_size +# } +# else: # no dn for inference +# input_query_label = tgt.repeat(batch_size, 1, 1) +# input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) +# attn_mask = None +# mask_dict = None +# +# input_query_label = input_query_label.transpose(0, 1) +# input_query_bbox = input_query_bbox.transpose(0, 1) +# +# return input_query_label, input_query_bbox, attn_mask, mask_dict +# +# +# def dn_post_process(outputs_class, outputs_coord, mask_dict): +# """ +# post process of dn after output from the transformer +# put the dn part in the mask_dict +# """ +# if mask_dict and mask_dict['pad_size'] > 0: +# output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] +# output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] +# outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] +# outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] +# mask_dict['output_known_lbs_bboxes']=(output_known_class,output_known_coord) +# return outputs_class, outputs_coord +# +# diff --git a/projects/dn_detr/train_net.py b/projects/dn_detr/train_net.py new file mode 100644 index 0000000000..0ba1475836 --- /dev/null +++ b/projects/dn_detr/train_net.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. + +### +""" +Training script using the new "LazyConfig" python config files. +This scripts reads a given python config file and runs the training or evaluation. +It can be used to train any models or dataset as long as they can be +instantiated by the recursive construction defined in the given config file. +Besides lazy construction of models, dataloader, etc., this scripts expects a +few common configuration parameters currently defined in "configs/common/train.py". +To add more complicated training logic, you can easily add other configs +in the config file and implement a new train_net.py to handle them. +""" +import logging +import time +import torch + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.engine import ( + AMPTrainer, + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, + launch, +) +from detectron2.engine.defaults import create_ddp_model +from detectron2.evaluation import inference_on_dataset, print_csv_format +from detectron2.utils import comm + +logger = logging.getLogger("ideadet") + + +class Trainer(SimpleTrainer): + def __init__(self, model, dataloader, optimizer): + super().__init__(model=model, data_loader=dataloader, optimizer=optimizer) + + def run_step(self): + """ + Implement the standard training logic described above. + """ + assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" + start = time.perf_counter() + """ + If you want to do something with the data, you can wrap the dataloader. + """ + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + """ + If you want to do something with the losses, you can wrap the model. + """ + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + + """ + If you need to accumulate gradients or do something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() + losses.backward() + + self._write_metrics(loss_dict, data_time) + + """ + If you need gradient clipping/scaling or other processing, you can + wrap the optimizer with your custom `step()` method. But it is + suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 + """ + + # add gradient clip here + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.1) + self.optimizer.step() + + +def do_test(cfg, model): + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + + +def do_train(args, cfg): + """ + Args: + cfg: an object with the following attributes: + model: instantiate to a module + dataloader.{train,test}: instantiate to dataloaders + dataloader.evaluator: instantiate to evaluator for test set + optimizer: instantaite to an optimizer + lr_multiplier: instantiate to a fvcore scheduler + train: other misc config defined in `configs/common/train.py`, including: + output_dir (str) + init_checkpoint (str) + amp.enabled (bool) + max_iter (int) + eval_period, log_period (int) + device (str) + checkpointer (dict) + ddp (dict) + """ + model = instantiate(cfg.model) + logger = logging.getLogger("detectron2") + logger.info("Model:\n{}".format(model)) + model.to(cfg.train.device) + + cfg.optimizer.params.model = model + optim = instantiate(cfg.optimizer) + + train_loader = instantiate(cfg.dataloader.train) + + model = create_ddp_model(model, **cfg.train.ddp) + # trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim) + trainer = Trainer(model, train_loader, optim) + checkpointer = DetectionCheckpointer( + model, + cfg.train.output_dir, + trainer=trainer, + ) + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) + if comm.is_main_process() + else None, + hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), + hooks.PeriodicWriter( + default_writers(cfg.train.output_dir, cfg.train.max_iter), + period=cfg.train.log_period, + ) + if comm.is_main_process() + else None, + ] + ) + + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) + if args.resume and checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, cfg.train.max_iter) + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + if args.eval_only: + model = instantiate(cfg.model) + model.to(cfg.train.device) + model = create_ddp_model(model) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + print(do_test(cfg, model)) + else: + do_train(args, cfg) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + )