Skip to content

Commit a10152f

Browse files
authored
gh-137400: Fix thread-safety issues when profiling all threads (gh-137518)
There were a few thread-safety issues when profiling or tracing all threads via PyEval_SetProfileAllThreads or PyEval_SetTraceAllThreads: * The loop over thread states could crash if a thread exits concurrently (in both the free threading and default build) * The modification of `c_profilefunc` and `c_tracefunc` wasn't thread-safe on the free threading build.
1 parent 923d686 commit a10152f

File tree

11 files changed

+429
-239
lines changed

11 files changed

+429
-239
lines changed

Include/internal/pycore_ceval.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ struct _ceval_runtime_state;
2222

2323
// Export for '_lsprof' shared extension
2424
PyAPI_FUNC(int) _PyEval_SetProfile(PyThreadState *tstate, Py_tracefunc func, PyObject *arg);
25+
extern int _PyEval_SetProfileAllThreads(PyInterpreterState *interp, Py_tracefunc func, PyObject *arg);
2526

2627
extern int _PyEval_SetTrace(PyThreadState *tstate, Py_tracefunc func, PyObject *arg);
28+
extern int _PyEval_SetTraceAllThreads(PyInterpreterState *interp, Py_tracefunc func, PyObject *arg);
2729

2830
extern int _PyEval_SetOpcodeTrace(PyFrameObject *f, bool enable);
2931

Include/internal/pycore_interp_structs.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ struct _ceval_runtime_state {
9999
// For example, we use a preallocated array
100100
// for the list of pending calls.
101101
struct _pending_calls pending_mainthread;
102-
PyMutex sys_trace_profile_mutex;
103102
};
104103

105104

@@ -951,8 +950,8 @@ struct _is {
951950
PyDict_WatchCallback builtins_dict_watcher;
952951

953952
_Py_GlobalMonitors monitors;
954-
bool sys_profile_initialized;
955-
bool sys_trace_initialized;
953+
_PyOnceFlag sys_profile_once_flag;
954+
_PyOnceFlag sys_trace_once_flag;
956955
Py_ssize_t sys_profiling_threads; /* Count of threads with c_profilefunc set */
957956
Py_ssize_t sys_tracing_threads; /* Count of threads with c_tracefunc set */
958957
PyObject *monitoring_callables[PY_MONITORING_TOOL_IDS][_PY_MONITORING_EVENTS];

Lib/test/test_free_threading/test_monitoring.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,31 @@ def during_threads(self):
195195

196196

197197
@threading_helper.requires_working_threading()
198+
class SetProfileAllThreadsMultiThreaded(InstrumentationMultiThreadedMixin, TestCase):
199+
"""Uses threading.setprofile_all_threads and repeatedly toggles instrumentation on and off"""
200+
201+
def setUp(self):
202+
self.set = False
203+
self.called = False
204+
205+
def after_test(self):
206+
self.assertTrue(self.called)
207+
208+
def tearDown(self):
209+
threading.setprofile_all_threads(None)
210+
211+
def trace_func(self, frame, event, arg):
212+
self.called = True
213+
return self.trace_func
214+
215+
def during_threads(self):
216+
if self.set:
217+
threading.setprofile_all_threads(self.trace_func)
218+
else:
219+
threading.setprofile_all_threads(None)
220+
self.set = not self.set
221+
222+
198223
class SetProfileAllMultiThreaded(TestCase):
199224
def test_profile_all_threads(self):
200225
done = threading.Event()
@@ -421,6 +446,38 @@ def noop():
421446

422447
self.observe_threads(noop, buf)
423448

449+
def test_trace_concurrent(self):
450+
# Test calling a function concurrently from a tracing and a non-tracing
451+
# thread
452+
b = threading.Barrier(2)
453+
454+
def func():
455+
for _ in range(100):
456+
pass
457+
458+
def noop():
459+
pass
460+
461+
def bg_thread():
462+
b.wait()
463+
func() # this may instrument `func`
464+
465+
def tracefunc(frame, event, arg):
466+
# These calls run under tracing can race with the background thread
467+
for _ in range(10):
468+
func()
469+
return tracefunc
470+
471+
t = Thread(target=bg_thread)
472+
t.start()
473+
try:
474+
sys.settrace(tracefunc)
475+
b.wait()
476+
noop()
477+
finally:
478+
sys.settrace(None)
479+
t.join()
480+
424481

425482
if __name__ == "__main__":
426483
unittest.main()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Fix a crash in the :term:`free threading` build when disabling profiling or
2+
tracing across all threads with :c:func:`PyEval_SetProfileAllThreads` or
3+
:c:func:`PyEval_SetTraceAllThreads` or their Python equivalents
4+
:func:`threading.settrace_all_threads` and
5+
:func:`threading.setprofile_all_threads`.

Python/bytecodes.c

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,15 @@ dummy_func(
178178
}
179179

180180
tier1 op(_MAYBE_INSTRUMENT, (--)) {
181-
if (tstate->tracing == 0) {
181+
#ifdef Py_GIL_DISABLED
182+
// For thread-safety, we need to check instrumentation version
183+
// even when tracing. Otherwise, another thread may concurrently
184+
// re-write the bytecode while we are executing this function.
185+
int check_instrumentation = 1;
186+
#else
187+
int check_instrumentation = (tstate->tracing == 0);
188+
#endif
189+
if (check_instrumentation) {
182190
uintptr_t global_version = _Py_atomic_load_uintptr_relaxed(&tstate->eval_breaker) & ~_PY_EVAL_EVENTS_MASK;
183191
uintptr_t code_version = FT_ATOMIC_LOAD_UINTPTR_ACQUIRE(_PyFrame_GetCode(frame)->_co_instrumentation_version);
184192
if (code_version != global_version) {

Python/ceval.c

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,21 +2510,10 @@ PyEval_SetProfile(Py_tracefunc func, PyObject *arg)
25102510
void
25112511
PyEval_SetProfileAllThreads(Py_tracefunc func, PyObject *arg)
25122512
{
2513-
PyThreadState *this_tstate = _PyThreadState_GET();
2514-
PyInterpreterState* interp = this_tstate->interp;
2515-
2516-
_PyRuntimeState *runtime = &_PyRuntime;
2517-
HEAD_LOCK(runtime);
2518-
PyThreadState* ts = PyInterpreterState_ThreadHead(interp);
2519-
HEAD_UNLOCK(runtime);
2520-
2521-
while (ts) {
2522-
if (_PyEval_SetProfile(ts, func, arg) < 0) {
2523-
PyErr_FormatUnraisable("Exception ignored in PyEval_SetProfileAllThreads");
2524-
}
2525-
HEAD_LOCK(runtime);
2526-
ts = PyThreadState_Next(ts);
2527-
HEAD_UNLOCK(runtime);
2513+
PyInterpreterState *interp = _PyInterpreterState_GET();
2514+
if (_PyEval_SetProfileAllThreads(interp, func, arg) < 0) {
2515+
/* Log _PySys_Audit() error */
2516+
PyErr_FormatUnraisable("Exception ignored in PyEval_SetProfileAllThreads");
25282517
}
25292518
}
25302519

@@ -2541,21 +2530,10 @@ PyEval_SetTrace(Py_tracefunc func, PyObject *arg)
25412530
void
25422531
PyEval_SetTraceAllThreads(Py_tracefunc func, PyObject *arg)
25432532
{
2544-
PyThreadState *this_tstate = _PyThreadState_GET();
2545-
PyInterpreterState* interp = this_tstate->interp;
2546-
2547-
_PyRuntimeState *runtime = &_PyRuntime;
2548-
HEAD_LOCK(runtime);
2549-
PyThreadState* ts = PyInterpreterState_ThreadHead(interp);
2550-
HEAD_UNLOCK(runtime);
2551-
2552-
while (ts) {
2553-
if (_PyEval_SetTrace(ts, func, arg) < 0) {
2554-
PyErr_FormatUnraisable("Exception ignored in PyEval_SetTraceAllThreads");
2555-
}
2556-
HEAD_LOCK(runtime);
2557-
ts = PyThreadState_Next(ts);
2558-
HEAD_UNLOCK(runtime);
2533+
PyInterpreterState *interp = _PyInterpreterState_GET();
2534+
if (_PyEval_SetTraceAllThreads(interp, func, arg) < 0) {
2535+
/* Log _PySys_Audit() error */
2536+
PyErr_FormatUnraisable("Exception ignored in PyEval_SetTraceAllThreads");
25592537
}
25602538
}
25612539

Python/generated_cases.c.h

Lines changed: 14 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Python/instrumentation.c

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,8 @@ set_version_raw(uintptr_t *ptr, uint32_t version)
10401040
static void
10411041
set_global_version(PyThreadState *tstate, uint32_t version)
10421042
{
1043+
ASSERT_WORLD_STOPPED();
1044+
10431045
assert((version & _PY_EVAL_EVENTS_MASK) == 0);
10441046
PyInterpreterState *interp = tstate->interp;
10451047
set_version_raw(&interp->ceval.instrumentation_version, version);
@@ -1939,28 +1941,26 @@ _Py_Instrument(PyCodeObject *code, PyInterpreterState *interp)
19391941

19401942

19411943
static int
1942-
instrument_all_executing_code_objects(PyInterpreterState *interp) {
1944+
instrument_all_executing_code_objects(PyInterpreterState *interp)
1945+
{
19431946
ASSERT_WORLD_STOPPED();
19441947

1945-
_PyRuntimeState *runtime = &_PyRuntime;
1946-
HEAD_LOCK(runtime);
1947-
PyThreadState* ts = PyInterpreterState_ThreadHead(interp);
1948-
HEAD_UNLOCK(runtime);
1949-
while (ts) {
1948+
int err = 0;
1949+
_Py_FOR_EACH_TSTATE_BEGIN(interp, ts) {
19501950
_PyInterpreterFrame *frame = ts->current_frame;
19511951
while (frame) {
19521952
if (frame->owner < FRAME_OWNED_BY_INTERPRETER) {
1953-
if (instrument_lock_held(_PyFrame_GetCode(frame), interp)) {
1954-
return -1;
1953+
err = instrument_lock_held(_PyFrame_GetCode(frame), interp);
1954+
if (err) {
1955+
goto done;
19551956
}
19561957
}
19571958
frame = frame->previous;
19581959
}
1959-
HEAD_LOCK(runtime);
1960-
ts = PyThreadState_Next(ts);
1961-
HEAD_UNLOCK(runtime);
19621960
}
1963-
return 0;
1961+
done:
1962+
_Py_FOR_EACH_TSTATE_END(interp);
1963+
return err;
19641964
}
19651965

19661966
static void
@@ -2006,6 +2006,7 @@ check_tool(PyInterpreterState *interp, int tool_id)
20062006
int
20072007
_PyMonitoring_SetEvents(int tool_id, _PyMonitoringEventSet events)
20082008
{
2009+
ASSERT_WORLD_STOPPED();
20092010
assert(0 <= tool_id && tool_id < PY_MONITORING_TOOL_IDS);
20102011
PyThreadState *tstate = _PyThreadState_GET();
20112012
PyInterpreterState *interp = tstate->interp;
@@ -2014,33 +2015,28 @@ _PyMonitoring_SetEvents(int tool_id, _PyMonitoringEventSet events)
20142015
return -1;
20152016
}
20162017

2017-
int res;
2018-
_PyEval_StopTheWorld(interp);
20192018
uint32_t existing_events = get_events(&interp->monitors, tool_id);
20202019
if (existing_events == events) {
2021-
res = 0;
2022-
goto done;
2020+
return 0;
20232021
}
20242022
set_events(&interp->monitors, tool_id, events);
20252023
uint32_t new_version = global_version(interp) + MONITORING_VERSION_INCREMENT;
20262024
if (new_version == 0) {
20272025
PyErr_Format(PyExc_OverflowError, "events set too many times");
2028-
res = -1;
2029-
goto done;
2026+
return -1;
20302027
}
20312028
set_global_version(tstate, new_version);
20322029
#ifdef _Py_TIER2
20332030
_Py_Executors_InvalidateAll(interp, 1);
20342031
#endif
2035-
res = instrument_all_executing_code_objects(interp);
2036-
done:
2037-
_PyEval_StartTheWorld(interp);
2038-
return res;
2032+
return instrument_all_executing_code_objects(interp);
20392033
}
20402034

20412035
int
20422036
_PyMonitoring_SetLocalEvents(PyCodeObject *code, int tool_id, _PyMonitoringEventSet events)
20432037
{
2038+
ASSERT_WORLD_STOPPED();
2039+
20442040
assert(0 <= tool_id && tool_id < PY_MONITORING_TOOL_IDS);
20452041
PyInterpreterState *interp = _PyInterpreterState_GET();
20462042
assert(events < (1 << _PY_MONITORING_LOCAL_EVENTS));
@@ -2052,28 +2048,20 @@ _PyMonitoring_SetLocalEvents(PyCodeObject *code, int tool_id, _PyMonitoringEvent
20522048
return -1;
20532049
}
20542050

2055-
int res;
2056-
_PyEval_StopTheWorld(interp);
20572051
if (allocate_instrumentation_data(code)) {
2058-
res = -1;
2059-
goto done;
2052+
return -1;
20602053
}
20612054

20622055
code->_co_monitoring->tool_versions[tool_id] = interp->monitoring_tool_versions[tool_id];
20632056

20642057
_Py_LocalMonitors *local = &code->_co_monitoring->local_monitors;
20652058
uint32_t existing_events = get_local_events(local, tool_id);
20662059
if (existing_events == events) {
2067-
res = 0;
2068-
goto done;
2060+
return 0;
20692061
}
20702062
set_local_events(local, tool_id, events);
20712063

2072-
res = force_instrument_lock_held(code, interp);
2073-
2074-
done:
2075-
_PyEval_StartTheWorld(interp);
2076-
return res;
2064+
return force_instrument_lock_held(code, interp);
20772065
}
20782066

20792067
int
@@ -2105,11 +2093,12 @@ int _PyMonitoring_ClearToolId(int tool_id)
21052093
}
21062094
}
21072095

2096+
_PyEval_StopTheWorld(interp);
21082097
if (_PyMonitoring_SetEvents(tool_id, 0) < 0) {
2098+
_PyEval_StartTheWorld(interp);
21092099
return -1;
21102100
}
21112101

2112-
_PyEval_StopTheWorld(interp);
21132102
uint32_t version = global_version(interp) + MONITORING_VERSION_INCREMENT;
21142103
if (version == 0) {
21152104
PyErr_Format(PyExc_OverflowError, "events set too many times");
@@ -2346,7 +2335,11 @@ monitoring_set_events_impl(PyObject *module, int tool_id, int event_set)
23462335
event_set &= ~(1 << PY_MONITORING_EVENT_BRANCH);
23472336
event_set |= (1 << PY_MONITORING_EVENT_BRANCH_RIGHT) | (1 << PY_MONITORING_EVENT_BRANCH_LEFT);
23482337
}
2349-
if (_PyMonitoring_SetEvents(tool_id, event_set)) {
2338+
PyInterpreterState *interp = _PyInterpreterState_GET();
2339+
_PyEval_StopTheWorld(interp);
2340+
int err = _PyMonitoring_SetEvents(tool_id, event_set);
2341+
_PyEval_StartTheWorld(interp);
2342+
if (err) {
23502343
return NULL;
23512344
}
23522345
Py_RETURN_NONE;
@@ -2427,7 +2420,11 @@ monitoring_set_local_events_impl(PyObject *module, int tool_id,
24272420
return NULL;
24282421
}
24292422

2430-
if (_PyMonitoring_SetLocalEvents((PyCodeObject*)code, tool_id, event_set)) {
2423+
PyInterpreterState *interp = _PyInterpreterState_GET();
2424+
_PyEval_StopTheWorld(interp);
2425+
int err = _PyMonitoring_SetLocalEvents((PyCodeObject*)code, tool_id, event_set);
2426+
_PyEval_StartTheWorld(interp);
2427+
if (err) {
24312428
return NULL;
24322429
}
24332430
Py_RETURN_NONE;

0 commit comments

Comments
 (0)