Skip to content

Commit

Permalink
Merge pull request #10 from Sanster/support_cat
Browse files Browse the repository at this point in the history
Support cat and YOLOX
  • Loading branch information
Sanster authored Nov 8, 2021
2 parents e3fd71d + 6ed095d commit 5442a50
Show file tree
Hide file tree
Showing 11 changed files with 865 additions and 46 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ datasets
output
*/*.egg-info
venv/
test.json
16 changes: 2 additions & 14 deletions gen_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json

import torch
from pns.schema_post_process import mbv3_large_schema_post_process

from backbone.build import build_model
from pns import SlimPruner
Expand Down Expand Up @@ -47,20 +48,7 @@ def parse_args():
config["shortcuts"] = shortcuts

if "mobilenet_v3_large" in args.net and "nose" not in args.net:
# BN in block with se module should have same channels
se_shortcuts = []
modules = []
for m in config["modules"]:
if m["name"].endswith("fc2"):
# e.g: features.5.block
feature_name = ".".join(m["name"].split(".")[:-2])
m["next_bn"] = f"{feature_name}.1.1"
elif m["name"].endswith("fc1"):
m["next_bn"] = ""

modules.append(m)
# config["shortcuts"].extend(se_shortcuts)
config["modules"] = modules
mbv3_large_schema_post_process(config)

with open(args.save_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
Expand Down
109 changes: 93 additions & 16 deletions src/pns/pns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import reduce
from itertools import chain
from typing import Dict, List
from loguru import logger

import pandas as pd
from torch.nn import Conv2d, BatchNorm2d, Linear
Expand Down Expand Up @@ -47,12 +48,16 @@ def prune(self, prev_bn=None, next_bn=None):
if next_bn is not None:
assert next_bn.is_pruned

self.in_channels_keep_idxes, self.out_channels_keep_idxes = prune_conv2d(
self.pruned_module,
self.prune_by_idxes(
prev_bn.keep_idxes if prev_bn else None,
next_bn.keep_idxes if next_bn else None,
)

def prune_by_idxes(self, prev_bn_keep_idxes=None, next_bn_keep_idxes=None):
self.in_channels_keep_idxes, self.out_channels_keep_idxes = prune_conv2d(
self.pruned_module, prev_bn_keep_idxes, next_bn_keep_idxes
)

self.is_pruned = True

def prune_info(self):
Expand Down Expand Up @@ -207,6 +212,7 @@ def __init__(self, model, schema: str = None):
self.conv2d_modules = {}
self.bn2d_modules = {}
self.fc_modules = {}
self.cats = schema.get("cats", [])
self.shortcuts = schema.get("shortcuts", [])
self.depthwise_conv_adjacent_bn = schema.get(
"depthwise_conv_adjacent_bn", []
Expand Down Expand Up @@ -251,6 +257,18 @@ def _add_prefix_to_config_name(self, config):
"""
for it in config.get("shortcuts", []):
it["names"] = [prefix + _ for _ in it["names"]]

"""
"cats": [
{
"names": [ "bn1", "bn2"],
}
]
"""
for it in config.get("cats", []):
it["input_bn_names"] = [prefix + _ for _ in it["input_bn_names"]]
it["output_conv_names"] = [prefix + _ for _ in it["output_conv_names"]]

for it in config.get("depthwise_conv_adjacent_bn", []):
it["names"] = [prefix + _ for _ in it["names"]]

Expand Down Expand Up @@ -302,17 +320,33 @@ def run(self, ratio: float):
print("\nBatchNorm2d prune info")
print(df.to_markdown() + "\n")

convs_after_cat = self._collect_conv_after_cat()
conv2d_prune_info = []
# fmt: off
for conv2d in self.conv2d_modules.values():
conv2d.prune(
self.bn2d_modules[conv2d.prev_bn_name] if conv2d.prev_bn_name else None,
self.bn2d_modules[conv2d.next_bn_name] if conv2d.next_bn_name else None,
)
if conv2d.name in convs_after_cat:
conv2d.prune(
None,
self.bn2d_modules[conv2d.next_bn_name] if conv2d.next_bn_name else None,
)
else:
conv2d.prune(
self.bn2d_modules[conv2d.prev_bn_name] if conv2d.prev_bn_name else None,
self.bn2d_modules[conv2d.next_bn_name] if conv2d.next_bn_name else None,
)
conv2d_prune_info.append(conv2d.prune_info())
# fmt: on

cat_conv_prune_info = self._prune_cat_conv_prev_bn()

df = pd.DataFrame(conv2d_prune_info)
print("\nConv2d prune info")
print(df.to_markdown() + "\n")

df = pd.DataFrame(cat_conv_prune_info)
print("\nConv2d after cat prune info")
print(df.to_markdown() + "\n")

fc_prune_info = []
for linear in self.fc_modules.values():
if not linear.prev_bn_name:
Expand Down Expand Up @@ -357,16 +391,6 @@ def _export_pruning_result(self):
# return {self.PRUNING_RESULT_KEY: prune_result}
return prune_result

def _merge_depthwise_conv2d_adjacent_bn(self):
self._align_bns(
self.depthwise_conv_adjacent_bn,
min_keep_ratio=0.05,
log_name="depthwise conv bn",
)

def _merge_shortcuts(self):
self._align_bns(self.shortcuts, min_keep_ratio=0.05, log_name="shortcuts")

def _align_bns(self, bn_groups, min_keep_ratio: float, log_name: str):
"""
bn layer is changed inplace
Expand Down Expand Up @@ -430,3 +454,56 @@ def _apply_fix_bn_ratio(self):
for _name in name:
if _name in self.bn2d_modules:
self.bn2d_modules[_name].set_fixed_ratio(1 - ratio)

def _collect_conv_after_cat(self) -> List[str]:
if not self.cats:
return []

res = []
for cat_group in self.cats:
output_conv_names = cat_group["output_conv_names"]
for conv_name in output_conv_names:
if conv_name not in self.conv2d_modules:
raise RuntimeError(
f"{conv_name} not exist in {self.conv2d_modules.keys()}"
)
res.extend(output_conv_names)
return res

def _prune_cat_conv_prev_bn(self) -> List[Dict]:
if not self.cats:
return []

conv2d_prune_info = []

# 1. 收集 bn2d_modules 的 keep indexes,cat 起来
# 2. 剪后续卷积的 kernel
for cat_group in self.cats:
input_bn_names = cat_group["input_bn_names"]
output_conv_names = cat_group["output_conv_names"]

final_bn_keep_idxes = []
idx_offset = 0
for bn_name in input_bn_names:
bn_module = self.bn2d_modules.get(bn_name, None)
if bn_module is None:
raise RuntimeError(
f"{bn_name} not exist in {self.bn2d_modules.keys()}"
)

keep_idxes_offset = [it + idx_offset for it in bn_module.keep_idxes]
final_bn_keep_idxes.extend(keep_idxes_offset)
idx_offset += bn_module.in_channels()

for conv_name in output_conv_names:
if conv_name not in self.conv2d_modules:
raise RuntimeError(
f"{conv_name} not exist in {self.conv2d_modules.keys()}"
)

self.conv2d_modules[conv_name].prune_by_idxes(
prev_bn_keep_idxes=final_bn_keep_idxes
)
conv2d_prune_info.append(self.conv2d_modules[conv_name].prune_info())

return conv2d_prune_info
14 changes: 14 additions & 0 deletions src/pns/schema_post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def mbv3_large_schema_post_process(config):
# BN in block with se module should have same channels
se_shortcuts = []
modules = []
for m in config["modules"]:
if m["name"].endswith("fc2"):
# e.g: features.5.block
feature_name = ".".join(m["name"].split(".")[:-2])
m["next_bn"] = f"{feature_name}.1.1"
elif m["name"].endswith("fc1"):
m["next_bn"] = ""

modules.append(m)
config["modules"] = modules
80 changes: 76 additions & 4 deletions src/pns/tracker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import typing
from collections import defaultdict
from typing import Callable, Dict, List
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
self.module_output_names = defaultdict(list)

self.shortcuts_group = []
self.cat_group = []

def __enter__(self):
for hook in self.hooks:
Expand All @@ -141,6 +143,9 @@ def __init__(self, name, module):
def is_bn(self):
return isinstance(self.module, torch.nn.BatchNorm2d)

def is_cat(self):
pass

def is_conv(self):
return isinstance(self.module, torch.nn.Conv2d)

Expand All @@ -150,7 +155,7 @@ def is_fc(self):

def BFS_find_bn(
module_names: Dict[str, List], wrappers: Dict[str, ModuleWrapper], name: str
):
) -> str:
output_names = module_names[name]
while len(output_names) != 0:
output_name = output_names.pop(0)
Expand Down Expand Up @@ -308,6 +313,36 @@ def track_add(ctx: TrackContext):
set_outputs_name_attr(outputs, input_names)


@register_tracker("torch.cat")
def track_cat(ctx: TrackContext):
inputs = ctx.method_args
outputs = ctx.method_return
if len(inputs) == 0:
return

input_names = []
for input in inputs[0]:
name = getattr(input, TRACK_ATTR_NAME, None)
if name is None:
continue

if len(name) == 0:
continue

if isinstance(name, list):
# TODO: why 0??
# Works for YOLOX
input_names.append(name[0])
if tuple(name) in ctx.cat_group:
ctx.cat_group.remove(tuple(name))
else:
input_names.append(name)

if len(input_names) > 1:
ctx.cat_group.append(tuple(input_names))
set_outputs_name_attr(outputs, input_names)


def gen_pruning_schema(model, *args, **kwargs):
with TrackContext() as ctx:
bn_names = []
Expand All @@ -334,13 +369,22 @@ def gen_pruning_schema(model, *args, **kwargs):
if name in common_names:
target_wrappers[name] = ModuleWrapper(name, module)

info = {"modules": [], "shortcuts": [], "depthwise_conv_adjacent_bn": []}
info = {
"modules": [],
"shortcuts": [],
"cats": [],
"depthwise_conv_adjacent_bn": [],
}

module_input_names = copy.deepcopy(ctx.module_input_names)
module_output_names = copy.deepcopy(ctx.module_output_names)
for name, wrapper in target_wrappers.items():
if not (wrapper.is_conv() or wrapper.is_fc()):
continue

prev_bn = BFS_find_bn(ctx.module_input_names, target_wrappers, name)
next_bn = BFS_find_bn(ctx.module_output_names, target_wrappers, name)
# module_input_names will be consumed
prev_bn = BFS_find_bn(module_input_names, target_wrappers, name)
next_bn = BFS_find_bn(module_output_names, target_wrappers, name)
m = {"name": name, "prev_bn": prev_bn, "next_bn": next_bn}

info["modules"].append(m)
Expand All @@ -357,4 +401,32 @@ def gen_pruning_schema(model, *args, **kwargs):
continue
info["shortcuts"].append({"names": sorted(shortcuts), "method": "or"})

for cat in ctx.cat_group:
cat = list(filter(lambda it: it in bn_names, cat))
if len(cat) <= 1:
continue
"""
ctx.module_input_names:
{
"conv": ["bn1", "bn2"]
}
cat: ['bn1', 'bn2']
如果 module_input_names 中 conv 的值和 cat 的值一样,说明该 conv 是 cat 的输出
"""
# cat 不能 sort,因为顺序会影响 channel index
cats = {"input_bn_names": cat, "output_conv_names": []}
for name, values in ctx.module_input_names.items():
if name not in target_wrappers:
continue
if not target_wrappers[name].is_conv():
continue

if values == cat:
cats["output_conv_names"].append(name)

if len(cats["output_conv_names"]) != 0:
info["cats"].append(cats)

return info
2 changes: 1 addition & 1 deletion src/pns/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = '0.2.0'
23 changes: 18 additions & 5 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,31 @@

import torch
from pns import SlimPruner
from pns.schema_post_process import mbv3_large_schema_post_process
from pns.tracker import gen_pruning_schema
from torch import Tensor


def check_gen_schema(model):
x = torch.Tensor(1, 3, 224, 224)
def check_gen_schema(model, net: str = "", in_channel: int = 3):
x = torch.Tensor(1, in_channel, 224, 224)
y = model(x)
if isinstance(y, Tensor):
print(f"model ouput shape: {y.shape}")
elif isinstance(y, dict):
for k, v in y.items():
print(f"{k} {v.shape}")

config = gen_pruning_schema(model, x)

with open("test.json", 'w') as f:
if "mobilenet_v3_large" in net and "nose" not in net:
mbv3_large_schema_post_process(config)

with tempfile.NamedTemporaryFile("w") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
f.flush()
pruner = SlimPruner(model, f.name)
pruner.run(0.6)
pruner.pruned_model.eval()
x = torch.Tensor(1, 3, 224, 224)
pruner.pruned_model(x)
x = torch.Tensor(1, in_channel, 224, 224)
y = pruner.pruned_model(x)
print(y)
Loading

0 comments on commit 5442a50

Please sign in to comment.