Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor training loop from script to class #17

Closed
wants to merge 10 commits into from
100 changes: 99 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,99 @@
# evals
# evals

This project provides a Python implementation for training and evaluating a simple neural network on the MNIST dataset using PyTorch.

## MNISTTrainer class

The `MNISTTrainer` class is used to train and evaluate the model. It is defined in `src/main.py`.

### Usage

First, create an instance of the `MNISTTrainer` class:

```python
trainer = MNISTTrainer()
```

You can then train the model using the `train` method:

```python
trainer.train()
```

The trained model's parameters are automatically saved to a file named "mnist_model.pth".

To load the model parameters from this file, use the following code:

```python
trainer.model.load_state_dict(torch.load("mnist_model.pth"))
```

To evaluate the model, you can use the `predict` method in `src/api.py`. This method takes an image file as input and returns the model's prediction.

## Dependencies

This project requires the following Python libraries:

- PyTorch
- torchvision
- numpy
- PIL
- FastAPI

Before installing these dependencies, ensure that `pip` is installed and working correctly. You can verify this by running the following command:

```bash
pip --version
```

If `pip` is not installed or not working correctly, you may need to troubleshoot your Python installation or install `pip` separately.

Once `pip` is working correctly, you can install the `poetry` package manager using `pip`:

```bash
pip install poetry
```

After installing `poetry`, verify that it was installed correctly by checking its version:

```bash
poetry --version
```

If the command does not return a version number, `poetry` was not installed correctly.

If you encounter issues with the `poetry` installation, you may need to troubleshoot your `poetry` installation. Here are some steps you can take:

1. Check if the `poetry` executable is in your system's PATH. You can do this by running the following command:

```bash
which poetry
```

If the `poetry` executable is in your PATH, this command will print its location. If it's not, it won't print anything.

2. If the `poetry` executable is not in your PATH, you need to add it. The process for this varies depending on your operating system and shell, but generally involves adding a line to a shell startup file like `~/.bashrc` or `~/.bash_profile` that exports the `poetry` executable's location to the PATH. You can find more detailed instructions in the `poetry` documentation or by searching online for "add to PATH".

3. If you're still having trouble, you can seek help from the `poetry` community or check the `poetry` documentation for more troubleshooting tips. Here's the link to the `poetry` documentation: https://python-poetry.org/docs/

Once `poetry` is installed, you can install the project dependencies using `poetry`:

```bash
poetry install
```

If you are unable to install `poetry`, you can install the project dependencies directly using `pip`:

```bash
pip install torch torchvision numpy pillow fastapi
```

## Running the project

To run the project, first start the FastAPI server:

```bash
uvicorn src.api:app --reload
```

You can then send a POST request to the `/predict` endpoint with an image file to get the model's prediction.
37 changes: 27 additions & 10 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
import os
import subprocess
import sys

# Check if pip is installed
try:
subprocess.run(["pip", "--version"], check=True)
except subprocess.CalledProcessError:
print("Error: pip is not installed. Please install pip and try again.")
sys.exit(1)

# Check if poetry is installed, if not, install it
try:
subprocess.run(["poetry", "--version"], check=True)
except subprocess.CalledProcessError:
# Download the get-poetry.py script
subprocess.run(["curl", "-sSL", "https://install.python-poetry.org", "-o", "get-poetry.py"], check=True)
# Execute the get-poetry.py script to install poetry
subprocess.run(["python", "get-poetry.py", "--yes"], check=True)

from fastapi import FastAPI, UploadFile, File
from PIL import Image
import torch
from torchvision import transforms
from main import Net # Importing Net class from main.py
from main import MNISTTrainer # Importing MNISTTrainer class from main.py

# Load the model
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()
# Create an instance of MNISTTrainer and load the model
trainer = MNISTTrainer()
trainer.model.load_state_dict(torch.load("mnist_model.pth"))
trainer.model.eval()

# Transform used for preprocessing the image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform = trainer.transform

app = FastAPI()

Expand All @@ -23,6 +40,6 @@ async def predict(file: UploadFile = File(...)):
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image)
output = trainer.model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}
74 changes: 40 additions & 34 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,49 @@
from torch.utils.data import DataLoader
import numpy as np

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
class MNISTTrainer:
def __init__(self, epochs=3):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform)
self.trainloader = DataLoader(self.trainset, batch_size=64, shuffle=True)
self.testset = datasets.MNIST('.', download=True, train=False, transform=self.transform)
self.testloader = DataLoader(self.testset, batch_size=64, shuffle=True)
self.model = self._define_model()
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
self.criterion = nn.NLLLoss()
self.epochs = epochs

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
def _define_model(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

# Step 2: Define the PyTorch Model
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)
return Net()

# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()
def train(self):
for epoch in range(self.epochs):
for images, labels in self.trainloader:
self.optimizer.zero_grad()
output = self.model(images)
loss = self.criterion(output, labels)
loss.backward()
self.optimizer.step()

# Training loop
epochs = 3
for epoch in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
torch.save(self.model.state_dict(), "mnist_model.pth")

trainer = MNISTTrainer()
trainer.train()

torch.save(model.state_dict(), "mnist_model.pth")