Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions astrbot/core/pipeline/preprocess_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ async def process(
logger.debug(f"路径映射: {url} -> {component.url}")
message_chain[idx] = component

failed_record_ids: set[int] = set()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between processing direct Record components (lines 70-84) and processing Record components inside Reply chains (lines 87-103). Both blocks perform the exact same WAV conversion, temporary file tracking, and error handling.

Following the general rule to avoid code duplication when implementing similar functionality for direct vs. quoted attachments, we should refactor this logic into a shared helper function.

Here is an example of how you can refactor this:

        async def _process_record(record_comp: Record, context: str = "") -> bool:
            try:
                original_path = await record_comp.convert_to_file_path()
                record_path = await ensure_wav(original_path)
                if record_path != original_path:
                    event.track_temporary_local_file(record_path)
                record_comp.file = record_path
                record_comp.path = record_path
                return True
            except Exception as e:
                failed_record_ids.add(id(record_comp))
                logger.warning(f"Voice processing {context}failed: {e}")
                return False

Then, the loops can be simplified to:

        # Process direct Record components
        message_chain = event.get_messages()
        for idx, component in enumerate(message_chain):
            if isinstance(component, Record):
                if await _process_record(component):
                    message_chain[idx] = component

        # Process Record components inside Reply chains
        for component in event.get_messages():
            if isinstance(component, Reply) and component.chain:
                for idx, reply_comp in enumerate(component.chain):
                    if isinstance(reply_comp, Record):
                        if await _process_record(reply_comp, "in reply chain "):
                            component.chain[idx] = reply_comp
References
  1. When implementing similar functionality for different cases (e.g., direct vs. quoted attachments), refactor the logic into a shared helper function to avoid code duplication.


# In here, we convert all Record components to wav format and update the file path.
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
Comment on lines +68 to 72

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Tracking failed records by id() may be brittle if components are reconstructed or copied.

This relies on the same Record instances being reused throughout the STT pass. If event.get_messages() or related logic can recreate or copy Records, their id() values will change and failed_record_ids will no longer prevent reprocessing failed items. In that case, prefer a stable identifier (e.g., a uid field, file path, or a boolean flag on the record) instead of id() and confirm which invariant actually holds in this pipeline.

Expand All @@ -78,6 +80,7 @@ async def process(
component.path = record_path
message_chain[idx] = component
except Exception as e:
failed_record_ids.add(id(component))
logger.warning(f"Voice processing failed: {e}")

# Also process Record components inside Reply chains (wav conversion)
Expand All @@ -94,6 +97,7 @@ async def process(
reply_comp.path = record_path
component.chain[idx] = reply_comp
except Exception as e:
failed_record_ids.add(id(reply_comp))
logger.warning(
f"Voice processing in reply chain failed: {e}"
)
Expand Down Expand Up @@ -141,7 +145,10 @@ async def _stt_record(record_comp: Record, is_reply: bool = False):

message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record):
if (
isinstance(component, Record)
and id(component) not in failed_record_ids
):
plain_comp = await _stt_record(component)
if plain_comp:
message_chain[idx] = plain_comp
Expand All @@ -152,7 +159,10 @@ async def _stt_record(record_comp: Record, is_reply: bool = False):
for component in event.get_messages():
if isinstance(component, Reply) and component.chain:
for idx, reply_comp in enumerate(component.chain):
if isinstance(reply_comp, Record):
if (
isinstance(reply_comp, Record)
and id(reply_comp) not in failed_record_ids
):
plain_comp = await _stt_record(reply_comp, is_reply=True)
if plain_comp:
component.chain[idx] = plain_comp
Expand Down
83 changes: 83 additions & 0 deletions tests/test_preprocess_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest

from astrbot.core.message.components import Plain, Record, Reply
from astrbot.core.pipeline.preprocess_stage.stage import PreProcessStage


def _make_stage(stt_provider: AsyncMock) -> PreProcessStage:
stage = PreProcessStage()
stage.config = {}
stage.platform_settings = {}
stage.stt_settings = {"enable": True}
stage.plugin_manager = SimpleNamespace(
context=SimpleNamespace(get_using_stt_provider=lambda _: stt_provider)
)
return stage


def _make_event(messages: list) -> MagicMock:
event = MagicMock()
event.get_platform_name.return_value = "test"
event.is_at_or_wake_command = False
event.get_messages.return_value = messages
event.unified_msg_origin = "test:friend:test"
event.message_str = ""
event.message_obj.message_str = ""
return event


@pytest.mark.asyncio
async def test_failed_audio_conversion_is_not_sent_to_stt(monkeypatch):
failed_record = Record(file="failed.amr")
valid_record = Record(file="valid.wav")
messages = [failed_record, valid_record]
stt_provider = AsyncMock()
stt_provider.get_text.return_value = "transcribed"

async def convert_to_file_path(record):
return record.file

async def convert_to_wav(path):
if path == "failed.amr":
raise RuntimeError("ffmpeg not found")
return path

monkeypatch.setattr(Record, "convert_to_file_path", convert_to_file_path)
monkeypatch.setattr(
"astrbot.core.pipeline.preprocess_stage.stage.ensure_wav",
convert_to_wav,
)

await _make_stage(stt_provider).process(_make_event(messages))

assert messages[0] is failed_record
assert isinstance(messages[1], Plain)
assert messages[1].text == "transcribed"
stt_provider.get_text.assert_awaited_once_with(audio_url="valid.wav")


@pytest.mark.asyncio
async def test_failed_reply_audio_conversion_is_not_sent_to_stt(monkeypatch):
failed_record = Record(file="failed.amr")
reply = Reply(id="reply-id", chain=[failed_record])
stt_provider = AsyncMock()

async def convert_to_file_path(record):
return record.file

async def convert_to_wav(_):
raise RuntimeError("ffmpeg not found")

monkeypatch.setattr(Record, "convert_to_file_path", convert_to_file_path)
monkeypatch.setattr(
"astrbot.core.pipeline.preprocess_stage.stage.ensure_wav",
convert_to_wav,
)

await _make_stage(stt_provider).process(_make_event([reply]))

assert reply.chain == [failed_record]
stt_provider.get_text.assert_not_awaited()
Loading