diff --git a/models/vanillanet.py b/models/vanillanet.py index e9e22d5..35fbd45 100644 --- a/models/vanillanet.py +++ b/models/vanillanet.py @@ -123,7 +123,6 @@ def __init__(self, in_chans=3, num_classes=1000, dims=[96, 192, 384, 768], if self.deploy: self.stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), - nn.Conv2d(dims[0], dims[0], kernel_size=1, stride=1), activation(dims[0], act_num) ) else: @@ -214,9 +213,9 @@ def switch_to_deploy(self): self.stem1[0].weight.data = kernel self.stem1[0].bias.data = bias kernel, bias = self._fuse_bn_tensor(self.stem2[0], self.stem2[1]) - self.stem2[0].weight.data = kernel - self.stem2[0].bias.data = bias - self.stem = torch.nn.Sequential(*[self.stem1[0], self.stem2[0], self.stem2[2]]) + self.stem1[0].weight.data = torch.einsum('oi,icjk->ocjk', kernel.squeeze(3).squeeze(2), self.stem1[0].weight.data) + self.stem1[0].bias.data = bias + (self.stem1[0].bias.data.view(1,-1,1,1)*kernel).sum(3).sum(2).sum(1) + self.stem = torch.nn.Sequential(*[self.stem1[0], self.stem2[2]]) self.__delattr__('stem1') self.__delattr__('stem2')