From 7b05c9f9dc0ad0be6f74afe57501f3f17026b032 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Thu, 12 Oct 2023 22:07:02 +0000 Subject: [PATCH] feat: Add CNN class for MNIST dataset --- src/cnn.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/cnn.py diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 0000000..2062dfd --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class CNN(nn.Module): + """ + Convolutional Neural Network (CNN) class. + """ + def __init__(self): + """ + Initialize the layers of the CNN. + """ + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.fc1 = nn.Linear(32 * 14 * 14, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + """ + Implement the forward pass of the CNN. + """ + x = F.relu(self.conv1(x)) + x = self.pool(x) + x = x.view(-1, 32 * 14 * 14) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1)