From e69d94b43c55e16714f80c4d4e85ad0f3c76ee78 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sat, 25 Nov 2023 07:27:14 +0000 Subject: [PATCH 1/2] feat: Updated src/main.py --- src/main.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..2e33cf5 100644 --- a/src/main.py +++ b/src/main.py @@ -1,10 +1,11 @@ -from PIL import Image +import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms +from cnn import CNN +from PIL import Image from torch.utils.data import DataLoader -import numpy as np +from torchvision import datasets, transforms # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -31,7 +32,7 @@ def forward(self, x): return nn.functional.log_softmax(x, dim=1) # Step 3: Train the Model -model = Net() +model = CNN() optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() From 325c8f9825efef3d1e61b23a95db7b8b0f99e1f0 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sat, 25 Nov 2023 07:29:48 +0000 Subject: [PATCH 2/2] feat: Updated src/api.py --- src/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..1c82835 100644 --- a/src/api.py +++ b/src/api.py @@ -1,8 +1,8 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image import torch +from cnn import CNN # Importing CNN class from cnn.py +from fastapi import FastAPI, File, UploadFile +from PIL import Image from torchvision import transforms -from main import Net # Importing Net class from main.py # Load the model model = Net()