diff --git a/models/vanillanet.py b/models/vanillanet.py index c6417a3..e04c10f 100644 --- a/models/vanillanet.py +++ b/models/vanillanet.py @@ -10,7 +10,7 @@ from timm.models.layers import weight_init, DropPath from timm.models.registry import register_model - +# Series informed activation function. Implemented by conv. class activation(nn.ReLU): def __init__(self, dim, act_num=3, deploy=False): super(activation, self).__init__() @@ -84,7 +84,10 @@ def forward(self, x): x = self.conv(x) else: x = self.conv1(x) + + # We use leakyrelu to implement the deep training technique. x = torch.nn.functional.leaky_relu(x,self.act_learn) + x = self.conv2(x) x = self.pool(x)