Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1 +1,29 @@
"""Entry point for the API (e.g., FastAPI or Flask)."""
from fastapi import FastAPI
from backend.app.controllers.segmentation_controller import router as segmentation_router
from backend.app.controllers.training_controller import router as training_router
from backend.app.controllers.user_controller import router as user_router
from backend.app.db.db import init_db
import uvicorn


def create_app():
# Initialize the FastAPI app
app = FastAPI()

# Include the routers (controllers) for different functionality
app.include_router(segmentation_router, prefix="/segmentation", tags=["segmentation"])
app.include_router(training_router, prefix="/training", tags=["training"])
app.include_router(user_router, prefix="/user", tags=["user"])

# Initialize the database
init_db()

return app


# Create the FastAPI app
app = create_app()

if __name__ == "__main__":
# Run the application with Uvicorn server
uvicorn.run(app, host="0.0.0.0", port=8000)
22 changes: 22 additions & 0 deletions backend/app/api/routes.py
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
"""API routes for model inference, training, etc."""

from fastapi import APIRouter, UploadFile, File
from backend.app.services.segmentation_service import SegmentationService
from backend.app.services.training_service import TrainingService

router = APIRouter()

segmentation_service = SegmentationService()
training_service = TrainingService()


@router.post("/segment")
async def segment_image(file: UploadFile, user_id: int):
image = Image.open(file.file)
result = segmentation_service.predict_and_save(image, user_id)
return {"segmentation_result": result}


@router.post("/train")
async def train_model():
training_service.train(dataset="your_dataset_path")
return {"status": "Training started"}
29 changes: 29 additions & 0 deletions backend/app/controllers/segmentation_controller.py
Original file line number Diff line number Diff line change
@@ -1 +1,30 @@
"""Handles segmentation API requests (e.g., POST image, GET result)."""
from fastapi import APIRouter, UploadFile, File, HTTPException
from backend.app.services.segmentation_service import SegmentationService

router = APIRouter()
segmentation_service = SegmentationService()


@router.post("/segment")
async def segment_image(file: UploadFile = File(...)):
try:
# Save the uploaded file locally or directly upload to a temporary S3 location
file_location = f"temp/{file.filename}"
with open(file_location, "wb") as buffer:
buffer.write(await file.read())

# Perform segmentation and get results
result = segmentation_service.segment_image(file_location)
return {"message": "Segmentation successful", "result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@router.get("/results")
async def get_segmentation_results(limit: int = 10):
try:
results = segmentation_service.repo.get_segmentation_results(limit)
return {"segmentation_results": results}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
37 changes: 37 additions & 0 deletions backend/app/controllers/training_controller.py
Original file line number Diff line number Diff line change
@@ -1 +1,38 @@
"""Handles training-related requests (e.g., POST training data)."""

# backend/app/controllers/training_controller.py

from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from backend.services.training_service import TrainingService
from backend.db.db import get_db
from sqlalchemy.orm import Session
import os

router = APIRouter()


# Pydantic model to receive training parameters
class TrainRequest(BaseModel):
epochs: int = 10
batch_size: int = 32
learning_rate: float = 1e-4


@router.post("/train")
async def start_training(
train_request: TrainRequest,
db: Session = Depends(get_db)
):
try:
# Initialize training service with the parameters
training_service = TrainingService(db, train_request.epochs, train_request.batch_size,
train_request.learning_rate)

# Start the training process (this can be a background task as well)
training_service.start_training()

return {"message": "Training started successfully"}

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
24 changes: 24 additions & 0 deletions backend/app/controllers/user_controller.py
Original file line number Diff line number Diff line change
@@ -1 +1,25 @@
"""Handles user related requests."""

from fastapi import APIRouter, HTTPException
from backend.app.services.user_service import UserService

router = APIRouter()
user_service = UserService()


@router.post("/signup")
async def signup(username: str, password: str, email: str):
try:
user = user_service.register_user(username, password, email)
return {"message": "User created successfully", "user": user}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))


@router.post("/login")
async def login(username: str, password: str):
try:
user = user_service.login_user(username, password)
return {"message": "Login successful", "user": user}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
6 changes: 6 additions & 0 deletions backend/app/db/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# backend/db/base.py

from sqlalchemy.ext.declarative import declarative_base

# Create the base class
Base = declarative_base()
23 changes: 23 additions & 0 deletions backend/app/db/db.py
Original file line number Diff line number Diff line change
@@ -1 +1,24 @@
"""Database connection setup (SQLAlchemy or another ORM)."""

# backend/db/db.py

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from base import Base

SQLALCHEMY_DATABASE_URL = "postgresql://user:password@localhost/db_name"

engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def init_db():
Base.metadata.create_all(bind=engine)


def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
18 changes: 18 additions & 0 deletions backend/app/db/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# backend/db/models.py

from sqlalchemy import Column, Integer, String, Float, DateTime
from datetime import datetime
from base import Base


class TrainingLog(Base):
__tablename__ = 'training_logs'

id = Column(Integer, primary_key=True, index=True)
status = Column(String, index=True)
epoch = Column(Integer)
loss = Column(Float)
timestamp = Column(DateTime, default=datetime.utcnow)

def __repr__(self):
return f"<TrainingLog(status={self.status}, epoch={self.epoch}, loss={self.loss}, timestamp={self.timestamp})>"
46 changes: 45 additions & 1 deletion backend/app/repositories/segmentation_repo.py
Original file line number Diff line number Diff line change
@@ -1 +1,45 @@
"""DB logic for storing/fetching segmentation results."""
import psycopg2
import os

class SegmentationRepository:
def __init__(self):
self.connection = psycopg2.connect(
dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"),
host=os.getenv("POSTGRES_HOST")
)
self._initialize_database()

def _initialize_database(self):
with self.connection.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS segmentation_results (
id SERIAL PRIMARY KEY,
original_image_url TEXT NOT NULL,
segmented_image_url TEXT NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.connection.commit()

def save_segmentation_result(self, original_image_url, segmented_image_url):
with self.connection.cursor() as cursor:
cursor.execute("""
INSERT INTO segmentation_results (original_image_url, segmented_image_url)
VALUES (%s, %s)
RETURNING id
""", (original_image_url, segmented_image_url))
result_id = cursor.fetchone()[0]
self.connection.commit()
return result_id

def get_segmentation_results(self, limit=10):
with self.connection.cursor() as cursor:
cursor.execute("""
SELECT id, original_image_url, segmented_image_url, timestamp
FROM segmentation_results
ORDER BY timestamp DESC
LIMIT %s
""", (limit,))
return cursor.fetchall()
19 changes: 18 additions & 1 deletion backend/app/repositories/training_repo.py
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@
"""DB logic for storing/fetching training data."""
# backend/app/repositories/training_repo.py

from sqlalchemy.orm import Session
from backend.db.models import TrainingLog
from datetime import datetime


class TrainingRepo:
def __init__(self, db: Session):
self.db = db

def save_training_log(self, status: str, epoch: int, loss: float):
log = TrainingLog(status=status, epoch=epoch, loss=loss, timestamp=datetime.now())
self.db.add(log)
self.db.commit()

def get_training_logs(self):
return self.db.query(TrainingLog).all()
48 changes: 48 additions & 0 deletions backend/app/repositories/user_repo.py
Original file line number Diff line number Diff line change
@@ -1 +1,49 @@
"""DB logic for storing/fetching user data."""

import psycopg2
import os
from werkzeug.security import generate_password_hash, check_password_hash


class UserRepository:
def __init__(self):
self.connection = psycopg2.connect(
dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"),
host=os.getenv("POSTGRES_HOST")
)
self._initialize_database()

def _initialize_database(self):
with self.connection.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL
)
""")
self.connection.commit()

def create_user(self, username, password, email):
password_hash = generate_password_hash(password)
with self.connection.cursor() as cursor:
cursor.execute("""
INSERT INTO users (username, password_hash, email)
VALUES (%s, %s, %s)
RETURNING id
""", (username, password_hash, email))
user_id = cursor.fetchone()[0]
self.connection.commit()
return user_id

def find_user_by_username(self, username):
with self.connection.cursor() as cursor:
cursor.execute("SELECT id, username, password_hash FROM users WHERE username = %s", (username,))
user = cursor.fetchone()
return user if user else None

def verify_password(self, stored_password_hash, password):
return check_password_hash(stored_password_hash, password)
45 changes: 44 additions & 1 deletion backend/app/services/segmentation_service.py
Original file line number Diff line number Diff line change
@@ -1 +1,44 @@
"""Handles segmentation logic (calls model inference, processing)."""
import torch
import os
from backend.app.repositories.segmentation_repo import SegmentationRepository
from backend.app.utils.s3_client import S3Client
from backend.models.vit_segmentation import VisionTransformerSegmentation
from backend.scripts.image_proccessing import preprocess_image # Import preprocessing function


def _save_output_image(output):
"""
Save the segmentation output to an image file.
You might want to save the output tensor as a .png or .jpg file.
"""
output_image = output.squeeze().cpu().numpy() # Remove batch dimension
output_image_path = "output/segmented_image.png"
# Save the image as a PNG (you can use other formats like .jpg, etc.)
from matplotlib import pyplot as plt
plt.imsave(output_image_path, output_image, cmap="jet")
return output_image_path


class SegmentationService:
def __init__(self, model_path="models/vit_model.pth"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = VisionTransformerSegmentation().to(self.device)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.repo = SegmentationRepository()
self.s3_client = S3Client()

def segment_image(self, image_path):
# Load and preprocess the image using the script's function
preprocessed_image = preprocess_image(image_path)

# Perform segmentation with the model
with torch.no_grad():
output = self.model(preprocessed_image.to(self.device))

# Save the segmented image and upload to S3
segmented_image_path = _save_output_image(output)
segmented_image_url = self.s3_client.upload_file(segmented_image_path, os.path.basename(segmented_image_path))

# Log the result in the database
result_id = self.repo.save_segmentation_result(image_path, segmented_image_url)
return {"result_id": result_id, "segmented_image_url": segmented_image_url}
Loading
Loading