diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 0000000..2062dfd --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class CNN(nn.Module): + """ + Convolutional Neural Network (CNN) class. + """ + def __init__(self): + """ + Initialize the layers of the CNN. + """ + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.fc1 = nn.Linear(32 * 14 * 14, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + """ + Implement the forward pass of the CNN. + """ + x = F.relu(self.conv1(x)) + x = self.pool(x) + x = x.view(-1, 32 * 14 * 14) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1)