From 3a75d537537a005e8a8c27af0781c2a36bcebbc8 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 23:08:31 +0000 Subject: [PATCH 1/2] feat: add CNN class for MNIST classification --- src/cnn.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/cnn.py diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 0000000..60eb97e --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import transforms + +class CNN(nn.Module): + """ + Convolutional Neural Network for MNIST classification. + """ + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.fc1 = nn.Linear(7*7*64, 128) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = x.view(-1, 7*7*64) + x = self.relu3(self.fc1(x)) + return self.fc2(x) + + def train(self, trainloader, lr, epochs): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + for epoch in range(epochs): + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + inputs, labels = data + optimizer.zero_grad() + outputs = self(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + running_loss += loss.item() + print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}") + + def save_model(self, file_path): + torch.save(self.state_dict(), file_path) From 90bbe88f7d4709422a3e1e48d9b987498777632b Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 23:09:52 +0000 Subject: [PATCH 2/2] feat: Updated src/main.py --- src/main.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..f361f2e 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np +from cnn import CNN # Import the CNN class # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -16,19 +17,14 @@ trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) +# Create an instance of the CNN class +cnn = CNN() + +# Train the CNN +cnn.train(trainloader, lr=0.001, epochs=10) + +# Save the trained model +cnn.save_model("mnist_model.pth") # Step 3: Train the Model model = Net()