Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions pymongo/pool_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
sock.setblocking(False)
await asyncio.get_running_loop().sock_connect(sock, host)
return sock
except OSError:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
Comment thread
NoahStapp marked this conversation as resolved.

Expand Down Expand Up @@ -231,6 +232,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
except OSError as e:
sock.close()
err = e # type: ignore[assignment]
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise

if err is not None:
raise err
Expand Down Expand Up @@ -282,19 +287,25 @@ async def _async_configured_socket(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise

ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the raw socket would otherwise leak.
ssl_sock.close()
raise


async def _configured_protocol_interface(
Expand Down Expand Up @@ -337,18 +348,19 @@ async def _configured_protocol_interface(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise

return AsyncNetworkingInterface((transport, protocol))
return AsyncNetworkingInterface((transport, protocol))
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the transport would otherwise leak.
transport.abort()
raise


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
Expand Down
103 changes: 103 additions & 0 deletions test/asynchronous/test_async_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
from __future__ import annotations

import asyncio
import functools
import socket as _socket
import ssl as _ssl

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do _socket and _ssl need to be private here ?

import sys
from test.asynchronous.utils import async_get_pool
from test.utils_shared import delay, one
from unittest.mock import patch

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, connected

from pymongo import pool_shared


class TestAsyncCancellation(AsyncIntegrationTest):
async def test_async_cancellation_closes_connection(self):
Expand Down Expand Up @@ -127,3 +133,100 @@ async def task():
await task

self.assertTrue(change_stream._closed)

async def test_cancellation_closes_socket_during_create_connection(self):
address = (await async_client_context.host, await async_client_context.port)
options = (await async_get_pool(self.client)).opts

created_sockets: list[_socket.socket] = []
real_socket_cls = _socket.socket
target_task = None

def tracking_socket(*args, **kwargs):
s = real_socket_cls(*args, **kwargs)
if asyncio.current_task() is target_task:
created_sockets.append(s)
return s

loop = asyncio.get_running_loop()
real_sock_connect = loop.sock_connect
started = asyncio.Event()
block_forever = asyncio.Event()

async def slow_sock_connect(sock, addr):
if sock in created_sockets:
started.set()
await block_forever.wait()
return None
return await real_sock_connect(sock, addr)

with (
patch.object(_socket, "socket", tracking_socket),
patch.object(loop, "sock_connect", slow_sock_connect),
):
task = asyncio.create_task(pool_shared._async_create_connection(address, options))
target_task = task
await asyncio.wait_for(started.wait(), timeout=5)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(created_sockets, "expected at least one socket to be created")
for sock in created_sockets:
self.assertEqual(
sock.fileno(),
-1,
f"socket leaked across cancellation: {sock!r}",
)

async def test_cancellation_closes_socket_during_ssl_wrap_socket(self):
address = (await async_client_context.host, await async_client_context.port)
options = (await async_get_pool(self.client)).opts
fake_ssl_context = _ssl.create_default_context()

created_sockets: list[_socket.socket] = []
real_socket_cls = _socket.socket
target_task = None

def tracking_socket(*args, **kwargs):
s = real_socket_cls(*args, **kwargs)
if asyncio.current_task() is target_task:
created_sockets.append(s)
return s

loop = asyncio.get_running_loop()
real_run_in_executor = loop.run_in_executor
started = asyncio.Event()

def slow_run_in_executor(executor, func, *args):
# Need to unwrap the SNI branch here if present
inner = func.func if isinstance(func, functools.partial) else func
# Each `ctx.wrap_socket` access returns a fresh bound-method
# object, so we check the bound instance (__self__) instead
if (
getattr(inner, "__self__", None) is fake_ssl_context
and asyncio.current_task() is target_task
):
started.set()
# Return a future that never completes for cancellation.
return asyncio.get_running_loop().create_future()
return real_run_in_executor(executor, func, *args)

with (
patch.object(_socket, "socket", tracking_socket),
patch.object(loop, "run_in_executor", slow_run_in_executor),
patch.object(options, "_PoolOptions__ssl_context", fake_ssl_context),
):
task = asyncio.create_task(pool_shared._async_configured_socket(address, options))
target_task = task
await asyncio.wait_for(started.wait(), timeout=5)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(created_sockets, "expected at least one socket to be created")
for sock in created_sockets:
self.assertEqual(
sock.fileno(),
-1,
f"socket leaked across cancellation: {sock!r}",
)
Loading