Skip to content

Commit

Permalink
add RAG for better PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
amiicao committed Nov 17, 2024
1 parent a9aa1b8 commit ca547ab
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 38 deletions.
47 changes: 47 additions & 0 deletions src/code_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import tempfile
import subprocess
import shutil
from typing import List, Dict
from src.github_api import get_access_token


def clone_repo_branch(
installation_id: str, repo_full_name: str, branch: str = "main"
) -> str:
"""Clone specific branch of repository to temporary directory"""

temp_dir = tempfile.mkdtemp()
access_token = get_access_token(installation_id)
clone_url = f"https://x-access-token:{access_token}@github.com/{repo_full_name}.git"
try:
subprocess.run(
["git", "clone", "-b", branch, "--single-branch", clone_url, temp_dir],
check=True,
)
return temp_dir
except Exception as e:
shutil.rmtree(temp_dir, ignore_errors=True) # Clean up on error
raise e


def index_code_files(temp_dir: str) -> List[Dict[str, str]]:
"""Index all code files from a directory"""
code_files = []

for root, _, files in os.walk(temp_dir):
for file in files:
if file.startswith(".") or "node_modules" in root:
continue

file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, temp_dir)

try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
code_files.append({"path": relative_path, "content": content})
except UnicodeDecodeError:
continue # Skip binary files

return code_files
84 changes: 59 additions & 25 deletions src/pull_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,74 @@
import logging

import ollama
from typing import List, Dict

from config import OLLAMA_MODEL
from src.vector_db import query_similar_code
from src.github_api import leave_comment

logger = logging.getLogger(__name__)


def handle_new_pull_request(
installation_id,
repo_full_name,
pull_request_number,
pull_request_title,
pull_request_body,
pr_diff,
):
context = (
"I am programming and I plan to merge in a pull request.\ Given the title, description, and pull rquest "
"code diff "
"of my pull request, succinctly identify any potential issues or downsides. Remember that in the code "
"diff, '+' is a code addition and '-' is code subtraction. \ End your response by asking if the PR "
"author has considered these points "
def generate_pr_feedback(
similar_code: List[Dict], pr_title: str, pr_body: str, pr_diff: str
) -> str:
"""Generate feedback using Ollama"""
context = "\n".join(
[f"File: {code['file_path']}\n{code['content']}\n---" for code in similar_code]
)

prompt = f'{context} \n Title: {pull_request_title} \n Description: {pull_request_body} \n Code diff: \n """\n {pr_diff} \n """\n '
prompt = f"""
I am programming and I plan to merge in a pull request. I will provide details of the pull request. You are to
review it and succinctly identify in point form any potential issues, code smells, duplication, or downsides of
the pull request. Consider interactions with the codebase and architectural design. Remember that in the code
diff, '+' is a code addition and '-' is code subtraction. End your response by asking if the PR author has
considered these points.
Pull Request:
Title: {pr_title}
Description: {pr_body}
Relevant context from codebase:
{context}
Changes in PR:
{pr_diff}
"""
logging.info("generating feedback...")
response = ollama.chat(
model=OLLAMA_MODEL,
messages=[
{
"role": "user",
"content": prompt,
},
],
model=OLLAMA_MODEL, messages=[{"role": "user", "content": prompt}]
)
comment_text = response["message"]["content"]

leave_comment(installation_id, repo_full_name, pull_request_number, comment_text)
return response["message"]["content"]


def handle_new_pull_request(
installation_id: str,
repo_id: int,
repo_full_name: str,
pr_number: int,
pr_title: str,
pr_body: str,
pr_diff: str,
changed_files: List[str],
):
"""Handle new pull request webhook"""
try:
# Query similar code from main branch
similar_code = query_similar_code(
changed_files, f"{pr_title} {pr_body} {pr_diff}", repo_id
)

# Generate feedback
feedback = generate_pr_feedback(similar_code, pr_title, pr_body, pr_diff)

# Post comment
leave_comment(installation_id, repo_full_name, pr_number, feedback)

logger.info(f"Posted feedback for PR #{pr_number} in {repo_full_name}")

except Exception as e:
logger.error(f"Error handling PR #{pr_number}: {str(e)}")
raise
87 changes: 87 additions & 0 deletions src/vector_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Dict

import chromadb
from sentence_transformers import SentenceTransformer

Expand Down Expand Up @@ -58,3 +60,88 @@ def add_issues_to_chroma(issues):
repo_id = issue["repository"]["id"]

add_issue_to_chroma(full_issue, issue_number, issue_title, repo_id)


def get_collection_for_repo_branch(repo_id: int, branch: str = "main"):
return chroma_client.get_or_create_collection(f"github_code_{repo_id}_{branch}")


def add_code_to_chroma(
code_files: List[Dict[str, str]], repo_id: int, branch: str = "main"
):
"""Add or update code files in the collection"""
collection = get_collection_for_repo_branch(repo_id, branch)

# Prepare all documents for batch processing
documents = []
metadatas = []
ids = []
embeddings = []

for file in code_files:
file_id = f"{repo_id}_{branch}_{file['path']}"
embedding = model.encode(file["content"]).tolist()

documents.append(file["content"])
metadatas.append(
{"file_path": file["path"], "repo_id": repo_id, "branch": branch}
)
embeddings.append(embedding)
ids.append(file_id)

# Get existing IDs in the collection
existing_ids = set()
try:
existing = collection.get()
if existing and existing["ids"]:
existing_ids = set(existing["ids"])
except Exception:
pass # Collection might be empty

# Split into new and existing documents
new_indices = []
update_indices = []

for i, doc_id in enumerate(ids):
if doc_id in existing_ids:
update_indices.append(i)
else:
new_indices.append(i)

# Add new documents
if new_indices:
collection.add(
documents=[documents[i] for i in new_indices],
metadatas=[metadatas[i] for i in new_indices],
embeddings=[embeddings[i] for i in new_indices],
ids=[ids[i] for i in new_indices],
)

# Update existing documents
if update_indices:
collection.update(
documents=[documents[i] for i in update_indices],
metadatas=[metadatas[i] for i in update_indices],
embeddings=[embeddings[i] for i in update_indices],
ids=[ids[i] for i in update_indices],
)


def query_similar_code(
changed_files: List[str], pr_content: str, repo_id: int
) -> List[Dict]:
collection = get_collection_for_repo_branch(repo_id)
embedding = model.encode(pr_content).tolist()

results = collection.query(
query_embeddings=[embedding],
n_results=5,
where={"file_path": {"$in": changed_files}},
)

return [
{"file_path": meta["file_path"], "content": doc, "distance": dist}
for meta, doc, dist in zip(
results["metadatas"][0], results["documents"][0], results["distances"][0]
)
]
51 changes: 38 additions & 13 deletions src/webhook_handler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import hashlib
import hmac
import logging
import os
import shutil

import requests
from flask import Blueprint, request, jsonify, abort

from config import WEBHOOK_SECRET
from src.code_indexer import clone_repo_branch, index_code_files
from src.github_api import fetch_existing_issues
from src.issue_handler import handle_new_issue
from src.pull_request_handler import handle_new_pull_request
from src.vector_db import add_issues_to_chroma, remove_issues_from_chroma
from src.vector_db import (
add_issues_to_chroma,
remove_issues_from_chroma,
add_code_to_chroma,
)

logger = logging.getLogger(__name__)
webhook_blueprint = Blueprint("webhook", __name__)
Expand Down Expand Up @@ -127,18 +134,36 @@ def handle_pull_requests(data, installation_id):
action = data.get("action")
pull_request = data["pull_request"]
repo_full_name = data.get("repository", {}).get("full_name")
repo_id = data.get("repository", {}).get("id")

if not repo_full_name:
abort(400, "Repository full name is missing")
if not repo_full_name or not repo_id:
abort(400, "Repository information missing")

pr_diff = requests.get(pull_request["diff_url"]).text

if action == "opened":
handle_new_pull_request(
installation_id,
repo_full_name,
pull_request["number"],
pull_request.get("title", ""),
pull_request.get("body", ""),
pr_diff,
)
changed_files = [
f["filename"] for f in requests.get(pull_request["url"] + "/files").json()
]

if action == "opened" or action == "synchronize":
temp_dir = None
try:
# Update main branch collection if needed
temp_dir = clone_repo_branch(installation_id, repo_full_name, "main")
code_files = index_code_files(temp_dir)
add_code_to_chroma(code_files, repo_id, "main")

# Handle the pull request
handle_new_pull_request(
installation_id,
repo_id,
repo_full_name,
pull_request["number"],
pull_request.get("title", ""),
pull_request.get("body", ""),
pr_diff,
changed_files,
)
finally:
if temp_dir and os.path.exists(temp_dir):
logging.info(f"removing temp_dir {temp_dir}...")
shutil.rmtree(temp_dir, ignore_errors=True)

0 comments on commit ca547ab

Please sign in to comment.