@@ -29,6 +29,20 @@ function _atpyexit()
29
29
return
30
30
end
31
31
32
+
33
+ const MAIN_THREAD_TASK_LOCK = ReentrantLock ()
34
+ const MAIN_THREAD_CHANNEL_INPUT = Channel (1 )
35
+ const MAIN_THREAD_CHANNEL_OUTPUT = Channel (1 )
36
+
37
+ # Execute f() on the main thread.
38
+ function on_main_thread (f)
39
+ @lock MAIN_THREAD_TASK_LOCK begin
40
+ put! (MAIN_THREAD_CHANNEL_INPUT, f)
41
+ take! (MAIN_THREAD_CHANNEL_OUTPUT)
42
+ end
43
+ end
44
+
45
+
32
46
function init_context ()
33
47
34
48
CTX. is_embedded = hasproperty (Base. Main, :__PythonCall_libptr )
@@ -240,6 +254,15 @@ function init_context()
240
254
" Only Python 3.9+ is supported, this is Python $(CTX. version) at $(CTX. exe_path=== missing ? " unknown location" : CTX. exe_path) ." ,
241
255
)
242
256
257
+ main_thread_task = Task () do
258
+ while true
259
+ f = take! (MAIN_THREAD_CHANNEL_INPUT)
260
+ put! (MAIN_THREAD_CHANNEL_OUTPUT, f ())
261
+ end
262
+ end
263
+ set_task_tid! (main_thread_task, Threads. threadid ())
264
+ schedule (main_thread_task)
265
+
243
266
@debug " Initialized PythonCall.jl" CTX. is_embedded CTX. is_initialized CTX. exe_path CTX. lib_path CTX. lib_ptr CTX. pyprogname CTX. pyhome CTX. version
244
267
245
268
return
@@ -260,3 +283,26 @@ const PYTHONCALL_PKGID = Base.PkgId(PYTHONCALL_UUID, "PythonCall")
260
283
261
284
const PYCALL_UUID = Base. UUID (" 438e738f-606a-5dbb-bf0a-cddfbfd45ab0" )
262
285
const PYCALL_PKGID = Base. PkgId (PYCALL_UUID, " PyCall" )
286
+
287
+
288
+ # taken from StableTasks.jl, itself taken from Dagger.jl
289
+ function set_task_tid! (task:: Task , tid:: Integer )
290
+ task. sticky = true
291
+ ctr = 0
292
+ while true
293
+ ret = ccall (:jl_set_task_tid , Cint, (Any, Cint), task, tid- 1 )
294
+ if ret == 1
295
+ break
296
+ elseif ret == 0
297
+ yield ()
298
+ else
299
+ error (" Unexpected retcode from jl_set_task_tid: $ret " )
300
+ end
301
+ ctr += 1
302
+ if ctr > 10
303
+ @warn " Setting task TID to $tid failed, giving up!"
304
+ return
305
+ end
306
+ end
307
+ @assert Threads. threadid (task) == tid " jl_set_task_tid failed!"
308
+ end
0 commit comments