diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 0000000..a9b6209 --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class CNN(nn.Module): + """ + Convolutional Neural Network class for handling MNIST dataset. + Inherits from nn.Module. + """ + def __init__(self): + """ + Initialize the layers of the network. + """ + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=5) + self.conv2 = nn.Conv2d(32, 64, kernel_size=5) + self.pool = nn.MaxPool2d(2, 2) + self.fc1 = nn.Linear(64 * 4 * 4, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + """ + Define the forward pass of the network. + """ + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 64 * 4 * 4) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) diff --git a/src/main.py b/src/main.py index 243a31e..ce60cc9 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np +from cnn import CNN # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -31,7 +32,7 @@ def forward(self, x): return nn.functional.log_softmax(x, dim=1) # Step 3: Train the Model -model = Net() +model = CNN() optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss()