diff --git a/dapr/ext/workflow/AGENTS.md b/dapr/ext/workflow/AGENTS.md index d0aaa8878..112c7a117 100644 --- a/dapr/ext/workflow/AGENTS.md +++ b/dapr/ext/workflow/AGENTS.md @@ -107,6 +107,26 @@ The entry point for registration and lifecycle: Internally wraps user functions: workflow functions get a `DaprWorkflowContext`, activity functions get a `WorkflowActivityContext`. Tracks registration state via `_workflow_registered` / `_activity_registered` attributes on functions to prevent double registration. +#### Sync and async activities + +Activities can be either `def my_activity(ctx, inp)` or `async def my_activity(ctx, inp)`. At registration, `_make_activity_wrapper` calls `_is_async_callable(fn)` to detect async-ness. That helper unwraps `functools.partial`, `@functools.wraps` chains, and callable-class `__call__` so common decorator patterns route correctly. The wrapper is built `async def` or `def` to match, then stored in the registry. + +At dispatch time (the gRPC stream loop in `_durabletask/worker.py`), `is_async_callable(activity_fn)` on the wrapper selects between two handlers. + +- **Async activities** go through `_execute_activity_async`, then `_ActivityExecutor.execute_async`, which awaits `fn(...)` directly on the event loop. The gRPC response is delivered via `loop.run_in_executor(self._async_worker_manager.thread_pool, stub.CompleteActivityTask, ...)` — the same pool sync activities use, sized by `maximum_thread_pool_workers`. +- **Sync activities** go through `_execute_activity`, dispatched to the thread pool by `_AsyncWorkerManager._run_func`. The activity runs on a worker thread, and the response is delivered from the same thread. + +Workflow (orchestrator) functions must remain generators (`def` with `yield`). They cannot be `async def` because durabletask's deterministic replay depends on synchronous generator semantics. Only activities support async. + +**Decorator ordering gotcha.** Wrapping `@wfr.activity` over `@alternate_name(...)` over `async def` works because `@alternate_name` now emits an `async def innerfn` when the wrapped function is async. A user-written decorator that wraps an async function in a sync `def` (without `@functools.wraps` exposing `__wrapped__`) defeats `_is_async_callable`, routes the activity to the sync path, and produces an un-awaited coroutine. Such decorators should use `@functools.wraps(fn)` so the unwrap walks through them. + +**`maximum_thread_pool_workers` covers both paths.** This knob sizes the worker thread pool used for sync-activity bodies and for async-activity gRPC response sends. Mixed workloads with long-running sync activities can starve async response delivery (and vice versa) since they share the pool — size to the sum of peak sync activity concurrency and peak in-flight async response sends. + +**Concurrency sizing and load characterization.** See `docs/concurrency.md` for sizing recommendations (`maximum_concurrent_activity_work_items`, `maximum_thread_pool_workers`) and an async-vs-sync decision tree. `tests/ext/workflow/durabletask/test_async_dispatch_regression.py` (marked `perf`) guards the core invariant: a batch of async activities overlaps on the event loop instead of serializing through the thread pool. + +**grpc.aio poller log noise.** The async client can emit benign `BlockingIOError: [Errno 11]` ERROR lines from `grpc.aio`'s `PollerCompletionQueue` under load. It is harmless and retried. `get_grpc_aio_channel` installs an internal `asyncio`-logger filter (`_silence_grpc_aio_poller_noise`) that drops only those records, so the SDK suppresses it automatically with no user action. + + ### DaprWorkflowClient (`dapr_workflow_client.py`) Client for workflow lifecycle management: @@ -165,7 +185,7 @@ Retry configuration for activities and child workflows: 1. **Registration**: User decorates functions with `@wfr.workflow` / `@wfr.activity`. The runtime wraps them and stores them in the durabletask worker's registry. 2. **Startup**: `wfr.start()` opens a gRPC stream to the Dapr sidecar. The worker polls for work items. 3. **Scheduling**: Client calls `schedule_new_workflow(fn, input=...)`. The function's name (or `_dapr_alternate_name`) is sent to the backend. -4. **Execution**: The durabletask engine dispatches work items. Workflow functions are Python **generators** that `yield` tasks (activity calls, timers, child workflows). The engine records history; on replay, yielded tasks return cached results without re-executing. +4. **Execution**: The durabletask engine dispatches work items. Workflow functions are Python **generators** that `yield` tasks (activity calls, timers, child workflows). Activity functions are either sync (dispatched to the worker's thread pool) or `async def` (awaited directly on the worker's event loop). The engine records history; on replay, yielded tasks return cached results without re-executing. 5. **Determinism**: Workflows must be deterministic — no random, no wall-clock time, no I/O. Use `ctx.current_utc_datetime` instead of `datetime.now()`. Use `ctx.is_replaying` to guard side effects like logging. 6. **Completion**: Client polls via `wait_for_workflow_completion()` or `get_workflow_state()`. @@ -193,6 +213,7 @@ Two example directories exercise workflows: - `cross-app1.py`, `cross-app2.py`, `cross-app3.py` — cross-app calls - `versioning.py` — workflow versioning with `is_patched()` - `simple_aio_client.py` — async client variant + - `async_activities.py` — `async def` activities (fan-out/fan-in with simulated I/O, configurable payload sizes) ## Testing diff --git a/dapr/ext/workflow/_durabletask/aio/client.py b/dapr/ext/workflow/_durabletask/aio/client.py index 31613ea3f..fe6a3ab53 100644 --- a/dapr/ext/workflow/_durabletask/aio/client.py +++ b/dapr/ext/workflow/_durabletask/aio/client.py @@ -71,18 +71,32 @@ def __init__( else: interceptors = None - channel = get_grpc_aio_channel( - host_address=host_address, - secure_channel=secure_channel, - interceptors=interceptors, - options=channel_options, - ) - self._channel = channel - self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._host_address = host_address + self._secure_channel = secure_channel + self._interceptors = interceptors + self._channel_options = channel_options + self._channel: grpc.aio.Channel | None = None + self._stub: stubs.TaskHubSidecarServiceStub | None = None self._logger = shared.get_logger('client', log_handler, log_formatter) + def _get_stub(self) -> stubs.TaskHubSidecarServiceStub: + """Lazily create the channel and stub on first use. + + Async grpc binds a channel to the loop active at creation, deferring it avoids binding to the wrong loop. + """ + if self._stub is None: + self._channel = get_grpc_aio_channel( + host_address=self._host_address, + secure_channel=self._secure_channel, + interceptors=self._interceptors, + options=self._channel_options, + ) + self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + return self._stub + async def aclose(self): - await self._channel.close() + if self._channel is not None: + await self._channel.close() async def __aenter__(self): return self @@ -113,14 +127,14 @@ async def schedule_new_orchestration( ) self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") - res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) + res: pb.CreateInstanceResponse = await self._get_stub().StartInstance(req) return res.instanceId async def get_orchestration_state( self, instance_id: str, *, fetch_payloads: bool = True ) -> Optional[WorkflowState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - res: pb.GetInstanceResponse = await self._stub.GetInstance(req) + res: pb.GetInstanceResponse = await self._get_stub().GetInstance(req) return new_orchestration_state(req.instanceId, res) async def wait_for_orchestration_start( @@ -132,7 +146,7 @@ async def wait_for_orchestration_start( ) async def _call(grpc_timeout): - res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart( + res: pb.GetInstanceResponse = await self._get_stub().WaitForInstanceStart( req, timeout=grpc_timeout ) return new_orchestration_state(req.instanceId, res) @@ -151,7 +165,7 @@ async def wait_for_orchestration_completion( ) async def _call(grpc_timeout): - res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion( + res: pb.GetInstanceResponse = await self._get_stub().WaitForInstanceCompletion( req, timeout=grpc_timeout ) state = new_orchestration_state(req.instanceId, res) @@ -262,7 +276,7 @@ async def raise_orchestration_event( ) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") - await self._stub.RaiseEvent(req) + await self._get_stub().RaiseEvent(req) async def terminate_orchestration( self, instance_id: str, *, output: Optional[Any] = None, recursive: bool = True @@ -274,19 +288,19 @@ async def terminate_orchestration( ) self._logger.info(f"Terminating instance '{instance_id}'.") - await self._stub.TerminateInstance(req) + await self._get_stub().TerminateInstance(req) async def suspend_orchestration(self, instance_id: str): req = pb.SuspendRequest(instanceId=instance_id) self._logger.info(f"Suspending instance '{instance_id}'.") - await self._stub.SuspendInstance(req) + await self._get_stub().SuspendInstance(req) async def resume_orchestration(self, instance_id: str): req = pb.ResumeRequest(instanceId=instance_id) self._logger.info(f"Resuming instance '{instance_id}'.") - await self._stub.ResumeInstance(req) + await self._get_stub().ResumeInstance(req) async def purge_orchestration(self, instance_id: str, recursive: bool = True): req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") - await self._stub.PurgeInstances(req) + await self._get_stub().PurgeInstances(req) diff --git a/dapr/ext/workflow/_durabletask/aio/internal/shared.py b/dapr/ext/workflow/_durabletask/aio/internal/shared.py index c6375294b..e4d39eb58 100644 --- a/dapr/ext/workflow/_durabletask/aio/internal/shared.py +++ b/dapr/ext/workflow/_durabletask/aio/internal/shared.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Optional, Sequence, Union import grpc @@ -28,6 +29,30 @@ grpc_aio.StreamStreamClientInterceptor, ] +_POLLER_NOISE_MARKER = 'PollerCompletionQueue._handle_events' + + +class _GrpcAioPollerNoiseFilter(logging.Filter): + """Drops the harmless grpc.aio poller BlockingIOError (EAGAIN) records. + + The poller does a non-blocking read on its wake-up fd and can get EAGAIN, which + asyncio logs at ERROR even though the read is retried and nothing is lost. + """ + + def filter(self, record: logging.LogRecord) -> bool: + exc = record.exc_info[1] if record.exc_info else None + is_poller_noise = isinstance(exc, BlockingIOError) and ( + _POLLER_NOISE_MARKER in record.getMessage() + ) + return not is_poller_noise + + +def _silence_grpc_aio_poller_noise() -> None: + """Install the poller-noise filter on the asyncio logger if not already present.""" + asyncio_logger = logging.getLogger('asyncio') + if not any(isinstance(f, _GrpcAioPollerNoiseFilter) for f in asyncio_logger.filters): + asyncio_logger.addFilter(_GrpcAioPollerNoiseFilter()) + def get_grpc_aio_channel( host_address: Optional[str], @@ -43,6 +68,8 @@ def get_grpc_aio_channel( interceptors: Optional sequence of client interceptors to apply to the channel. options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html """ + _silence_grpc_aio_poller_noise() + if host_address is None: host_address = get_default_host_address() diff --git a/dapr/ext/workflow/_durabletask/internal/shared.py b/dapr/ext/workflow/_durabletask/internal/shared.py index fdcee8840..f91762355 100644 --- a/dapr/ext/workflow/_durabletask/internal/shared.py +++ b/dapr/ext/workflow/_durabletask/internal/shared.py @@ -10,6 +10,8 @@ # limitations under the License. import dataclasses +import functools +import inspect import json import logging import os @@ -20,6 +22,32 @@ from dapr.ext.workflow import _model_protocol +logger = logging.getLogger(__name__) + + +def is_async_callable(fn: Any) -> bool: + """Return True if ``fn`` is async. Catches ``functools.partial`` of coroutines, + sync decorators that wrap async functions, and callable instances with ``async __call__``. + """ + candidate = fn + while isinstance(candidate, functools.partial): + candidate = candidate.func + if callable(candidate): + try: + candidate = inspect.unwrap(candidate) + except ValueError: + # Cyclic ``__wrapped__`` chain from a malformed decorator. Fall back to the + # outermost callable; misclassification is preferable to crashing dispatch. + logger.warning( + f'Cyclic __wrapped__ on {fn!r}, using outermost callable for async detection.' + ) + if inspect.iscoroutinefunction(candidate): + return True + if not inspect.isfunction(candidate) and hasattr(candidate, '__call__'): + return inspect.iscoroutinefunction(candidate.__call__) + return False + + ClientInterceptor = Union[ grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, diff --git a/dapr/ext/workflow/_durabletask/worker.py b/dapr/ext/workflow/_durabletask/worker.py index 4b742b8aa..bdc9ff486 100644 --- a/dapr/ext/workflow/_durabletask/worker.py +++ b/dapr/ext/workflow/_durabletask/worker.py @@ -32,6 +32,7 @@ import dapr.ext.workflow._durabletask.internal.shared as shared from dapr.ext.workflow._durabletask import deterministic, task from dapr.ext.workflow._durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +from dapr.ext.workflow._durabletask.internal.shared import is_async_callable from dapr.ext.workflow.propagation import PropagatedHistory, PropagationScope TInput = TypeVar('TInput') @@ -66,10 +67,10 @@ def _log_all_threads(logger: logging.Logger, context: str = ''): class ConcurrencyOptions: - """Configuration options for controlling concurrency of different work item types and the thread pool size. + """Concurrency limits for the worker. - This class provides fine-grained control over concurrent processing limits for - activities, orchestrations and the thread pool size. + ``maximum_thread_pool_workers`` sizes the pool used to run sync activities and to + deliver async-activity responses to the sidecar. """ def __init__( @@ -81,11 +82,13 @@ def __init__( """Initialize concurrency options. Args: - maximum_concurrent_activity_work_items: Maximum number of activity work items - that can be processed concurrently. Defaults to 100 * processor_count. - maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items - that can be processed concurrently. Defaults to 100 * processor_count. - maximum_thread_pool_workers: Maximum number of thread pool workers to use. + maximum_concurrent_activity_work_items: Cap on concurrent activity work items. + Defaults to ``100 * cpu_count``. + maximum_concurrent_orchestration_work_items: Cap on concurrent orchestration work + items. Defaults to ``100 * cpu_count``. + maximum_thread_pool_workers: Size of the worker thread pool. Sync activities run + on this pool, and async-activity gRPC response sends also borrow a thread + from it. Defaults to ``cpu_count + 4``. """ processor_count = os.cpu_count() or 1 default_concurrency = 100 * processor_count @@ -350,6 +353,7 @@ def __init__( self._interceptors = None self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger) + self._activity_executor = _ActivityExecutor(self._logger) @property def concurrency_options(self) -> ConcurrencyOptions: @@ -659,8 +663,19 @@ def stream_reader(): work_item.completionToken, ) elif work_item.HasField('activityRequest'): + # Async user activities run on the event loop. Sync ones fall through + # to the thread pool via _execute_activity. + activity_fn = self._registry.get_activity( + work_item.activityRequest.name + ) + activity_handler = ( + self._execute_activity_async + if activity_fn is not None and is_async_callable(activity_fn) + else self._execute_activity + ) self._async_worker_manager.submit_activity( - self._execute_activity, + activity_handler, + activity_fn, work_item.activityRequest, stub, work_item.completionToken, @@ -965,98 +980,178 @@ def _execute_orchestrator( f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}" ) + def _activity_span(self, req: pb.ActivityRequest, instance_id: str): + """Return an OTel span context manager, or a nullcontext if OTel is not installed.""" + if otel_tracer is None: + return contextlib.nullcontext() + return otel_tracer.start_as_current_span( + name=f'activity: {req.name}', + context=otel_propagator.extract( + carrier={'traceparent': req.parentTraceContext.traceParent} + ), + attributes={ + 'dapr.ext.workflow._durabletask.task.instance_id': instance_id, + 'dapr.ext.workflow._durabletask.task.id': req.taskId, + 'dapr.ext.workflow._durabletask.activity.name': req.name, + }, + ) + + def _propagated_history(self, req: pb.ActivityRequest) -> PropagatedHistory | None: + if req.HasField('propagatedHistory'): + return PropagatedHistory.from_proto(req.propagatedHistory) + return None + + def _build_activity_result_response( + self, + req: pb.ActivityRequest, + instance_id: str, + result: str | None, + completion_token, + ) -> pb.ActivityResponse: + return pb.ActivityResponse( + instanceId=instance_id, + taskId=req.taskId, + result=ph.get_string_value(result), + completionToken=completion_token, + ) + + def _build_activity_failure_response( + self, + req: pb.ActivityRequest, + instance_id: str, + ex: BaseException, + completion_token, + ) -> pb.ActivityResponse: + return pb.ActivityResponse( + instanceId=instance_id, + taskId=req.taskId, + failureDetails=ph.new_failure_details(ex), + completionToken=completion_token, + ) + + def _send_activity_response( + self, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + res: pb.ActivityResponse, + completion_token, + instance_id: str, + ): + """Send an activity response, falling back to a failure response when the + result is too large to deliver.""" + try: + stub.CompleteActivityTask(res) + except grpc.RpcError as rpc_error: # type: ignore + if _is_message_too_large(rpc_error): + # Result is too large to deliver - fail the activity immediately. + # This can only be fixed with infrastructure changes (increasing gRPC max message size). + self._logger.error( + f"Activity '{req.name}#{req.taskId}' result is too large to deliver " + f'(RESOURCE_EXHAUSTED). Failing the activity task: {rpc_error.details()}' + ) + oversize_error = RuntimeError( + f'Activity result exceeds gRPC max message size: {rpc_error.details()}' + ) + failure_res = self._build_activity_failure_response( + req, instance_id, oversize_error, completion_token + ) + try: + stub.CompleteActivityTask(failure_res) + except Exception as ex: + self._logger.exception( + f"Failed to deliver activity failure response for '{req.name}#{req.taskId}' " + f"of orchestration ID '{instance_id}': {ex}" + ) + else: + self._handle_grpc_execution_error(rpc_error, 'activity') + except ValueError: + # gRPC raises ValueError when the underlying channel has been closed (e.g. during reconnection). + self._logger.debug( + f"Could not deliver activity response for '{req.name}#{req.taskId}' of " + f"orchestration ID '{instance_id}': channel was closed (likely due to " + f'reconnection). The sidecar will re-dispatch this work item.' + ) + except Exception as ex: + self._logger.exception( + f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" + ) + def _execute_activity( self, + fn: task.Activity | None, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken, ): instance_id = req.workflowInstance.instanceId - - if otel_tracer is not None: - span_context = otel_tracer.start_as_current_span( - name=f'activity: {req.name}', - context=otel_propagator.extract( - carrier={'traceparent': req.parentTraceContext.traceParent} - ), - attributes={ - 'dapr.ext.workflow._durabletask.task.instance_id': instance_id, - 'dapr.ext.workflow._durabletask.task.id': req.taskId, - 'dapr.ext.workflow._durabletask.activity.name': req.name, - }, - ) - else: - span_context = contextlib.nullcontext() - - with span_context: + with self._activity_span(req, instance_id): try: - executor = _ActivityExecutor(self._registry, self._logger) - propagated = ( - PropagatedHistory.from_proto(req.propagatedHistory) - if req.HasField('propagatedHistory') - else None - ) - result = executor.execute( + result = self._activity_executor.execute( + fn, instance_id, req.name, req.taskId, req.input.value, req.taskExecutionId, - propagated_history=propagated, + propagated_history=self._propagated_history(req), ) - res = pb.ActivityResponse( - instanceId=instance_id, - taskId=req.taskId, - result=ph.get_string_value(result), - completionToken=completionToken, + res = self._build_activity_result_response( + req, instance_id, result, completionToken ) except Exception as ex: - res = pb.ActivityResponse( - instanceId=instance_id, - taskId=req.taskId, - failureDetails=ph.new_failure_details(ex), - completionToken=completionToken, - ) + res = self._build_activity_failure_response(req, instance_id, ex, completionToken) + self._send_activity_response(req, stub, res, completionToken, instance_id) + async def _execute_activity_async( + self, + fn: task.Activity, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): + """Run an async activity on the event loop and send its result to the sidecar. + The gRPC send runs on the worker thread pool to avoid blocking the loop. + """ + instance_id = req.workflowInstance.instanceId + with self._activity_span(req, instance_id): try: - stub.CompleteActivityTask(res) - except grpc.RpcError as rpc_error: # type: ignore - if _is_message_too_large(rpc_error): - # Result is too large to deliver - fail the activity immediately. - # This can only be fixed with infrastructure changes (increasing gRPC max message size). - self._logger.error( - f"Activity '{req.name}#{req.taskId}' result is too large to deliver " - f'(RESOURCE_EXHAUSTED). Failing the activity task: {rpc_error.details()}' - ) - failure_res = pb.ActivityResponse( - instanceId=instance_id, - taskId=req.taskId, - failureDetails=ph.new_failure_details( - RuntimeError( - f'Activity result exceeds gRPC max message size: {rpc_error.details()}' - ) - ), - completionToken=completionToken, - ) - try: - stub.CompleteActivityTask(failure_res) - except Exception as ex: - self._logger.exception( - f"Failed to deliver activity failure response for '{req.name}#{req.taskId}' " - f"of orchestration ID '{instance_id}': {ex}" - ) - else: - self._handle_grpc_execution_error(rpc_error, 'activity') - except ValueError: - # gRPC raises ValueError when the underlying channel has been closed (e.g. during reconnection). - self._logger.debug( - f"Could not deliver activity response for '{req.name}#{req.taskId}' of " - f"orchestration ID '{instance_id}': channel was closed (likely due to " - f'reconnection). The sidecar will re-dispatch this work item.' + result = await self._activity_executor.execute_async( + fn, + instance_id, + req.name, + req.taskId, + req.input.value, + req.taskExecutionId, + propagated_history=self._propagated_history(req), ) + res = self._build_activity_result_response( + req, instance_id, result, completionToken + ) + except asyncio.CancelledError: + raise except Exception as ex: - self._logger.exception( - f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" + res = self._build_activity_failure_response(req, instance_id, ex, completionToken) + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor( + self._async_worker_manager.thread_pool, + self._send_activity_response, + req, + stub, + res, + completionToken, + instance_id, + ) + except RuntimeError as exc: + # Swallow only when the thread pool itself is shut down (worker tearing down). + # Other RuntimeErrors are unexpected and propagate to the work-item processor. + # The sidecar will re-dispatch this work item once the worker reconnects. + pool = self._async_worker_manager.thread_pool + if not getattr(pool, '_shutdown', False): + raise + self._logger.warning( + f"Could not deliver activity response for '{req.name}#{req.taskId}': " + f'{exc}. The sidecar will re-dispatch this work item.' ) @@ -1999,27 +2094,25 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven class _ActivityExecutor: - def __init__(self, registry: _Registry, logger: logging.Logger): - self._registry = registry + def __init__(self, logger: logging.Logger): self._logger = logger - def execute( + def _resolve( self, + fn: task.Activity | None, orchestration_id: str, name: str, task_id: int, - encoded_input: Optional[str], - task_execution_id: str = '', - propagated_history: Optional[PropagatedHistory] = None, - ) -> Optional[str]: - """Executes an activity function and returns the serialized result, if any.""" + encoded_input: str | None, + task_execution_id: str, + propagated_history: PropagatedHistory | None, + ) -> tuple[task.Activity, task.ActivityContext, Any]: + """Validate ``fn`` and build its ``(fn, ctx, input)`` call args.""" self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...") - fn = self._registry.get_activity(name) - if not fn: + if fn is None: raise ActivityNotRegisteredError( f"Activity function named '{name}' was not registered!" ) - activity_input = shared.from_json(encoded_input) if encoded_input else None ctx = task.ActivityContext( orchestration_id, @@ -2027,10 +2120,11 @@ def execute( task_execution_id, propagated_history=propagated_history, ) + return fn, ctx, activity_input - # Execute the activity function - activity_output = fn(ctx, activity_input) - + def _encode_output( + self, orchestration_id: str, name: str, task_id: int, activity_output: Any + ) -> str | None: encoded_output = shared.to_json(activity_output) if activity_output is not None else None chars = len(encoded_output) if encoded_output else 0 self._logger.debug( @@ -2038,6 +2132,64 @@ def execute( ) return encoded_output + def execute( + self, + fn: task.Activity | None, + orchestration_id: str, + name: str, + task_id: int, + encoded_input: str | None, + task_execution_id: str = '', + propagated_history: PropagatedHistory | None = None, + ) -> str | None: + """Run a sync activity function and return the serialized result, if any. + + Raises ``RuntimeError`` if the activity returns a coroutine, which happens when + ``is_async_callable`` fails to detect an async callable at registration. + """ + resolved_fn, ctx, activity_input = self._resolve( + fn, + orchestration_id, + name, + task_id, + encoded_input, + task_execution_id, + propagated_history, + ) + activity_output = resolved_fn(ctx, activity_input) + if inspect.iscoroutine(activity_output): + activity_output.close() + raise RuntimeError( + f"Activity '{name}' returned a coroutine on the sync path. " + f'Declare it with ``async def``, or if it already is, ensure any decorator ' + f'wrapping it uses ``@functools.wraps(fn)`` so the runtime can detect the ' + f'underlying async function.' + ) + return self._encode_output(orchestration_id, name, task_id, activity_output) + + async def execute_async( + self, + fn: task.Activity, + orchestration_id: str, + name: str, + task_id: int, + encoded_input: str | None, + task_execution_id: str = '', + propagated_history: PropagatedHistory | None = None, + ) -> str | None: + """Await a coroutine activity function and return the serialized result, if any.""" + resolved_fn, ctx, activity_input = self._resolve( + fn, + orchestration_id, + name, + task_id, + encoded_input, + task_execution_id, + propagated_history, + ) + activity_output = await resolved_fn(ctx, activity_input) + return self._encode_output(orchestration_id, name, task_id, activity_output) + def _get_non_determinism_error(task_id: int, action_name: str) -> task.NonDeterminismError: return task.NonDeterminismError( @@ -2275,7 +2427,7 @@ async def _process_work_item( queue.task_done() async def _run_func(self, func, *args, **kwargs): - if inspect.iscoroutinefunction(func): + if is_async_callable(func): return await func(*args, **kwargs) else: loop = asyncio.get_running_loop() diff --git a/dapr/ext/workflow/docs/concurrency.md b/dapr/ext/workflow/docs/concurrency.md new file mode 100644 index 000000000..50353dc64 --- /dev/null +++ b/dapr/ext/workflow/docs/concurrency.md @@ -0,0 +1,86 @@ +# Concurrency configuration for `dapr.ext.workflow` + +Sizing notes for the worker's concurrency knobs. + +## Knobs + +| Setting | Default | Effect | +| --- | --- | --- | +| `maximum_concurrent_activity_work_items` | `100 × cpu_count` | Async semaphore cap on in-flight activity work items. | +| `maximum_concurrent_orchestration_work_items` | `100 × cpu_count` | Same, for orchestrations. | +| `maximum_thread_pool_workers` | `cpu_count + 4` | Worker thread pool size. Sync activities run on this pool, and async-activity gRPC response sends also borrow a thread from it. | + +A `def` activity consumes a semaphore slot **and** a thread pool worker. An +`async def` activity consumes only a semaphore slot. + +## Choosing sync vs async + +Sync (`def`) activities are fully supported and unchanged: they run on the thread +pool. Keep CPU-bound work sync. An `async def` that burns CPU blocks the event loop +and starves every other activity. + +For **I/O-bound** activities (HTTP calls, database queries, anything that waits), +prefer `async def`. A sync activity holds a thread for the whole wait, so concurrency +is capped at the pool size (`cpu_count + 4`); an async activity holds only a semaphore +slot, so in-flight concurrency scales to `maximum_concurrent_activity_work_items`. The +gap widens with fan-out width. If your activities wait on I/O, moving them to `async def` +is the single biggest concurrency win available. + +Raising `maximum_thread_pool_workers` lifts the ceiling for a sync I/O activity you can't +convert yet, but threads scale worse than the loop. Each costs stack memory and contends +on the GIL, so the activity semaphore reaches `100 × cpu_count` in flight where a thread +pool that size would not. It buys headroom, not the async ceiling. + +Async helps concurrent activities, not sequential chains. A chain of dependent steps +costs the sum of its steps either way, sync or async. + +## Sizing the activity cap + +The cap is the lever for throughput and queue wait. Below the cap, in-flight work +runs concurrently; past it, submissions wait in the queue. Rule of thumb: set the +cap to ~2x the expected steady-state in-flight count to absorb bursts. + +If activities call a downstream with a hard concurrency limit (e.g. a database +with a 100-connection pool), set the cap below that limit so it doubles as +backpressure. + +## Sizing the thread pool + +The worker thread pool, sized by `maximum_thread_pool_workers`, has two uses. + +**Sync activity execution.** Each `def` activity holds one thread for its +duration. Size to peak concurrent sync-activity count. + +**Async response delivery.** Each async activity, on completion, schedules +`stub.CompleteActivityTask` on the same pool to avoid blocking the loop during +the gRPC send. If the sidecar takes >5 ms to acknowledge and the worker runs +many concurrent async activities, response delivery can serialize through the +pool and tail latency inflates. Raise `maximum_thread_pool_workers` to widen +response-delivery throughput. + +Mixed workloads with long-running sync activities can starve async response +delivery (and vice versa) since they share the pool. If that becomes an issue, +size `maximum_thread_pool_workers` to the sum of peak sync activity concurrency +and peak in-flight async response sends. + +This thread hop goes away when the worker migrates to `grpc.aio`. + +## Reusing clients in async activities + +When async activities call out over the network (HTTP, a database), a fresh client per +call bounds throughput by connection setup, not the I/O. A per-call `httpx.AsyncClient` +plateaus around a few hundred req/s. Reuse one client and size its pool to the activity +cap: + +```python +_shared_client: httpx.AsyncClient | None = None + +def _get_client() -> httpx.AsyncClient: + global _shared_client + if _shared_client is None: + _shared_client = httpx.AsyncClient(timeout=30.0) + return _shared_client +``` + +The caller owns closing it during worker shutdown. For activities that hit many +hosts or need per-call timeout isolation, stick with per-call clients. diff --git a/dapr/ext/workflow/workflow_runtime.py b/dapr/ext/workflow/workflow_runtime.py index be0c3406d..4f70f953b 100644 --- a/dapr/ext/workflow/workflow_runtime.py +++ b/dapr/ext/workflow/workflow_runtime.py @@ -16,7 +16,7 @@ import inspect import time from functools import wraps -from typing import Optional, Sequence, TypeVar, Union +from typing import Any, Awaitable, Callable, Optional, Sequence, TypeVar, Union import grpc @@ -26,6 +26,7 @@ from dapr.conf import settings from dapr.conf.helpers import GrpcEndpoint from dapr.ext.workflow._durabletask import task, worker +from dapr.ext.workflow._durabletask.internal.shared import is_async_callable as _is_async_callable from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.util import getAddress @@ -45,6 +46,64 @@ grpc.StreamStreamClientInterceptor, ] +# Durabletask returns decoded JSON, so we type the input as ``object | None`` and let the +# wrapper narrow it via the activity's declared model. +SyncActivityWrapper = Callable[[task.ActivityContext, object | None], object] +AsyncActivityWrapper = Callable[[task.ActivityContext, object | None], Awaitable[object]] +ActivityWrapper = SyncActivityWrapper | AsyncActivityWrapper + + +def _coerce_activity_input(inp: object | None, input_model: type | None) -> object | None: + """Coerce the raw input to the activity's declared model, if it has one.""" + if inp is None or input_model is None or isinstance(inp, input_model): + return inp + return _model_protocol.coerce_to_model(inp, input_model) + + +def _make_activity_wrapper(fn: Activity, logger: Logger) -> ActivityWrapper: + """Wrap a user activity for the durabletask worker. + + Returns: + An ``async def`` wrapper for async activities, a plain ``def`` for sync. + """ + accepts_input, input_model = _model_protocol.resolve_input(fn) + + def _call_args(ctx: task.ActivityContext, inp: object | None) -> tuple: + wf_ctx = WorkflowActivityContext(ctx) + if not accepts_input: + return (wf_ctx,) + return (wf_ctx, _coerce_activity_input(inp, input_model)) + + def _log_failure(ctx: task.ActivityContext, exc: Exception) -> None: + activity_id = getattr(ctx, 'task_id', 'unknown') + logger.warning(f'Activity execution failed - task_id: {activity_id}, error: {exc}') + + is_async = _is_async_callable(fn) + activity_name = getattr(fn, '__name__', repr(fn)) + kind = 'async' if is_async else 'sync' + logger.debug(f"Registering activity '{activity_name}' on the {kind} dispatch path.") + if is_async: + + async def async_activity_wrapper( + ctx: task.ActivityContext, inp: object | None = None + ) -> object: + try: + return await fn(*_call_args(ctx, inp)) + except Exception as exc: + _log_failure(ctx, exc) + raise + + return async_activity_wrapper + + def sync_activity_wrapper(ctx: task.ActivityContext, inp: object | None = None) -> object: + try: + return fn(*_call_args(ctx, inp)) + except Exception as exc: + _log_failure(ctx, exc) + raise + + return sync_activity_wrapper + class WorkflowRuntime: """WorkflowRuntime is the entry point for registering workflows and activities.""" @@ -180,36 +239,14 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = fn.__dict__['_workflow_registered'] = True def register_activity(self, fn: Activity, *, name: Optional[str] = None): - """Registers a workflow activity as a function that takes - a specified input type and returns a specified output type. + """Register a workflow activity. ``def`` and ``async def`` are both supported. + Async activities run on the worker's event loop. Sync activities run in the + thread pool sized by ``maximum_thread_pool_workers``. """ effective_name = name or fn.__name__ self._logger.info(f"Registering activity '{effective_name}' with runtime") - accepts_input, input_model = _model_protocol.resolve_input(fn) - - def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): - """Responsible to call Activity function in activityWrapper""" - activity_id = getattr(ctx, 'task_id', 'unknown') - - try: - wfActivityContext = WorkflowActivityContext(ctx) - if not accepts_input: - result = fn(wfActivityContext) - else: - if ( - (inp is not None) - and (input_model is not None) - and not isinstance(inp, input_model) - ): - inp = _model_protocol.coerce_to_model(inp, input_model) - result = fn(wfActivityContext, inp) - return result - except Exception as e: - self._logger.warning( - f'Activity execution failed - task_id: {activity_id}, error: {e}' - ) - raise + activity_wrapper = _make_activity_wrapper(fn, self._logger) if hasattr(fn, '_activity_registered'): # whenever an activity is registered, it has a _dapr_alternate_name attribute @@ -224,7 +261,7 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_activity( - fn.__dict__['_dapr_alternate_name'], activityWrapper + fn.__dict__['_dapr_alternate_name'], activity_wrapper ) fn.__dict__['_activity_registered'] = True @@ -446,16 +483,23 @@ def add(ctx, x: int, y: int) -> int: the workflow runtime. Defaults to None. """ - def wrapper(fn: any): + def wrapper(fn: Any): if hasattr(fn, '_dapr_alternate_name'): raise ValueError( f'Function {fn.__name__} already has an alternate name {fn._dapr_alternate_name}' ) fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ - @wraps(fn) - def innerfn(*args, **kwargs): - return fn(*args, **kwargs) + if _is_async_callable(fn): + + @wraps(fn) + async def innerfn(*args, **kwargs): + return await fn(*args, **kwargs) + else: + + @wraps(fn) + def innerfn(*args, **kwargs): + return fn(*args, **kwargs) innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ innerfn.__signature__ = inspect.signature(fn) diff --git a/examples/workflow/README.md b/examples/workflow/README.md index 26601d6fa..85a8dabe7 100644 --- a/examples/workflow/README.md +++ b/examples/workflow/README.md @@ -559,3 +559,35 @@ It shows: ```sh dapr run --app-id workflow-history-propagation -- python3 history_propagation.py ``` + +### Async Activities + +This example fans out several `async def` activities, then aggregates their +results in a sync activity. Each async activity awaits a delay to stand in for +an I/O call, so the instances run concurrently on the worker's event loop +instead of taking a thread each. + +Fan-out width and payload sizes are set with environment variables: +`WORKFLOW_FAN_OUT` (default 5), `WORKFLOW_INPUT_BYTES` (default 2048), +`WORKFLOW_OUTPUT_BYTES` (default 1024), and `WORKFLOW_IO_SECONDS` (default 1.0). + +See [concurrency.md](../../dapr/ext/workflow/docs/concurrency.md) for when to +prefer async over sync activities and how to size the concurrency knobs. + +```sh +dapr run --app-id workflow-async-activities -- python3 async_activities.py +``` + +The output should look like this (the async lines can arrive in any order): + +``` +Workflow started. Instance ID: 7b3e9c1f... +[async] payload 0: 2048B in -> 1024B out +[async] payload 1: 2048B in -> 1024B out +[async] payload 2: 2048B in -> 1024B out +[async] payload 3: 2048B in -> 1024B out +[async] payload 4: 2048B in -> 1024B out +[sync] 5 results, 5120 bytes +Workflow completed! Status: COMPLETED +Workflow result: 5 results, 5120 bytes +``` diff --git a/examples/workflow/async_activities.py b/examples/workflow/async_activities.py new file mode 100644 index 000000000..84be3df76 --- /dev/null +++ b/examples/workflow/async_activities.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async activities running alongside a sync one in a fan-out/fan-in workflow. + +Each async activity simulates an I/O-bound call: it takes a payload, awaits a fixed +delay (standing in for a network round-trip), and returns a result payload. The async +instances run concurrently on the worker's event loop; a final sync activity aggregates +the results. Fan-out width, input/output payload sizes, and the delay are configurable +via environment variables. + +Run with: + + dapr run --app-id async-activities --app-protocol grpc --dapr-grpc-port 50001 \\ + -- python async_activities.py +""" + +from __future__ import annotations + +import asyncio +import os +import random +import string +from time import sleep + +from pydantic import BaseModel + +import dapr.ext.workflow as wf + +FAN_OUT = int(os.environ.get('WORKFLOW_FAN_OUT', '5')) +INPUT_BYTES = int(os.environ.get('WORKFLOW_INPUT_BYTES', '2048')) +OUTPUT_BYTES = int(os.environ.get('WORKFLOW_OUTPUT_BYTES', '1024')) +IO_SECONDS = float(os.environ.get('WORKFLOW_IO_SECONDS', '1.0')) + +wfr = wf.WorkflowRuntime() + + +def _random_digits(n: int) -> str: + return ''.join(random.choices(string.digits, k=n)) + + +class Payload(BaseModel): + index: int + data: str + + +@wfr.workflow(name='fan_out_fan_in_workflow') +def fan_out_fan_in_workflow(ctx: wf.DaprWorkflowContext, payloads: list[dict]): + tasks = [ctx.call_activity(process_payload, input=p) for p in payloads] + results = yield wf.when_all(tasks) + summary = yield ctx.call_activity(summarize, input=results) + return summary + + +@wfr.activity(name='process_payload') +async def process_payload(ctx: wf.WorkflowActivityContext, payload: Payload) -> str: + """Async activity: simulate an I/O-bound call. Instances run concurrently on the loop.""" + await asyncio.sleep(IO_SECONDS) + result = _random_digits(OUTPUT_BYTES) + print( + f'[async] payload {payload.index}: {len(payload.data)}B in -> {len(result)}B out', + flush=True, + ) + return result + + +@wfr.activity(name='summarize') +def summarize(ctx: wf.WorkflowActivityContext, results: list[str]) -> str: + """Sync activity: aggregate the fan-out results on the thread pool.""" + summary = f'{len(results)} results, {sum(len(r) for r in results)} bytes' + print(f'[sync] {summary}', flush=True) + return summary + + +def main() -> None: + payloads = [ + Payload(index=i, data=_random_digits(INPUT_BYTES)).model_dump() for i in range(FAN_OUT) + ] + + wfr.start() + sleep(5) # wait for workflow runtime to start + + wf_client = wf.DaprWorkflowClient() + instance_id = wf_client.schedule_new_workflow(workflow=fan_out_fan_in_workflow, input=payloads) + print(f'Workflow started. Instance ID: {instance_id}') + + state = wf_client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert state is not None + print(f'Workflow completed! Status: {state.runtime_status.name}') + print(f'Workflow result: {state.serialized_output.strip(chr(34))}') + + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/pyproject.toml b/pyproject.toml index a9587f906..44e2c6422 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -207,6 +207,7 @@ ignore_errors = true [tool.pytest.ini_options] markers = [ 'example_dir(name): set the example directory for the dapr fixture', + 'perf: timing-sensitive dispatch regression tests', ] pythonpath = ["."] asyncio_mode = "auto" diff --git a/tests/examples/test_workflow.py b/tests/examples/test_workflow.py index 45aa0e855..2d5ec18eb 100644 --- a/tests/examples/test_workflow.py +++ b/tests/examples/test_workflow.py @@ -63,3 +63,22 @@ def test_history_propagation(dapr): ) for line in EXPECTED_HISTORY_PROPAGATION: assert line in output, f'Missing in output: {line}' + + +# Defaults: 5 async activities, 2048B in / 1024B out each, so 5 * 1024 = 5120 bytes aggregated. +EXPECTED_ASYNC_ACTIVITIES = [ + '[async] payload 0: 2048B in -> 1024B out', + '[sync] 5 results, 5120 bytes', + 'Workflow completed! Status: COMPLETED', + 'Workflow result: 5 results, 5120 bytes', +] + + +@pytest.mark.example_dir('workflow') +def test_async_activities(dapr): + output = dapr.run( + '--app-id workflow-async-activities -- python3 async_activities.py', + timeout=60, + ) + for line in EXPECTED_ASYNC_ACTIVITIES: + assert line in output, f'Missing in output: {line}' diff --git a/tests/ext/workflow/durabletask/test_activity_dispatch_routing.py b/tests/ext/workflow/durabletask/test_activity_dispatch_routing.py new file mode 100644 index 000000000..afc590f62 --- /dev/null +++ b/tests/ext/workflow/durabletask/test_activity_dispatch_routing.py @@ -0,0 +1,93 @@ +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contract tests for the activity dispatch handlers on ``TaskHubGrpcWorker``. + +The work-item dispatcher at the top of ``worker.py``'s gRPC loop selects between +``_execute_activity`` (sync, runs in the thread pool) and ``_execute_activity_async`` +(coroutine, awaited on the event loop) using ``is_async_callable(handler)`` via +``_AsyncWorkerManager._run_func``. These tests pin the async-ness of each handler so +the dispatch routing stays correct. +""" + +import asyncio +import inspect +import logging +import threading +from typing import Iterator + +import pytest + +from dapr.ext.workflow._durabletask.worker import ( + ConcurrencyOptions, + TaskHubGrpcWorker, + _AsyncWorkerManager, +) + + +@pytest.fixture +def worker() -> Iterator[TaskHubGrpcWorker]: + instance = TaskHubGrpcWorker() + try: + yield instance + finally: + # The worker was never started, so ``stop()`` early-returns; shut the manager + # down directly so the test doesn't leak threads if any work was submitted. + instance.stop() + instance._async_worker_manager.shutdown() + + +@pytest.fixture +def manager() -> Iterator[_AsyncWorkerManager]: + instance = _AsyncWorkerManager(ConcurrencyOptions(), logger=logging.getLogger()) + try: + yield instance + finally: + instance.shutdown() + + +def test_sync_activity_handler_is_not_a_coroutine_function(worker: TaskHubGrpcWorker): + assert not inspect.iscoroutinefunction(worker._execute_activity) + + +def test_async_activity_handler_is_a_coroutine_function(worker: TaskHubGrpcWorker): + assert inspect.iscoroutinefunction(worker._execute_activity_async) + + +def test_run_func_awaits_coroutines_directly(manager: _AsyncWorkerManager): + """``_AsyncWorkerManager._run_func`` is the single point that branches on async-ness. + + A coroutine handler returns its value without going through the thread pool. + """ + + async def coroutine_handler(value: int) -> int: + return value + 1 + + async def driver() -> int: + return await manager._run_func(coroutine_handler, 41) + + assert asyncio.run(driver()) == 42 + + +def test_run_func_dispatches_sync_callables_to_thread_pool(manager: _AsyncWorkerManager): + main_thread_id = threading.get_ident() + captured: dict[str, int] = {} + + def sync_handler(value: int) -> int: + captured['thread_id'] = threading.get_ident() + return value + 1 + + async def driver() -> int: + return await manager._run_func(sync_handler, 41) + + result = asyncio.run(driver()) + assert result == 42 + assert captured['thread_id'] != main_thread_id diff --git a/tests/ext/workflow/durabletask/test_activity_executor.py b/tests/ext/workflow/durabletask/test_activity_executor.py index f65aaf3f6..8a0b3fe63 100644 --- a/tests/ext/workflow/durabletask/test_activity_executor.py +++ b/tests/ext/workflow/durabletask/test_activity_executor.py @@ -34,7 +34,9 @@ def test_activity(ctx: task.ActivityContext, test_input: Any): activity_input = 'Hello, 世界!' executor, name = _get_activity_executor(test_activity) - result = executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input)) + result = executor.execute( + test_activity, TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input) + ) assert result is not None result_input, result_orchestration_id, result_task_id = json.loads(result) @@ -44,14 +46,14 @@ def test_activity(ctx: task.ActivityContext, test_input: Any): def test_activity_not_registered(): - def test_activity(ctx: task.ActivityContext, _): - pass # not used - - executor, _ = _get_activity_executor(test_activity) + """Dispatch site passes ``fn=None`` for unknown activity names. Executor surfaces + that as ``ActivityNotRegisteredError`` carrying the requested name. + """ + executor = worker._ActivityExecutor(TEST_LOGGER) caught_exception: Optional[Exception] = None try: - executor.execute(TEST_INSTANCE_ID, 'Bogus', TEST_TASK_ID, None) + executor.execute(None, TEST_INSTANCE_ID, 'Bogus', TEST_TASK_ID, None) except Exception as ex: caught_exception = ex @@ -59,8 +61,29 @@ def test_activity(ctx: task.ActivityContext, _): assert 'Bogus' in str(caught_exception) +def test_sync_execute_rejects_async_activity(): + """Sync ``execute`` must raise a clear RuntimeError when the activity returns a + coroutine. Guards against ``_is_async_callable`` missing an async callable at + registration; without this, JSON encoding would fail with a confusing TypeError. + """ + + async def async_activity(ctx: task.ActivityContext, _): + return 'never reached' + + executor, name = _get_activity_executor(async_activity) + + caught_exception: Optional[Exception] = None + try: + executor.execute(async_activity, TEST_INSTANCE_ID, name, TEST_TASK_ID, None) + except Exception as ex: + caught_exception = ex + + assert type(caught_exception) is RuntimeError + assert 'returned a coroutine' in str(caught_exception) + + def _get_activity_executor(fn: task.Activity) -> Tuple[worker._ActivityExecutor, str]: registry = worker._Registry() name = registry.add_activity(fn) - executor = worker._ActivityExecutor(registry, TEST_LOGGER) + executor = worker._ActivityExecutor(TEST_LOGGER) return executor, name diff --git a/tests/ext/workflow/durabletask/test_activity_executor_async.py b/tests/ext/workflow/durabletask/test_activity_executor_async.py new file mode 100644 index 000000000..a8758dca2 --- /dev/null +++ b/tests/ext/workflow/durabletask/test_activity_executor_async.py @@ -0,0 +1,99 @@ +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the async branch of ``_ActivityExecutor``. + +These mirror ``test_activity_executor.py`` but exercise the ``execute_async`` path used +when a registered activity is a coroutine function. +""" + +import asyncio +import inspect +import json +import logging +from typing import Any + +import pytest + +from dapr.ext.workflow._durabletask import task, worker + +logging.basicConfig( + format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.DEBUG, +) +TEST_LOGGER = logging.getLogger('tests') +TEST_INSTANCE_ID = 'abc123' +TEST_TASK_ID = 42 + + +def _get_activity_executor(fn: task.Activity) -> tuple[worker._ActivityExecutor, str]: + registry = worker._Registry() + name = registry.add_activity(fn) + executor = worker._ActivityExecutor(TEST_LOGGER) + return executor, name + + +def test_async_activity_inputs(): + """Validates that execute_async awaits the activity and returns the encoded result.""" + + async def test_async_activity(ctx: task.ActivityContext, test_input: Any): + await asyncio.sleep(0) + return test_input, ctx.orchestration_id, ctx.task_id + + activity_input = 'Hello, 世界!' + executor, name = _get_activity_executor(test_async_activity) + result = asyncio.run( + executor.execute_async( + test_async_activity, + TEST_INSTANCE_ID, + name, + TEST_TASK_ID, + json.dumps(activity_input), + ) + ) + assert result is not None + + result_input, result_orchestration_id, result_task_id = json.loads(result) + assert activity_input == result_input + assert TEST_INSTANCE_ID == result_orchestration_id + assert TEST_TASK_ID == result_task_id + + +def test_async_activity_exception_propagates(): + async def test_async_activity(ctx: task.ActivityContext, _): + raise RuntimeError('boom') + + executor, name = _get_activity_executor(test_async_activity) + + with pytest.raises(RuntimeError) as exc_info: + asyncio.run( + executor.execute_async(test_async_activity, TEST_INSTANCE_ID, name, TEST_TASK_ID, None) + ) + assert 'boom' in str(exc_info.value) + + +def test_async_activity_registry_preserves_coroutine_function(): + """The dispatcher relies on iscoroutinefunction(fn) at the registry lookup level. + + If the registry's add_activity ever wraps coroutine functions in a way that hides their + async-ness (e.g. functools.wraps with a sync decorator), the dispatcher would route + them to the thread pool and break I/O concurrency. This test pins that contract. + """ + + async def test_async_activity(ctx: task.ActivityContext, _): + return None + + registry = worker._Registry() + name = registry.add_activity(test_async_activity) + + retrieved = registry.get_activity(name) + assert inspect.iscoroutinefunction(retrieved) diff --git a/tests/ext/workflow/durabletask/test_async_dispatch_regression.py b/tests/ext/workflow/durabletask/test_async_dispatch_regression.py new file mode 100644 index 000000000..8ea14172b --- /dev/null +++ b/tests/ext/workflow/durabletask/test_async_dispatch_regression.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Perf regression test for the async activity dispatch path. + +Drives ``_execute_activity_async`` through ``_AsyncWorkerManager`` against an in-process +stub. A timeout fails the batch if async activities serialize instead of overlapping. +""" + +import asyncio + +import pytest + +import dapr.ext.workflow._durabletask.internal.protos as pb +from dapr.ext.workflow._durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker + +pytestmark = pytest.mark.perf + +ACTIVITY_DURATION_SECONDS = 0.02 +N_ITEMS = 1000 +SEMAPHORE_CAP = 2000 +THREAD_POOL = 16 + +# Generous fraction of the time to run 1000 activities serially. Should trip fast if async I/O serializes +TIMEOUT_S = 2.0 + + +class _MockSidecarStub: + """In-process stand-in for the gRPC stub; records activity completions.""" + + def __init__(self) -> None: + self.completions = 0 + + def CompleteActivityTask(self, _response: pb.ActivityResponse) -> None: # noqa: N802 + self.completions += 1 + + +def _activity_request(task_id: int) -> pb.ActivityRequest: + return pb.ActivityRequest( + name='regression_async', + taskId=task_id, + workflowInstance=pb.WorkflowInstance(instanceId='regression'), + parentTraceContext=pb.TraceContext(traceParent=''), + taskExecutionId='', + ) + + +async def _run_async_batch(n_items: int, timeout_s: float) -> int: + """Submit ``n_items`` async sleep activities through the dispatch path and drain them. + + Raises ``asyncio.TimeoutError`` if the batch does not drain within ``timeout_s``. + """ + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=SEMAPHORE_CAP, + maximum_concurrent_orchestration_work_items=SEMAPHORE_CAP, + maximum_thread_pool_workers=THREAD_POOL, + ) + worker = TaskHubGrpcWorker(host_address='in-process-mock', concurrency_options=options) + manager = worker._async_worker_manager + stub = _MockSidecarStub() + + async def activity(ctx, _inp) -> None: + await asyncio.sleep(ACTIVITY_DURATION_SECONDS) + + worker_task = asyncio.create_task(manager.run()) + # Non-blocking poll: yield to the event loop until the worker creates the activity queue + while manager.activity_queue is None: + await asyncio.sleep(0) + try: + for task_id in range(n_items): + req = _activity_request(task_id) + manager.submit_activity(worker._execute_activity_async, activity, req, stub, '') + await asyncio.wait_for(manager.activity_queue.join(), timeout=timeout_s) + finally: + manager._shutdown = True + worker_task.cancel() + await asyncio.gather(worker_task, return_exceptions=True) + manager.shutdown() + + return stub.completions + + +def test_async_activities_overlap_instead_of_serializing(): + """A batch of async activities drains in ~one I/O window, not N of them. + + Fails if the batch cannot finish within ``TIMEOUT_S``, meaning the async path is + serializing instead of overlapping I/O on the event loop. + """ + try: + completions = asyncio.run(_run_async_batch(N_ITEMS, TIMEOUT_S)) + except asyncio.TimeoutError: + serial_s = N_ITEMS * ACTIVITY_DURATION_SECONDS + pytest.fail( + f'{N_ITEMS} async activities did not drain within {TIMEOUT_S:.1f}s. Serialized' + f' they would cost ~{serial_s:.0f}s, so the async path is not overlapping I/O.' + ) + assert completions == N_ITEMS, f'only {completions}/{N_ITEMS} activities completed' diff --git a/tests/ext/workflow/durabletask/test_client_async.py b/tests/ext/workflow/durabletask/test_client_async.py index 57e7374e7..ec56c21b4 100644 --- a/tests/ext/workflow/durabletask/test_client_async.py +++ b/tests/ext/workflow/durabletask/test_client_async.py @@ -165,7 +165,11 @@ def test_async_client_construct_with_metadata(): with patch( 'dapr.ext.workflow._durabletask.aio.internal.shared.grpc_aio.insecure_channel' ) as mock_channel: - AsyncTaskHubGrpcClient(host_address=HOST_ADDRESS, metadata=METADATA) + client = AsyncTaskHubGrpcClient(host_address=HOST_ADDRESS, metadata=METADATA) + assert mock_channel.call_count == 0 # channel is built lazily, not at construction + + client._get_stub() + # Ensure channel created with an interceptor that has the expected metadata args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS @@ -175,6 +179,18 @@ def test_async_client_construct_with_metadata(): assert interceptors[0]._metadata == METADATA +def test_async_client_channel_is_lazy(): + with patch( + 'dapr.ext.workflow._durabletask.aio.internal.shared.grpc_aio.insecure_channel' + ) as mock_channel: + client = AsyncTaskHubGrpcClient(host_address=HOST_ADDRESS) + assert mock_channel.call_count == 0 # not built at construction + + client._get_stub() + client._get_stub() + assert mock_channel.call_count == 1 # built once on first use, then cached + + def test_aio_channel_passes_base_options_and_max_lengths(): base_options = [ ('grpc.max_send_message_length', 4321), diff --git a/tests/ext/workflow/durabletask/test_grpc_aio_log_filter.py b/tests/ext/workflow/durabletask/test_grpc_aio_log_filter.py new file mode 100644 index 000000000..fbc3fbfcc --- /dev/null +++ b/tests/ext/workflow/durabletask/test_grpc_aio_log_filter.py @@ -0,0 +1,71 @@ +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +from unittest.mock import patch + +from dapr.ext.workflow._durabletask.aio.internal import shared + +HOST_ADDRESS = 'localhost:50051' + + +def _record(msg: str, exc: BaseException | None) -> logging.LogRecord: + exc_info = None + if exc is not None: + try: + raise exc + except BaseException: + exc_info = sys.exc_info() + return logging.LogRecord('asyncio', logging.ERROR, __file__, 1, msg, (), exc_info) + + +def test_filter_drops_poller_eagain_record(): + record = _record( + 'Exception in callback PollerCompletionQueue._handle_events()', + BlockingIOError(11, 'Resource temporarily unavailable'), + ) + assert shared._GrpcAioPollerNoiseFilter().filter(record) is False + + +def test_filter_keeps_record_without_exception(): + assert shared._GrpcAioPollerNoiseFilter().filter(_record('some other error', None)) is True + + +def test_filter_keeps_blockingioerror_without_marker(): + record = _record('unrelated message', BlockingIOError(11, 'nope')) + assert shared._GrpcAioPollerNoiseFilter().filter(record) is True + + +def test_get_grpc_aio_channel_installs_filter_on_asyncio_logger(): + asyncio_logger = logging.getLogger('asyncio') + for existing in [ + f for f in asyncio_logger.filters if isinstance(f, shared._GrpcAioPollerNoiseFilter) + ]: + asyncio_logger.removeFilter(existing) + + with patch('dapr.ext.workflow._durabletask.aio.internal.shared.grpc_aio.insecure_channel'): + shared.get_grpc_aio_channel(HOST_ADDRESS, False) + + installed = [ + f for f in asyncio_logger.filters if isinstance(f, shared._GrpcAioPollerNoiseFilter) + ] + assert len(installed) == 1 # installed once, not duplicated + + +def test_install_is_idempotent(): + asyncio_logger = logging.getLogger('asyncio') + shared._silence_grpc_aio_poller_noise() + shared._silence_grpc_aio_poller_noise() + installed = [ + f for f in asyncio_logger.filters if isinstance(f, shared._GrpcAioPollerNoiseFilter) + ] + assert len(installed) == 1 diff --git a/tests/ext/workflow/durabletask/test_propagation_wiring.py b/tests/ext/workflow/durabletask/test_propagation_wiring.py index f74b8fb8e..010f54aae 100644 --- a/tests/ext/workflow/durabletask/test_propagation_wiring.py +++ b/tests/ext/workflow/durabletask/test_propagation_wiring.py @@ -191,12 +191,13 @@ def reading_activity(ctx: task.ActivityContext, _): registry = worker._Registry() activity_name = registry.add_activity(reading_activity) - executor = worker._ActivityExecutor(registry, TEST_LOGGER) + executor = worker._ActivityExecutor(TEST_LOGGER) propagated = PropagatedHistory.from_proto(_single_chunk_history('Caller')) assert propagated is not None encoded_output = executor.execute( + reading_activity, orchestration_id='wf-1', name=activity_name, task_id=1, @@ -221,8 +222,9 @@ def reading_activity(ctx: task.ActivityContext, _): registry = worker._Registry() activity_name = registry.add_activity(reading_activity) - executor = worker._ActivityExecutor(registry, TEST_LOGGER) + executor = worker._ActivityExecutor(TEST_LOGGER) executor.execute( + reading_activity, orchestration_id='wf-1', name=activity_name, task_id=1, diff --git a/tests/ext/workflow/test_async_activity_registration.py b/tests/ext/workflow/test_async_activity_registration.py new file mode 100644 index 000000000..8347a2f9e --- /dev/null +++ b/tests/ext/workflow/test_async_activity_registration.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- + +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sync/async activity registration and the resulting wrappers. + +These tests exercise the helpers in workflow_runtime that decide whether an activity +runs in a thread pool (sync) or as a coroutine on the event loop (async). The +WorkflowRuntime is constructed against a fake registry so we don't need a sidecar. +""" + +import asyncio +import functools +import inspect +import unittest +from unittest import mock + +from pydantic import BaseModel + +from dapr.ext.workflow._durabletask.internal.shared import is_async_callable +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class OrderInput(BaseModel): + order_id: str + amount: float + + +class FakeRegistry: + def __init__(self): + self.activities: dict[str, object] = {} + + def add_named_activity(self, name: str, fn) -> None: + self.activities[name] = fn + + +class _AsyncActivityRegistrationTestBase(unittest.TestCase): + def setUp(self) -> None: + self._registry_patch = mock.patch( + 'dapr.ext.workflow._durabletask.worker._Registry', return_value=FakeRegistry() + ) + self._registry_patch.start() + self.runtime = WorkflowRuntime() + # Reach into the runtime to grab its registry for assertions. + self.registry: FakeRegistry = self.runtime._WorkflowRuntime__worker._registry + + def tearDown(self) -> None: + # Tear down the worker's ThreadPoolExecutor so each test doesn't leak threads/fds. + # The runtime never started, so ``shutdown()`` -> ``stop()`` early-returns; + # shut the manager down directly to actually close the executor. + worker = self.runtime._WorkflowRuntime__worker + self.runtime.shutdown() + worker._async_worker_manager.shutdown() + self._registry_patch.stop() + + +class AsyncActivityRegistrationTest(_AsyncActivityRegistrationTestBase): + def test_async_activity_registers_coroutine_wrapper(self) -> None: + async def my_async_activity(ctx: WorkflowActivityContext, payload: str) -> str: + return payload.upper() + + self.runtime.register_activity(my_async_activity) + + wrapper = self.registry.activities['my_async_activity'] + self.assertTrue(inspect.iscoroutinefunction(wrapper)) + + def test_sync_activity_registers_plain_wrapper(self) -> None: + def my_sync_activity(ctx: WorkflowActivityContext, payload: str) -> str: + return payload.upper() + + self.runtime.register_activity(my_sync_activity) + + wrapper = self.registry.activities['my_sync_activity'] + self.assertFalse(inspect.iscoroutinefunction(wrapper)) + self.assertTrue(callable(wrapper)) + + def test_async_wrapper_awaits_user_function(self) -> None: + recorded: list[tuple[WorkflowActivityContext, str]] = [] + + async def my_async_activity(ctx: WorkflowActivityContext, payload: str) -> str: + await asyncio.sleep(0) + recorded.append((ctx, payload)) + return payload.upper() + + self.runtime.register_activity(my_async_activity) + wrapper = self.registry.activities['my_async_activity'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 7 + result = asyncio.run(wrapper(fake_ctx, 'hello')) + + self.assertEqual(result, 'HELLO') + self.assertEqual(len(recorded), 1) + self.assertEqual(recorded[0][1], 'hello') + self.assertIsInstance(recorded[0][0], WorkflowActivityContext) + + def test_sync_wrapper_calls_user_function(self) -> None: + recorded: list[tuple[WorkflowActivityContext, str]] = [] + + def my_sync_activity(ctx: WorkflowActivityContext, payload: str) -> str: + recorded.append((ctx, payload)) + return payload.upper() + + self.runtime.register_activity(my_sync_activity) + wrapper = self.registry.activities['my_sync_activity'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 3 + result = wrapper(fake_ctx, 'world') + + self.assertEqual(result, 'WORLD') + self.assertEqual(len(recorded), 1) + self.assertEqual(recorded[0][1], 'world') + self.assertIsInstance(recorded[0][0], WorkflowActivityContext) + + def test_async_wrapper_coerces_input_to_declared_model(self) -> None: + seen: list[OrderInput] = [] + + async def place_order(ctx: WorkflowActivityContext, order: OrderInput) -> str: + seen.append(order) + return order.order_id + + self.runtime.register_activity(place_order) + wrapper = self.registry.activities['place_order'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 99 + raw_input = {'order_id': 'abc-1', 'amount': 9.5} + result = asyncio.run(wrapper(fake_ctx, raw_input)) + + self.assertEqual(result, 'abc-1') + self.assertEqual(len(seen), 1) + self.assertIsInstance(seen[0], OrderInput) + self.assertEqual(seen[0].amount, 9.5) + + def test_async_wrapper_propagates_exceptions(self) -> None: + async def failing(ctx: WorkflowActivityContext, payload: str) -> str: + raise RuntimeError('boom') + + self.runtime.register_activity(failing) + wrapper = self.registry.activities['failing'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 1 + with self.assertRaises(RuntimeError) as caught: + asyncio.run(wrapper(fake_ctx, 'x')) + self.assertEqual(str(caught.exception), 'boom') + + def test_async_wrapper_supports_no_input_parameter(self) -> None: + async def heartbeat(ctx: WorkflowActivityContext) -> str: + return 'ok' + + self.runtime.register_activity(heartbeat) + wrapper = self.registry.activities['heartbeat'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 0 + result = asyncio.run(wrapper(fake_ctx, None)) + self.assertEqual(result, 'ok') + + +class IsAsyncCallableTest(unittest.TestCase): + """Pin the contract of ``is_async_callable`` against decorator shapes that a bare + ``inspect.iscoroutinefunction`` would miss. These are the patterns the fix for finding + #5 was meant to address. Without coverage, a future refactor can silently regress + async-activity routing for any of them. + """ + + def test_plain_async_function_is_async(self) -> None: + async def fn() -> None: ... + + self.assertTrue(is_async_callable(fn)) + + def test_plain_sync_function_is_not_async(self) -> None: + def fn() -> None: ... + + self.assertFalse(is_async_callable(fn)) + + def test_functools_partial_of_async_is_async(self) -> None: + async def fn(prefix: str, payload: str) -> str: + return prefix + payload + + partial_fn = functools.partial(fn, 'hello-') + self.assertTrue(is_async_callable(partial_fn)) + + def test_functools_partial_of_sync_is_not_async(self) -> None: + def fn(prefix: str, payload: str) -> str: + return prefix + payload + + partial_fn = functools.partial(fn, 'hello-') + self.assertFalse(is_async_callable(partial_fn)) + + def test_wraps_chain_over_async_is_async(self) -> None: + """A sync decorator that uses @functools.wraps exposes the inner via __wrapped__.""" + + async def inner(ctx: object, inp: object) -> None: ... + + @functools.wraps(inner) + def outer(ctx: object, inp: object) -> object: + return inner(ctx, inp) + + self.assertTrue(is_async_callable(outer)) + + def test_nested_partial_and_wraps_chain_is_async(self) -> None: + """partial(@wraps over async). Exercises both unwrap stages in order.""" + + async def inner(prefix: str, payload: str) -> str: + return prefix + payload + + @functools.wraps(inner) + def wrapped(prefix: str, payload: str) -> str: + return inner(prefix, payload) + + partial_wrapped = functools.partial(wrapped, 'hi-') + self.assertTrue(is_async_callable(partial_wrapped)) + + def test_callable_class_instance_with_async_call_is_async(self) -> None: + class AsyncCallable: + async def __call__(self, ctx: object, inp: object) -> str: + return 'ok' + + self.assertTrue(is_async_callable(AsyncCallable())) + + def test_callable_class_instance_with_sync_call_is_not_async(self) -> None: + class SyncCallable: + def __call__(self, ctx: object, inp: object) -> str: + return 'ok' + + self.assertFalse(is_async_callable(SyncCallable())) + + def test_cyclic_wrapped_chain_does_not_crash(self) -> None: + """A self-referential ``__wrapped__`` makes ``inspect.unwrap`` raise; detection must + fall back to the outermost callable instead of propagating the error.""" + + async def async_cyclic() -> None: ... + + async_cyclic.__wrapped__ = async_cyclic # type: ignore[attr-defined] + + def sync_cyclic() -> None: ... + + sync_cyclic.__wrapped__ = sync_cyclic # type: ignore[attr-defined] + + self.assertTrue(is_async_callable(async_cyclic)) + self.assertFalse(is_async_callable(sync_cyclic)) + + def test_non_callable_input_is_not_async(self) -> None: + """The worker passes ``None`` for an unregistered activity and relies on a False + result to route to the sync handler.""" + self.assertFalse(is_async_callable(None)) + self.assertFalse(is_async_callable(42)) + + +class AsyncAndSyncCoexistTest(_AsyncActivityRegistrationTestBase): + def test_runtime_registers_mixed_sync_and_async_activities(self) -> None: + async def async_activity(ctx: WorkflowActivityContext, payload: int) -> int: + return payload + 1 + + def sync_activity(ctx: WorkflowActivityContext, payload: int) -> int: + return payload * 2 + + self.runtime.register_activity(async_activity) + self.runtime.register_activity(sync_activity) + + async_wrapper = self.registry.activities['async_activity'] + sync_wrapper = self.registry.activities['sync_activity'] + + self.assertTrue(inspect.iscoroutinefunction(async_wrapper)) + self.assertFalse(inspect.iscoroutinefunction(sync_wrapper)) + + +if __name__ == '__main__': + unittest.main()