diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index c2daacfdadb..fd14e05beb4 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -128,6 +128,16 @@ def get_tool_workflow_state(workflow): return State.SUCCESS +class ToolWorkflowCallPostHandler(WorkFlowPostHandler): + def __init__(self, chat_info, tool_id): + super().__init__(chat_info) + self.tool_id = tool_id + + def handler(self, workflow): + self.chat_info = None + self.tool_id = None + + class ToolWorkflowPostHandler(WorkFlowPostHandler): def __init__(self, chat_info, tool_id): super().__init__(chat_info) diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 53dfe7a0cd4..718c10ac06a 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -9,23 +9,123 @@ import json import re import time +import uuid from functools import reduce from typing import List, Dict -from application.flow.i_step_node import NodeResult, INode +from langchain_core.tools import StructuredTool + +from application.flow.common import Workflow, WorkflowMode +from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler, ToolWorkflowCallPostHandler from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from application.flow.tools import Reasoning, mcp_response_generator from application.models import Application, ApplicationApiKey, ApplicationAccessToken +from application.serializers.common import ToolExecute from common.exception.app_exception import AppApiException from common.utils.rsa_util import rsa_long_decrypt from common.utils.shared_resource_auth import filter_authorized_ids from common.utils.tool_code import ToolExecutor -from django.db.models import QuerySet +from django.db.models import QuerySet, OuterRef, Subquery from django.utils.translation import gettext as _ from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage from models_provider.models import Model from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id -from tools.models import Tool +from tools.models import Tool, ToolWorkflowVersion, ToolType +from pydantic import BaseModel, Field, create_model +import uuid_utils.compat as uuid + + +def build_schema(fields: dict): + return create_model("dynamicSchema", **fields) + + +def get_type(_type: str): + if _type == 'float': + return float + if _type == 'string': + return str + if _type == 'int': + return int + if _type == 'dict': + return dict + if _type == 'array': + return list + if _type == 'boolean': + return bool + return object + + +def get_workflow_args(tool, qv): + for node in qv.work_flow.get('nodes'): + if node.get('type') == 'tool-base-node': + input_field_list = node.get('properties').get('user_input_field_list') + return build_schema( + {field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc'))) + for field in input_field_list}) + + return build_schema({}) + + +def get_workflow_func(tool, qv, workspace_id): + tool_id = tool.id + tool_record_id = str(uuid.uuid7()) + took_execute = ToolExecute(tool_id, tool_record_id, + workspace_id, + None, + None, + True) + + def inner(**kwargs): + from application.flow.tool_workflow_manage import ToolWorkflowManage + work_flow_manage = ToolWorkflowManage( + Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL), + { + 'chat_record_id': tool_record_id, + 'tool_id': tool_id, + 'stream': True, + 'workspace_id': workspace_id, + **kwargs}, + + ToolWorkflowCallPostHandler(took_execute, tool_id), + is_the_task_interrupted=lambda: False, + child_node=None, + start_node_id=None, + start_node_data=None, + chat_record=None + ) + res = work_flow_manage.run() + for r in res: + pass + return work_flow_manage.out_context + + return inner + + +def get_tools(tool_workflow_ids, workspace_id): + tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id) + latest_subquery = ToolWorkflowVersion.objects.filter( + tool_id=OuterRef('tool_id') + ).order_by('-create_time') + + qs = ToolWorkflowVersion.objects.filter( + tool_id__in=[t.id for t in tools], + id=Subquery(latest_subquery.values('id')[:1]) + ) + qd = {q.tool_id: q for q in qs} + results = [] + for tool in tools: + qv = qd.get(tool.id) + func = get_workflow_func(tool, qv, workspace_id) + args = get_workflow_args(tool, qv) + tool = StructuredTool.from_function( + func=func, + name=tool.name, + description=tool.desc, + args_schema=args, + ) + results.append(tool) + + return results def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, @@ -178,7 +278,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record model_id = reference_data.get('model_id', model_id) model_params_setting = reference_data.get('model_params_setting') - if model_params_setting is None and model_id: + if model_params_setting is None and model_id: model_params_setting = get_default_model_params_setting(model_id) if model_setting is None: @@ -216,7 +316,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record mcp_result = self._handle_mcp_request( mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids, application_ids, skill_tool_ids, mcp_output_enable, - chat_model, message_list, history_message, question, chat_id + chat_model, message_list, history_message, question, chat_id, workspace_id ) if mcp_result: return mcp_result @@ -236,7 +336,8 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids, application_ids, skill_tool_ids, - mcp_output_enable, chat_model, message_list, history_message, question, chat_id): + mcp_output_enable, chat_model, message_list, history_message, question, chat_id, + workspace_id): mcp_servers_config = {} @@ -259,11 +360,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])} mcp_servers_config = self.handle_variables(mcp_servers_config) tool_init_params = {} + tools = get_tools(tool_ids, workspace_id) if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP self.context['tool_ids'] = tool_ids for tool_id in tool_ids: - tool = QuerySet(Tool).filter(id=tool_id).first() - if not tool.is_active: + tool = QuerySet(Tool).filter(id=tool_id, tool_type=ToolType.CUSTOM).first() + if tool is None or not tool.is_active: continue executor = ToolExecutor() if tool.init_params is not None: @@ -323,7 +425,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids }) mcp_servers_config['skills'] = skill_file_items - if len(mcp_servers_config) > 0: + if len(mcp_servers_config) > 0 or len(tools) > 0: # 安全获取 application application_id = None if (self.workflow_manage and @@ -334,7 +436,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids source_id = application_id or knowledge_id source_type = 'APPLICATION' if application_id else 'KNOWLEDGE' r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable, - tool_init_params, source_id, source_type, chat_id) + tool_init_params, source_id, source_type, chat_id, tools) return NodeResult( {'result': r, 'chat_model': chat_model, 'message_list': message_list, 'history_message': [{'content': message.content, 'role': message.type} for message in diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index a986e97c760..77117769769 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -56,6 +56,7 @@ def _merge_lists_normalize_empty_tool_chunk_ids(left, *others): """Wrapper around merge_lists that normalises empty-string IDs to None in tool_call_chunk items (those with an 'index' key) so that qwen streaming chunks with id='' are merged correctly by index.""" + def _norm(lst): if lst is None: return lst @@ -158,9 +159,9 @@ def get_reasoning_content(self, chunk): self.reasoning_content_end_tag) if reasoning_content_end_tag_index > -1: reasoning_content_chunk = self.reasoning_content_chunk[ - 0:reasoning_content_end_tag_index] + 0:reasoning_content_end_tag_index] content_chunk = self.reasoning_content_chunk[ - reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:] + reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:] self.reasoning_content += reasoning_content_chunk self.content += content_chunk self.reasoning_content_chunk = "" @@ -168,7 +169,7 @@ def get_reasoning_content(self, chunk): return {'content': content_chunk, 'reasoning_content': reasoning_content_chunk} else: reasoning_content_chunk = self.reasoning_content_chunk[ - 0:reasoning_content_end_tag_prefix_index + 1] + 0:reasoning_content_end_tag_prefix_index + 1] self.reasoning_content_chunk = self.reasoning_content_chunk.replace( reasoning_content_chunk, '') self.reasoning_content += reasoning_content_chunk @@ -401,11 +402,14 @@ async def _initialize_skills(mcp_servers, temp_dir): async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={}, - source_id=None, source_type=None, temp_dir=None, chat_id=None): + source_id=None, source_type=None, temp_dir=None, chat_id=None, extra_tools=None): try: checkpointer = MemorySaver() client = await _initialize_skills(mcp_servers, temp_dir) tools = await client.get_tools() + if extra_tools: + for tool in extra_tools: + tools.append(tool) agent = create_deep_agent( model=chat_model, backend=SandboxShellBackend(root_dir=temp_dir, virtual_mode=True), @@ -517,7 +521,7 @@ def _upsert_fragment(key, raw_id, func_name, part_args): # qwen-plus often emits {} here as a placeholder while # the real args are split in tool_call_chunks/invalid_tool_calls. if has_tool_call_chunks and ( - part_args == '' or part_args == {} or part_args == [] + part_args == '' or part_args == {} or part_args == [] ): part_args = '' key = _get_fragment_key(tool_call.get('index'), raw_id) @@ -563,9 +567,9 @@ def _upsert_fragment(key, raw_id, func_name, part_args): # 3. 检测工具调用结束,更新 tool_calls_info # ---------------------------------------------------------------- is_finish_chunk = ( - chunk[0].response_metadata.get( - 'finish_reason') == 'tool_calls' - or chunk[0].chunk_position == 'last' + chunk[0].response_metadata.get( + 'finish_reason') == 'tool_calls' + or chunk[0].chunk_position == 'last' ) if is_finish_chunk: @@ -734,7 +738,7 @@ async def save_tool_record(tool_id, tool_info, tool_result, source_id, source_ty def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={}, - source_id=None, source_type=None, chat_id=None): + source_id=None, source_type=None, chat_id=None, extra_tools=None): """使用全局事件循环,不创建新实例""" result_queue = queue.Queue() loop = get_global_loop() # 使用共享循环 @@ -751,7 +755,7 @@ def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_ena async def _run(): try: async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable, tool_init_params, - source_id, source_type, temp_dir, chat_id) + source_id, source_type, temp_dir, chat_id, extra_tools) async for chunk in async_gen: result_queue.put(('data', chunk)) except Exception as e: diff --git a/ui/src/locales/lang/en-US/dynamics-form.ts b/ui/src/locales/lang/en-US/dynamics-form.ts index 006060eb588..5f6ad5a4772 100644 --- a/ui/src/locales/lang/en-US/dynamics-form.ts +++ b/ui/src/locales/lang/en-US/dynamics-form.ts @@ -51,6 +51,10 @@ export default { placeholder: 'Please select a type', requiredMessage: 'Type is a required property', }, + desc: { + label: 'description', + placeholder: 'Please enter a description', + }, }, DatePicker: { placeholder: 'Select Date', diff --git a/ui/src/locales/lang/zh-CN/dynamics-form.ts b/ui/src/locales/lang/zh-CN/dynamics-form.ts index 64e8ab35226..304a43b560d 100644 --- a/ui/src/locales/lang/zh-CN/dynamics-form.ts +++ b/ui/src/locales/lang/zh-CN/dynamics-form.ts @@ -51,6 +51,10 @@ export default { placeholder: '请选择组件类型', requiredMessage: '组建类型 为必填属性', }, + desc: { + label: '描述', + placeholder: '请输入描述', + }, }, DatePicker: { placeholder: '选择日期', diff --git a/ui/src/locales/lang/zh-Hant/dynamics-form.ts b/ui/src/locales/lang/zh-Hant/dynamics-form.ts index 3bdd59aec14..3bbaaffbe50 100644 --- a/ui/src/locales/lang/zh-Hant/dynamics-form.ts +++ b/ui/src/locales/lang/zh-Hant/dynamics-form.ts @@ -51,6 +51,10 @@ export default { placeholder: '請選擇組件類型', requiredMessage: '組件類型 為必填屬性', }, + desc: { + label: '描述', + placeholder: '請輸入描述', + }, }, DatePicker: { placeholder: '選擇日期', diff --git a/ui/src/views/application/component/ToolDialog.vue b/ui/src/views/application/component/ToolDialog.vue index aecd4032532..40933193424 100644 --- a/ui/src/views/application/component/ToolDialog.vue +++ b/ui/src/views/application/component/ToolDialog.vue @@ -279,6 +279,12 @@ function getFolder() { function getList() { const folder_id = currentFolder.value?.id || user.getWorkspaceId() + const query: any = {} + if (props.tool_type.includes(',')) { + query['tool_type_list'] = props.tool_type.split(',') + } else { + query['tool_type'] = props.tool_type + } loadSharedApi({ type: 'tool', isShared: folder_id === 'share', @@ -286,7 +292,7 @@ function getList() { }) .getToolList({ folder_id: folder_id, - tool_type: props.tool_type, + ...query, }) .then((res: any) => { toolList.value = res.data?.tools || res.data || [] diff --git a/ui/src/workflow/nodes/ai-chat-node/index.vue b/ui/src/workflow/nodes/ai-chat-node/index.vue index fa37aee59b8..98ce17ab00c 100644 --- a/ui/src/workflow/nodes/ai-chat-node/index.vue +++ b/ui/src/workflow/nodes/ai-chat-node/index.vue @@ -488,7 +488,7 @@ @refresh="submitReasoningDialog" /> - + @@ -724,12 +724,12 @@ function getToolSelectOptions() { apiType.value === 'systemManage' ? { scope: 'WORKSPACE', - tool_type: 'CUSTOM', + tool_type_list: ['CUSTOM', 'WORKFLOW'], workspace_id: resource.value?.workspace_id, } : { scope: 'WORKSPACE', - tool_type: 'CUSTOM', + tool_type_list: ['CUSTOM', 'WORKFLOW'], } loadSharedApi({ type: 'tool', systemType: apiType.value }) diff --git a/ui/src/workflow/nodes/tool-base-node/component/input/InputFieldFormDialog.vue b/ui/src/workflow/nodes/tool-base-node/component/input/InputFieldFormDialog.vue index ed59fd24f75..c80a85952f3 100644 --- a/ui/src/workflow/nodes/tool-base-node/component/input/InputFieldFormDialog.vue +++ b/ui/src/workflow/nodes/tool-base-node/component/input/InputFieldFormDialog.vue @@ -32,6 +32,15 @@ @blur="form.label = form.label?.trim()" /> + + + @@ -66,6 +75,7 @@ const form = ref({ field: '', type: typeOptions[0], label: '', + desc: '', is_required: true, }) diff --git a/ui/src/workflow/nodes/tool-base-node/component/input/InputFieldTable.vue b/ui/src/workflow/nodes/tool-base-node/component/input/InputFieldTable.vue index ee57d5fcd14..5ccad89df9f 100644 --- a/ui/src/workflow/nodes/tool-base-node/component/input/InputFieldTable.vue +++ b/ui/src/workflow/nodes/tool-base-node/component/input/InputFieldTable.vue @@ -79,6 +79,7 @@ function openChangeTitleDialog() { function deleteField(index: any) { inputFieldList.value.splice(index, 1) + set(props.nodeModel.properties, 'user_input_field_list', cloneDeep(inputFieldList.value)) props.nodeModel.graphModel.eventCenter.emit('refreshFieldList') onDragHandle() }