diff --git a/README.md b/README.md index ea3afcc..4ee8847 100644 --- a/README.md +++ b/README.md @@ -1 +1,99 @@ -# evals \ No newline at end of file +# evals + +This project provides a Python implementation for training and evaluating a simple neural network on the MNIST dataset using PyTorch. + +## MNISTTrainer class + +The `MNISTTrainer` class is used to train and evaluate the model. It is defined in `src/main.py`. + +### Usage + +First, create an instance of the `MNISTTrainer` class: + +```python +trainer = MNISTTrainer() +``` + +You can then train the model using the `train` method: + +```python +trainer.train() +``` + +The trained model's parameters are automatically saved to a file named "mnist_model.pth". + +To load the model parameters from this file, use the following code: + +```python +trainer.model.load_state_dict(torch.load("mnist_model.pth")) +``` + +To evaluate the model, you can use the `predict` method in `src/api.py`. This method takes an image file as input and returns the model's prediction. + +## Dependencies + +This project requires the following Python libraries: + +- PyTorch +- torchvision +- numpy +- PIL +- FastAPI + +Before installing these dependencies, ensure that `pip` is installed and working correctly. You can verify this by running the following command: + +```bash +pip --version +``` + +If `pip` is not installed or not working correctly, you may need to troubleshoot your Python installation or install `pip` separately. + +Once `pip` is working correctly, you can install the `poetry` package manager using `pip`: + +```bash +pip install poetry +``` + +After installing `poetry`, verify that it was installed correctly by checking its version: + +```bash +poetry --version +``` + +If the command does not return a version number, `poetry` was not installed correctly. + +If you encounter issues with the `poetry` installation, you may need to troubleshoot your `poetry` installation. Here are some steps you can take: + +1. Check if the `poetry` executable is in your system's PATH. You can do this by running the following command: + +```bash +which poetry +``` + +If the `poetry` executable is in your PATH, this command will print its location. If it's not, it won't print anything. + +2. If the `poetry` executable is not in your PATH, you need to add it. The process for this varies depending on your operating system and shell, but generally involves adding a line to a shell startup file like `~/.bashrc` or `~/.bash_profile` that exports the `poetry` executable's location to the PATH. You can find more detailed instructions in the `poetry` documentation or by searching online for "add to PATH". + +3. If you're still having trouble, you can seek help from the `poetry` community or check the `poetry` documentation for more troubleshooting tips. Here's the link to the `poetry` documentation: https://python-poetry.org/docs/ + +Once `poetry` is installed, you can install the project dependencies using `poetry`: + +```bash +poetry install +``` + +If you are unable to install `poetry`, you can install the project dependencies directly using `pip`: + +```bash +pip install torch torchvision numpy pillow fastapi +``` + +## Running the project + +To run the project, first start the FastAPI server: + +```bash +uvicorn src.api:app --reload +``` + +You can then send a POST request to the `/predict` endpoint with an image file to get the model's prediction. \ No newline at end of file diff --git a/src/api.py b/src/api.py index 36c257a..f435a48 100644 --- a/src/api.py +++ b/src/api.py @@ -1,19 +1,36 @@ +import os +import subprocess +import sys + +# Check if pip is installed +try: + subprocess.run(["pip", "--version"], check=True) +except subprocess.CalledProcessError: + print("Error: pip is not installed. Please install pip and try again.") + sys.exit(1) + +# Check if poetry is installed, if not, install it +try: + subprocess.run(["poetry", "--version"], check=True) +except subprocess.CalledProcessError: + # Download the get-poetry.py script + subprocess.run(["curl", "-sSL", "https://install.python-poetry.org", "-o", "get-poetry.py"], check=True) + # Execute the get-poetry.py script to install poetry + subprocess.run(["python", "get-poetry.py", "--yes"], check=True) + from fastapi import FastAPI, UploadFile, File from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import MNISTTrainer # Importing MNISTTrainer class from main.py -# Load the model -model = Net() -model.load_state_dict(torch.load("mnist_model.pth")) -model.eval() +# Create an instance of MNISTTrainer and load the model +trainer = MNISTTrainer() +trainer.model.load_state_dict(torch.load("mnist_model.pth")) +trainer.model.eval() # Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = trainer.transform app = FastAPI() @@ -23,6 +40,6 @@ async def predict(file: UploadFile = File(...)): image = transform(image) image = image.unsqueeze(0) # Add batch dimension with torch.no_grad(): - output = model(image) + output = trainer.model(image) _, predicted = torch.max(output.data, 1) return {"prediction": int(predicted[0])} diff --git a/src/main.py b/src/main.py index 243a31e..114bf16 100644 --- a/src/main.py +++ b/src/main.py @@ -6,43 +6,49 @@ from torch.utils.data import DataLoader import numpy as np -# Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +class MNISTTrainer: + def __init__(self, epochs=3): + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + self.trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + self.trainloader = DataLoader(self.trainset, batch_size=64, shuffle=True) + self.testset = datasets.MNIST('.', download=True, train=False, transform=self.transform) + self.testloader = DataLoader(self.testset, batch_size=64, shuffle=True) + self.model = self._define_model() + self.optimizer = optim.SGD(self.model.parameters(), lr=0.01) + self.criterion = nn.NLLLoss() + self.epochs = epochs -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) -trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + def _define_model(self): + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 10) -# Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) + def forward(self, x): + x = x.view(-1, 28 * 28) + x = nn.functional.relu(self.fc1(x)) + x = nn.functional.relu(self.fc2(x)) + x = self.fc3(x) + return nn.functional.log_softmax(x, dim=1) + return Net() -# Step 3: Train the Model -model = Net() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() + def train(self): + for epoch in range(self.epochs): + for images, labels in self.trainloader: + self.optimizer.zero_grad() + output = self.model(images) + loss = self.criterion(output, labels) + loss.backward() + self.optimizer.step() -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + torch.save(self.model.state_dict(), "mnist_model.pth") + +trainer = MNISTTrainer() +trainer.train() torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file