From 77a91b1204ce275b06d477103d01c8f4bae98392 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Fri, 29 Jan 2021 17:25:31 +0800 Subject: [PATCH] feat: support multi-teacher kd Summary: support multi-teacher kd with logits and overhaul distillation --- .gitignore | 41 ++++++++++++++--- CHANGELOG.md | 12 +++++ README.md | 30 +++++++------ fastreid/config/defaults.py | 4 +- fastreid/modeling/meta_arch/distiller.py | 38 ++++++++++------ .../configs/kd-sbs_r101ibn-sbs_r34.yml | 4 +- projects/FastDistill/fastdistill/overhaul.py | 44 +++++++++++-------- 7 files changed, 118 insertions(+), 55 deletions(-) create mode 100644 CHANGELOG.md diff --git a/.gitignore b/.gitignore index 41d06a514..8be82c3e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,37 @@ -.idea + +logs + +# compilation and distribution __pycache__ -.DS_Store -.vscode +_ext +*.pyc +*.pyd *.so -logs/ -.ipynb_checkpoints -logs \ No newline at end of file +*.dll +*.egg-info/ +build/ +dist/ +wheels/ + +# pytorch/python/numpy formats +*.pth +*.pkl +*.npy +*.ts +model_ts*.txt + +# ipython/jupyter notebooks +*.ipynb +**/.ipynb_checkpoints/ + +# Editor temporaries +*.swn +*.swo +*.swp +*~ + +# editor settings +.idea +.vscode +_darcs +.DS_Store diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..d89c008e8 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,12 @@ +# Changelog + +### v1.1 (29/01/2021) + +#### New Features + +- NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20) +- Multi-teacher Knowledge Distillation + +#### Bug Fixes + +#### Improvements \ No newline at end of file diff --git a/README.md b/README.md index f72d102c1..5ff912bf5 100644 --- a/README.md +++ b/README.md @@ -4,48 +4,52 @@ FastReID is a research platform that implements state-of-the-art re-identificati ## What's New -- [Jan 2021] NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20) based on fastreid has been released! +- [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released! - [Jan 2021] FastReID V1.0 has been released!🎉 Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0). -- [Oct 2020] Added the [Hyper-Parameter Optimization](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastTune) based on fastreid. See `projects/FastTune`. -- [Sep 2020] Added the [person attribute recognition](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastAttr) based on fastreid. See `projects/FastAttr`. +- [Oct 2020] Added the [Hyper-Parameter Optimization](projects/FastTune) based on fastreid. See `projects/FastTune`. +- [Sep 2020] Added the [person attribute recognition](projects/FastAttr) based on fastreid. See `projects/FastAttr`. - [Sep 2020] Automatic Mixed Precision training is supported with `apex`. Set `cfg.SOLVER.FP16_ENABLED=True` to switch it on. -- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution. +- [Aug 2020] [Model Distillation](projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution. - [Aug 2020] ONNX/TensorRT converter is supported. - [Jul 2020] Distributed training with multiple GPUs, it trains much faster. - Includes more features such as circle loss, abundant visualization methods and evaluation metrics, SoTA results on conventional, cross-domain, partial and vehicle re-id, testing on multi-datasets simultaneously, etc. -- Can be used as a library to support [different projects](https://github.com/JDAI-CV/fast-reid/tree/master/projects) on top of it. We'll open source more research projects in this way. +- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way. - Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/). We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox. +## Changelog + +Please refer to [changelog.md](CHANGELOG.md) for details and release history. + ## Installation -See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/INSTALL.md). +See [INSTALL.md](INSTALL.md). ## Quick Start The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself. -See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md). +See [GETTING_STARTED.md](GETTING_STARTED.md). -Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid. +Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](projects) for some projects that are build on top of fastreid. ## Model Zoo and Baselines -We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md). +We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](MODEL_ZOO.md). ## Deployment -We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](https://github.com/JDAI-CV/fast-reid/blob/master/tools/deploy). +We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](tools/deploy). ## License -Fastreid is released under the [Apache 2.0 license](https://github.com/JDAI-CV/fast-reid/blob/master/LICENSE). +Fastreid is released under the [Apache 2.0 license](LICENSE). -## Citing Fastreid +## Citing FastReID -If you use Fastreid in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry. +If you use FastReID in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry. ```BibTeX @article{he2020fastreid, diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index dbc21706c..c32be61ed 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -128,8 +128,8 @@ # ----------------------------------------------------------------------------- _C.KD = CN() -_C.KD.MODEL_CONFIG = "" -_C.KD.MODEL_WEIGHTS = "" +_C.KD.MODEL_CONFIG = ['',] +_C.KD.MODEL_WEIGHTS = ['',] # ----------------------------------------------------------------------------- # INPUT diff --git a/fastreid/modeling/meta_arch/distiller.py b/fastreid/modeling/meta_arch/distiller.py index 54057212e..4e409700d 100644 --- a/fastreid/modeling/meta_arch/distiller.py +++ b/fastreid/modeling/meta_arch/distiller.py @@ -22,22 +22,25 @@ def __init__(self, cfg): super(Distiller, self).__init__(cfg) # Get teacher model config - cfg_t = get_cfg() - cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG) + model_ts = [] + for i in range(len(cfg.KD.MODEL_CONFIG)): + cfg_t = get_cfg() + cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i]) - model_t = build_model(cfg_t) - logger.info("Teacher model:\n{}".format(model_t)) + model_t = build_model(cfg_t) - # No gradients for teacher model - for param in model_t.parameters(): - param.requires_grad_(False) + # No gradients for teacher model + for param in model_t.parameters(): + param.requires_grad_(False) - logger.info("Loading teacher model weights ...") - Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS) + logger.info("Loading teacher model weights ...") + Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS[i]) + + model_ts.append(model_t) # Not register teacher model as `nn.Module`, this is # make sure teacher model weights not saved - self.model_t = [model_t.backbone, model_t.heads] + self.model_ts = model_ts def forward(self, batched_inputs): if self.training: @@ -51,10 +54,13 @@ def forward(self, batched_inputs): s_outputs = self.heads(s_feat, targets) + t_outputs = [] # teacher model forward with torch.no_grad(): - t_feat = self.model_t[0](images) - t_outputs = self.model_t[1](t_feat, targets) + for model_t in self.model_ts: + t_feat = model_t.backbone(images) + t_output = model_t.heads(t_feat, targets) + t_outputs.append(t_output) losses = self.losses(s_outputs, t_outputs, targets) return losses @@ -71,8 +77,12 @@ def losses(self, s_outputs, t_outputs, gt_labels): loss_dict = super(Distiller, self).losses(s_outputs, gt_labels) s_logits = s_outputs["pred_class_logits"] - t_logits = t_outputs["pred_class_logits"].detach() - loss_dict["loss_jsdiv"] = self.jsdiv_loss(s_logits, t_logits) + loss_jsdiv = 0. + for t_output in t_outputs: + t_logits = t_output["pred_class_logits"].detach() + loss_jsdiv += self.jsdiv_loss(s_logits, t_logits) + + loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs) return loss_dict diff --git a/projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml b/projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml index 46f30cf46..55efb37b0 100644 --- a/projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml +++ b/projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml @@ -8,8 +8,8 @@ MODEL: WITH_IBN: False KD: - MODEL_CONFIG: projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml - MODEL_WEIGHTS: projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth + MODEL_CONFIG: ("projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml",) + MODEL_WEIGHTS: ("projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth",) DATASETS: NAMES: ("DukeMTMC",) diff --git a/projects/FastDistill/fastdistill/overhaul.py b/projects/FastDistill/fastdistill/overhaul.py index cf263295c..d3111029b 100644 --- a/projects/FastDistill/fastdistill/overhaul.py +++ b/projects/FastDistill/fastdistill/overhaul.py @@ -61,16 +61,18 @@ def __init__(self, cfg): super().__init__(cfg) s_channels = self.backbone.get_channel_nums() - t_channels = self.model_t[0].get_channel_nums() - self.connectors = nn.ModuleList( - [build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)]) + for i in range(len(self.model_ts)): + t_channels = self.model_ts[i].backbone.get_channel_nums() - teacher_bns = self.model_t[0].get_bn_before_relu() - margins = [get_margin_from_BN(bn) for bn in teacher_bns] - for i, margin in enumerate(margins): - self.register_buffer("margin%d" % (i + 1), - margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach()) + setattr(self, "connectors_{}".format(i), nn.ModuleList( + [build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)])) + + teacher_bns = self.model_ts[i].backbone.get_bn_before_relu() + margins = [get_margin_from_BN(bn) for bn in teacher_bns] + for j, margin in enumerate(margins): + self.register_buffer("margin{}_{}".format(i, j + 1), + margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach()) def forward(self, batched_inputs): if self.training: @@ -84,20 +86,25 @@ def forward(self, batched_inputs): s_outputs = self.heads(s_feat, targets) + t_feats_list = [] + t_outputs = [] # teacher model forward with torch.no_grad(): - t_feats, t_feat = self.model_t[0].extract_feature(images, preReLU=True) - t_outputs = self.model_t[1](t_feat, targets) + for model_t in self.model_ts: + t_feats, t_feat = model_t.backbone.extract_feature(images, preReLU=True) + t_output = model_t.heads(t_feat, targets) + t_feats_list.append(t_feats) + t_outputs.append(t_output) - losses = self.losses(s_outputs, s_feats, t_outputs, t_feats, targets) + losses = self.losses(s_outputs, s_feats, t_outputs, t_feats_list, targets) return losses else: outputs = super(DistillerOverhaul, self).forward(batched_inputs) return outputs - def losses(self, s_outputs, s_feats, t_outputs, t_feats, gt_labels): - r""" + def losses(self, s_outputs, s_feats, t_outputs, t_feats_list, gt_labels): + """ Compute loss from modeling's outputs, the loss function input arguments must be the same as the outputs of the model forwarding. """ @@ -106,11 +113,12 @@ def losses(self, s_outputs, s_feats, t_outputs, t_feats, gt_labels): # Overhaul distillation loss feat_num = len(s_feats) loss_distill = 0 - for i in range(feat_num): - s_feats[i] = self.connectors[i](s_feats[i]) - loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr( - self, "margin%d" % (i + 1)).to(s_feats[i].dtype)) / 2 ** (feat_num - i - 1) + for i in range(len(t_feats_list)): + for j in range(feat_num): + s_feats_connect = getattr(self, "connectors_{}".format(i))[j](s_feats[j]) + loss_distill += distillation_loss(s_feats_connect, t_feats_list[i][j].detach(), getattr( + self, "margin{}_{}".format(i, j + 1)).to(s_feats_connect.dtype)) / 2 ** (feat_num - j - 1) - loss_dict["loss_overhaul"] = loss_distill / len(gt_labels) / 10000 + loss_dict["loss_overhaul"] = loss_distill / len(t_feats_list) / len(gt_labels) / 10000 return loss_dict