From 45766ad97c52933bf80f37d378185d120b577ba5 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Thu, 12 Oct 2023 22:33:24 +0000 Subject: [PATCH 1/2] feat: Updated src/main.py --- src/main.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..6ccd3c6 100644 --- a/src/main.py +++ b/src/main.py @@ -6,16 +6,7 @@ from torch.utils.data import DataLoader import numpy as np -# Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) - -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) -trainloader = DataLoader(trainset, batch_size=64, shuffle=True) - -# Step 2: Define the PyTorch Model +# Define the PyTorch Model class Net(nn.Module): def __init__(self): super().__init__() @@ -30,19 +21,28 @@ def forward(self, x): x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) -# Step 3: Train the Model -model = Net() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() +class MNISTTrainer: + def __init__(self): + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + def load_data(self): + """Load and preprocess the MNIST data.""" + trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + return trainloader -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file + def train(self, model, criterion, optimizer): + """Train the model using the provided criterion and optimizer.""" + trainloader = self.load_data() + epochs = 3 + for epoch in range(epochs): + for images, labels in trainloader: + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file From d282480f69af768446414e74ae9262f69630b4cb Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Thu, 12 Oct 2023 22:35:55 +0000 Subject: [PATCH 2/2] feat: Updated src/api.py --- src/api.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..e530b0e 100644 --- a/src/api.py +++ b/src/api.py @@ -2,18 +2,18 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import MNISTTrainer # Importing MNISTTrainer class from main.py + +# Create an instance of the MNISTTrainer class +trainer = MNISTTrainer() # Load the model -model = Net() +model = trainer.model model.load_state_dict(torch.load("mnist_model.pth")) model.eval() -# Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# Transform used for preprocessing the image is now inside the MNISTTrainer class +transform = trainer.transform app = FastAPI()