diff --git a/src/api.py b/src/api.py index 36c257a..2270bd1 100644 --- a/src/api.py +++ b/src/api.py @@ -2,10 +2,13 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import Net, MNISTDataLoader # Importing Net and MNISTDataLoader classes from main.py + +# Instantiate the MNISTDataLoader +data_loader = MNISTDataLoader() # Load the model -model = Net() +model = Net(data_loader) model.load_state_dict(torch.load("mnist_model.pth")) model.eval() diff --git a/src/main.py b/src/main.py index 243a31e..fc02a55 100644 --- a/src/main.py +++ b/src/main.py @@ -6,19 +6,21 @@ from torch.utils.data import DataLoader import numpy as np -# Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +class MNISTDataLoader: + def __init__(self): + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) -trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + self.trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + self.trainloader = DataLoader(self.trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model class Net(nn.Module): - def __init__(self): + def __init__(self, data_loader): super().__init__() + self.data_loader = data_loader self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) @@ -31,14 +33,15 @@ def forward(self, x): return nn.functional.log_softmax(x, dim=1) # Step 3: Train the Model -model = Net() +data_loader = MNISTDataLoader() +model = Net(data_loader) optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() # Training loop epochs = 3 for epoch in range(epochs): - for images, labels in trainloader: + for images, labels in data_loader.trainloader: optimizer.zero_grad() output = model(images) loss = criterion(output, labels)