diff --git a/source/http_connection.c b/source/http_connection.c index 6a000fe40..bf6340e66 100644 --- a/source/http_connection.c +++ b/source/http_connection.c @@ -7,6 +7,7 @@ #include "io.h" #include +#include #include #include #include @@ -25,8 +26,12 @@ struct http_connection_binding { /* Reference to python object that reference to other related python object to keep it alive */ PyObject *py_core; - bool release_called; - bool shutdown_called; + /** + * Ref-count starting at 2 (one for release, one for shutdown). + * Each path decrements atomically; whoever decrements to 0 calls s_connection_destroy(). + * This is safe under Py_GIL_DISABLED where PyGILState_Ensure() provides no mutual exclusion. + */ + struct aws_atomic_var ref_count; }; static void s_connection_destroy(struct http_connection_binding *connection) { @@ -41,14 +46,11 @@ struct aws_http_connection *aws_py_get_http_connection(PyObject *connection) { } static void s_connection_release(struct http_connection_binding *connection) { - AWS_FATAL_ASSERT(!connection->release_called); - connection->release_called = true; - - bool destroy_after_release = connection->shutdown_called; - aws_http_connection_release(connection->native); - if (destroy_after_release) { + size_t prev = aws_atomic_fetch_sub(&connection->ref_count, 1); + AWS_FATAL_ASSERT(prev != 0); + if (prev == 1) { s_connection_destroy(connection); } } @@ -61,17 +63,12 @@ static void s_connection_capsule_destructor(PyObject *capsule) { static void s_on_connection_shutdown(struct aws_http_connection *native_connection, int error_code, void *user_data) { (void)native_connection; struct http_connection_binding *connection = user_data; - AWS_FATAL_ASSERT(!connection->shutdown_called); PyGILState_STATE state; if (aws_py_gilstate_ensure(&state)) { return; /* Python has shut down. Nothing matters anymore, but don't crash */ } - connection->shutdown_called = true; - - bool destroy_after_shutdown = connection->release_called; - /* Invoke on_shutdown, then clear our reference to it */ PyObject *result = PyObject_CallMethod(connection->py_core, "_on_shutdown", "(i)", error_code); @@ -82,7 +79,9 @@ static void s_on_connection_shutdown(struct aws_http_connection *native_connecti PyErr_WriteUnraisable(PyErr_Occurred()); } - if (destroy_after_shutdown) { + size_t prev = aws_atomic_fetch_sub(&connection->ref_count, 1); + AWS_FATAL_ASSERT(prev != 0); + if (prev == 1) { s_connection_destroy(connection); } @@ -281,6 +280,7 @@ PyObject *aws_py_http_client_connection_new(PyObject *self, PyObject *args) { } struct http_connection_binding *connection = aws_mem_calloc(allocator, 1, sizeof(struct http_connection_binding)); + aws_atomic_init_int(&connection->ref_count, 2); /* From hereon, we need to clean up if errors occur */ struct aws_http2_setting *http2_settings = NULL; size_t http2_settings_count = 0; diff --git a/test/test_http_connection_lifetime.py b/test/test_http_connection_lifetime.py new file mode 100644 index 000000000..4418bcc49 --- /dev/null +++ b/test/test_http_connection_lifetime.py @@ -0,0 +1,160 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0. + +import gc +import threading +import unittest +from test import NativeResourceTest +from http.server import HTTPServer, SimpleHTTPRequestHandler +from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup +from awscrt.http import HttpClientConnection, HttpRequest + + +class SilentHandler(SimpleHTTPRequestHandler): + def log_message(self, format, *args): + pass + + def do_GET(self): + self.send_response(200, 'OK') + self.send_header('Content-Length', '5') + self.end_headers() + self.wfile.write(b'hello') + + +class TestConnectionLifetime(NativeResourceTest): + """Tests for http_connection_binding ref-count based lifetime management. + + These tests exercise the two-path destruction handshake (capsule destructor + on application thread + on_connection_shutdown on event-loop thread) to + verify no double-free occurs after the atomic ref-count fix. + """ + hostname = 'localhost' + timeout = 10 + + def _start_server(self): + self.server = HTTPServer((self.hostname, 0), SilentHandler) + self.port = self.server.server_address[1] + self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.server_thread.start() + + def _stop_server(self): + self.server.shutdown() + self.server.server_close() + self.server_thread.join() + + def _new_connection(self): + event_loop_group = EventLoopGroup() + host_resolver = DefaultHostResolver(event_loop_group) + bootstrap = ClientBootstrap(event_loop_group, host_resolver) + future = HttpClientConnection.new( + host_name=self.hostname, + port=self.port, + bootstrap=bootstrap) + return future.result(self.timeout) + + def test_release_before_shutdown(self): + """Drop all Python references so capsule destructor fires first, + then shutdown callback fires. Binding must be destroyed exactly once.""" + self._start_server() + try: + connection = self._new_connection() + shutdown_future = connection.shutdown_future + + # Drop Python reference -> capsule destructor -> s_connection_release + del connection + gc.collect() + + # Shutdown callback fires on event-loop thread -> completes the ref pair + shutdown_future.result(self.timeout) + finally: + self._stop_server() + + def test_shutdown_before_release(self): + """Force shutdown callback to fire first via close(), then drop Python + reference so capsule destructor fires second.""" + self._start_server() + try: + connection = self._new_connection() + shutdown_future = connection.shutdown_future + + # Trigger shutdown on event-loop thread + connection.close() + shutdown_future.result(self.timeout) + + # Now drop Python reference -> capsule destructor fires second + del connection + gc.collect() + finally: + self._stop_server() + + def test_concurrent_release_and_shutdown_stress(self): + """Stress test: create many connections and race release against shutdown. + + Creates connections, immediately closes them (triggering shutdown on the + event-loop thread) and simultaneously drops the Python reference from + another thread. Under the old bool-flag approach with Py_GIL_DISABLED, + this would produce double-frees. With atomic ref-counting, exactly one + path destroys the binding. + """ + self._start_server() + try: + iterations = 50 + errors = [] + + def release_connection(conn): + try: + del conn + gc.collect() + except Exception as e: + errors.append(e) + + for _ in range(iterations): + connection = self._new_connection() + shutdown_future = connection.shutdown_future + + # Start close (fires shutdown on event-loop thread) + connection.close() + + # Concurrently drop the Python reference from another thread + t = threading.Thread(target=release_connection, args=(connection,)) + del connection + t.start() + + # Wait for both paths to complete + shutdown_future.result(self.timeout) + t.join(self.timeout) + + self.assertEqual([], errors) + finally: + self._stop_server() + + def test_multiple_connections_sequential_lifecycle(self): + """Create and destroy multiple connections sequentially to verify + no corruption from one connection's destruction affects the next.""" + self._start_server() + try: + for _ in range(20): + connection = self._new_connection() + self.assertTrue(connection.is_open()) + + request = HttpRequest('GET', '/') + stream = connection.request(request) + stream.activate() + stream.completion_future.result(self.timeout) + del stream + del request + + shutdown_future = connection.shutdown_future + del connection + gc.collect() + + try: + shutdown_future.result(self.timeout) + except Exception: + pass + finally: + self._stop_server() + + +if __name__ == '__main__': + unittest.main()