diff --git a/CHANGELOG.md b/CHANGELOG.md index 58dba66..8222e07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [dev] - YYYY-MM-DD ### Added -* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224) +* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224), [gh-295](https://github.com/IntelPython/mkl_fft/pull/295) ### Changed * In `mkl_fft.fftn` and `mkl_fft.ifftn`, improved checking of the shape argument `s` to use faster direct transforms more often. This makes performance more consistent between `mkl_fft.fftn/ifftn` and `mkl.interfaces`. [gh-283](https://github.com/IntelPython/mkl_fft/pull/283) diff --git a/mkl_fft/_patch_numpy.py b/mkl_fft/_patch_numpy.py index 5c6b1ca..e37cecb 100644 --- a/mkl_fft/_patch_numpy.py +++ b/mkl_fft/_patch_numpy.py @@ -25,6 +25,7 @@ """Define functions for patching NumPy with MKL-based NumPy interface.""" +import warnings from contextlib import ContextDecorator from threading import Lock, local @@ -85,11 +86,12 @@ def do_restore(self, verbose=False): with self._lock: local_count = getattr(self._tls, "local_count", 0) if local_count <= 0: - if verbose: - print( - "Warning: restore_numpy_fft called more times than " - "patch_numpy_fft in this thread." - ) + warnings.warn( + "restore_numpy_fft called more times than " + "patch_numpy_fft in this thread.", + RuntimeWarning, + stacklevel=2, + ) return self._tls.local_count -= 1 self._patch_count -= 1 diff --git a/mkl_fft/tests/test_patch.py b/mkl_fft/tests/test_patch.py index b68dc2d..4333306 100644 --- a/mkl_fft/tests/test_patch.py +++ b/mkl_fft/tests/test_patch.py @@ -24,6 +24,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np +import pytest import mkl_fft import mkl_fft.interfaces.numpy_fft as _nfft @@ -78,3 +79,10 @@ def test_patch_reentrant(): assert not mkl_fft.is_patched() assert np.fft.fft.__module__ == old_module + + +def test_patch_warning(): + if mkl_fft.is_patched(): + pytest.skip("This test should not be run with a pre-patched NumPy.") + with pytest.warns(RuntimeWarning, match="restore_numpy_fft*"): + mkl_fft.restore_numpy_fft()