From 9fa22171f841c00f7e404f13d253f5da5b9f01aa Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 18:41:37 -0400 Subject: [PATCH 01/15] PYTHON-5676 Add command_runner.run_command; route network.command() through it Introduce pymongo/asynchronous/command_runner.py (auto-generates the sync mirror), the single async code path for command execution. run_command owns the full skeleton: STARTED/SUCCEEDED/FAILED command logging AND APM publishing together, the network round trip, $clusterTime gossip, _process_response, _check_command_response, failure conversion, and auto-encryption decryption. Route network.command() through it, removing the duplicated logging/APM/send/ receive/decrypt block. Behavior is preserved byte-for-byte (logging and APM event documents unchanged); no per-command object is allocated, so the hot path is unchanged. --- pymongo/asynchronous/command_runner.py | 254 +++++++++++++++++++++++++ pymongo/asynchronous/network.py | 175 +++-------------- pymongo/synchronous/command_runner.py | 254 +++++++++++++++++++++++++ pymongo/synchronous/network.py | 176 +++-------------- 4 files changed, 557 insertions(+), 302 deletions(-) create mode 100644 pymongo/asynchronous/command_runner.py create mode 100644 pymongo/synchronous/command_runner.py diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py new file mode 100644 index 0000000000..d70d926d84 --- /dev/null +++ b/pymongo/asynchronous/command_runner.py @@ -0,0 +1,254 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The single code path for executing a command over a connection. + +Every database operation -- standard commands, cursor ``find``/``getMore`` +operations, and (collection-level and client-level) bulk writes -- runs its +network round trip through :func:`run_command`. The function owns the entire +shared skeleton: command logging, APM event publishing, ``send``/``receive``, +``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, +failure conversion, and auto-encryption decryption. Callers supply only the +parts that vary (the encoded message and a handful of transport/output hooks). +""" +from __future__ import annotations + +import datetime +import logging +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, + cast, +) + +from bson import _decode_all_selective +from pymongo import helpers_shared +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception +from pymongo.network_layer import async_receive_message, async_sendall + +if TYPE_CHECKING: + from bson import CodecOptions + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.mongo_client import AsyncMongoClient + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.message import _OpMsg, _OpReply + from pymongo.monitoring import _EventListeners + from pymongo.typings import _Address, _DocumentOut, _DocumentType + +_IS_SYNC = False + + +async def run_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + unacknowledged: bool = False, + speculative_hello: bool = False, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. + + This is the single code path for command execution. It publishes the + ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs + the network round trip, gossips ``$clusterTime``, runs + ``client._process_response`` and ``_check_command_response``, and decrypts + the reply when auto-encryption is enabled. + + :param conn: The AsyncConnection to send on. + :param cmd: The command document, used for the ``STARTED`` log event. + :param dbname: The database the command runs against. + :param request_id: The request id of the encoded message. + :param msg: The encoded OP_MSG bytes to send. + :param client: The AsyncMongoClient, for ``$clusterTime`` gossip, logging, + and decryption. ``None`` disables those steps (e.g. during handshake). + :param session: The session to update from the response. + :param listeners: The event listeners, or ``None`` to disable APM. + :param address: The (host, port) of ``conn`` for APM events. + :param start: The ``datetime`` the operation began, for duration timing. + :param codec_options: The CodecOptions used to decode the reply. + :param user_fields: Response fields decoded with the codec's TypeDecoders. + :param orig: The command document published in the ``STARTED`` APM event; + defaults to ``cmd`` (differs only when the wire command was mutated, + e.g. with a read preference or after encryption). + :param op_id: The APM operation id; defaults to ``request_id``. + :param check: Raise OperationFailure on a command error. + :param allowable_errors: Errors to ignore when ``check`` is True. + :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param unacknowledged: True for an unacknowledged write: send only and fake + an ``{"ok": 1}`` reply. + :param speculative_hello: True if the command carried speculative auth, for + APM redaction. + """ + name = next(iter(cmd)) + if orig is None: + orig = cmd + publish = listeners is not None and listeners.enabled_for_commands + + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + ) + + try: + await async_sendall(conn.conn.get_conn, msg) + if unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + docs: list[dict[str, Any]] = [{"ok": 1}] + else: + reply = await async_receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + + duration = datetime.datetime.now() - start + response_doc = docs[0] + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=response_doc, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + response_doc, + name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply: + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) + docs = cast( + "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) + ) + + return docs, reply, duration diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index b7de9253f6..d37fb2acff 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -16,7 +16,6 @@ from __future__ import annotations import datetime -import logging from typing import ( TYPE_CHECKING, Any, @@ -25,23 +24,13 @@ Optional, Sequence, Union, - cast, ) -from bson import _decode_all_selective -from pymongo import _csot, helpers_shared, message +from pymongo import _csot, message +from pymongo.asynchronous.command_runner import run_command from pymongo.compression_support import _NO_COMPRESSION -from pymongo.errors import ( - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - async_receive_message, - async_sendall, -) if TYPE_CHECKING: from bson import CodecOptions @@ -52,7 +41,7 @@ from pymongo.monitoring import _EventListeners from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode - from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.typings import _Address, _CollationIn, _DocumentType from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -148,140 +137,24 @@ async def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - await async_sendall(conn.conn.get_conn, msg) - if unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = await async_receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - - return response_doc # type: ignore[return-value] + docs, _, _ = await run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + unacknowledged=unacknowledged, + speculative_hello=speculative_hello, + ) + return docs[0] # type: ignore[return-value] diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py new file mode 100644 index 0000000000..2e272a5eab --- /dev/null +++ b/pymongo/synchronous/command_runner.py @@ -0,0 +1,254 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The single code path for executing a command over a connection. + +Every database operation -- standard commands, cursor ``find``/``getMore`` +operations, and (collection-level and client-level) bulk writes -- runs its +network round trip through :func:`run_command`. The function owns the entire +shared skeleton: command logging, APM event publishing, ``send``/``receive``, +``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, +failure conversion, and auto-encryption decryption. Callers supply only the +parts that vary (the encoded message and a handful of transport/output hooks). +""" +from __future__ import annotations + +import datetime +import logging +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, + cast, +) + +from bson import _decode_all_selective +from pymongo import helpers_shared +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception +from pymongo.network_layer import receive_message, sendall + +if TYPE_CHECKING: + from bson import CodecOptions + from pymongo.message import _OpMsg, _OpReply + from pymongo.monitoring import _EventListeners + from pymongo.synchronous.client_session import ClientSession + from pymongo.synchronous.mongo_client import MongoClient + from pymongo.synchronous.pool import Connection + from pymongo.typings import _Address, _DocumentOut, _DocumentType + +_IS_SYNC = True + + +def run_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + unacknowledged: bool = False, + speculative_hello: bool = False, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. + + This is the single code path for command execution. It publishes the + ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs + the network round trip, gossips ``$clusterTime``, runs + ``client._process_response`` and ``_check_command_response``, and decrypts + the reply when auto-encryption is enabled. + + :param conn: The Connection to send on. + :param cmd: The command document, used for the ``STARTED`` log event. + :param dbname: The database the command runs against. + :param request_id: The request id of the encoded message. + :param msg: The encoded OP_MSG bytes to send. + :param client: The MongoClient, for ``$clusterTime`` gossip, logging, + and decryption. ``None`` disables those steps (e.g. during handshake). + :param session: The session to update from the response. + :param listeners: The event listeners, or ``None`` to disable APM. + :param address: The (host, port) of ``conn`` for APM events. + :param start: The ``datetime`` the operation began, for duration timing. + :param codec_options: The CodecOptions used to decode the reply. + :param user_fields: Response fields decoded with the codec's TypeDecoders. + :param orig: The command document published in the ``STARTED`` APM event; + defaults to ``cmd`` (differs only when the wire command was mutated, + e.g. with a read preference or after encryption). + :param op_id: The APM operation id; defaults to ``request_id``. + :param check: Raise OperationFailure on a command error. + :param allowable_errors: Errors to ignore when ``check`` is True. + :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param unacknowledged: True for an unacknowledged write: send only and fake + an ``{"ok": 1}`` reply. + :param speculative_hello: True if the command carried speculative auth, for + APM redaction. + """ + name = next(iter(cmd)) + if orig is None: + orig = cmd + publish = listeners is not None and listeners.enabled_for_commands + + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + ) + + try: + sendall(conn.conn.get_conn, msg) + if unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + docs: list[dict[str, Any]] = [{"ok": 1}] + else: + reply = receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + + duration = datetime.datetime.now() - start + response_doc = docs[0] + if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=response_doc, + commandName=name, + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + response_doc, + name, + request_id, + address, + conn.server_connection_id, + op_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + docs = cast( + "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) + ) + + return docs, reply, duration diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index b7516a523f..07b285e59e 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,7 +16,6 @@ from __future__ import annotations import datetime -import logging from typing import ( TYPE_CHECKING, Any, @@ -25,23 +24,13 @@ Optional, Sequence, Union, - cast, ) -from bson import _decode_all_selective -from pymongo import _csot, helpers_shared, message +from pymongo import _csot, message from pymongo.compression_support import _NO_COMPRESSION -from pymongo.errors import ( - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - receive_message, - sendall, -) +from pymongo.synchronous.command_runner import run_command if TYPE_CHECKING: from bson import CodecOptions @@ -52,7 +41,7 @@ from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection - from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.typings import _Address, _CollationIn, _DocumentType from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -148,140 +137,25 @@ def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - sendall(conn.conn.get_conn, msg) - if unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - - return response_doc # type: ignore[return-value] + docs, _, _ = run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + unacknowledged=unacknowledged, + speculative_hello=speculative_hello, + ) + return docs[0] # type: ignore[return-value] +>>>>>>> 0d7dedb0 (PYTHON-5676 Add command_runner.run_command; route network.command() through it) From 41f35e1635cd02c57a7fc97798707baa4a6c53c4 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 18:55:19 -0400 Subject: [PATCH 02/15] PYTHON-5676 Route Server.run_operation() through run_command Extend run_command with the cursor transport (conn.send_message/receive_message, exhaust more_to_come receive-only) and output hooks (unpack_res, cursor_id, is_command_response for legacy OP_QUERY, pool_opts, command_name, ensure_db for $db gossip, and a reply_doc_builder for the find/getMore/explain APM reply format). run_operation now builds the message, supplies the reply-doc builder, and keeps the Response/PinnedResponse wrapping; everything between is the shared run_command path. The legacy OP_QUERY response shaping is preserved (is_command_response=use_cmd), not deleted -- that dead-code cleanup stays out of this consolidation. Behavior (logging, APM events, exhaust/pinning, decryption) is unchanged. --- pymongo/asynchronous/command_runner.py | 136 +++++++++++++---- pymongo/asynchronous/server.py | 189 ++++++----------------- pymongo/synchronous/command_runner.py | 140 ++++++++++++----- pymongo/synchronous/server.py | 199 +++++++------------------ 4 files changed, 311 insertions(+), 353 deletions(-) diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index d70d926d84..53a2ff803a 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -14,13 +14,13 @@ """The single code path for executing a command over a connection. -Every database operation -- standard commands, cursor ``find``/``getMore`` -operations, and (collection-level and client-level) bulk writes -- runs its -network round trip through :func:`run_command`. The function owns the entire -shared skeleton: command logging, APM event publishing, ``send``/``receive``, -``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, -failure conversion, and auto-encryption decryption. Callers supply only the -parts that vary (the encoded message and a handful of transport/output hooks). +Every database operation -- standard commands and cursor ``find``/``getMore`` +operations -- runs its network round trip through :func:`run_command`. The +function owns the entire shared skeleton: command logging, APM event +publishing, ``send``/``receive``, ``$clusterTime`` gossip, +``_process_response``, ``_check_command_response``, failure conversion, and +auto-encryption decryption. Callers supply only the parts that vary (the +encoded message and a handful of transport/output hooks). """ from __future__ import annotations @@ -29,6 +29,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -51,6 +52,7 @@ from pymongo.asynchronous.pool import AsyncConnection from pymongo.message import _OpMsg, _OpReply from pymongo.monitoring import _EventListeners + from pymongo.pool_options import PoolOptions from pymongo.typings import _Address, _DocumentOut, _DocumentType _IS_SYNC = False @@ -72,11 +74,24 @@ async def run_command( user_fields: Optional[Mapping[str, Any]] = None, orig: Optional[MutableMapping[str, Any]] = None, op_id: Optional[int] = None, + command_name: Optional[str] = None, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, parse_write_concern_error: bool = False, + pool_opts: Optional[PoolOptions] = None, unacknowledged: bool = False, speculative_hello: bool = False, + ensure_db: bool = False, + use_conn_transport: bool = False, + max_doc_size: int = 0, + more_to_come: bool = False, + set_conn_more_to_come: bool = True, + is_command_response: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, + reply_doc_builder: Optional[ + Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + ] = None, ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. @@ -87,10 +102,11 @@ async def run_command( the reply when auto-encryption is enabled. :param conn: The AsyncConnection to send on. - :param cmd: The command document, used for the ``STARTED`` log event. + :param cmd: The command document, used for the ``STARTED`` log/APM event. :param dbname: The database the command runs against. - :param request_id: The request id of the encoded message. - :param msg: The encoded OP_MSG bytes to send. + :param request_id: The request id of the encoded message (``0`` when + ``more_to_come`` and no message is sent). + :param msg: The encoded bytes to send (ignored when ``more_to_come``). :param client: The AsyncMongoClient, for ``$clusterTime`` gossip, logging, and decryption. ``None`` disables those steps (e.g. during handshake). :param session: The session to update from the response. @@ -103,15 +119,40 @@ async def run_command( defaults to ``cmd`` (differs only when the wire command was mutated, e.g. with a read preference or after encryption). :param op_id: The APM operation id; defaults to ``request_id``. + :param command_name: The command name for the ``SUCCEEDED``/``FAILED`` APM + events; defaults to the first key of ``cmd``. :param check: Raise OperationFailure on a command error. :param allowable_errors: Errors to ignore when ``check`` is True. :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param pool_opts: PoolOptions forwarded to ``_check_command_response`` (the + cursor path uses this in place of ``allowable_errors``). :param unacknowledged: True for an unacknowledged write: send only and fake an ``{"ok": 1}`` reply. :param speculative_hello: True if the command carried speculative auth, for APM redaction. + :param ensure_db: Add ``$db`` to the published command if missing (cursor + path), after the ``STARTED`` log has been emitted. + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` (cursor path) instead of the raw + ``async_sendall`` / ``async_receive_message`` (network path). + :param max_doc_size: The largest document size, for ``conn.send_message``. + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the + network/streaming-monitor path); the cursor path manages exhaust + separately and must leave ``conn.more_to_come`` untouched. + :param is_command_response: True if the reply is an OP_MSG command response + (``_process_response``/``_check_command_response``/decryption apply); + False for a legacy OP_QUERY cursor response. + :param unpack_res: A callable decoding the wire response (cursor path); when + ``None`` the reply's own ``unpack_response`` is used. + :param cursor_id: The cursor id passed to ``unpack_res``. + :param reply_doc_builder: Builds the reply document published in the + ``SUCCEEDED`` event from ``(docs, reply)`` (cursor find/getMore format); + when ``None`` the first decoded document is published. """ name = next(iter(cmd)) + if command_name is None: + command_name = name if orig is None: orig = cmd publish = listeners is not None and listeners.enabled_for_commands @@ -135,6 +176,8 @@ async def run_command( if publish: assert listeners is not None assert address is not None + if ensure_db and "$db" not in orig: + orig["$db"] = dbname listeners.publish_command_start( orig, dbname, @@ -145,30 +188,53 @@ async def run_command( service_id=conn.service_id, ) + reply: Optional[Union[_OpReply, _OpMsg]] try: - await async_sendall(conn.conn.get_conn, msg) - if unacknowledged: + if more_to_come: + reply = await conn.receive_message(None) + elif use_conn_transport: + if session is not None and session._starting_transaction: + session._transaction.set_in_progress() + await conn.send_message(msg, max_doc_size) + reply = await conn.receive_message(request_id) + elif unacknowledged: + await async_sendall(conn.conn.get_conn, msg) # Unacknowledged, fake a successful command response. reply = None docs: list[dict[str, Any]] = [{"ok": 1}] else: + await async_sendall(conn.conn.get_conn, msg) reply = await async_receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) - response_doc = docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, + + if reply is not None: + if set_conn_more_to_come: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + legacy_response=not is_command_response, + user_fields=user_fields, ) + else: + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + if is_command_response: + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + pool_opts=pool_opts, + ) except Exception as exc: duration = datetime.datetime.now() - start if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -199,7 +265,7 @@ async def run_command( listeners.publish_command_failure( duration, failure, - name, + command_name, request_id, address, conn.server_connection_id, @@ -210,14 +276,18 @@ async def run_command( raise duration = datetime.datetime.now() - start - response_doc = docs[0] + published_reply: _DocumentOut + if reply_doc_builder is not None: + published_reply = reply_doc_builder(docs, reply) + else: + published_reply = docs[0] if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, message=_CommandStatusMessage.SUCCEEDED, clientId=client._topology_settings._topology_id, durationMS=duration, - reply=response_doc, + reply=published_reply, commandName=name, databaseName=dbname, requestId=request_id, @@ -234,8 +304,8 @@ async def run_command( assert address is not None listeners.publish_command_success( duration, - response_doc, - name, + published_reply, + command_name, request_id, address, conn.server_connection_id, @@ -245,7 +315,7 @@ async def run_command( database_name=dbname, ) - if client and client._encrypter and reply: + if client and client._encrypter and reply and is_command_response: decrypted = await client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 39d422d038..b18cf56c52 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -26,18 +26,14 @@ Union, ) -from bson import _decode_all_selective +from pymongo.asynchronous.command_runner import run_command from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( - _COMMAND_LOGGER, _SDAM_LOGGER, - _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.message import _GetMore, _OpMsg, _OpReply, _Query from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: @@ -158,7 +154,6 @@ async def run_operation( :param client: An AsyncMongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) @@ -166,144 +161,58 @@ async def run_operation( cmd, dbn = await self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 + data = b"" + max_doc_size = 0 else: message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) + user_fields = _CURSOR_DOC_FIELDS if use_cmd else None - try: - if more_to_come: - reply = await conn.receive_message(None) - else: - if operation.session is not None and operation.session._starting_transaction: - operation.session._transaction.set_in_progress() - await conn.send_message(data, max_doc_size) - reply = await conn.receive_message(request_id) - - # Unpack and check for command errors. - if use_cmd: - user_fields = _CURSOR_DOC_FIELDS - legacy_response = False - else: - user_fields = None - legacy_response = True - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields, - ) + def _build_reply_doc( + docs: list[dict[str, Any]], reply: Optional[Union[_OpReply, _OpMsg]] + ) -> _DocumentOut: + # Must publish in find / getMore / explain command response format. if use_cmd: - first = docs[0] - await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] + return docs[0] + elif operation.name == "explain": + return docs[0] if docs else {} + res: dict[str, Any] = { + "cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, # type: ignore[union-attr] + "ok": 1, + } + if operation.name == "find": + res["cursor"]["firstBatch"] = docs else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response - # format. - res = docs[0] - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=res, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - if use_cmd: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + res["cursor"]["nextBatch"] = docs + return res + + docs, reply, duration = await run_command( + conn, + cmd, + dbn, + request_id, + data, + client=client, + session=operation.session, # type: ignore[arg-type] + listeners=listeners, + address=conn.address, + start=start, + codec_options=operation.codec_options, + user_fields=user_fields, + command_name=operation.name, + pool_opts=conn.opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=bool(more_to_come), + set_conn_more_to_come=False, + is_command_response=use_cmd, + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + reply_doc_builder=_build_reply_doc, + ) + assert reply is not None response: Response @@ -325,7 +234,7 @@ async def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] more_to_come=more_to_come, ) else: @@ -335,7 +244,7 @@ async def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] ) return response diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 2e272a5eab..4bd93400cf 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -14,13 +14,13 @@ """The single code path for executing a command over a connection. -Every database operation -- standard commands, cursor ``find``/``getMore`` -operations, and (collection-level and client-level) bulk writes -- runs its -network round trip through :func:`run_command`. The function owns the entire -shared skeleton: command logging, APM event publishing, ``send``/``receive``, -``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, -failure conversion, and auto-encryption decryption. Callers supply only the -parts that vary (the encoded message and a handful of transport/output hooks). +Every database operation -- standard commands and cursor ``find``/``getMore`` +operations -- runs its network round trip through :func:`run_command`. The +function owns the entire shared skeleton: command logging, APM event +publishing, ``send``/``receive``, ``$clusterTime`` gossip, +``_process_response``, ``_check_command_response``, failure conversion, and +auto-encryption decryption. Callers supply only the parts that vary (the +encoded message and a handful of transport/output hooks). """ from __future__ import annotations @@ -29,6 +29,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -46,11 +47,12 @@ if TYPE_CHECKING: from bson import CodecOptions - from pymongo.message import _OpMsg, _OpReply - from pymongo.monitoring import _EventListeners from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection + from pymongo.message import _OpMsg, _OpReply + from pymongo.monitoring import _EventListeners + from pymongo.pool_options import PoolOptions from pymongo.typings import _Address, _DocumentOut, _DocumentType _IS_SYNC = True @@ -72,11 +74,24 @@ def run_command( user_fields: Optional[Mapping[str, Any]] = None, orig: Optional[MutableMapping[str, Any]] = None, op_id: Optional[int] = None, + command_name: Optional[str] = None, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, parse_write_concern_error: bool = False, + pool_opts: Optional[PoolOptions] = None, unacknowledged: bool = False, speculative_hello: bool = False, + ensure_db: bool = False, + use_conn_transport: bool = False, + max_doc_size: int = 0, + more_to_come: bool = False, + set_conn_more_to_come: bool = True, + is_command_response: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, + reply_doc_builder: Optional[ + Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + ] = None, ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. @@ -87,10 +102,11 @@ def run_command( the reply when auto-encryption is enabled. :param conn: The Connection to send on. - :param cmd: The command document, used for the ``STARTED`` log event. + :param cmd: The command document, used for the ``STARTED`` log/APM event. :param dbname: The database the command runs against. - :param request_id: The request id of the encoded message. - :param msg: The encoded OP_MSG bytes to send. + :param request_id: The request id of the encoded message (``0`` when + ``more_to_come`` and no message is sent). + :param msg: The encoded bytes to send (ignored when ``more_to_come``). :param client: The MongoClient, for ``$clusterTime`` gossip, logging, and decryption. ``None`` disables those steps (e.g. during handshake). :param session: The session to update from the response. @@ -103,15 +119,40 @@ def run_command( defaults to ``cmd`` (differs only when the wire command was mutated, e.g. with a read preference or after encryption). :param op_id: The APM operation id; defaults to ``request_id``. + :param command_name: The command name for the ``SUCCEEDED``/``FAILED`` APM + events; defaults to the first key of ``cmd``. :param check: Raise OperationFailure on a command error. :param allowable_errors: Errors to ignore when ``check`` is True. :param parse_write_concern_error: Parse the ``writeConcernError`` field. + :param pool_opts: PoolOptions forwarded to ``_check_command_response`` (the + cursor path uses this in place of ``allowable_errors``). :param unacknowledged: True for an unacknowledged write: send only and fake an ``{"ok": 1}`` reply. :param speculative_hello: True if the command carried speculative auth, for APM redaction. + :param ensure_db: Add ``$db`` to the published command if missing (cursor + path), after the ``STARTED`` log has been emitted. + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` (cursor path) instead of the raw + ``sendall`` / ``receive_message`` (network path). + :param max_doc_size: The largest document size, for ``conn.send_message``. + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the + network/streaming-monitor path); the cursor path manages exhaust + separately and must leave ``conn.more_to_come`` untouched. + :param is_command_response: True if the reply is an OP_MSG command response + (``_process_response``/``_check_command_response``/decryption apply); + False for a legacy OP_QUERY cursor response. + :param unpack_res: A callable decoding the wire response (cursor path); when + ``None`` the reply's own ``unpack_response`` is used. + :param cursor_id: The cursor id passed to ``unpack_res``. + :param reply_doc_builder: Builds the reply document published in the + ``SUCCEEDED`` event from ``(docs, reply)`` (cursor find/getMore format); + when ``None`` the first decoded document is published. """ name = next(iter(cmd)) + if command_name is None: + command_name = name if orig is None: orig = cmd publish = listeners is not None and listeners.enabled_for_commands @@ -135,6 +176,8 @@ def run_command( if publish: assert listeners is not None assert address is not None + if ensure_db and "$db" not in orig: + orig["$db"] = dbname listeners.publish_command_start( orig, dbname, @@ -145,30 +188,53 @@ def run_command( service_id=conn.service_id, ) + reply: Optional[Union[_OpReply, _OpMsg]] try: - sendall(conn.conn.get_conn, msg) - if unacknowledged: + if more_to_come: + reply = conn.receive_message(None) + elif use_conn_transport: + if session is not None and session._starting_transaction: + session._transaction.set_in_progress() + conn.send_message(msg, max_doc_size) + reply = conn.receive_message(request_id) + elif unacknowledged: + sendall(conn.conn.get_conn, msg) # Unacknowledged, fake a successful command response. reply = None docs: list[dict[str, Any]] = [{"ok": 1}] else: + sendall(conn.conn.get_conn, msg) reply = receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) - response_doc = docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, + + if reply is not None: + if set_conn_more_to_come: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + legacy_response=not is_command_response, + user_fields=user_fields, ) + else: + docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + if is_command_response: + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + pool_opts=pool_opts, + ) except Exception as exc: duration = datetime.datetime.now() - start if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -199,7 +265,7 @@ def run_command( listeners.publish_command_failure( duration, failure, - name, + command_name, request_id, address, conn.server_connection_id, @@ -210,14 +276,18 @@ def run_command( raise duration = datetime.datetime.now() - start - response_doc = docs[0] + published_reply: _DocumentOut + if reply_doc_builder is not None: + published_reply = reply_doc_builder(docs, reply) + else: + published_reply = docs[0] if client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, message=_CommandStatusMessage.SUCCEEDED, clientId=client._topology_settings._topology_id, durationMS=duration, - reply=response_doc, + reply=published_reply, commandName=name, databaseName=dbname, requestId=request_id, @@ -234,8 +304,8 @@ def run_command( assert address is not None listeners.publish_command_success( duration, - response_doc, - name, + published_reply, + command_name, request_id, address, conn.server_connection_id, @@ -245,7 +315,7 @@ def run_command( database_name=dbname, ) - if client and client._encrypter and reply: + if client and client._encrypter and reply and is_command_response: decrypted = client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 5297a9e297..47ae00af80 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -20,37 +20,33 @@ from typing import ( TYPE_CHECKING, Any, - Callable, ContextManager, + Callable, Optional, Union, ) -from bson import _decode_all_selective -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response +from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.helpers import _handle_reauth from pymongo.logger import ( - _COMMAND_LOGGER, _SDAM_LOGGER, - _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.message import _GetMore, _OpMsg, _OpReply, _Query from pymongo.response import PinnedResponse, Response -from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: from queue import Queue from weakref import ReferenceType from bson.objectid import ObjectId - from pymongo.monitoring import _EventListeners - from pymongo.read_preferences import _ServerMode - from pymongo.server_description import ServerDescription from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler from pymongo.synchronous.monitor import Monitor from pymongo.synchronous.pool import Connection, Pool + from pymongo.monitoring import _EventListeners + from pymongo.read_preferences import _ServerMode + from pymongo.server_description import ServerDescription from pymongo.typings import _DocumentOut _IS_SYNC = True @@ -158,7 +154,6 @@ def run_operation( :param client: A MongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) @@ -166,144 +161,58 @@ def run_operation( cmd, dbn = self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 + data = b"" + max_doc_size = 0 else: message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = conn.receive_message(None) - else: - if operation.session is not None and operation.session._starting_transaction: - operation.session._transaction.set_in_progress() - conn.send_message(data, max_doc_size) - reply = conn.receive_message(request_id) + user_fields = _CURSOR_DOC_FIELDS if use_cmd else None - # Unpack and check for command errors. + def _build_reply_doc( + docs: list[dict[str, Any]], reply: Optional[Union[_OpReply, _OpMsg]] + ) -> _DocumentOut: + # Must publish in find / getMore / explain command response format. if use_cmd: - user_fields = _CURSOR_DOC_FIELDS - legacy_response = False + return docs[0] + elif operation.name == "explain": + return docs[0] if docs else {} + res: dict[str, Any] = { + "cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, # type: ignore[union-attr] + "ok": 1, + } + if operation.name == "find": + res["cursor"]["firstBatch"] = docs else: - user_fields = None - legacy_response = True - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields, - ) - if use_cmd: - first = docs[0] - operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response - # format. - res = docs[0] - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=res, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - if use_cmd: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + res["cursor"]["nextBatch"] = docs + return res + + docs, reply, duration = run_command( + conn, + cmd, + dbn, + request_id, + data, + client=client, + session=operation.session, # type: ignore[arg-type] + listeners=listeners, + address=conn.address, + start=start, + codec_options=operation.codec_options, + user_fields=user_fields, + command_name=operation.name, + pool_opts=conn.opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=bool(more_to_come), + set_conn_more_to_come=False, + is_command_response=use_cmd, + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + reply_doc_builder=_build_reply_doc, + ) + assert reply is not None response: Response @@ -325,7 +234,7 @@ def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] more_to_come=more_to_come, ) else: @@ -335,7 +244,7 @@ def run_operation( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs, + docs=docs, # type: ignore[arg-type] ) return response From 317cdbd01b28ce7c3a4b1102fd08237a525a12ef Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 19:05:35 -0400 Subject: [PATCH 03/15] PYTHON-5676 Route collection bulk writes through run_command Add process_response and decrypt_reply flags plus the conn.unack_write transport to run_command, then route bulk.write_command (acknowledged) and bulk.unack_write through it. The bulk paths pass process_response=False (they run _process_response at the call site, preserving check -> APM-succeed -> process ordering) and decrypt_reply=False (their commands are encrypted up front). The unack path publishes a copy of the command carrying the docs field while logging the bare command, matching the prior asymmetry. Drops the duplicated logging/APM/failure-conversion blocks (and the unreachable _convert_write_result-on-failure branch for unacknowledged writes). --- pymongo/asynchronous/bulk.py | 207 ++++++------------------ pymongo/asynchronous/command_runner.py | 29 ++-- pymongo/synchronous/bulk.py | 213 +++++++------------------ pymongo/synchronous/command_runner.py | 29 ++-- 4 files changed, 148 insertions(+), 330 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 8da5ffcb47..f331eb5707 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -19,8 +19,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -36,10 +34,8 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.asynchronous.client_session import ( - AsyncClientSession, - _validate_session_write_concern, -) +from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern +from pymongo.asynchronous.command_runner import run_command from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -60,14 +56,11 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, _UPDATE, _BulkWriteContext, - _convert_exception, - _convert_write_result, _EncryptedBulkWriteContext, _randint, ) @@ -253,83 +246,36 @@ async def write_command( docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" + """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, docs) try: - if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.set_in_progress() - reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + result_docs, _, _ = await run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + process_response=False, + decrypt_reply=False, + ) + reply = result_docs[0] + # Process the response from the server. await client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Process the response from the server. if isinstance(exc, (NotPrimaryError, OperationFailure)): await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise - return reply # type: ignore[return-value] + return reply async def unack_write( self, @@ -341,83 +287,34 @@ async def unack_write( docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, docs) - try: - result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - raise - return result # type: ignore[return-value] + """Send an unacknowledged batch write command.""" + # Historically the STARTED log omits the documents while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying the ``docs`` field. + published = dict(cmd) + published[bwc.field] = docs + await run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=published, + unacknowledged=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + process_response=False, + decrypt_reply=False, + ) + return None async def _execute_batch_unack( self, diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 53a2ff803a..bb3860c5e7 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -82,6 +82,8 @@ async def run_command( unacknowledged: bool = False, speculative_hello: bool = False, ensure_db: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, use_conn_transport: bool = False, max_doc_size: int = 0, more_to_come: bool = False, @@ -132,9 +134,15 @@ async def run_command( APM redaction. :param ensure_db: Add ``$db`` to the published command if missing (cursor path), after the ``STARTED`` log has been emitted. + :param process_response: Run ``client._process_response`` on success here; + the bulk paths pass False and process the reply at the call site to + keep their check -> APM-succeed -> process ordering. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; + the bulk paths pass False (their commands are encrypted up front). :param use_conn_transport: Send/receive via ``conn.send_message`` / - ``conn.receive_message`` (cursor path) instead of the raw - ``async_sendall`` / ``async_receive_message`` (network path). + ``conn.receive_message`` (cursor path) or ``conn.unack_write`` (bulk + unacknowledged) instead of the raw ``async_sendall`` / + ``async_receive_message`` (network path). :param max_doc_size: The largest document size, for ``conn.send_message``. :param more_to_come: Receive only, without sending (exhaust ``getMore``). :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the @@ -192,16 +200,19 @@ async def run_command( try: if more_to_come: reply = await conn.receive_message(None) + elif unacknowledged: + if use_conn_transport: + await conn.unack_write(msg, max_doc_size) + else: + await async_sendall(conn.conn.get_conn, msg) + # Unacknowledged, fake a successful command response. + reply = None + docs: list[dict[str, Any]] = [{"ok": 1}] elif use_conn_transport: if session is not None and session._starting_transaction: session._transaction.set_in_progress() await conn.send_message(msg, max_doc_size) reply = await conn.receive_message(request_id) - elif unacknowledged: - await async_sendall(conn.conn.get_conn, msg) - # Unacknowledged, fake a successful command response. - reply = None - docs: list[dict[str, Any]] = [{"ok": 1}] else: await async_sendall(conn.conn.get_conn, msg) reply = await async_receive_message(conn, request_id) @@ -225,7 +236,7 @@ async def run_command( cluster_time = response_doc.get("$clusterTime") if cluster_time: conn._cluster_time = cluster_time - if client: + if process_response and client: await client._process_response(response_doc, session) if check: helpers_shared._check_command_response( @@ -315,7 +326,7 @@ async def run_command( database_name=dbname, ) - if client and client._encrypter and reply and is_command_response: + if client and client._encrypter and reply and is_command_response and decrypt_reply: decrypted = await client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index f6e1d1abe4..b8ef95165e 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -19,8 +19,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -36,6 +34,9 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common +from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, _DELETE_ALL, @@ -55,23 +56,15 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, _UPDATE, _BulkWriteContext, - _convert_exception, - _convert_write_result, _EncryptedBulkWriteContext, _randint, ) from pymongo.read_preferences import ReadPreference -from pymongo.synchronous.client_session import ( - ClientSession, - _validate_session_write_concern, -) -from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -253,83 +246,36 @@ def write_command( docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" + """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, docs) try: - if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.set_in_progress() - reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + result_docs, _, _ = run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + process_response=False, + decrypt_reply=False, + ) + reply = result_docs[0] + # Process the response from the server. client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Process the response from the server. if isinstance(exc, (NotPrimaryError, OperationFailure)): client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise - return reply # type: ignore[return-value] + return reply def unack_write( self, @@ -341,83 +287,34 @@ def unack_write( docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, docs) - try: - result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - raise - return result # type: ignore[return-value] + """Send an unacknowledged batch write command.""" + # Historically the STARTED log omits the documents while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying the ``docs`` field. + published = dict(cmd) + published[bwc.field] = docs + run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=published, + unacknowledged=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + process_response=False, + decrypt_reply=False, + ) + return None def _execute_batch_unack( self, @@ -619,7 +516,9 @@ def retryable_bulk( _raise_bulk_write_error(full_result) return full_result - def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: + def execute_op_msg_no_results( + self, conn: Connection, generator: Iterator[Any] + ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 4bd93400cf..317a6bddff 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -82,6 +82,8 @@ def run_command( unacknowledged: bool = False, speculative_hello: bool = False, ensure_db: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, use_conn_transport: bool = False, max_doc_size: int = 0, more_to_come: bool = False, @@ -132,9 +134,15 @@ def run_command( APM redaction. :param ensure_db: Add ``$db`` to the published command if missing (cursor path), after the ``STARTED`` log has been emitted. + :param process_response: Run ``client._process_response`` on success here; + the bulk paths pass False and process the reply at the call site to + keep their check -> APM-succeed -> process ordering. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; + the bulk paths pass False (their commands are encrypted up front). :param use_conn_transport: Send/receive via ``conn.send_message`` / - ``conn.receive_message`` (cursor path) instead of the raw - ``sendall`` / ``receive_message`` (network path). + ``conn.receive_message`` (cursor path) or ``conn.unack_write`` (bulk + unacknowledged) instead of the raw ``sendall`` / + ``receive_message`` (network path). :param max_doc_size: The largest document size, for ``conn.send_message``. :param more_to_come: Receive only, without sending (exhaust ``getMore``). :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the @@ -192,16 +200,19 @@ def run_command( try: if more_to_come: reply = conn.receive_message(None) + elif unacknowledged: + if use_conn_transport: + conn.unack_write(msg, max_doc_size) + else: + sendall(conn.conn.get_conn, msg) + # Unacknowledged, fake a successful command response. + reply = None + docs: list[dict[str, Any]] = [{"ok": 1}] elif use_conn_transport: if session is not None and session._starting_transaction: session._transaction.set_in_progress() conn.send_message(msg, max_doc_size) reply = conn.receive_message(request_id) - elif unacknowledged: - sendall(conn.conn.get_conn, msg) - # Unacknowledged, fake a successful command response. - reply = None - docs: list[dict[str, Any]] = [{"ok": 1}] else: sendall(conn.conn.get_conn, msg) reply = receive_message(conn, request_id) @@ -225,7 +236,7 @@ def run_command( cluster_time = response_doc.get("$clusterTime") if cluster_time: conn._cluster_time = cluster_time - if client: + if process_response and client: client._process_response(response_doc, session) if check: helpers_shared._check_command_response( @@ -315,7 +326,7 @@ def run_command( database_name=dbname, ) - if client and client._encrypter and reply and is_command_response: + if client and client._encrypter and reply and is_command_response and decrypt_reply: decrypted = client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) From 49d828aa20f0bff3b657c99cde52361e9ded9288 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 19:10:01 -0400 Subject: [PATCH 04/15] PYTHON-5676 Route client-level bulk writes through run_command Route client_bulk.write_command and client_bulk.unack_write through run_command (process_response=False, decrypt_reply=False, conn.unack_write transport for the unack path). The client-level swallow semantics stay at the call site: the except wraps the raised error into reply={"error": exc} and runs the $clusterTime gossip (exc.details for OperationFailure, else {}); the unack path publishes a copy carrying ops/nsInfo while logging the bare command. With this, all command execution -- standard commands, cursor find/getMore, and both bulk write families -- flows through the single run_command path. --- pymongo/asynchronous/client_bulk.py | 196 +++++++-------------------- pymongo/synchronous/client_bulk.py | 200 ++++++++-------------------- 2 files changed, 103 insertions(+), 293 deletions(-) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index dcef4eea02..3c74bc7304 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -19,8 +19,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -41,6 +39,7 @@ ) from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.command_runner import run_command from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.helpers import _handle_reauth @@ -66,12 +65,9 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, - _convert_exception, - _convert_write_result, _randint, ) from pymongo.read_preferences import ReadPreference @@ -239,80 +235,32 @@ async def write_command( ns_docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> dict[str, Any]: - """A proxy for AsyncConnection.write_command that handles event publishing.""" + """Run a client-level batch write command, returning the response as a dict.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, op_docs, ns_docs) try: - if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.set_in_progress() - reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + result_docs, _, _ = await run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, # type: ignore[arg-type] + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + process_response=False, + decrypt_reply=False, + ) + reply = result_docs[0] # Process the response from the server. await self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} # Process the response from the server. @@ -332,81 +280,37 @@ async def unack_write( ns_docs: list[Mapping[str, Any]], client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, op_docs, ns_docs) + """Send an unacknowledged client-level batch write command.""" + # Historically the STARTED log omits the ops/nsInfo while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying those fields. + published = dict(cmd) + published["ops"] = op_docs + published["nsInfo"] = ns_docs + reply: Mapping[str, Any] = {"ok": 1} try: - result = await bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) + await run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=published, + unacknowledged=True, + use_conn_transport=True, + max_doc_size=bwc.max_bson_size, + process_response=False, + decrypt_reply=False, + ) except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} return reply diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 400b1a2170..0c9621f7d8 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -19,8 +19,6 @@ from __future__ import annotations import copy -import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -41,6 +39,7 @@ ) from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.command_runner import run_command from pymongo.synchronous.database import Database from pymongo.synchronous.helpers import _handle_reauth @@ -66,12 +65,9 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, - _convert_exception, - _convert_write_result, _randint, ) from pymongo.read_preferences import ReadPreference @@ -239,80 +235,32 @@ def write_command( ns_docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> dict[str, Any]: - """A proxy for Connection.write_command that handles event publishing.""" + """Run a client-level batch write command, returning the response as a dict.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._start(cmd, request_id, op_docs, ns_docs) try: - if bwc.session is not None and bwc.session._starting_transaction: - bwc.session._transaction.set_in_progress() - reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] - duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + result_docs, _, _ = run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, # type: ignore[arg-type] + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + process_response=False, + decrypt_reply=False, + ) + reply = result_docs[0] # Process the response from the server. self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if bwc.publish: - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} # Process the response from the server. @@ -332,81 +280,37 @@ def unack_write( ns_docs: list[Mapping[str, Any]], client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - cmd = bwc._start(cmd, request_id, op_docs, ns_docs) + """Send an unacknowledged client-level batch write command.""" + # Historically the STARTED log omits the ops/nsInfo while the published + # CommandStartedEvent includes them, so log ``cmd`` but publish a copy + # carrying those fields. + published = dict(cmd) + published["ops"] = op_docs + published["nsInfo"] = ns_docs + reply: Mapping[str, Any] = {"ok": 1} try: - result = bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] - duration = datetime.datetime.now() - bwc.start_time - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) - if bwc.publish: - bwc._succeed(request_id, reply, duration) + run_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + orig=published, + unacknowledged=True, + use_conn_transport=True, + max_doc_size=bwc.max_bson_size, + process_response=False, + decrypt_reply=False, + ) except Exception as exc: - duration = datetime.datetime.now() - bwc.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} return reply @@ -432,7 +336,9 @@ def _execute_batch( ) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]: """Executes a batch of bulkWrite server commands (ack).""" request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces) - result = self.write_command(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type] + result = self.write_command( + bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client + ) # type: ignore[arg-type] return result, to_send_ops, to_send_ns # type: ignore[return-value] def _process_results_cursor( From e9f7e90c1598b5b431aad3daf9339c5d9787f296 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 4 Jun 2026 19:31:24 -0400 Subject: [PATCH 05/15] PYTHON-5676 Rename network.py to command_encoder.py After the consolidation, this module no longer does any networking -- the send/receive round trip moved into command_runner.run_command. It now only encodes a command and runs its pre-flight (read preference/concern, collation, $clusterTime, auto-encryption, CSOT, OP_MSG encoding), so 'network' was misleading and collided with the lower-level network_layer.py (raw sockets). Pure rename: git mv the async module (synchro regenerates the sync mirror) and update the two pool.py imports. No behavior change. --- pymongo/asynchronous/{network.py => command_encoder.py} | 9 ++++++++- pymongo/asynchronous/pool.py | 2 +- pymongo/synchronous/{network.py => command_encoder.py} | 9 ++++++++- pymongo/synchronous/pool.py | 2 +- 4 files changed, 18 insertions(+), 4 deletions(-) rename pymongo/asynchronous/{network.py => command_encoder.py} (94%) rename pymongo/synchronous/{network.py => command_encoder.py} (94%) diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/command_encoder.py similarity index 94% rename from pymongo/asynchronous/network.py rename to pymongo/asynchronous/command_encoder.py index d37fb2acff..a107c1ce54 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/command_encoder.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Internal network layer helper methods.""" +"""Encode a command and run it over a connection. + +This builds the wire-protocol message for a single command -- applying read +preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, +and OP_MSG encoding -- then hands it to +:func:`pymongo.asynchronous.command_runner.run_command` for the network round +trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. +""" from __future__ import annotations import datetime diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 07e528a607..d48f8310ee 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -39,8 +39,8 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern +from pymongo.asynchronous.command_encoder import command from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.network import command from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/command_encoder.py similarity index 94% rename from pymongo/synchronous/network.py rename to pymongo/synchronous/command_encoder.py index 07b285e59e..3547e49075 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/command_encoder.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Internal network layer helper methods.""" +"""Encode a command and run it over a connection. + +This builds the wire-protocol message for a single command -- applying read +preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, +and OP_MSG encoding -- then hands it to +:func:`pymongo.command_runner.run_command` for the network round +trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. +""" from __future__ import annotations import datetime diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 280afebebe..4d95e0f489 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -88,8 +88,8 @@ from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker from pymongo.synchronous.client_session import _validate_session_write_concern +from pymongo.synchronous.command_encoder import command from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.network import command if TYPE_CHECKING: from bson import CodecOptions From 83aa84587766bdf7f6e0e8ea9f2917caba3f3f89 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Fri, 5 Jun 2026 15:12:29 -0400 Subject: [PATCH 06/15] Noah feedback --- pymongo/asynchronous/bulk.py | 9 +- pymongo/asynchronous/client_bulk.py | 9 +- pymongo/asynchronous/command_encoder.py | 59 ++++--- pymongo/asynchronous/command_runner.py | 225 ++++++++++++++++++++++-- pymongo/asynchronous/pool.py | 18 -- pymongo/asynchronous/server.py | 7 +- pymongo/message.py | 102 ----------- pymongo/synchronous/bulk.py | 10 +- pymongo/synchronous/client_bulk.py | 9 +- pymongo/synchronous/command_encoder.py | 60 ++++--- pymongo/synchronous/command_runner.py | 225 ++++++++++++++++++++++-- pymongo/synchronous/pool.py | 18 -- pymongo/synchronous/server.py | 7 +- 13 files changed, 523 insertions(+), 235 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index f331eb5707..7bf05f5526 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -35,7 +35,7 @@ from bson.raw_bson import RawBSONDocument from pymongo import _csot, common from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern -from pymongo.asynchronous.command_runner import run_command +from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -293,7 +293,7 @@ async def unack_write( # carrying the ``docs`` field. published = dict(cmd) published[bwc.field] = docs - await run_command( + await run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -308,11 +308,8 @@ async def unack_write( op_id=bwc.op_id, command_name=bwc.name, orig=published, - unacknowledged=True, use_conn_transport=True, max_doc_size=max_doc_size, - process_response=False, - decrypt_reply=False, ) return None @@ -386,7 +383,7 @@ async def _execute_command( run = self.current_run # AsyncConnection.command validates the session, but we use - # AsyncConnection.write_command + # run_command/run_unacknowledged_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 3c74bc7304..f768143a19 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -39,7 +39,7 @@ ) from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.command_runner import run_command +from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.helpers import _handle_reauth @@ -289,7 +289,7 @@ async def unack_write( published["nsInfo"] = ns_docs reply: Mapping[str, Any] = {"ok": 1} try: - await run_command( + await run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -304,11 +304,8 @@ async def unack_write( op_id=bwc.op_id, command_name=bwc.name, orig=published, - unacknowledged=True, use_conn_transport=True, max_doc_size=bwc.max_bson_size, - process_response=False, - decrypt_reply=False, ) except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. @@ -409,7 +406,7 @@ async def _execute_command( listeners = self.client._event_listeners # AsyncConnection.command validates the session, but we use - # AsyncConnection.write_command + # run_command/run_unacknowledged_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/asynchronous/command_encoder.py b/pymongo/asynchronous/command_encoder.py index a107c1ce54..25efcd9d8a 100644 --- a/pymongo/asynchronous/command_encoder.py +++ b/pymongo/asynchronous/command_encoder.py @@ -34,7 +34,7 @@ ) from pymongo import _csot, message -from pymongo.asynchronous.command_runner import run_command +from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command from pymongo.compression_support import _NO_COMPRESSION from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate @@ -144,24 +144,41 @@ async def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - docs, _, _ = await run_command( - conn, - spec, - dbname, - request_id, - msg, - client=client, - session=session, - listeners=listeners, - address=address, - start=start, - codec_options=codec_options, - user_fields=user_fields, - orig=orig, - check=check, - allowable_errors=allowable_errors, - parse_write_concern_error=parse_write_concern_error, - unacknowledged=unacknowledged, - speculative_hello=speculative_hello, - ) + if unacknowledged: + docs, _, _ = await run_unacknowledged_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + speculative_hello=speculative_hello, + ) + else: + docs, _, _ = await run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + ) return docs[0] # type: ignore[return-value] diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index bb3860c5e7..0630957d2a 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The single code path for executing a command over a connection. - -Every database operation -- standard commands and cursor ``find``/``getMore`` -operations -- runs its network round trip through :func:`run_command`. The -function owns the entire shared skeleton: command logging, APM event -publishing, ``send``/``receive``, ``$clusterTime`` gossip, -``_process_response``, ``_check_command_response``, failure conversion, and -auto-encryption decryption. Callers supply only the parts that vary (the -encoded message and a handful of transport/output hooks). +"""The shared code path for executing a command over a connection. + +Every database operation runs its network round trip through one of three +public entry points -- :func:`run_command` (acknowledged commands and bulk +write batches), :func:`run_unacknowledged_command` (unacknowledged writes), and +:func:`run_cursor_command` (cursor ``find``/``getMore`` operations) -- each of +which wraps the private :func:`_run_command`. ``_run_command`` owns the entire +shared skeleton: command logging, APM event publishing, ``send``/``receive``, +``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, +failure conversion, and auto-encryption decryption. The three wrappers fix the +transport and response-shaping flags for their command type so call sites pass +only the parts that vary (the encoded message and a handful of hooks). """ from __future__ import annotations @@ -58,7 +61,7 @@ _IS_SYNC = False -async def run_command( +async def _run_command( conn: AsyncConnection, cmd: MutableMapping[str, Any], dbname: str, @@ -97,7 +100,13 @@ async def run_command( ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. - This is the single code path for command execution. It publishes the + This is the shared implementation behind :func:`run_command`, + :func:`run_unacknowledged_command`, and :func:`run_cursor_command`. Those + three public entry points each fix the transport and response-shaping flags + for their command type; the bare kwargs here should not be set directly by + new call sites. + + It publishes the ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs the network round trip, gossips ``$clusterTime``, runs ``client._process_response`` and ``_check_command_response``, and decrypts @@ -333,3 +342,197 @@ async def run_command( ) return docs, reply, duration + + +async def run_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + speculative_hello: bool = False, + use_conn_transport: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send an acknowledged command and return ``(docs, reply, duration)``. + + This is the entry point for standard commands and bulk write batches: it + sends ``msg``, receives the reply, runs ``_process_response`` and + ``_check_command_response``, decrypts the reply when auto-encryption is + enabled, and publishes the command log/APM events. + + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` (bulk path) instead of the raw + ``async_sendall`` / ``async_receive_message`` (standard command path). + :param process_response: Run ``client._process_response`` here; the bulk + paths pass False and process the reply at the call site to keep their + check -> APM-succeed -> process ordering. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the + bulk paths pass False (their commands are encrypted up front). + + See :func:`_run_command` for the remaining parameters. + """ + return await _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + op_id=op_id, + command_name=command_name, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + use_conn_transport=use_conn_transport, + process_response=process_response, + decrypt_reply=decrypt_reply, + ) + + +async def run_unacknowledged_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + speculative_hello: bool = False, + use_conn_transport: bool = False, + max_doc_size: int = 0, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send an unacknowledged command and fake an ``{"ok": 1}`` reply. + + The message is sent only -- no reply is received -- so the response + processing, command checking, and decryption steps are skipped. + + :param use_conn_transport: Send via ``conn.unack_write`` (bulk path) instead + of the raw ``async_sendall`` (standard command path). + :param max_doc_size: The largest document size, for ``conn.unack_write``. + + See :func:`_run_command` for the remaining parameters. + """ + return await _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + op_id=op_id, + command_name=command_name, + speculative_hello=speculative_hello, + unacknowledged=True, + use_conn_transport=use_conn_transport, + max_doc_size=max_doc_size, + process_response=False, + decrypt_reply=False, + ) + + +async def run_cursor_command( + conn: AsyncConnection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[AsyncMongoClient[Any]], + session: Optional[AsyncClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + command_name: str, + user_fields: Optional[Mapping[str, Any]] = None, + pool_opts: Optional[PoolOptions] = None, + max_doc_size: int = 0, + more_to_come: bool = False, + is_command_response: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, + reply_doc_builder: Optional[ + Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + ] = None, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Run a cursor ``find``/``getMore`` operation over ``conn``. + + Uses the connection transport, leaves ``conn.more_to_come`` untouched (the + cursor path manages exhaust separately), and shapes the published reply in + the find/getMore command response format. + + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param is_command_response: True for an OP_MSG command response; False for a + legacy OP_QUERY cursor response. + :param unpack_res: A callable decoding the wire response. + :param cursor_id: The cursor id passed to ``unpack_res``. + :param reply_doc_builder: Builds the reply document published in the + ``SUCCEEDED`` event from ``(docs, reply)``. + + See :func:`_run_command` for the remaining parameters. + """ + return await _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + command_name=command_name, + pool_opts=pool_opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=more_to_come, + set_conn_more_to_come=False, + is_command_response=is_command_response, + unpack_res=unpack_res, + cursor_id=cursor_id, + reply_doc_builder=reply_doc_builder, + ) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index d48f8310ee..d22e300adc 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -473,24 +473,6 @@ async def unack_write(self, msg: bytes, max_doc_size: int) -> None: self._raise_if_not_writable(True) await self.send_message(msg, max_doc_size) - async def write_command( - self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]] - ) -> dict[str, Any]: - """Send "insert" etc. command, returning response as a dict. - - Can raise ConnectionFailure or OperationFailure. - - :param request_id: an int. - :param msg: bytes, the command message. - """ - await self.send_message(msg, 0) - reply = await self.receive_message(request_id) - result = reply.command_response(codec_options) - - # Raises NotPrimaryError or OperationFailure. - helpers_shared._check_command_response(result, self.max_wire_version) - return result - async def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index b18cf56c52..0c4fbed00f 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -26,7 +26,7 @@ Union, ) -from pymongo.asynchronous.command_runner import run_command +from pymongo.asynchronous.command_runner import run_cursor_command from pymongo.asynchronous.helpers import _handle_reauth from pymongo.logger import ( _SDAM_LOGGER, @@ -187,7 +187,7 @@ def _build_reply_doc( res["cursor"]["nextBatch"] = docs return res - docs, reply, duration = await run_command( + docs, reply, duration = await run_cursor_command( conn, cmd, dbn, @@ -202,11 +202,8 @@ def _build_reply_doc( user_fields=user_fields, command_name=operation.name, pool_opts=conn.opts, - ensure_db=True, - use_conn_transport=True, max_doc_size=max_doc_size, more_to_come=bool(more_to_come), - set_conn_more_to_come=False, is_command_response=use_cmd, unpack_res=unpack_res, cursor_id=operation.cursor_id, diff --git a/pymongo/message.py b/pymongo/message.py index fdac2b4daa..b6209b9df0 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -69,7 +69,6 @@ _AgnosticClientSession, _AgnosticConnection, _AgnosticMongoClient, - _DocumentOut, ) @@ -146,42 +145,6 @@ def _convert_client_bulk_exception(exception: Exception) -> dict[str, Any]: } -def _convert_write_result( - operation: str, command: Mapping[str, Any], result: Mapping[str, Any] -) -> dict[str, Any]: - """Convert a legacy write result to write command format.""" - # Based on _merge_legacy from bulk.py - affected = result.get("n", 0) - res = {"ok": 1, "n": affected} - errmsg = result.get("errmsg", result.get("err", "")) - if errmsg: - # The write was successful on at least the primary so don't return. - if result.get("wtimeout"): - res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} - else: - # The write failed. - error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} - if "errInfo" in result: - error["errInfo"] = result["errInfo"] - res["writeErrors"] = [error] - return res - if operation == "insert": - # GLE result for insert is always 0 in most MongoDB versions. - res["n"] = len(command["documents"]) - elif operation == "update": - if "upserted" in result: - res["upserted"] = [{"index": 0, "_id": result["upserted"]}] - # Versions of MongoDB before 2.6 don't return the _id for an - # upsert if _id is not an ObjectId. - elif result.get("updatedExisting") is False and affected == 1: - # If _id is in both the update document *and* the query spec - # the update document _id takes precedence. - update = command["updates"][0] - _id = update["u"].get("_id", update["q"].get("_id")) - res["upserted"] = [{"index": 0, "_id": _id}] - return res - - _OPTIONS = { "tailable": 2, "oplogReplay": 8, @@ -540,34 +503,6 @@ def max_split_size(self) -> int: """The maximum size of a BSON command before batch splitting.""" return self.max_bson_size - def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None: - """Publish a CommandSucceededEvent.""" - self.listeners.publish_command_success( - duration, - reply, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None: - """Publish a CommandFailedEvent.""" - self.listeners.publish_command_failure( - duration, - failure, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - class _BulkWriteContext(_BulkWriteContextBase): """A wrapper around AsyncConnection/Connection for use with the collection-level bulk write API.""" @@ -607,22 +542,6 @@ def batch_command( raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send - def _start( - self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd[self.field] = docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - class _EncryptedBulkWriteContext(_BulkWriteContext): __slots__ = () @@ -869,27 +788,6 @@ def batch_command( raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send_ops, to_send_ns - def _start( - self, - cmd: MutableMapping[str, Any], - request_id: int, - op_docs: list[Mapping[str, Any]], - ns_docs: list[Mapping[str, Any]], - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd["ops"] = op_docs - cmd["nsInfo"] = ns_docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - _OP_MSG_OVERHEAD = 1000 diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index b8ef95165e..3aaf6c0553 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -65,6 +65,9 @@ _randint, ) from pymongo.read_preferences import ReadPreference +from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -293,7 +296,7 @@ def unack_write( # carrying the ``docs`` field. published = dict(cmd) published[bwc.field] = docs - run_command( + run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -308,11 +311,8 @@ def unack_write( op_id=bwc.op_id, command_name=bwc.name, orig=published, - unacknowledged=True, use_conn_transport=True, max_doc_size=max_doc_size, - process_response=False, - decrypt_reply=False, ) return None @@ -386,7 +386,7 @@ def _execute_command( run = self.current_run # Connection.command validates the session, but we use - # Connection.write_command + # run_command/run_unacknowledged_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 0c9621f7d8..191eeb6346 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -39,7 +39,7 @@ ) from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command from pymongo.synchronous.database import Database from pymongo.synchronous.helpers import _handle_reauth @@ -289,7 +289,7 @@ def unack_write( published["nsInfo"] = ns_docs reply: Mapping[str, Any] = {"ok": 1} try: - run_command( + run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -304,11 +304,8 @@ def unack_write( op_id=bwc.op_id, command_name=bwc.name, orig=published, - unacknowledged=True, use_conn_transport=True, max_doc_size=bwc.max_bson_size, - process_response=False, - decrypt_reply=False, ) except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. @@ -409,7 +406,7 @@ def _execute_command( listeners = self.client._event_listeners # Connection.command validates the session, but we use - # Connection.write_command + # run_command/run_unacknowledged_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/synchronous/command_encoder.py b/pymongo/synchronous/command_encoder.py index 3547e49075..d620f6ee90 100644 --- a/pymongo/synchronous/command_encoder.py +++ b/pymongo/synchronous/command_encoder.py @@ -37,7 +37,7 @@ from pymongo.compression_support import _NO_COMPRESSION from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command if TYPE_CHECKING: from bson import CodecOptions @@ -144,25 +144,41 @@ def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - docs, _, _ = run_command( - conn, - spec, - dbname, - request_id, - msg, - client=client, - session=session, - listeners=listeners, - address=address, - start=start, - codec_options=codec_options, - user_fields=user_fields, - orig=orig, - check=check, - allowable_errors=allowable_errors, - parse_write_concern_error=parse_write_concern_error, - unacknowledged=unacknowledged, - speculative_hello=speculative_hello, - ) + if unacknowledged: + docs, _, _ = run_unacknowledged_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + speculative_hello=speculative_hello, + ) + else: + docs, _, _ = run_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + ) return docs[0] # type: ignore[return-value] ->>>>>>> 0d7dedb0 (PYTHON-5676 Add command_runner.run_command; route network.command() through it) diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 317a6bddff..3523d3d443 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The single code path for executing a command over a connection. - -Every database operation -- standard commands and cursor ``find``/``getMore`` -operations -- runs its network round trip through :func:`run_command`. The -function owns the entire shared skeleton: command logging, APM event -publishing, ``send``/``receive``, ``$clusterTime`` gossip, -``_process_response``, ``_check_command_response``, failure conversion, and -auto-encryption decryption. Callers supply only the parts that vary (the -encoded message and a handful of transport/output hooks). +"""The shared code path for executing a command over a connection. + +Every database operation runs its network round trip through one of three +public entry points -- :func:`run_command` (acknowledged commands and bulk +write batches), :func:`run_unacknowledged_command` (unacknowledged writes), and +:func:`run_cursor_command` (cursor ``find``/``getMore`` operations) -- each of +which wraps the private :func:`_run_command`. ``_run_command`` owns the entire +shared skeleton: command logging, APM event publishing, ``send``/``receive``, +``$clusterTime`` gossip, ``_process_response``, ``_check_command_response``, +failure conversion, and auto-encryption decryption. The three wrappers fix the +transport and response-shaping flags for their command type so call sites pass +only the parts that vary (the encoded message and a handful of hooks). """ from __future__ import annotations @@ -58,7 +61,7 @@ _IS_SYNC = True -def run_command( +def _run_command( conn: Connection, cmd: MutableMapping[str, Any], dbname: str, @@ -97,7 +100,13 @@ def run_command( ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. - This is the single code path for command execution. It publishes the + This is the shared implementation behind :func:`run_command`, + :func:`run_unacknowledged_command`, and :func:`run_cursor_command`. Those + three public entry points each fix the transport and response-shaping flags + for their command type; the bare kwargs here should not be set directly by + new call sites. + + It publishes the ``STARTED``/``SUCCEEDED``/``FAILED`` command log and APM events, performs the network round trip, gossips ``$clusterTime``, runs ``client._process_response`` and ``_check_command_response``, and decrypts @@ -333,3 +342,197 @@ def run_command( ) return docs, reply, duration + + +def run_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + speculative_hello: bool = False, + use_conn_transport: bool = False, + process_response: bool = True, + decrypt_reply: bool = True, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send an acknowledged command and return ``(docs, reply, duration)``. + + This is the entry point for standard commands and bulk write batches: it + sends ``msg``, receives the reply, runs ``_process_response`` and + ``_check_command_response``, decrypts the reply when auto-encryption is + enabled, and publishes the command log/APM events. + + :param use_conn_transport: Send/receive via ``conn.send_message`` / + ``conn.receive_message`` (bulk path) instead of the raw + ``sendall`` / ``receive_message`` (standard command path). + :param process_response: Run ``client._process_response`` here; the bulk + paths pass False and process the reply at the call site to keep their + check -> APM-succeed -> process ordering. + :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the + bulk paths pass False (their commands are encrypted up front). + + See :func:`_run_command` for the remaining parameters. + """ + return _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + op_id=op_id, + command_name=command_name, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + use_conn_transport=use_conn_transport, + process_response=process_response, + decrypt_reply=decrypt_reply, + ) + + +def run_unacknowledged_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + user_fields: Optional[Mapping[str, Any]] = None, + orig: Optional[MutableMapping[str, Any]] = None, + op_id: Optional[int] = None, + command_name: Optional[str] = None, + speculative_hello: bool = False, + use_conn_transport: bool = False, + max_doc_size: int = 0, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send an unacknowledged command and fake an ``{"ok": 1}`` reply. + + The message is sent only -- no reply is received -- so the response + processing, command checking, and decryption steps are skipped. + + :param use_conn_transport: Send via ``conn.unack_write`` (bulk path) instead + of the raw ``sendall`` (standard command path). + :param max_doc_size: The largest document size, for ``conn.unack_write``. + + See :func:`_run_command` for the remaining parameters. + """ + return _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + op_id=op_id, + command_name=command_name, + speculative_hello=speculative_hello, + unacknowledged=True, + use_conn_transport=use_conn_transport, + max_doc_size=max_doc_size, + process_response=False, + decrypt_reply=False, + ) + + +def run_cursor_command( + conn: Connection, + cmd: MutableMapping[str, Any], + dbname: str, + request_id: int, + msg: bytes, + *, + client: Optional[MongoClient[Any]], + session: Optional[ClientSession], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + codec_options: CodecOptions[_DocumentType], + command_name: str, + user_fields: Optional[Mapping[str, Any]] = None, + pool_opts: Optional[PoolOptions] = None, + max_doc_size: int = 0, + more_to_come: bool = False, + is_command_response: bool = True, + unpack_res: Optional[Callable[..., Any]] = None, + cursor_id: Optional[int] = None, + reply_doc_builder: Optional[ + Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + ] = None, +) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Run a cursor ``find``/``getMore`` operation over ``conn``. + + Uses the connection transport, leaves ``conn.more_to_come`` untouched (the + cursor path manages exhaust separately), and shapes the published reply in + the find/getMore command response format. + + :param more_to_come: Receive only, without sending (exhaust ``getMore``). + :param is_command_response: True for an OP_MSG command response; False for a + legacy OP_QUERY cursor response. + :param unpack_res: A callable decoding the wire response. + :param cursor_id: The cursor id passed to ``unpack_res``. + :param reply_doc_builder: Builds the reply document published in the + ``SUCCEEDED`` event from ``(docs, reply)``. + + See :func:`_run_command` for the remaining parameters. + """ + return _run_command( + conn, + cmd, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + command_name=command_name, + pool_opts=pool_opts, + ensure_db=True, + use_conn_transport=True, + max_doc_size=max_doc_size, + more_to_come=more_to_come, + set_conn_more_to_come=False, + is_command_response=is_command_response, + unpack_res=unpack_res, + cursor_id=cursor_id, + reply_doc_builder=reply_doc_builder, + ) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 4d95e0f489..1dd4b81a35 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -473,24 +473,6 @@ def unack_write(self, msg: bytes, max_doc_size: int) -> None: self._raise_if_not_writable(True) self.send_message(msg, max_doc_size) - def write_command( - self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]] - ) -> dict[str, Any]: - """Send "insert" etc. command, returning response as a dict. - - Can raise ConnectionFailure or OperationFailure. - - :param request_id: an int. - :param msg: bytes, the command message. - """ - self.send_message(msg, 0) - reply = self.receive_message(request_id) - result = reply.command_response(codec_options) - - # Raises NotPrimaryError or OperationFailure. - helpers_shared._check_command_response(result, self.max_wire_version) - return result - def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 47ae00af80..7041ee12e8 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -35,6 +35,8 @@ ) from pymongo.message import _GetMore, _OpMsg, _OpReply, _Query from pymongo.response import PinnedResponse, Response +from pymongo.synchronous.command_runner import run_cursor_command +from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: from queue import Queue @@ -187,7 +189,7 @@ def _build_reply_doc( res["cursor"]["nextBatch"] = docs return res - docs, reply, duration = run_command( + docs, reply, duration = run_cursor_command( conn, cmd, dbn, @@ -202,11 +204,8 @@ def _build_reply_doc( user_fields=user_fields, command_name=operation.name, pool_opts=conn.opts, - ensure_db=True, - use_conn_transport=True, max_doc_size=max_doc_size, more_to_come=bool(more_to_come), - set_conn_more_to_come=False, is_command_response=use_cmd, unpack_res=unpack_res, cursor_id=operation.cursor_id, From eb8ec077b227ff4b75ae8c1563ccd4f494d18fed Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Fri, 5 Jun 2026 15:17:56 -0400 Subject: [PATCH 07/15] =?UTF-8?q?rename=20run=5Fcommand=20=E2=86=92=20run?= =?UTF-8?q?=5Facknowledged=5Fcommand?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pymongo/asynchronous/bulk.py | 9 ++++++--- pymongo/asynchronous/client_bulk.py | 9 ++++++--- pymongo/asynchronous/command_encoder.py | 10 +++++++--- pymongo/asynchronous/command_runner.py | 9 +++++---- pymongo/synchronous/bulk.py | 9 ++++++--- pymongo/synchronous/client_bulk.py | 9 ++++++--- pymongo/synchronous/command_encoder.py | 10 +++++++--- pymongo/synchronous/command_runner.py | 9 +++++---- 8 files changed, 48 insertions(+), 26 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 7bf05f5526..f93b9ceb42 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -35,7 +35,10 @@ from bson.raw_bson import RawBSONDocument from pymongo import _csot, common from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern -from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.asynchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -249,7 +252,7 @@ async def write_command( """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs try: - result_docs, _, _ = await run_command( + result_docs, _, _ = await run_acknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -383,7 +386,7 @@ async def _execute_command( run = self.current_run # AsyncConnection.command validates the session, but we use - # run_command/run_unacknowledged_command. + # run_acknowledged_command/run_unacknowledged_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index f768143a19..200e85fa45 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -39,7 +39,10 @@ ) from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.asynchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.helpers import _handle_reauth @@ -239,7 +242,7 @@ async def write_command( cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs try: - result_docs, _, _ = await run_command( + result_docs, _, _ = await run_acknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -406,7 +409,7 @@ async def _execute_command( listeners = self.client._event_listeners # AsyncConnection.command validates the session, but we use - # run_command/run_unacknowledged_command. + # run_acknowledged_command/run_unacknowledged_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/asynchronous/command_encoder.py b/pymongo/asynchronous/command_encoder.py index 25efcd9d8a..e60aabc8e4 100644 --- a/pymongo/asynchronous/command_encoder.py +++ b/pymongo/asynchronous/command_encoder.py @@ -17,7 +17,8 @@ This builds the wire-protocol message for a single command -- applying read preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, and OP_MSG encoding -- then hands it to -:func:`pymongo.asynchronous.command_runner.run_command` for the network round +:func:`pymongo.asynchronous.command_runner.run_acknowledged_command` for the +network round trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. """ from __future__ import annotations @@ -34,7 +35,10 @@ ) from pymongo import _csot, message -from pymongo.asynchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.asynchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.compression_support import _NO_COMPRESSION from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate @@ -162,7 +166,7 @@ async def command( speculative_hello=speculative_hello, ) else: - docs, _, _ = await run_command( + docs, _, _ = await run_acknowledged_command( conn, spec, dbname, diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 0630957d2a..4bf7c06434 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -15,8 +15,9 @@ """The shared code path for executing a command over a connection. Every database operation runs its network round trip through one of three -public entry points -- :func:`run_command` (acknowledged commands and bulk -write batches), :func:`run_unacknowledged_command` (unacknowledged writes), and +public entry points -- :func:`run_acknowledged_command` (acknowledged commands +and bulk write batches), :func:`run_unacknowledged_command` (unacknowledged +writes), and :func:`run_cursor_command` (cursor ``find``/``getMore`` operations) -- each of which wraps the private :func:`_run_command`. ``_run_command`` owns the entire shared skeleton: command logging, APM event publishing, ``send``/``receive``, @@ -100,7 +101,7 @@ async def _run_command( ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. - This is the shared implementation behind :func:`run_command`, + This is the shared implementation behind :func:`run_acknowledged_command`, :func:`run_unacknowledged_command`, and :func:`run_cursor_command`. Those three public entry points each fix the transport and response-shaping flags for their command type; the bare kwargs here should not be set directly by @@ -344,7 +345,7 @@ async def _run_command( return docs, reply, duration -async def run_command( +async def run_acknowledged_command( conn: AsyncConnection, cmd: MutableMapping[str, Any], dbname: str, diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 3aaf6c0553..0205c7bda5 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -66,7 +66,10 @@ ) from pymongo.read_preferences import ReadPreference from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.synchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern @@ -252,7 +255,7 @@ def write_command( """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs try: - result_docs, _, _ = run_command( + result_docs, _, _ = run_acknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -386,7 +389,7 @@ def _execute_command( run = self.current_run # Connection.command validates the session, but we use - # run_command/run_unacknowledged_command. + # run_acknowledged_command/run_unacknowledged_command. conn.validate_session(client, session) last_run = False diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 191eeb6346..c1d94f49ed 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -39,7 +39,10 @@ ) from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.synchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.synchronous.database import Database from pymongo.synchronous.helpers import _handle_reauth @@ -239,7 +242,7 @@ def write_command( cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs try: - result_docs, _, _ = run_command( + result_docs, _, _ = run_acknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -406,7 +409,7 @@ def _execute_command( listeners = self.client._event_listeners # Connection.command validates the session, but we use - # run_command/run_unacknowledged_command. + # run_acknowledged_command/run_unacknowledged_command. conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( diff --git a/pymongo/synchronous/command_encoder.py b/pymongo/synchronous/command_encoder.py index d620f6ee90..d188da39da 100644 --- a/pymongo/synchronous/command_encoder.py +++ b/pymongo/synchronous/command_encoder.py @@ -17,7 +17,8 @@ This builds the wire-protocol message for a single command -- applying read preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, and OP_MSG encoding -- then hands it to -:func:`pymongo.command_runner.run_command` for the network round +:func:`pymongo.command_runner.run_acknowledged_command` for the +network round trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. """ from __future__ import annotations @@ -37,7 +38,10 @@ from pymongo.compression_support import _NO_COMPRESSION from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.synchronous.command_runner import run_command, run_unacknowledged_command +from pymongo.synchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) if TYPE_CHECKING: from bson import CodecOptions @@ -162,7 +166,7 @@ def command( speculative_hello=speculative_hello, ) else: - docs, _, _ = run_command( + docs, _, _ = run_acknowledged_command( conn, spec, dbname, diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 3523d3d443..81b01a2b6a 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -15,8 +15,9 @@ """The shared code path for executing a command over a connection. Every database operation runs its network round trip through one of three -public entry points -- :func:`run_command` (acknowledged commands and bulk -write batches), :func:`run_unacknowledged_command` (unacknowledged writes), and +public entry points -- :func:`run_acknowledged_command` (acknowledged commands +and bulk write batches), :func:`run_unacknowledged_command` (unacknowledged +writes), and :func:`run_cursor_command` (cursor ``find``/``getMore`` operations) -- each of which wraps the private :func:`_run_command`. ``_run_command`` owns the entire shared skeleton: command logging, APM event publishing, ``send``/``receive``, @@ -100,7 +101,7 @@ def _run_command( ) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. - This is the shared implementation behind :func:`run_command`, + This is the shared implementation behind :func:`run_acknowledged_command`, :func:`run_unacknowledged_command`, and :func:`run_cursor_command`. Those three public entry points each fix the transport and response-shaping flags for their command type; the bare kwargs here should not be set directly by @@ -344,7 +345,7 @@ def _run_command( return docs, reply, duration -def run_command( +def run_acknowledged_command( conn: Connection, cmd: MutableMapping[str, Any], dbname: str, From 02d0ea12419ee802d1db82d075696ad890394656 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Mon, 8 Jun 2026 11:03:23 -0400 Subject: [PATCH 08/15] Steve feedback --- pymongo/asynchronous/bulk.py | 4 ---- pymongo/asynchronous/client_bulk.py | 8 ++------ pymongo/asynchronous/command_runner.py | 17 +++++++++-------- pymongo/asynchronous/pool.py | 22 ++-------------------- pymongo/synchronous/bulk.py | 4 ---- pymongo/synchronous/client_bulk.py | 8 ++------ pymongo/synchronous/command_runner.py | 17 +++++++++-------- pymongo/synchronous/pool.py | 22 ++-------------------- 8 files changed, 26 insertions(+), 76 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index f93b9ceb42..97d82818fd 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -267,14 +267,10 @@ async def write_command( op_id=bwc.op_id, command_name=bwc.name, use_conn_transport=True, - process_response=False, decrypt_reply=False, ) reply = result_docs[0] - # Process the response from the server. - await client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - # Process the response from the server. if isinstance(exc, (NotPrimaryError, OperationFailure)): await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 200e85fa45..85a5ada6e3 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -257,16 +257,12 @@ async def write_command( op_id=bwc.op_id, command_name=bwc.name, use_conn_transport=True, - process_response=False, decrypt_reply=False, ) reply = result_docs[0] - # Process the response from the server. - await self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} - # Process the response from the server. if isinstance(exc, OperationFailure): await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] else: @@ -290,9 +286,8 @@ async def unack_write( published = dict(cmd) published["ops"] = op_docs published["nsInfo"] = ns_docs - reply: Mapping[str, Any] = {"ok": 1} try: - await run_unacknowledged_command( + result_docs, _, _ = await run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -310,6 +305,7 @@ async def unack_write( use_conn_transport=True, max_doc_size=bwc.max_bson_size, ) + reply: Mapping[str, Any] = result_docs[0] except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 4bf7c06434..40e999da71 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -150,8 +150,7 @@ async def _run_command( :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the bulk paths pass False (their commands are encrypted up front). :param use_conn_transport: Send/receive via ``conn.send_message`` / - ``conn.receive_message`` (cursor path) or ``conn.unack_write`` (bulk - unacknowledged) instead of the raw ``async_sendall`` / + ``conn.receive_message`` instead of the raw ``async_sendall`` / ``async_receive_message`` (network path). :param max_doc_size: The largest document size, for ``conn.send_message``. :param more_to_come: Receive only, without sending (exhaust ``getMore``). @@ -212,7 +211,11 @@ async def _run_command( reply = await conn.receive_message(None) elif unacknowledged: if use_conn_transport: - await conn.unack_write(msg, max_doc_size) + if not conn.is_writable: + raise NotPrimaryError( + "not primary", {"ok": 0, "errmsg": "not primary", "code": 10107} + ) + await conn.send_message(msg, max_doc_size) else: await async_sendall(conn.conn.get_conn, msg) # Unacknowledged, fake a successful command response. @@ -380,9 +383,7 @@ async def run_acknowledged_command( :param use_conn_transport: Send/receive via ``conn.send_message`` / ``conn.receive_message`` (bulk path) instead of the raw ``async_sendall`` / ``async_receive_message`` (standard command path). - :param process_response: Run ``client._process_response`` here; the bulk - paths pass False and process the reply at the call site to keep their - check -> APM-succeed -> process ordering. + :param process_response: Run ``client._process_response`` here. :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the bulk paths pass False (their commands are encrypted up front). @@ -440,9 +441,9 @@ async def run_unacknowledged_command( The message is sent only -- no reply is received -- so the response processing, command checking, and decryption steps are skipped. - :param use_conn_transport: Send via ``conn.unack_write`` (bulk path) instead + :param use_conn_transport: Send via ``conn.send_message`` (bulk path) instead of the raw ``async_sendall`` (standard command path). - :param max_doc_size: The largest document size, for ``conn.unack_write``. + :param max_doc_size: The largest document size, for ``conn.send_message``. See :func:`_run_command` for the remaining parameters. """ diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index d22e300adc..60ba236c3d 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -393,7 +393,8 @@ async def command( self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None unacknowledged = bool(write_concern and not write_concern.acknowledged) - self._raise_if_not_writable(unacknowledged) + if unacknowledged and not self.is_writable: + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) try: if session is not None and session._starting_transaction: session._transaction.set_in_progress() @@ -454,25 +455,6 @@ async def receive_message(self, request_id: Optional[int]) -> _OpMsg: except BaseException as error: await self._raise_connection_failure(error) - def _raise_if_not_writable(self, unacknowledged: bool) -> None: - """Raise NotPrimaryError on unacknowledged write if this socket is not - writable. - """ - if unacknowledged and not self.is_writable: - # Write won't succeed, bail as if we'd received a not primary error. - raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - - async def unack_write(self, msg: bytes, max_doc_size: int) -> None: - """Send unack OP_MSG. - - Can raise ConnectionFailure or InvalidDocument. - - :param msg: bytes, an OP_MSG message. - :param max_doc_size: size in bytes of the largest document in `msg`. - """ - self._raise_if_not_writable(True) - await self.send_message(msg, max_doc_size) - async def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 0205c7bda5..e79aa9437d 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -270,14 +270,10 @@ def write_command( op_id=bwc.op_id, command_name=bwc.name, use_conn_transport=True, - process_response=False, decrypt_reply=False, ) reply = result_docs[0] - # Process the response from the server. - client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: - # Process the response from the server. if isinstance(exc, (NotPrimaryError, OperationFailure)): client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index c1d94f49ed..c4e133cbea 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -257,16 +257,12 @@ def write_command( op_id=bwc.op_id, command_name=bwc.name, use_conn_transport=True, - process_response=False, decrypt_reply=False, ) reply = result_docs[0] - # Process the response from the server. - self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} - # Process the response from the server. if isinstance(exc, OperationFailure): self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] else: @@ -290,9 +286,8 @@ def unack_write( published = dict(cmd) published["ops"] = op_docs published["nsInfo"] = ns_docs - reply: Mapping[str, Any] = {"ok": 1} try: - run_unacknowledged_command( + result_docs, _, _ = run_unacknowledged_command( bwc.conn, # type: ignore[arg-type] cmd, bwc.db_name, @@ -310,6 +305,7 @@ def unack_write( use_conn_transport=True, max_doc_size=bwc.max_bson_size, ) + reply: Mapping[str, Any] = result_docs[0] except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 81b01a2b6a..4b67968591 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -150,8 +150,7 @@ def _run_command( :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the bulk paths pass False (their commands are encrypted up front). :param use_conn_transport: Send/receive via ``conn.send_message`` / - ``conn.receive_message`` (cursor path) or ``conn.unack_write`` (bulk - unacknowledged) instead of the raw ``sendall`` / + ``conn.receive_message`` instead of the raw ``sendall`` / ``receive_message`` (network path). :param max_doc_size: The largest document size, for ``conn.send_message``. :param more_to_come: Receive only, without sending (exhaust ``getMore``). @@ -212,7 +211,11 @@ def _run_command( reply = conn.receive_message(None) elif unacknowledged: if use_conn_transport: - conn.unack_write(msg, max_doc_size) + if not conn.is_writable: + raise NotPrimaryError( + "not primary", {"ok": 0, "errmsg": "not primary", "code": 10107} + ) + conn.send_message(msg, max_doc_size) else: sendall(conn.conn.get_conn, msg) # Unacknowledged, fake a successful command response. @@ -380,9 +383,7 @@ def run_acknowledged_command( :param use_conn_transport: Send/receive via ``conn.send_message`` / ``conn.receive_message`` (bulk path) instead of the raw ``sendall`` / ``receive_message`` (standard command path). - :param process_response: Run ``client._process_response`` here; the bulk - paths pass False and process the reply at the call site to keep their - check -> APM-succeed -> process ordering. + :param process_response: Run ``client._process_response`` here. :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the bulk paths pass False (their commands are encrypted up front). @@ -440,9 +441,9 @@ def run_unacknowledged_command( The message is sent only -- no reply is received -- so the response processing, command checking, and decryption steps are skipped. - :param use_conn_transport: Send via ``conn.unack_write`` (bulk path) instead + :param use_conn_transport: Send via ``conn.send_message`` (bulk path) instead of the raw ``sendall`` (standard command path). - :param max_doc_size: The largest document size, for ``conn.unack_write``. + :param max_doc_size: The largest document size, for ``conn.send_message``. See :func:`_run_command` for the remaining parameters. """ diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 1dd4b81a35..b619fcda95 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -393,7 +393,8 @@ def command( self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None unacknowledged = bool(write_concern and not write_concern.acknowledged) - self._raise_if_not_writable(unacknowledged) + if unacknowledged and not self.is_writable: + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) try: if session is not None and session._starting_transaction: session._transaction.set_in_progress() @@ -454,25 +455,6 @@ def receive_message(self, request_id: Optional[int]) -> _OpMsg: except BaseException as error: self._raise_connection_failure(error) - def _raise_if_not_writable(self, unacknowledged: bool) -> None: - """Raise NotPrimaryError on unacknowledged write if this socket is not - writable. - """ - if unacknowledged and not self.is_writable: - # Write won't succeed, bail as if we'd received a not primary error. - raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - - def unack_write(self, msg: bytes, max_doc_size: int) -> None: - """Send unack OP_MSG. - - Can raise ConnectionFailure or InvalidDocument. - - :param msg: bytes, an OP_MSG message. - :param max_doc_size: size in bytes of the largest document in `msg`. - """ - self._raise_if_not_writable(True) - self.send_message(msg, max_doc_size) - def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. From 9ff6d410e531c0b75b6b1a258aead61581408887 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Mon, 8 Jun 2026 20:20:15 -0400 Subject: [PATCH 09/15] Fix ImportError: remove _OpReply references dropped by PYTHON-5713 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PYTHON-5713 removed _OpReply from pymongo/message.py since OP_MSG is now used exclusively. Update command_runner.py and server.py to drop the _OpReply import and simplify Union[_OpReply, _OpMsg] → _OpMsg in all type annotations. Regenerate sync mirrors via synchro.py. --- pymongo/asynchronous/command_runner.py | 16 ++++++++-------- pymongo/asynchronous/server.py | 6 ++---- pymongo/synchronous/bulk.py | 11 ++++------- pymongo/synchronous/command_runner.py | 20 ++++++++++---------- pymongo/synchronous/server.py | 10 +++------- 5 files changed, 27 insertions(+), 36 deletions(-) diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 40e999da71..bef6d61f93 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -54,7 +54,7 @@ from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection - from pymongo.message import _OpMsg, _OpReply + from pymongo.message import _OpMsg from pymongo.monitoring import _EventListeners from pymongo.pool_options import PoolOptions from pymongo.typings import _Address, _DocumentOut, _DocumentType @@ -96,9 +96,9 @@ async def _run_command( unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, reply_doc_builder: Optional[ - Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + Callable[[list[dict[str, Any]], Optional[_OpMsg]], _DocumentOut] ] = None, -) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. This is the shared implementation behind :func:`run_acknowledged_command`, @@ -205,7 +205,7 @@ async def _run_command( service_id=conn.service_id, ) - reply: Optional[Union[_OpReply, _OpMsg]] + reply: Optional[_OpMsg] try: if more_to_come: reply = await conn.receive_message(None) @@ -372,7 +372,7 @@ async def run_acknowledged_command( use_conn_transport: bool = False, process_response: bool = True, decrypt_reply: bool = True, -) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Send an acknowledged command and return ``(docs, reply, duration)``. This is the entry point for standard commands and bulk write batches: it @@ -435,7 +435,7 @@ async def run_unacknowledged_command( speculative_hello: bool = False, use_conn_transport: bool = False, max_doc_size: int = 0, -) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Send an unacknowledged command and fake an ``{"ok": 1}`` reply. The message is sent only -- no reply is received -- so the response @@ -494,9 +494,9 @@ async def run_cursor_command( unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, reply_doc_builder: Optional[ - Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + Callable[[list[dict[str, Any]], Optional[_OpMsg]], _DocumentOut] ] = None, -) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Run a cursor ``find``/``getMore`` operation over ``conn``. Uses the connection transport, leaves ``conn.more_to_come`` untouched (the diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 0c4fbed00f..57198621ad 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -33,7 +33,7 @@ _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _GetMore, _OpMsg, _OpReply, _Query +from pymongo.message import _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: @@ -169,9 +169,7 @@ async def run_operation( user_fields = _CURSOR_DOC_FIELDS if use_cmd else None - def _build_reply_doc( - docs: list[dict[str, Any]], reply: Optional[Union[_OpReply, _OpMsg]] - ) -> _DocumentOut: + def _build_reply_doc(docs: list[dict[str, Any]], reply: Optional[_OpMsg]) -> _DocumentOut: # Must publish in find / getMore / explain command response format. if use_cmd: return docs[0] diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index e79aa9437d..4cfd7f8106 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -35,7 +35,10 @@ from bson.raw_bson import RawBSONDocument from pymongo import _csot, common from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) from pymongo.synchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, @@ -65,12 +68,6 @@ _randint, ) from pymongo.read_preferences import ReadPreference -from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.command_runner import ( - run_acknowledged_command, - run_unacknowledged_command, -) -from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern if TYPE_CHECKING: diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 4b67968591..e1876f00a2 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -51,12 +51,12 @@ if TYPE_CHECKING: from bson import CodecOptions + from pymongo.message import _OpMsg + from pymongo.monitoring import _EventListeners + from pymongo.pool_options import PoolOptions from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection - from pymongo.message import _OpMsg, _OpReply - from pymongo.monitoring import _EventListeners - from pymongo.pool_options import PoolOptions from pymongo.typings import _Address, _DocumentOut, _DocumentType _IS_SYNC = True @@ -96,9 +96,9 @@ def _run_command( unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, reply_doc_builder: Optional[ - Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + Callable[[list[dict[str, Any]], Optional[_OpMsg]], _DocumentOut] ] = None, -) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Send ``msg`` over ``conn`` and return ``(docs, reply, duration)``. This is the shared implementation behind :func:`run_acknowledged_command`, @@ -205,7 +205,7 @@ def _run_command( service_id=conn.service_id, ) - reply: Optional[Union[_OpReply, _OpMsg]] + reply: Optional[_OpMsg] try: if more_to_come: reply = conn.receive_message(None) @@ -372,7 +372,7 @@ def run_acknowledged_command( use_conn_transport: bool = False, process_response: bool = True, decrypt_reply: bool = True, -) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Send an acknowledged command and return ``(docs, reply, duration)``. This is the entry point for standard commands and bulk write batches: it @@ -435,7 +435,7 @@ def run_unacknowledged_command( speculative_hello: bool = False, use_conn_transport: bool = False, max_doc_size: int = 0, -) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Send an unacknowledged command and fake an ``{"ok": 1}`` reply. The message is sent only -- no reply is received -- so the response @@ -494,9 +494,9 @@ def run_cursor_command( unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, reply_doc_builder: Optional[ - Callable[[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]]], _DocumentOut] + Callable[[list[dict[str, Any]], Optional[_OpMsg]], _DocumentOut] ] = None, -) -> tuple[list[dict[str, Any]], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: +) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Run a cursor ``find``/``getMore`` operation over ``conn``. Uses the connection transport, leaves ``conn.more_to_come`` untouched (the diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 7041ee12e8..84c6aea30a 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -26,17 +26,15 @@ Union, ) -from pymongo.synchronous.command_runner import run_command +from pymongo.synchronous.command_runner import run_cursor_command from pymongo.synchronous.helpers import _handle_reauth from pymongo.logger import ( _SDAM_LOGGER, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _GetMore, _OpMsg, _OpReply, _Query +from pymongo.message import _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response -from pymongo.synchronous.command_runner import run_cursor_command -from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: from queue import Queue @@ -171,9 +169,7 @@ def run_operation( user_fields = _CURSOR_DOC_FIELDS if use_cmd else None - def _build_reply_doc( - docs: list[dict[str, Any]], reply: Optional[Union[_OpReply, _OpMsg]] - ) -> _DocumentOut: + def _build_reply_doc(docs: list[dict[str, Any]], reply: Optional[_OpMsg]) -> _DocumentOut: # Must publish in find / getMore / explain command response format. if use_cmd: return docs[0] From bef5a3530c8cfed21e71f411a911a99040e40e1d Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Mon, 8 Jun 2026 20:28:01 -0400 Subject: [PATCH 10/15] Steve feedback: deduplicate _raise_if_not_writable; drop redundant _process_response calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _raise_if_not_writable() to AsyncConnection (and sync mirror); use it in both pool.write() and command_runner._run_command() so the NotPrimaryError logic lives in one place. - bulk.run(): remove the try/except that manually called _process_response on error — run_acknowledged_command already handles it via process_response=True. - client_bulk.run(): same — remove the _process_response calls from the except block; the {"error": exc} wrapping for ClientBulkWriteException is kept. --- pymongo/asynchronous/bulk.py | 43 +++++++++++--------------- pymongo/asynchronous/client_bulk.py | 4 --- pymongo/asynchronous/command_runner.py | 5 +-- pymongo/asynchronous/pool.py | 9 ++++-- pymongo/synchronous/bulk.py | 43 +++++++++++--------------- pymongo/synchronous/client_bulk.py | 4 --- pymongo/synchronous/command_runner.py | 5 +-- pymongo/synchronous/pool.py | 9 ++++-- 8 files changed, 52 insertions(+), 70 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 97d82818fd..13c79ecf9e 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -55,7 +55,6 @@ from pymongo.errors import ( ConfigurationError, InvalidOperation, - NotPrimaryError, OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES @@ -251,30 +250,24 @@ async def write_command( ) -> dict[str, Any]: """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs - try: - result_docs, _, _ = await run_acknowledged_command( - bwc.conn, # type: ignore[arg-type] - cmd, - bwc.db_name, - request_id, - msg, - client=client, - session=bwc.session, # type: ignore[arg-type] - listeners=bwc.listeners, - address=bwc.conn.address, - start=bwc.start_time, - codec_options=bwc.codec, - op_id=bwc.op_id, - command_name=bwc.name, - use_conn_transport=True, - decrypt_reply=False, - ) - reply = result_docs[0] - except Exception as exc: - if isinstance(exc, (NotPrimaryError, OperationFailure)): - await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - raise - return reply + result_docs, _, _ = await run_acknowledged_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + decrypt_reply=False, + ) + return result_docs[0] async def unack_write( self, diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 85a5ada6e3..88c69a273f 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -263,10 +263,6 @@ async def write_command( except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} - if isinstance(exc, OperationFailure): - await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - else: - await self.client._process_response({}, bwc.session) # type: ignore[arg-type] return reply # type: ignore[return-value] async def unack_write( diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index bef6d61f93..80b53e1c19 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -211,10 +211,7 @@ async def _run_command( reply = await conn.receive_message(None) elif unacknowledged: if use_conn_transport: - if not conn.is_writable: - raise NotPrimaryError( - "not primary", {"ok": 0, "errmsg": "not primary", "code": 10107} - ) + conn._raise_if_not_writable() await conn.send_message(msg, max_doc_size) else: await async_sendall(conn.conn.get_conn, msg) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 60ba236c3d..f61c040fa9 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -393,8 +393,8 @@ async def command( self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None unacknowledged = bool(write_concern and not write_concern.acknowledged) - if unacknowledged and not self.is_writable: - raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + if unacknowledged: + self._raise_if_not_writable() try: if session is not None and session._starting_transaction: session._transaction.set_in_progress() @@ -427,6 +427,11 @@ async def command( except BaseException as error: await self._raise_connection_failure(error) + def _raise_if_not_writable(self) -> None: + """Raise NotPrimaryError if this connection is not writable.""" + if not self.is_writable: + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + async def send_message(self, message: bytes, max_doc_size: int) -> None: """Send a raw BSON message or raise ConnectionFailure. diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 4cfd7f8106..511d226bf0 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -55,7 +55,6 @@ from pymongo.errors import ( ConfigurationError, InvalidOperation, - NotPrimaryError, OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES @@ -251,30 +250,24 @@ def write_command( ) -> dict[str, Any]: """Run a batch write command, returning the response as a dict.""" cmd[bwc.field] = docs - try: - result_docs, _, _ = run_acknowledged_command( - bwc.conn, # type: ignore[arg-type] - cmd, - bwc.db_name, - request_id, - msg, - client=client, - session=bwc.session, # type: ignore[arg-type] - listeners=bwc.listeners, - address=bwc.conn.address, - start=bwc.start_time, - codec_options=bwc.codec, - op_id=bwc.op_id, - command_name=bwc.name, - use_conn_transport=True, - decrypt_reply=False, - ) - reply = result_docs[0] - except Exception as exc: - if isinstance(exc, (NotPrimaryError, OperationFailure)): - client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - raise - return reply + result_docs, _, _ = run_acknowledged_command( + bwc.conn, # type: ignore[arg-type] + cmd, + bwc.db_name, + request_id, + msg, + client=client, + session=bwc.session, # type: ignore[arg-type] + listeners=bwc.listeners, + address=bwc.conn.address, + start=bwc.start_time, + codec_options=bwc.codec, + op_id=bwc.op_id, + command_name=bwc.name, + use_conn_transport=True, + decrypt_reply=False, + ) + return result_docs[0] def unack_write( self, diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index c4e133cbea..a472f90f99 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -263,10 +263,6 @@ def write_command( except Exception as exc: # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} - if isinstance(exc, OperationFailure): - self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - else: - self.client._process_response({}, bwc.session) # type: ignore[arg-type] return reply # type: ignore[return-value] def unack_write( diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index e1876f00a2..d49370caf2 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -211,10 +211,7 @@ def _run_command( reply = conn.receive_message(None) elif unacknowledged: if use_conn_transport: - if not conn.is_writable: - raise NotPrimaryError( - "not primary", {"ok": 0, "errmsg": "not primary", "code": 10107} - ) + conn._raise_if_not_writable() conn.send_message(msg, max_doc_size) else: sendall(conn.conn.get_conn, msg) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index b619fcda95..9270c36a69 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -393,8 +393,8 @@ def command( self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None unacknowledged = bool(write_concern and not write_concern.acknowledged) - if unacknowledged and not self.is_writable: - raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + if unacknowledged: + self._raise_if_not_writable() try: if session is not None and session._starting_transaction: session._transaction.set_in_progress() @@ -427,6 +427,11 @@ def command( except BaseException as error: self._raise_connection_failure(error) + def _raise_if_not_writable(self) -> None: + """Raise NotPrimaryError if this connection is not writable.""" + if not self.is_writable: + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + def send_message(self, message: bytes, max_doc_size: int) -> None: """Send a raw BSON message or raise ConnectionFailure. From ebc74b6b4bbb480e37a71eec9f23c9a5229dd925 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Wed, 10 Jun 2026 11:02:34 -0400 Subject: [PATCH 11/15] Copilot feedback --- pymongo/asynchronous/command_runner.py | 5 ++--- pymongo/synchronous/command_runner.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 80b53e1c19..967c3d6666 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -144,9 +144,8 @@ async def _run_command( APM redaction. :param ensure_db: Add ``$db`` to the published command if missing (cursor path), after the ``STARTED`` log has been emitted. - :param process_response: Run ``client._process_response`` on success here; - the bulk paths pass False and process the reply at the call site to - keep their check -> APM-succeed -> process ordering. + :param process_response: Run ``client._process_response`` on the response + document before ``_check_command_response`` and APM/log events. :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the bulk paths pass False (their commands are encrypted up front). :param use_conn_transport: Send/receive via ``conn.send_message`` / diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index d49370caf2..075df0b6a9 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -144,9 +144,8 @@ def _run_command( APM redaction. :param ensure_db: Add ``$db`` to the published command if missing (cursor path), after the ``STARTED`` log has been emitted. - :param process_response: Run ``client._process_response`` on success here; - the bulk paths pass False and process the reply at the call site to - keep their check -> APM-succeed -> process ordering. + :param process_response: Run ``client._process_response`` on the response + document before ``_check_command_response`` and APM/log events. :param decrypt_reply: Decrypt the reply when auto-encryption is enabled; the bulk paths pass False (their commands are encrypted up front). :param use_conn_transport: Send/receive via ``conn.send_message`` / From 82ec67b004a6fda5f7eaed36012683dbf9e08a33 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 11 Jun 2026 14:48:29 -0400 Subject: [PATCH 12/15] ruff format --- pymongo/synchronous/bulk.py | 16 +++++++--------- pymongo/synchronous/client_bulk.py | 4 +--- pymongo/synchronous/server.py | 12 ++++++------ 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 511d226bf0..c8449af496 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -34,12 +34,6 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.command_runner import ( - run_acknowledged_command, - run_unacknowledged_command, -) -from pymongo.synchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( _COMMANDS, _DELETE_ALL, @@ -67,6 +61,12 @@ _randint, ) from pymongo.read_preferences import ReadPreference +from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.command_runner import ( + run_acknowledged_command, + run_unacknowledged_command, +) +from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -505,9 +505,7 @@ def retryable_bulk( _raise_bulk_write_error(full_result) return full_result - def execute_op_msg_no_results( - self, conn: Connection, generator: Iterator[Any] - ) -> None: + def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index a472f90f99..12b7a65c55 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -328,9 +328,7 @@ def _execute_batch( ) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]: """Executes a batch of bulkWrite server commands (ack).""" request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces) - result = self.write_command( - bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client - ) # type: ignore[arg-type] + result = self.write_command(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type] return result, to_send_ops, to_send_ns # type: ignore[return-value] def _process_results_cursor( diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 84c6aea30a..2c524bb2c6 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -20,14 +20,12 @@ from typing import ( TYPE_CHECKING, Any, - ContextManager, Callable, + ContextManager, Optional, Union, ) -from pymongo.synchronous.command_runner import run_cursor_command -from pymongo.synchronous.helpers import _handle_reauth from pymongo.logger import ( _SDAM_LOGGER, _debug_log, @@ -35,18 +33,20 @@ ) from pymongo.message import _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response +from pymongo.synchronous.command_runner import run_cursor_command +from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: from queue import Queue from weakref import ReferenceType from bson.objectid import ObjectId - from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler - from pymongo.synchronous.monitor import Monitor - from pymongo.synchronous.pool import Connection, Pool from pymongo.monitoring import _EventListeners from pymongo.read_preferences import _ServerMode from pymongo.server_description import ServerDescription + from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler + from pymongo.synchronous.monitor import Monitor + from pymongo.synchronous.pool import Connection, Pool from pymongo.typings import _DocumentOut _IS_SYNC = True From 4d7fdf3cb6dc163bf9123aa226513317ace3a5f8 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 11 Jun 2026 15:43:33 -0400 Subject: [PATCH 13/15] Restore location of _raise_if_not_writable --- pymongo/asynchronous/pool.py | 10 +++++----- pymongo/synchronous/pool.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index f61c040fa9..94d54f04f5 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -427,11 +427,6 @@ async def command( except BaseException as error: await self._raise_connection_failure(error) - def _raise_if_not_writable(self) -> None: - """Raise NotPrimaryError if this connection is not writable.""" - if not self.is_writable: - raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - async def send_message(self, message: bytes, max_doc_size: int) -> None: """Send a raw BSON message or raise ConnectionFailure. @@ -449,6 +444,11 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None: except BaseException as error: await self._raise_connection_failure(error) + def _raise_if_not_writable(self) -> None: + """Raise NotPrimaryError if this connection is not writable.""" + if not self.is_writable: + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + async def receive_message(self, request_id: Optional[int]) -> _OpMsg: """Receive a raw BSON message or raise ConnectionFailure. diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 9270c36a69..6d3ce24991 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -427,11 +427,6 @@ def command( except BaseException as error: self._raise_connection_failure(error) - def _raise_if_not_writable(self) -> None: - """Raise NotPrimaryError if this connection is not writable.""" - if not self.is_writable: - raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - def send_message(self, message: bytes, max_doc_size: int) -> None: """Send a raw BSON message or raise ConnectionFailure. @@ -449,6 +444,11 @@ def send_message(self, message: bytes, max_doc_size: int) -> None: except BaseException as error: self._raise_connection_failure(error) + def _raise_if_not_writable(self) -> None: + """Raise NotPrimaryError if this connection is not writable.""" + if not self.is_writable: + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + def receive_message(self, request_id: Optional[int]) -> _OpMsg: """Receive a raw BSON message or raise ConnectionFailure. From 222d948f8629b1fb1a8853249a1631cb5ed618a0 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Thu, 11 Jun 2026 16:49:28 -0400 Subject: [PATCH 14/15] Restore location of _raise_if_not_writable --- pymongo/asynchronous/pool.py | 10 +++++----- pymongo/synchronous/pool.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 94d54f04f5..9e1fc87ae2 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -444,11 +444,6 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None: except BaseException as error: await self._raise_connection_failure(error) - def _raise_if_not_writable(self) -> None: - """Raise NotPrimaryError if this connection is not writable.""" - if not self.is_writable: - raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - async def receive_message(self, request_id: Optional[int]) -> _OpMsg: """Receive a raw BSON message or raise ConnectionFailure. @@ -460,6 +455,11 @@ async def receive_message(self, request_id: Optional[int]) -> _OpMsg: except BaseException as error: await self._raise_connection_failure(error) + def _raise_if_not_writable(self) -> None: + """Raise NotPrimaryError if this connection is not writable.""" + if not self.is_writable: + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + async def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 6d3ce24991..545892cc76 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -444,11 +444,6 @@ def send_message(self, message: bytes, max_doc_size: int) -> None: except BaseException as error: self._raise_connection_failure(error) - def _raise_if_not_writable(self) -> None: - """Raise NotPrimaryError if this connection is not writable.""" - if not self.is_writable: - raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - def receive_message(self, request_id: Optional[int]) -> _OpMsg: """Receive a raw BSON message or raise ConnectionFailure. @@ -460,6 +455,11 @@ def receive_message(self, request_id: Optional[int]) -> _OpMsg: except BaseException as error: self._raise_connection_failure(error) + def _raise_if_not_writable(self) -> None: + """Raise NotPrimaryError if this connection is not writable.""" + if not self.is_writable: + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. From b52b28de24c1eaa71ab9d4d8d5c7972bfc525f10 Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Fri, 12 Jun 2026 13:39:31 -0400 Subject: [PATCH 15/15] PYTHON-5676 Address Noah's review comments - Merge command_encoder.py into command_runner.py; update pool.py import - Fix _build_reply_doc regression in server.py: always return docs[0] - Remove dead code: is_command_response param from _run_command/run_cursor_command - Remove dead code: legacy_response passthrough in unpack_res call - Remove dead code: self.publish from _BulkWriteContextBase - Add explicit set_conn_more_to_come=False to bulk write_command call Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pymongo/asynchronous/bulk.py | 1 + pymongo/asynchronous/command_encoder.py | 188 ---------------------- pymongo/asynchronous/command_runner.py | 197 ++++++++++++++++++++---- pymongo/asynchronous/pool.py | 2 +- pymongo/asynchronous/server.py | 17 +- pymongo/message.py | 2 - pymongo/synchronous/bulk.py | 1 + pymongo/synchronous/command_runner.py | 197 ++++++++++++++++++++---- pymongo/synchronous/pool.py | 2 +- pymongo/synchronous/server.py | 17 +- 10 files changed, 340 insertions(+), 284 deletions(-) delete mode 100644 pymongo/asynchronous/command_encoder.py diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 13c79ecf9e..cc33d89371 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -266,6 +266,7 @@ async def write_command( command_name=bwc.name, use_conn_transport=True, decrypt_reply=False, + set_conn_more_to_come=False, ) return result_docs[0] diff --git a/pymongo/asynchronous/command_encoder.py b/pymongo/asynchronous/command_encoder.py deleted file mode 100644 index e60aabc8e4..0000000000 --- a/pymongo/asynchronous/command_encoder.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Encode a command and run it over a connection. - -This builds the wire-protocol message for a single command -- applying read -preference, read concern, collation, ``$clusterTime``, auto-encryption, CSOT, -and OP_MSG encoding -- then hands it to -:func:`pymongo.asynchronous.command_runner.run_acknowledged_command` for the -network round -trip. The raw socket I/O lives in :mod:`pymongo.network_layer`. -""" -from __future__ import annotations - -import datetime -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, - Sequence, - Union, -) - -from pymongo import _csot, message -from pymongo.asynchronous.command_runner import ( - run_acknowledged_command, - run_unacknowledged_command, -) -from pymongo.compression_support import _NO_COMPRESSION -from pymongo.message import _OpMsg -from pymongo.monitoring import _is_speculative_authenticate - -if TYPE_CHECKING: - from bson import CodecOptions - from pymongo.asynchronous.client_session import AsyncClientSession - from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.pool import AsyncConnection - from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext - from pymongo.monitoring import _EventListeners - from pymongo.read_concern import ReadConcern - from pymongo.read_preferences import _ServerMode - from pymongo.typings import _Address, _CollationIn, _DocumentType - from pymongo.write_concern import WriteConcern - -_IS_SYNC = False - - -async def command( - conn: AsyncConnection, - dbname: str, - spec: MutableMapping[str, Any], - is_mongos: bool, # noqa: ARG001 - read_preference: Optional[_ServerMode], - codec_options: CodecOptions[_DocumentType], - session: Optional[AsyncClientSession], - client: Optional[AsyncMongoClient[Any]], - check: bool = True, - allowable_errors: Optional[Sequence[Union[str, int]]] = None, - address: Optional[_Address] = None, - listeners: Optional[_EventListeners] = None, - max_bson_size: Optional[int] = None, - read_concern: Optional[ReadConcern] = None, - parse_write_concern_error: bool = False, - collation: Optional[_CollationIn] = None, - compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, - unacknowledged: bool = False, - user_fields: Optional[Mapping[str, Any]] = None, - exhaust_allowed: bool = False, - write_concern: Optional[WriteConcern] = None, -) -> _DocumentType: - """Execute a command over the socket, or raise socket.error. - - :param conn: a AsyncConnection instance - :param dbname: name of the database on which to run the command - :param spec: a command document as an ordered dict type, eg SON. - :param is_mongos: are we connected to a mongos? - :param read_preference: a read preference - :param codec_options: a CodecOptions instance - :param session: optional AsyncClientSession instance. - :param client: optional AsyncMongoClient instance for updating $clusterTime. - :param check: raise OperationFailure if there are errors - :param allowable_errors: errors to ignore if `check` is True - :param address: the (host, port) of `conn` - :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` - :param max_bson_size: The maximum encoded bson size for this server - :param read_concern: The read concern for this command. - :param parse_write_concern_error: Whether to parse the ``writeConcernError`` - field in the command response. - :param collation: The collation for this command. - :param compression_ctx: optional compression Context. - :param unacknowledged: True if this is an unacknowledged command. - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. - """ - name = next(iter(spec)) - speculative_hello = False - - # Publish the original command document, perhaps with lsid and $clusterTime. - orig = spec - if read_concern and not (session and session.in_transaction): - if read_concern.level: - spec["readConcern"] = read_concern.document - if session: - session._update_read_concern(spec, conn) - if collation is not None: - spec["collation"] = collation - - publish = listeners is not None and listeners.enabled_for_commands - start = datetime.datetime.now() - if publish: - speculative_hello = _is_speculative_authenticate(name, spec) - - if compression_ctx and name.lower() in _NO_COMPRESSION: - compression_ctx = None - - if client and client._encrypter and not client._encrypter._bypass_auto_encryption: - spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options) - - # Support CSOT - if client: - conn.apply_timeout(client, spec) - _csot.apply_write_concern(spec, write_concern) - - flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 - flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 - request_id, msg, size, max_doc_size = message._op_msg( - flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx - ) - # If this is an unacknowledged write then make sure the encoded doc(s) - # are small enough, otherwise rely on the server to return an error. - if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: - message._raise_document_too_large(name, size, max_bson_size) - - if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: - message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if unacknowledged: - docs, _, _ = await run_unacknowledged_command( - conn, - spec, - dbname, - request_id, - msg, - client=client, - session=session, - listeners=listeners, - address=address, - start=start, - codec_options=codec_options, - user_fields=user_fields, - orig=orig, - speculative_hello=speculative_hello, - ) - else: - docs, _, _ = await run_acknowledged_command( - conn, - spec, - dbname, - request_id, - msg, - client=client, - session=session, - listeners=listeners, - address=address, - start=start, - codec_options=codec_options, - user_fields=user_fields, - orig=orig, - check=check, - allowable_errors=allowable_errors, - parse_write_concern_error=parse_write_concern_error, - speculative_hello=speculative_hello, - ) - return docs[0] # type: ignore[return-value] diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 967c3d6666..1779eed165 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The shared code path for executing a command over a connection. +"""Encoding and execution of commands over a connection. + +The public :func:`command` entry point applies read preference, read concern, +collation, ``$clusterTime``, auto-encryption, and CSOT to a command spec, +encodes it as an OP_MSG message, and then delegates to one of three lower-level +runners. Every database operation runs its network round trip through one of three public entry points -- :func:`run_acknowledged_command` (acknowledged commands @@ -43,10 +48,12 @@ ) from bson import _decode_all_selective -from pymongo import helpers_shared +from pymongo import _csot, helpers_shared, message +from pymongo.compression_support import _NO_COMPRESSION from pymongo.errors import NotPrimaryError, OperationFailure from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _convert_exception +from pymongo.message import _convert_exception, _OpMsg +from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import async_receive_message, async_sendall if TYPE_CHECKING: @@ -54,10 +61,13 @@ from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection - from pymongo.message import _OpMsg + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.monitoring import _EventListeners from pymongo.pool_options import PoolOptions - from pymongo.typings import _Address, _DocumentOut, _DocumentType + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -92,7 +102,6 @@ async def _run_command( max_doc_size: int = 0, more_to_come: bool = False, set_conn_more_to_come: bool = True, - is_command_response: bool = True, unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, reply_doc_builder: Optional[ @@ -156,9 +165,6 @@ async def _run_command( :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the network/streaming-monitor path); the cursor path manages exhaust separately and must leave ``conn.more_to_come`` untouched. - :param is_command_response: True if the reply is an OP_MSG command response - (``_process_response``/``_check_command_response``/decryption apply); - False for a legacy OP_QUERY cursor response. :param unpack_res: A callable decoding the wire response (cursor path); when ``None`` the reply's own ``unpack_response`` is used. :param cursor_id: The cursor id passed to ``unpack_res``. @@ -234,27 +240,25 @@ async def _run_command( reply, cursor_id, codec_options, - legacy_response=not is_command_response, user_fields=user_fields, ) else: docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) - if is_command_response: - response_doc = docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if process_response and client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - pool_opts=pool_opts, - ) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if process_response and client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + pool_opts=pool_opts, + ) except Exception as exc: duration = datetime.datetime.now() - start if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -335,7 +339,7 @@ async def _run_command( database_name=dbname, ) - if client and client._encrypter and reply and is_command_response and decrypt_reply: + if client and client._encrypter and reply and decrypt_reply: decrypted = await client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) @@ -486,7 +490,6 @@ async def run_cursor_command( pool_opts: Optional[PoolOptions] = None, max_doc_size: int = 0, more_to_come: bool = False, - is_command_response: bool = True, unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, reply_doc_builder: Optional[ @@ -500,8 +503,6 @@ async def run_cursor_command( the find/getMore command response format. :param more_to_come: Receive only, without sending (exhaust ``getMore``). - :param is_command_response: True for an OP_MSG command response; False for a - legacy OP_QUERY cursor response. :param unpack_res: A callable decoding the wire response. :param cursor_id: The cursor id passed to ``unpack_res``. :param reply_doc_builder: Builds the reply document published in the @@ -529,8 +530,142 @@ async def run_cursor_command( max_doc_size=max_doc_size, more_to_come=more_to_come, set_conn_more_to_come=False, - is_command_response=is_command_response, unpack_res=unpack_res, cursor_id=cursor_id, reply_doc_builder=reply_doc_builder, ) + + +async def command( + conn: AsyncConnection, + dbname: str, + spec: MutableMapping[str, Any], + is_mongos: bool, # noqa: ARG001 + read_preference: Optional[_ServerMode], + codec_options: CodecOptions[_DocumentType], + session: Optional[AsyncClientSession], + client: Optional[AsyncMongoClient[Any]], + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + address: Optional[_Address] = None, + listeners: Optional[_EventListeners] = None, + max_bson_size: Optional[int] = None, + read_concern: Optional[ReadConcern] = None, + parse_write_concern_error: bool = False, + collation: Optional[_CollationIn] = None, + compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, + unacknowledged: bool = False, + user_fields: Optional[Mapping[str, Any]] = None, + exhaust_allowed: bool = False, + write_concern: Optional[WriteConcern] = None, +) -> _DocumentType: + """Encode and execute a command over ``conn``, or raise socket.error. + + Applies read preference, read concern, collation, ``$clusterTime``, + auto-encryption, and CSOT to ``spec``, encodes it as an OP_MSG message, + and then delegates the network round trip and response processing to + :func:`run_acknowledged_command` or :func:`run_unacknowledged_command`. + + :param conn: a AsyncConnection instance + :param dbname: name of the database on which to run the command + :param spec: a command document as an ordered dict type, eg SON. + :param is_mongos: are we connected to a mongos? + :param read_preference: a read preference + :param codec_options: a CodecOptions instance + :param session: optional AsyncClientSession instance. + :param client: optional AsyncMongoClient instance for updating $clusterTime. + :param check: raise OperationFailure if there are errors + :param allowable_errors: errors to ignore if `check` is True + :param address: the (host, port) of `conn` + :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` + :param max_bson_size: The maximum encoded bson size for this server + :param read_concern: The read concern for this command. + :param parse_write_concern_error: Whether to parse the ``writeConcernError`` + field in the command response. + :param collation: The collation for this command. + :param compression_ctx: optional compression Context. + :param unacknowledged: True if this is an unacknowledged command. + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. + """ + name = next(iter(spec)) + speculative_hello = False + + # Publish the original command document, perhaps with lsid and $clusterTime. + orig = spec + if read_concern and not (session and session.in_transaction): + if read_concern.level: + spec["readConcern"] = read_concern.document + if session: + session._update_read_concern(spec, conn) + if collation is not None: + spec["collation"] = collation + + publish = listeners is not None and listeners.enabled_for_commands + start = datetime.datetime.now() + if publish: + speculative_hello = _is_speculative_authenticate(name, spec) + + if compression_ctx and name.lower() in _NO_COMPRESSION: + compression_ctx = None + + if client and client._encrypter and not client._encrypter._bypass_auto_encryption: + spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options) + + # Support CSOT + if client: + conn.apply_timeout(client, spec) + _csot.apply_write_concern(spec, write_concern) + + flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 + flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 + request_id, msg, size, max_doc_size = message._op_msg( + flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx + ) + # If this is an unacknowledged write then make sure the encoded doc(s) + # are small enough, otherwise rely on the server to return an error. + if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: + message._raise_document_too_large(name, size, max_bson_size) + + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) + if unacknowledged: + docs, _, _ = await run_unacknowledged_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + speculative_hello=speculative_hello, + ) + else: + docs, _, _ = await run_acknowledged_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + ) + return docs[0] # type: ignore[return-value] diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 9e1fc87ae2..32cdd9c179 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -39,7 +39,7 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern -from pymongo.asynchronous.command_encoder import command +from pymongo.asynchronous.command_runner import command from pymongo.asynchronous.helpers import _handle_reauth from pymongo.common import ( MAX_BSON_SIZE, diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 57198621ad..faad122764 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -169,21 +169,9 @@ async def run_operation( user_fields = _CURSOR_DOC_FIELDS if use_cmd else None - def _build_reply_doc(docs: list[dict[str, Any]], reply: Optional[_OpMsg]) -> _DocumentOut: + def _build_reply_doc(docs: list[dict[str, Any]], reply: Optional[_OpMsg]) -> _DocumentOut: # noqa: ARG001 # Must publish in find / getMore / explain command response format. - if use_cmd: - return docs[0] - elif operation.name == "explain": - return docs[0] if docs else {} - res: dict[str, Any] = { - "cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, # type: ignore[union-attr] - "ok": 1, - } - if operation.name == "find": - res["cursor"]["firstBatch"] = docs - else: - res["cursor"]["nextBatch"] = docs - return res + return docs[0] docs, reply, duration = await run_cursor_command( conn, @@ -202,7 +190,6 @@ def _build_reply_doc(docs: list[dict[str, Any]], reply: Optional[_OpMsg]) -> _Do pool_opts=conn.opts, max_doc_size=max_doc_size, more_to_come=bool(more_to_come), - is_command_response=use_cmd, unpack_res=unpack_res, cursor_id=operation.cursor_id, reply_doc_builder=_build_reply_doc, diff --git a/pymongo/message.py b/pymongo/message.py index b6209b9df0..7fa40c35e3 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -447,7 +447,6 @@ class _BulkWriteContextBase: "op_id", "name", "field", - "publish", "start_time", "listeners", "session", @@ -471,7 +470,6 @@ def __init__( self.conn = conn self.op_id = operation_id self.listeners = listeners - self.publish = listeners.enabled_for_commands self.name = cmd_name self.field = _FIELD_MAP[self.name] self.start_time = datetime.datetime.now() diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index c8449af496..509759dac6 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -266,6 +266,7 @@ def write_command( command_name=bwc.name, use_conn_transport=True, decrypt_reply=False, + set_conn_more_to_come=False, ) return result_docs[0] diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 075df0b6a9..75090809a6 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The shared code path for executing a command over a connection. +"""Encoding and execution of commands over a connection. + +The public :func:`command` entry point applies read preference, read concern, +collation, ``$clusterTime``, auto-encryption, and CSOT to a command spec, +encodes it as an OP_MSG message, and then delegates to one of three lower-level +runners. Every database operation runs its network round trip through one of three public entry points -- :func:`run_acknowledged_command` (acknowledged commands @@ -43,21 +48,26 @@ ) from bson import _decode_all_selective -from pymongo import helpers_shared +from pymongo import _csot, helpers_shared, message +from pymongo.compression_support import _NO_COMPRESSION from pymongo.errors import NotPrimaryError, OperationFailure from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _convert_exception +from pymongo.message import _convert_exception, _OpMsg +from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import receive_message, sendall if TYPE_CHECKING: from bson import CodecOptions - from pymongo.message import _OpMsg + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.monitoring import _EventListeners from pymongo.pool_options import PoolOptions + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection - from pymongo.typings import _Address, _DocumentOut, _DocumentType + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -92,7 +102,6 @@ def _run_command( max_doc_size: int = 0, more_to_come: bool = False, set_conn_more_to_come: bool = True, - is_command_response: bool = True, unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, reply_doc_builder: Optional[ @@ -156,9 +165,6 @@ def _run_command( :param set_conn_more_to_come: Store ``reply.more_to_come`` on ``conn`` (the network/streaming-monitor path); the cursor path manages exhaust separately and must leave ``conn.more_to_come`` untouched. - :param is_command_response: True if the reply is an OP_MSG command response - (``_process_response``/``_check_command_response``/decryption apply); - False for a legacy OP_QUERY cursor response. :param unpack_res: A callable decoding the wire response (cursor path); when ``None`` the reply's own ``unpack_response`` is used. :param cursor_id: The cursor id passed to ``unpack_res``. @@ -234,27 +240,25 @@ def _run_command( reply, cursor_id, codec_options, - legacy_response=not is_command_response, user_fields=user_fields, ) else: docs = reply.unpack_response(codec_options=codec_options, user_fields=user_fields) - if is_command_response: - response_doc = docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if process_response and client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - pool_opts=pool_opts, - ) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if process_response and client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + pool_opts=pool_opts, + ) except Exception as exc: duration = datetime.datetime.now() - start if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -335,7 +339,7 @@ def _run_command( database_name=dbname, ) - if client and client._encrypter and reply and is_command_response and decrypt_reply: + if client and client._encrypter and reply and decrypt_reply: decrypted = client._encrypter.decrypt(reply.raw_command_response()) docs = cast( "list[dict[str, Any]]", _decode_all_selective(decrypted, codec_options, user_fields) @@ -486,7 +490,6 @@ def run_cursor_command( pool_opts: Optional[PoolOptions] = None, max_doc_size: int = 0, more_to_come: bool = False, - is_command_response: bool = True, unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, reply_doc_builder: Optional[ @@ -500,8 +503,6 @@ def run_cursor_command( the find/getMore command response format. :param more_to_come: Receive only, without sending (exhaust ``getMore``). - :param is_command_response: True for an OP_MSG command response; False for a - legacy OP_QUERY cursor response. :param unpack_res: A callable decoding the wire response. :param cursor_id: The cursor id passed to ``unpack_res``. :param reply_doc_builder: Builds the reply document published in the @@ -529,8 +530,142 @@ def run_cursor_command( max_doc_size=max_doc_size, more_to_come=more_to_come, set_conn_more_to_come=False, - is_command_response=is_command_response, unpack_res=unpack_res, cursor_id=cursor_id, reply_doc_builder=reply_doc_builder, ) + + +def command( + conn: Connection, + dbname: str, + spec: MutableMapping[str, Any], + is_mongos: bool, # noqa: ARG001 + read_preference: Optional[_ServerMode], + codec_options: CodecOptions[_DocumentType], + session: Optional[ClientSession], + client: Optional[MongoClient[Any]], + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + address: Optional[_Address] = None, + listeners: Optional[_EventListeners] = None, + max_bson_size: Optional[int] = None, + read_concern: Optional[ReadConcern] = None, + parse_write_concern_error: bool = False, + collation: Optional[_CollationIn] = None, + compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, + unacknowledged: bool = False, + user_fields: Optional[Mapping[str, Any]] = None, + exhaust_allowed: bool = False, + write_concern: Optional[WriteConcern] = None, +) -> _DocumentType: + """Encode and execute a command over ``conn``, or raise socket.error. + + Applies read preference, read concern, collation, ``$clusterTime``, + auto-encryption, and CSOT to ``spec``, encodes it as an OP_MSG message, + and then delegates the network round trip and response processing to + :func:`run_acknowledged_command` or :func:`run_unacknowledged_command`. + + :param conn: a Connection instance + :param dbname: name of the database on which to run the command + :param spec: a command document as an ordered dict type, eg SON. + :param is_mongos: are we connected to a mongos? + :param read_preference: a read preference + :param codec_options: a CodecOptions instance + :param session: optional ClientSession instance. + :param client: optional MongoClient instance for updating $clusterTime. + :param check: raise OperationFailure if there are errors + :param allowable_errors: errors to ignore if `check` is True + :param address: the (host, port) of `conn` + :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` + :param max_bson_size: The maximum encoded bson size for this server + :param read_concern: The read concern for this command. + :param parse_write_concern_error: Whether to parse the ``writeConcernError`` + field in the command response. + :param collation: The collation for this command. + :param compression_ctx: optional compression Context. + :param unacknowledged: True if this is an unacknowledged command. + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. + """ + name = next(iter(spec)) + speculative_hello = False + + # Publish the original command document, perhaps with lsid and $clusterTime. + orig = spec + if read_concern and not (session and session.in_transaction): + if read_concern.level: + spec["readConcern"] = read_concern.document + if session: + session._update_read_concern(spec, conn) + if collation is not None: + spec["collation"] = collation + + publish = listeners is not None and listeners.enabled_for_commands + start = datetime.datetime.now() + if publish: + speculative_hello = _is_speculative_authenticate(name, spec) + + if compression_ctx and name.lower() in _NO_COMPRESSION: + compression_ctx = None + + if client and client._encrypter and not client._encrypter._bypass_auto_encryption: + spec = orig = client._encrypter.encrypt(dbname, spec, codec_options) + + # Support CSOT + if client: + conn.apply_timeout(client, spec) + _csot.apply_write_concern(spec, write_concern) + + flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 + flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 + request_id, msg, size, max_doc_size = message._op_msg( + flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx + ) + # If this is an unacknowledged write then make sure the encoded doc(s) + # are small enough, otherwise rely on the server to return an error. + if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: + message._raise_document_too_large(name, size, max_bson_size) + + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) + if unacknowledged: + docs, _, _ = run_unacknowledged_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + speculative_hello=speculative_hello, + ) + else: + docs, _, _ = run_acknowledged_command( + conn, + spec, + dbname, + request_id, + msg, + client=client, + session=session, + listeners=listeners, + address=address, + start=start, + codec_options=codec_options, + user_fields=user_fields, + orig=orig, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + speculative_hello=speculative_hello, + ) + return docs[0] # type: ignore[return-value] diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 545892cc76..ccb3be2fe1 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -88,7 +88,7 @@ from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker from pymongo.synchronous.client_session import _validate_session_write_concern -from pymongo.synchronous.command_encoder import command +from pymongo.synchronous.command_runner import command from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 2c524bb2c6..ec1654877b 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -169,21 +169,9 @@ def run_operation( user_fields = _CURSOR_DOC_FIELDS if use_cmd else None - def _build_reply_doc(docs: list[dict[str, Any]], reply: Optional[_OpMsg]) -> _DocumentOut: + def _build_reply_doc(docs: list[dict[str, Any]], reply: Optional[_OpMsg]) -> _DocumentOut: # noqa: ARG001 # Must publish in find / getMore / explain command response format. - if use_cmd: - return docs[0] - elif operation.name == "explain": - return docs[0] if docs else {} - res: dict[str, Any] = { - "cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, # type: ignore[union-attr] - "ok": 1, - } - if operation.name == "find": - res["cursor"]["firstBatch"] = docs - else: - res["cursor"]["nextBatch"] = docs - return res + return docs[0] docs, reply, duration = run_cursor_command( conn, @@ -202,7 +190,6 @@ def _build_reply_doc(docs: list[dict[str, Any]], reply: Optional[_OpMsg]) -> _Do pool_opts=conn.opts, max_doc_size=max_doc_size, more_to_come=bool(more_to_come), - is_command_response=use_cmd, unpack_res=unpack_res, cursor_id=operation.cursor_id, reply_doc_builder=_build_reply_doc,