diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 8a586eaf44f..6eec69e7bd6 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -22,7 +22,7 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler -from application.flow.tools import Reasoning, mcp_response_generator +from application.flow.tools import Reasoning, mcp_response_generator, get_tools from application.models import ApplicationChatUserStats, ChatUserType, Application, ApplicationApiKey, \ ApplicationAccessToken from common.exception.app_exception import AppApiException @@ -31,7 +31,7 @@ from common.utils.shared_resource_auth import filter_authorized_ids from common.utils.tool_code import ToolExecutor from models_provider.tools import get_model_instance_by_model_workspace_id -from tools.models import Tool +from tools.models import Tool, ToolType def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None): @@ -232,7 +232,7 @@ def reset_message_list(message_list: List[BaseMessage], answer_text): def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, application_ids, skill_tool_ids, mcp_output_enable, chat_model, message_list, agent_id, - chat_id): + chat_id, workspace_id): mcp_servers_config = {} @@ -252,10 +252,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])} tool_init_params = {} + tools = get_tools("APPLICATION", agent_id, tool_ids, + workspace_id) if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP self.context['tool_ids'] = tool_ids for tool_id in tool_ids: - tool = QuerySet(Tool).filter(id=tool_id).first() + tool = QuerySet(Tool).filter(id=tool_id, tool_type=ToolType.CUSTOM).first() if tool is None or tool.is_active is False: continue executor = ToolExecutor() @@ -316,12 +318,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, }) mcp_servers_config['skills'] = skill_file_items - if len(mcp_servers_config) > 0: + if len(mcp_servers_config) > 0 or len(tools) > 0: source_id = agent_id source_type = 'APPLICATION' return mcp_response_generator( chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable, - tool_init_params, source_id, source_type, chat_id + tool_init_params, source_id, source_type, chat_id, tools ) return None @@ -372,7 +374,7 @@ def get_stream_result(self, message_list: List[BaseMessage], mcp_result = self._handle_mcp_request( mcp_source, mcp_servers, mcp_tool_ids, tool_ids, application_ids, skill_tool_ids, mcp_output_enable, chat_model, - message_list, agent_id, chat_id + message_list, agent_id, chat_id, workspace_id ) if mcp_result: return mcp_result, True @@ -461,7 +463,7 @@ def get_block_result(self, message_list: List[BaseMessage], mcp_result = self._handle_mcp_request( mcp_source, mcp_servers, mcp_tool_ids, tool_ids, application_ids, skill_tool_ids, mcp_output_enable, - chat_model, message_list, application_id, chat_id + chat_model, message_list, application_id, chat_id, workspace_id ) if mcp_result: return mcp_result, True @@ -496,7 +498,7 @@ def execute_block(self, message_list: List[BaseMessage], chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting, problem_text, mcp_tool_ids, mcp_servers, mcp_source, - tool_ids, application_ids, skill_tool_ids,workspace_id, + tool_ids, application_ids, skill_tool_ids, workspace_id, mcp_output_enable, manage.context.get('application_id'), chat_id) if is_ai_chat: diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 1b3020bca0b..664b696ed63 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -9,123 +9,25 @@ import json import re import time -import uuid from functools import reduce from typing import List, Dict -import uuid_utils.compat as uuid -from django.db.models import QuerySet, OuterRef, Subquery +from django.db.models import QuerySet from django.utils.translation import gettext as _ from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage -from langchain_core.tools import StructuredTool -from pydantic import Field, create_model -from application.flow.common import Workflow, WorkflowMode -from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler +from application.flow.common import WorkflowMode +from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode -from application.flow.tools import Reasoning, mcp_response_generator +from application.flow.tools import Reasoning, mcp_response_generator, get_tools from application.models import Application, ApplicationApiKey, ApplicationAccessToken -from application.serializers.common import ToolExecute from common.exception.app_exception import AppApiException from common.utils.rsa_util import rsa_long_decrypt from common.utils.shared_resource_auth import filter_authorized_ids from common.utils.tool_code import ToolExecutor from models_provider.models import Model from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id -from tools.models import Tool, ToolWorkflowVersion, ToolType - - -def build_schema(fields: dict): - return create_model("dynamicSchema", **fields) - - -def get_type(_type: str): - if _type == 'float': - return float - if _type == 'string': - return str - if _type == 'int': - return int - if _type == 'dict': - return dict - if _type == 'array': - return list - if _type == 'boolean': - return bool - return object - - -def get_workflow_args(tool, qv): - for node in qv.work_flow.get('nodes'): - if node.get('type') == 'tool-base-node': - input_field_list = node.get('properties').get('user_input_field_list') - return build_schema( - {field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc'))) - for field in input_field_list}) - - return build_schema({}) - - -def get_workflow_func(node, tool, qv, workspace_id): - tool_id = tool.id - tool_record_id = str(uuid.uuid7()) - took_execute = ToolExecute(tool_id, tool_record_id, - workspace_id, - node.workflow_manage.get_source_type(), - node.workflow_manage.get_source_id(), - False) - - def inner(**kwargs): - from application.flow.tool_workflow_manage import ToolWorkflowManage - work_flow_manage = ToolWorkflowManage( - Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL), - { - 'chat_record_id': tool_record_id, - 'tool_id': tool_id, - 'stream': True, - 'workspace_id': workspace_id, - **kwargs}, - - ToolWorkflowPostHandler(took_execute, tool_id), - is_the_task_interrupted=lambda: False, - child_node=None, - start_node_id=None, - start_node_data=None, - chat_record=None - ) - res = work_flow_manage.run() - for r in res: - pass - return work_flow_manage.out_context - - return inner - - -def get_tools(node, tool_workflow_ids, workspace_id): - tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id) - latest_subquery = ToolWorkflowVersion.objects.filter( - tool_id=OuterRef('tool_id') - ).order_by('-create_time') - - qs = ToolWorkflowVersion.objects.filter( - tool_id__in=[t.id for t in tools], - id=Subquery(latest_subquery.values('id')[:1]) - ) - qd = {q.tool_id: q for q in qs} - results = [] - for tool in tools: - qv = qd.get(tool.id) - func = get_workflow_func(node, tool, qv, workspace_id) - args = get_workflow_args(tool, qv) - tool = StructuredTool.from_function( - func=func, - name=tool.name, - description=tool.desc, - args_schema=args, - ) - results.append(tool) - - return results +from tools.models import Tool, ToolType def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, @@ -362,7 +264,8 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])} mcp_servers_config = self.handle_variables(mcp_servers_config) tool_init_params = {} - tools = get_tools(self, tool_ids, workspace_id) + tools = get_tools(self.workflow_manage.get_source_type(), self.workflow_manage.get_source_id(), tool_ids, + workspace_id) if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP self.context['tool_ids'] = tool_ids for tool_id in tool_ids: diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 86e6d6cdc63..2bd224da347 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -6,13 +6,17 @@ @date:2024/6/6 15:15 @desc: """ -from tools.models import ToolRecord, Tool, ToolScope +from langchain_core.tools import StructuredTool + +from application.flow.common import Workflow, WorkflowMode +from application.serializers.common import ToolExecute +from tools.models import ToolRecord, Tool, ToolScope, ToolWorkflowVersion, ToolType from maxkb.const import CONFIG from knowledge.models.knowledge_action import State from knowledge.models import File from common.utils.logger import maxkb_logger from common.result import result -from application.flow.i_step_node import WorkFlowPostHandler +from application.flow.i_step_node import WorkFlowPostHandler, ToolWorkflowPostHandler from application.flow.backend.sandbox_shell import SandboxShellBackend import asyncio import io @@ -25,11 +29,11 @@ import zipfile from functools import reduce from typing import Iterator - +from pydantic import Field, create_model import uuid_utils.compat as uuid from asgiref.sync import sync_to_async from deepagents import create_deep_agent -from django.db.models import QuerySet +from django.db.models import QuerySet, OuterRef, Subquery from django.http import StreamingHttpResponse from langchain_core.messages import BaseMessageChunk, BaseMessage, ToolMessage, AIMessageChunk, SystemMessage from langchain_mcp_adapters.client import MultiServerMCPClient @@ -990,3 +994,97 @@ def get_child_tool_id_list(work_flow, response): for tool in tool_list: response.append(str(tool.id)) return response + + +def build_schema(fields: dict): + return create_model("dynamicSchema", **fields) + + +def get_type(_type: str): + if _type == 'float': + return float + if _type == 'string': + return str + if _type == 'int': + return int + if _type == 'dict': + return dict + if _type == 'array': + return list + if _type == 'boolean': + return bool + return object + + +def get_workflow_args(tool, qv): + for node in qv.work_flow.get('nodes'): + if node.get('type') == 'tool-base-node': + input_field_list = node.get('properties').get('user_input_field_list') + return build_schema( + {field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc'))) + for field in input_field_list}) + + return build_schema({}) + + +def get_workflow_func(source_type, source_id, tool, qv, workspace_id): + tool_id = tool.id + tool_record_id = str(uuid.uuid7()) + took_execute = ToolExecute(tool_id, tool_record_id, + workspace_id, + source_type, + source_id, + False) + + def inner(**kwargs): + from application.flow.tool_workflow_manage import ToolWorkflowManage + work_flow_manage = ToolWorkflowManage( + Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL), + { + 'chat_record_id': tool_record_id, + 'tool_id': tool_id, + 'stream': True, + 'workspace_id': workspace_id, + **kwargs}, + + ToolWorkflowPostHandler(took_execute, tool_id), + is_the_task_interrupted=lambda: False, + child_node=None, + start_node_id=None, + start_node_data=None, + chat_record=None + ) + res = work_flow_manage.run() + for r in res: + pass + return work_flow_manage.out_context + + return inner + + +def get_tools(source_type, source_id, tool_workflow_ids, workspace_id): + tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id) + latest_subquery = ToolWorkflowVersion.objects.filter( + tool_id=OuterRef('tool_id') + ).order_by('-create_time') + + qs = ToolWorkflowVersion.objects.filter( + tool_id__in=[t.id for t in tools], + id=Subquery(latest_subquery.values('id')[:1]) + ) + qd = {q.tool_id: q for q in qs} + results = [] + for tool in tools: + qv = qd.get(tool.id) + func = get_workflow_func(source_type, source_id, tool, qv, + workspace_id) + args = get_workflow_args(tool, qv) + tool = StructuredTool.from_function( + func=func, + name=tool.name, + description=tool.desc, + args_schema=args, + ) + results.append(tool) + + return results