Skip to content

Commit

Permalink
Update vanillanet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
HantingChen authored May 23, 2023
1 parent 7cfaa57 commit 1aafbb7
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions models/vanillanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def __init__(self, in_chans=3, num_classes=1000, dims=[96, 192, 384, 768],
if self.deploy:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
nn.Conv2d(dims[0], dims[0], kernel_size=1, stride=1),
activation(dims[0], act_num)
)
else:
Expand Down Expand Up @@ -214,9 +213,9 @@ def switch_to_deploy(self):
self.stem1[0].weight.data = kernel
self.stem1[0].bias.data = bias
kernel, bias = self._fuse_bn_tensor(self.stem2[0], self.stem2[1])
self.stem2[0].weight.data = kernel
self.stem2[0].bias.data = bias
self.stem = torch.nn.Sequential(*[self.stem1[0], self.stem2[0], self.stem2[2]])
self.stem1[0].weight.data = torch.einsum('oi,icjk->ocjk', kernel.squeeze(3).squeeze(2), self.stem1[0].weight.data)
self.stem1[0].bias.data = bias + (self.stem1[0].bias.data.view(1,-1,1,1)*kernel).sum(3).sum(2).sum(1)
self.stem = torch.nn.Sequential(*[self.stem1[0], self.stem2[2]])
self.__delattr__('stem1')
self.__delattr__('stem2')

Expand Down

0 comments on commit 1aafbb7

Please sign in to comment.