8
8
import shutil
9
9
import datetime
10
10
from urllib .parse import quote
11
- from git import Repo # type: ignore
12
-
13
- from llm4papers .models import EditTrigger , EditResult , EditType , DocumentID , RevisionID
11
+ from git import Repo , GitCommandError # type: ignore
12
+ from typing import Iterable
13
+ import re
14
+
15
+ from llm4papers .models import (
16
+ EditTrigger ,
17
+ EditResult ,
18
+ EditType ,
19
+ DocumentID ,
20
+ RevisionID ,
21
+ LineRange ,
22
+ )
14
23
from llm4papers .paper_remote .MultiDocumentPaperRemote import MultiDocumentPaperRemote
15
24
from llm4papers .logger import logger
16
25
17
26
27
+ diff_line_edit_re = re .compile (
28
+ r"@{2,}\s*-(?P<old_line>\d+),(?P<old_count>\d+)\s*\+(?P<new_line>\d+),(?P<new_count>\d+)\s*@{2,}"
29
+ )
30
+
31
+
32
+ def _diff_to_ranges (diff : str ) -> Iterable [LineRange ]:
33
+ """Given a git diff, return LineRange object(s) indicating which lines in the
34
+ original document were changed.
35
+ """
36
+ for match in diff_line_edit_re .finditer (diff ):
37
+ git_line_number = int (match .group ("new_line" ))
38
+ git_line_count = int (match .group ("new_count" ))
39
+ # Git counts from 1 and gives (start, length), inclusive. LineRange counts from
40
+ # 0 and gives start:end indices (exclusive).
41
+ zero_index_start = git_line_number - 1
42
+ yield zero_index_start , zero_index_start + git_line_count
43
+
44
+
45
+ def _ranges_overlap (a : LineRange , b : LineRange ) -> bool :
46
+ """Given two LineRanges, return True if they overlap, False otherwise."""
47
+ return not (a [1 ] < b [0 ] or b [1 ] < a [0 ])
48
+
49
+
18
50
def _too_close_to_human_edits (
19
- repo : Repo , filename : str , line_number : int , last_n : int = 2
51
+ repo : Repo , filename : str , line_range : LineRange , last_n : int = 2
20
52
) -> bool :
21
53
"""
22
54
Determine if the line `line_number` of the file `filename` was changed in
@@ -41,22 +73,19 @@ def _too_close_to_human_edits(
41
73
logger .info (f"Last commit was { sec_since_last_commit } s ago, approving edit." )
42
74
return False
43
75
44
- # Get the diff for HEAD~n:
76
+ # Get the diff for HEAD~n. Note that the gitpython DiffIndex and Diff objects
77
+ # drop the line number info (!) so we can't use the gitpython object-oriented API
78
+ # to do this. Calling repo.git.diff is pretty much a direct pass-through to
79
+ # running "git diff HEAD~n -- <filename>" on the command line.
45
80
total_diff = repo .git .diff (f"HEAD~{ last_n } " , filename , unified = 0 )
46
81
47
- # Get the current repo state of that line:
48
- current_line = repo .git .show (f"HEAD:{ filename } " ).split ("\n " )[line_number ]
49
-
50
- logger .debug ("Diff: " + total_diff )
51
- logger .debug ("Current line: " + current_line )
52
-
53
- # Match the line in the diff:
54
- if current_line in total_diff :
55
- logger .info (
56
- f"Found current line ({ current_line [:10 ]} ...) in diff, rejecting edit."
57
- )
58
- return True
59
-
82
+ for git_line_range in _diff_to_ranges (total_diff ):
83
+ if _ranges_overlap (git_line_range , line_range ):
84
+ logger .info (
85
+ f"Line range { line_range } overlaps with git-edited { git_line_range } , "
86
+ f"rejecting edit."
87
+ )
88
+ return True
60
89
return False
61
90
62
91
@@ -77,15 +106,28 @@ def _add_auth(uri: str):
77
106
return uri
78
107
79
108
109
+ def _add_git_user_from_config (repo : Repo ) -> None :
110
+ try :
111
+ from llm4papers .config import OverleafConfig
112
+
113
+ config = OverleafConfig ()
114
+ repo .config_writer ().set_value ("user" , "name" , config .git_name ).release ()
115
+ repo .config_writer ().set_value ("user" , "email" , config .git_email ).release ()
116
+ except ImportError :
117
+ logger .debug ("No config file found, assuming public repo." )
118
+ repo .config_writer ().set_value ("user" , "name" , "no-config" ).release ()
119
+ repo .config_writer ().set_value (
120
+ "user" ,
"email" ,
"[email protected] "
121
+ ).release ()
122
+
123
+
80
124
class OverleafGitPaperRemote (MultiDocumentPaperRemote ):
81
125
"""
82
126
Overleaf exposes a git remote for each project. This class handles reading
83
127
and writing to Overleaf documents using gitpython, and implements the
84
128
PaperRemote protocol for use by the AI editor.
85
129
"""
86
130
87
- current_revision_id : RevisionID
88
-
89
131
def __init__ (self , git_cached_repo : str ):
90
132
"""
91
133
Saves the git repo to a local temporary directory using gitpython.
@@ -100,6 +142,10 @@ def __init__(self, git_cached_repo: str):
100
142
self ._cached_repo : Repo | None = None
101
143
self .refresh ()
102
144
145
+ @property
146
+ def current_revision_id (self ) -> RevisionID :
147
+ return self ._get_repo ().head .commit .hexsha
148
+
103
149
def _get_repo (self ) -> Repo :
104
150
if self ._cached_repo is None :
105
151
# TODO - this makes me anxious about race conditions. every time we refresh,
@@ -119,7 +165,7 @@ def _doc_id_to_path(self, doc_id: DocumentID) -> pathlib.Path:
119
165
# so we can cast to a string on this next line:
120
166
return pathlib .Path (git_root ) / str (doc_id )
121
167
122
- def refresh (self ):
168
+ def refresh (self , retry : bool = True ):
123
169
"""
124
170
This is a fallback method (that likely needs some love) to ensure that
125
171
the repo is up to date with the latest upstream changes.
@@ -134,6 +180,7 @@ def refresh(self):
134
180
)
135
181
136
182
self ._cached_repo = Repo (f"/tmp/{ self ._reposlug } " )
183
+ _add_git_user_from_config (self ._cached_repo )
137
184
138
185
logger .info (f"Pulling latest from repo { self ._reposlug } ." )
139
186
try :
@@ -143,15 +190,14 @@ def refresh(self):
143
190
f"Latest change at { self ._get_repo ().head .commit .committed_datetime } "
144
191
)
145
192
logger .info (f"Repo dirty: { self ._get_repo ().is_dirty ()} " )
146
- self .current_revision_id = self ._get_repo ().head .commit .hexsha
147
193
try :
148
194
self ._get_repo ().git .stash ("pop" )
149
- except Exception as e :
195
+ except GitCommandError as e :
150
196
# TODO: this just means there was nothing to pop, but
151
197
# we should handle this more gracefully.
152
198
logger .debug (f"Nothing to pop: { e } " )
153
199
pass
154
- except Exception as e :
200
+ except GitCommandError as e :
155
201
logger .error (
156
202
f"Error pulling from repo { self ._reposlug } : { e } . "
157
203
"Falling back on DESTRUCTION!!!"
@@ -161,7 +207,10 @@ def refresh(self):
161
207
self ._cached_repo = None
162
208
# recursively delete the repo
163
209
shutil .rmtree (f"/tmp/{ self ._reposlug } " )
164
- self .refresh ()
210
+ if retry :
211
+ self .refresh (retry = False )
212
+ else :
213
+ raise e
165
214
166
215
def list_doc_ids (self ) -> list [DocumentID ]:
167
216
"""
@@ -196,14 +245,15 @@ def is_edit_ok(self, edit: EditTrigger) -> bool:
196
245
# want to wait for the user to move on to the next line.
197
246
for doc_range in edit .input_ranges + edit .output_ranges :
198
247
repo_scoped_file = str (self ._doc_id_to_path (doc_range .doc_id ))
199
- for i in range (doc_range .selection [0 ], doc_range .selection [1 ]):
200
- if _too_close_to_human_edits (self ._get_repo (), repo_scoped_file , i ):
201
- logger .info (
202
- f"Temporarily skipping edit request in { doc_range .doc_id } "
203
- " at line {i} because it was still in progress"
204
- " in the last commit."
205
- )
206
- return False
248
+ if _too_close_to_human_edits (
249
+ self ._get_repo (), repo_scoped_file , doc_range .selection
250
+ ):
251
+ logger .info (
252
+ f"Temporarily skipping edit request in { doc_range .doc_id } "
253
+ " at line {i} because it was still in progress"
254
+ " in the last commit."
255
+ )
256
+ return False
207
257
return True
208
258
209
259
def to_dict (self ):
@@ -221,27 +271,30 @@ def perform_edit(self, edit: EditResult) -> bool:
221
271
Returns:
222
272
True if the edit was successful, False otherwise
223
273
"""
274
+ if not self ._doc_id_to_path (edit .range .doc_id ).exists ():
275
+ logger .error (f"Document { edit .range .doc_id } does not exist." )
276
+ return False
277
+
224
278
logger .info (f"Performing edit { edit } on remote { self ._reposlug } " )
225
279
226
- if edit .type == EditType .replace :
227
- success = self ._perform_replace (edit )
228
- elif edit .type == EditType .comment :
229
- success = self ._perform_comment (edit )
230
- else :
231
- raise ValueError (f"Unknown edit type { edit .type } " )
280
+ try :
281
+ with self .rewind (edit .range .revision_id , message = "AI edit" ) as paper :
282
+ if edit .type == EditType .replace :
283
+ success = paper ._perform_replace (edit )
284
+ elif edit .type == EditType .comment :
285
+ success = paper ._perform_comment (edit )
286
+ else :
287
+ raise ValueError (f"Unknown edit type { edit .type } " )
288
+ except GitCommandError as e :
289
+ logger .error (
290
+ f"Git error performing edit { edit } on remote { self ._reposlug } : { e } "
291
+ )
292
+ success = False
232
293
233
294
if success :
234
- # TODO - apply edit relative to the edit.range.revision_id commit and then
235
- # rebase onto HEAD for poor-man's operational transforms
236
- self ._get_repo ().index .add ([self ._doc_id_to_path (str (edit .range .doc_id ))])
237
- self ._get_repo ().index .commit ("AI edit completed." )
238
- # Instead of just pushing, we need to rebase and then push.
239
- # This is because we want to make sure that the AI edits are always
240
- # on top of the stack.
241
- self ._get_repo ().git .pull ()
242
- # TODO: We could do a better job catching WARNs here and then maybe setting
243
- # success = False
244
295
self ._get_repo ().git .push ()
296
+ else :
297
+ self .refresh ()
245
298
246
299
return success
247
300
@@ -257,6 +310,19 @@ def _perform_replace(self, edit: EditResult) -> bool:
257
310
"""
258
311
doc_range , text = edit .range , edit .content
259
312
try :
313
+ num_lines = len (self .get_lines (doc_range .doc_id ))
314
+ if (
315
+ any (i < 0 for i in doc_range .selection )
316
+ or doc_range .selection [1 ] < doc_range .selection [0 ]
317
+ or any (
318
+ i > len (self .get_lines (doc_range .doc_id ))
319
+ for i in doc_range .selection
320
+ )
321
+ ):
322
+ raise IndexError (
323
+ f"Invalid selection { doc_range .selection } for document "
324
+ f"{ doc_range .doc_id } with { num_lines } lines."
325
+ )
260
326
lines = self .get_lines (doc_range .doc_id )
261
327
lines = (
262
328
lines [: doc_range .selection [0 ]]
@@ -284,3 +350,49 @@ def _perform_comment(self, edit: EditResult) -> bool:
284
350
# TODO - implement this for real
285
351
logger .info (f"Performing comment edit { edit } on remote { self ._reposlug } " )
286
352
return True
353
+
354
+ def rewind (self , commit : str , message : str ):
355
+ return self .RewindContext (self , commit , message )
356
+
357
+ # Create an inner class for "with" semantics so that we can do
358
+ # `with remote.rewind(commit)` to rewind to a particular commit and play some edits
359
+ # onto it, then merge when the 'with' context exits.
360
+ class RewindContext :
361
+ # TODO - there are tricks in gitpython where an IndexFile can be used to
362
+ # handle changes to files in-memory without having to call checkout() and
363
+ # (briefly) modify the state of things on disk. This would be an improvement,
364
+ # but would require using the gitpython API more directly inside of
365
+ # perform_edit, such as calling git.IndexFile.write() instead of python's
366
+ # open() and write()
367
+
368
+ def __init__ (self , remote : "OverleafGitPaperRemote" , commit : str , message : str ):
369
+ self ._remote = remote
370
+ self ._message = message
371
+ self ._rewind_commit = commit
372
+
373
+ def __enter__ (self ):
374
+ repo = self ._remote ._get_repo ()
375
+ self ._restore_ref = repo .head .ref
376
+ self ._new_branch = repo .create_head (
377
+ "tmp-edit-branch" , commit = self ._rewind_commit
378
+ )
379
+ self ._new_branch .checkout ()
380
+ return self ._remote
381
+
382
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
383
+ repo = self ._remote ._get_repo ()
384
+ assert (
385
+ repo .active_branch == self ._new_branch
386
+ ), "Branch changed unexpectedly mid-`with`"
387
+ # Add files that changed
388
+ repo .index .add ([_file for (_file , _ ), _ in repo .index .entries .items ()])
389
+ repo .index .commit (self ._message )
390
+ self ._restore_ref .checkout ()
391
+ try :
392
+ repo .git .merge ("tmp-edit-branch" )
393
+ except GitCommandError as e :
394
+ # Hard reset on failure
395
+ repo .git .reset ("--hard" , self ._restore_ref .commit .hexsha )
396
+ raise e
397
+ finally :
398
+ repo .delete_head (self ._new_branch , force = True )
0 commit comments