diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py index cb455e5..04e299d 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py @@ -9,12 +9,14 @@ ) from ldai_langchain.langchain_model_runner import LangChainModelRunner from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory +from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner __version__ = "0.1.0" __all__ = [ '__version__', 'LangChainRunnerFactory', + 'LangGraphAgentGraphRunner', 'LangChainModelRunner', 'convert_messages_to_langchain', 'create_langchain_model', diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py index 402e295..43febb5 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py @@ -1,5 +1,7 @@ +from typing import Any + from ldai.models import AIConfigKind -from ldai.providers import AIProvider +from ldai.providers import AIProvider, ToolRegistry from ldai_langchain.langchain_helper import create_langchain_model from ldai_langchain.langchain_model_runner import LangChainModelRunner @@ -8,6 +10,19 @@ class LangChainRunnerFactory(AIProvider): """LangChain ``AIProvider`` implementation for the LaunchDarkly AI SDK.""" + def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any: + """ + Create a configured LangGraphAgentGraphRunner for the given graph definition. + + :param graph_def: The AgentGraphDefinition to execute + :param tools: Registry mapping tool names to callables (langchain-compatible) + :return: LangGraphAgentGraphRunner ready to execute the graph + """ + from ldai_langchain.langgraph_agent_graph_runner import ( + LangGraphAgentGraphRunner, + ) + return LangGraphAgentGraphRunner(graph_def, tools) + def create_model(self, config: AIConfigKind) -> LangChainModelRunner: """ Create a configured LangChainModelRunner for the given AI config. diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py new file mode 100644 index 0000000..c0c0b5c --- /dev/null +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -0,0 +1,163 @@ +"""LangGraph agent graph runner for LaunchDarkly AI SDK.""" + +import operator +import time +from typing import Annotated, Any, List + +from ldai import log +from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode +from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry +from ldai.providers.types import LDAIMetrics + +from ldai_langchain.langchain_helper import ( + create_langchain_model, + get_ai_metrics_from_response, + get_ai_usage_from_response, + get_tool_calls_from_response, + sum_token_usage_from_messages, +) + + +class LangGraphAgentGraphRunner(AgentGraphRunner): + """ + AgentGraphRunner implementation for LangGraph. + + Compiles and runs the agent graph with LangGraph and automatically records + graph- and node-level AI metric data to the LaunchDarkly trackers on the + graph definition and each node. + + Requires ``langgraph`` to be installed. + """ + + def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry): + """ + Initialize the runner. + + :param graph: The AgentGraphDefinition to execute + :param tools: Registry mapping tool names to callables (langchain-compatible) + """ + self._graph = graph + self._tools = tools + + async def run(self, input: Any) -> AgentGraphResult: + """ + Run the agent graph with the given input. + + Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles + it, and invokes it. Tracks latency and invocation success/failure. + + :param input: The string prompt to send to the agent graph + :return: AgentGraphResult with the final output and metrics + """ + tracker = self._graph.get_tracker() + start_ns = time.perf_counter_ns() + try: + from langchain_core.messages import AnyMessage, HumanMessage + from langgraph.graph import END, START, StateGraph + from typing_extensions import TypedDict + + class WorkflowState(TypedDict): + messages: Annotated[List[Any], operator.add] + + agent_builder: StateGraph = StateGraph(WorkflowState) + root_node = self._graph.root() + root_key = root_node.get_key() if root_node else None + tools_ref = self._tools + exec_path: List[str] = [] + + def handle_traversal(node: AgentGraphNode, ctx: dict) -> None: + node_config = node.get_config() + node_key = node.get_key() + node_tracker = node_config.tracker + + model = None + if node_config.model: + lc_model = create_langchain_model(node_config) + tool_defs = node_config.model.get_parameter('tools') or [] + tool_fns = [ + tools_ref[t.get('name', '')] + for t in tool_defs + if t.get('name', '') in tools_ref + ] + model = lc_model.bind_tools(tool_fns) if tool_fns else lc_model + + def invoke(state: WorkflowState) -> WorkflowState: + exec_path.append(node_key) + if not model: + return {'messages': []} + gk = tracker.graph_key if tracker is not None else None + if node_tracker: + response = node_tracker.track_metrics_of( + lambda: model.invoke(state['messages']), + get_ai_metrics_from_response, + graph_key=gk, + ) + node_tracker.track_tool_calls( + get_tool_calls_from_response(response), + graph_key=tracker.graph_key if tracker is not None else None, + ) + else: + response = model.invoke(state['messages']) + + return {'messages': [response]} + + invoke.__name__ = node_key + + agent_builder.add_node(node_key, invoke) + + if node_key == root_key: + agent_builder.add_edge(START, node_key) + + if node.is_terminal(): + agent_builder.add_edge(node_key, END) + + for edge in node.get_edges(): + agent_builder.add_edge(node_key, edge.target_config) + + return None + + self._graph.traverse(fn=handle_traversal) + compiled = agent_builder.compile() + + result = await compiled.ainvoke( # type: ignore[call-overload] + {'messages': [HumanMessage(content=str(input))]} + ) + duration = (time.perf_counter_ns() - start_ns) // 1_000_000 + + output = '' + messages = result.get('messages', []) + if messages: + last = messages[-1] + if hasattr(last, 'content'): + output = str(last.content) + + if tracker: + tracker.track_path(exec_path) + tracker.track_latency(duration) + tracker.track_invocation_success() + tracker.track_total_tokens( + sum_token_usage_from_messages(messages) + ) + + return AgentGraphResult( + output=output, + raw=result, + metrics=LDAIMetrics(success=True), + ) + except Exception as exc: + if isinstance(exc, ImportError): + log.warning( + "langgraph is required for LangGraphAgentGraphRunner. " + "Install it with: pip install langgraph" + ) + else: + log.warning(f'LangGraphAgentGraphRunner run failed: {exc}') + duration = (time.perf_counter_ns() - start_ns) // 1_000_000 + if tracker: + tracker.track_latency(duration) + tracker.track_invocation_failure() + return AgentGraphResult( + output='', + raw=None, + metrics=LDAIMetrics(success=False), + ) diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py new file mode 100644 index 0000000..de5e8d9 --- /dev/null +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py @@ -0,0 +1,151 @@ +"""Tests for LangGraphAgentGraphRunner and LangChainRunnerFactory.create_agent_graph().""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from ldai.agent_graph import AgentGraphDefinition +from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig +from ldai.providers import AgentGraphResult, ToolRegistry +from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner +from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory + + +def _make_graph(enabled: bool = True) -> AgentGraphDefinition: + root_config = AIAgentConfig( + key='root-agent', + enabled=enabled, + model=ModelConfig(name='gpt-4'), + provider=ProviderConfig(name='openai'), + instructions='You are a helpful assistant.', + tracker=MagicMock(), + ) + graph_config = AIAgentGraphConfig( + key='test-graph', + root_config_key='root-agent', + edges=[], + enabled=enabled, + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, {'root-agent': root_config}) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=MagicMock(), + enabled=enabled, + tracker=MagicMock(), + ) + + +# --- Factory --- + +def test_langchain_runner_factory_create_agent_graph_returns_runner(): + graph = _make_graph() + tools: ToolRegistry = {'fetch_weather': lambda loc: f'weather in {loc}'} + factory = LangChainRunnerFactory() + runner = factory.create_agent_graph(graph, tools) + assert isinstance(runner, LangGraphAgentGraphRunner) + + +def test_langchain_runner_factory_create_agent_graph_wires_graph_and_tools(): + graph = _make_graph() + tools: ToolRegistry = {} + factory = LangChainRunnerFactory() + runner = factory.create_agent_graph(graph, tools) + assert runner._graph is graph + assert runner._tools is tools + + +# --- LangGraphAgentGraphRunner --- + +def test_langgraph_runner_stores_graph_and_tools(): + graph = _make_graph() + tools: ToolRegistry = {} + runner = LangGraphAgentGraphRunner(graph, tools) + assert runner._graph is graph + assert runner._tools is tools + + +@pytest.mark.asyncio +async def test_langgraph_runner_run_raises_when_langgraph_not_installed(): + graph = _make_graph() + runner = LangGraphAgentGraphRunner(graph, {}) + + with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}): + result = await runner.run("test") + assert isinstance(result, AgentGraphResult) + assert result.metrics.success is False + + +@pytest.mark.asyncio +async def test_langgraph_runner_run_tracks_failure_on_exception(): + graph = _make_graph() + tracker = graph.get_tracker() + runner = LangGraphAgentGraphRunner(graph, {}) + + with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}): + result = await runner.run("fail") + + assert result.metrics.success is False + tracker.track_invocation_failure.assert_called_once() + tracker.track_latency.assert_called_once() + + +@pytest.mark.asyncio +async def test_langgraph_runner_run_success(): + graph = _make_graph() + tracker = graph.get_tracker() + + mock_message = MagicMock() + mock_message.content = "langgraph answer" + mock_message.usage_metadata = None + mock_message.response_metadata = None + + mock_compiled = MagicMock() + mock_compiled.ainvoke = AsyncMock(return_value={'messages': [mock_message]}) + + mock_state_graph_instance = MagicMock() + mock_state_graph_instance.add_node = MagicMock() + mock_state_graph_instance.add_edge = MagicMock() + mock_state_graph_instance.compile = MagicMock(return_value=mock_compiled) + + mock_langgraph_graph = MagicMock() + mock_langgraph_graph.END = 'END' + mock_langgraph_graph.START = 'START' + mock_langgraph_graph.StateGraph = MagicMock(return_value=mock_state_graph_instance) + + mock_human_message = MagicMock() + mock_lc_core_messages = MagicMock() + mock_lc_core_messages.HumanMessage = MagicMock(return_value=mock_human_message) + mock_lc_core_messages.AnyMessage = MagicMock() + + mock_model_response = MagicMock() + mock_model_response.content = 'langgraph answer' + mock_model_response.usage_metadata = None + mock_model_response.response_metadata = None + mock_model_response.tool_calls = None + + mock_llm = MagicMock() + mock_llm.invoke = MagicMock(return_value=mock_model_response) + + mock_init_model = MagicMock() + mock_init_model.return_value = mock_llm + mock_langchain_chat = MagicMock() + mock_langchain_chat.init_chat_model = mock_init_model + + with patch.dict('sys.modules', { + 'langgraph': MagicMock(), + 'langgraph.graph': mock_langgraph_graph, + 'langchain_core': MagicMock(), + 'langchain_core.messages': mock_lc_core_messages, + 'langchain': MagicMock(), + 'langchain.chat_models': mock_langchain_chat, + 'typing_extensions': __import__('typing_extensions'), + }): + runner = LangGraphAgentGraphRunner(graph, {}) + result = await runner.run("find restaurants") + + assert isinstance(result, AgentGraphResult) + assert result.output == "langgraph answer" + assert result.metrics.success is True + tracker.track_path.assert_called_once_with([]) + tracker.track_invocation_success.assert_called_once() + tracker.track_latency.assert_called_once() diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py index 8a8199b..422c059 100644 --- a/packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py @@ -1,3 +1,4 @@ +from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner from ldai_openai.openai_helper import ( convert_messages_to_openai, get_ai_metrics_from_response, @@ -8,6 +9,7 @@ __all__ = [ 'OpenAIRunnerFactory', + 'OpenAIAgentGraphRunner', 'OpenAIModelRunner', 'convert_messages_to_openai', 'get_ai_metrics_from_response', diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py new file mode 100644 index 0000000..df2acf6 --- /dev/null +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py @@ -0,0 +1,341 @@ +"""OpenAI agent graph runner for LaunchDarkly AI SDK.""" + +import time +from typing import Any, List, Optional + +from ldai import log +from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode +from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry +from ldai.providers.types import LDAIMetrics +from ldai.tracker import TokenUsage + + +def _to_openai_name(name: str) -> str: + """Convert a hyphenated tool/node name to an underscore-separated OpenAI function name.""" + return name.replace('-', '_') + + +def _build_native_tool_map() -> dict: + try: + from agents import ( + CodeInterpreterTool, + FileSearchTool, + ImageGenerationTool, + WebSearchTool, + ) + return { + 'web_search_tool': lambda _: WebSearchTool(), + 'file_search_tool': lambda _: FileSearchTool(), + 'code_interpreter': lambda _: CodeInterpreterTool(), + 'image_generation': lambda _: ImageGenerationTool(), + } + except ImportError: + return {} + + +_NATIVE_OPENAI_TOOLS = _build_native_tool_map() + + +class _RunState: + """Mutable state shared across handoff and tool callbacks during a single run.""" + + def __init__(self, last_handoff_ns: int, last_node_key: str) -> None: + self.last_handoff_ns = last_handoff_ns + self.last_node_key = last_node_key + + +class OpenAIAgentGraphRunner(AgentGraphRunner): + """ + AgentGraphRunner implementation for the OpenAI Agents SDK. + + Runs the agent graph with the OpenAI Agents SDK and automatically records + graph- and node-level AI metric data to the LaunchDarkly trackers on the + graph definition and each node. + + Requires ``openai-agents`` to be installed. + """ + + def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry): + """ + Initialize the runner. + + :param graph: The AgentGraphDefinition to execute + :param tools: Registry mapping OpenAI-formatted tool names to callables + """ + self._graph = graph + self._tools = tools + + async def run(self, input: Any) -> AgentGraphResult: + """ + Run the agent graph with the given input. + + Builds the agent tree via reverse_traverse, then invokes the root + agent with Runner.run(). Tracks path, latency, and invocation + success/failure. + + :param input: The string prompt to send to the agent graph + :return: AgentGraphResult with the final output and metrics + """ + tracker = self._graph.get_tracker() + path: List[str] = [] + root_node = self._graph.root() + root_key = root_node.get_key() if root_node else '' + if root_key: + path.append(root_key) + + start_ns = time.perf_counter_ns() + state = _RunState(last_handoff_ns=start_ns, last_node_key=root_key) + try: + from agents import Runner + root_agent = self._build_agents(path, state) + result = await Runner.run(root_agent, str(input)) + self._flush_final_segment(state, tracker, result) + + duration = (time.perf_counter_ns() - start_ns) // 1_000_000 + + if tracker: + tracker.track_path(path) + tracker.track_latency(duration) + tracker.track_invocation_success() + try: + usage = result.context_wrapper.usage + tracker.track_total_tokens( + TokenUsage( + total=usage.total_tokens, + input=usage.input_tokens, + output=usage.output_tokens, + ) + ) + except Exception: + pass + + return AgentGraphResult( + output=str(result.final_output), + raw=result, + metrics=LDAIMetrics(success=True), + ) + except Exception as exc: + if isinstance(exc, ImportError): + log.warning( + "openai-agents is required for OpenAIAgentGraphRunner. " + "Install it with: pip install openai-agents" + ) + else: + log.warning(f'OpenAIAgentGraphRunner run failed: {exc}') + duration = (time.perf_counter_ns() - start_ns) // 1_000_000 + if tracker: + tracker.track_latency(duration) + tracker.track_invocation_failure() + return AgentGraphResult( + output='', + raw=None, + metrics=LDAIMetrics(success=False), + ) + + def _flush_final_segment( + self, + state: _RunState, + tracker: Any, + result: Any, + ) -> None: + """Record duration/tokens for the last active agent (no handoff after it).""" + if not state.last_node_key: + return + node = self._graph.get_node(state.last_node_key) + if node is None: + return + config_tracker = node.get_config().tracker + if config_tracker is None: + return + + now_ns = time.perf_counter_ns() + duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000 + + usage: Optional[TokenUsage] = None + try: + usage_entry = result.context_wrapper.usage.request_usage_entries[-1] + usage = TokenUsage( + total=usage_entry.total_tokens, + input=usage_entry.input_tokens, + output=usage_entry.output_tokens, + ) + except Exception: + pass + + gk = tracker.graph_key if tracker is not None else None + if usage is not None: + config_tracker.track_tokens(usage, graph_key=gk) + config_tracker.track_duration(int(duration_ms), graph_key=gk) + config_tracker.track_success(graph_key=gk) + + def _handle_handoff( + self, + run_ctx: Any, + src: str, + tgt: str, + path: List[str], + tracker: Any, + config_tracker: Any, + state: _RunState, + ) -> None: + path.append(tgt) + state.last_node_key = tgt + if tracker: + tracker.track_handoff_success(src, tgt) + + usage: Optional[TokenUsage] = None + now_ns = time.perf_counter_ns() + duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000 + state.last_handoff_ns = now_ns + try: + usage_entry = run_ctx.usage.request_usage_entries[-1] + usage = TokenUsage( + total=usage_entry.total_tokens, + input=usage_entry.input_tokens, + output=usage_entry.output_tokens, + ) + except Exception: + pass + + gk = tracker.graph_key if tracker is not None else None + if config_tracker is not None: + if usage is not None: + config_tracker.track_tokens(usage, graph_key=gk) + if duration_ms is not None: + config_tracker.track_duration(int(duration_ms), graph_key=gk) + config_tracker.track_success(graph_key=gk) + + def _make_on_handoff( + self, + src: str, + tgt: str, + path: List[str], + tracker: Any, + config_tracker: Any, + state: _RunState, + ): + def on_handoff(run_ctx: Any) -> None: + self._handle_handoff( + run_ctx, src, tgt, path, tracker, config_tracker, state + ) + return on_handoff + + def _build_agents(self, path: List[str], state: _RunState) -> Any: + """ + Build the agent tree from the graph definition via reverse_traverse. + + Agents are constructed from terminal nodes upward so that handoff + targets exist before the agents that hand off to them. + + :param path: Mutable list to accumulate the execution path + :param state: Shared run state for tracking handoff timing and last node + :return: The root Agent instance + """ + try: + from agents import ( + Agent, + FunctionTool, + Handoff, + RunContextWrapper, + Tool, + handoff, + ) + from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX + from agents.tool_context import ToolContext + except ImportError as exc: + raise ImportError( + "openai-agents is required for OpenAIAgentGraphRunner. " + "Install it with: pip install openai-agents" + ) from exc + + tracker = self._graph.get_tracker() + + def build_node(node: AgentGraphNode, ctx: dict) -> Any: + node_config = node.get_config() + config_tracker = node_config.tracker + model = node_config.model + + if not model: + raise ValueError(f"Model not set for node '{node_config.key}'") + + tool_defs = model.get_parameter('tools') or [] + + # --- handoffs --- + agent_handoffs: List[Handoff] = [] + for edge in node.get_edges(): + target_key = edge.target_config + agent_handoffs.append( + handoff( + agent=ctx[target_key], + on_handoff=self._make_on_handoff( + node_config.key, + target_key, + path, + tracker, + config_tracker, + state, + ), + ) + ) + + # --- tools --- + agent_tools: List[Tool] = [] + for tool_def in tool_defs: + tool_name_raw = tool_def.get('name', '') + tool_name = _to_openai_name(tool_name_raw) + + # Check native OpenAI tools first, then fall back to ToolRegistry + if tool_name in _NATIVE_OPENAI_TOOLS: + agent_tools.append(_NATIVE_OPENAI_TOOLS[tool_name](tool_def)) + continue + + tool_fn = self._tools.get(tool_name) or self._tools.get(tool_name_raw) + if not tool_fn: + continue + + def _make_tool( + name: str, + raw_name: str, + fn: Any, + description: str, + params_schema: dict, + ) -> FunctionTool: + def wrapped(tool_ctx: ToolContext, tool_args: str) -> Any: + import json + try: + args = json.loads(tool_args) + except Exception: + args = {} + path.append(raw_name) + if config_tracker is not None: + config_tracker.track_tool_call( + name, + graph_key=tracker.graph_key if tracker is not None else None, + ) + return fn(**args) + + return FunctionTool( + name=f'tool_{name}', + description=description, + params_json_schema=params_schema, + on_invoke_tool=wrapped, + ) + + agent_tools.append( + _make_tool( + tool_name, + tool_name_raw, + tool_fn, + tool_def.get('description', ''), + tool_def.get('parameters', {}), + ) + ) + + return Agent( + name=_to_openai_name(node_config.key), + instructions=f'{RECOMMENDED_PROMPT_PREFIX} {node_config.instructions or ""}', + handoffs=list(agent_handoffs), + tools=list(agent_tools), + ) + + return self._graph.reverse_traverse(fn=build_node) diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py index d80fc01..ae4094c 100644 --- a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py @@ -1,8 +1,8 @@ import os -from typing import Optional +from typing import Any, Optional from ldai.models import AIConfigKind -from ldai.providers import AIProvider +from ldai.providers import AIProvider, ToolRegistry from openai import AsyncOpenAI from ldai_openai.openai_model_runner import OpenAIModelRunner @@ -36,6 +36,17 @@ def create_model(self, config: AIConfigKind) -> OpenAIModelRunner: parameters = model_dict.get('parameters') or {} return OpenAIModelRunner(self._client, model_name, parameters) + def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any: + """ + Create a configured OpenAIAgentGraphRunner for the given graph definition. + + :param graph_def: The AgentGraphDefinition to execute + :param tools: Registry mapping tool names to callables + :return: OpenAIAgentGraphRunner ready to execute the graph + """ + from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner + return OpenAIAgentGraphRunner(graph_def, tools) + def get_client(self) -> AsyncOpenAI: """ Return the underlying AsyncOpenAI client. diff --git a/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py b/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py new file mode 100644 index 0000000..56dd0c1 --- /dev/null +++ b/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py @@ -0,0 +1,142 @@ +"""Tests for OpenAIAgentGraphRunner and OpenAIRunnerFactory.create_agent_graph().""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from ldai.agent_graph import AgentGraphDefinition +from ldai.models import AIAgentGraphConfig, AIAgentConfig, Edge, ModelConfig, ProviderConfig +from ldai.providers import AgentGraphResult, ToolRegistry +from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner +from ldai_openai.openai_runner_factory import OpenAIRunnerFactory + + +def _make_graph(enabled: bool = True) -> AgentGraphDefinition: + """Build a minimal single-node AgentGraphDefinition for testing.""" + root_config = AIAgentConfig( + key='root-agent', + enabled=enabled, + model=ModelConfig(name='gpt-4'), + provider=ProviderConfig(name='openai'), + instructions='You are a helpful assistant.', + tracker=MagicMock(), + ) + graph_config = AIAgentGraphConfig( + key='test-graph', + root_config_key='root-agent', + edges=[], + enabled=enabled, + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, {'root-agent': root_config}) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=MagicMock(), + enabled=enabled, + tracker=MagicMock(), + ) + + +# --- Factory --- + +def test_openai_runner_factory_create_agent_graph_returns_runner(): + graph = _make_graph() + tools: ToolRegistry = {'search': lambda q: q} + factory = OpenAIRunnerFactory(client=MagicMock()) + runner = factory.create_agent_graph(graph, tools) + assert isinstance(runner, OpenAIAgentGraphRunner) + + +def test_openai_runner_factory_create_agent_graph_wires_graph_and_tools(): + graph = _make_graph() + tools: ToolRegistry = {'my_tool': lambda: None} + factory = OpenAIRunnerFactory(client=MagicMock()) + runner = factory.create_agent_graph(graph, tools) + assert runner._graph is graph + assert runner._tools is tools + + +# --- OpenAIAgentGraphRunner --- + +def test_openai_agent_graph_runner_stores_graph_and_tools(): + graph = _make_graph() + tools: ToolRegistry = {} + runner = OpenAIAgentGraphRunner(graph, tools) + assert runner._graph is graph + assert runner._tools is tools + + +@pytest.mark.asyncio +async def test_openai_agent_graph_runner_run_raises_when_agents_not_installed(): + graph = _make_graph() + runner = OpenAIAgentGraphRunner(graph, {}) + + with patch.dict('sys.modules', {'agents': None}): + # The import inside run() will fail — runner should return failure result + # rather than propagate the ImportError, since it's caught by the except block + result = await runner.run("test input") + assert isinstance(result, AgentGraphResult) + assert result.metrics.success is False + + +@pytest.mark.asyncio +async def test_openai_agent_graph_runner_run_tracks_invocation_failure_on_exception(): + graph = _make_graph() + tracker = graph.get_tracker() + runner = OpenAIAgentGraphRunner(graph, {}) + + with patch.dict('sys.modules', {'agents': None}): + result = await runner.run("fail") + + assert result.metrics.success is False + tracker.track_invocation_failure.assert_called_once() + tracker.track_latency.assert_called_once() + + +@pytest.mark.asyncio +async def test_openai_agent_graph_runner_run_success(): + graph = _make_graph() + tracker = graph.get_tracker() + + mock_result = MagicMock() + mock_result.final_output = "agent answer" + mock_result.context_wrapper.usage.total_tokens = 0 + mock_result.context_wrapper.usage.input_tokens = 0 + mock_result.context_wrapper.usage.output_tokens = 0 + + mock_runner_module = MagicMock() + mock_runner_module.run = AsyncMock(return_value=mock_result) + + mock_agents = MagicMock() + mock_agents.Runner = mock_runner_module + mock_agents.Agent = MagicMock(return_value=MagicMock()) + mock_agents.FunctionTool = MagicMock() + mock_agents.Handoff = MagicMock() + mock_agents.RunContextWrapper = MagicMock() + mock_agents.Tool = MagicMock() + mock_agents.handoff = MagicMock() + + mock_agents_ext = MagicMock() + mock_agents_ext.RECOMMENDED_PROMPT_PREFIX = '[PREFIX]' + + mock_tool_context = MagicMock() + + with patch.dict('sys.modules', { + 'agents': mock_agents, + 'agents.extensions': MagicMock(), + 'agents.extensions.handoff_prompt': mock_agents_ext, + 'agents.tool_context': mock_tool_context, + }): + runner = OpenAIAgentGraphRunner(graph, {}) + result = await runner.run("find restaurants") + + assert isinstance(result, AgentGraphResult) + assert result.output == "agent answer" + assert result.metrics.success is True + tracker.track_invocation_success.assert_called_once() + tracker.track_path.assert_called_once() + tracker.track_latency.assert_called_once() + + root_tracker = graph.get_node('root-agent').get_config().tracker + root_tracker.track_duration.assert_called_once() + root_tracker.track_tokens.assert_called_once() + root_tracker.track_success.assert_called_once() diff --git a/packages/sdk/server-ai/src/ldai/__init__.py b/packages/sdk/server-ai/src/ldai/__init__.py index 944a0cb..25295a7 100644 --- a/packages/sdk/server-ai/src/ldai/__init__.py +++ b/packages/sdk/server-ai/src/ldai/__init__.py @@ -6,6 +6,7 @@ from ldai.chat import Chat # Deprecated — use ManagedModel from ldai.client import LDAIClient from ldai.judge import Judge +from ldai.managed_agent_graph import ManagedAgentGraph from ldai.managed_model import ManagedModel from ldai.models import ( # Deprecated aliases for backward compatibility AIAgentConfig, @@ -56,6 +57,7 @@ 'AIJudgeConfig', 'AIJudgeConfigDefault', 'ManagedModel', + 'ManagedAgentGraph', 'EvalScore', 'AgentGraphDefinition', 'Judge', diff --git a/packages/sdk/server-ai/src/ldai/client.py b/packages/sdk/server-ai/src/ldai/client.py index 358f9eb..6aef7c6 100644 --- a/packages/sdk/server-ai/src/ldai/client.py +++ b/packages/sdk/server-ai/src/ldai/client.py @@ -7,6 +7,7 @@ from ldai import log from ldai.agent_graph import AgentGraphDefinition from ldai.judge import Judge +from ldai.managed_agent_graph import ManagedAgentGraph from ldai.managed_model import ManagedModel from ldai.models import ( AIAgentConfig, @@ -24,6 +25,7 @@ ModelConfig, ProviderConfig, ) +from ldai.providers import ToolRegistry from ldai.providers.runner_factory import RunnerFactory from ldai.sdk_info import AI_SDK_LANGUAGE, AI_SDK_NAME, AI_SDK_VERSION from ldai.tracker import AIGraphTracker, LDAIConfigTracker @@ -35,6 +37,7 @@ _TRACK_USAGE_CREATE_JUDGE = '$ld:ai:usage:create-judge' _TRACK_USAGE_AGENT_CONFIG = '$ld:ai:usage:agent-config' _TRACK_USAGE_AGENT_CONFIGS = '$ld:ai:usage:agent-configs' +_TRACK_USAGE_CREATE_AGENT_GRAPH = '$ld:ai:usage:create-agent-graph' _INIT_TRACK_CONTEXT = Context.builder('ld-internal-tracking').kind('ld_ai').anonymous(True).build() @@ -609,6 +612,55 @@ def agent_graph( tracker=tracker, ) + async def create_agent_graph( + self, + key: str, + context: Context, + tools: Optional[ToolRegistry] = None, + default_ai_provider: Optional[str] = None, + ) -> Optional[ManagedAgentGraph]: + """ + Creates and returns a new ManagedAgentGraph for AI agent graph execution. + + Resolves the graph configuration via ``agent_graph()``, creates a + provider-specific runner, and wraps it in a ``ManagedAgentGraph``. + + :param key: The key identifying the agent graph configuration + :param context: Standard Context used when evaluating flags + :param tools: Registry mapping tool names to callables + :param default_ai_provider: Optional provider override ('openai', 'langchain', …) + :return: ManagedAgentGraph instance, or None if the graph is disabled or unsupported + + Example:: + + graph = await client.create_agent_graph( + "travel-assistant-graph", + context, + tools={ + "web_search_tool": my_search_fn, + "get_weather": my_weather_fn, + } + ) + + if graph: + result = await graph.run("Find me restaurants in Seattle") + print(result.output) + """ + self._client.track(_TRACK_USAGE_CREATE_AGENT_GRAPH, context, key, 1) + log.debug(f"Creating managed agent graph for key: {key}") + + graph = self.agent_graph(key, context) + if not graph.enabled: + return None + + runner = RunnerFactory.create_agent_graph( + graph, tools or {}, default_ai_provider + ) + if not runner: + return None + + return ManagedAgentGraph(runner, graph.get_tracker()) + def agents( self, agent_configs: List[AIAgentConfigRequest], diff --git a/packages/sdk/server-ai/src/ldai/managed_agent_graph.py b/packages/sdk/server-ai/src/ldai/managed_agent_graph.py new file mode 100644 index 0000000..bb04add --- /dev/null +++ b/packages/sdk/server-ai/src/ldai/managed_agent_graph.py @@ -0,0 +1,60 @@ +"""ManagedAgentGraph — LaunchDarkly managed wrapper for agent graph execution.""" + +from typing import Any, Optional + +from ldai.providers import AgentGraphResult, AgentGraphRunner +from ldai.tracker import AIGraphTracker + + +class ManagedAgentGraph: + """ + LaunchDarkly managed wrapper for AI agent graph execution. + + Holds an AgentGraphRunner and an AIGraphTracker. Auto-tracking of path, + tool calls, handoffs, latency, and invocation success/failure is handled + by the runner implementation. + + Obtain an instance via ``LDAIClient.create_agent_graph()``. + """ + + def __init__( + self, + runner: AgentGraphRunner, + tracker: Optional[AIGraphTracker] = None, + ): + """ + Initialize ManagedAgentGraph. + + :param runner: The AgentGraphRunner to delegate execution to + :param tracker: The AIGraphTracker for this graph + """ + self._runner = runner + self._tracker = tracker + + async def run(self, input: Any) -> AgentGraphResult: + """ + Run the agent graph with the given input. + + Delegates to the underlying AgentGraphRunner, which handles + execution and all auto-tracking internally. + + :param input: The input prompt or structured input for the graph + :return: AgentGraphResult containing the output, raw response, and metrics + """ + return await self._runner.run(input) + + def get_agent_graph_runner(self) -> AgentGraphRunner: + """ + Return the underlying AgentGraphRunner for advanced use. + + :return: The AgentGraphRunner instance + """ + return self._runner + + def get_tracker(self) -> Optional[AIGraphTracker]: + """ + Return the AIGraphTracker for this graph. + + :return: The AIGraphTracker instance, or None if not available + """ + return self._tracker diff --git a/packages/sdk/server-ai/tests/test_managed_agent_graph.py b/packages/sdk/server-ai/tests/test_managed_agent_graph.py new file mode 100644 index 0000000..476ac02 --- /dev/null +++ b/packages/sdk/server-ai/tests/test_managed_agent_graph.py @@ -0,0 +1,195 @@ +"""Tests for ManagedAgentGraph and LDAIClient.create_agent_graph().""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from ldclient import Config, Context, LDClient +from ldclient.integrations.test_data import TestData + +from ldai import LDAIClient, ManagedAgentGraph +from ldai.providers.types import LDAIMetrics +from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry +from ldai.tracker import AIGraphTracker + + +# --- Test double --- + +class StubAgentGraphRunner(AgentGraphRunner): + def __init__(self, output: str = "stub output"): + self._output = output + + async def run(self, input) -> AgentGraphResult: + return AgentGraphResult( + output=self._output, + raw={"input": input}, + metrics=LDAIMetrics(success=True), + ) + + +# --- ManagedAgentGraph unit tests --- + +@pytest.mark.asyncio +async def test_managed_agent_graph_run_delegates_to_runner(): + runner = StubAgentGraphRunner("hello world") + managed = ManagedAgentGraph(runner) + result = await managed.run("test input") + assert result.output == "hello world" + assert result.metrics.success is True + + +def test_managed_agent_graph_get_runner(): + runner = StubAgentGraphRunner() + managed = ManagedAgentGraph(runner) + assert managed.get_agent_graph_runner() is runner + + +def test_managed_agent_graph_get_tracker_none_by_default(): + runner = StubAgentGraphRunner() + managed = ManagedAgentGraph(runner) + assert managed.get_tracker() is None + + +def test_managed_agent_graph_get_tracker_returns_tracker(): + runner = StubAgentGraphRunner() + tracker = MagicMock(spec=AIGraphTracker) + managed = ManagedAgentGraph(runner, tracker) + assert managed.get_tracker() is tracker + + +# --- LDAIClient.create_agent_graph() integration tests --- + +@pytest.fixture +def td() -> TestData: + td = TestData.data_source() + + td.update( + td.flag('travel-graph') + .variations({ + 'root': 'triage-agent', + 'edges': { + 'triage-agent': [{'key': 'specialist-agent'}], + }, + '_ldMeta': {'enabled': True, 'variationKey': 'v1', 'version': 1}, + }) + .variation_for_all(0) + ) + + td.update( + td.flag('triage-agent') + .variations({ + 'model': {'name': 'gpt-4'}, + 'provider': {'name': 'openai'}, + 'instructions': 'You are a triage agent.', + '_ldMeta': {'enabled': True, 'variationKey': 'triage-v1', 'version': 1}, + }) + .variation_for_all(0) + ) + + td.update( + td.flag('specialist-agent') + .variations({ + 'model': {'name': 'gpt-4'}, + 'provider': {'name': 'openai'}, + 'instructions': 'You are a specialist.', + '_ldMeta': {'enabled': True, 'variationKey': 'specialist-v1', 'version': 1}, + }) + .variation_for_all(0) + ) + + td.update( + td.flag('disabled-graph') + .variations({ + '_ldMeta': {'enabled': False, 'variationKey': 'disabled-v1', 'version': 1}, + }) + .variation_for_all(0) + ) + + return td + + +@pytest.fixture +def client(td: TestData) -> LDClient: + config = Config('sdk-key', update_processor_class=td, send_events=False) + return LDClient(config=config) + + +@pytest.fixture +def ldai_client(client: LDClient) -> LDAIClient: + return LDAIClient(client) + + +@pytest.mark.asyncio +async def test_create_agent_graph_returns_managed_agent_graph(ldai_client: LDAIClient): + context = Context.create('user-key') + stub_runner = StubAgentGraphRunner("result") + + with patch( + 'ldai.providers.runner_factory.RunnerFactory.create_agent_graph', + new=MagicMock(return_value=stub_runner), + ): + managed = await ldai_client.create_agent_graph('travel-graph', context) + + assert managed is not None + assert isinstance(managed, ManagedAgentGraph) + assert managed.get_agent_graph_runner() is stub_runner + + +@pytest.mark.asyncio +async def test_create_agent_graph_returns_none_when_disabled(ldai_client: LDAIClient): + context = Context.create('user-key') + managed = await ldai_client.create_agent_graph('disabled-graph', context) + assert managed is None + + +@pytest.mark.asyncio +async def test_create_agent_graph_returns_none_when_runner_factory_fails(ldai_client: LDAIClient): + context = Context.create('user-key') + + with patch( + 'ldai.providers.runner_factory.RunnerFactory.create_agent_graph', + new=MagicMock(return_value=None), + ): + managed = await ldai_client.create_agent_graph('travel-graph', context) + + assert managed is None + + +@pytest.mark.asyncio +async def test_create_agent_graph_passes_tools_to_factory(ldai_client: LDAIClient): + context = Context.create('user-key') + tools: ToolRegistry = {'search': lambda q: f'results for {q}'} + captured = {} + + def fake_create_agent_graph(graph_def, tools_arg, default_ai_provider=None): + captured['tools'] = tools_arg + return StubAgentGraphRunner() + + with patch( + 'ldai.providers.runner_factory.RunnerFactory.create_agent_graph', + new=fake_create_agent_graph, + ): + await ldai_client.create_agent_graph('travel-graph', context, tools=tools) + + assert captured['tools'] is tools + + +@pytest.mark.asyncio +async def test_create_agent_graph_run_produces_result(ldai_client: LDAIClient): + context = Context.create('user-key') + + with patch( + 'ldai.providers.runner_factory.RunnerFactory.create_agent_graph', + new=MagicMock(return_value=StubAgentGraphRunner("final answer")), + ): + managed = await ldai_client.create_agent_graph('travel-graph', context) + + assert managed is not None + result = await managed.run("find restaurants") + assert result.output == "final answer" + assert result.metrics.success is True + + +# --- Top-level export --- + +def test_managed_agent_graph_exported(): + import ldai + assert hasattr(ldai, 'ManagedAgentGraph')