diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 28568bd893c5..ad3d5bc66469 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -153,21 +153,39 @@ def _run_event_loop(): @staticmethod def reset_state(): + event_loop_thread_to_join = None with AsyncWrapper._lock: if AsyncWrapper._event_loop: AsyncWrapper._event_loop.call_soon_threadsafe( AsyncWrapper._event_loop.stop) if AsyncWrapper._event_loop_thread: - AsyncWrapper._event_loop_thread.join() + event_loop_thread_to_join = AsyncWrapper._event_loop_thread AsyncWrapper._event_loop = None AsyncWrapper._event_loop_thread = None if AsyncWrapper._loop_started is not None: AsyncWrapper._loop_started.clear() - for pool in AsyncWrapper._pool.values(): - pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown( - wait=True, cancel_futures=True) + pools = list(AsyncWrapper._pool.values()) + + # We must join the asyncio event loop thread outside of the lock block. + # If joined inside the lock, the waiting thread holds the lock while blocking, + # preventing active coroutines' done callbacks from acquiring the lock on the + # event loop thread, resulting in a deadlock. + if event_loop_thread_to_join: + event_loop_thread_to_join.join() + + # We must acquire and shut down the thread pools outside of the lock block. + # If shutdown(wait=True) is called inside the lock, the caller blocks holding + # the lock, preventing active worker threads from acquiring the lock to run + # their done callbacks, resulting in a deadlock. + pools_to_shutdown = [ + pool.acquire(AsyncWrapper.initialize_pool(1)) for pool in pools + ] + + for pool in pools_to_shutdown: + pool.shutdown(wait=True, cancel_futures=True) + with AsyncWrapper._lock: AsyncWrapper._pool = {} AsyncWrapper._processing_elements = {} @@ -268,7 +286,8 @@ async def _collect(result): def decrement_items_in_buffer(self, future): with AsyncWrapper._lock: - AsyncWrapper._items_in_buffer[self._uuid] -= 1 + if self._uuid in AsyncWrapper._items_in_buffer: + AsyncWrapper._items_in_buffer[self._uuid] -= 1 def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): """Schedules an item to be processed asynchronously if there is room. diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index 81c7b8e163ff..39901d791fb9 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -16,6 +16,7 @@ # import logging +import multiprocessing import random import time import unittest @@ -487,6 +488,40 @@ def add_item(i): self.check_output(results[i], expected_outputs['key' + str(i)]) self.assertEqual(bag_states['key' + str(i)].items, []) + @staticmethod + def _run_reset_state_concurrent_teardown(use_asyncio): + dofn = BasicDofn(sleep_time=0.5) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=use_asyncio) + async_dofn.setup() + fake_bag_state = FakeBagState([]) + fake_timer = FakeTimer(0) + + # Start processing an item. This starts a worker thread/coroutine sleeping for 0.5s. + async_dofn.process(('key1', 1), to_process=fake_bag_state, timer=fake_timer) + time.sleep(0.05) + + # Verify that calling reset_state() while background tasks are actively running + # completes cleanly without causing lock-ordering deadlocks. + async_lib.AsyncWrapper.reset_state() + + def test_reset_state_concurrent_teardown(self): + # Verify concurrent teardown safety in a separate process to prevent any potential + # regressions from freezing the main pytest process at exit. + p = multiprocessing.Process( + target=AsyncTest._run_reset_state_concurrent_teardown, + args=(self.use_asyncio, )) + p.start() + p.join(timeout=10.0) + + if p.is_alive(): + p.terminate() + p.join() + self.fail( + "reset_state() deadlocked/hung waiting for active threads/tasks to finish" + ) + else: + self.assertEqual(p.exitcode, 0) + if __name__ == '__main__': unittest.main()