Skip to content

Commit

Permalink
feat: support multi-teacher kd
Browse files Browse the repository at this point in the history
Summary: support multi-teacher kd with logits and overhaul distillation
  • Loading branch information
L1aoXingyu committed Jan 29, 2021
1 parent db8670d commit 77a91b1
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 55 deletions.
41 changes: 35 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
.idea

logs

# compilation and distribution
__pycache__
.DS_Store
.vscode
_ext
*.pyc
*.pyd
*.so
logs/
.ipynb_checkpoints
logs
*.dll
*.egg-info/
build/
dist/
wheels/

# pytorch/python/numpy formats
*.pth
*.pkl
*.npy
*.ts
model_ts*.txt

# ipython/jupyter notebooks
*.ipynb
**/.ipynb_checkpoints/

# Editor temporaries
*.swn
*.swo
*.swp
*~

# editor settings
.idea
.vscode
_darcs
.DS_Store
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Changelog

### v1.1 (29/01/2021)

#### New Features

- NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20)
- Multi-teacher Knowledge Distillation

#### Bug Fixes

#### Improvements
30 changes: 17 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,52 @@ FastReID is a research platform that implements state-of-the-art re-identificati

## What's New

- [Jan 2021] NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20) based on fastreid has been released!
- [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released!
- [Jan 2021] FastReID V1.0 has been released!🎉
Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0).
- [Oct 2020] Added the [Hyper-Parameter Optimization](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastTune) based on fastreid. See `projects/FastTune`.
- [Sep 2020] Added the [person attribute recognition](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastAttr) based on fastreid. See `projects/FastAttr`.
- [Oct 2020] Added the [Hyper-Parameter Optimization](projects/FastTune) based on fastreid. See `projects/FastTune`.
- [Sep 2020] Added the [person attribute recognition](projects/FastAttr) based on fastreid. See `projects/FastAttr`.
- [Sep 2020] Automatic Mixed Precision training is supported with `apex`. Set `cfg.SOLVER.FP16_ENABLED=True` to switch it on.
- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
- [Aug 2020] [Model Distillation](projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
- [Aug 2020] ONNX/TensorRT converter is supported.
- [Jul 2020] Distributed training with multiple GPUs, it trains much faster.
- Includes more features such as circle loss, abundant visualization methods and evaluation metrics, SoTA results on conventional, cross-domain, partial and vehicle re-id, testing on multi-datasets simultaneously, etc.
- Can be used as a library to support [different projects](https://github.com/JDAI-CV/fast-reid/tree/master/projects) on top of it. We'll open source more research projects in this way.
- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way.
- Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/).

We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox.

## Changelog

Please refer to [changelog.md](CHANGELOG.md) for details and release history.

## Installation

See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/INSTALL.md).
See [INSTALL.md](INSTALL.md).

## Quick Start

The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.

See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md).
See [GETTING_STARTED.md](GETTING_STARTED.md).

Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](projects) for some projects that are build on top of fastreid.

## Model Zoo and Baselines

We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md).
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](MODEL_ZOO.md).

## Deployment

We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](https://github.com/JDAI-CV/fast-reid/blob/master/tools/deploy).
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](tools/deploy).

## License

Fastreid is released under the [Apache 2.0 license](https://github.com/JDAI-CV/fast-reid/blob/master/LICENSE).
Fastreid is released under the [Apache 2.0 license](LICENSE).

## Citing Fastreid
## Citing FastReID

If you use Fastreid in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
If you use FastReID in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.

```BibTeX
@article{he2020fastreid,
Expand Down
4 changes: 2 additions & 2 deletions fastreid/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@
# -----------------------------------------------------------------------------

_C.KD = CN()
_C.KD.MODEL_CONFIG = ""
_C.KD.MODEL_WEIGHTS = ""
_C.KD.MODEL_CONFIG = ['',]
_C.KD.MODEL_WEIGHTS = ['',]

# -----------------------------------------------------------------------------
# INPUT
Expand Down
38 changes: 24 additions & 14 deletions fastreid/modeling/meta_arch/distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,25 @@ def __init__(self, cfg):
super(Distiller, self).__init__(cfg)

# Get teacher model config
cfg_t = get_cfg()
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG)
model_ts = []
for i in range(len(cfg.KD.MODEL_CONFIG)):
cfg_t = get_cfg()
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i])

model_t = build_model(cfg_t)
logger.info("Teacher model:\n{}".format(model_t))
model_t = build_model(cfg_t)

# No gradients for teacher model
for param in model_t.parameters():
param.requires_grad_(False)
# No gradients for teacher model
for param in model_t.parameters():
param.requires_grad_(False)

logger.info("Loading teacher model weights ...")
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS)
logger.info("Loading teacher model weights ...")
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS[i])

model_ts.append(model_t)

# Not register teacher model as `nn.Module`, this is
# make sure teacher model weights not saved
self.model_t = [model_t.backbone, model_t.heads]
self.model_ts = model_ts

def forward(self, batched_inputs):
if self.training:
Expand All @@ -51,10 +54,13 @@ def forward(self, batched_inputs):

s_outputs = self.heads(s_feat, targets)

t_outputs = []
# teacher model forward
with torch.no_grad():
t_feat = self.model_t[0](images)
t_outputs = self.model_t[1](t_feat, targets)
for model_t in self.model_ts:
t_feat = model_t.backbone(images)
t_output = model_t.heads(t_feat, targets)
t_outputs.append(t_output)

losses = self.losses(s_outputs, t_outputs, targets)
return losses
Expand All @@ -71,8 +77,12 @@ def losses(self, s_outputs, t_outputs, gt_labels):
loss_dict = super(Distiller, self).losses(s_outputs, gt_labels)

s_logits = s_outputs["pred_class_logits"]
t_logits = t_outputs["pred_class_logits"].detach()
loss_dict["loss_jsdiv"] = self.jsdiv_loss(s_logits, t_logits)
loss_jsdiv = 0.
for t_output in t_outputs:
t_logits = t_output["pred_class_logits"].detach()
loss_jsdiv += self.jsdiv_loss(s_logits, t_logits)

loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs)

return loss_dict

Expand Down
4 changes: 2 additions & 2 deletions projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ MODEL:
WITH_IBN: False

KD:
MODEL_CONFIG: projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml
MODEL_WEIGHTS: projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth
MODEL_CONFIG: ("projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml",)
MODEL_WEIGHTS: ("projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth",)

DATASETS:
NAMES: ("DukeMTMC",)
Expand Down
44 changes: 26 additions & 18 deletions projects/FastDistill/fastdistill/overhaul.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,18 @@ def __init__(self, cfg):
super().__init__(cfg)

s_channels = self.backbone.get_channel_nums()
t_channels = self.model_t[0].get_channel_nums()

self.connectors = nn.ModuleList(
[build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)])
for i in range(len(self.model_ts)):
t_channels = self.model_ts[i].backbone.get_channel_nums()

teacher_bns = self.model_t[0].get_bn_before_relu()
margins = [get_margin_from_BN(bn) for bn in teacher_bns]
for i, margin in enumerate(margins):
self.register_buffer("margin%d" % (i + 1),
margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
setattr(self, "connectors_{}".format(i), nn.ModuleList(
[build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)]))

teacher_bns = self.model_ts[i].backbone.get_bn_before_relu()
margins = [get_margin_from_BN(bn) for bn in teacher_bns]
for j, margin in enumerate(margins):
self.register_buffer("margin{}_{}".format(i, j + 1),
margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())

def forward(self, batched_inputs):
if self.training:
Expand All @@ -84,20 +86,25 @@ def forward(self, batched_inputs):

s_outputs = self.heads(s_feat, targets)

t_feats_list = []
t_outputs = []
# teacher model forward
with torch.no_grad():
t_feats, t_feat = self.model_t[0].extract_feature(images, preReLU=True)
t_outputs = self.model_t[1](t_feat, targets)
for model_t in self.model_ts:
t_feats, t_feat = model_t.backbone.extract_feature(images, preReLU=True)
t_output = model_t.heads(t_feat, targets)
t_feats_list.append(t_feats)
t_outputs.append(t_output)

losses = self.losses(s_outputs, s_feats, t_outputs, t_feats, targets)
losses = self.losses(s_outputs, s_feats, t_outputs, t_feats_list, targets)
return losses

else:
outputs = super(DistillerOverhaul, self).forward(batched_inputs)
return outputs

def losses(self, s_outputs, s_feats, t_outputs, t_feats, gt_labels):
r"""
def losses(self, s_outputs, s_feats, t_outputs, t_feats_list, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
Expand All @@ -106,11 +113,12 @@ def losses(self, s_outputs, s_feats, t_outputs, t_feats, gt_labels):
# Overhaul distillation loss
feat_num = len(s_feats)
loss_distill = 0
for i in range(feat_num):
s_feats[i] = self.connectors[i](s_feats[i])
loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(
self, "margin%d" % (i + 1)).to(s_feats[i].dtype)) / 2 ** (feat_num - i - 1)
for i in range(len(t_feats_list)):
for j in range(feat_num):
s_feats_connect = getattr(self, "connectors_{}".format(i))[j](s_feats[j])
loss_distill += distillation_loss(s_feats_connect, t_feats_list[i][j].detach(), getattr(
self, "margin{}_{}".format(i, j + 1)).to(s_feats_connect.dtype)) / 2 ** (feat_num - j - 1)

loss_dict["loss_overhaul"] = loss_distill / len(gt_labels) / 10000
loss_dict["loss_overhaul"] = loss_distill / len(t_feats_list) / len(gt_labels) / 10000

return loss_dict

0 comments on commit 77a91b1

Please sign in to comment.