From cf44e303b0dfec634c558acf58e81ac3e3a36af6 Mon Sep 17 00:00:00 2001
From: HantingChen <40243544+HantingChen@users.noreply.github.com>
Date: Mon, 22 May 2023 20:18:12 +0800
Subject: [PATCH] Add files via upload
---
License | 12 +
README.md | 155 +++++-
THIRD_PARTY_OPEN_SOURCE_SOFTWARE_NOTICE.txt | 31 ++
datasets.py | 96 ++++
engine.py | 195 +++++++
main.py | 559 ++++++++++++++++++++
optim_factory.py | 180 +++++++
test_latency.py | 45 ++
utils.py | 520 ++++++++++++++++++
9 files changed, 1784 insertions(+), 9 deletions(-)
create mode 100644 License
create mode 100644 THIRD_PARTY_OPEN_SOURCE_SOFTWARE_NOTICE.txt
create mode 100644 datasets.py
create mode 100644 engine.py
create mode 100644 main.py
create mode 100644 optim_factory.py
create mode 100644 test_latency.py
create mode 100644 utils.py
diff --git a/License b/License
new file mode 100644
index 0000000..07664df
--- /dev/null
+++ b/License
@@ -0,0 +1,12 @@
+Copyright (c) 2023. Huawei Technologies Co., Ltd.
+ All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/README.md b/README.md
index 2fef13e..44d51a5 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,13 @@
# VanillaNet: the Power of Minimalism in Deep Learning
+
+
+
+
Official PyTorch implementation of **VanillaNet**, from the following paper:\
-**VanillaNet: the Power of Minimalism in Deep Learning**\
-*Hanting chen, [Yunhe Wang](https://www.wangyunhe.site/), Jianyuan Guo and Dacheng Tao*
+[VanillaNet: the Power of Minimalism in Deep Learning ](https://arxiv.org/abs/)\
+Hanting chen, [Yunhe Wang](https://www.wangyunhe.site/), Jianyuan Guo and Dacheng Tao
@@ -11,9 +15,9 @@ VanillaNet is an innovative neural network architecture that focuses on **simpli
## Comparison of Depth and Speed
-
+
-VanillaNet, in its robust simplicity, offers comparable precision to prevalent computer vision foundation models, yet boasts a **reduced depth and enhanced processing speed** (test on Nvidia A100 GPU with batch size 1):
+VanillaNet, in its robust simplicity, offers comparable precision to prevalent computer vision foundation models, yet boasts a **reduced depth and enhanced processing speed**:
- **9-layers'** VanillaNet achieves about **80%** Top-1 accuracy with **3.59ms**, over **100%** speed increase compared to ResNet-50 (**7.64ms**).
- **13 layers'** VanillaNet achieves about **83%** Top-1 accuracy with **9.72ms**, over **100%** speed increase compared to Swin-T (**20.25ms**).
@@ -21,7 +25,7 @@ VanillaNet, in its robust simplicity, offers comparable precision to prevalent c
| Framework | Backbone | FLOPs(G) | #params(M) | FPS | APb | APm |
|:---:|:---:|:---:|:---:| :---:|:---:|:---:|
-| RetinaNet | Swin-T| 245 | 38.5 | 27.5 | 41.5 | - |
+| ReTinaNet | Swin-T| 245 | 38.5 | 27.5 | 41.5 | - |
| | VanillaNet-11 | 386 | 67.0 | 30.8 | 41.8 | - |
| Mask RCNN | ConvNeXtV2-N | 221 | 35.2 | 31.7 | 42.7 | 38.9 |
| | [Swin-T](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | 267 | 47.8 | 28.2 | 42.7 | 39.3 |
@@ -32,9 +36,9 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm
## Catalog
-- [ ] ImageNet-1K Testing Code
-- [ ] ImageNet-1K Training Code of VanillaNet-5 to VanillaNet-10
-- [ ] ImageNet-1K Pretrained Weights of VanillaNet-5 to VanillaNet-10
+- [x] ImageNet-1K Testing Code
+- [x] ImageNet-1K Training Code of VanillaNet-5 to VanillaNet-10
+- [x] ImageNet-1K Pretrained Weights of VanillaNet-5 to VanillaNet-10
- [ ] ImageNet-1K Training Code of VanillaNet-11 to VanillaNet-13
- [ ] ImageNet-1K Pretrained Weights of VanillaNet-11 to VanillaNet-13
- [ ] Downstream Transfer (Detection, Segmentation) Code
@@ -54,7 +58,7 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm
| VanillaNet-12 | 54.3 | 11.1 | 3.82 | 81.55 | - |
| VanillaNet-13 | 58.6 | 11.9 | 4.26 | 82.05 | - |
| VanillaNet-13-1.5x | 127.8 | 26.5 | 7.83 | 82.53 | - |
-| VanillaNet-13-1.5x† | 127.8 | 48.5 | 9.72 | 83.11 | - |
+| VanillaNet-13-1.5x† | 127.8 | 9.72 | 198M | 83.11 | - |
## Installation
@@ -72,6 +76,139 @@ pip install tensorboardX
pip install terminaltables
```
+## Testing
+
+We give an example evaluation command for VanillaNet-5:
+
+without deploy:
+
+```
+python -m torch.distributed.launch --nproc_per_node=1 main.py --model vanillanet_5 --data_path /path/to/imagenet-1k/ --real_labels /path/to/imagenet_real_labels.json --finetune /path/to/vanillanet_5.pth --eval True --model_key model_ema --crop_pct 0.875
+```
+
+with deploy:
+```
+python -m torch.distributed.launch --nproc_per_node=1 main.py --model vanillanet_5 --data_path /path/to/imagenet-1k/ --real_labels /path/to/imagenet_real_labels.json --finetune /path/to/vanillanet_5.pth --eval True --model_key model_ema --crop_pct 0.875 --switch_to_deploy /path/to/vanillanet_5_deploy.pth
+```
+
+## Training
+
+You can use the following command to train VanillaNet-5 on a single machine with 8 GPUs:
+```
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+--model vanillanet_5 \
+--data_path /path/to/imagenet-1k \
+--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
+--lr 3.5e-3 --weight_decay 0.35 --drop 0.05 \
+--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.1 --bce_loss \
+--output_dir /path/to/save_results \
+--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
+--use_amp true
+```
+
+- Here, the effective batch size = `--nproc_per_node` * `--batch_size` * `--update_freq`. In the example above, the effective batch size is `8*128*1 = 1024`.
+
+To train other VanillaNet variants, `--model` need to be changed. Examples are given below.
+
+
+
+VanillaNet-6
+
+
+```
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+--model vanillanet_6 \
+--data_path /path/to/imagenet-1k \
+--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
+--lr 4.8e-3 --weight_decay 0.32 --drop 0.05 \
+--layer_decay 0.8 --layer_decay_num_layers 4 \
+--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.15 --bce_loss \
+--output_dir /path/to/save_results \
+--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
+--use_amp true
+```
+
+
+
+
+VanillaNet-7
+
+
+```
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+--model vanillanet_7 \
+--data_path /path/to/imagenet-1k \
+--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
+--lr 4.7e-3 --weight_decay 0.35 --drop 0.05 \
+--layer_decay 0.8 --layer_decay_num_layers 5 \
+--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.4 --bce_loss \
+--output_dir /path/to/save_results \
+--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
+--use_amp true
+```
+
+
+
+
+VanillaNet-8
+
+
+```
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+--model vanillanet_8 \
+--data_path /path/to/imagenet-1k \
+--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
+--lr 3.5e-3 --weight_decay 0.3 --drop 0.05 \
+--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.4 --bce_loss \
+--output_dir /path/to/save_results \
+--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
+--use_amp true
+```
+
+
+
+
+VanillaNet-9
+
+
+```
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+--model vanillanet_9 \
+--data_path /path/to/imagenet-1k \
+--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
+--lr 3.5e-3 --weight_decay 0.3 --drop 0.05 \
+--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.4 --bce_loss \
+--output_dir /path/to/save_results \
+--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
+--use_amp true
+```
+
+
+
+
+VanillaNet-10
+
+
+```
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+--model vanillanet_10 \
+--data_path /path/to/imagenet-1k \
+--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
+--lr 3.5e-3 --weight_decay 0.25 --drop 0.05 \
+--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.4 --bce_loss \
+--output_dir /path/to/save_results \
+--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
+--use_amp true
+```
+
+
+
+### Acknowledgement
+
+This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library, [DeiT](https://github.com/facebookresearch/deit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [RegVGG](https://github.com/DingXiaoH/RepVGG), and [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repositories.
### License
This project is released under the MIT license. Please see the [LICENSE](License) file for more information.
+
+### Citation
+If our work is useful for your research, please consider citing:
diff --git a/THIRD_PARTY_OPEN_SOURCE_SOFTWARE_NOTICE.txt b/THIRD_PARTY_OPEN_SOURCE_SOFTWARE_NOTICE.txt
new file mode 100644
index 0000000..5734dd2
--- /dev/null
+++ b/THIRD_PARTY_OPEN_SOURCE_SOFTWARE_NOTICE.txt
@@ -0,0 +1,31 @@
+Please note we provide an open source software notice for the third party open source software along with this software and/or this software component contributed by Huawei (in the following just “this SOFTWARE”). The open source software licenses are granted by the respective right holders.
+
+Warranty Disclaimer
+THE OPEN SOURCE SOFTWARE IN THIS SOFTWARE IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
+
+Copyright Notice and License Texts
+Software: ConvNeXt
+Copyright notice:
+Copyright (c) 2022
+
+License: MIT License
+
+Copyright (c) Meta Platforms, Inc. and affiliates.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/datasets.py b/datasets.py
new file mode 100644
index 0000000..ee96135
--- /dev/null
+++ b/datasets.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import os
+from torchvision import datasets, transforms
+
+from timm.data.constants import \
+ IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from timm.data import create_transform
+
+def build_dataset(is_train, args):
+ transform = build_transform(is_train, args)
+
+ print("Transform = ")
+ if isinstance(transform, tuple):
+ for trans in transform:
+ print(" - - - - - - - - - - ")
+ for t in trans.transforms:
+ print(t)
+ else:
+ for t in transform.transforms:
+ print(t)
+ print("---------------------------")
+
+ if args.data_set == 'CIFAR':
+ dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
+ nb_classes = 100
+ elif args.data_set == 'IMNET':
+ print("reading from datapath", args.data_path)
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = 1000
+ elif args.data_set == "image_folder":
+ root = args.data_path if is_train else args.eval_data_path
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = args.nb_classes
+ assert len(dataset.class_to_idx) == nb_classes
+ else:
+ raise NotImplementedError()
+ print("Number of the class = %d" % nb_classes)
+
+ return dataset, nb_classes
+
+
+def build_transform(is_train, args):
+ resize_im = args.input_size > 32
+ imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
+ mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
+ std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
+
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=args.input_size,
+ is_training=True,
+ color_jitter=args.color_jitter,
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ re_prob=args.reprob,
+ re_mode=args.remode,
+ re_count=args.recount,
+ mean=mean,
+ std=std,
+ )
+ if not resize_im:
+ transform.transforms[0] = transforms.RandomCrop(
+ args.input_size, padding=4)
+ return transform
+
+ t = []
+ if resize_im:
+ # warping (no cropping) when evaluated at 384 or larger
+ if args.input_size >= 384:
+ t.append(
+ transforms.Resize((args.input_size, args.input_size),
+ interpolation=transforms.InterpolationMode.BICUBIC),
+ )
+ print(f"Warping {args.input_size} size input images...")
+ else:
+ if args.crop_pct is None:
+ args.crop_pct = 224 / 256
+ size = int(args.input_size / args.crop_pct)
+ t.append(
+ # to maintain same ratio w.r.t. 224 images
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
+ )
+ t.append(transforms.CenterCrop(args.input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(mean, std))
+ return transforms.Compose(t)
diff --git a/engine.py b/engine.py
new file mode 100644
index 0000000..609fa00
--- /dev/null
+++ b/engine.py
@@ -0,0 +1,195 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import math
+from typing import Iterable, Optional
+import torch
+from timm.data import Mixup
+from timm.utils import accuracy, ModelEma
+import logging
+
+import utils
+
+def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
+ wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
+ num_training_steps_per_epoch=None, update_freq=None, use_amp=False):
+ model.train(True)
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = 200
+
+ optimizer.zero_grad()
+
+ for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ step = data_iter_step // update_freq
+ if step >= num_training_steps_per_epoch:
+ continue
+ it = start_steps + step # global training iteration
+ # Update LR & WD for the first acc
+ if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
+ for i, param_group in enumerate(optimizer.param_groups):
+ if lr_schedule_values is not None:
+ param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
+ if wd_schedule_values is not None and param_group["weight_decay"] > 0:
+ param_group["weight_decay"] = wd_schedule_values[it]
+
+ samples = samples.to(device, non_blocking=True)
+ targets = targets.to(device, non_blocking=True)
+
+ if mixup_fn is not None:
+ samples, targets = mixup_fn(samples, targets)
+
+ if use_amp:
+ with torch.cuda.amp.autocast():
+ output = model(samples)
+ loss = criterion(output, targets)
+ else: # full precision
+ output = model(samples)
+ loss = criterion(output, targets)
+
+ loss_value = loss.item()
+
+ if not math.isfinite(loss_value): # this could trigger if using AMP
+ logging.error("Logging: Loss is {}, stopping training".format(loss_value))
+ print("Loss is {}, stopping training".format(loss_value))
+ assert math.isfinite(loss_value)
+
+ if use_amp:
+ # this attribute is added by timm on one optimizer (adahessian)
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+ loss /= update_freq
+ grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
+ parameters=model.parameters(), create_graph=is_second_order,
+ update_grad=(data_iter_step + 1) % update_freq == 0)
+ if (data_iter_step + 1) % update_freq == 0:
+ optimizer.zero_grad()
+ if model_ema is not None:
+ for iter_model_ema in model_ema:
+ iter_model_ema.update(model)
+ for i in range(len(iter_model_ema.ema.stages)):
+ if hasattr(iter_model_ema.ema.stages[i], 'act_learn'):
+ iter_model_ema.ema.stages[i].act_learn = model.module.stages[i].act_learn
+ if hasattr(iter_model_ema.ema, 'act_learn'):
+ iter_model_ema.ema.act_learn = model.module.act_learn
+ else: # full precision
+ loss /= update_freq
+ loss.backward()
+ if (data_iter_step + 1) % update_freq == 0:
+ optimizer.step()
+ optimizer.zero_grad()
+ if model_ema is not None:
+ for iter_model_ema in model_ema:
+ iter_model_ema.update(model)
+ for i in range(len(iter_model_ema.ema.stages)):
+ if hasattr(iter_model_ema.ema.stages[i], 'act_learn'):
+ iter_model_ema.ema.stages[i].act_learn = model.module.stages[i].act_learn
+ if hasattr(iter_model_ema.ema, 'act_learn'):
+ iter_model_ema.ema.act_learn = model.module.act_learn
+
+ torch.cuda.synchronize()
+
+ if mixup_fn is None:
+ class_acc = (output.max(-1)[-1] == targets).float().mean()
+ else:
+ class_acc = None
+ metric_logger.update(loss=loss_value)
+ metric_logger.update(class_acc=class_acc)
+ min_lr = 10.
+ max_lr = 0.
+ for group in optimizer.param_groups:
+ min_lr = min(min_lr, group["lr"])
+ max_lr = max(max_lr, group["lr"])
+
+ metric_logger.update(lr=max_lr)
+ metric_logger.update(min_lr=min_lr)
+ weight_decay_value = None
+ for group in optimizer.param_groups:
+ if group["weight_decay"] > 0:
+ weight_decay_value = group["weight_decay"]
+ metric_logger.update(weight_decay=weight_decay_value)
+ if use_amp:
+ metric_logger.update(grad_norm=grad_norm)
+
+ if log_writer is not None:
+ log_writer.update(loss=loss_value, head="loss")
+ log_writer.update(class_acc=class_acc, head="loss")
+ log_writer.update(lr=max_lr, head="opt")
+ log_writer.update(min_lr=min_lr, head="opt")
+ log_writer.update(weight_decay=weight_decay_value, head="opt")
+ if use_amp:
+ log_writer.update(grad_norm=grad_norm, head="opt")
+ log_writer.set_step()
+
+ if wandb_logger:
+ wandb_logger._wandb.log({
+ 'Rank-0 Batch Wise/train_loss': loss_value,
+ 'Rank-0 Batch Wise/train_max_lr': max_lr,
+ 'Rank-0 Batch Wise/train_min_lr': min_lr
+ }, commit=False)
+ if class_acc:
+ wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False)
+ if use_amp:
+ wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
+ wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
+
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+@torch.no_grad()
+def evaluate(data_loader, model, device, use_amp=False, real_labels=None):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Test:'
+
+ # switch to evaluation mode
+ model.eval()
+ for batch in metric_logger.log_every(data_loader, 200, header):
+ images = batch[0]
+ target = batch[-1]
+
+ images = images.to(device, non_blocking=True)
+ target = target.to(device, non_blocking=True)
+
+ # compute output
+ if use_amp:
+ with torch.cuda.amp.autocast():
+ output = model(images)
+ loss = criterion(output, target)
+ else:
+ output = model(images)
+ loss = criterion(output, target)
+
+ if real_labels is not None:
+ real_labels.add_result(output)
+
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+ batch_size = images.shape[0]
+ metric_logger.update(loss=loss.item())
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print('* val Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
+
+ if real_labels is not None:
+ # real labels mode replaces topk values at the end
+ acc1, acc5 = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
+ print('* REAL Acc@1 {:.3f} Acc@5 {:.3f}'.format(acc1, acc5))
+
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..9fbd657
--- /dev/null
+++ b/main.py
@@ -0,0 +1,559 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+import warnings
+warnings.filterwarnings('ignore')
+import datetime
+import numpy as np
+import time
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+import json
+import math
+
+from pathlib import Path
+
+from timm.data import create_dataset, create_loader, RealLabelsImagenet, Mixup
+from timm.models import create_model
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy
+from timm.utils import ModelEma
+from optim_factory import create_optimizer, LayerDecayValueAssigner
+
+from datasets import build_dataset
+from engine import train_one_epoch, evaluate
+
+from utils import NativeScalerWithGradNormCount as NativeScaler
+import utils
+import models.vanillanet
+
+
+def str2bool(v):
+ """
+ Converts string to bool type; enables command line
+ arguments in the format of '--arg1 true --arg2 false'
+ """
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('Vanillanet script', add_help=False)
+ parser.add_argument('--batch_size', default=64, type=int,
+ help='Per GPU batch size')
+ parser.add_argument('--epochs', default=300, type=int)
+ parser.add_argument('--early_stop_epochs', default=None, type=int)
+ parser.add_argument('--decay_epochs', default=100, type=int,
+ help='for deep training strategy')
+ parser.add_argument('--decay_linear', type=str2bool, default=True,
+ help='cos/linear for decay manner')
+ parser.add_argument('--update_freq', default=1, type=int,
+ help='gradient accumulation steps')
+
+ # Model parameters
+ parser.add_argument('--model', default='vanillanet_5', type=str, metavar='MODEL',
+ help='Name of model to train')
+ parser.add_argument('--switch_to_deploy', default=None, type=str)
+ parser.add_argument('--deploy', type=str2bool, default=False)
+ parser.add_argument('--drop', type=float, default=0, metavar='PCT',
+ help='Drop rate (default: 0.0)')
+ parser.add_argument('--input_size', default=224, type=int,
+ help='image input size')
+
+ # EMA related parameters
+ parser.add_argument('--model_ema', type=str2bool, default=False)
+ parser.add_argument('--model_ema_decay', type=float, default=0.9999, nargs='+')
+ parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
+ parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.')
+
+ # Optimization parameters
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+ help='Optimizer (default: "adamw"')
+ parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
+ help='Optimizer Epsilon (default: 1e-8)')
+ parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
+ help='Optimizer Betas (default: None, use opt default)')
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+ help='Clip gradient norm (default: None, no clipping)')
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+ help='SGD momentum (default: 0.9)')
+ parser.add_argument('--weight_decay', type=float, default=0.05,
+ help='weight decay (default: 0.05)')
+ parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
+ weight decay. We use a cosine schedule for WD and using a larger decay by
+ the end of training improves performance for ViTs.""")
+
+ parser.add_argument('--lr', type=float, default=4e-3, metavar='LR',
+ help='learning rate (default: 4e-3), with total batch size 4096')
+ parser.add_argument('--layer_decay', type=float, default=1.0)
+ parser.add_argument('--layer_decay_num_layers', default=4, type=int)
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
+ help='epochs to warmup LR, if scheduler supports')
+ parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
+ help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
+
+ # Augmentation parameters
+ parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
+ help='Color jitter factor (default: 0.4)')
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
+ help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
+ parser.add_argument('--smoothing', type=float, default=0.1,
+ help='Label smoothing (default: 0.1)')
+ parser.add_argument('--train_interpolation', type=str, default='bicubic',
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
+ parser.add_argument('--bce_loss', action='store_true', default=False,
+ help='Enable BCE loss w/ Mixup/CutMix use.')
+ parser.add_argument('--bce_target_thresh', type=float, default=None,
+ help='Threshold for binarizing softened BCE targets (default: None, disabled)')
+
+ # Evaluation parameters
+ parser.add_argument('--crop_pct', type=float, default=None)
+
+ # * Random Erase params
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+ help='Random erase prob (default: 0.25)')
+ parser.add_argument('--remode', type=str, default='pixel',
+ help='Random erase mode (default: "pixel")')
+ parser.add_argument('--recount', type=int, default=1,
+ help='Random erase count (default: 1)')
+ parser.add_argument('--resplit', type=str2bool, default=False,
+ help='Do not random erase first (clean) augmentation split')
+
+ # * Mixup params
+ parser.add_argument('--mixup', type=float, default=0.8,
+ help='mixup alpha, mixup enabled if > 0.')
+ parser.add_argument('--cutmix', type=float, default=1.0,
+ help='cutmix alpha, cutmix enabled if > 0.')
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+ parser.add_argument('--mixup_prob', type=float, default=1.0,
+ help='Probability of performing mixup or cutmix when either/both is enabled')
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
+ parser.add_argument('--mixup_mode', type=str, default='batch',
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+
+ # * Finetuning params
+ parser.add_argument('--finetune', default='',
+ help='finetune from checkpoint')
+ parser.add_argument('--model_key', default='model|module', type=str,
+ help='which key to load from saved state dict, usually model or model_ema')
+ parser.add_argument('--model_prefix', default='', type=str)
+
+ # Dataset parameters
+ parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
+ help='dataset path')
+ parser.add_argument('--eval_data_path', default=None, type=str,
+ help='dataset path for evaluation')
+ parser.add_argument('--nb_classes', default=1000, type=int,
+ help='number of the classification types')
+ parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
+ parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'],
+ type=str, help='ImageNet dataset path')
+ parser.add_argument('--output_dir', default='',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--log_dir', default=None,
+ help='path where to tensorboard log')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
+ parser.add_argument('--resume_new_sched', action='store_true', help='resume with new schedule')
+ parser.set_defaults(resume_new_sched=False)
+ parser.add_argument('--auto_resume', type=str2bool, default=False)
+ parser.add_argument('--save_ckpt', type=str2bool, default=True)
+ parser.add_argument('--save_ckpt_freq', default=1, type=int)
+ parser.add_argument('--test_freq', default=20, type=int)
+ parser.add_argument('--test_epoch', default=260, type=int)
+ parser.add_argument('--save_ckpt_num', default=10, type=int)
+
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--eval', type=str2bool, default=False,
+ help='Perform evaluation only')
+ parser.add_argument('--dist_eval', type=str2bool, default=True,
+ help='Enabling distributed evaluation')
+ parser.add_argument('--disable_eval', type=str2bool, default=False,
+ help='Disabling evaluation during training')
+ parser.add_argument('--num_workers', default=10, type=int)
+ parser.add_argument('--pin_mem', type=str2bool, default=True,
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', type=str2bool, default=False)
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+
+ parser.add_argument('--use_amp', type=str2bool, default=False,
+ help="Use PyTorch's AMP (Automatic Mixed Precision) or not")
+
+ # Weights and Biases arguments
+ parser.add_argument('--enable_wandb', type=str2bool, default=False,
+ help="enable logging to Weights and Biases")
+ parser.add_argument('--wandb_ckpt', type=str2bool, default=False,
+ help="Save model checkpoints as W&B Artifacts.")
+
+ parser.add_argument('--act_num', default=3, type=int)
+ parser.add_argument('--real_labels', default='', type=str, metavar='FILENAME',
+ help='Real labels JSON file for imagenet evaluation')
+
+ return parser
+
+def main(args):
+ utils.init_distributed_mode(args)
+ print(args)
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ cudnn.benchmark = True
+
+ dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
+ if args.disable_eval:
+ args.dist_eval = False
+ dataset_val = None
+ else:
+ dataset_val, _ = build_dataset(is_train=False, args=args)
+
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
+ )
+ print("Sampler_train = %s" % str(sampler_train))
+ if args.dist_eval:
+ if len(dataset_val) % num_tasks != 0:
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
+ 'equal num of samples per-process.')
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
+ else:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ if global_rank == 0 and args.log_dir is not None:
+ os.makedirs(args.log_dir, exist_ok=True)
+ log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
+ else:
+ log_writer = None
+
+ if global_rank == 0 and args.enable_wandb:
+ wandb_logger = utils.WandbLogger(args)
+ else:
+ wandb_logger = None
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ )
+
+ if dataset_val is not None:
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=int(1.5 * args.batch_size),
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False
+ )
+ else:
+ data_loader_val = None
+
+ mixup_fn = None
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+ if mixup_active:
+ print("Mixup is activated!")
+ mixup_fn = Mixup(
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+ label_smoothing=args.smoothing, num_classes=args.nb_classes)
+
+ model = create_model(
+ args.model,
+ pretrained=False,
+ num_classes=args.nb_classes,
+ act_num=args.act_num,
+ drop_rate=args.drop,
+ deploy=args.deploy,
+ )
+
+ if args.finetune:
+ if args.finetune.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.finetune, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.finetune, map_location='cpu')
+
+ print("Load ckpt from %s" % args.finetune)
+ checkpoint_model = None
+ for model_key in args.model_key.split('|'):
+ if model_key in checkpoint:
+ checkpoint_model = checkpoint[model_key]
+ print("Load state_dict by model_key = %s" % model_key)
+ break
+ if checkpoint_model is None:
+ checkpoint_model = checkpoint
+ state_dict = model.state_dict()
+ for k in ['head.weight', 'head.bias']:
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
+ print(f"Removing key {k} from pretrained checkpoint")
+ del checkpoint_model[k]
+ utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
+
+ model.to(device)
+ if args.switch_to_deploy:
+ model.switch_to_deploy()
+ model_ckpt = dict()
+ model_ckpt['model'] = model.state_dict()
+ torch.save(model_ckpt, args.switch_to_deploy)
+
+ model_ema = None
+ if args.model_ema:
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
+ model_ema = []
+ for ema_decay in args.model_ema_decay:
+ model_ema.append(
+ ModelEma(model, decay=ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='')
+ )
+ print("Using EMA with decay = %s" % args.model_ema_decay)
+
+ model_without_ddp = model
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ print("Model = %s" % str(model_without_ddp))
+ print('number of params (M): %.2f' % (n_parameters / 1.e6))
+
+ if not args.eval:
+ input_size = [2, 3, args.input_size, args.input_size]
+ input = torch.randn(input_size).cuda()
+ from torchprofile import profile_macs
+ macs = profile_macs(model, input)
+ print('model flops (G):', macs / 2 / 1.e9, 'input_size:', input_size)
+
+ total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
+ num_training_steps_per_epoch = len(dataset_train) // total_batch_size
+ print("LR = %.8f" % args.lr)
+ print("Batch size = %d" % total_batch_size)
+ print("Update frequent = %d" % args.update_freq)
+ print("Number of training examples = %d" % len(dataset_train))
+ print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
+
+ if args.layer_decay < 1.0 or args.layer_decay > 1.0:
+ num_layers = args.layer_decay_num_layers
+ assigner = LayerDecayValueAssigner(num_max_layer=num_layers, values=list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
+ else:
+ assigner = None
+
+ if assigner is not None:
+ print("Assigned values = %s" % str(assigner.values))
+
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
+ model_without_ddp = model.module
+
+ optimizer = create_optimizer(
+ args, model_without_ddp, skip_list=None,
+ get_num_layer=assigner.get_layer_id if assigner is not None else None,
+ get_layer_scale=assigner.get_scale if assigner is not None else None)
+
+ loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used
+
+ print("Use Cosine LR scheduler")
+ lr_schedule_values = utils.cosine_scheduler(
+ args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
+ warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
+ )
+
+ if args.weight_decay_end is None:
+ args.weight_decay_end = args.weight_decay
+ wd_schedule_values = utils.cosine_scheduler(
+ args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
+ print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
+
+ if mixup_fn is not None:
+ if args.bce_loss:
+ criterion = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
+ else:
+ # smoothing is handled with mixup label transform
+ criterion = SoftTargetCrossEntropy()
+ elif args.smoothing > 0.:
+ criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+ else:
+ criterion = torch.nn.CrossEntropyLoss()
+
+ print("criterion = %s" % str(criterion))
+
+ utils.auto_load_model(
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
+
+ if args.eval:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=int(1.5 * args.batch_size),
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False
+ )
+ if args.real_labels:
+ dataset = create_dataset(root=args.data_path, name='', split='validation', class_map='')
+ real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
+ else:
+ real_labels = None
+ print(f"Eval only mode")
+ test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp, real_labels=real_labels)
+ print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%")
+ return
+
+ max_accuracy = 0.0
+ max_accuracy_epoch = 0
+ if args.model_ema and args.model_ema_eval:
+ max_accuracy_ema = 0.0
+ max_accuracy_ema_epoch = 0
+ best_ema_decay = args.model_ema_decay[0]
+
+ print("Start training for %d epochs" % args.epochs)
+ start_time = time.time()
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ data_loader_train.sampler.set_epoch(epoch)
+ if log_writer is not None:
+ log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
+ if wandb_logger:
+ wandb_logger.set_steps()
+ if 'VanillaNet' == model.module.__class__. __name__ and epoch <= args.decay_epochs:
+ if args.decay_linear:
+ act_learn = epoch / args.decay_epochs * 1.0
+ else:
+ act_learn = 0.5 * (1 - math.cos(math.pi * epoch / args.decay_epochs)) * 1.0
+ print(f"VanillaNet decay_linear: {args.decay_linear}, act_learn weight: {act_learn:.3f}")
+ model.module.change_act(act_learn)
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train, optimizer,
+ device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
+ log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch,
+ lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
+ num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
+ use_amp=args.use_amp
+ )
+
+ if args.output_dir and args.save_ckpt:
+ if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, epoch_name=str(epoch), model_ema=model_ema[0])
+
+ if (data_loader_val is not None) and (epoch > 0) and (epoch % args.test_freq == 0 or epoch > args.test_epoch):
+ test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
+ print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
+ if max_accuracy < test_stats["acc1"]:
+ max_accuracy = test_stats["acc1"]
+ max_accuracy_epoch = epoch
+ if args.output_dir and args.save_ckpt:
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, epoch_name="best", model_ema=model_ema[0])
+
+ if log_writer is not None:
+ log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
+ log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
+ log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
+
+ # repeat testing routines for EMA, if ema eval is turned on
+ if args.model_ema and args.model_ema_eval:
+ for idx, iter_model_ema in enumerate(model_ema):
+ test_stats_ema = evaluate(data_loader_val, iter_model_ema.ema, device, use_amp=args.use_amp)
+ print(f"Accuracy of the {args.model_ema_decay[idx]} EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%")
+ if max_accuracy_ema < test_stats_ema["acc1"]:
+ max_accuracy_ema = test_stats_ema["acc1"]
+ max_accuracy_ema_epoch = epoch
+ best_ema_decay = args.model_ema_decay[idx]
+ if args.output_dir and args.save_ckpt:
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, epoch_name="best-ema", model_ema=iter_model_ema)
+
+ print(f'Max Acc: {max_accuracy:.3f}% @{max_accuracy_epoch}, {best_ema_decay} EMA: {max_accuracy_ema:.3f}% @{max_accuracy_ema_epoch}')
+ if log_writer is not None:
+ log_writer.update(ema_test_acc1=test_stats_ema['acc1'], head="perf", step=epoch)
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ **{f'ema_test_{k}': v for k, v in test_stats_ema.items()},
+ 'epoch': epoch, 'n_parameters': n_parameters}
+ else:
+ print(f'Max Acc.: {max_accuracy:.3f}% @{max_accuracy_epoch}')
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch, 'n_parameters': n_parameters}
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+
+ if args.output_dir and utils.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ if wandb_logger:
+ wandb_logger.log_epoch_metrics(log_stats)
+
+ if args.early_stop_epochs and epoch == args.early_stop_epochs:
+ break
+
+ if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir:
+ wandb_logger.log_checkpoints()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+ time.sleep(10)
+ if args.real_labels and args.model_ema_eval:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=int(1.5 * args.batch_size),
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False
+ )
+ dataset = create_dataset(root=args.data_path, name='', split='validation', class_map='')
+ real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
+ print('Start eval on REAL.')
+ ckpt = torch.load(os.path.join(args.output_dir, 'checkpoint-best-ema.pth'), map_location='cpu')
+ msg = model_without_ddp.load_state_dict(ckpt['model_ema'])
+ print(msg)
+ test_stats = evaluate(data_loader_val, model_without_ddp, device, use_amp=args.use_amp, real_labels=real_labels)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('Vanillanet script', parents=[get_args_parser()])
+ args = parser.parse_args()
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ main(args)
diff --git a/optim_factory.py b/optim_factory.py
new file mode 100644
index 0000000..99d0c9e
--- /dev/null
+++ b/optim_factory.py
@@ -0,0 +1,180 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+from torch import optim as optim
+
+from timm.optim.adafactor import Adafactor
+from timm.optim.adahessian import Adahessian
+from timm.optim.adamp import AdamP
+from timm.optim.lookahead import Lookahead
+from timm.optim.nadam import Nadam
+from timm.optim.lamb import Lamb
+from timm.optim.nvnovograd import NvNovoGrad
+from timm.optim.radam import RAdam
+from timm.optim.rmsprop_tf import RMSpropTF
+from timm.optim.sgdp import SGDP
+
+import json
+
+try:
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
+ has_apex = True
+except ImportError:
+ has_apex = False
+
+
+def get_num_layer_for_vanillanet(num_max_layer, var_name):
+ if var_name.startswith("stem"):
+ return 0
+ elif var_name.startswith("stages"):
+ stage_id = int(var_name.split('.')[1])
+ return stage_id + 1
+ elif var_name.startswith("cls"):
+ return num_max_layer + 1
+ else:
+ raise ValueError('Unknown layer name: ' + var_name)
+
+
+class LayerDecayValueAssigner(object):
+ def __init__(self, num_max_layer, values):
+ self.num_max_layer = num_max_layer
+ self.values = values
+
+ def get_scale(self, layer_id):
+ return self.values[layer_id]
+
+ def get_layer_id(self, var_name):
+ return get_num_layer_for_vanillanet(self.num_max_layer, var_name)
+
+
+def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
+ parameter_group_names = {}
+ parameter_group_vars = {}
+
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith(".bias") or name.endswith(".alpha") or name.endswith(".beta") or name in skip_list:
+ group_name = "no_decay"
+ this_weight_decay = 0.
+ else:
+ group_name = "decay"
+ this_weight_decay = weight_decay
+ if get_num_layer is not None:
+ layer_id = get_num_layer(name)
+ group_name = "layer_%d_%s" % (layer_id, group_name)
+ else:
+ layer_id = None
+
+ if group_name not in parameter_group_names:
+ if get_layer_scale is not None:
+ scale = get_layer_scale(layer_id)
+ else:
+ scale = 1.
+
+ parameter_group_names[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale
+ }
+
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+ return list(parameter_group_vars.values())
+
+
+def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
+ opt_lower = args.opt.lower()
+ weight_decay = args.weight_decay
+ # if weight_decay and filter_bias_and_bn:
+ if filter_bias_and_bn:
+ skip = {}
+ if skip_list is not None:
+ skip = skip_list
+ elif hasattr(model, 'no_weight_decay'):
+ skip = model.no_weight_decay()
+ parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
+ weight_decay = 0.
+ else:
+ parameters = model.parameters()
+
+ if 'fused' in opt_lower:
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
+
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
+ opt_args['eps'] = args.opt_eps
+ if hasattr(args, 'opt_betas') and args.opt_betas is not None:
+ opt_args['betas'] = args.opt_betas
+
+ opt_split = opt_lower.split('_')
+ opt_lower = opt_split[-1]
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'momentum':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'adam':
+ optimizer = optim.Adam(parameters, **opt_args)
+ elif opt_lower == 'adamw':
+ optimizer = optim.AdamW(parameters, **opt_args)
+ elif opt_lower == 'nadam':
+ optimizer = Nadam(parameters, **opt_args)
+ elif opt_lower == 'radam':
+ optimizer = RAdam(parameters, **opt_args)
+ elif opt_lower == 'adamp':
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
+ elif opt_lower == 'sgdp':
+ optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'adadelta':
+ optimizer = optim.Adadelta(parameters, **opt_args)
+ elif opt_lower == 'adafactor':
+ if not args.lr:
+ opt_args['lr'] = None
+ optimizer = Adafactor(parameters, **opt_args)
+ elif opt_lower == 'adahessian':
+ optimizer = Adahessian(parameters, **opt_args)
+ elif opt_lower == 'rmsprop':
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
+ elif opt_lower == 'rmsproptf':
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
+ elif opt_lower == 'lamb':
+ optimizer = Lamb(parameters, **opt_args)
+ elif opt_lower == 'nvnovograd':
+ optimizer = NvNovoGrad(parameters, **opt_args)
+ elif opt_lower == 'fusedsgd':
+ opt_args.pop('eps', None)
+ optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'fusedmomentum':
+ opt_args.pop('eps', None)
+ optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'fusedadam':
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
+ elif opt_lower == 'fusedadamw':
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
+ elif opt_lower == 'fusedlamb':
+ optimizer = FusedLAMB(parameters, **opt_args)
+ elif opt_lower == 'fusednovograd':
+ opt_args.setdefault('betas', (0.95, 0.98))
+ optimizer = FusedNovoGrad(parameters, **opt_args)
+ else:
+ assert False and "Invalid optimizer"
+
+ if len(opt_split) > 1:
+ if opt_split[0] == 'lookahead':
+ optimizer = Lookahead(optimizer)
+
+ return optimizer
diff --git a/test_latency.py b/test_latency.py
new file mode 100644
index 0000000..f67c20a
--- /dev/null
+++ b/test_latency.py
@@ -0,0 +1,45 @@
+#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.
+
+#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License.
+
+#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details.
+
+import torch.nn as nn
+import torch
+import torchvision
+import time
+
+import models.vanillanet
+
+
+if __name__ == "__main__":
+ from timm.data import create_dataset, create_loader
+ dataset_val = create_dataset(name='', root='/data/imagenet/', split='validation', is_training=False, batch_size=1)
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ size = 224
+ data_loader_val = create_loader(dataset_val, input_size=size, batch_size=1, is_training=False, use_prefetcher=False)
+
+ net = vanillanet_5().cuda()
+ net.eval()
+ print(net)
+ for img, target in data_loader_val:
+ img = img.cuda()
+ for i in range(5):
+ net(img)
+ torch.cuda.synchronize()
+ t = time.time()
+ with torch.no_grad():
+ for i in range(1000):
+ net(img)
+ torch.cuda.synchronize()
+ print((time.time() - t))
+
+ n_parameters = sum(p.numel() for p in net.parameters())
+ print('number of params (M): %.2f' % (n_parameters / 1.e6))
+
+ from torchprofile import profile_macs
+ macs = profile_macs(net, img)
+ print('model flops (G):', macs / 1.e9, 'input_size:', img.shape)
+
+ break
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..537de19
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,520 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import os
+import math
+import time
+from collections import defaultdict, deque
+import datetime
+import numpy as np
+from timm.utils import get_state_dict
+
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import Sampler
+from torch._six import inf
+
+from tensorboardX import SummaryWriter
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+class TensorboardLogger(object):
+ def __init__(self, log_dir):
+ self.writer = SummaryWriter(logdir=log_dir)
+ self.step = 0
+
+ def set_step(self, step=None):
+ if step is not None:
+ self.step = step
+ else:
+ self.step += 1
+
+ def update(self, head='scalar', step=None, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
+
+ def flush(self):
+ self.writer.flush()
+
+
+class WandbLogger(object):
+ def __init__(self, args):
+ self.args = args
+
+ try:
+ import wandb
+ self._wandb = wandb
+ except ImportError:
+ raise ImportError(
+ "To use the Weights and Biases Logger please install wandb."
+ "Run `pip install wandb` to install it."
+ )
+
+ # Initialize a W&B run
+ if self._wandb.run is None:
+ self._wandb.init(
+ project=args.project,
+ config=args
+ )
+
+ def log_epoch_metrics(self, metrics, commit=True):
+ """
+ Log train/test metrics onto W&B.
+ """
+ # Log number of model parameters as W&B summary
+ self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
+ metrics.pop('n_parameters', None)
+
+ # Log current epoch
+ self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
+ metrics.pop('epoch')
+
+ for k, v in metrics.items():
+ if 'train' in k:
+ self._wandb.log({f'Global Train/{k}': v}, commit=False)
+ elif 'test' in k:
+ self._wandb.log({f'Global Test/{k}': v}, commit=False)
+
+ self._wandb.log({})
+
+ def log_checkpoints(self):
+ output_dir = self.args.output_dir
+ model_artifact = self._wandb.Artifact(
+ self._wandb.run.id + "_model", type="model"
+ )
+
+ model_artifact.add_dir(output_dir)
+ self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
+
+ def set_steps(self):
+ # Set global training step
+ self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
+ # Set epoch-wise step
+ self._wandb.define_metric('Global Train/*', step_metric='epoch')
+ self._wandb.define_metric('Global Test/*', step_metric='epoch')
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print('[{}] '.format(now.strftime("%H:%M:%S")), end='') # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+
+ os.environ['RANK'] = str(args.rank)
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0 or args.rank == 8 or args.rank == 16 or args.rank == 24)
+
+
+def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ def load(module, prefix=''):
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(model, prefix=prefix)
+
+ warn_missing_keys = []
+ ignore_missing_keys = []
+ for key in missing_keys:
+ keep_flag = True
+ for ignore_key in ignore_missing.split('|'):
+ if ignore_key in key:
+ keep_flag = False
+ break
+ if keep_flag:
+ warn_missing_keys.append(key)
+ else:
+ ignore_missing_keys.append(key)
+
+ missing_keys = warn_missing_keys
+
+ if len(missing_keys) > 0:
+ print("Weights of {} not initialized from pretrained model: {}".format(
+ model.__class__.__name__, missing_keys))
+ if len(unexpected_keys) > 0:
+ print("Weights from pretrained model not used in {}: {}".format(
+ model.__class__.__name__, unexpected_keys))
+ if len(ignore_missing_keys) > 0:
+ print("Ignored weights of {} not initialized from pretrained model: {}".format(
+ model.__class__.__name__, ignore_missing_keys))
+ if len(error_msgs) > 0:
+ print('\n'.join(error_msgs))
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
+
+
+def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
+ start_warmup_value=0, warmup_steps=-1):
+ warmup_schedule = np.array([])
+ warmup_iters = warmup_epochs * niter_per_ep
+ if warmup_steps > 0:
+ warmup_iters = warmup_steps
+ print("Set warmup steps = %d" % warmup_iters)
+ if warmup_epochs > 0:
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
+ schedule = np.array(
+ [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
+
+ schedule = np.concatenate((warmup_schedule, schedule))
+
+ assert len(schedule) == epochs * niter_per_ep
+ return schedule
+
+def save_model(args, epoch, epoch_name, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
+ output_dir = Path(args.output_dir)
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
+ for checkpoint_path in checkpoint_paths:
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }
+
+ if model_ema is not None:
+ to_save['model_ema'] = get_state_dict(model_ema)
+
+ save_on_master(to_save, checkpoint_path)
+
+ if is_main_process() and isinstance(epoch, int):
+ to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
+ old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
+ if os.path.exists(old_ckpt):
+ os.remove(old_ckpt)
+
+
+def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
+ output_dir = Path(args.output_dir)
+ if args.auto_resume and len(args.resume) == 0:
+ import glob
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
+ latest_ckpt = -1
+ for ckpt in all_checkpoints:
+ t = ckpt.split('-')[-1].split('.')[0]
+ if t.isdigit():
+ latest_ckpt = max(int(t), latest_ckpt)
+ if latest_ckpt >= 0:
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
+ print("Auto resume checkpoint: %s" % args.resume)
+
+ if args.resume:
+ if args.resume.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.resume, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ msg = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
+ print(msg)
+ print("Resume checkpoint %s" % args.resume)
+
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not args.resume_new_sched:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
+ args.start_epoch = checkpoint['epoch'] + 1
+ else:
+ assert args.eval, 'Does not support resuming with checkpoint-best'
+ if hasattr(args, 'model_ema') and args.model_ema:
+ if 'model_ema' in checkpoint.keys():
+ model_ema.ema.load_state_dict(checkpoint['model_ema'])
+ else:
+ model_ema.ema.load_state_dict(checkpoint['model'])
+ if 'scaler' in checkpoint:
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ print("With optim & sched!")
+ elif hasattr(args, 'model_ema') and args.model_ema and 'model_ema' in checkpoint.keys():
+ if isinstance(model_ema, list):
+ for tmp_model_ema in model_ema:
+ msg = tmp_model_ema.ema.load_state_dict(checkpoint['model_ema'], strict=False)
+ print(msg)
+ print("Resume EMA checkpoint %s" % args.resume)
+ else:
+ msg = model_ema.ema.load_state_dict(checkpoint['model_ema'])
+ print(msg)
+ print("Resume EMA checkpoint %s" % args.resume)