Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 64 additions & 49 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ async def _process_quote_message(
quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS,
config: MainAgentBuildConfig | None = None,
main_provider_supports_image: bool = False,
skip_quote_image_caption: bool = False,
) -> None:
quote = None
for comp in event.message_obj.message:
Expand Down Expand Up @@ -805,54 +806,63 @@ async def _process_quote_message(
image_seg = comp
break

if image_seg and main_provider_supports_image:
logger.debug(
"Skipping quote image captioning because the main provider supports image input."
)
elif image_seg and not img_cap_prov_id:
logger.debug(
"No dedicated image caption provider configured. "
"Skipping quote image captioning."
)
elif image_seg:
try:
prov = None
path = None
compress_path = None
prov = plugin_context.get_provider_by_id(img_cap_prov_id)
if prov is None:
prov = plugin_context.get_using_provider(event.unified_msg_origin)

if prov and isinstance(prov, Provider):
path = await image_seg.convert_to_file_path()
compress_path = await _compress_image_for_provider(
path,
config.provider_settings if config else None,
)
if path and _is_generated_compressed_image_path(path, compress_path):
event.track_temporary_local_file(compress_path)
llm_resp = await prov.text_chat(
prompt="Please describe the image content.",
image_urls=[compress_path],
)
if llm_resp.completion_text:
content_parts.append(
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
if image_seg:
if skip_quote_image_caption:
logger.debug(
"Skipping quote image captioning because image captioning already handled this request."
)
elif main_provider_supports_image:
logger.debug(
"Skipping quote image captioning because the main provider supports image input."
)
elif not img_cap_prov_id:
logger.debug(
"No dedicated image caption provider configured. "
"Skipping quote image captioning."
)
else:
try:
prov = None
path = None
compress_path = None
prov = plugin_context.get_provider_by_id(img_cap_prov_id)
if prov is None:
prov = plugin_context.get_using_provider(event.unified_msg_origin)

if prov and isinstance(prov, Provider):
path = await image_seg.convert_to_file_path()
compress_path = await _compress_image_for_provider(
path,
config.provider_settings if config else None,
)
else:
logger.warning("No provider found for image captioning in quote.")
except BaseException as exc:
logger.error("处理引用图片失败: %s", exc)
finally:
if (
compress_path
and compress_path != path
and os.path.exists(compress_path)
):
try:
os.remove(compress_path)
except Exception as exc: # noqa: BLE001
logger.warning("Fail to remove temporary compressed image: %s", exc)
if path and _is_generated_compressed_image_path(
path, compress_path
):
event.track_temporary_local_file(compress_path)
llm_resp = await prov.text_chat(
prompt="Please describe the image content.",
image_urls=[compress_path],
)
if llm_resp.completion_text:
content_parts.append(
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
)
else:
logger.warning("No provider found for image captioning in quote.")
except BaseException as exc:
logger.error("处理引用图片失败: %s", exc)
finally:
if (
compress_path
and compress_path != path
and os.path.exists(compress_path)
):
try:
os.remove(compress_path)
except Exception as exc: # noqa: BLE001
logger.warning(
"Fail to remove temporary compressed image: %s", exc
)

quoted_content = "\n".join(content_parts)
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
Expand Down Expand Up @@ -918,11 +928,12 @@ async def _decorate_llm_request(
main_provider_supports_image = provider is not None and _provider_supports_modality(
provider, "image"
)
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
quote_images_already_captioned = False

if req.conversation:
await _ensure_persona_and_skills(req, cfg, plugin_context, event)

img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
if img_cap_prov_id and req.image_urls and not main_provider_supports_image:
await _ensure_img_caption(
event,
Expand All @@ -931,8 +942,11 @@ async def _decorate_llm_request(
plugin_context,
img_cap_prov_id,
)
quote_images_already_captioned = any(
"<image_caption>" in getattr(part, "text", "")
for part in req.extra_user_content_parts
)

img_cap_prov_id = cfg.get("default_image_caption_provider_id") or ""
quoted_message_settings = _get_quoted_message_parser_settings(cfg)
await _process_quote_message(
event,
Expand All @@ -942,6 +956,7 @@ async def _decorate_llm_request(
quoted_message_settings,
config,
main_provider_supports_image=main_provider_supports_image,
skip_quote_image_caption=quote_images_already_captioned,
)

tz = config.timezone
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,77 @@ async def test_build_main_agent_skips_caption_when_main_provider_supports_images
)
mock_provider.text_chat.assert_not_called()

@pytest.mark.asyncio
async def test_build_main_agent_does_not_caption_quoted_image_twice(
self, mock_event, mock_context
):
"""Quoted images should not be captioned again after request image captioning."""
module = ama
text_provider = MagicMock(spec=Provider)
text_provider.provider_config = {
"id": "text-provider",
"modalities": ["text", "tool_use"],
}
text_provider.get_model.return_value = "text-model"

caption_provider = MagicMock(spec=Provider)
caption_provider.text_chat = AsyncMock(
return_value=MagicMock(completion_text="quoted image caption")
)

mock_reply = Reply(
id="reply-1",
chain=[Plain(text="quoted text"), Image(file="file:///tmp/quoted.jpg")],
sender_nickname="Alice",
message_str="quoted text",
)
mock_event.message_obj.message = [Plain(text="Hello"), mock_reply]

mock_context.get_provider_by_id.return_value = caption_provider
mock_context.get_using_provider.return_value = text_provider
mock_context.get_config.return_value = {}

conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)

with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
patch.object(
Image,
"convert_to_file_path",
AsyncMock(return_value="/tmp/quoted.jpg"),
),
patch(
"astrbot.core.astr_main_agent._compress_image_for_provider",
AsyncMock(side_effect=lambda path, _settings: path),
),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner

result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(
tool_call_timeout=60,
provider_settings={
"default_image_caption_provider_id": "caption-provider",
},
),
provider=text_provider,
)

assert result is not None
assert caption_provider.text_chat.await_count == 1

extra_text = "\n".join(
part.text for part in result.provider_request.extra_user_content_parts
)
assert "<image_caption>quoted image caption</image_caption>" in extra_text
assert "[Image Caption in quoted message]" not in extra_text

@pytest.mark.asyncio
async def test_build_main_agent_uses_image_fallback_provider(
self, mock_event, mock_context
Expand Down
Loading