Skip to content

Commit 7d54b98

Browse files
authored
Improvements to OverleafGitPaperRemote (#25)
* add test fixture for OverleafGitPaperRemote * testing that git can resolve non-conflicting edits using merge (partially addressing #19) * testing that we recover gracefully from merge conflicts and other kinds of bad edits * better git diff handling to test for recent human edits * handling edits' revision_id in OverleafGitPaperRemote. Snazzy `with paper.rewind(commit_id)` syntax
1 parent a8eb7c5 commit 7d54b98

File tree

6 files changed

+533
-49
lines changed

6 files changed

+533
-49
lines changed

llm4papers/config.example.py

+4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@ class OpenAIConfig(BaseSettings):
77

88

99
class OverleafConfig(BaseSettings):
10+
# Username and password for logging into your overleaf account
1011
username: str = "###"
1112
password: str = "###"
13+
# Author name and email that should appear in git history
14+
git_name: str = "AI assistant"
15+
git_email: str = "[email protected]"
1216

1317

1418
class Settings(BaseSettings):

llm4papers/paper_remote/OverleafGitPaperRemote.py

+161-49
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,47 @@
88
import shutil
99
import datetime
1010
from urllib.parse import quote
11-
from git import Repo # type: ignore
12-
13-
from llm4papers.models import EditTrigger, EditResult, EditType, DocumentID, RevisionID
11+
from git import Repo, GitCommandError # type: ignore
12+
from typing import Iterable
13+
import re
14+
15+
from llm4papers.models import (
16+
EditTrigger,
17+
EditResult,
18+
EditType,
19+
DocumentID,
20+
RevisionID,
21+
LineRange,
22+
)
1423
from llm4papers.paper_remote.MultiDocumentPaperRemote import MultiDocumentPaperRemote
1524
from llm4papers.logger import logger
1625

1726

27+
diff_line_edit_re = re.compile(
28+
r"@{2,}\s*-(?P<old_line>\d+),(?P<old_count>\d+)\s*\+(?P<new_line>\d+),(?P<new_count>\d+)\s*@{2,}"
29+
)
30+
31+
32+
def _diff_to_ranges(diff: str) -> Iterable[LineRange]:
33+
"""Given a git diff, return LineRange object(s) indicating which lines in the
34+
original document were changed.
35+
"""
36+
for match in diff_line_edit_re.finditer(diff):
37+
git_line_number = int(match.group("new_line"))
38+
git_line_count = int(match.group("new_count"))
39+
# Git counts from 1 and gives (start, length), inclusive. LineRange counts from
40+
# 0 and gives start:end indices (exclusive).
41+
zero_index_start = git_line_number - 1
42+
yield zero_index_start, zero_index_start + git_line_count
43+
44+
45+
def _ranges_overlap(a: LineRange, b: LineRange) -> bool:
46+
"""Given two LineRanges, return True if they overlap, False otherwise."""
47+
return not (a[1] < b[0] or b[1] < a[0])
48+
49+
1850
def _too_close_to_human_edits(
19-
repo: Repo, filename: str, line_number: int, last_n: int = 2
51+
repo: Repo, filename: str, line_range: LineRange, last_n: int = 2
2052
) -> bool:
2153
"""
2254
Determine if the line `line_number` of the file `filename` was changed in
@@ -41,22 +73,19 @@ def _too_close_to_human_edits(
4173
logger.info(f"Last commit was {sec_since_last_commit}s ago, approving edit.")
4274
return False
4375

44-
# Get the diff for HEAD~n:
76+
# Get the diff for HEAD~n. Note that the gitpython DiffIndex and Diff objects
77+
# drop the line number info (!) so we can't use the gitpython object-oriented API
78+
# to do this. Calling repo.git.diff is pretty much a direct pass-through to
79+
# running "git diff HEAD~n -- <filename>" on the command line.
4580
total_diff = repo.git.diff(f"HEAD~{last_n}", filename, unified=0)
4681

47-
# Get the current repo state of that line:
48-
current_line = repo.git.show(f"HEAD:{filename}").split("\n")[line_number]
49-
50-
logger.debug("Diff: " + total_diff)
51-
logger.debug("Current line: " + current_line)
52-
53-
# Match the line in the diff:
54-
if current_line in total_diff:
55-
logger.info(
56-
f"Found current line ({current_line[:10]}...) in diff, rejecting edit."
57-
)
58-
return True
59-
82+
for git_line_range in _diff_to_ranges(total_diff):
83+
if _ranges_overlap(git_line_range, line_range):
84+
logger.info(
85+
f"Line range {line_range} overlaps with git-edited {git_line_range}, "
86+
f"rejecting edit."
87+
)
88+
return True
6089
return False
6190

6291

@@ -77,15 +106,28 @@ def _add_auth(uri: str):
77106
return uri
78107

79108

109+
def _add_git_user_from_config(repo: Repo) -> None:
110+
try:
111+
from llm4papers.config import OverleafConfig
112+
113+
config = OverleafConfig()
114+
repo.config_writer().set_value("user", "name", config.git_name).release()
115+
repo.config_writer().set_value("user", "email", config.git_email).release()
116+
except ImportError:
117+
logger.debug("No config file found, assuming public repo.")
118+
repo.config_writer().set_value("user", "name", "no-config").release()
119+
repo.config_writer().set_value(
120+
"user", "email", "[email protected]"
121+
).release()
122+
123+
80124
class OverleafGitPaperRemote(MultiDocumentPaperRemote):
81125
"""
82126
Overleaf exposes a git remote for each project. This class handles reading
83127
and writing to Overleaf documents using gitpython, and implements the
84128
PaperRemote protocol for use by the AI editor.
85129
"""
86130

87-
current_revision_id: RevisionID
88-
89131
def __init__(self, git_cached_repo: str):
90132
"""
91133
Saves the git repo to a local temporary directory using gitpython.
@@ -100,6 +142,10 @@ def __init__(self, git_cached_repo: str):
100142
self._cached_repo: Repo | None = None
101143
self.refresh()
102144

145+
@property
146+
def current_revision_id(self) -> RevisionID:
147+
return self._get_repo().head.commit.hexsha
148+
103149
def _get_repo(self) -> Repo:
104150
if self._cached_repo is None:
105151
# TODO - this makes me anxious about race conditions. every time we refresh,
@@ -119,7 +165,7 @@ def _doc_id_to_path(self, doc_id: DocumentID) -> pathlib.Path:
119165
# so we can cast to a string on this next line:
120166
return pathlib.Path(git_root) / str(doc_id)
121167

122-
def refresh(self):
168+
def refresh(self, retry: bool = True):
123169
"""
124170
This is a fallback method (that likely needs some love) to ensure that
125171
the repo is up to date with the latest upstream changes.
@@ -134,6 +180,7 @@ def refresh(self):
134180
)
135181

136182
self._cached_repo = Repo(f"/tmp/{self._reposlug}")
183+
_add_git_user_from_config(self._cached_repo)
137184

138185
logger.info(f"Pulling latest from repo {self._reposlug}.")
139186
try:
@@ -143,15 +190,14 @@ def refresh(self):
143190
f"Latest change at {self._get_repo().head.commit.committed_datetime}"
144191
)
145192
logger.info(f"Repo dirty: {self._get_repo().is_dirty()}")
146-
self.current_revision_id = self._get_repo().head.commit.hexsha
147193
try:
148194
self._get_repo().git.stash("pop")
149-
except Exception as e:
195+
except GitCommandError as e:
150196
# TODO: this just means there was nothing to pop, but
151197
# we should handle this more gracefully.
152198
logger.debug(f"Nothing to pop: {e}")
153199
pass
154-
except Exception as e:
200+
except GitCommandError as e:
155201
logger.error(
156202
f"Error pulling from repo {self._reposlug}: {e}. "
157203
"Falling back on DESTRUCTION!!!"
@@ -161,7 +207,10 @@ def refresh(self):
161207
self._cached_repo = None
162208
# recursively delete the repo
163209
shutil.rmtree(f"/tmp/{self._reposlug}")
164-
self.refresh()
210+
if retry:
211+
self.refresh(retry=False)
212+
else:
213+
raise e
165214

166215
def list_doc_ids(self) -> list[DocumentID]:
167216
"""
@@ -196,14 +245,15 @@ def is_edit_ok(self, edit: EditTrigger) -> bool:
196245
# want to wait for the user to move on to the next line.
197246
for doc_range in edit.input_ranges + edit.output_ranges:
198247
repo_scoped_file = str(self._doc_id_to_path(doc_range.doc_id))
199-
for i in range(doc_range.selection[0], doc_range.selection[1]):
200-
if _too_close_to_human_edits(self._get_repo(), repo_scoped_file, i):
201-
logger.info(
202-
f"Temporarily skipping edit request in {doc_range.doc_id}"
203-
" at line {i} because it was still in progress"
204-
" in the last commit."
205-
)
206-
return False
248+
if _too_close_to_human_edits(
249+
self._get_repo(), repo_scoped_file, doc_range.selection
250+
):
251+
logger.info(
252+
f"Temporarily skipping edit request in {doc_range.doc_id}"
253+
" at line {i} because it was still in progress"
254+
" in the last commit."
255+
)
256+
return False
207257
return True
208258

209259
def to_dict(self):
@@ -221,27 +271,30 @@ def perform_edit(self, edit: EditResult) -> bool:
221271
Returns:
222272
True if the edit was successful, False otherwise
223273
"""
274+
if not self._doc_id_to_path(edit.range.doc_id).exists():
275+
logger.error(f"Document {edit.range.doc_id} does not exist.")
276+
return False
277+
224278
logger.info(f"Performing edit {edit} on remote {self._reposlug}")
225279

226-
if edit.type == EditType.replace:
227-
success = self._perform_replace(edit)
228-
elif edit.type == EditType.comment:
229-
success = self._perform_comment(edit)
230-
else:
231-
raise ValueError(f"Unknown edit type {edit.type}")
280+
try:
281+
with self.rewind(edit.range.revision_id, message="AI edit") as paper:
282+
if edit.type == EditType.replace:
283+
success = paper._perform_replace(edit)
284+
elif edit.type == EditType.comment:
285+
success = paper._perform_comment(edit)
286+
else:
287+
raise ValueError(f"Unknown edit type {edit.type}")
288+
except GitCommandError as e:
289+
logger.error(
290+
f"Git error performing edit {edit} on remote {self._reposlug}: {e}"
291+
)
292+
success = False
232293

233294
if success:
234-
# TODO - apply edit relative to the edit.range.revision_id commit and then
235-
# rebase onto HEAD for poor-man's operational transforms
236-
self._get_repo().index.add([self._doc_id_to_path(str(edit.range.doc_id))])
237-
self._get_repo().index.commit("AI edit completed.")
238-
# Instead of just pushing, we need to rebase and then push.
239-
# This is because we want to make sure that the AI edits are always
240-
# on top of the stack.
241-
self._get_repo().git.pull()
242-
# TODO: We could do a better job catching WARNs here and then maybe setting
243-
# success = False
244295
self._get_repo().git.push()
296+
else:
297+
self.refresh()
245298

246299
return success
247300

@@ -257,6 +310,19 @@ def _perform_replace(self, edit: EditResult) -> bool:
257310
"""
258311
doc_range, text = edit.range, edit.content
259312
try:
313+
num_lines = len(self.get_lines(doc_range.doc_id))
314+
if (
315+
any(i < 0 for i in doc_range.selection)
316+
or doc_range.selection[1] < doc_range.selection[0]
317+
or any(
318+
i > len(self.get_lines(doc_range.doc_id))
319+
for i in doc_range.selection
320+
)
321+
):
322+
raise IndexError(
323+
f"Invalid selection {doc_range.selection} for document "
324+
f"{doc_range.doc_id} with {num_lines} lines."
325+
)
260326
lines = self.get_lines(doc_range.doc_id)
261327
lines = (
262328
lines[: doc_range.selection[0]]
@@ -284,3 +350,49 @@ def _perform_comment(self, edit: EditResult) -> bool:
284350
# TODO - implement this for real
285351
logger.info(f"Performing comment edit {edit} on remote {self._reposlug}")
286352
return True
353+
354+
def rewind(self, commit: str, message: str):
355+
return self.RewindContext(self, commit, message)
356+
357+
# Create an inner class for "with" semantics so that we can do
358+
# `with remote.rewind(commit)` to rewind to a particular commit and play some edits
359+
# onto it, then merge when the 'with' context exits.
360+
class RewindContext:
361+
# TODO - there are tricks in gitpython where an IndexFile can be used to
362+
# handle changes to files in-memory without having to call checkout() and
363+
# (briefly) modify the state of things on disk. This would be an improvement,
364+
# but would require using the gitpython API more directly inside of
365+
# perform_edit, such as calling git.IndexFile.write() instead of python's
366+
# open() and write()
367+
368+
def __init__(self, remote: "OverleafGitPaperRemote", commit: str, message: str):
369+
self._remote = remote
370+
self._message = message
371+
self._rewind_commit = commit
372+
373+
def __enter__(self):
374+
repo = self._remote._get_repo()
375+
self._restore_ref = repo.head.ref
376+
self._new_branch = repo.create_head(
377+
"tmp-edit-branch", commit=self._rewind_commit
378+
)
379+
self._new_branch.checkout()
380+
return self._remote
381+
382+
def __exit__(self, exc_type, exc_val, exc_tb):
383+
repo = self._remote._get_repo()
384+
assert (
385+
repo.active_branch == self._new_branch
386+
), "Branch changed unexpectedly mid-`with`"
387+
# Add files that changed
388+
repo.index.add([_file for (_file, _), _ in repo.index.entries.items()])
389+
repo.index.commit(self._message)
390+
self._restore_ref.checkout()
391+
try:
392+
repo.git.merge("tmp-edit-branch")
393+
except GitCommandError as e:
394+
# Hard reset on failure
395+
repo.git.reset("--hard", self._restore_ref.commit.hexsha)
396+
raise e
397+
finally:
398+
repo.delete_head(self._new_branch, force=True)

0 commit comments

Comments
 (0)