diff --git a/tests/a2a/test_registry_client.py b/tests/a2a/test_registry_client.py new file mode 100644 index 00000000..6cabaddb --- /dev/null +++ b/tests/a2a/test_registry_client.py @@ -0,0 +1,411 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +from unittest.mock import Mock, patch + +import pytest +import requests + +from veadk.a2a.registry_client import ( + AgentKitA2ARegistryConfig, + RegistryError, + _agent_auth_headers, + _volc_sign_v4, + create_task, + poll_task, + search_agent_cards, +) +from veadk.tools.builtin_tools.a2a_registry import build_a2a_registry_tools + + +def _mock_response(payload: dict, status_code: int = 200) -> Mock: + response = Mock() + response.status_code = status_code + response.raise_for_status.return_value = None + response.json.return_value = payload + return response + + +def _agent_card() -> dict: + return { + "name": "Weather-A2A-Agent", + "description": "Weather agent", + "version": "1.0.0", + "url": "https://example.test/a2a", + "security": [{"bearer": ["Bearer secret-token"]}], + "securitySchemes": { + "bearer": { + "type": "apiKey", + "in": "header", + "name": "Authorization", + } + }, + "skills": [ + { + "id": "weather", + "name": "Weather", + "description": "Query weather", + "tags": ["weather"], + } + ], + } + + +@patch.dict( + "os.environ", + { + "AGENTKIT_ACCESS_KEY": "ak-test", + "AGENTKIT_SECRET_KEY": "sk-test", + }, + clear=False, +) +@patch("veadk.a2a.registry_client.requests.post") +def test_search_agent_cards_sanitizes_and_signs_request(post: Mock): + card = _agent_card() + post.return_value = _mock_response( + { + "ResponseMetadata": {"RequestId": "req-1"}, + "Result": {"AgentCards": [json.dumps(card)], "TotalCount": 1}, + } + ) + + result = search_agent_cards( + "北京天气", + 3, + AgentKitA2ARegistryConfig(space_id="space-test"), + ) + + assert result["outcome"] == "success" + assert result["agents"][0]["name"] == "Weather-A2A-Agent" + + request_headers = post.call_args.kwargs["headers"] + assert "X-Content-Sha256" in request_headers + assert ( + "SignedHeaders=content-type;host;x-content-sha256;x-date" + in request_headers["Authorization"] + ) + assert isinstance(post.call_args.kwargs["data"], bytes) + assert "北京天气" in post.call_args.kwargs["data"].decode("utf-8") + + serialized = json.dumps(result, ensure_ascii=False) + assert "secret-token" not in serialized + assert "Authorization" not in serialized + assert "https://example.test/a2a" not in serialized + + +@patch.dict( + "os.environ", + { + "AGENTKIT_ACCESS_KEY": "ak-test", + "AGENTKIT_SECRET_KEY": "sk-test", + }, + clear=False, +) +@patch("veadk.a2a.registry_client.requests.post") +def test_create_task_gets_agent_and_sends_message(post: Mock): + card = _agent_card() + post.side_effect = [ + _mock_response( + { + "ResponseMetadata": {"RequestId": "get-req"}, + "Result": { + "Id": "agent-id", + "Status": "running", + "AgentCard": json.dumps(card), + }, + } + ), + _mock_response( + { + "result": { + "kind": "message", + "parts": [{"kind": "text", "text": "今天北京晴。"}], + } + } + ), + ] + + result = create_task( + "Weather-A2A-Agent", + "北京天气", + config=AgentKitA2ARegistryConfig(space_id="space-test"), + ) + + assert result["outcome"] == "success" + assert result["selected_agent"]["name"] == "Weather-A2A-Agent" + assert result["response"]["text"] == "今天北京晴。" + assert post.call_args_list[0].kwargs["params"]["Action"] == "GetA2aAgent" + assert post.call_args_list[1].args[0] == "https://example.test/a2a" + + serialized = json.dumps(result, ensure_ascii=False) + assert "secret-token" not in serialized + assert "Authorization" not in serialized + + +@patch.dict( + "os.environ", + { + "AGENTKIT_ACCESS_KEY": "ak-test", + "AGENTKIT_SECRET_KEY": "sk-test", + }, + clear=False, +) +@patch("veadk.a2a.registry_client.time.sleep") +@patch("veadk.a2a.registry_client.requests.post") +def test_poll_task_sleeps_5_seconds_when_not_terminal(post: Mock, sleep: Mock): + card = _agent_card() + post.side_effect = [ + _mock_response( + { + "ResponseMetadata": {"RequestId": "get-req"}, + "Result": { + "Id": "agent-id", + "Status": "running", + "AgentCard": json.dumps(card), + }, + } + ), + _mock_response( + { + "result": { + "id": "task-1", + "status": {"state": "working"}, + } + } + ), + ] + + result = poll_task( + "Weather-A2A-Agent", + "task-1", + config=AgentKitA2ARegistryConfig(space_id="space-test"), + ) + + assert result["outcome"] == "success" + assert result["task"]["status"] == "working" + assert result["is_terminal"] is False + assert result["diagnostics"]["sleep_seconds"] == 5 + assert result["diagnostics"]["next_action"] + sleep.assert_called_once_with(5) + assert post.call_args_list[0].kwargs["params"]["Action"] == "GetA2aAgent" + assert post.call_args_list[1].args[0] == "https://example.test/a2a" + + serialized = json.dumps(result, ensure_ascii=False) + assert "secret-token" not in serialized + assert "Authorization" not in serialized + + +@patch.dict( + "os.environ", + { + "AGENTKIT_ACCESS_KEY": "ak-test", + "AGENTKIT_SECRET_KEY": "sk-test", + }, + clear=False, +) +@patch("veadk.a2a.registry_client.time.sleep") +@patch("veadk.a2a.registry_client.requests.post") +def test_poll_task_returns_terminal_without_sleep(post: Mock, sleep: Mock): + card = _agent_card() + post.side_effect = [ + _mock_response( + { + "ResponseMetadata": {"RequestId": "get-req"}, + "Result": { + "Id": "agent-id", + "Status": "running", + "AgentCard": json.dumps(card), + }, + } + ), + _mock_response( + { + "result": { + "id": "task-1", + "status": {"state": "completed"}, + "artifacts": [ + {"parts": [{"kind": "text", "text": "任务完成。"}]} + ], + } + } + ), + ] + + result = poll_task( + "Weather-A2A-Agent", + "task-1", + config=AgentKitA2ARegistryConfig(space_id="space-test"), + ) + + assert result["outcome"] == "success" + assert result["task"]["status"] == "completed" + assert result["is_terminal"] is True + assert result["response"]["text"] == "任务完成。" + sleep.assert_not_called() + + +def test_build_a2a_registry_tools_exposes_mcp_compatible_names(): + tools = build_a2a_registry_tools(AgentKitA2ARegistryConfig(space_id="space-test")) + + assert [tool.__name__ for tool in tools] == [ + "a2a_registry_search_agent_cards", + "a2a_registry_task_create", + "a2a_registry_task_poll", + ] + + +def test_a2a_registry_tool_descriptions_guide_model_flow(): + search_tool, create_tool, poll_tool = build_a2a_registry_tools( + AgentKitA2ARegistryConfig(space_id="space-test") + ) + + search_doc = " ".join((search_tool.__doc__ or "").split()) + assert "Use this first" in search_doc + assert "concise search prompt" in search_doc + assert "must not exceed 2048 bytes" in search_doc + assert "agents" in search_doc + assert "a2a_registry_task_create" in search_doc + + create_doc = " ".join((create_tool.__doc__ or "").split()) + assert "selected `agents[].name`" in create_doc + assert "message/send" in create_doc + assert "a2a_registry_task_poll" in create_doc + + poll_doc = " ".join((poll_tool.__doc__ or "").split()) + assert "tasks/get" in poll_doc + assert "do not create a new task" in poll_doc + assert "completed" in poll_doc + assert "rejected" in poll_doc + + +@patch("veadk.tools.builtin_tools.a2a_registry.search_agent_cards") +def test_search_tool_accepts_prompt(search: Mock): + config = AgentKitA2ARegistryConfig(space_id="space-test") + search.return_value = {"outcome": "success", "agents": []} + tool = build_a2a_registry_tools(config)[0] + + result = tool(prompt="三亚五日游") + + assert result["outcome"] == "success" + search.assert_called_once_with("三亚五日游", None, config) + + +@patch("veadk.tools.builtin_tools.a2a_registry.search_agent_cards") +def test_search_tool_does_not_expose_top_k_to_model(search: Mock): + config = AgentKitA2ARegistryConfig(space_id="space-test", top_k=7) + search.return_value = {"outcome": "success", "agents": []} + tool = build_a2a_registry_tools(config)[0] + + assert "top_k" not in inspect.signature(tool).parameters + assert "query" not in inspect.signature(tool).parameters + + result = tool(prompt="财务报销") + + assert result["outcome"] == "success" + search.assert_called_once_with("财务报销", None, config) + + +def test_agent_auth_headers_extracts_api_key_header(): + assert _agent_auth_headers(_agent_card()) == { + "Authorization": "Bearer secret-token" + } + + +def test_agent_auth_headers_rejects_unusable_security(): + with pytest.raises(RegistryError) as ctx: + _agent_auth_headers( + { + "security": [{"bearer": []}], + "securitySchemes": { + "bearer": { + "type": "apiKey", + "in": "header", + "name": "Authorization", + } + }, + } + ) + assert ctx.value.code == "AGENT_AUTH_MISSING" + + +def test_agentkit_http_error_uses_safe_diagnostics(): + response = _mock_response( + { + "ResponseMetadata": { + "RequestId": "req-401", + "Action": "SearchAgentCards", + "Version": "2025-10-30", + "Service": "agentkit", + "Region": "cn-beijing", + "Error": { + "Code": "SignatureDoesNotMatch", + "CodeN": 100010, + "Message": "signature mismatch", + }, + } + }, + status_code=401, + ) + response.raise_for_status.side_effect = requests.HTTPError( + "401 Client Error", response=response + ) + + with patch.dict( + "os.environ", + {"AGENTKIT_ACCESS_KEY": "ak-test", "AGENTKIT_SECRET_KEY": "sk-test"}, + clear=False, + ), patch("veadk.a2a.registry_client.requests.post", return_value=response): + with pytest.raises(RegistryError) as ctx: + search_agent_cards( + "weather", + 3, + AgentKitA2ARegistryConfig(space_id="space-test"), + ) + + assert ctx.value.code == "AGENTKIT_OPENAPI_FAILED" + assert ctx.value.diagnostics["status_code"] == 401 + assert ctx.value.diagnostics["request_id"] == "req-401" + assert ctx.value.diagnostics["response_error"]["Code"] == ( + "SignatureDoesNotMatch" + ) + serialized = json.dumps(ctx.value.diagnostics, ensure_ascii=False) + assert "Authorization" not in serialized + assert "ak-test" not in serialized + assert "sk-test" not in serialized + + +def test_volc_sign_v4_signs_openapi_headers(): + headers = _volc_sign_v4( + access_key="ak-test", + secret_key="sk-test", + service="agentkit", + region="cn-beijing", + method="POST", + path="/", + query={"Action": "SearchAgentCards", "Version": "2025-10-30"}, + headers={ + "Host": "open.volcengineapi.com", + "Content-Type": "application/json", + }, + body='{"SpaceId":"space-test"}', + ) + + assert "X-Date" in headers + assert "X-Content-Sha256" in headers + assert ( + "SignedHeaders=content-type;host;x-content-sha256;x-date" + in headers["Authorization"] + ) diff --git a/tests/cli/test_cli_harness_registry.py b/tests/cli/test_cli_harness_registry.py new file mode 100644 index 00000000..83aaad42 --- /dev/null +++ b/tests/cli/test_cli_harness_registry.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import yaml +from click.testing import CliRunner + +from veadk.cli.cli_harness import harness + + +def test_harness_add_no_longer_exposes_registry_flags(): + runner = CliRunner() + with runner.isolated_filesystem(): + create_result = runner.invoke(harness, ["create", "harness-app"]) + assert create_result.exit_code == 0 + + result = runner.invoke( + harness, + [ + "add", + "--path", + "harness-app", + "--registry-space-id", + "space-test", + ], + ) + + assert result.exit_code != 0 + assert "No such option" in result.output + + +def test_harness_show_does_not_list_registry_override_flags(): + runner = CliRunner() + with runner.isolated_filesystem(): + runner.invoke(harness, ["create", "harness-app"]) + + result = runner.invoke(harness, ["show", "--path", "harness-app"]) + + assert result.exit_code == 0, result.output + assert "--registry-space-id" not in result.output + assert "--registry-top-k" not in result.output + + +def test_harness_add_tool_calling_flags_write_top_level_config(): + runner = CliRunner() + with runner.isolated_filesystem(): + runner.invoke(harness, ["create", "harness-app"]) + + result = runner.invoke( + harness, + [ + "add", + "--path", + "harness-app", + "--structured-tool-calls", + "--include-tools-every-turn", + ], + ) + + assert result.exit_code == 0, result.output + data = yaml.safe_load((Path("harness-app") / "harness.yaml").read_text()) + assert data["structured_tool_calls"] is True + assert data["include_tools_every_turn"] is True + + +def test_harness_add_removes_old_responses_config_names(): + runner = CliRunner() + with runner.isolated_filesystem(): + runner.invoke(harness, ["create", "harness-app"]) + yaml_path = Path("harness-app") / "harness.yaml" + data = yaml.safe_load(yaml_path.read_text()) + data["enable_responses"] = True + data["enable_responses_cache"] = False + yaml_path.write_text(yaml.safe_dump(data, sort_keys=False)) + + result = runner.invoke( + harness, + [ + "add", + "--path", + "harness-app", + "--structured-tool-calls", + ], + ) + + assert result.exit_code == 0, result.output + data = yaml.safe_load(yaml_path.read_text()) + assert "enable_responses" not in data + assert "enable_responses_cache" not in data + assert data["structured_tool_calls"] is True diff --git a/tests/cloud/test_harness_app_contract.py b/tests/cloud/test_harness_app_contract.py index cca27ef8..5383d812 100644 --- a/tests/cloud/test_harness_app_contract.py +++ b/tests/cloud/test_harness_app_contract.py @@ -24,6 +24,8 @@ import time, so it is intentionally left out to keep these tests offline. """ +from pathlib import Path + from veadk.cloud.harness_app.types import ( HarnessConfig, HarnessOverrides, @@ -31,7 +33,8 @@ InvokeHarnessResponse, RunAgentRequest, ) -from veadk.cloud.harness_app.utils import split_csv +from veadk.cloud.harness_app.env_mapping import to_runtime_env +from veadk.cloud.harness_app.utils import config_from_env, split_csv from veadk.consts import DEFAULT_MODEL_AGENT_NAME from veadk.prompts.agent_default_prompt import DEFAULT_INSTRUCTION @@ -49,6 +52,10 @@ def test_fields(self): "skills", "system_prompt", "runtime", + "registry_space_id", + "registry_endpoint", + "registry_region", + "registry_top_k", } def test_defaults(self): @@ -58,6 +65,10 @@ def test_defaults(self): assert fields["skills"].default == "" assert fields["system_prompt"].default == "You are a helpful assistant." assert fields["runtime"].default == "adk" + assert fields["registry_space_id"].default == "" + assert fields["registry_endpoint"].default == "" + assert fields["registry_region"].default == "" + assert fields["registry_top_k"].default == 3 def test_tools_and_skills_are_csv_strings(self): # The server splits these with split_csv(); they must stay plain strings, @@ -84,6 +95,13 @@ def test_adds_creation_time_fields(self): "longterm_memory_type", "shortterm_memory_type", "max_llm_calls", + "structured_tool_calls", + "include_tools_every_turn", + "registry_type", + "registry_version", + "registry_service_name", + "registry_timeout_ms", + "registry_poll_interval_ms", } def test_component_defaults(self): @@ -92,6 +110,12 @@ def test_component_defaults(self): assert fields["knowledgebase_type"].default == "" assert fields["longterm_memory_type"].default == "" assert fields["shortterm_memory_type"].default == "local" + assert fields["structured_tool_calls"].default is False + assert fields["include_tools_every_turn"].default is True + assert fields["registry_type"].default == "" + assert fields["registry_top_k"].default == 3 + assert fields["registry_timeout_ms"].default == 60000 + assert fields["registry_poll_interval_ms"].default == 5000 def test_system_prompt_default_is_veadk_instruction(self): # HarnessConfig overrides the override-layer default with VeADK's own. @@ -101,6 +125,63 @@ def test_app_name_populated_via_name_alias(self): assert HarnessConfig(name="research-agent").app_name == "research-agent" assert HarnessConfig().app_name == "harness_app" + def test_registry_yaml_maps_to_runtime_env(self): + envs = to_runtime_env( + { + "registry": { + "type": "agentkit_a2a", + "space_id": "space-test", + "top_k": 5, + "region": "cn-beijing", + } + } + ) + + assert envs["REGISTRY_TYPE"] == "agentkit_a2a" + assert envs["REGISTRY_SPACE_ID"] == "space-test" + assert envs["REGISTRY_TOP_K"] == "5" + assert envs["REGISTRY_REGION"] == "cn-beijing" + + def test_tool_calling_yaml_maps_to_runtime_env(self): + envs = to_runtime_env( + { + "structured_tool_calls": True, + "include_tools_every_turn": True, + } + ) + + assert envs["STRUCTURED_TOOL_CALLS"] == "true" + assert envs["INCLUDE_TOOLS_EVERY_TURN"] == "true" + + def test_config_from_env_reads_registry_fields(self, monkeypatch): + monkeypatch.setenv("REGISTRY_TYPE", "agentkit_a2a") + monkeypatch.setenv("REGISTRY_SPACE_ID", "space-test") + monkeypatch.setenv("REGISTRY_TOP_K", "5") + monkeypatch.setenv("REGISTRY_REGION", "cn-beijing") + + config = config_from_env() + + assert config.registry_type == "agentkit_a2a" + assert config.registry_space_id == "space-test" + assert config.registry_top_k == 5 + assert config.registry_region == "cn-beijing" + + def test_config_from_env_reads_tool_calling_fields(self, monkeypatch): + monkeypatch.setenv("STRUCTURED_TOOL_CALLS", "true") + monkeypatch.setenv("INCLUDE_TOOLS_EVERY_TURN", "false") + + config = config_from_env() + + assert config.structured_tool_calls is True + assert config.include_tools_every_turn is False + + def test_registry_overrides_remount_registry_tools(self): + source = Path("veadk/cloud/harness_app/utils.py").read_text() + + assert "_apply_registry_overrides(" in source + assert "_remove_a2a_registry_tools(" in source + assert "build_a2a_registry_tools(overridden_config)" in source + class TestRequestResponseSchemas: def test_run_agent_request_fields(self): diff --git a/veadk/a2a/registry_client.py b/veadk/a2a/registry_client.py new file mode 100644 index 00000000..dcaf46b5 --- /dev/null +++ b/veadk/a2a/registry_client.py @@ -0,0 +1,833 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import hashlib +import hmac +import json +import os +import time +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any +from urllib.parse import quote, urlparse + +import requests + +from veadk.auth.veauth.utils import get_credential_from_vefaas_iam + +DEFAULT_ENDPOINT = "http://volcengineapi.byted.org/" +DEFAULT_VERSION = "2025-10-30" +DEFAULT_SERVICE_NAME = "agentkit" +DEFAULT_REGION = "cn-beijing" +DEFAULT_TOP_K = 3 +DEFAULT_TIMEOUT_MS = 60000 +DEFAULT_POLL_INTERVAL_MS = 5000 +TERMINAL_STATES = {"completed", "failed", "canceled", "rejected"} + + +class RegistryError(Exception): + """A safe, structured error from the AgentKit A2A registry client.""" + + def __init__( + self, code: str, message: str, diagnostics: dict[str, Any] | None = None + ): + super().__init__(message) + self.code = code + self.message = message + self.diagnostics = diagnostics or {} + + +@dataclass(frozen=True) +class AgentKitA2ARegistryConfig: + space_id: str = "" + endpoint: str = DEFAULT_ENDPOINT + version: str = DEFAULT_VERSION + service_name: str = DEFAULT_SERVICE_NAME + region: str = DEFAULT_REGION + top_k: int = DEFAULT_TOP_K + timeout_ms: int = DEFAULT_TIMEOUT_MS + poll_interval_ms: int = DEFAULT_POLL_INTERVAL_MS + + +@dataclass(frozen=True) +class _RegistryCredentials: + access_key: str + secret_key: str + session_token: str = "" + + +def registry_config_from_env() -> AgentKitA2ARegistryConfig: + """Read AgentKit A2A registry config from Harness-compatible env vars.""" + + return AgentKitA2ARegistryConfig( + space_id=_first_env( + ["REGISTRY_SPACE_ID", "AGENTKIT_A2A_SPACE_ID", "A2A_REGISTRY_SPACE_ID"] + ), + endpoint=_first_env( + ["REGISTRY_ENDPOINT", "AGENTKIT_OPENAPI_ENDPOINT"], DEFAULT_ENDPOINT + ), + version=_first_env( + ["REGISTRY_VERSION", "AGENTKIT_OPENAPI_VERSION"], DEFAULT_VERSION + ), + service_name=_first_env( + ["REGISTRY_SERVICE_NAME", "AGENTKIT_SERVICE_NAME"], DEFAULT_SERVICE_NAME + ), + region=_first_env(["REGISTRY_REGION", "AGENTKIT_REGION"], DEFAULT_REGION), + top_k=_int_env("REGISTRY_TOP_K", DEFAULT_TOP_K, minimum=1), + timeout_ms=_int_env("REGISTRY_TIMEOUT_MS", DEFAULT_TIMEOUT_MS, minimum=1000), + poll_interval_ms=_int_env( + "REGISTRY_POLL_INTERVAL_MS", DEFAULT_POLL_INTERVAL_MS, minimum=100 + ), + ) + + +def search_agent_cards( + prompt: str, + top_k: int | None = None, + config: AgentKitA2ARegistryConfig | None = None, +) -> dict[str, Any]: + """Search AgentKit A2A registry by prompt and return sanitized AgentCards.""" + + started = time.monotonic() + config = _resolve_config(config) + if not prompt or not prompt.strip(): + raise RegistryError("INVALID_ARGUMENT", "prompt is required") + _require_space_id(config) + + safe_top_k = max(1, min(int(top_k or config.top_k or DEFAULT_TOP_K), 20)) + response, request_duration_ms = _agentkit_post( + config, + "SearchAgentCards", + {"SpaceId": config.space_id, "Prompt": prompt.strip(), "TopK": safe_top_k}, + ) + result = response.get("Result") or {} + raw_cards = result.get("AgentCards") or [] + + agents = [] + for index, raw_card in enumerate(raw_cards[:safe_top_k]): + card = _parse_json_object( + raw_card, "AGENT_CARD_PARSE_FAILED", f"AgentCards[{index}]" + ) + agents.append(_sanitize_agent_card(card)) + + duration_ms = int((time.monotonic() - started) * 1000) + if not agents: + raise RegistryError( + "AGENT_NOT_FOUND", + "SearchAgentCards did not return usable agents", + {"duration_ms": duration_ms}, + ) + + return _success( + { + "agents": agents, + "total_count": result.get("TotalCount", len(agents)), + "diagnostics": { + "search_request_id": _request_id(response), + "request_duration_ms": request_duration_ms, + "duration_ms": duration_ms, + }, + } + ) + + +def create_task( + agent_name: str, + input_text: str, + task_id: str | None = None, + config: AgentKitA2ARegistryConfig | None = None, +) -> dict[str, Any]: + """Create a remote A2A task by AgentKit A2A agent name.""" + + started = time.monotonic() + config = _resolve_config(config) + if not agent_name or not agent_name.strip(): + raise RegistryError("INVALID_ARGUMENT", "agent_name is required") + if not input_text or not input_text.strip(): + raise RegistryError("INVALID_ARGUMENT", "input is required") + + result, card, raw_response, get_duration_ms = _get_a2a_agent( + agent_name.strip(), config + ) + a2a_result = _send_message(card, input_text, config, task_id=task_id) + return _task_or_message_success( + a2a_result, + _sanitize_get_agent_result(result, card), + { + "get_request_id": _request_id(raw_response), + "get_duration_ms": get_duration_ms, + "duration_ms": int((time.monotonic() - started) * 1000), + }, + ) + + +def poll_task( + agent_name: str, + task_id: str, + history_length: int = 10, + config: AgentKitA2ARegistryConfig | None = None, +) -> dict[str, Any]: + """Poll a remote A2A task by AgentKit A2A agent name.""" + + started = time.monotonic() + config = _resolve_config(config) + if not agent_name or not agent_name.strip(): + raise RegistryError("INVALID_ARGUMENT", "agent_name is required") + if not task_id or not task_id.strip(): + raise RegistryError("INVALID_ARGUMENT", "task_id is required") + + _, card, _, _ = _get_a2a_agent(agent_name.strip(), config) + return _poll_card(card, task_id, history_length, config, started) + + +def failure( + code: str, message: str, diagnostics: dict[str, Any] | None = None +) -> dict[str, Any]: + """Return a safe failure payload suitable for tool output.""" + + return { + "outcome": "failure", + "error_code": code, + "error_message": message, + "diagnostics": diagnostics or {}, + } + + +def _resolve_config( + config: AgentKitA2ARegistryConfig | None, +) -> AgentKitA2ARegistryConfig: + env_config = registry_config_from_env() + config = config or env_config + return AgentKitA2ARegistryConfig( + space_id=config.space_id or env_config.space_id, + endpoint=config.endpoint or env_config.endpoint or DEFAULT_ENDPOINT, + version=config.version or env_config.version or DEFAULT_VERSION, + service_name=config.service_name or env_config.service_name + or DEFAULT_SERVICE_NAME, + region=config.region or env_config.region or DEFAULT_REGION, + top_k=max(1, min(int(config.top_k or env_config.top_k or DEFAULT_TOP_K), 20)), + timeout_ms=max( + 1000, int(config.timeout_ms or env_config.timeout_ms or DEFAULT_TIMEOUT_MS) + ), + poll_interval_ms=max( + 100, + int( + config.poll_interval_ms + or env_config.poll_interval_ms + or DEFAULT_POLL_INTERVAL_MS + ), + ), + ) + + +def _require_space_id(config: AgentKitA2ARegistryConfig) -> None: + if not config.space_id: + raise RegistryError( + "CONFIG_MISSING", "Missing required registry config: space_id" + ) + + +def _resolve_credentials() -> _RegistryCredentials: + access_key = _first_env( + [ + "AGENTKIT_ACCESS_KEY", + "A2A_REGISTRY_ACCESS_KEY", + "ACCESS_KEY", + "VOLCENGINE_ACCESS_KEY", + ] + ) + secret_key = _first_env( + [ + "AGENTKIT_SECRET_KEY", + "A2A_REGISTRY_SECRET_KEY", + "SECRET_KEY", + "VOLCENGINE_SECRET_KEY", + ] + ) + session_token = _first_env( + [ + "AGENTKIT_SESSION_TOKEN", + "A2A_REGISTRY_SESSION_TOKEN", + "VOLCENGINE_SESSION_TOKEN", + ] + ) + + if not (access_key and secret_key): + try: + credential = get_credential_from_vefaas_iam() + access_key = credential.access_key_id + secret_key = credential.secret_access_key + session_token = credential.session_token + except Exception as exc: + raise RegistryError( + "CONFIG_MISSING", + "Missing required registry credentials: access key and secret key", + {"source": "env_or_iam", "reason": exc.__class__.__name__}, + ) from exc + + return _RegistryCredentials( + access_key=access_key, + secret_key=secret_key, + session_token=session_token, + ) + + +def _first_env(names: list[str], default: str = "") -> str: + for name in names: + value = os.getenv(name) + if value: + return value + return default + + +def _int_env(name: str, default: int, minimum: int) -> int: + raw = os.getenv(name, str(default)) + try: + return max(minimum, int(raw)) + except ValueError: + return default + + +def _success(payload: dict[str, Any]) -> dict[str, Any]: + return {"outcome": "success", **payload} + + +def _timeout_seconds(config: AgentKitA2ARegistryConfig) -> float: + return max(1, config.timeout_ms) / 1000 + + +def _request_id(response: dict[str, Any]) -> str | None: + return (response.get("ResponseMetadata") or {}).get("RequestId") + + +def _agentkit_post( + config: AgentKitA2ARegistryConfig, action: str, body: dict[str, Any] +) -> tuple[dict[str, Any], int]: + _require_space_id(config) + credentials = _resolve_credentials() + started = time.monotonic() + body_str = json.dumps(body, ensure_ascii=False) + body_bytes = body_str.encode("utf-8") + parsed = urlparse(config.endpoint) + path = parsed.path or "/" + query = {"Action": action, "Version": config.version} + headers_to_sign = { + "Host": parsed.netloc, + "Content-Type": "application/json", + } + auth_headers = _volc_sign_v4( + access_key=credentials.access_key, + secret_key=credentials.secret_key, + service=config.service_name, + region=config.region, + method="POST", + path=path, + query=query, + headers=headers_to_sign, + body=body_str, + ) + request_headers = { + "Content-Type": "application/json", + "Host": parsed.netloc, + **auth_headers, + } + if credentials.session_token: + request_headers["X-Security-Token"] = credentials.session_token + + response = None + try: + response = requests.post( + config.endpoint, + params=query, + headers=request_headers, + data=body_bytes, + timeout=_timeout_seconds(config), + ) + response.raise_for_status() + data = response.json() + except requests.RequestException as exc: + raise RegistryError( + "AGENTKIT_OPENAPI_FAILED", + f"Agent-A2A center request failed: {exc}", + _agentkit_http_diagnostics(exc, response), + ) from exc + except ValueError as exc: + raise RegistryError( + "AGENTKIT_RESPONSE_PARSE_FAILED", + "Agent-A2A center returned non-JSON response", + ) from exc + + duration_ms = int((time.monotonic() - started) * 1000) + if data.get("Error"): + raise RegistryError( + "AGENTKIT_OPENAPI_ERROR", + "Agent-A2A center returned an error", + {"response": data.get("Error")}, + ) + if "Result" not in data: + raise RegistryError( + "AGENTKIT_RESPONSE_INVALID", "Agent-A2A center response missing Result" + ) + return data, duration_ms + + +def _agentkit_http_diagnostics( + exc: requests.RequestException, + response: requests.Response | None, +) -> dict[str, Any]: + response = getattr(exc, "response", None) or response + if response is None: + return {} + + diagnostics: dict[str, Any] = {"status_code": response.status_code} + try: + data = response.json() + except ValueError: + return diagnostics + + metadata = data.get("ResponseMetadata") if isinstance(data, dict) else None + if not isinstance(metadata, dict): + return diagnostics + + for source_key, target_key in [ + ("RequestId", "request_id"), + ("Action", "action"), + ("Version", "version"), + ("Service", "service"), + ("Region", "region"), + ]: + value = metadata.get(source_key) + if value: + diagnostics[target_key] = value + + error = metadata.get("Error") + if isinstance(error, dict): + diagnostics["response_error"] = { + key: error[key] for key in ["Code", "CodeN", "Message"] if key in error + } + + return diagnostics + + +def _get_a2a_agent( + agent_name: str, + config: AgentKitA2ARegistryConfig, +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], int]: + response, duration_ms = _agentkit_post( + config, + "GetA2aAgent", + {"Name": agent_name, "SpaceId": config.space_id}, + ) + result = response.get("Result") or {} + status = result.get("Status", "") + if status and status != "running": + raise RegistryError( + "AGENT_NOT_RUNNING", + f"Agent {agent_name} status is {status}", + {"status": status}, + ) + + card = _parse_json_object( + result.get("AgentCard"), "AGENT_CARD_PARSE_FAILED", "Result.AgentCard" + ) + if not card.get("url"): + raise RegistryError( + "AGENT_URL_MISSING", f"Agent {agent_name} AgentCard missing url" + ) + return result, card, response, duration_ms + + +def _send_message( + card: dict[str, Any], + input_text: str, + config: AgentKitA2ARegistryConfig, + task_id: str | None = None, +) -> dict[str, Any]: + message: dict[str, Any] = { + "kind": "message", + "messageId": str(uuid.uuid4()), + "role": "user", + "parts": [{"kind": "text", "text": input_text}], + } + if task_id: + message["taskId"] = task_id + + try: + return _a2a_jsonrpc( + card["url"], + "message/send", + {"message": message, "configuration": {"blocking": False}}, + _agent_auth_headers(card), + config, + ) + except RegistryError as exc: + if exc.code in { + "A2A_HTTP_FAILED", + "A2A_RESPONSE_PARSE_FAILED", + "A2A_REMOTE_ERROR", + "A2A_RESPONSE_INVALID", + }: + raise RegistryError( + "A2A_TASK_CREATE_FAILED", exc.message, exc.diagnostics + ) from exc + raise + + +def _poll_card( + card: dict[str, Any], + task_id: str, + history_length: int, + config: AgentKitA2ARegistryConfig, + started: float | None = None, +) -> dict[str, Any]: + started = started or time.monotonic() + a2a_result = _a2a_jsonrpc( + card["url"], + "tasks/get", + {"id": task_id.strip(), "historyLength": max(0, int(history_length))}, + _agent_auth_headers(card), + config, + ) + state = _task_state(a2a_result) + is_terminal = state in TERMINAL_STATES + payload: dict[str, Any] = { + "task": _task_summary(a2a_result), + "is_terminal": is_terminal, + "diagnostics": {"duration_ms": int((time.monotonic() - started) * 1000)}, + } + response_text = _task_response_text(a2a_result) + if response_text: + payload["response"] = {"text": response_text} + + if not is_terminal: + sleep_seconds = config.poll_interval_ms / 1000 + time.sleep(sleep_seconds) + payload["diagnostics"]["sleep_seconds"] = sleep_seconds + payload["diagnostics"]["next_action"] = ( + "call a2a_registry_task_poll again until task status is terminal" + ) + + return _success(payload) + + +def _task_or_message_success( + a2a_result: dict[str, Any], + selected_agent: dict[str, Any], + diagnostics: dict[str, Any], +) -> dict[str, Any]: + if a2a_result.get("kind") == "message": + return _success( + { + "selected_agent": selected_agent, + "task": None, + "response": {"text": _message_text(a2a_result)}, + "diagnostics": diagnostics, + } + ) + + task = _task_summary(a2a_result) + if not task["id"]: + raise RegistryError( + "A2A_TASK_CREATE_FAILED", + "A2A task created but response has no task id", + diagnostics, + ) + + return _success( + { + "selected_agent": selected_agent, + "task": task, + "diagnostics": diagnostics, + } + ) + + +def _a2a_jsonrpc( + url: str, + method: str, + params: dict[str, Any], + headers: dict[str, str], + config: AgentKitA2ARegistryConfig, +) -> dict[str, Any]: + payload = { + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "method": method, + "params": params, + } + request_headers = {"Content-Type": "application/json", **headers} + response = None + + try: + response = requests.post( + url, + headers=request_headers, + json=payload, + timeout=_timeout_seconds(config), + ) + response.raise_for_status() + data = response.json() + except requests.RequestException as exc: + raise RegistryError( + "A2A_HTTP_FAILED", + f"A2A JSON-RPC request failed: {exc}", + _http_response_diagnostics(exc, response), + ) from exc + except ValueError as exc: + raise RegistryError( + "A2A_RESPONSE_PARSE_FAILED", "A2A endpoint returned non-JSON response" + ) from exc + + if data.get("error"): + error = data["error"] + message = error.get("message") if isinstance(error, dict) else str(error) + raise RegistryError("A2A_REMOTE_ERROR", f"A2A JSON-RPC error: {message}") + + result = data.get("result") + if not isinstance(result, dict): + raise RegistryError( + "A2A_RESPONSE_INVALID", "A2A JSON-RPC response missing object result" + ) + return result + + +def _http_response_diagnostics( + exc: requests.RequestException, + response: requests.Response | None, +) -> dict[str, Any]: + response = getattr(exc, "response", None) or response + if response is None: + return {} + return {"status_code": response.status_code} + + +def _volc_sign_v4( + access_key: str, + secret_key: str, + service: str, + region: str, + method: str, + path: str, + query: dict[str, str], + headers: dict[str, str], + body: str, +) -> dict[str, str]: + now = datetime.now(timezone.utc) + x_date = now.strftime("%Y%m%dT%H%M%SZ") + date_short = now.strftime("%Y%m%d") + + canonical_query = "&".join( + f"{_uri_encode(k)}={_uri_encode(v)}" for k, v in sorted(query.items()) + ) + body_hash = hashlib.sha256(body.encode("utf-8")).hexdigest() + headers_to_sign = {**headers, "X-Content-Sha256": body_hash, "X-Date": x_date} + signed_headers_keys: list[str] = [] + canonical_headers_parts: list[str] = [] + for key in sorted(headers_to_sign.keys(), key=str.lower): + lower_key = key.lower() + signed_headers_keys.append(lower_key) + canonical_headers_parts.append(f"{lower_key}:{headers_to_sign[key].strip()}") + + canonical_headers = "\n".join(canonical_headers_parts) + "\n" + signed_headers = ";".join(signed_headers_keys) + canonical_request = "\n".join( + [ + method.upper(), + path or "/", + canonical_query, + canonical_headers, + signed_headers, + body_hash, + ] + ) + + algorithm = "HMAC-SHA256" + credential_scope = f"{date_short}/{region}/{service}/request" + string_to_sign = "\n".join( + [ + algorithm, + x_date, + credential_scope, + hashlib.sha256(canonical_request.encode("utf-8")).hexdigest(), + ] + ) + + k_date = _hmac_sha256(secret_key.encode("utf-8"), date_short.encode("utf-8")) + k_region = _hmac_sha256(k_date, region.encode("utf-8")) + k_service = _hmac_sha256(k_region, service.encode("utf-8")) + signing_key = _hmac_sha256(k_service, b"request") + signature = hmac.new( + signing_key, string_to_sign.encode("utf-8"), hashlib.sha256 + ).hexdigest() + authorization = ( + f"{algorithm} Credential={access_key}/{credential_scope}, " + f"SignedHeaders={signed_headers}, Signature={signature}" + ) + return { + "X-Content-Sha256": body_hash, + "X-Date": x_date, + "Authorization": authorization, + } + + +def _hmac_sha256(key: bytes, msg: bytes) -> bytes: + return hmac.new(key, msg, hashlib.sha256).digest() + + +def _uri_encode(value: str) -> str: + return quote(value, safe="-_.~") + + +def _parse_json_object(raw: Any, code: str, label: str) -> dict[str, Any]: + if isinstance(raw, dict): + return raw + if not isinstance(raw, str): + raise RegistryError(code, f"{label} is not a JSON string") + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise RegistryError(code, f"Failed to parse {label}: {exc}") from exc + if not isinstance(parsed, dict): + raise RegistryError(code, f"{label} parsed value is not an object") + return parsed + + +def _sanitize_skill(skill: dict[str, Any]) -> dict[str, Any]: + return { + "id": skill.get("id", ""), + "name": skill.get("name", ""), + "description": skill.get("description", ""), + "tags": skill.get("tags") or [], + } + + +def _sanitize_agent_card(card: dict[str, Any]) -> dict[str, Any]: + return { + "name": card.get("name", ""), + "description": card.get("description", ""), + "version": card.get("version") or card.get("latestPublishedVersion") or "", + "protocol_version": card.get("protocolVersion", ""), + "preferred_transport": card.get("preferredTransport", ""), + "registration_type": card.get("registrationType", ""), + "skills": [ + _sanitize_skill(skill) + for skill in card.get("skills") or [] + if isinstance(skill, dict) + ], + } + + +def _sanitize_get_agent_result( + result: dict[str, Any], card: dict[str, Any] +) -> dict[str, Any]: + runtime_config = result.get("RuntimeConfig") or {} + return { + **_sanitize_agent_card(card), + "id": result.get("Id", ""), + "status": result.get("Status", ""), + "source": result.get("Source", ""), + "default_version": result.get("DefaultVersion", ""), + "runtime_id": runtime_config.get("RuntimeId", ""), + "network_type": runtime_config.get("NetworkType", ""), + } + + +def _agent_auth_headers(card: dict[str, Any]) -> dict[str, str]: + security = card.get("security") or [] + schemes = card.get("securitySchemes") or {} + headers: dict[str, str] = {} + + for requirement in security: + if not isinstance(requirement, dict): + continue + for scheme_name, credentials in requirement.items(): + scheme = schemes.get(scheme_name) or {} + if scheme.get("type") != "apiKey" or scheme.get("in") != "header": + continue + header_name = scheme.get("name") or "Authorization" + token = ( + credentials[0] + if isinstance(credentials, list) and credentials + else credentials + ) + if isinstance(token, str) and token: + headers[header_name] = token + + if security and not headers: + raise RegistryError( + "AGENT_AUTH_MISSING", + "AgentCard has security config but no usable header credential", + ) + return headers + + +def _text_from_parts(parts: list[Any]) -> str: + texts: list[str] = [] + for part in parts: + if not isinstance(part, dict): + continue + kind = part.get("kind") or part.get("type") + if kind == "text": + texts.append(part.get("text", "")) + elif kind == "data": + texts.append(json.dumps(part.get("data") or {}, ensure_ascii=False)) + elif kind == "file": + file_obj = part.get("file") or {} + texts.append( + f"File: {file_obj['uri']}" if file_obj.get("uri") else "File attachment" + ) + return "\n".join(text for text in texts if text) + + +def _message_text(message: Any) -> str: + if isinstance(message, str): + return message + if isinstance(message, dict): + return _text_from_parts(message.get("parts") or []) + return "" + + +def _task_state(task: dict[str, Any]) -> str: + status = task.get("status") or {} + if isinstance(status, dict): + return status.get("state") or "unknown" + if isinstance(status, str): + return status + return "unknown" + + +def _task_response_text(task: dict[str, Any]) -> str: + artifacts = task.get("artifacts") or [] + artifact_texts = [] + for artifact in artifacts: + if isinstance(artifact, dict): + artifact_texts.append(_text_from_parts(artifact.get("parts") or [])) + artifact_text = "\n".join(text for text in artifact_texts if text) + if artifact_text: + return artifact_text + + status = task.get("status") or {} + if isinstance(status, dict): + return _message_text(status.get("message")) + return "" + + +def _task_summary(task: dict[str, Any]) -> dict[str, Any]: + return { + "id": task.get("id", ""), + "status": _task_state(task), + } diff --git a/veadk/cli/cli_harness.py b/veadk/cli/cli_harness.py index d08aee67..0288fad5 100644 --- a/veadk/cli/cli_harness.py +++ b/veadk/cli/cli_harness.py @@ -89,6 +89,12 @@ # env: RUNTIME flag: --runtime runtime: adk +# Structured tool calls via Ark Responses API. +# env: STRUCTURED_TOOL_CALLS flag: --structured-tool-calls +# env: INCLUDE_TOOLS_EVERY_TURN flag: --include-tools-every-turn +structured_tool_calls: false +include_tools_every_turn: true + # --- Knowledge base ---------------------------------------------------------- # type -> env: KNOWLEDGEBASE_TYPE flag: --knowledgebase-type # "" disables it. Supported: viking | opensearch | redis | tos_vector | context_search @@ -319,7 +325,10 @@ def _prune_empty(data: dict) -> None: for sub in list(value): if _is_blank(value[sub]): del value[sub] - if (key in COMPONENT_TYPE_ENV and not value.get("type")) or not value: + if ( + (key in COMPONENT_TYPE_ENV or key == "registry") + and not value.get("type") + ) or not value: del data[key] elif _is_blank(value): del data[key] @@ -348,11 +357,14 @@ def _override_options(func): """Attach a ``--flag`` for every :class:`HarnessOverrides` field. Shared by ``add`` and ``invoke`` so their model / tools / skills / - system-prompt / runtime flags stay identical and in sync with the model — - adding a field to ``HarnessOverrides`` exposes the flag in both. Each flag - defaults to ``None`` (unset → not applied). + system-prompt / runtime flags stay identical and in sync with the model. + ``registry_*`` overrides are accepted by the HTTP API for AgentKit, but are + intentionally hidden from the VeADK CLI. Each exposed flag defaults to + ``None`` (unset → not applied). """ for name, field in reversed(list(HarnessOverrides.model_fields.items())): + if name.startswith("registry_"): + continue option: dict = { "default": None, "help": field.description or f"`{name}`.", @@ -391,6 +403,18 @@ def _override_options(func): default=None, help="Default max LLM calls per run (overridable per invocation).", ) +@click.option( + "--structured-tool-calls", + is_flag=True, + default=None, + help="Use Ark Responses API for structured tool calling.", +) +@click.option( + "--include-tools-every-turn", + is_flag=True, + default=None, + help="Include tool definitions on every model turn.", +) @_connection_options @click.option( "--path", @@ -403,6 +427,8 @@ def add( long_term_memory_type: str | None, short_term_memory_type: str | None, max_llm_calls: int | None, + structured_tool_calls: bool | None, + include_tools_every_turn: bool | None, path: str, model_name: str | None, tools: str | None, @@ -420,11 +446,17 @@ def add( """ yaml_path = Path(path).resolve() / "harness.yaml" data = _load_harness_yaml(yaml_path) + data.pop("enable_responses", None) + data.pop("enable_responses_cache", None) if harness_name is not None: data["harness_name"] = harness_name if max_llm_calls is not None: data["max_llm_calls"] = max_llm_calls + if structured_tool_calls is not None: + data["structured_tool_calls"] = structured_tool_calls + if include_tools_every_turn is not None: + data["include_tools_every_turn"] = include_tools_every_turn if model_name is not None: model = data.get("model") if not isinstance(model, dict): @@ -502,12 +534,14 @@ def show(path: str) -> None: click.echo("") click.secho("Overridable at invoke time:", fg="green", bold=True) for name, field in HarnessOverrides.model_fields.items(): + if name.startswith("registry_"): + continue flag = "--" + name.replace("_", "-") click.echo(f" {flag}: {field.description or name}") click.echo("") click.echo( "Override per call via `veadk harness invoke ... --`. " - "Memory and knowledgebase are NOT overridable." + "Memory, knowledgebase, and registry are not exposed as VeADK CLI overrides." ) diff --git a/veadk/cloud/harness_app/agent.py b/veadk/cloud/harness_app/agent.py index 6222c9ba..7d4ed9f7 100644 --- a/veadk/cloud/harness_app/agent.py +++ b/veadk/cloud/harness_app/agent.py @@ -30,6 +30,9 @@ KNOWLEDGEBASE_TYPE Knowledge base backend (e.g. "viking"). Unset disables it. LONG_TERM_MEMORY_TYPE Long-term memory backend (e.g. "viking"). Unset disables it. SHORT_TERM_MEMORY_TYPE Short-term memory backend (e.g. "sqlite"). Default: "local". + REGISTRY_TYPE Remote Agent discovery backend. Currently: "agentkit_a2a". + REGISTRY_SPACE_ID AgentKit A2A SpaceId used by SearchAgentCards/GetA2aAgent. + REGISTRY_TOP_K Candidate AgentCard count for semantic search. Default: 3. """ from veadk.cloud.harness_app.utils import init_harness_agent diff --git a/veadk/cloud/harness_app/env_mapping.py b/veadk/cloud/harness_app/env_mapping.py index 95ea5b1e..1330140e 100644 --- a/veadk/cloud/harness_app/env_mapping.py +++ b/veadk/cloud/harness_app/env_mapping.py @@ -26,11 +26,11 @@ Two kinds of fields are converted differently: * **Everything except the component sections** (``harness_name``, ``model``, - ``tools``, ``skills``, ``system_prompt``, ``runtime``) is flattened with VeADK's - own :func:`veadk.utils.misc.flatten_dict` (the flattener ``set_envs`` uses for - ``config.yaml``): nested keys joined with ``_``, then upper-cased, lists - comma-joined. So ``model: {name: x}`` -> ``MODEL_NAME``, ``tools: [a, b]`` -> - ``TOOLS``. + ``tools``, ``skills``, ``system_prompt``, ``runtime``, ``registry``) is + flattened with VeADK's own :func:`veadk.utils.misc.flatten_dict` (the flattener + ``set_envs`` uses for ``config.yaml``): nested keys joined with ``_``, then + upper-cased, lists comma-joined. So ``model: {name: x}`` -> ``MODEL_NAME``, + ``tools: [a, b]`` -> ``TOOLS``. * **Component sections** (``knowledgebase`` / ``long_term_memory`` / ``short_term_memory``): ``type`` becomes the harness selector env, and the remaining connection params are mapped to the VeADK env vars the backend diff --git a/veadk/cloud/harness_app/types.py b/veadk/cloud/harness_app/types.py index cfea0ea1..a310aec5 100644 --- a/veadk/cloud/harness_app/types.py +++ b/veadk/cloud/harness_app/types.py @@ -36,9 +36,9 @@ class HarnessOverrides(BaseModel): """Harness parameters that may be overridden on a per-invocation basis. - Field descriptions are the single source of truth for both the FastAPI schema - and the ``veadk harness invoke`` CLI flags (which are generated from these - fields), so adding a field here exposes a new override everywhere. + Field descriptions are the source of truth for the FastAPI schema and most + ``veadk harness invoke`` CLI flags. ``registry_*`` fields are accepted for + AgentKit's harness invoke API but intentionally hidden from the VeADK CLI. """ model_name: str = Field( @@ -56,6 +56,18 @@ class HarnessOverrides(BaseModel): runtime: Literal["adk", "codex"] = Field( default="adk", description="Agent runtime backend." ) + registry_space_id: str = Field( + default="", description="Override the AgentKit A2A registry space id." + ) + registry_endpoint: str = Field( + default="", description="Override the AgentKit A2A registry OpenAPI endpoint." + ) + registry_region: str = Field( + default="", description="Override the AgentKit A2A registry OpenAPI region." + ) + registry_top_k: int = Field( + default=3, description="Override the number of A2A AgentCards to retrieve." + ) class HarnessConfig(HarnessOverrides): @@ -78,6 +90,13 @@ class HarnessConfig(HarnessOverrides): default=None, description="Default max LLM calls per run; unset follows ADK RunConfig's default. Overridable per invocation.", ) + structured_tool_calls: bool = Field(default=False) + include_tools_every_turn: bool = Field(default=True) + registry_type: Literal["", "agentkit_a2a"] = Field(default="") + registry_version: str = Field(default="") + registry_service_name: str = Field(default="") + registry_timeout_ms: int = Field(default=60000) + registry_poll_interval_ms: int = Field(default=5000) class RunAgentRequest(BaseModel): diff --git a/veadk/cloud/harness_app/utils.py b/veadk/cloud/harness_app/utils.py index 3e7ec26a..cdef4175 100644 --- a/veadk/cloud/harness_app/utils.py +++ b/veadk/cloud/harness_app/utils.py @@ -28,6 +28,7 @@ import shutil import tempfile import zipfile +from dataclasses import replace from pathlib import Path from typing import Any @@ -46,6 +47,19 @@ logger = get_logger(__name__) +_REGISTRY_CONFIG_ATTR = "_veadk_a2a_registry_config" +_REGISTRY_TOOL_NAMES = { + "a2a_registry_search_agent_cards", + "a2a_registry_task_create", + "a2a_registry_task_poll", +} +_REGISTRY_OVERRIDE_FIELDS = { + "registry_space_id", + "registry_endpoint", + "registry_region", + "registry_top_k", +} + __all__ = [ "HarnessConfig", "HarnessOverrides", @@ -94,11 +108,22 @@ def _load_builtin_tool(name: str) -> Any: "skills": "SKILLS", "system_prompt": "SYSTEM_PROMPT", "runtime": "RUNTIME", + "structured_tool_calls": "STRUCTURED_TOOL_CALLS", + "include_tools_every_turn": "INCLUDE_TOOLS_EVERY_TURN", "name": "HARNESS_NAME", "knowledgebase_type": "KNOWLEDGEBASE_TYPE", "longterm_memory_type": "LONG_TERM_MEMORY_TYPE", "shortterm_memory_type": "SHORT_TERM_MEMORY_TYPE", "max_llm_calls": "MAX_LLM_CALLS", + "registry_type": "REGISTRY_TYPE", + "registry_space_id": "REGISTRY_SPACE_ID", + "registry_endpoint": "REGISTRY_ENDPOINT", + "registry_version": "REGISTRY_VERSION", + "registry_service_name": "REGISTRY_SERVICE_NAME", + "registry_region": "REGISTRY_REGION", + "registry_top_k": "REGISTRY_TOP_K", + "registry_timeout_ms": "REGISTRY_TIMEOUT_MS", + "registry_poll_interval_ms": "REGISTRY_POLL_INTERVAL_MS", } @@ -236,6 +261,26 @@ def _assemble_agent(config: HarnessConfig) -> tuple[Agent, ShortTermMemory]: if skill_toolset is not None: tools.append(skill_toolset) + registry_config = None + if config.registry_type: + from veadk.a2a.registry_client import AgentKitA2ARegistryConfig + from veadk.tools.builtin_tools.a2a_registry import ( + build_a2a_registry_tools, + ) + + logger.info(f"Mounting A2A registry tools: type={config.registry_type}") + registry_config = AgentKitA2ARegistryConfig( + space_id=config.registry_space_id, + endpoint=config.registry_endpoint, + version=config.registry_version, + service_name=config.registry_service_name, + region=config.registry_region, + top_k=config.registry_top_k, + timeout_ms=config.registry_timeout_ms, + poll_interval_ms=config.registry_poll_interval_ms, + ) + tools.extend(build_a2a_registry_tools(registry_config)) + knowledgebase = None if config.knowledgebase_type: logger.info( @@ -271,10 +316,14 @@ def _assemble_agent(config: HarnessConfig) -> tuple[Agent, ShortTermMemory]: instruction=config.system_prompt, tools=tools, runtime=config.runtime, + enable_responses=config.structured_tool_calls, + enable_responses_cache=not config.include_tools_every_turn, knowledgebase=knowledgebase, long_term_memory=long_term_memory, short_term_memory=short_term_memory, ) + if registry_config is not None: + setattr(agent, _REGISTRY_CONFIG_ATTR, registry_config) return agent, short_term_memory @@ -338,6 +387,41 @@ def _add_incremental_skills( agent.tools.append(SkillToolset(skills=existing_skills + new_skills)) +def _remove_a2a_registry_tools(agent: Agent) -> None: + agent.tools = [ + tool for tool in agent.tools if _tool_name(tool) not in _REGISTRY_TOOL_NAMES + ] + + +def _apply_registry_overrides( + agent: Agent, + base_config, + overrides: HarnessOverrides, +) -> None: + set_fields = overrides.model_fields_set + if not (_REGISTRY_OVERRIDE_FIELDS & set_fields): + return + + from veadk.a2a.registry_client import AgentKitA2ARegistryConfig + from veadk.tools.builtin_tools.a2a_registry import build_a2a_registry_tools + + config = base_config or AgentKitA2ARegistryConfig() + updates: dict[str, Any] = {} + if "registry_space_id" in set_fields: + updates["space_id"] = overrides.registry_space_id + if "registry_endpoint" in set_fields: + updates["endpoint"] = overrides.registry_endpoint + if "registry_region" in set_fields: + updates["region"] = overrides.registry_region + if "registry_top_k" in set_fields: + updates["top_k"] = overrides.registry_top_k + + overridden_config = replace(config, **updates) + _remove_a2a_registry_tools(agent) + agent.tools.extend(build_a2a_registry_tools(overridden_config)) + setattr(agent, _REGISTRY_CONFIG_ATTR, overridden_config) + + def spawn_harness_agent( base_agent: Agent, overrides: HarnessOverrides, download_dir: Path | None = None ) -> Agent: @@ -371,4 +455,10 @@ def spawn_harness_agent( if "skills" in set_fields: _add_incremental_skills(cloned, split_csv(overrides.skills), download_dir) + _apply_registry_overrides( + cloned, + getattr(base_agent, _REGISTRY_CONFIG_ATTR, None), + overrides, + ) + return cloned diff --git a/veadk/tools/builtin_tools/a2a_registry.py b/veadk/tools/builtin_tools/a2a_registry.py new file mode 100644 index 00000000..3d29c502 --- /dev/null +++ b/veadk/tools/builtin_tools/a2a_registry.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from veadk.a2a.registry_client import ( + AgentKitA2ARegistryConfig, + RegistryError, + create_task, + failure, + poll_task, + registry_config_from_env, + search_agent_cards, +) + + +def build_a2a_registry_tools( + config: AgentKitA2ARegistryConfig | None = None, +) -> list[Callable[..., dict[str, Any]]]: + """Build the three AgentKit A2A registry tools for a harness agent.""" + + resolved_config = config or registry_config_from_env() + + def a2a_registry_search_agent_cards(prompt: str = "") -> dict[str, Any]: + """Search the AgentKit A2A registry for remote agents that can handle a task. + + Use this first when you determine that a remote A2A Agent may be needed + for the task, such as when specialist capabilities, delegation, or agent + discovery could improve the result. Pass a concise search prompt, not + the complete user request. If the user's request is long, summarize it + into keywords or a short task description before calling this tool. + The UTF-8 encoded `prompt` must not exceed 2048 bytes. Inspect the + returned `agents` list, compare each agent's `name`, `description`, and + `skills`, then choose the best `agent_name` for + `a2a_registry_task_create` if a suitable agent is available. + """ + + try: + return search_agent_cards(prompt, None, resolved_config) + except RegistryError as exc: + return failure(exc.code, exc.message, exc.diagnostics) + except Exception as exc: # noqa: BLE001 - tool calls should return safely. + return failure("INTERNAL_ERROR", str(exc)) + + def a2a_registry_task_create( + agent_name: str, input: str, task_id: str | None = None + ) -> dict[str, Any]: + """Send the user's task to the selected remote A2A agent. + + Use this after `a2a_registry_search_agent_cards` and pass the exact + selected `agents[].name` as `agent_name`. Put the full user request in + `input`. This calls the remote agent with A2A `message/send` and may + return either a final `response.text` or a `task.id`. If it returns a + `task.id` without a final response, call `a2a_registry_task_poll` with + the same `agent_name` and `task_id`. + """ + + try: + return create_task(agent_name, input, task_id, resolved_config) + except RegistryError as exc: + return failure(exc.code, exc.message, exc.diagnostics) + except Exception as exc: # noqa: BLE001 + return failure("INTERNAL_ERROR", str(exc)) + + def a2a_registry_task_poll( + agent_name: str, task_id: str, history_length: int = 10 + ) -> dict[str, Any]: + """Check the status of an existing remote A2A task. + + Use this after `a2a_registry_task_create` returns a `task.id` without a + final response. This tool calls A2A `tasks/get` once with the same + `agent_name` and `task_id`. If `is_terminal` is false, do not create a + new task; call this tool again with the same `task_id` until the task + reaches a terminal state. When the task is terminal, return the A2A + task's query result to the user. + """ + + try: + return poll_task(agent_name, task_id, history_length, resolved_config) + except RegistryError as exc: + return failure(exc.code, exc.message, exc.diagnostics) + except Exception as exc: # noqa: BLE001 + return failure("INTERNAL_ERROR", str(exc)) + + return [ + a2a_registry_search_agent_cards, + a2a_registry_task_create, + a2a_registry_task_poll, + ]