diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 4cee6ba6d1..5083aac4f2 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -250,9 +250,44 @@ def convert_schema(schema: dict) -> dict: "integer": {"int32", "int64"}, "number": {"float", "double"}, } + support_fields = { + "title", + "description", + "enum", + "minimum", + "maximum", + "maxItems", + "minItems", + "nullable", + "required", + } - if "anyOf" in schema: - return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]} + def apply_supported_fields(result: dict, source: dict) -> None: + for key in support_fields: + if key in source and key not in result: + result[key] = source[key] + + for union_key in ("anyOf", "oneOf"): + union_value = schema.get(union_key) + if isinstance(union_value, list): + converted_branches = [ + convert_schema(item) if isinstance(item, dict) else item + for item in union_value + ] + non_null_branches = [ + item + for item in converted_branches + if not (isinstance(item, dict) and item.get("type") == "null") + ] + if len(non_null_branches) == 1 and isinstance( + non_null_branches[0], dict + ): + result = non_null_branches[0].copy() + if len(converted_branches) > 1: + result["nullable"] = True + apply_supported_fields(result, schema) + return result + return {union_key: converted_branches} result = {} @@ -268,6 +303,12 @@ def convert_schema(schema: dict) -> dict: if target_type in supported_types: result["type"] = target_type + if ( + isinstance(origin_type, list) + and "null" in origin_type + and target_type != "null" + ): + result["nullable"] = True if "format" in schema and schema["format"] in supported_formats.get( result["type"], set(), @@ -276,18 +317,7 @@ def convert_schema(schema: dict) -> dict: else: result["type"] = "null" - support_fields = { - "title", - "description", - "enum", - "minimum", - "maximum", - "maxItems", - "minItems", - "nullable", - "required", - } - result.update({k: schema[k] for k in support_fields if k in schema}) + apply_supported_fields(result, schema) if "properties" in schema: properties = {} diff --git a/tests/unit/test_tool_google_schema.py b/tests/unit/test_tool_google_schema.py index f1046e6af3..a2be6c2fb9 100644 --- a/tests/unit/test_tool_google_schema.py +++ b/tests/unit/test_tool_google_schema.py @@ -75,3 +75,131 @@ def test_google_schema_fills_missing_array_items_with_string_schema(): assert source_uuids["type"] == "array" assert source_uuids["items"] == {"type": "string"} + + +def test_google_schema_collapses_nullable_anyof_property(): + tool_module = load_tool_module() + FunctionTool = tool_module.FunctionTool + ToolSet = tool_module.ToolSet + + tool = FunctionTool( + name="search_sources", + description="Search sources by recency.", + parameters={ + "type": "object", + "properties": { + "time_range": { + "description": "Optional recency filter.", + "anyOf": [ + { + "type": "string", + "enum": ["day", "week", "month", "year"], + }, + {"type": "null"}, + ], + "default": None, + } + }, + }, + ) + + schema = ToolSet([tool]).google_schema() + time_range = schema["function_declarations"][0]["parameters"]["properties"][ + "time_range" + ] + + assert time_range["type"] == "string" + assert time_range["description"] == "Optional recency filter." + assert time_range["enum"] == ["day", "week", "month", "year"] + assert time_range["nullable"] is True + assert "anyOf" not in time_range + assert "default" not in time_range + + +def test_google_schema_collapses_single_branch_anyof_property(): + tool_module = load_tool_module() + FunctionTool = tool_module.FunctionTool + ToolSet = tool_module.ToolSet + + tool = FunctionTool( + name="search_sources", + description="Search sources by query.", + parameters={ + "type": "object", + "properties": { + "query": { + "description": "Search query.", + "anyOf": [ + { + "type": "string", + } + ], + } + }, + }, + ) + + schema = ToolSet([tool]).google_schema() + query = schema["function_declarations"][0]["parameters"]["properties"]["query"] + + assert query["type"] == "string" + assert query["description"] == "Search query." + assert "nullable" not in query + assert "anyOf" not in query + + +def test_google_schema_preserves_non_dict_union_branches(): + tool_module = load_tool_module() + FunctionTool = tool_module.FunctionTool + ToolSet = tool_module.ToolSet + + tool = FunctionTool( + name="search_sources", + description="Search sources by literal value.", + parameters={ + "type": "object", + "properties": { + "value": { + "anyOf": [ + {"type": "string"}, + False, + ], + } + }, + }, + ) + + schema = ToolSet([tool]).google_schema() + value = schema["function_declarations"][0]["parameters"]["properties"]["value"] + + assert value["anyOf"] == [ + {"type": "string"}, + False, + ] + + +def test_google_schema_marks_type_list_with_null_as_nullable(): + tool_module = load_tool_module() + FunctionTool = tool_module.FunctionTool + ToolSet = tool_module.ToolSet + + tool = FunctionTool( + name="search_sources", + description="Search sources by recency.", + parameters={ + "type": "object", + "properties": { + "query": { + "type": ["string", "null"], + "description": "Optional query.", + } + }, + }, + ) + + schema = ToolSet([tool]).google_schema() + query = schema["function_declarations"][0]["parameters"]["properties"]["query"] + + assert query["type"] == "string" + assert query["description"] == "Optional query." + assert query["nullable"] is True