Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 50 additions & 163 deletions pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from __future__ import annotations

import copy
import datetime
import logging
from collections.abc import MutableMapping
from itertools import islice
from typing import (
Expand All @@ -36,9 +34,10 @@
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.asynchronous.client_session import (
AsyncClientSession,
_validate_session_write_concern,
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
from pymongo.asynchronous.command_runner import (
run_acknowledged_command,
run_unacknowledged_command,
)
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.bulk_shared import (
Expand All @@ -56,18 +55,14 @@
from pymongo.errors import (
ConfigurationError,
InvalidOperation,
NotPrimaryError,
OperationFailure,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import (
_DELETE,
_INSERT,
_UPDATE,
_BulkWriteContext,
_convert_exception,
_convert_write_result,
_EncryptedBulkWriteContext,
_randint,
)
Expand Down Expand Up @@ -253,83 +248,27 @@ async def write_command(
docs: list[Mapping[str, Any]],
client: AsyncMongoClient[Any],
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
"""Run a batch write command, returning the response as a dict."""
cmd[bwc.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.STARTED,
clientId=client._topology_settings._topology_id,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._start(cmd, request_id, docs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_BulkWriteContext._start, _BulkWriteContext._succeed, _BulkWriteContext._fail, and _ClientBulkWriteContext._start are all dead code.

try:
if bwc.session is not None and bwc.session._starting_transaction:
bwc.session._transaction.set_in_progress()
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.SUCCEEDED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
await client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.FAILED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)

if bwc.publish:
bwc._fail(request_id, failure, duration)
# Process the response from the server.
if isinstance(exc, (NotPrimaryError, OperationFailure)):
await client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
raise
return reply # type: ignore[return-value]
result_docs, _, _ = await run_acknowledged_command(
bwc.conn, # type: ignore[arg-type]
cmd,
bwc.db_name,
Comment on lines 252 to +256
request_id,
msg,
client=client,
session=bwc.session, # type: ignore[arg-type]
listeners=bwc.listeners,
address=bwc.conn.address,
start=bwc.start_time,
codec_options=bwc.codec,
op_id=bwc.op_id,
command_name=bwc.name,
use_conn_transport=True,
decrypt_reply=False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have set_conn_more_to_come=False for clarity. The default value for run_acknowledged_command is True.

set_conn_more_to_come=False,
)
Comment on lines +266 to +270
return result_docs[0]

async def unack_write(
self,
Expand All @@ -341,83 +280,31 @@ async def unack_write(
docs: list[Mapping[str, Any]],
client: AsyncMongoClient[Any],
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.STARTED,
clientId=client._topology_settings._topology_id,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
cmd = bwc._start(cmd, request_id, docs)
try:
result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
duration = datetime.datetime.now() - bwc.start_time
if result is not None:
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_convert_write_result is also dead code now.

else:
# Comply with APM spec.
reply = {"ok": 1}
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.SUCCEEDED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, OperationFailure):
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
message=_CommandStatusMessage.FAILED,
clientId=client._topology_settings._topology_id,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
assert bwc.start_time is not None
bwc._fail(request_id, failure, duration)
raise
return result # type: ignore[return-value]
"""Send an unacknowledged batch write command."""
# Historically the STARTED log omits the documents while the published
# CommandStartedEvent includes them, so log ``cmd`` but publish a copy
# carrying the ``docs`` field.
published = dict(cmd)
published[bwc.field] = docs
await run_unacknowledged_command(
bwc.conn, # type: ignore[arg-type]
cmd,
bwc.db_name,
request_id,
msg,
client=client,
session=bwc.session, # type: ignore[arg-type]
listeners=bwc.listeners,
address=bwc.conn.address,
start=bwc.start_time,
codec_options=bwc.codec,
op_id=bwc.op_id,
command_name=bwc.name,
orig=published,
use_conn_transport=True,
max_doc_size=max_doc_size,
)
return None

async def _execute_batch_unack(
self,
Expand Down Expand Up @@ -489,7 +376,7 @@ async def _execute_command(
run = self.current_run

# AsyncConnection.command validates the session, but we use
# AsyncConnection.write_command
# run_acknowledged_command/run_unacknowledged_command.
conn.validate_session(client, session)
last_run = False

Expand Down
Loading
Loading