Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion splunklib/ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ class. The default for that limit is suppressed automatically - the other defaul
remain active:

```py
from splunklib.ai.hooks import (
from splunklib.ai.limits import (
TokenLimitMiddleware,
StepLimitMiddleware,
TimeoutLimitMiddleware,
Expand Down
2 changes: 1 addition & 1 deletion splunklib/ai/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import BaseModel

from splunklib.ai.conversation_store import ConversationStore
from splunklib.ai.hooks import (
from splunklib.ai.limits import (
DEFAULT_STEP_LIMIT,
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT,
DEFAULT_TIMEOUT_SECONDS,
Expand Down
159 changes: 0 additions & 159 deletions splunklib/ai/hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
from collections.abc import Awaitable, Callable
from time import monotonic
from typing import Any, override

from splunklib.ai.messages import AgentResponse
Expand All @@ -12,44 +11,6 @@
ModelRequest,
ModelResponse,
)
from splunklib.ai.structured_output import StructuredOutputGenerationException

DEFAULT_TIMEOUT_SECONDS: float = 600.0
DEFAULT_STEP_LIMIT: int = 100
DEFAULT_TOKEN_LIMIT: int = 200_000
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3


class AgentStopException(Exception):
"""Custom exception to indicate conversation stopping conditions."""


class TokenLimitExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when token limit exceeds"""

def __init__(self, token_limit: float) -> None:
super().__init__(f"Token limit of {token_limit} exceeded.")


class StepsLimitExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when steps limit exceeds"""

def __init__(self, steps_limit: int) -> None:
super().__init__(f"Steps limit of {steps_limit} exceeded.")


class TimeoutExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when timeout exceeds"""

def __init__(self, timeout_seconds: float) -> None:
super().__init__(f"Timed out after {timeout_seconds} seconds.")


class StructuredOutputRetryLimitExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""

def __init__(self, retry_count: int) -> None:
super().__init__(f"Structured output retry limit of {retry_count} exceeded")


def before_model(
Expand Down Expand Up @@ -132,123 +93,3 @@ async def agent_middleware(
return handler_response

return _Middleware()


class TokenLimitMiddleware(AgentMiddleware):
"""Stops agent execution when the token count of messages passed to the model exceeds the given limit."""

_limit: int

def __init__(self, limit: int) -> None:
self._limit = limit

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
if request.state.token_count >= self._limit:
raise TokenLimitExceededException(token_limit=self._limit)
return await handler(request)


class StepLimitMiddleware(AgentMiddleware):
"""Stops agent execution when the number of steps taken reaches the given limit."""

_limit: int

def __init__(self, limit: int) -> None:
self._limit = limit

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
if request.state.total_steps >= self._limit:
raise StepsLimitExceededException(steps_limit=self._limit)
return await handler(request)


class TimeoutLimitMiddleware(AgentMiddleware):
"""Stops agent execution when wall-clock time within an invoke exceeds the given seconds.

The deadline resets on every invoke call - it measures time from the start of
each invocation, not from agent construction.

Do not share instances between agents.
"""

_seconds: float
_deadline_per_thread_id: dict[str, float]

def __init__(self, seconds: float) -> None:
self._seconds = seconds
self._deadline_per_thread_id = {}

@override
async def agent_middleware(
self,
request: AgentRequest,
handler: AgentMiddlewareHandler,
) -> AgentResponse[Any | None]:
try:
# Agent loop starting.
self._deadline_per_thread_id[request.thread_id] = (
monotonic() + self._seconds
)
return await handler(request)
finally:
del self._deadline_per_thread_id[request.thread_id] # don't leak memory

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
raise TimeoutExceededException(timeout_seconds=self._seconds)
return await handler(request)


class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
"""Stops agent execution when the agent exceeds structured output
retry limit during a single agent loop invocation. Pass 0 to disable retries.
"""

_limit: int
_retries_per_thread_id: dict[str, int]

def __init__(self, limit: int) -> None:
self._limit = limit
self._retries_per_thread_id = {}

@override
async def agent_middleware(
self,
request: AgentRequest,
handler: AgentMiddlewareHandler,
) -> AgentResponse[Any | None]:
try:
# Agent loop starting.
self._retries_per_thread_id[request.thread_id] = 0
return await handler(request)
finally:
del self._retries_per_thread_id[request.thread_id] # don't leak memory

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
try:
return await handler(request)
except StructuredOutputGenerationException:
self._retries_per_thread_id[request.state.thread_id] += 1
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
raise StructuredOutputRetryLimitExceededException(self._limit)
raise # re-raise, to retry structured output generation
184 changes: 184 additions & 0 deletions splunklib/ai/limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright © 2011-2026 Splunk, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"): you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from time import monotonic
Comment thread
mateusz834 marked this conversation as resolved.
from typing import Any, override

from splunklib.ai.messages import AgentResponse
from splunklib.ai.middleware import (
AgentMiddleware,
AgentMiddlewareHandler,
AgentRequest,
ModelMiddlewareHandler,
ModelRequest,
ModelResponse,
)
from splunklib.ai.structured_output import StructuredOutputGenerationException

DEFAULT_TIMEOUT_SECONDS: float = 600.0
DEFAULT_STEP_LIMIT: int = 100
DEFAULT_TOKEN_LIMIT: int = 200_000
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3


class AgentStopException(Exception):
"""Custom exception to indicate conversation stopping conditions."""


class TokenLimitExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when token limit exceeds"""

def __init__(self, token_limit: int) -> None:
super().__init__(f"Token limit of {token_limit} exceeded.")


class StepsLimitExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when steps limit exceeds"""

def __init__(self, steps_limit: int) -> None:
super().__init__(f"Steps limit of {steps_limit} exceeded.")


class TimeoutExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when timeout exceeds"""

def __init__(self, timeout_seconds: float) -> None:
super().__init__(f"Timed out after {timeout_seconds} seconds.")


class StructuredOutputRetryLimitExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""

def __init__(self, retry_count: int) -> None:
super().__init__(f"Structured output retry limit of {retry_count} exceeded")


class TokenLimitMiddleware(AgentMiddleware):
"""Stops agent execution when the token count of messages passed to the model exceeds the given limit."""

_limit: int

def __init__(self, limit: int) -> None:
self._limit = limit

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
if request.state.token_count >= self._limit:
raise TokenLimitExceededException(token_limit=self._limit)
return await handler(request)


class StepLimitMiddleware(AgentMiddleware):
"""Stops agent execution when the number of steps taken reaches the given limit."""

_limit: int

def __init__(self, limit: int) -> None:
self._limit = limit

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
if request.state.total_steps >= self._limit:
raise StepsLimitExceededException(steps_limit=self._limit)
return await handler(request)


class TimeoutLimitMiddleware(AgentMiddleware):
"""Stops agent execution when wall-clock time within an invoke exceeds the given seconds.

The deadline resets on every invoke call - it measures time from the start of
each invocation, not from agent construction.

Do not share instances between agents.
"""

_seconds: float
_deadline_per_thread_id: dict[str, float]

def __init__(self, seconds: float) -> None:
self._seconds = seconds
self._deadline_per_thread_id = {}

@override
async def agent_middleware(
self,
request: AgentRequest,
handler: AgentMiddlewareHandler,
) -> AgentResponse[Any | None]:
try:
# Agent loop starting.
self._deadline_per_thread_id[request.thread_id] = (
monotonic() + self._seconds
)
return await handler(request)
finally:
del self._deadline_per_thread_id[request.thread_id] # don't leak memory

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
raise TimeoutExceededException(timeout_seconds=self._seconds)
return await handler(request)


class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
"""Stops agent execution when the agent exceeds structured output
retry limit during a single agent loop invocation. Pass 0 to disable retires.
"""

_limit: int
_retries_per_thread_id: dict[str, int]

def __init__(self, limit: int) -> None:
self._limit = limit
self._retries_per_thread_id = {}

@override
async def agent_middleware(
self,
request: AgentRequest,
handler: AgentMiddlewareHandler,
) -> AgentResponse[Any | None]:
try:
# Agent loop starting.
self._retries_per_thread_id[request.thread_id] = 0
return await handler(request)
finally:
del self._retries_per_thread_id[request.thread_id] # don't leak memory

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
try:
return await handler(request)
except StructuredOutputGenerationException:
self._retries_per_thread_id[request.state.thread_id] += 1
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
raise StructuredOutputRetryLimitExceededException(self._limit)
raise # re-raise, to retry structured output generation
Loading