Skip to content

Commit 992d48a

Browse files
committed
Refactor, document, make more robust and remove a lock
1 parent c0750ab commit 992d48a

File tree

1 file changed

+51
-40
lines changed

1 file changed

+51
-40
lines changed

src/C/context.jl

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,59 @@ function _atpyexit()
3030
end
3131

3232

33-
const MAIN_THREAD_TASK_LOCK = ReentrantLock()
34-
const MAIN_THREAD_CHANNEL_INPUT = Channel(1)
35-
const MAIN_THREAD_CHANNEL_OUTPUT = Channel(1)
36-
37-
# Execute f() on the main thread.
38-
function on_main_thread(f)
39-
@lock MAIN_THREAD_TASK_LOCK begin
40-
put!(MAIN_THREAD_CHANNEL_INPUT, f)
41-
take!(MAIN_THREAD_CHANNEL_OUTPUT)
33+
function setup_onfixedthread()
34+
channel_input = Channel(1)
35+
channel_output = Channel(1)
36+
islaunched = Ref(false) # use Ref to avoid closure boxing of variable
37+
function launch_worker(tid)
38+
islaunched[] && error("Cannot launch more than once: call setup_onfixedthread again if need be.")
39+
islaunched[] = true
40+
worker_task = Task() do
41+
while true
42+
f = take!(channel_input)
43+
put!(channel_output, f())
44+
end
45+
end
46+
# code adapted from set_task_tid! in StableTasks.jl, itself taken from Dagger.jl
47+
worker_task.sticky = true
48+
for _ in 1:100
49+
# try to fix the task id to tid, retrying up to 100 times
50+
ret = ccall(:jl_set_task_tid, Cint, (Any, Cint), worker_task, tid-1)
51+
if ret == 1
52+
break # success
53+
elseif ret == 0
54+
yield()
55+
else
56+
error("Unexpected retcode from jl_set_task_tid: $ret")
57+
end
58+
end
59+
if Threads.threadid(worker_task) != tid
60+
error("Failed setting the thread ID to $tid.")
61+
end
62+
schedule(worker_task)
63+
end
64+
function onfixedthread(f)
65+
put!(channel_input, f)
66+
take!(channel_output)
4267
end
68+
launch_worker, onfixedthread
4369
end
4470

71+
# launch_on_main_thread is used in init_context(), after which on_main_thread becomes usable
72+
const launch_on_main_thread, on_main_thread = setup_onfixedthread()
73+
74+
"""
75+
on_main_thread(f)
76+
77+
Execute `f()` on the main thread.
78+
79+
!!! warning
80+
The value returned by `on_main_thread(f)` cannot be type-inferred by the compiler:
81+
if necessary, use explicit type annotations such as `on_main_thread(f)::T`, where `T` is
82+
the expected return type.
83+
"""
84+
on_main_thread
85+
4586

4687
function init_context()
4788

@@ -254,14 +295,7 @@ function init_context()
254295
"Only Python 3.9+ is supported, this is Python $(CTX.version) at $(CTX.exe_path===missing ? "unknown location" : CTX.exe_path).",
255296
)
256297

257-
main_thread_task = Task() do
258-
while true
259-
f = take!(MAIN_THREAD_CHANNEL_INPUT)
260-
put!(MAIN_THREAD_CHANNEL_OUTPUT, f())
261-
end
262-
end
263-
set_task_tid!(main_thread_task, Threads.threadid())
264-
schedule(main_thread_task)
298+
launch_on_main_thread(Threads.threadid()) # makes on_main_thread usable
265299

266300
@debug "Initialized PythonCall.jl" CTX.is_embedded CTX.is_initialized CTX.exe_path CTX.lib_path CTX.lib_ptr CTX.pyprogname CTX.pyhome CTX.version
267301

@@ -283,26 +317,3 @@ const PYTHONCALL_PKGID = Base.PkgId(PYTHONCALL_UUID, "PythonCall")
283317

284318
const PYCALL_UUID = Base.UUID("438e738f-606a-5dbb-bf0a-cddfbfd45ab0")
285319
const PYCALL_PKGID = Base.PkgId(PYCALL_UUID, "PyCall")
286-
287-
288-
# taken from StableTasks.jl, itself taken from Dagger.jl
289-
function set_task_tid!(task::Task, tid::Integer)
290-
task.sticky = true
291-
ctr = 0
292-
while true
293-
ret = ccall(:jl_set_task_tid, Cint, (Any, Cint), task, tid-1)
294-
if ret == 1
295-
break
296-
elseif ret == 0
297-
yield()
298-
else
299-
error("Unexpected retcode from jl_set_task_tid: $ret")
300-
end
301-
ctr += 1
302-
if ctr > 10
303-
@warn "Setting task TID to $tid failed, giving up!"
304-
return
305-
end
306-
end
307-
@assert Threads.threadid(task) == tid "jl_set_task_tid failed!"
308-
end

0 commit comments

Comments
 (0)