diff --git a/dmoj/cptbox/_cptbox.pyi b/dmoj/cptbox/_cptbox.pyi index 7a52be6bf..915d74f27 100644 --- a/dmoj/cptbox/_cptbox.pyi +++ b/dmoj/cptbox/_cptbox.pyi @@ -12,6 +12,10 @@ PTBOX_ABI_COUNT: int ALL_ABIS: List[int] SUPPORTED_ABIS: List[int] +class ArgAccessor: + def __getitem__(self, reg : int) -> int: ... + def __setitem__(self, reg : int, val : int) -> None: ... + class Debugger: syscall: int result: int @@ -22,6 +26,7 @@ class Debugger: arg3: int arg4: int arg5: int + arg: ArgAccessor uresult: int uarg0: int @@ -30,6 +35,7 @@ class Debugger: uarg3: int uarg4: int uarg5: int + uarg: ArgAccessor pid: int tid: int diff --git a/dmoj/cptbox/_cptbox.pyx b/dmoj/cptbox/_cptbox.pyx index c92829669..b218436ac 100644 --- a/dmoj/cptbox/_cptbox.pyx +++ b/dmoj/cptbox/_cptbox.pyx @@ -225,6 +225,22 @@ def memory_fd_seal(int fd): cdef class Process +cdef class Debugger + +cdef class ArgAccessor: + cdef Debugger debugger + cdef object base + + def __cinit__(self, Debugger debugger, object base): + self.debugger = debugger + self.base = base + + def __getitem__(self, reg): + return getattr(self.debugger, '%s%d' % (self.base, reg)) + + def __setitem__(self, reg, val): + setattr(self.debugger, '%s%d' % (self.base, reg), val) + cdef class Debugger: cdef pt_debugger *thisptr @@ -235,6 +251,8 @@ cdef class Debugger: self.thisptr = new pt_debugger() self.process = process self.on_return_callback = {} + self.uarg = ArgAccessor(self, "uarg") + self.arg = ArgAccessor(self, "arg") def __dealloc__(self): del self.thisptr diff --git a/dmoj/cptbox/isolate.py b/dmoj/cptbox/isolate.py index 358948158..c46a54778 100644 --- a/dmoj/cptbox/isolate.py +++ b/dmoj/cptbox/isolate.py @@ -213,7 +213,7 @@ def _compile_fs_jail(self, fs) -> FilesystemPolicy: def _dirfd_getter_from_reg(self, reg: int) -> DirFDGetter: def getter(debugger: Debugger) -> int: - return getattr(debugger, 'uarg%d' % reg) + return debugger.uarg[reg] return getter @@ -222,7 +222,7 @@ def _dirfd_getter_cwd(self, debugger: Debugger) -> int: def _fs_jail_getter_from_open_flags_reg(self, reg: int) -> FSJailGetter: def getter(debugger: Debugger) -> FilesystemPolicy: - open_flags = getattr(debugger, 'uarg%d' % reg) + open_flags = debugger.uarg[reg] for flag in open_write_flags: # Strict equality is necessary here, since e.g. O_TMPFILE has multiple bits set, # and O_DIRECTORY & O_TMPFILE > 0. @@ -282,7 +282,7 @@ def check(debugger: Debugger) -> None: # We already allowed this one way or another, don't check again. return - dirfd = getattr(debugger, 'uarg%d' % dir_reg) + dirfd = debugger.uarg[dir_reg] full_path = self.get_full_path_unnormalized(debugger, rel_file, dirfd=dirfd) self._access_check(debugger, full_path, self.read_fs_jail) @@ -299,9 +299,8 @@ def check(debugger: Debugger) -> None: return check def get_rel_file(self, debugger: Debugger, *, reg: int) -> str: - ptr = getattr(debugger, 'uarg%d' % reg) try: - file = debugger.readstr(ptr) + file = debugger.readstr(debugger.uarg[reg]) except MaxLengthExceeded as e: raise DeniedSyscall(ACCESS_ENAMETOOLONG, f'Overly long path: {e.args[0]}') except UnicodeDecodeError as e: