-
Notifications
You must be signed in to change notification settings - Fork 0
/
complexity.py
48 lines (35 loc) · 1.12 KB
/
complexity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import logging
import os
from model import DCRNet
from omegaconf import DictConfig, OmegaConf
import hydra
import thop
# A logger for this file
log = logging.getLogger(__name__)
def handle_config(cfg):
print(OmegaConf.to_yaml(cfg))
return cfg
@hydra.main(version_base=None, config_path="config", config_name="config")
def main(cfg):
# Load Config
cfg = handle_config(cfg)
# Set device
if cfg.gpu is None:
device = torch.device("cpu")
else:
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu)
device = torch.device("cuda", cfg.gpu)
torch.backends.cudnn.benchmark = True
print("Using Device: {}".format(device))
# Load Model
model = DCRNet(cfg.db.shape, reduction=cfg.reduction, expansion=cfg.expansion).to(device)
# Calculate Complexity
input_ = torch.randn([1, 2, 32, 32]).to(device)
flops, params = thop.profile(model, inputs=(input_,), verbose=False)
flops, params = thop.clever_format([flops, params], "%.3f")
log.info(f"Expansions: {cfg.expansion}, Reductions: {cfg.reduction}")
log.info("FLOPs: {}".format(flops))
log.info("Params: {}".format(params))
if __name__ == "__main__":
main()