From cd9b5b2e959c302534182d940e5eecef878f548a Mon Sep 17 00:00:00 2001 From: Luke Wagner Date: Mon, 4 May 2026 15:15:36 -0500 Subject: [PATCH] Prevent waitable-set.{wait,poll} from stealing events from sync built-ins Resolves #642 --- design/mvp/CanonicalABI.md | 69 ++++++++++++++++++------- design/mvp/canonical-abi/definitions.py | 45 +++++++++++----- design/mvp/canonical-abi/run_tests.py | 1 + 3 files changed, 82 insertions(+), 33 deletions(-) diff --git a/design/mvp/CanonicalABI.md b/design/mvp/CanonicalABI.md index 098e4c0c..14217852 100644 --- a/design/mvp/CanonicalABI.md +++ b/design/mvp/CanonicalABI.md @@ -1013,14 +1013,23 @@ A waitable can belong to at most one "waitable set" (defined next) which is referred to by the `wset` field. A `Waitable`'s `pending_event` is delivered (via `get_pending_event`) when core wasm code waits on its waitable set (via `waitable-set.wait` or, when using `callback`, by returning to the event loop). + +Lastly, a waitable cannot be waited on *both* asynchronously (via +waitable set) and synchronously (via synchronous `subtask.cancel` or +`{stream,future}.{,cancel-}{read,write}`) since this raises the possibility that +the waitable set "steals" events from the synchronous waiter, leaving the +synchronous waiter forever waiting. This condition is asserted by the `Waitable` +methods here and guarded via traps by the relevant built-ins below. ```python class Waitable: pending_event: Optional[Callable[[], EventTuple]] wset: Optional[WaitableSet] + has_sync_waiter: bool def __init__(self): self.pending_event = None self.wset = None + self.has_sync_waiter = False def set_pending_event(self, pending_event): self.pending_event = pending_event @@ -1028,12 +1037,22 @@ class Waitable: def has_pending_event(self): return bool(self.pending_event) + def in_waitable_set(self): + return self.wset is not None + + def wait_for_pending_event(self): + assert(not self.in_waitable_set() and not self.has_sync_waiter) + self.has_sync_waiter = True + current_thread().wait_until(self.has_pending_event, cancellable = False) + self.has_sync_waiter = False + def get_pending_event(self) -> EventTuple: pending_event = self.pending_event self.pending_event = None return pending_event() def join(self, wset): + assert(not self.has_sync_waiter) if self.wset: self.wset.elems.remove(self) self.wset = wset @@ -1042,6 +1061,7 @@ class Waitable: def drop(self): assert(not self.has_pending_event()) + assert(not self.has_sync_waiter) self.join(None) ``` @@ -1785,10 +1805,9 @@ state which is referenced by the `shared` field and either points to a ```python class CopyState(Enum): IDLE = 1 - SYNC_COPYING = 2 - ASYNC_COPYING = 3 - CANCELLING_COPY = 4 - DONE = 5 + COPYING = 2 + CANCELLING_COPY = 3 + DONE = 4 class CopyEnd(Waitable): state: CopyState @@ -1803,7 +1822,7 @@ class CopyEnd(Waitable): match self.state: case CopyState.IDLE | CopyState.DONE: return False - case CopyState.SYNC_COPYING | CopyState.ASYNC_COPYING | CopyState.CANCELLING_COPY: + case CopyState.COPYING | CopyState.CANCELLING_COPY: return True assert(False) @@ -1823,9 +1842,7 @@ class WritableStreamEnd(CopyEnd): As shown in `drop`, attempting to drop a readable or writable end while a copy is in progress or in the process of being cancelled traps. This means that client code must take care to wait for these operations to finish (potentially -cancelling them via `stream.cancel-{read,write}`) before dropping. The -`SYNC_COPY` vs. `ASYNC_COPY` distinction is tracked in the state to determine -whether the copy operation can be cancelled. +cancelling them via `stream.cancel-{read,write}`) before dropping. The polymorphic `copy` method dispatches to either `ReadableStream.read` or `WritableStream.write` and allows the implementations of `stream.{read,write}` @@ -4195,6 +4212,7 @@ def canon_waitable_join(wi, si): trap_if(not inst.may_leave) w = inst.handles.get(wi) trap_if(not isinstance(w, Waitable)) + trap_if(w.has_sync_waiter) if si == 0: w.join(None) else: @@ -4203,6 +4221,10 @@ def canon_waitable_join(wi, si): w.join(wset) return [] ``` +As described with the definition of `Waitable` above, to prevent surprising +deadlocks, a waitable that is currently being synchronously waited on traps if +added to a waitable set. + Note that tables do not allow elements at index `0`, so `0` is a valid sentinel that tells `join` to remove the given waitable from any set that it is currently a part of. Waitables can be a member of at most one set, so if the @@ -4248,6 +4270,7 @@ def canon_subtask_cancel(async_, i): trap_if(not isinstance(subtask, Subtask)) trap_if(subtask.resolve_delivered()) trap_if(subtask.cancellation_requested) + trap_if(subtask.in_waitable_set() and not async_) if subtask.resolved(): assert(subtask.has_pending_event()) else: @@ -4255,7 +4278,7 @@ def canon_subtask_cancel(async_, i): subtask.on_cancel() if not subtask.resolved(): if not async_: - thread.wait_until(subtask.resolved) + subtask.wait_for_pending_event() else: return [BLOCKED] code,index,payload = subtask.get_pending_event() @@ -4268,7 +4291,8 @@ unconditionally traps if it transitively attempts to make a synchronous call to `subtask.cancel` (regardless of whether the cancellation would have succeeded without blocking). The other traps disallow calling `subtask.cancel` twice for the same subtask or after the supertask has already been notified that the -subtask has returned. +subtask has returned or if the subtask is already being asynchronously waited +on via waitable set. A race condition handled by the above code is that it's possible for a subtask to have already resolved (by calling `task.return` or `task.cancel`) and @@ -4384,11 +4408,14 @@ def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n): Next, `stream_copy` checks that the element at index `i` is of the right type and allowed to start a new copy. (In the future, the "trap if not `IDLE`" condition could be relaxed to allow multiple pipelined reads or writes.) +There is also a trap if attempting to synchronously read or write from a +stream that is already being asynchronously waited on via waitable set. ```python e = thread.task.inst.handles.get(i) trap_if(not isinstance(e, EndT)) trap_if(e.shared.t != stream_t.t) trap_if(e.state != CopyState.IDLE) + trap_if(e.in_waitable_set() and not opts.async_) ``` Then a readable or writable buffer is created which (in `Buffer`'s constructor) @@ -4425,6 +4452,7 @@ independently of the `addrtype`. ```python def stream_event(result, reclaim_buffer): reclaim_buffer() + assert(e.copying()) if result == CopyResult.DROPPED: e.state = CopyState.DONE else: @@ -4440,6 +4468,7 @@ independently of the `addrtype`. def on_copy_done(result): e.set_pending_event(partial(stream_event, result, reclaim_buffer = lambda:())) + e.state = CopyState.COPYING e.copy(thread.task.inst, buffer, on_copy, on_copy_done) ``` @@ -4451,10 +4480,8 @@ synchronously and return `BLOCKED` if not: ```python if not e.has_pending_event(): if not opts.async_: - e.state = CopyState.SYNC_COPYING - thread.wait_until(e.has_pending_event) + e.wait_for_pending_event() else: - e.state = CopyState.ASYNC_COPYING return [BLOCKED] code,index,payload = e.get_pending_event() assert(code == event_code and index == i and payload != BLOCKED) @@ -4507,6 +4534,7 @@ def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr): trap_if(not isinstance(e, EndT)) trap_if(e.shared.t != future_t.t) trap_if(e.state != CopyState.IDLE) + trap_if(e.in_waitable_set() and not opts.async_) assert(not contains_borrow(future_t)) cx = LiftLowerContext(opts, thread.task.inst, borrow_scope = None) @@ -4526,6 +4554,7 @@ of elements copied is not packed in the high 28 bits; they're always zero. ```python def future_event(result): assert((buffer.remain() == 0) == (result == CopyResult.COMPLETED)) + assert(e.copying()) if result == CopyResult.DROPPED or result == CopyResult.COMPLETED: e.state = CopyState.DONE else: @@ -4536,6 +4565,7 @@ of elements copied is not packed in the high 28 bits; they're always zero. assert(result != CopyResult.DROPPED or event_code == EventCode.FUTURE_WRITE) e.set_pending_event(partial(future_event, result)) + e.state = CopyState.COPYING e.copy(thread.task.inst, buffer, on_copy_done) ``` @@ -4544,10 +4574,8 @@ synchronously and returning either the progress made or `BLOCKED`. ```python if not e.has_pending_event(): if not opts.async_: - e.state = CopyState.SYNC_COPYING - thread.wait_until(e.has_pending_event) + e.wait_for_pending_event() else: - e.state = CopyState.ASYNC_COPYING return [BLOCKED] code,index,payload = e.get_pending_event() assert(code == event_code and index == i) @@ -4591,13 +4619,14 @@ def cancel_copy(EndT, event_code, stream_or_future_t, async_, i): e = thread.task.inst.handles.get(i) trap_if(not isinstance(e, EndT)) trap_if(e.shared.t != stream_or_future_t.t) - trap_if(e.state != CopyState.ASYNC_COPYING) + trap_if(e.state != CopyState.COPYING or e.has_sync_waiter) + trap_if(e.in_waitable_set() and not async_) e.state = CopyState.CANCELLING_COPY if not e.has_pending_event(): e.shared.cancel() if not e.has_pending_event(): if not async_: - thread.wait_until(e.has_pending_event) + e.wait_for_pending_event() else: return [BLOCKED] code,index,payload = e.get_pending_event() @@ -4610,7 +4639,9 @@ unconditionally traps if it transitively attempts to make a synchronous call to have completed without blocking). There is also a trap if there is not currently an async copy in progress (sync copies do not expect or check for cancellation and thus cannot be cancelled, and repeatedly cancelling the same -async copy after the first call blocked is not allowed). +async copy after the first call blocked is not allowed). Lastly, there is a +trap if attempting to synchronously cancel a stream operation when the stream +end is already being asynchronously waited on by a waitable set. The *first* check for `e.has_pending_event()` catches the case where the copy has already racily finished, in which case we must *not* call `cancel()`. Calling diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py index 8d8bd156..c20ff525 100644 --- a/design/mvp/canonical-abi/definitions.py +++ b/design/mvp/canonical-abi/definitions.py @@ -566,10 +566,12 @@ class EventCode(IntEnum): class Waitable: pending_event: Optional[Callable[[], EventTuple]] wset: Optional[WaitableSet] + has_sync_waiter: bool def __init__(self): self.pending_event = None self.wset = None + self.has_sync_waiter = False def set_pending_event(self, pending_event): self.pending_event = pending_event @@ -577,12 +579,22 @@ def set_pending_event(self, pending_event): def has_pending_event(self): return bool(self.pending_event) + def in_waitable_set(self): + return self.wset is not None + + def wait_for_pending_event(self): + assert(not self.in_waitable_set() and not self.has_sync_waiter) + self.has_sync_waiter = True + current_thread().wait_until(self.has_pending_event, cancellable = False) + self.has_sync_waiter = False + def get_pending_event(self) -> EventTuple: pending_event = self.pending_event self.pending_event = None return pending_event() def join(self, wset): + assert(not self.has_sync_waiter) if self.wset: self.wset.elems.remove(self) self.wset = wset @@ -591,6 +603,7 @@ def join(self, wset): def drop(self): assert(not self.has_pending_event()) + assert(not self.has_sync_waiter) self.join(None) class WaitableSet: @@ -963,10 +976,9 @@ def none_or_number_type(t): class CopyState(Enum): IDLE = 1 - SYNC_COPYING = 2 - ASYNC_COPYING = 3 - CANCELLING_COPY = 4 - DONE = 5 + COPYING = 2 + CANCELLING_COPY = 3 + DONE = 4 class CopyEnd(Waitable): state: CopyState @@ -981,7 +993,7 @@ def copying(self): match self.state: case CopyState.IDLE | CopyState.DONE: return False - case CopyState.SYNC_COPYING | CopyState.ASYNC_COPYING | CopyState.CANCELLING_COPY: + case CopyState.COPYING | CopyState.CANCELLING_COPY: return True assert(False) @@ -2355,6 +2367,7 @@ def canon_waitable_join(wi, si): trap_if(not inst.may_leave) w = inst.handles.get(wi) trap_if(not isinstance(w, Waitable)) + trap_if(w.has_sync_waiter) if si == 0: w.join(None) else: @@ -2375,6 +2388,7 @@ def canon_subtask_cancel(async_, i): trap_if(not isinstance(subtask, Subtask)) trap_if(subtask.resolve_delivered()) trap_if(subtask.cancellation_requested) + trap_if(subtask.in_waitable_set() and not async_) if subtask.resolved(): assert(subtask.has_pending_event()) else: @@ -2382,7 +2396,7 @@ def canon_subtask_cancel(async_, i): subtask.on_cancel() if not subtask.resolved(): if not async_: - thread.wait_until(subtask.resolved) + subtask.wait_for_pending_event() else: return [BLOCKED] code,index,payload = subtask.get_pending_event() @@ -2437,6 +2451,7 @@ def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n): trap_if(not isinstance(e, EndT)) trap_if(e.shared.t != stream_t.t) trap_if(e.state != CopyState.IDLE) + trap_if(e.in_waitable_set() and not opts.async_) assert(not isinstance(stream_t, CharType)) assert(not contains_borrow(stream_t)) @@ -2445,6 +2460,7 @@ def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n): def stream_event(result, reclaim_buffer): reclaim_buffer() + assert(e.copying()) if result == CopyResult.DROPPED: e.state = CopyState.DONE else: @@ -2460,14 +2476,13 @@ def on_copy(reclaim_buffer): def on_copy_done(result): e.set_pending_event(partial(stream_event, result, reclaim_buffer = lambda:())) + e.state = CopyState.COPYING e.copy(thread.task.inst, buffer, on_copy, on_copy_done) if not e.has_pending_event(): if not opts.async_: - e.state = CopyState.SYNC_COPYING - thread.wait_until(e.has_pending_event) + e.wait_for_pending_event() else: - e.state = CopyState.ASYNC_COPYING return [BLOCKED] code,index,payload = e.get_pending_event() assert(code == event_code and index == i and payload != BLOCKED) @@ -2492,6 +2507,7 @@ def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr): trap_if(not isinstance(e, EndT)) trap_if(e.shared.t != future_t.t) trap_if(e.state != CopyState.IDLE) + trap_if(e.in_waitable_set() and not opts.async_) assert(not contains_borrow(future_t)) cx = LiftLowerContext(opts, thread.task.inst, borrow_scope = None) @@ -2499,6 +2515,7 @@ def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr): def future_event(result): assert((buffer.remain() == 0) == (result == CopyResult.COMPLETED)) + assert(e.copying()) if result == CopyResult.DROPPED or result == CopyResult.COMPLETED: e.state = CopyState.DONE else: @@ -2509,14 +2526,13 @@ def on_copy_done(result): assert(result != CopyResult.DROPPED or event_code == EventCode.FUTURE_WRITE) e.set_pending_event(partial(future_event, result)) + e.state = CopyState.COPYING e.copy(thread.task.inst, buffer, on_copy_done) if not e.has_pending_event(): if not opts.async_: - e.state = CopyState.SYNC_COPYING - thread.wait_until(e.has_pending_event) + e.wait_for_pending_event() else: - e.state = CopyState.ASYNC_COPYING return [BLOCKED] code,index,payload = e.get_pending_event() assert(code == event_code and index == i) @@ -2543,13 +2559,14 @@ def cancel_copy(EndT, event_code, stream_or_future_t, async_, i): e = thread.task.inst.handles.get(i) trap_if(not isinstance(e, EndT)) trap_if(e.shared.t != stream_or_future_t.t) - trap_if(e.state != CopyState.ASYNC_COPYING) + trap_if(e.state != CopyState.COPYING or e.has_sync_waiter) + trap_if(e.in_waitable_set() and not async_) e.state = CopyState.CANCELLING_COPY if not e.has_pending_event(): e.shared.cancel() if not e.has_pending_event(): if not async_: - thread.wait_until(e.has_pending_event) + e.wait_for_pending_event() else: return [BLOCKED] code,index,payload = e.get_pending_event() diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py index 2b6e3c03..0f0dbfa3 100644 --- a/design/mvp/canonical-abi/run_tests.py +++ b/design/mvp/canonical-abi/run_tests.py @@ -1592,6 +1592,7 @@ def core_func(args): [event] = canon_waitable_set_wait(True, MemInst(mem, 'i32'), seti, retp) assert(event == EventCode.STREAM_READ) assert(mem[retp+0] == rsi4) + [] = canon_waitable_join(rsi4, 0) result,n = unpack_result(mem[retp+4]) assert(n == 4 and result == CopyResult.COMPLETED) [ret] = canon_stream_read(StreamType(U8Type()), sync_opts, rsi4, 0, 4)