From b982817093acce60630d1721b1bbe0d5d92edefd Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:01:10 +0000 Subject: [PATCH 1/6] feat: Updated src/main.py --- src/main.py | 76 +++++++++++++++++++++++++++-------------------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..3b8acc3 100644 --- a/src/main.py +++ b/src/main.py @@ -1,48 +1,50 @@ -from PIL import Image import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader -import numpy as np -# Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +class MNISTTrainer: + def __init__(self): + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + self.optimizer = None + self.criterion = nn.NLLLoss() -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) -trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + def load_data(self): + trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + return DataLoader(trainset, batch_size=64, shuffle=True) -# Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = nn.functional.relu(self.fc1(x)) + x = nn.functional.relu(self.fc2(x)) + x = self.fc3(x) + return nn.functional.log_softmax(x, dim=1) -# Step 3: Train the Model -model = Net() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() + def define_model(self): + model = self.Net() + self.optimizer = optim.SGD(model.parameters(), lr=0.01) + return model -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + def train_model(self, model, trainloader): + epochs = 3 + for epoch in range(epochs): + for images, labels in trainloader: + self.optimizer.zero_grad() + output = model(images) + loss = self.criterion(output, labels) + loss.backward() + self.optimizer.step() -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file + def save_model(self, model): + torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file From 97e4498dba89e414e0dbb14a30f40ff9f7c9c24f Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:02:14 +0000 Subject: [PATCH 2/6] Sandbox run src/main.py --- src/main.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/main.py b/src/main.py index 3b8acc3..ed553e1 100644 --- a/src/main.py +++ b/src/main.py @@ -1,20 +1,22 @@ import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torchvision import datasets, transforms + class MNISTTrainer: def __init__(self): - self.transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) + self.transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] + ) self.optimizer = None self.criterion = nn.NLLLoss() def load_data(self): - trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + trainset = datasets.MNIST( + ".", download=True, train=True, transform=self.transform + ) return DataLoader(trainset, batch_size=64, shuffle=True) class Net(nn.Module): @@ -38,7 +40,7 @@ def define_model(self): def train_model(self, model, trainloader): epochs = 3 - for epoch in range(epochs): + for _epoch in range(epochs): for images, labels in trainloader: self.optimizer.zero_grad() output = model(images) @@ -47,4 +49,4 @@ def train_model(self, model, trainloader): self.optimizer.step() def save_model(self, model): - torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file + torch.save(model.state_dict(), "mnist_model.pth") From 47c102b32cfc4e7e1182991371b82d34d40033f4 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:02:17 +0000 Subject: [PATCH 3/6] feat: Updated src/api.py --- src/api.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..99c9d23 100644 --- a/src/api.py +++ b/src/api.py @@ -2,12 +2,9 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py - -# Load the model -model = Net() -model.load_state_dict(torch.load("mnist_model.pth")) -model.eval() +from main import MNISTTrainer # Importing MNISTTrainer class from main.py +trainer = MNISTTrainer() +model = trainer.load_model("mnist_model.pth") # Transform used for preprocessing the image transform = transforms.Compose([ @@ -19,10 +16,11 @@ @app.post("/predict/") async def predict(file: UploadFile = File(...)): + """Predict the digit in the uploaded image using the loaded model.""" image = Image.open(file.file).convert("L") image = transform(image) image = image.unsqueeze(0) # Add batch dimension with torch.no_grad(): - output = model(image) + output = trainer.predict(image) # Use the predict method of the trainer _, predicted = torch.max(output.data, 1) return {"prediction": int(predicted[0])} From 2a95177f84ddcb968f4d4dbeab259885f071c23b Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:02:58 +0000 Subject: [PATCH 4/6] Sandbox run src/api.py --- src/api.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/api.py b/src/api.py index 99c9d23..1df68e7 100644 --- a/src/api.py +++ b/src/api.py @@ -1,19 +1,21 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image import torch +from fastapi import FastAPI, File, UploadFile +from PIL import Image from torchvision import transforms + from main import MNISTTrainer # Importing MNISTTrainer class from main.py + trainer = MNISTTrainer() model = trainer.load_model("mnist_model.pth") # Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) app = FastAPI() + @app.post("/predict/") async def predict(file: UploadFile = File(...)): """Predict the digit in the uploaded image using the loaded model.""" From d757ce4a3981ec2e43774e13f635d7ea252df40e Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:04:48 +0000 Subject: [PATCH 5/6] feat: Updated src/main.py --- src/main.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/main.py b/src/main.py index ed553e1..ce8bb2f 100644 --- a/src/main.py +++ b/src/main.py @@ -50,3 +50,9 @@ def train_model(self, model, trainloader): def save_model(self, model): torch.save(model.state_dict(), "mnist_model.pth") + + def load_model(self, model_path): + model = self.define_model() + model.load_state_dict(torch.load(model_path)) + model.eval() + return model From ba73a2aaf3d1e3a88dffc22cd1e2bae4683d501e Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:05:41 +0000 Subject: [PATCH 6/6] Sandbox run src/main.py --- src/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.py b/src/main.py index ce8bb2f..09508a1 100644 --- a/src/main.py +++ b/src/main.py @@ -50,7 +50,7 @@ def train_model(self, model, trainloader): def save_model(self, model): torch.save(model.state_dict(), "mnist_model.pth") - + def load_model(self, model_path): model = self.define_model() model.load_state_dict(torch.load(model_path))