Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions astrbot/builtin_stars/astrbot/group_chat_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import datetime
import random
import time
import uuid
from collections import defaultdict, deque

Expand Down Expand Up @@ -79,21 +80,39 @@ async def get_image_caption(
image_url: str,
image_caption_provider_id: str,
image_caption_prompt: str,
event: AstrMessageEvent | None = None,
) -> str:
provider_id = image_caption_provider_id
if not image_caption_provider_id:
provider = self.context.get_using_provider()
provider_id = provider.meta().id if hasattr(provider, "meta") else ""
else:
provider = self.context.get_provider_by_id(image_caption_provider_id)
if not provider:
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
if not isinstance(provider, Provider):
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")

start_time = time.time()
response = await provider.text_chat(
prompt=image_caption_prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
)
end_time = time.time()

# 记录图片转述模型的调用统计
if event is not None:
await provider.record_image_caption_stat(
umo=event.unified_msg_origin,
provider_id=provider_id,
conversation_id=None,
start_time=start_time,
end_time=end_time,
response=response,
)

return response.completion_text
Comment on lines +96 to 116

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Issues Identified:

  1. API Failure Statistics are Not Tracked: If provider.text_chat raises an exception (e.g., due to network timeout, API error, etc.), the exception propagates immediately, and the statistics tracking block is bypassed. This means failed requests are never recorded in the database, making error rate calculations impossible.
  2. Potential AttributeError: Accessing response.usage.input_other and response.role directly can raise an AttributeError if the provider returns a different structure or if the response is malformed. Using getattr is a safer, more defensive approach.
  3. Redundant Method Call: provider.get_model() is called twice. We can simplify this to a single call.

Solution:

Wrap the provider.text_chat call in a try...except block to capture any exceptions, record the "error" status in the database, and then re-raise the exception. Use getattr to safely access response attributes.

        start_time = time.time()
        response = None
        exception = None
        try:
            response = await provider.text_chat(
                prompt=image_caption_prompt,
                session_id=uuid.uuid4().hex,
                image_urls=[image_url],
                persist=False,
            )
        except Exception as e:
            exception = e
        end_time = time.time()

        # 记录图片转述模型的调用统计
        if event is not None:
            try:
                provider_model = provider.get_model() or None
                usage_dict: dict = {}
                if response and getattr(response, "usage", None):
                    usage_dict = {
                        "input_other": getattr(response.usage, "input_other", 0),
                        "input_cached": getattr(response.usage, "input_cached", 0),
                        "output": getattr(response.usage, "output", 0),
                    }

                status = "error"
                if exception is None and response:
                    status = "completed" if getattr(response, "role", "") != "err" else "error"

                await db_helper.insert_provider_stat(
                    umo=event.unified_msg_origin,
                    provider_id=provider_id,
                    provider_model=provider_model,
                    conversation_id=None,
                    status=status,
                    stats={
                        "token_usage": usage_dict,
                        "start_time": start_time,
                        "end_time": end_time,
                        "time_to_first_token": 0.0,
                    },
                    agent_type="internal",
                )
            except Exception:
                logger.debug(
                    "Failed to record group chat image caption provider stat",
                    exc_info=True,
                )

        if exception is not None:
            raise exception

        return response.completion_text


async def need_active_reply(self, event: AstrMessageEvent) -> bool:
Expand Down Expand Up @@ -200,6 +219,7 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
url,
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
event=event,
)
parts.append(f" [Image: {caption}]")
except Exception as e:
Expand Down
19 changes: 19 additions & 0 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import platform
import time
import zoneinfo
from collections.abc import Coroutine
from dataclasses import dataclass, field
Expand Down Expand Up @@ -606,6 +607,8 @@ async def _request_img_caption(
cfg: dict,
image_urls: list[str],
plugin_context: Context,
event: AstrMessageEvent | None = None,
conversation_id: str | None = None,
) -> str:
prov = plugin_context.get_provider_by_id(provider_id)
if prov is None:
Expand All @@ -622,10 +625,24 @@ async def _request_img_caption(
"Please describe the image.",
)
logger.debug("Processing image caption with provider: %s", provider_id)

start_time = time.time()
llm_resp = await prov.text_chat(
prompt=img_cap_prompt,
image_urls=image_urls,
)
end_time = time.time()

if event is not None:
await prov.record_image_caption_stat(
umo=event.unified_msg_origin,
provider_id=provider_id,
conversation_id=conversation_id,
start_time=start_time,
end_time=end_time,
response=llm_resp,
)

return llm_resp.completion_text


Expand All @@ -648,6 +665,8 @@ async def _ensure_img_caption(
cfg,
compressed_urls,
plugin_context,
event=event,
conversation_id=req.conversation.cid if req.conversation else None,
)
if caption:
req.extra_user_content_parts.append(
Expand Down
44 changes: 44 additions & 0 deletions astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,50 @@ async def test(self) -> None:
"""
...

async def record_image_caption_stat(
self,
*,
umo: str,
provider_id: str,
conversation_id: str | None,
start_time: float,
end_time: float,
response: LLMResponse,
) -> None:
"""记录图片转述模型调用统计,由子类调用。"""
try:
from astrbot.core import db_helper, logger
except ImportError:
return

try:
model = self.get_model()
provider_model = model or None
usage_dict: dict = {}
if response.usage:
usage_dict = {
"input_other": response.usage.input_other,
"input_cached": response.usage.input_cached,
"output": response.usage.output,
}

await db_helper.insert_provider_stat(
umo=umo,
provider_id=provider_id,
provider_model=provider_model,
conversation_id=conversation_id,
status="completed" if response.role != "err" else "error",
stats={
"token_usage": usage_dict,
"start_time": start_time,
"end_time": end_time,
"time_to_first_token": 0.0,
},
agent_type="internal",
)
except Exception:
logger.debug("Failed to record image caption provider stat", exc_info=True)


class Provider(AbstractProvider):
"""Chat Provider"""
Expand Down
Loading