Skip to content

Commit 11e83a8

Browse files
committed
factor patching tests into separate file
1 parent 5cd2258 commit 11e83a8

3 files changed

Lines changed: 102 additions & 12 deletions

File tree

mkl_umath/tests/__init__.py

Whitespace-only changes.

mkl_umath/tests/test_basic.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import numpy as np
2727
import pytest
2828

29-
import mkl_umath
3029
import mkl_umath._ufuncs as mu # pylint: disable=no-name-in-module
3130

3231
np.random.seed(42)
@@ -190,14 +189,3 @@ def test_reduce_complex(func, dtype):
190189
assert np.allclose(
191190
mkl_res, np_res
192191
), f"Results for '{func}[reduce]' do not match"
193-
194-
195-
def test_patch():
196-
mkl_umath.restore_numpy_umath()
197-
assert not mkl_umath.is_patched()
198-
199-
mkl_umath.patch_numpy_umath() # Enable mkl_umath in Numpy
200-
assert mkl_umath.is_patched()
201-
202-
mkl_umath.restore_numpy_umath() # Disable mkl_umath in Numpy
203-
assert not mkl_umath.is_patched()

mkl_umath/tests/test_patching.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) 2019, Intel Corporation
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# * Redistributions of source code must retain the above copyright notice,
7+
# this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of Intel Corporation nor the names of its contributors
12+
# may be used to endorse or promote products derived from this software
13+
# without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
19+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
26+
import mkl_umath
27+
28+
import contextlib
29+
import sys
30+
31+
from dataclasses import dataclass
32+
from io import StringIO
33+
34+
import pytest
35+
36+
37+
@dataclass
38+
class CapturedOutput():
39+
stdout: str
40+
41+
42+
@contextlib.contextmanager
43+
def capture_output():
44+
old_stdout = sys.stdout
45+
capturer = StringIO()
46+
sys.stdout = capturer
47+
output = CapturedOutput(stdout="")
48+
yield output
49+
sys.stdout = old_stdout
50+
output.stdout = capturer.getvalue()
51+
52+
53+
def test_patch_basic():
54+
mkl_umath.restore_numpy_umath()
55+
assert not mkl_umath.is_patched()
56+
57+
mkl_umath.patch_numpy_umath() # Enable mkl_umath in Numpy
58+
assert mkl_umath.is_patched()
59+
60+
mkl_umath.restore_numpy_umath() # Disable mkl_umath in Numpy
61+
assert not mkl_umath.is_patched()
62+
63+
64+
def test_patch_redundant_patching():
65+
assert not mkl_umath.is_patched()
66+
67+
mkl_umath.patch_numpy_umath()
68+
mkl_umath.patch_numpy_umath()
69+
70+
assert mkl_umath.is_patched()
71+
72+
mkl_umath.restore_numpy_umath()
73+
assert mkl_umath.is_patched()
74+
75+
mkl_umath.restore_numpy_umath()
76+
assert not mkl_umath.is_patched()
77+
78+
79+
def test_patch_reentrant():
80+
assert not mkl_umath.is_patched()
81+
82+
with mkl_umath.mkl_umath():
83+
assert mkl_umath.is_patched()
84+
85+
with mkl_umath.mkl_umath():
86+
assert mkl_umath.is_patched()
87+
88+
assert mkl_umath.is_patched()
89+
90+
assert not mkl_umath.is_patched()
91+
92+
93+
def test_patch_verbose():
94+
assert not mkl_umath.is_patched()
95+
96+
with capture_output() as output:
97+
mkl_umath.patch_numpy_umath(verbose=True)
98+
assert output.stdout
99+
assert mkl_umath.is_patched()
100+
101+
mkl_umath.restore_numpy_umath()
102+
assert not mkl_umath.is_patched()

0 commit comments

Comments
 (0)