Skip to content

Commit

Permalink
Lock around Tracker creation and destruction
Browse files Browse the repository at this point in the history
The `Tracker.__enter__` and `Tracker.__exit` methods may wind up
releasing the GIL, which allows another thread to see an intermediate
state where the tracker is not fully installed. Add a lock, shared
across all trackers, to serialize access to the global state that
tracker installation and uninstallation writes to.

Signed-off-by: Matt Wozniski <[email protected]>
  • Loading branch information
godlygeek committed Aug 20, 2024
1 parent f848fc3 commit c4544c7
Showing 1 changed file with 37 additions and 32 deletions.
69 changes: 37 additions & 32 deletions src/memray/_memray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,9 @@ cdef class ProfileFunctionGuard:
NativeTracker.forgetPythonStack()


tracker_creation_lock = threading.Lock()


cdef class Tracker:
"""Context manager for tracking memory allocations in a Python script.
Expand Down Expand Up @@ -690,46 +693,48 @@ cdef class Tracker:

@cython.profile(False)
def __enter__(self):
if NativeTracker.getTracker() != NULL:
raise RuntimeError("No more than one Tracker instance can be active at the same time")

cdef unique_ptr[RecordWriter] writer
if self._writer == NULL:
raise RuntimeError("Attempting to use stale output handle")
writer = move(self._writer)

for attr in ("_name", "_ident"):
assert not hasattr(threading.Thread, attr)
setattr(
threading.Thread,
attr,
ThreadNameInterceptor(attr, NativeTracker.registerThreadNameById),
)
with tracker_creation_lock:
if NativeTracker.getTracker() != NULL:
raise RuntimeError("No more than one Tracker instance can be active at the same time")

if self._writer == NULL:
raise RuntimeError("Attempting to use stale output handle")
writer = move(self._writer)

for attr in ("_name", "_ident"):
assert not hasattr(threading.Thread, attr)
setattr(
threading.Thread,
attr,
ThreadNameInterceptor(attr, NativeTracker.registerThreadNameById),
)

self._previous_profile_func = sys.getprofile()
self._previous_thread_profile_func = threading._profile_hook
threading.setprofile(start_thread_trace)
self._previous_profile_func = sys.getprofile()
self._previous_thread_profile_func = threading._profile_hook
threading.setprofile(start_thread_trace)

if "greenlet" in sys.modules:
NativeTracker.beginTrackingGreenlets()
if "greenlet" in sys.modules:
NativeTracker.beginTrackingGreenlets()

NativeTracker.createTracker(
move(writer),
self._native_traces,
self._memory_interval_ms,
self._follow_fork,
self._trace_python_allocators,
)
return self
NativeTracker.createTracker(
move(writer),
self._native_traces,
self._memory_interval_ms,
self._follow_fork,
self._trace_python_allocators,
)
return self

@cython.profile(False)
def __exit__(self, exc_type, exc_value, exc_traceback):
NativeTracker.destroyTracker()
sys.setprofile(self._previous_profile_func)
threading.setprofile(self._previous_thread_profile_func)
with tracker_creation_lock:
NativeTracker.destroyTracker()
sys.setprofile(self._previous_profile_func)
threading.setprofile(self._previous_thread_profile_func)

for attr in ("_name", "_ident"):
delattr(threading.Thread, attr)
for attr in ("_name", "_ident"):
delattr(threading.Thread, attr)


def start_thread_trace(frame, event, arg):
Expand Down

0 comments on commit c4544c7

Please sign in to comment.