From 0676acf124a35afcccc129fa12219d9b522e8fd5 Mon Sep 17 00:00:00 2001 From: Dawid Taborski Date: Fri, 24 Apr 2026 15:34:35 +0200 Subject: [PATCH 1/3] Changing AgentState.response -> AgentState.messages --- splunklib/ai/engines/langchain.py | 50 ++++++++------- splunklib/ai/middleware.py | 6 +- tests/integration/ai/test_agent.py | 38 ++++++------ .../integration/ai/test_conversation_store.py | 20 +++--- tests/integration/ai/test_hooks.py | 36 +++++++---- tests/integration/ai/test_middleware.py | 13 ++-- tests/unit/ai/test_default_limits.py | 61 ++++++++++++++----- 7 files changed, 135 insertions(+), 89 deletions(-) diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index 76fa100b..4424816c 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -279,9 +279,13 @@ async def awrap_tool_call( assert resp.artifact is None, "artifact is already populated" if resp.name.startswith(AGENT_PREFIX): - resp.artifact = SubagentFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType] + resp.artifact = SubagentFailureResult( + str(resp.content) + ) # pyright: ignore[reportUnknownArgumentType] else: - resp.artifact = ToolFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType] + resp.artifact = ToolFailureResult( + str(resp.content) + ) # pyright: ignore[reportUnknownArgumentType] return resp @@ -967,9 +971,9 @@ async def _sdk_handler(request: ToolRequest) -> ToolResponse: lc_request = _convert_tool_request_to_lc(request, original_request) result = await handler(lc_request) sdk_result = _convert_tool_message_from_lc(result) - assert isinstance(sdk_result, ToolMessage), ( - "Expected tool response from tool middleware handler" - ) + assert isinstance( + sdk_result, ToolMessage + ), "Expected tool response from tool middleware handler" return ToolResponse(sdk_result.result) return _sdk_handler @@ -987,9 +991,9 @@ async def _sdk_handler( lc_request = _convert_subagent_request_to_lc(request, original_request) result = await handler(lc_request) sdk_result = _convert_tool_message_from_lc(result) - assert isinstance(sdk_result, SubagentMessage), ( - "Expected subagent response from subagent middleware handler" - ) + assert isinstance( + sdk_result, SubagentMessage + ), "Expected subagent response from subagent middleware handler" return SubagentResponse(sdk_result.result) return _sdk_handler @@ -1182,16 +1186,18 @@ def _convert_tool_message_from_lc( ) case LC_ToolMessage(): # If this is reached, we likely passed an invalid tool name to LangChain. - assert message.name is not None, ( - "LangChain responded with a nameless tool call" - ) + assert ( + message.name is not None + ), "LangChain responded with a nameless tool call" if message.name.startswith(TOOL_STRATEGY_TOOL_PREFIX): return StructuredOutputMessage( name=message.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX), call_id=message.tool_call_id, status=message.status, - content=str(message.content), # pyright: ignore[reportUnknownArgumentType] + content=str( + message.content + ), # pyright: ignore[reportUnknownArgumentType] ) assert isinstance(message.artifact, ToolResult) or isinstance( @@ -1266,7 +1272,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe def _convert_agent_state_to_lc(state: AgentState) -> LC_AgentState[Any]: - messages = [_map_message_to_langchain(m) for m in state.response.messages] + messages = [_map_message_to_langchain(m) for m in state.messages] return LC_AgentState(messages=messages) @@ -1351,7 +1357,9 @@ async def _tool_call( except ToolException as e: raise LC_ToolException(*e.args) from e except LC_ToolException: - assert False, ( # noqa: PT015 + assert ( + False + ), ( # noqa: PT015 "ToolException from LangChain should not be raised in tool.func" ) @@ -1454,6 +1462,7 @@ async def _run( # pyright: ignore[reportRedeclaration] content: str, thread_id: str ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: return await invoke_agent(HumanMessage(content=content), thread_id) + else: async def _run( # pyright: ignore[reportRedeclaration] @@ -1627,14 +1636,9 @@ def _convert_agent_state_from_langchain( messages = state["messages"] total_tokens_counter = _get_approximate_token_counter(model) total_tokens = total_tokens_counter(messages) - - response = AgentResponse[Any | None]( - messages=[_map_message_from_langchain(m) for m in state["messages"]], - structured_output=state.get("structured_response"), - ) - + messages = [_map_message_from_langchain(m) for m in state["messages"]] return AgentState( - response=response, + messages=messages, total_steps=len(messages), token_count=total_tokens, ) @@ -1646,7 +1650,9 @@ def _get_approximate_token_counter(model: BaseChatModel) -> LC_TokenCounter: # NOTE: This is adapted from the backend provider library # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting # API: https://platform.claude.com/docs/en/build-with-claude/token-counting - if model._llm_type == ANTHROPIC_CHAT_MODEL_TYPE: # pyright: ignore[reportPrivateUsage] + if ( + model._llm_type == ANTHROPIC_CHAT_MODEL_TYPE + ): # pyright: ignore[reportPrivateUsage] return partial(count_tokens_approximately, chars_per_token=3.3) return count_tokens_approximately diff --git a/splunklib/ai/middleware.py b/splunklib/ai/middleware.py index 8814c5d6..0231dbb6 100644 --- a/splunklib/ai/middleware.py +++ b/splunklib/ai/middleware.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from collections.abc import Awaitable, Callable +from collections.abc import Sequence, Awaitable, Callable from dataclasses import dataclass from typing import Any, override @@ -35,7 +35,7 @@ class AgentState: """AgentState is available through certain middlewares and contains information about the current state of an agent execution.""" # holds messages exchanged so far in the conversation - response: AgentResponse[Any | None] + messages: Sequence[BaseMessage] # steps taken so far in the conversation total_steps: int # tokens used so far in the conversation @@ -96,7 +96,7 @@ def __post_init__(self) -> None: @dataclass(frozen=True) class AgentRequest: - messages: list[BaseMessage] + messages: Sequence[BaseMessage] AgentMiddlewareHandler = Callable[[AgentRequest], Awaitable[AgentResponse[Any | None]]] diff --git a/tests/integration/ai/test_agent.py b/tests/integration/ai/test_agent.py index 1f9ea591..52e2262b 100644 --- a/tests/integration/ai/test_agent.py +++ b/tests/integration/ai/test_agent.py @@ -65,9 +65,9 @@ async def test_agent_with_openai_round_trip(self): ) response = result.final_message.content.strip().lower().replace(".", "") - assert result.structured_output is None, ( - "The structured output should not be populated" - ) + assert ( + result.structured_output is None + ), "The structured output should not be populated" assert "stefan" in response @pytest.mark.asyncio @@ -160,9 +160,9 @@ class Person(BaseModel): # check if the last message contains the response in natural language assert response.name in last_message, "Name field not found in the message" - assert str(response.age) in last_message, ( - "Age field not found in the message" - ) + assert ( + str(response.age) in last_message + ), "Age field not found in the message" async def test_agent_uses_subagent(self): pytest.importorskip("langchain_openai") @@ -215,9 +215,9 @@ class NicknameGeneratorInput(BaseModel): subagent_message = next( filter(lambda m: m.role == "subagent", result.messages), None ) - assert isinstance(subagent_message, SubagentMessage), ( - "Invalid subagent message" - ) + assert isinstance( + subagent_message, SubagentMessage + ), "Invalid subagent message" assert subagent_message, "No subagent message found in response" response = result.final_message.content @@ -366,12 +366,12 @@ class SupervisorOutput(BaseModel): ) response = result.structured_output - assert type(response) == SupervisorOutput, ( - "Response is not of type Team" - ) - assert len(response.member_descriptions) == 3, ( - "Team does not have 3 members" - ) + assert ( + type(response) == SupervisorOutput + ), "Response is not of type Team" + assert ( + len(response.member_descriptions) == 3 + ), "Team does not have 3 members" @pytest.mark.asyncio async def test_duplicated_subagent_name(self) -> None: @@ -520,9 +520,9 @@ async def _subagent_call_middleware( # Override the arguments, such that are invalid. resp = await handler(replace(request, call=replace(request.call, args={}))) - assert isinstance(resp.result, SubagentFailureResult), ( - "subagent call did not fail" - ) + assert isinstance( + resp.result, SubagentFailureResult + ), "subagent call did not fail" after_subagent_call = True return resp @@ -532,7 +532,7 @@ async def _model_call_middleware( req: ModelRequest, _handler: ModelMiddlewareHandler ) -> ModelResponse: if after_subagent_call: - msgs = req.state.response.messages + msgs = req.state.messages assert isinstance(msgs[-1], SubagentMessage) assert isinstance(msgs[-1].result, SubagentFailureResult) diff --git a/tests/integration/ai/test_conversation_store.py b/tests/integration/ai/test_conversation_store.py index a5c10b34..28211e29 100644 --- a/tests/integration/ai/test_conversation_store.py +++ b/tests/integration/ai/test_conversation_store.py @@ -66,9 +66,9 @@ async def _model_middleware( if after_first_call: # Previous messages included. - assert len(request.state.response.messages) == 3 + assert len(request.state.messages) == 3 else: - assert len(request.state.response.messages) == 1 + assert len(request.state.messages) == 1 return await handler(request) @agent_middleware @@ -166,7 +166,7 @@ async def _model_middleware( nonlocal model_middleware_called model_middleware_called = True - assert len(request.state.response.messages) == 1 + assert len(request.state.messages) == 1 return await handler(request) async with Agent( @@ -186,9 +186,9 @@ async def _model_middleware( thread_id="2", ) response = result.final_message.content - assert "Mike" not in response, ( - "Agent remembered the name from a different thread_id" - ) + assert ( + "Mike" not in response + ), "Agent remembered the name from a different thread_id" assert model_middleware_called @@ -276,9 +276,9 @@ async def _model_middleware( nonlocal after_first_call if after_first_call: - assert len(request.state.response.messages) == 3 + assert len(request.state.messages) == 3 else: - assert len(request.state.response.messages) == 1 + assert len(request.state.messages) == 1 after_first_call = True return await handler(request) @@ -347,9 +347,9 @@ async def _model_middleware( nonlocal after_first_call if after_first_call: - assert len(request.state.response.messages) == 3 + assert len(request.state.messages) == 3 else: - assert len(request.state.response.messages) == 1 + assert len(request.state.messages) == 1 after_first_call = True return await handler(request) diff --git a/tests/integration/ai/test_hooks.py b/tests/integration/ai/test_hooks.py index ad22a75b..71c4f3d5 100644 --- a/tests/integration/ai/test_hooks.py +++ b/tests/integration/ai/test_hooks.py @@ -30,7 +30,13 @@ before_model, ) from splunklib.ai.messages import AIMessage, AgentResponse, HumanMessage -from splunklib.ai.middleware import AgentRequest, ModelMiddlewareHandler, ModelRequest, ModelResponse, model_middleware +from splunklib.ai.middleware import ( + AgentRequest, + ModelMiddlewareHandler, + ModelRequest, + ModelResponse, + model_middleware, +) from tests.ai_testlib import AITestCase @@ -47,7 +53,7 @@ def test_hook_before(req: ModelRequest) -> None: hook_calls += 1 assert req.system_message.startswith("Your name is stefan") - assert len(req.state.response.messages) == 1 + assert len(req.state.messages) == 1 @before_model async def test_async_hook_before(req: ModelRequest) -> None: @@ -55,7 +61,7 @@ async def test_async_hook_before(req: ModelRequest) -> None: hook_calls += 1 assert req.system_message.startswith("Your name is stefan") - assert len(req.state.response.messages) == 1 + assert len(req.state.messages) == 1 @after_model def test_hook_after(resp: ModelResponse) -> None: @@ -197,10 +203,12 @@ async def test_agent_loop_stop_conditions_conversation_limit(self) -> None: with pytest.raises( StepsLimitExceededException, match="Steps limit of 2 exceeded" ): - _ = await agent.invoke([ - HumanMessage(content="hi, my name is Chris"), - HumanMessage(content="What is my name?"), - ]) + _ = await agent.invoke( + [ + HumanMessage(content="hi, my name is Chris"), + HumanMessage(content="What is my name?"), + ] + ) @pytest.mark.asyncio async def test_agent_loop_stop_conditions_conversation_limit_with_checkpointer( @@ -220,13 +228,17 @@ async def test_agent_loop_stop_conditions_conversation_limit_with_checkpointer( with pytest.raises( StepsLimitExceededException, match="Steps limit of 2 exceeded" ): - _ = await agent.invoke([ - HumanMessage(content="What is my name?"), - HumanMessage(content="Are you sure?"), - ]) + _ = await agent.invoke( + [ + HumanMessage(content="What is my name?"), + HumanMessage(content="Are you sure?"), + ] + ) @pytest.mark.asyncio - async def test_agent_loop_stop_conditions_steps_accumulate_across_invokes(self) -> None: + async def test_agent_loop_stop_conditions_steps_accumulate_across_invokes( + self, + ) -> None: pytest.importorskip("langchain_openai") step_limit = StepLimitMiddleware(2) diff --git a/tests/integration/ai/test_middleware.py b/tests/integration/ai/test_middleware.py index d699bb5b..eec31db4 100644 --- a/tests/integration/ai/test_middleware.py +++ b/tests/integration/ai/test_middleware.py @@ -78,7 +78,7 @@ async def test_middleware( assert call.args == {"city": "Krakow"} state = request.state - assert len(state.response.messages) == 2 + assert len(state.messages) == 2 response = await handler(request) assert isinstance(response.result, ToolResult) @@ -500,9 +500,9 @@ async def test_middleware( ) assert subagent_message, "SubagentMessage not found in messages" assert isinstance(subagent_message.result, SubagentTextResult) - assert subagent_message.result.content == "Chris-superstar", ( - "Invalid response from subagent" - ) + assert ( + subagent_message.result.content == "Chris-superstar" + ), "Invalid response from subagent" assert middleware_called, "Middleware was not called" @pytest.mark.asyncio @@ -699,10 +699,7 @@ async def mutating_middleware( ) -> ModelResponse: new_state = replace( request.state, - response=replace( - request.state.response, - messages=[HumanMessage(content="What is the capital of France?")], - ), + messages=[HumanMessage(content="What is the capital of France?")], ) return await handler(replace(request, state=new_state)) diff --git a/tests/unit/ai/test_default_limits.py b/tests/unit/ai/test_default_limits.py index e97c67c7..6eb2d076 100644 --- a/tests/unit/ai/test_default_limits.py +++ b/tests/unit/ai/test_default_limits.py @@ -28,7 +28,13 @@ TokenLimitMiddleware, ) from splunklib.ai.messages import AIMessage, AgentResponse -from splunklib.ai.middleware import AgentMiddleware, AgentRequest, AgentState, ModelRequest, ModelResponse +from splunklib.ai.middleware import ( + AgentMiddleware, + AgentRequest, + AgentState, + ModelRequest, + ModelResponse, +) from splunklib.ai.model import OpenAIModel from splunklib.client import Service @@ -48,7 +54,7 @@ def _make_agent_request() -> AgentRequest: def _make_model_request(token_count: int = 0, total_steps: int = 0) -> ModelRequest: state = AgentState( - response=AgentResponse(messages=[], structured_output=None), + messages=[], total_steps=total_steps, token_count=token_count, ) @@ -69,9 +75,13 @@ def test_default_values_match_constants(self) -> None: token = next(m for m in mw if isinstance(m, TokenLimitMiddleware)) step = next(m for m in mw if isinstance(m, StepLimitMiddleware)) timeout = next(m for m in mw if isinstance(m, TimeoutLimitMiddleware)) - assert token._limit == DEFAULT_TOKEN_LIMIT # pyright: ignore[reportPrivateUsage] + assert ( + token._limit == DEFAULT_TOKEN_LIMIT + ) # pyright: ignore[reportPrivateUsage] assert step._limit == DEFAULT_STEP_LIMIT # pyright: ignore[reportPrivateUsage] - assert timeout._seconds == DEFAULT_TIMEOUT_SECONDS # pyright: ignore[reportPrivateUsage] + assert ( + timeout._seconds == DEFAULT_TIMEOUT_SECONDS + ) # pyright: ignore[reportPrivateUsage] def test_user_token_limit_suppresses_default(self) -> None: agent = _make_agent(middleware=[TokenLimitMiddleware(50_000)]) @@ -102,7 +112,11 @@ def test_user_timeout_limit_suppresses_default(self) -> None: def test_all_user_limits_suppress_all_defaults(self) -> None: agent = _make_agent( - middleware=[TokenLimitMiddleware(50_000), StepLimitMiddleware(10), TimeoutLimitMiddleware(30.0)] + middleware=[ + TokenLimitMiddleware(50_000), + StepLimitMiddleware(10), + TimeoutLimitMiddleware(30.0), + ] ) mw = list(agent.middleware or []) assert len([m for m in mw if isinstance(m, TokenLimitMiddleware)]) == 1 @@ -139,9 +153,15 @@ async def test_deadline_is_none_before_first_invoke(self) -> None: async def test_timeout_fires_when_deadline_exceeded(self) -> None: mw = TimeoutLimitMiddleware(60.0) - mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past - - state = AgentState(response=AgentResponse(messages=[], structured_output=None), total_steps=0, token_count=0) + mw._deadline = ( + monotonic() - 1.0 + ) # pyright: ignore[reportPrivateUsage] # already in the past + + state = AgentState( + messages=[], + total_steps=0, + token_count=0, + ) request = ModelRequest(system_message="", state=state) with self.assertRaises(TimeoutExceededException): @@ -152,18 +172,29 @@ class TestTokenLimitMiddleware(unittest.IsolatedAsyncioTestCase): async def test_raises_when_token_count_in_request_exceeds_limit(self) -> None: mw = TokenLimitMiddleware(200) - await mw.model_middleware(_make_model_request(token_count=100), _noop_model_handler) - await mw.model_middleware(_make_model_request(token_count=199), _noop_model_handler) + await mw.model_middleware( + _make_model_request(token_count=100), _noop_model_handler + ) + await mw.model_middleware( + _make_model_request(token_count=199), _noop_model_handler + ) with self.assertRaises(TokenLimitExceededException): - await mw.model_middleware(_make_model_request(token_count=200), _noop_model_handler) + await mw.model_middleware( + _make_model_request(token_count=200), _noop_model_handler + ) class TestStepLimitMiddleware(unittest.IsolatedAsyncioTestCase): async def test_raises_when_steps_in_request_reach_limit(self) -> None: mw = StepLimitMiddleware(3) - await mw.model_middleware(_make_model_request(total_steps=1), _noop_model_handler) - await mw.model_middleware(_make_model_request(total_steps=2), _noop_model_handler) + await mw.model_middleware( + _make_model_request(total_steps=1), _noop_model_handler + ) + await mw.model_middleware( + _make_model_request(total_steps=2), _noop_model_handler + ) with self.assertRaises(StepsLimitExceededException): - await mw.model_middleware(_make_model_request(total_steps=3), _noop_model_handler) - + await mw.model_middleware( + _make_model_request(total_steps=3), _noop_model_handler + ) From cbfe46e5cec419cd879d86991edca863abce346b Mon Sep 17 00:00:00 2001 From: Dawid Taborski Date: Fri, 24 Apr 2026 16:00:32 +0200 Subject: [PATCH 2/3] fixing lints --- splunklib/ai/engines/langchain.py | 38 ++++++++++++---------------- tests/unit/ai/test_default_limits.py | 12 +++------ 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index 4424816c..2f58a2fb 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -280,12 +280,12 @@ async def awrap_tool_call( if resp.name.startswith(AGENT_PREFIX): resp.artifact = SubagentFailureResult( - str(resp.content) - ) # pyright: ignore[reportUnknownArgumentType] + str(resp.content) # pyright: ignore[reportUnknownArgumentType] + ) else: resp.artifact = ToolFailureResult( - str(resp.content) - ) # pyright: ignore[reportUnknownArgumentType] + str(resp.content) # pyright: ignore[reportUnknownArgumentType] + ) return resp @@ -971,9 +971,9 @@ async def _sdk_handler(request: ToolRequest) -> ToolResponse: lc_request = _convert_tool_request_to_lc(request, original_request) result = await handler(lc_request) sdk_result = _convert_tool_message_from_lc(result) - assert isinstance( - sdk_result, ToolMessage - ), "Expected tool response from tool middleware handler" + assert isinstance(sdk_result, ToolMessage), ( + "Expected tool response from tool middleware handler" + ) return ToolResponse(sdk_result.result) return _sdk_handler @@ -991,9 +991,9 @@ async def _sdk_handler( lc_request = _convert_subagent_request_to_lc(request, original_request) result = await handler(lc_request) sdk_result = _convert_tool_message_from_lc(result) - assert isinstance( - sdk_result, SubagentMessage - ), "Expected subagent response from subagent middleware handler" + assert isinstance(sdk_result, SubagentMessage), ( + "Expected subagent response from subagent middleware handler" + ) return SubagentResponse(sdk_result.result) return _sdk_handler @@ -1186,18 +1186,16 @@ def _convert_tool_message_from_lc( ) case LC_ToolMessage(): # If this is reached, we likely passed an invalid tool name to LangChain. - assert ( - message.name is not None - ), "LangChain responded with a nameless tool call" + assert message.name is not None, ( + "LangChain responded with a nameless tool call" + ) if message.name.startswith(TOOL_STRATEGY_TOOL_PREFIX): return StructuredOutputMessage( name=message.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX), call_id=message.tool_call_id, status=message.status, - content=str( - message.content - ), # pyright: ignore[reportUnknownArgumentType] + content=str(message.content), # pyright: ignore[reportUnknownArgumentType] ) assert isinstance(message.artifact, ToolResult) or isinstance( @@ -1357,9 +1355,7 @@ async def _tool_call( except ToolException as e: raise LC_ToolException(*e.args) from e except LC_ToolException: - assert ( - False - ), ( # noqa: PT015 + assert False, ( # noqa: PT015 "ToolException from LangChain should not be raised in tool.func" ) @@ -1650,9 +1646,7 @@ def _get_approximate_token_counter(model: BaseChatModel) -> LC_TokenCounter: # NOTE: This is adapted from the backend provider library # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting # API: https://platform.claude.com/docs/en/build-with-claude/token-counting - if ( - model._llm_type == ANTHROPIC_CHAT_MODEL_TYPE - ): # pyright: ignore[reportPrivateUsage] + if model._llm_type == ANTHROPIC_CHAT_MODEL_TYPE: # pyright: ignore[reportPrivateUsage] return partial(count_tokens_approximately, chars_per_token=3.3) return count_tokens_approximately diff --git a/tests/unit/ai/test_default_limits.py b/tests/unit/ai/test_default_limits.py index 6eb2d076..1d4bd505 100644 --- a/tests/unit/ai/test_default_limits.py +++ b/tests/unit/ai/test_default_limits.py @@ -75,13 +75,9 @@ def test_default_values_match_constants(self) -> None: token = next(m for m in mw if isinstance(m, TokenLimitMiddleware)) step = next(m for m in mw if isinstance(m, StepLimitMiddleware)) timeout = next(m for m in mw if isinstance(m, TimeoutLimitMiddleware)) - assert ( - token._limit == DEFAULT_TOKEN_LIMIT - ) # pyright: ignore[reportPrivateUsage] + assert token._limit == DEFAULT_TOKEN_LIMIT # pyright: ignore[reportPrivateUsage] assert step._limit == DEFAULT_STEP_LIMIT # pyright: ignore[reportPrivateUsage] - assert ( - timeout._seconds == DEFAULT_TIMEOUT_SECONDS - ) # pyright: ignore[reportPrivateUsage] + assert timeout._seconds == DEFAULT_TIMEOUT_SECONDS # pyright: ignore[reportPrivateUsage] def test_user_token_limit_suppresses_default(self) -> None: agent = _make_agent(middleware=[TokenLimitMiddleware(50_000)]) @@ -153,9 +149,7 @@ async def test_deadline_is_none_before_first_invoke(self) -> None: async def test_timeout_fires_when_deadline_exceeded(self) -> None: mw = TimeoutLimitMiddleware(60.0) - mw._deadline = ( - monotonic() - 1.0 - ) # pyright: ignore[reportPrivateUsage] # already in the past + mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past state = AgentState( messages=[], From a9c1179aceeb5bffff4a93a97f9e2a91d9e87960 Mon Sep 17 00:00:00 2001 From: Dawid Taborski Date: Mon, 27 Apr 2026 10:18:28 +0200 Subject: [PATCH 3/3] reverting formatting --- splunklib/ai/engines/langchain.py | 9 +--- tests/integration/ai/test_agent.py | 36 +++++++-------- .../integration/ai/test_conversation_store.py | 6 +-- tests/integration/ai/test_hooks.py | 32 +++++--------- tests/integration/ai/test_middleware.py | 6 +-- tests/unit/ai/test_default_limits.py | 44 ++++--------------- 6 files changed, 45 insertions(+), 88 deletions(-) diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index 2f58a2fb..0052c30d 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -279,13 +279,9 @@ async def awrap_tool_call( assert resp.artifact is None, "artifact is already populated" if resp.name.startswith(AGENT_PREFIX): - resp.artifact = SubagentFailureResult( - str(resp.content) # pyright: ignore[reportUnknownArgumentType] - ) + resp.artifact = SubagentFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType] else: - resp.artifact = ToolFailureResult( - str(resp.content) # pyright: ignore[reportUnknownArgumentType] - ) + resp.artifact = ToolFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType] return resp @@ -1458,7 +1454,6 @@ async def _run( # pyright: ignore[reportRedeclaration] content: str, thread_id: str ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: return await invoke_agent(HumanMessage(content=content), thread_id) - else: async def _run( # pyright: ignore[reportRedeclaration] diff --git a/tests/integration/ai/test_agent.py b/tests/integration/ai/test_agent.py index 52e2262b..ad906dbd 100644 --- a/tests/integration/ai/test_agent.py +++ b/tests/integration/ai/test_agent.py @@ -65,9 +65,9 @@ async def test_agent_with_openai_round_trip(self): ) response = result.final_message.content.strip().lower().replace(".", "") - assert ( - result.structured_output is None - ), "The structured output should not be populated" + assert result.structured_output is None, ( + "The structured output should not be populated" + ) assert "stefan" in response @pytest.mark.asyncio @@ -160,9 +160,9 @@ class Person(BaseModel): # check if the last message contains the response in natural language assert response.name in last_message, "Name field not found in the message" - assert ( - str(response.age) in last_message - ), "Age field not found in the message" + assert str(response.age) in last_message, ( + "Age field not found in the message" + ) async def test_agent_uses_subagent(self): pytest.importorskip("langchain_openai") @@ -215,9 +215,9 @@ class NicknameGeneratorInput(BaseModel): subagent_message = next( filter(lambda m: m.role == "subagent", result.messages), None ) - assert isinstance( - subagent_message, SubagentMessage - ), "Invalid subagent message" + assert isinstance(subagent_message, SubagentMessage), ( + "Invalid subagent message" + ) assert subagent_message, "No subagent message found in response" response = result.final_message.content @@ -366,12 +366,12 @@ class SupervisorOutput(BaseModel): ) response = result.structured_output - assert ( - type(response) == SupervisorOutput - ), "Response is not of type Team" - assert ( - len(response.member_descriptions) == 3 - ), "Team does not have 3 members" + assert type(response) == SupervisorOutput, ( + "Response is not of type Team" + ) + assert len(response.member_descriptions) == 3, ( + "Team does not have 3 members" + ) @pytest.mark.asyncio async def test_duplicated_subagent_name(self) -> None: @@ -520,9 +520,9 @@ async def _subagent_call_middleware( # Override the arguments, such that are invalid. resp = await handler(replace(request, call=replace(request.call, args={}))) - assert isinstance( - resp.result, SubagentFailureResult - ), "subagent call did not fail" + assert isinstance(resp.result, SubagentFailureResult), ( + "subagent call did not fail" + ) after_subagent_call = True return resp diff --git a/tests/integration/ai/test_conversation_store.py b/tests/integration/ai/test_conversation_store.py index 28211e29..77a756f2 100644 --- a/tests/integration/ai/test_conversation_store.py +++ b/tests/integration/ai/test_conversation_store.py @@ -186,9 +186,9 @@ async def _model_middleware( thread_id="2", ) response = result.final_message.content - assert ( - "Mike" not in response - ), "Agent remembered the name from a different thread_id" + assert "Mike" not in response, ( + "Agent remembered the name from a different thread_id" + ) assert model_middleware_called diff --git a/tests/integration/ai/test_hooks.py b/tests/integration/ai/test_hooks.py index 71c4f3d5..8ad1601e 100644 --- a/tests/integration/ai/test_hooks.py +++ b/tests/integration/ai/test_hooks.py @@ -30,13 +30,7 @@ before_model, ) from splunklib.ai.messages import AIMessage, AgentResponse, HumanMessage -from splunklib.ai.middleware import ( - AgentRequest, - ModelMiddlewareHandler, - ModelRequest, - ModelResponse, - model_middleware, -) +from splunklib.ai.middleware import AgentRequest, ModelMiddlewareHandler, ModelRequest, ModelResponse, model_middleware from tests.ai_testlib import AITestCase @@ -203,12 +197,10 @@ async def test_agent_loop_stop_conditions_conversation_limit(self) -> None: with pytest.raises( StepsLimitExceededException, match="Steps limit of 2 exceeded" ): - _ = await agent.invoke( - [ - HumanMessage(content="hi, my name is Chris"), - HumanMessage(content="What is my name?"), - ] - ) + _ = await agent.invoke([ + HumanMessage(content="hi, my name is Chris"), + HumanMessage(content="What is my name?"), + ]) @pytest.mark.asyncio async def test_agent_loop_stop_conditions_conversation_limit_with_checkpointer( @@ -228,17 +220,13 @@ async def test_agent_loop_stop_conditions_conversation_limit_with_checkpointer( with pytest.raises( StepsLimitExceededException, match="Steps limit of 2 exceeded" ): - _ = await agent.invoke( - [ - HumanMessage(content="What is my name?"), - HumanMessage(content="Are you sure?"), - ] - ) + _ = await agent.invoke([ + HumanMessage(content="What is my name?"), + HumanMessage(content="Are you sure?"), + ]) @pytest.mark.asyncio - async def test_agent_loop_stop_conditions_steps_accumulate_across_invokes( - self, - ) -> None: + async def test_agent_loop_stop_conditions_steps_accumulate_across_invokes(self) -> None: pytest.importorskip("langchain_openai") step_limit = StepLimitMiddleware(2) diff --git a/tests/integration/ai/test_middleware.py b/tests/integration/ai/test_middleware.py index eec31db4..b2adfed9 100644 --- a/tests/integration/ai/test_middleware.py +++ b/tests/integration/ai/test_middleware.py @@ -500,9 +500,9 @@ async def test_middleware( ) assert subagent_message, "SubagentMessage not found in messages" assert isinstance(subagent_message.result, SubagentTextResult) - assert ( - subagent_message.result.content == "Chris-superstar" - ), "Invalid response from subagent" + assert subagent_message.result.content == "Chris-superstar", ( + "Invalid response from subagent" + ) assert middleware_called, "Middleware was not called" @pytest.mark.asyncio diff --git a/tests/unit/ai/test_default_limits.py b/tests/unit/ai/test_default_limits.py index 1d4bd505..bd998075 100644 --- a/tests/unit/ai/test_default_limits.py +++ b/tests/unit/ai/test_default_limits.py @@ -28,13 +28,7 @@ TokenLimitMiddleware, ) from splunklib.ai.messages import AIMessage, AgentResponse -from splunklib.ai.middleware import ( - AgentMiddleware, - AgentRequest, - AgentState, - ModelRequest, - ModelResponse, -) +from splunklib.ai.middleware import AgentMiddleware, AgentRequest, AgentState, ModelRequest, ModelResponse from splunklib.ai.model import OpenAIModel from splunklib.client import Service @@ -108,11 +102,7 @@ def test_user_timeout_limit_suppresses_default(self) -> None: def test_all_user_limits_suppress_all_defaults(self) -> None: agent = _make_agent( - middleware=[ - TokenLimitMiddleware(50_000), - StepLimitMiddleware(10), - TimeoutLimitMiddleware(30.0), - ] + middleware=[TokenLimitMiddleware(50_000), StepLimitMiddleware(10), TimeoutLimitMiddleware(30.0)] ) mw = list(agent.middleware or []) assert len([m for m in mw if isinstance(m, TokenLimitMiddleware)]) == 1 @@ -151,11 +141,7 @@ async def test_timeout_fires_when_deadline_exceeded(self) -> None: mw = TimeoutLimitMiddleware(60.0) mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past - state = AgentState( - messages=[], - total_steps=0, - token_count=0, - ) + state = AgentState(messages=[], total_steps=0, token_count=0) request = ModelRequest(system_message="", state=state) with self.assertRaises(TimeoutExceededException): @@ -166,29 +152,17 @@ class TestTokenLimitMiddleware(unittest.IsolatedAsyncioTestCase): async def test_raises_when_token_count_in_request_exceeds_limit(self) -> None: mw = TokenLimitMiddleware(200) - await mw.model_middleware( - _make_model_request(token_count=100), _noop_model_handler - ) - await mw.model_middleware( - _make_model_request(token_count=199), _noop_model_handler - ) + await mw.model_middleware(_make_model_request(token_count=100), _noop_model_handler) + await mw.model_middleware(_make_model_request(token_count=199), _noop_model_handler) with self.assertRaises(TokenLimitExceededException): - await mw.model_middleware( - _make_model_request(token_count=200), _noop_model_handler - ) + await mw.model_middleware(_make_model_request(token_count=200), _noop_model_handler) class TestStepLimitMiddleware(unittest.IsolatedAsyncioTestCase): async def test_raises_when_steps_in_request_reach_limit(self) -> None: mw = StepLimitMiddleware(3) - await mw.model_middleware( - _make_model_request(total_steps=1), _noop_model_handler - ) - await mw.model_middleware( - _make_model_request(total_steps=2), _noop_model_handler - ) + await mw.model_middleware(_make_model_request(total_steps=1), _noop_model_handler) + await mw.model_middleware(_make_model_request(total_steps=2), _noop_model_handler) with self.assertRaises(StepsLimitExceededException): - await mw.model_middleware( - _make_model_request(total_steps=3), _noop_model_handler - ) + await mw.model_middleware(_make_model_request(total_steps=3), _noop_model_handler)