diff --git a/design/mvp/CanonicalABI.md b/design/mvp/CanonicalABI.md index 3bed87ed..65a6fa9a 100644 --- a/design/mvp/CanonicalABI.md +++ b/design/mvp/CanonicalABI.md @@ -130,8 +130,11 @@ class Store: def __init__(self): self.pending = [] - def invoke(self, f: FuncInst, caller, on_start, on_resolve) -> Call: - return f(caller, on_start, on_resolve) + def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve) -> Call: + host_caller = Supertask() + host_caller.inst = None + host_caller.supertask = caller + return f(host_caller, on_start, on_resolve) def tick(self): random.shuffle(self.pending) @@ -167,7 +170,7 @@ OnStart = Callable[[], list[any]] OnResolve = Callable[[Optional[list[any]]], None] class Supertask: - inst: ComponentInstance + inst: Optional[ComponentInstance] supertask: Optional[Supertask] class Call: @@ -190,6 +193,14 @@ However, as described in the [concurrency explainer], an async call's (currently) that the caller can know or do about it (hence there are currently no other methods on `Call`). +The optional `Supertask.inst` field either points to the `ComponentInstance` +containing the supertask or, if `None`, indicates that the supertask is a host +function. Because `Store.invoke` unconditionally appends a host `Supertask`, +every callstack is rooted by a host `Supertask`. There is no prohibition on +component-to-host-to-component calls (as long as the recursive call condition +checked by `call_is_recursive` are satisfied) and thus host `Supertask`s may +also appear anywhere else in the callstack. + ## Supporting definitions @@ -280,14 +291,17 @@ behavior and enforce invariants. ```python class ComponentInstance: store: Store + parent: Optional[ComponentInstance] table: Table may_leave: bool backpressure: int exclusive: bool num_waiting_to_enter: int - def __init__(self, store): + def __init__(self, store, parent = None): + assert(parent is None or parent.store is store) self.store = store + self.parent = parent self.table = Table() self.may_leave = True self.backpressure = 0 @@ -295,9 +309,106 @@ class ComponentInstance: self.num_waiting_to_enter = 0 ``` Components are always instantiated in the context of a `Store` which is saved -immutably in the `store` field. The other fields are described below as they -are used. - +immutably in the `store` field. + +If a component is instantiated by an `instantiate` expression in a "parent" +component, the parent's `ComponentInstance` is immutably saved in the `parent` +field of the child's `ComponentInstance`. If instead a component is +instantiated directly by the host, the `parent` field is `None`. Thus, the set +of component instances in a store forms a forest rooted by the component +instances that were instantiated directly by the host. + +How the host instantiates and invokes root components is up to the host and not +specified by the Component Model. Exports of previously-instantiated root +components *may* be supplied as the imports of subsequently-instantiated root +components. Due to the ordered nature of instantiation, root components cannot +directly import each others' exports in cyclic manner. However, the host *may* +perform cyclic component-to-host-to-component calls, in the same way that a +parent component can use `call_indirect` and a table of mutable `funcref`s to +make cyclic child-to-parent-to-child calls. + +Because a child component is fully encapsulated by its parent component (with +all child imports specified by the parent's `instantiate` expression and access +to all child exports controlled by the parent through its private instance index +space), the host does not have direct control over how a child component is +instantiated or invoked. However, if a child's ancestors transitively forward +the root component's host-supplied imports to the child, direct child-to-host +calls are possible. Symmetrically, if a child's ancestors transitively +re-export the child's exports from the root component, direct host-to-child +calls are possible. Consequently, direct calls between child components of +distinct parent components are also possible. + +As mentioned above, cyclic calls between components are made possible by +indirecting through a parent component or the host. However, for the time +being, a "recursive" call in which a single component instance is entered +multiple times on the same `Supertask` callstack is well-defined to trap upon +attempted reentry. There are several reasons for this trapping behavior: +* automatic [backpressure] would otherwise deadlock in unpredictable and + surprising ways; +* by default, most code does not expect [recursive reentrance] and will break + in subtle and potentially security sensitive ways if allowed; +* to properly handle recursive reentrance, an extra ABI parameter is required + to link recursive calls on the same stack and this requires opting in via + some [TBD](Concurrency.md#TODO) function effect type or canonical ABI option + +The `call_is_recursive` predicate is used by `canon_lift` and +`canon_resource_drop` (defined below) to detect recursive reentrance and +subsequently trap. The supporting `ancestors` function enumerates all +transitive parents of a node, *including the node itself*, in a Python `set`, +thereby allowing set-wise union (`|`), intersection (`&`) and difference (`-`). +```python +def call_is_recursive(caller: Supertask, callee_inst: ComponentInstance): + callee_insts = { callee_inst } | (ancestors(callee_inst) - ancestors(caller.inst)) + while caller is not None: + if callee_insts & ancestors(caller.inst): + return True + caller = caller.supertask + return False + +def ancestors(inst: Optional[ComponentInstance]) -> set[ComponentInstance]: + s = set() + while inst is not None: + s.add(inst) + inst = inst.parent + return s +``` +The `callee_insts` set contains all the component instances being freshly +entered by the call, always including the `callee_inst` itself. The subsequent +loop then tests whether *any* of the `callee_insts` is already on the stack. +This set-wise definition considers cases like the following to be recursive: +``` + +-------+ + | A |<-. + | +---+ | | +--->| B |----' + | +---+ | + +-------+ +``` +At the point when recursively calling back into `A`, `callee_inst` is `A` +and `caller` points to the following stack: +``` +caller --> |inst=None| --supertask--> |inst=B| --supertask--> |inst=None| --supertask--> None +``` +while `A` does not appear as the `inst` of any `Supertask` on this stack, +`callee_insts` is `{ A }` and `ancestors(B)` is `{ B, A }`, so the second iteration +of the loop sees a non-empty intersection and correctly determines that `A` is +being reentered. + +An optimizing implementation can avoid the overhead of sets and loops in +several ways: +* In the quite-common case that a component does not contain *both* core module + instances *and* component instances, inter-component recursion is not possible + and can thus be statically eliminated from the generated inter-component + trampolines. +* If the runtime imposes a modest per-store upper-bound on the number of + component instances, like 64, then an `i64` can be used to represent the + `set[ComponentInstance]`, assigning each component instance a bit. Then, + the `i64` representing the transitive union of all `supertask`'s + `ancestor(inst)`s can be propagated from caller to callee, allowing the + `while` loop to be replaced by a single bitwise-and of the callee's + `i64` with the transitive callers' `i64`. + +The other fields of `ComponentInstance` are described below as they are used. #### Table State @@ -804,7 +915,7 @@ class Task(Call, Supertask): opts: CanonicalOptions inst: ComponentInstance ft: FuncType - supertask: Optional[Task] + supertask: Supertask on_resolve: OnResolve num_borrows: int threads: list[Thread] @@ -838,37 +949,6 @@ called (by the `Task.return_` and `Task.cancel` methods, defined below). assert(self.num_borrows == 0) ``` -The `Task.trap_if_on_the_stack` method checks for unintended reentrance, -enforcing a [component invariant]. This guard uses the `Supertask` defined by -the [Embedding](#embedding) interface to walk up the async call tree defined as -part of [structured concurrency]. The async call tree is necessary to -distinguish between the deadlock-hazardous kind of reentrance (where the new -task is a transitive subtask of a task already running in the same component -instance) and the normal kind of async reentrance (where the new task is just a -sibling of any existing tasks running in the component instance). Note that, in -the [future](Concurrency.md#TODO), there will be a way for a function to opt in -(via function type attribute) to the hazardous kind of reentrance, which will -nuance this test. -```python - def trap_if_on_the_stack(self, inst): - c = self.supertask - while c is not None: - trap_if(c.inst is inst) - c = c.supertask -``` -An optimizing implementation can avoid the O(n) loop in `trap_if_on_the_stack` -in several ways: -* Reentrance by a child component can (often) be statically ruled out when the - parent component doesn't both lift and lower the child's imports and exports - (i.e., "donut wrapping"). -* Reentrance of the root component by the host can either be asserted not to - happen or be tracked in a per-root-component-instance flag. -* When a potentially-reenterable child component only lifts and lowers - synchronously, reentrance can be tracked in a per-component-instance flag. -* For the remaining cases, the live instances on the stack can be maintained in - a packed bit-vector (assigning each potentially-reenterable async component - instance a static bit position) that is passed by copy from caller to callee. - The `Task.needs_exclusive` predicate returns whether the Canonical ABI options indicate that the core wasm being executed does not expect to be reentered (e.g., because the code is using a single global linear memory shadow stack). @@ -3161,8 +3241,8 @@ Based on this, `canon_lift` is defined in chunks as follows, starting with how a `lift`ed function starts executing: ```python def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call: + trap_if(call_is_recursive(caller, inst)) task = Task(opts, inst, ft, caller, on_resolve) - task.trap_if_on_the_stack(inst) def thread_func(thread): if not task.enter(thread): return @@ -3176,16 +3256,16 @@ def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call: flat_ft = flatten_functype(opts, ft, 'lift') assert(types_match_values(flat_ft.params, flat_args)) ``` -Each call starts by immediately checking for unexpected reentrance using -`Task.trap_if_on_the_stack`. +Each lifted function call starts by immediately trapping on recursive +reentrance (as defined by `call_is_recursive` above). The `thread_func` is immediately called from a new `Thread` created and resumed -at the end of `canon_lift` and so control flow proceeds directly from the -`trap_if_on_stack` to the `enter`. `Task.enter` (defined above) suspends the -newly-created `Thread` if there is backpressure until the backpressure is -resolved. If the caller cancels the new `Task` while the `Task` is still -waiting to `enter`, the call is aborted before the arguments are lowered (which -means that owned-handle arguments are not transferred). +at the end of `canon_lift` and so control flow proceeds directly to the `enter`. +`Task.enter` (defined above) suspends the newly-created `Thread` if there is +backpressure until the backpressure is resolved. If the caller cancels the new +`Task` while the `Task` is still waiting to `enter`, the call is aborted before +the arguments are lowered (which means that owned-handle arguments are not +transferred). Once the backpressure gate is cleared, the `Thread` is added to the callee's component instance's table (storing the index for later retrieval by the @@ -3570,7 +3650,7 @@ def canon_resource_drop(rt, thread, i): callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor) [] = canon_lower(caller_opts, ft, callee, thread, [h.rep]) else: - thread.task.trap_if_on_the_stack(rt.impl) + trap_if(call_is_recursive(thread.task, rt.impl)) else: h.borrow_scope.num_borrows -= 1 return [] @@ -3587,9 +3667,9 @@ reentrance guard of `Task.enter`, an exception is made when the resource type's implementation-instance is the same as the current instance (which is statically known for any given `canon resource.drop`). -When a destructor isn't present, the rules still perform a reentrance check +When a destructor isn't present, there is still a trap on recursive reentrance since this is the caller's responsibility and the presence or absence of a -destructor is an encapsualted implementation detail of the resource type. +destructor is an encapsulated implementation detail of the resource type. ### `canon resource.rep` @@ -4807,6 +4887,7 @@ def canon_thread_available_parallelism(): [Concurrency Explainer]: Concurrency.md [Suspended]: Concurrency#thread-built-ins [Structured Concurrency]: Concurrency.md#subtasks-and-supertasks +[Recursive Reentrance]: Concurrency.md#subtasks-and-supertasks [Backpressure]: Concurrency.md#backpressure [Current Thread]: Concurrency.md#current-thread-and-task [Current Task]: Concurrency.md#current-thread-and-task diff --git a/design/mvp/Explainer.md b/design/mvp/Explainer.md index 0a0a9193..f173ea4f 100644 --- a/design/mvp/Explainer.md +++ b/design/mvp/Explainer.md @@ -2870,7 +2870,7 @@ three runtime invariants: component instance. 2. The Component Model disallows reentrance by trapping if a callee's component-instance is already on the stack when the call starts. - (For details, see [`trap_if_on_the_stack`](CanonicalABI.md#task-state) + (For details, see [`call_is_recursive`](CanonicalABI.md#component-instance-state) in the Canonical ABI explainer.) This default prevents obscure composition-time bugs and also enables more-efficient non-reentrant runtime glue code. This rule will be relaxed by an opt-in diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py index 89d51e61..83bad16e 100644 --- a/design/mvp/canonical-abi/definitions.py +++ b/design/mvp/canonical-abi/definitions.py @@ -189,8 +189,11 @@ class Store: def __init__(self): self.pending = [] - def invoke(self, f: FuncInst, caller, on_start, on_resolve) -> Call: - return f(caller, on_start, on_resolve) + def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve) -> Call: + host_caller = Supertask() + host_caller.inst = None + host_caller.supertask = caller + return f(host_caller, on_start, on_resolve) def tick(self): random.shuffle(self.pending) @@ -205,7 +208,7 @@ def tick(self): OnResolve = Callable[[Optional[list[any]]], None] class Supertask: - inst: ComponentInstance + inst: Optional[ComponentInstance] supertask: Optional[Supertask] class Call: @@ -252,20 +255,38 @@ class CanonicalOptions(LiftLowerOptions): class ComponentInstance: store: Store + parent: Optional[ComponentInstance] table: Table may_leave: bool backpressure: int exclusive: bool num_waiting_to_enter: int - def __init__(self, store): + def __init__(self, store, parent = None): + assert(parent is None or parent.store is store) self.store = store + self.parent = parent self.table = Table() self.may_leave = True self.backpressure = 0 self.exclusive = False self.num_waiting_to_enter = 0 +def call_is_recursive(caller: Supertask, callee_inst: ComponentInstance): + callee_insts = { callee_inst } | (ancestors(callee_inst) - ancestors(caller.inst)) + while caller is not None: + if callee_insts & ancestors(caller.inst): + return True + caller = caller.supertask + return False + +def ancestors(inst: Optional[ComponentInstance]) -> set[ComponentInstance]: + s = set() + while inst is not None: + s.add(inst) + inst = inst.parent + return s + #### Table State class Table: @@ -534,7 +555,7 @@ class State(Enum): opts: CanonicalOptions inst: ComponentInstance ft: FuncType - supertask: Optional[Task] + supertask: Supertask on_resolve: OnResolve num_borrows: int threads: list[Thread] @@ -560,12 +581,6 @@ def thread_stop(self, thread): trap_if(self.state != Task.State.RESOLVED) assert(self.num_borrows == 0) - def trap_if_on_the_stack(self, inst): - c = self.supertask - while c is not None: - trap_if(c.inst is inst) - c = c.supertask - def needs_exclusive(self): return not self.opts.async_ or self.opts.callback @@ -1984,8 +1999,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None): ### `canon lift` def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call: + trap_if(call_is_recursive(caller, inst)) task = Task(opts, inst, ft, caller, on_resolve) - task.trap_if_on_the_stack(inst) def thread_func(thread): if not task.enter(thread): return @@ -2167,7 +2182,7 @@ def canon_resource_drop(rt, thread, i): callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor) [] = canon_lower(caller_opts, ft, callee, thread, [h.rep]) else: - thread.task.trap_if_on_the_stack(rt.impl) + trap_if(call_is_recursive(thread.task, rt.impl)) else: h.borrow_scope.num_borrows -= 1 return [] diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py index 70a197f3..20a63ce6 100644 --- a/design/mvp/canonical-abi/run_tests.py +++ b/design/mvp/canonical-abi/run_tests.py @@ -2757,6 +2757,50 @@ def core_consumer(thread, args): run_lift(mk_opts(), consumer_inst, consumer_ft, core_consumer, lambda:[], lambda _:()) +def test_reentrance(): + def mk_task(supertask, inst): + t = Supertask() + t.supertask = supertask + t.inst = inst + return t + + store = Store() + root_task = mk_task(None, None) + + c1 = ComponentInstance(store, None) + c2 = ComponentInstance(store, None) + c1_task = mk_task(root_task, c1) + assert(call_is_recursive(mk_task(c1_task, None), c1)) + assert(not call_is_recursive(mk_task(c1_task, None), c2)) + c1c2_task = mk_task(c1_task, c2) + assert(call_is_recursive(mk_task(c1c2_task, None), c1)) + assert(call_is_recursive(mk_task(c1c2_task, None), c2)) + c1host_task = mk_task(c1_task, None) + assert(call_is_recursive(mk_task(c1host_task, None), c1)) + assert(not call_is_recursive(mk_task(c1host_task, None), c2)) + + p = ComponentInstance(store, None) + c1 = ComponentInstance(store, p) + c2 = ComponentInstance(store, p) + c3 = ComponentInstance(store, None) + c1_task = mk_task(root_task, c1) + c1c2_task = mk_task(c1_task, c2) + c1c2host_task = mk_task(c1c2_task, None) + assert(call_is_recursive(c1c2host_task, p)) + assert(call_is_recursive(c1c2host_task, c1)) + assert(call_is_recursive(c1c2host_task, c2)) + c1c2p_task = mk_task(c1c2_task, p) + assert(call_is_recursive(c1c2p_task, p)) + assert(call_is_recursive(c1c2p_task, c1)) + assert(call_is_recursive(c1c2p_task, c2)) + p_task = mk_task(root_task, p) + pc1_task = mk_task(p_task, c1) + pc1host_task = mk_task(pc1_task, None) + assert(call_is_recursive(pc1host_task, p)) + assert(call_is_recursive(pc1host_task, c1)) + assert(call_is_recursive(pc1host_task, c2)) + + test_roundtrips() test_handles() test_async_to_async() @@ -2782,5 +2826,6 @@ def core_consumer(thread, args): test_async_flat_params() test_threads() test_thread_cancel_callback() +test_reentrance() print("All tests passed")