From 211ed340755b70120cab1c644cba43dc6974d7e1 Mon Sep 17 00:00:00 2001 From: gngpp Date: Fri, 27 Mar 2026 22:43:00 +0800 Subject: [PATCH 1/2] feat(client): implement graceful shutdown for Client --- Cargo.toml | 3 ++- python/wreq/blocking.py | 23 +++++++++++++++++++++++ python/wreq/wreq.py | 23 +++++++++++++++++++++++ src/client.rs | 28 ++++++++++++++++++++++++---- src/client/nogil.rs | 20 ++++++++++++++++++++ src/client/resp/http.rs | 3 ++- 6 files changed, 94 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3c20ced0..9da96fd7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,8 @@ abi3-py313 = ["pyo3/abi3-py313"] abi3-py314 = ["pyo3/abi3-py314"] [dependencies] -tokio = { version = "1.49.0", features = ["sync"]} +tokio = { version = "1.50.0", features = ["sync"]} +tokio-util = "0.7.18" pyo3 = { version = "0.28.2", features = [ "indexmap", "multiple-pymethods", diff --git a/python/wreq/blocking.py b/python/wreq/blocking.py index 7d8bd06e..0563a224 100644 --- a/python/wreq/blocking.py +++ b/python/wreq/blocking.py @@ -256,6 +256,29 @@ def __init__( """ ... + def close(self) -> None: + r""" + Closes the client and any associated resources. + + After calling this method, the client should not be used to make further requests. + + Examples: + + ```python + import asyncio + import wreq + + async def main(): + client = wreq.Client() + response = await client.get('https://httpbin.io/get') + print(await response.text()) + client.close() + + asyncio.run(main()) + ``` + """ + ... + def request( self, method: Method, diff --git a/python/wreq/wreq.py b/python/wreq/wreq.py index 29fbf494..fe0240fc 100644 --- a/python/wreq/wreq.py +++ b/python/wreq/wreq.py @@ -1136,6 +1136,29 @@ async def main(): """ ... + def close(self) -> None: + r""" + Closes the client and any associated resources. + + After calling this method, the client should not be used to make further requests. + + Examples: + + ```python + import asyncio + import wreq + + async def main(): + client = wreq.Client() + response = await client.get('https://httpbin.io/get') + print(await response.text()) + client.close() + + asyncio.run(main()) + ``` + """ + ... + async def request( self, method: Method, diff --git a/src/client.rs b/src/client.rs index 71aa875c..48509bcc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -15,6 +15,7 @@ use std::{ use pyo3::{IntoPyObjectExt, coroutine::CancelHandle, prelude::*, pybacked::PyBackedStr}; use req::{Request, WebSocketRequest}; +use tokio_util::sync::CancellationToken; use wreq::{Proxy, tls::CertStore}; use wreq_util::EmulationOption; @@ -234,6 +235,7 @@ impl FromPyObject<'_, '_> for Builder { #[pyclass(subclass, frozen, skip_from_py_object)] pub struct Client { inner: wreq::Client, + cancel: CancellationToken, /// Get the cookie jar of the client. #[pyo3(get)] @@ -449,12 +451,22 @@ impl Client { builder .build() - .map(|inner| Client { inner, cookie_jar }) + .map(|inner| Client { + inner, + cancel: CancellationToken::new(), + cookie_jar, + }) .map_err(Error::Library) .map_err(Into::into) }) } + /// Close the client, preventing any new requests. + #[inline] + pub fn close(&self) { + self.cancel.cancel(); + } + /// Make a GET request to the given URL. #[inline(always)] #[pyo3(signature = (url, **kwds))] @@ -561,9 +573,10 @@ impl Client { url: PyBackedStr, kwds: Option, ) -> PyResult { - NoGIL::new( + NoGIL::new_with_token( execute_request(self.inner.clone(), method, url, kwds), cancel, + self.cancel.clone(), ) .await } @@ -577,9 +590,10 @@ impl Client { url: PyBackedStr, kwds: Option, ) -> PyResult { - NoGIL::new( + NoGIL::new_with_token( execute_websocket_request(self.inner.clone(), url, kwds), cancel, + self.cancel.clone(), ) .await } @@ -594,7 +608,7 @@ impl Client { #[inline] async fn __aexit__(&self, _exc_type: Py, _exc_val: Py, _traceback: Py) { - // TODO: Implement connection closing logic if necessary. + self.cancel.cancel(); } } @@ -617,6 +631,12 @@ impl BlockingClient { self.0.cookie_jar.clone() } + /// Close the client, preventing any new requests. + #[inline] + pub fn close(&self) { + self.0.close(); + } + /// Make a GET request to the specified URL. #[inline(always)] #[pyo3(signature = (url, **kwds))] diff --git a/src/client/nogil.rs b/src/client/nogil.rs index 84477dae..51592252 100644 --- a/src/client/nogil.rs +++ b/src/client/nogil.rs @@ -11,6 +11,7 @@ use pyo3::{ prelude::*, }; use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; pin_project! { /// A future that allows Python threads to run while it is being polled or executed. @@ -38,6 +39,25 @@ where } }) } } + + /// Create [`NoGIL`] from a future and a cancellation token + #[inline] + pub fn new_with_token( + fut: Fut, + mut cancel: CancelHandle, + cancel_token: CancellationToken, + ) -> Self + where + Fut: Future> + Send + 'static, + { + Self { handle: pyo3_async_runtimes::tokio::get_runtime().spawn(async move { + tokio::select! { + result = fut => result, + _ = cancel.cancelled() => Err(CancelledError::new_err("Operation was cancelled")), + _ = cancel_token.cancelled() => Err(CancelledError::new_err("Operation was cancelled: client has been closed")), + } + }) } + } } impl Future for NoGIL diff --git a/src/client/resp/http.rs b/src/client/resp/http.rs index 90e4ba2b..4c79b1a1 100644 --- a/src/client/resp/http.rs +++ b/src/client/resp/http.rs @@ -44,7 +44,7 @@ enum Body { } /// A blocking response from a request. -#[pyclass(name = "Response", subclass, frozen, str)] +#[pyclass(name = "Response", subclass, frozen, str, skip_from_py_object)] pub struct BlockingResponse(Response); // ===== impl Response ===== @@ -115,6 +115,7 @@ impl Response { /// Forcefully destroys the response [`Body`], preventing any further reads. fn destroy(&self) { + #[allow(clippy::option_map_unit_fn)] self.body .swap(None) .and_then(Arc::into_inner) From aef12438493b88dce8850e70cb00198f6cd4f4c8 Mon Sep 17 00:00:00 2001 From: gngpp Date: Fri, 27 Mar 2026 22:46:17 +0800 Subject: [PATCH 2/2] feat(client): implement graceful shutdown for Client --- python/wreq/blocking.py | 11 +++++------ python/wreq/cookie.py | 1 - python/wreq/wreq.py | 2 ++ 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/wreq/blocking.py b/python/wreq/blocking.py index 0563a224..3ad9a7c2 100644 --- a/python/wreq/blocking.py +++ b/python/wreq/blocking.py @@ -268,13 +268,12 @@ def close(self) -> None: import asyncio import wreq - async def main(): - client = wreq.Client() - response = await client.get('https://httpbin.io/get') - print(await response.text()) - client.close() + client = wreq.blocking.Client() + + response = client.get('https://httpbin.io/get') + print(response.text()) - asyncio.run(main()) + client.close() ``` """ ... diff --git a/python/wreq/cookie.py b/python/wreq/cookie.py index 2df877de..3dd1753c 100644 --- a/python/wreq/cookie.py +++ b/python/wreq/cookie.py @@ -9,7 +9,6 @@ import datetime from enum import Enum, auto from typing import Sequence, final -from warnings import deprecated __all__ = ["SameSite", "Cookie", "Jar"] diff --git a/python/wreq/wreq.py b/python/wreq/wreq.py index fe0240fc..2be42022 100644 --- a/python/wreq/wreq.py +++ b/python/wreq/wreq.py @@ -1150,8 +1150,10 @@ def close(self) -> None: async def main(): client = wreq.Client() + response = await client.get('https://httpbin.io/get') print(await response.text()) + client.close() asyncio.run(main())