diff --git a/CHANGELOG.md b/CHANGELOG.md index f002d37d8..32527c303 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ to include examples, links to docs, or any other relevant information. loop. - Relaxed the protobuf dependency bounds to allow protobuf 7 where compatible with the selected optional dependencies. +- Standalone Nexus operation links are now forwarded on start workflow and signal requests. ### Deprecated diff --git a/temporalio/client/_impl.py b/temporalio/client/_impl.py index 4e7049985..8e33ff910 100644 --- a/temporalio/client/_impl.py +++ b/temporalio/client/_impl.py @@ -217,28 +217,18 @@ async def _build_start_workflow_execution_request( if input.request_id: req.request_id = input.request_id - # Server currently only supports workflow_event and batch_job - # link types. This filter should be removed or adapted as - # server-side support comes online. - # See https://github.com/temporalio/temporal/issues/10345 - links = [ - link - for link in input.links - if link.HasField("workflow_event") or link.HasField("batch_job") - ] - req.completion_callbacks.extend( temporalio.api.common.v1.Callback( nexus=temporalio.api.common.v1.Callback.Nexus( url=callback.url, header=callback.headers, ), - links=links, + links=input.links, ) for callback in input.callbacks ) # Links are duplicated on request for compatibility with older server versions. - req.links.extend(links) + req.links.extend(input.links) nexus_ctx = temporalio.nexus._operation_context._try_start_operation_context() if nexus_ctx is not None: diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index d3142f74a..0af468b44 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -1,3 +1,102 @@ +from collections.abc import Sequence + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.history.v1 +from temporalio.client import WorkflowHistory + + def make_nexus_endpoint_name(task_queue: str) -> str: # Create endpoints for different task queues without name collisions. return f"nexus-endpoint-{task_queue}" + + +def events_of_type( + history: WorkflowHistory, + event_type: temporalio.api.enums.v1.EventType.ValueType, +) -> list[temporalio.api.history.v1.HistoryEvent]: + return [event for event in history.events if event.event_type == event_type] + + +def links_from_workflow_execution_started_event( + event: temporalio.api.history.v1.HistoryEvent, +) -> list[temporalio.api.common.v1.Link]: + callback_links = [ + link + for callback in event.workflow_execution_started_event_attributes.completion_callbacks + for link in callback.links + ] + if callback_links: + return list(callback_links) + return list(event.links) + + +def workflow_event_link_event_type( + workflow_event: temporalio.api.common.v1.Link.WorkflowEvent, +) -> temporalio.api.enums.v1.EventType.ValueType: + if workflow_event.HasField("request_id_ref"): + return workflow_event.request_id_ref.event_type + return workflow_event.event_ref.event_type + + +def expected_nexus_operation_link( + *, + namespace: str, + operation_id: str, + run_id: str, +) -> temporalio.api.common.v1.Link: + return temporalio.api.common.v1.Link( + nexus_operation=temporalio.api.common.v1.Link.NexusOperation( + namespace=namespace, + operation_id=operation_id, + run_id=run_id, + ) + ) + + +def expected_workflow_event_link( + *, + namespace: str, + workflow_id: str, + run_id: str, + event_type: temporalio.api.enums.v1.EventType.ValueType, + event_id: int = 0, + request_id: str | None = None, +) -> temporalio.api.common.v1.Link: + if request_id is not None: + return temporalio.api.common.v1.Link( + workflow_event=temporalio.api.common.v1.Link.WorkflowEvent( + namespace=namespace, + workflow_id=workflow_id, + run_id=run_id, + request_id_ref=temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference( + request_id=request_id, + event_type=event_type, + ), + ) + ) + + return temporalio.api.common.v1.Link( + workflow_event=temporalio.api.common.v1.Link.WorkflowEvent( + namespace=namespace, + workflow_id=workflow_id, + run_id=run_id, + event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference( + event_id=event_id, + event_type=event_type, + ), + ) + ) + + +def assert_links_match( + links: Sequence[temporalio.api.common.v1.Link], + *expected_links: temporalio.api.common.v1.Link, +) -> None: + actual = sorted(list(links), key=_link_sort_key) + expected = sorted(list(expected_links), key=_link_sort_key) + assert actual == expected + + +def _link_sort_key(link: temporalio.api.common.v1.Link) -> bytes: + return link.SerializeToString(deterministic=True) diff --git a/tests/nexus/test_signal_link_propagation_e2e.py b/tests/nexus/test_signal_link_propagation_e2e.py index 9795fd841..e489ad8a7 100644 --- a/tests/nexus/test_signal_link_propagation_e2e.py +++ b/tests/nexus/test_signal_link_propagation_e2e.py @@ -14,8 +14,8 @@ ``history.enableCHASMSignalBacklinks=true`` (added to the local dev-server args in ``tests/conftest.py``). The server populates the backlink's reference via ``RequestIdReference`` rather than ``EventReference``, so backlink assertions tolerate both oneof variants of -``common.v1.Link.WorkflowEvent.reference`` (see ``_backlink_event_type``). When run against a -server that does not emit the backlink, the backward assertions are skipped. +``common.v1.Link.WorkflowEvent.reference`` (see ``workflow_event_link_event_type``). When run +against a server that does not emit the backlink, the backward assertions are skipped. The forward/backward description above applies to operations scheduled by a caller workflow. The file also covers the same handlers invoked as standalone (client-initiated) operations via @@ -41,7 +41,6 @@ ) from nexusrpc.handler._decorators import operation_handler -import temporalio.api.common.v1 import temporalio.api.enums.v1 import temporalio.api.history.v1 import temporalio.common @@ -51,7 +50,11 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers import assert_eventually -from tests.helpers.nexus import make_nexus_endpoint_name +from tests.helpers.nexus import ( + events_of_type, + make_nexus_endpoint_name, + workflow_event_link_event_type, +) EventType = temporalio.api.enums.v1.EventType @@ -216,29 +219,12 @@ async def run(self, callee_id: str, task_queue: str) -> str: # ── Assertion helpers ─────────────────────────────────────────────────────────────────────── -def _events_of_type( - history: WorkflowHistory, - event_type: temporalio.api.enums.v1.EventType.ValueType, -) -> list[temporalio.api.history.v1.HistoryEvent]: - return [e for e in history.events if e.event_type == event_type] - - -def _backlink_event_type( - we: temporalio.api.common.v1.Link.WorkflowEvent, -) -> temporalio.api.enums.v1.EventType.ValueType: - # Server PR #9897 keys backlinks via RequestIdReference rather than EventReference; accept - # either oneof variant (matches Java SignalOperationLinkingTest.assertBacklink). - if we.HasField("request_id_ref"): - return we.request_id_ref.event_type - return we.event_ref.event_type - - def _assert_forward_link( callee_history: WorkflowHistory, caller_id: str, expected_count: int, ) -> None: - signaled = _events_of_type( + signaled = events_of_type( callee_history, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED ) assert len(signaled) == expected_count, ( @@ -267,7 +253,10 @@ def _assert_backlink( return False we = event.links[0].workflow_event assert we.workflow_id == callee_id, "backlink should reference the callee workflow" - assert _backlink_event_type(we) == EventType.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED + assert ( + workflow_event_link_event_type(we) + == EventType.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED + ) return True @@ -310,7 +299,7 @@ async def test_sync_signal_operation_links( _assert_forward_link(callee_history, caller_id, expected_count=2) # Backward: the single NexusOperationCompleted carries backlinks to the callee. - completed = _events_of_type( + completed = events_of_type( caller_history, EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED ) assert len(completed) == 1, ( @@ -360,7 +349,7 @@ async def test_async_signal_operation_links( _assert_forward_link(callee_history, caller_id, expected_count=1) # Backward: the backlink lands on NexusOperationStarted for the async response path. - started = _events_of_type( + started = events_of_type( caller_history, EventType.EVENT_TYPE_NEXUS_OPERATION_STARTED ) assert len(started) == 1, ( @@ -389,7 +378,7 @@ def _assert_standalone_forward_link( operation_id: str, expected_count: int, ) -> None: - signaled = _events_of_type( + signaled = events_of_type( callee_history, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED ) assert len(signaled) == expected_count, ( @@ -549,7 +538,7 @@ async def test_start_from_handler_attaches_on_conflict_options( assert await callee_handle.result() == "done" callee_history = await callee_handle.fetch_history() - updated = _events_of_type( + updated = events_of_type( callee_history, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED ) if not updated: diff --git a/tests/nexus/test_standalone_operations.py b/tests/nexus/test_standalone_operations.py index 10b2f17fa..8193ba7ba 100644 --- a/tests/nexus/test_standalone_operations.py +++ b/tests/nexus/test_standalone_operations.py @@ -21,6 +21,7 @@ sync_operation, ) +import temporalio.api.enums.v1 from temporalio import nexus, workflow from temporalio.client import ( CancelNexusOperationInput, @@ -36,6 +37,7 @@ OutboundInterceptor, StartNexusOperationInput, TerminateNexusOperationInput, + WorkflowHistory, WorkflowUpdateStage, ) from temporalio.common import ( @@ -56,7 +58,13 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers import assert_eventually -from tests.helpers.nexus import make_nexus_endpoint_name +from tests.helpers.nexus import ( + assert_links_match, + expected_nexus_operation_link, + expected_workflow_event_link, + links_from_workflow_execution_started_event, + make_nexus_endpoint_name, +) # --------------------------------------------------------------------------- # Data types @@ -259,6 +267,61 @@ async def test_start_async_operation_and_poll_result( assert result.value == "async-hello" +async def test_started_workflow_has_link_to_standalone_nexus_operation( + client: Client, env: WorkflowEnvironment +): + """Start a workflow_run operation and verify its workflow links back to the Nexus op.""" + if env.supports_time_skipping: + pytest.skip( + "Standalone Nexus Operation tests don't work with time-skipping server" + ) + + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + service_handler = StandaloneTestServiceHandler() + + async with Worker( + client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + workflows=[EchoHandlerWorkflow, BlockingHandlerWorkflow], + ): + await env.create_nexus_endpoint(endpoint_name, task_queue) + + nexus_client = client.create_nexus_client( + service=StandaloneTestService, endpoint=endpoint_name + ) + op_id = str(uuid.uuid4()) + input_value = f"link-test-{uuid.uuid4()}" + workflow_id = f"blocking_async-{input_value}" + + handle = await nexus_client.start_operation( + StandaloneTestService.blocking_async, + EchoInput(value=input_value), + id=op_id, + id_reuse_policy=NexusOperationIDReusePolicy.REJECT_DUPLICATE, + id_conflict_policy=NexusOperationIDConflictPolicy.FAIL, + schedule_to_close_timeout=timedelta(seconds=30), + ) + + await service_handler.started_blocking.wait() + workflow_history = await _assert_workflow_started_with_nexus_operation_link( + client, workflow_id, handle + ) + await _assert_nexus_operation_has_link_to_started_workflow( + client, workflow_history, handle + ) + + workflow_handle = client.get_workflow_handle(workflow_id) + await workflow_handle.start_update( + BlockingHandlerWorkflow.unblock, + wait_for_stage=WorkflowUpdateStage.COMPLETED, + ) + result = await handle.result() + assert isinstance(result, EchoOutput) + assert result.value == input_value + + async def test_execute_operation(client: Client, env: WorkflowEnvironment): """Use execute_operation convenience method, verify it returns result directly.""" if env.supports_time_skipping: @@ -949,3 +1012,52 @@ async def test_interceptor_receives_inputs(client: Client, env: WorkflowEnvironm count_input = interceptor.count_calls[-1] assert isinstance(count_input, CountNexusOperationsInput) assert count_input.query == query + + +async def _assert_workflow_started_with_nexus_operation_link( + client: Client, + workflow_id: str, + operation_handle: NexusOperationHandle[Any], +) -> WorkflowHistory: + history = await client.get_workflow_handle(workflow_id).fetch_history() + started_event = next( + ( + e + for e in history.events + if ( + e.event_type + == temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED + ) + ), + None, + ) + assert started_event is not None + + assert operation_handle.run_id is not None + assert_links_match( + links_from_workflow_execution_started_event(started_event), + expected_nexus_operation_link( + namespace=client.namespace, + operation_id=operation_handle.operation_id, + run_id=operation_handle.run_id, + ), + ) + return history + + +async def _assert_nexus_operation_has_link_to_started_workflow( + client: Client, + workflow_history: WorkflowHistory, + operation_handle: NexusOperationHandle[Any], +) -> None: + desc = await operation_handle.describe() + assert_links_match( + desc.raw_description.links, + expected_workflow_event_link( + namespace=client.namespace, + workflow_id=workflow_history.workflow_id, + run_id=workflow_history.run_id, + event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, + event_id=workflow_history.events[0].event_id, + ), + ) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index df6ace9fa..89ce2719a 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -27,9 +27,7 @@ from nexusrpc.handler._decorators import operation_handler import temporalio.api -import temporalio.api.common.v1 import temporalio.api.enums.v1 -import temporalio.api.history.v1 import temporalio.nexus._operation_handlers from temporalio import nexus, workflow from temporalio.client import ( @@ -69,7 +67,10 @@ ) from tests.helpers import find_free_port, new_worker from tests.helpers.metrics import PromMetricMatcher -from tests.helpers.nexus import make_nexus_endpoint_name +from tests.helpers.nexus import ( + links_from_workflow_execution_started_event, + make_nexus_endpoint_name, +) # TODO(nexus-preview): test worker shutdown, wait_all_completed, drain etc @@ -1285,7 +1286,7 @@ async def test_untyped_caller( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - if type(response_type) == SyncResponse: + if type(response_type) is SyncResponse: response_type = SyncResponse( op_definition_type=op_definition_type, use_async_def=True, @@ -1696,7 +1697,7 @@ async def assert_handler_workflow_has_link_to_caller_workflow( == temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ) ) - links = _get_links_from_workflow_execution_started_event(wf_started_event) + links = links_from_workflow_execution_started_event(wf_started_event) if not len(links) == 1: pytest.fail( f"Expected 1 link on WorkflowExecutionStarted event, got {len(links)}" @@ -1712,16 +1713,6 @@ async def assert_handler_workflow_has_link_to_caller_workflow( ) -def _get_links_from_workflow_execution_started_event( - event: temporalio.api.history.v1.HistoryEvent, -) -> list[temporalio.api.common.v1.Link]: - [callback] = event.workflow_execution_started_event_attributes.completion_callbacks - if links := callback.links: - return list(links) - else: - return list(event.links) - - # When request_cancel is True, the NexusOperationHandle in the workflow evolves # through the following states: # start_fut result_fut handle_task w/ fut_waiter (task._must_cancel)