-
Notifications
You must be signed in to change notification settings - Fork 2.7k
feat: AI node supports workflow calling tools #4955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,23 +9,123 @@ | |
| import json | ||
| import re | ||
| import time | ||
| import uuid | ||
| from functools import reduce | ||
| from typing import List, Dict | ||
|
|
||
| from application.flow.i_step_node import NodeResult, INode | ||
| from langchain_core.tools import StructuredTool | ||
|
|
||
| from application.flow.common import Workflow, WorkflowMode | ||
| from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler, ToolWorkflowCallPostHandler | ||
| from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode | ||
| from application.flow.tools import Reasoning, mcp_response_generator | ||
| from application.models import Application, ApplicationApiKey, ApplicationAccessToken | ||
| from application.serializers.common import ToolExecute | ||
| from common.exception.app_exception import AppApiException | ||
| from common.utils.rsa_util import rsa_long_decrypt | ||
| from common.utils.shared_resource_auth import filter_authorized_ids | ||
| from common.utils.tool_code import ToolExecutor | ||
| from django.db.models import QuerySet | ||
| from django.db.models import QuerySet, OuterRef, Subquery | ||
| from django.utils.translation import gettext as _ | ||
| from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage | ||
| from models_provider.models import Model | ||
| from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id | ||
| from tools.models import Tool | ||
| from tools.models import Tool, ToolWorkflowVersion, ToolType | ||
| from pydantic import BaseModel, Field, create_model | ||
| import uuid_utils.compat as uuid | ||
|
|
||
|
|
||
| def build_schema(fields: dict): | ||
| return create_model("dynamicSchema", **fields) | ||
|
|
||
|
|
||
| def get_type(_type: str): | ||
| if _type == 'float': | ||
| return float | ||
| if _type == 'string': | ||
| return str | ||
| if _type == 'int': | ||
| return int | ||
| if _type == 'dict': | ||
| return dict | ||
| if _type == 'array': | ||
| return list | ||
| if _type == 'boolean': | ||
| return bool | ||
| return object | ||
|
|
||
|
|
||
| def get_workflow_args(tool, qv): | ||
| for node in qv.work_flow.get('nodes'): | ||
| if node.get('type') == 'tool-base-node': | ||
| input_field_list = node.get('properties').get('user_input_field_list') | ||
| return build_schema( | ||
| {field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc'))) | ||
| for field in input_field_list}) | ||
|
|
||
| return build_schema({}) | ||
|
|
||
|
|
||
| def get_workflow_func(tool, qv, workspace_id): | ||
| tool_id = tool.id | ||
| tool_record_id = str(uuid.uuid7()) | ||
| took_execute = ToolExecute(tool_id, tool_record_id, | ||
| workspace_id, | ||
| None, | ||
| None, | ||
| True) | ||
|
|
||
| def inner(**kwargs): | ||
| from application.flow.tool_workflow_manage import ToolWorkflowManage | ||
| work_flow_manage = ToolWorkflowManage( | ||
| Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL), | ||
| { | ||
| 'chat_record_id': tool_record_id, | ||
| 'tool_id': tool_id, | ||
| 'stream': True, | ||
| 'workspace_id': workspace_id, | ||
| **kwargs}, | ||
|
|
||
| ToolWorkflowCallPostHandler(took_execute, tool_id), | ||
| is_the_task_interrupted=lambda: False, | ||
| child_node=None, | ||
| start_node_id=None, | ||
| start_node_data=None, | ||
| chat_record=None | ||
| ) | ||
| res = work_flow_manage.run() | ||
| for r in res: | ||
| pass | ||
| return work_flow_manage.out_context | ||
|
|
||
| return inner | ||
|
|
||
|
|
||
| def get_tools(tool_workflow_ids, workspace_id): | ||
| tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id) | ||
| latest_subquery = ToolWorkflowVersion.objects.filter( | ||
| tool_id=OuterRef('tool_id') | ||
| ).order_by('-create_time') | ||
|
|
||
| qs = ToolWorkflowVersion.objects.filter( | ||
| tool_id__in=[t.id for t in tools], | ||
| id=Subquery(latest_subquery.values('id')[:1]) | ||
| ) | ||
| qd = {q.tool_id: q for q in qs} | ||
| results = [] | ||
| for tool in tools: | ||
| qv = qd.get(tool.id) | ||
| func = get_workflow_func(tool, qv, workspace_id) | ||
| args = get_workflow_args(tool, qv) | ||
| tool = StructuredTool.from_function( | ||
| func=func, | ||
| name=tool.name, | ||
| description=tool.desc, | ||
| args_schema=args, | ||
| ) | ||
| results.append(tool) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, | ||
|
|
@@ -178,7 +278,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record | |
| model_id = reference_data.get('model_id', model_id) | ||
| model_params_setting = reference_data.get('model_params_setting') | ||
|
|
||
| if model_params_setting is None and model_id: | ||
| if model_params_setting is None and model_id: | ||
| model_params_setting = get_default_model_params_setting(model_id) | ||
|
|
||
| if model_setting is None: | ||
|
|
@@ -216,7 +316,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record | |
| mcp_result = self._handle_mcp_request( | ||
| mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids, | ||
| application_ids, skill_tool_ids, mcp_output_enable, | ||
| chat_model, message_list, history_message, question, chat_id | ||
| chat_model, message_list, history_message, question, chat_id, workspace_id | ||
| ) | ||
| if mcp_result: | ||
| return mcp_result | ||
|
|
@@ -236,7 +336,8 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record | |
|
|
||
| def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids, | ||
| application_ids, skill_tool_ids, | ||
| mcp_output_enable, chat_model, message_list, history_message, question, chat_id): | ||
| mcp_output_enable, chat_model, message_list, history_message, question, chat_id, | ||
| workspace_id): | ||
|
|
||
| mcp_servers_config = {} | ||
|
|
||
|
|
@@ -259,11 +360,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids | |
| mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])} | ||
| mcp_servers_config = self.handle_variables(mcp_servers_config) | ||
| tool_init_params = {} | ||
| tools = get_tools(tool_ids, workspace_id) | ||
| if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP | ||
| self.context['tool_ids'] = tool_ids | ||
| for tool_id in tool_ids: | ||
| tool = QuerySet(Tool).filter(id=tool_id).first() | ||
| if not tool.is_active: | ||
| tool = QuerySet(Tool).filter(id=tool_id, tool_type=ToolType.CUSTOM).first() | ||
| if tool is None or not tool.is_active: | ||
| continue | ||
| executor = ToolExecutor() | ||
| if tool.init_params is not None: | ||
|
|
@@ -323,7 +425,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids | |
| }) | ||
| mcp_servers_config['skills'] = skill_file_items | ||
|
|
||
| if len(mcp_servers_config) > 0: | ||
| if len(mcp_servers_config) > 0 or len(tools) > 0: | ||
| # 安全获取 application | ||
| application_id = None | ||
| if (self.workflow_manage and | ||
|
|
@@ -334,7 +436,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids | |
| source_id = application_id or knowledge_id | ||
| source_type = 'APPLICATION' if application_id else 'KNOWLEDGE' | ||
| r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable, | ||
| tool_init_params, source_id, source_type, chat_id) | ||
| tool_init_params, source_id, source_type, chat_id, tools) | ||
| return NodeResult( | ||
| {'result': r, 'chat_model': chat_model, 'message_list': message_list, | ||
| 'history_message': [{'content': message.content, 'role': message.type} for message in | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Here's an optimized version of the script incorporating some suggested improvements: |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,7 @@ def _merge_lists_normalize_empty_tool_chunk_ids(left, *others): | |
| """Wrapper around merge_lists that normalises empty-string IDs to None in | ||
| tool_call_chunk items (those with an 'index' key) so that qwen streaming | ||
| chunks with id='' are merged correctly by index.""" | ||
|
|
||
| def _norm(lst): | ||
| if lst is None: | ||
| return lst | ||
|
|
@@ -158,17 +159,17 @@ def get_reasoning_content(self, chunk): | |
| self.reasoning_content_end_tag) | ||
| if reasoning_content_end_tag_index > -1: | ||
| reasoning_content_chunk = self.reasoning_content_chunk[ | ||
| 0:reasoning_content_end_tag_index] | ||
| 0:reasoning_content_end_tag_index] | ||
| content_chunk = self.reasoning_content_chunk[ | ||
| reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:] | ||
| reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:] | ||
| self.reasoning_content += reasoning_content_chunk | ||
| self.content += content_chunk | ||
| self.reasoning_content_chunk = "" | ||
| self.reasoning_content_is_end = True | ||
| return {'content': content_chunk, 'reasoning_content': reasoning_content_chunk} | ||
| else: | ||
| reasoning_content_chunk = self.reasoning_content_chunk[ | ||
| 0:reasoning_content_end_tag_prefix_index + 1] | ||
| 0:reasoning_content_end_tag_prefix_index + 1] | ||
| self.reasoning_content_chunk = self.reasoning_content_chunk.replace( | ||
| reasoning_content_chunk, '') | ||
| self.reasoning_content += reasoning_content_chunk | ||
|
|
@@ -401,11 +402,14 @@ async def _initialize_skills(mcp_servers, temp_dir): | |
|
|
||
|
|
||
| async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={}, | ||
| source_id=None, source_type=None, temp_dir=None, chat_id=None): | ||
| source_id=None, source_type=None, temp_dir=None, chat_id=None, extra_tools=None): | ||
| try: | ||
| checkpointer = MemorySaver() | ||
| client = await _initialize_skills(mcp_servers, temp_dir) | ||
| tools = await client.get_tools() | ||
| if extra_tools: | ||
| for tool in extra_tools: | ||
| tools.append(tool) | ||
| agent = create_deep_agent( | ||
| model=chat_model, | ||
| backend=SandboxShellBackend(root_dir=temp_dir, virtual_mode=True), | ||
|
|
@@ -517,7 +521,7 @@ def _upsert_fragment(key, raw_id, func_name, part_args): | |
| # qwen-plus often emits {} here as a placeholder while | ||
| # the real args are split in tool_call_chunks/invalid_tool_calls. | ||
| if has_tool_call_chunks and ( | ||
| part_args == '' or part_args == {} or part_args == [] | ||
| part_args == '' or part_args == {} or part_args == [] | ||
| ): | ||
| part_args = '' | ||
| key = _get_fragment_key(tool_call.get('index'), raw_id) | ||
|
|
@@ -563,9 +567,9 @@ def _upsert_fragment(key, raw_id, func_name, part_args): | |
| # 3. 检测工具调用结束,更新 tool_calls_info | ||
| # ---------------------------------------------------------------- | ||
| is_finish_chunk = ( | ||
| chunk[0].response_metadata.get( | ||
| 'finish_reason') == 'tool_calls' | ||
| or chunk[0].chunk_position == 'last' | ||
| chunk[0].response_metadata.get( | ||
| 'finish_reason') == 'tool_calls' | ||
| or chunk[0].chunk_position == 'last' | ||
| ) | ||
|
|
||
| if is_finish_chunk: | ||
|
|
@@ -734,7 +738,7 @@ async def save_tool_record(tool_id, tool_info, tool_result, source_id, source_ty | |
|
|
||
|
|
||
| def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={}, | ||
| source_id=None, source_type=None, chat_id=None): | ||
| source_id=None, source_type=None, chat_id=None, extra_tools=None): | ||
| """使用全局事件循环,不创建新实例""" | ||
| result_queue = queue.Queue() | ||
| loop = get_global_loop() # 使用共享循环 | ||
|
|
@@ -751,7 +755,7 @@ def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_ena | |
| async def _run(): | ||
| try: | ||
| async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable, tool_init_params, | ||
| source_id, source_type, temp_dir, chat_id) | ||
| source_id, source_type, temp_dir, chat_id, extra_tools) | ||
| async for chunk in async_gen: | ||
| result_queue.put(('data', chunk)) | ||
| except Exception as e: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's some advice and comments on your provided code:
Specific Recommendations:# Ensure consistent bracket usage throughout the code for readability.
def _upsert_fragment(key, raw_id, func_name, part_args):
if part_args == '':
part_args = {}
... rest of methods...
...
async def save_tool_record(tool_id, tool_info, tool_result, source_id, source_type):
...
... add more context-based improvements where appropriate...Feel free to adapt these comments based on further review and understanding of your specific requirements! |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code seems to have several issues:
Class Duplication: The third class definition of
ToolWorkflowPostHandleris identical to the first one, which can lead to confusion and errors if not resolved.Variable Initialization Issue: In the second class definition, there are no values assigned to
self.chat_info, making it impossible to use or modify this variable within the handlers.Method Overriding and Reinitialization: Even though this might be an oversight, reassigning
self.chat_infoandself.tool_idafter calling the base class method can cause unexpected behavior.Documentation Lack: Comments are sparse, making it difficult to understand the purpose and functionality of each part of the code.
To improve the code:
ToolWorkflowPostHandler.self.chat_infoandself.tool_idbefore the call tosuper().__init__().Here's a corrected version (assuming you want to keep only one implementation):
This revision ensures that
ChatInfoandtool_idare initialized correctly during the initialization process and provides a clear structure for handling work items associated with tools through workflows.