Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [dev] - YYYY-MM-DD

### Added
* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224)
* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224), [gh-295](https://github.com/IntelPython/mkl_fft/pull/295)

### Changed
* In `mkl_fft.fftn` and `mkl_fft.ifftn`, improved checking of the shape argument `s` to use faster direct transforms more often. This makes performance more consistent between `mkl_fft.fftn/ifftn` and `mkl.interfaces`. [gh-283](https://github.com/IntelPython/mkl_fft/pull/283)
Expand Down
12 changes: 7 additions & 5 deletions mkl_fft/_patch_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

"""Define functions for patching NumPy with MKL-based NumPy interface."""

import warnings
from contextlib import ContextDecorator
from threading import Lock, local

Expand Down Expand Up @@ -85,11 +86,12 @@ def do_restore(self, verbose=False):
with self._lock:
local_count = getattr(self._tls, "local_count", 0)
if local_count <= 0:
if verbose:
print(
"Warning: restore_numpy_fft called more times than "
"patch_numpy_fft in this thread."
)
warnings.warn(
"restore_numpy_fft called more times than "
"patch_numpy_fft in this thread.",
RuntimeWarning,
stacklevel=2,
)
return
self._tls.local_count -= 1
self._patch_count -= 1
Expand Down
8 changes: 8 additions & 0 deletions mkl_fft/tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import pytest

import mkl_fft
import mkl_fft.interfaces.numpy_fft as _nfft
Expand Down Expand Up @@ -78,3 +79,10 @@ def test_patch_reentrant():

assert not mkl_fft.is_patched()
assert np.fft.fft.__module__ == old_module


def test_patch_warning():
if mkl_fft.is_patched():
pytest.skip("This test should not be run with a pre-patched NumPy.")
with pytest.warns(RuntimeWarning, match="restore_numpy_fft*"):
mkl_fft.restore_numpy_fft()
Loading