-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmodels.py
65 lines (53 loc) · 2.14 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class convnet(nn.Module):
def __init__(self,num_classes=10):
super(convnet,self).__init__()
self.bn0 = nn.BatchNorm2d(3)
self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32,32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32,64, kernel_size=3, stride=2, padding=1)
self.conv4 = nn.Conv2d(64,64, kernel_size=3, stride=1, padding=1)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = self.bn0(x)
x = self.conv1(x)
x = self.relu(x) # 28x28
x = self.maxpool(x) # 14x14
x = self.conv2(x)
x = self.relu(x) #14x14
feat_out = x
x = self.conv3(x)
x = self.relu(x) # 7x7
x = self.conv4(x)
x = self.relu(x) # 7x7
feat_low = x
feat_low = self.avgpool(feat_low)
feat_low = feat_low.view(feat_low.size(0),-1)
y_low = self.fc(feat_low)
return feat_out, y_low
class Predictor(nn.Module):
def __init__(self, input_ch=32, num_classes=8):
super(Predictor, self).__init__()
self.pred_conv1 = nn.Conv2d(input_ch, input_ch, kernel_size=3,
stride=1, padding=1)
self.pred_bn1 = nn.BatchNorm2d(input_ch)
self.relu = nn.ReLU(inplace=True)
self.pred_conv2 = nn.Conv2d(input_ch, num_classes, kernel_size=3,
stride=1, padding=1)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.pred_conv1(x)
x = self.pred_bn1(x)
x = self.relu(x)
x = self.pred_conv2(x)
px = self.softmax(x)
return x,px