-
Notifications
You must be signed in to change notification settings - Fork 75
Expand file tree
/
Copy pathtest_all_objects.py
More file actions
301 lines (241 loc) · 12.2 KB
/
test_all_objects.py
File metadata and controls
301 lines (241 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
"""Automated tests based on the skbase test suite template."""
from inspect import isclass
import shutil
from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator
from skbase.testing import QuickTester as _QuickTester
from skbase.testing import TestAllObjects as _TestAllObjects
from hyperactive.registry import all_objects
from hyperactive.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
from hyperactive.tests._doctest import run_doctest
# whether to test only estimators from modules that are changed w.r.t. main
# default is False, can be set to True by pytest --only_changed_modules True flag
ONLY_CHANGED_MODULES = False
class PackageConfig:
"""Contains package config variables for test classes."""
# class variables which can be overridden by descendants
# ------------------------------------------------------
# package to search for objects
# expected type: str, package/module name, relative to python environment root
package_name = "hyperactive"
# list of object types (class names) to exclude
# expected type: list of str, str are class names
exclude_objects = EXCLUDE_ESTIMATORS
# list of tests to exclude
# expected type: dict of lists, key:str, value: List[str]
# keys are class names of estimators, values are lists of test names to exclude
excluded_tests = EXCLUDED_TESTS
# list of valid tags
# expected type: list of str, str are tag names
valid_tags = [
# general tags
"object_type",
"python_dependencies",
"authors",
"maintainers",
# experiments
"property:randomness",
# optimizers
"info:name", # str
"info:local_vs_global", # "local", "mixed", "global"
"info:explore_vs_exploit", # "explore", "exploit", "mixed"
"info:compute", # "low", "middle", "high"
]
class BaseFixtureGenerator(PackageConfig, _BaseFixtureGenerator):
"""Fixture generator for base testing functionality in sktime.
Test classes inheriting from this and not overriding pytest_generate_tests
will have estimator and scenario fixtures parametrized out of the box.
Descendants can override:
object_type_filter: str, class variable; None or scitype string
e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST
which objects are being retrieved and tested
exclude_objects : str or list of str, or None, default=None
names of object classes to exclude in retrieval; None = no objects are excluded
excluded_tests : dict with str keys and list of str values, or None, default=None
str keys must be object names, value keys must be lists of test names
names of tests (values) to exclude for object with name as key
None = no tests are excluded
valid_tags : list of str or None, default = None
list of valid tags, None = all tags are valid
valid_base_types : list of str or None, default = None
list of valid base types (strings), None = all base types are valid
fixture_sequence: list of str
sequence of fixture variable names in conditional fixture generation
_generate_[variable]: object methods, all (test_name: str, **kwargs) -> list
generating list of fixtures for fixture variable with name [variable]
to be used in test with name test_name
can optionally use values for fixtures earlier in fixture_sequence,
these must be input as kwargs in a call
is_excluded: static method (test_name: str, est: class) -> bool
whether test with name test_name should be excluded for object est
should be used only for encoding general rules, not individual skips
individual skips should go on the excluded_tests list
requires _generate_object_class and _generate_object_instance as is
Fixtures parametrized
---------------------
object_class: class inheriting from BaseObject
ranges over classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
object_instance: object instances inheriting from BaseObject
ranges over classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
instances are generated by create_test_instance class method of object_class
"""
# overrides object retrieval in scikit-base
def _all_objects(self):
"""Retrieve list of all object classes of type self.object_type_filter.
If self.object_type_filter is None, retrieve all objects.
If class, retrieve all classes inheriting from self.object_type_filter.
Otherwise (assumed str or list of str), retrieve all classes with tags
object_type in self.object_type_filter.
"""
filter = getattr(self, "object_type_filter", None)
if isclass(filter):
object_types = filter.get_class_tag("object_type", None)
else:
object_types = filter
obj_list = all_objects(
object_types=object_types,
return_names=False,
exclude_objects=self.exclude_objects,
)
if isclass(filter):
obj_list = [obj for obj in obj_list if issubclass(obj, filter)]
# run_test_for_class selects the estimators to run
# based on whether they have changed, and whether they have all dependencies
# internally, uses the ONLY_CHANGED_MODULES flag,
# and checks the python env against python_dependencies tag
# obj_list = [obj for obj in obj_list if run_test_for_class(obj)]
return obj_list
# which sequence the conditional fixtures are generated in
fixture_sequence = ["object_class", "object_instance"]
class TestAllObjects(BaseFixtureGenerator, _TestAllObjects):
"""Generic tests for all objects in the package."""
def test_doctest_examples(self, object_class):
"""Runs doctests for estimator class."""
run_doctest(object_class, name=f"class {object_class.__name__}")
class ExperimentFixtureGenerator(BaseFixtureGenerator):
"""Fixture generator for experiments.
Fixtures parameterized
----------------------
object_class: class inheriting from BaseObject
ranges over classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
object_instance: object instances inheriting from BaseObject
ranges over classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
instances are generated by create_test_instance class method of object_class
"""
object_type_filter = "experiment"
class TestAllExperiments(ExperimentFixtureGenerator, _QuickTester):
"""Module level tests for all experiment classes."""
def test_paramnames(self, object_class):
"""Test that paramnames returns the correct parameter names."""
inst_params = object_class.get_test_params()
obj_params = object_class._get_score_params()
for inst, obj_param in zip(inst_params, obj_params):
obj_inst = object_class(**inst)
paramnames = obj_inst.paramnames()
assert set(obj_param.keys()) <= set(paramnames), (
f"Parameter names do not match: {paramnames} != {obj_param}"
)
def test_score_function(self, object_class):
"""Test that substituting into score works as intended."""
inst_params = object_class.get_test_params()
obj_params = object_class._get_score_params()
for inst, obj in zip(inst_params, obj_params):
inst = object_class(**inst)
res = inst.score(obj)
msg = f"Score function did not return a length two tuple: {res}"
assert isinstance(res, tuple) and len(res) == 2, msg
score, metadata = res
assert isinstance(score, float), f"Score is not a float: {score}"
assert isinstance(metadata, dict), f"Metadata is not a dict: {metadata}"
call_sc = inst(**obj)
assert isinstance(call_sc, float), f"Score is not a float: {call_sc}"
if inst.get_tag("property:randomness") == "deterministic":
assert score == call_sc, f"Score does not match: {score} != {call_sc}"
class OptimizerFixtureGenerator(BaseFixtureGenerator):
"""Fixture generator for optimizers.
Fixtures parameterized
----------------------
object_class: class inheriting from BaseObject
ranges over classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
object_instance: object instances inheriting from BaseObject
ranges over classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
instances are generated by create_test_instance class method of object_class
"""
object_type_filter = "optimizer"
class TestAllOptimizers(OptimizerFixtureGenerator, _QuickTester):
"""Module level tests for all optimizer classes."""
def test_opt_run(self, object_instance):
"""Test that run returns the expected result."""
paramnames = object_instance.get_params().keys()
if "experiment" not in paramnames:
raise ValueError(
"Optimizer must have an 'experiment' parameter."
)
# check that experiment occurs last in __init__ signature
if not object_instance.__init__.__code__.co_varnames[-1] == "experiment":
raise ValueError(
"'experiment' parameter in optimizer __init__ must be last argument."
)
if not hasattr(object_instance, "experiment"):
raise ValueError(
"Optimizer test cases must have 'experiment' parameter defined."
)
experiment = object_instance.experiment
msg = "experiment must be an instance of BaseExperiment."
if not hasattr(experiment, "get_tag"):
raise ValueError(msg)
if not experiment.get_tag("object_type") == "experiment":
raise ValueError(msg)
best_params = object_instance.run()
assert isinstance(best_params, dict), "return of run is not a dict"
paramnames = list(best_params.keys())
expected_paramnames = experiment.paramnames()
assert set(paramnames) <= set(expected_paramnames), (
f"Optimizer run must return a dict with keys being paramnames "
f"from the experiment, but found: {paramnames} !<= {expected_paramnames}"
)
msg = "Optimizer run must write best_params_ to self in run."
if not hasattr(object_instance, "best_params_"):
raise ValueError(msg)
msg = "Optimizer best_params_ must equal the best_params returned by run."
if not object_instance.best_params_ == best_params:
raise ValueError(msg)
def test_gfo_integration(self, object_instance):
"""Integration test for optimizer end-to-end, for GFO optimizers only.
Runs the optimizer on the sklearn tuning experiment.
"""
from hyperactive.opt._adapters._gfo import _BaseGFOadapter
if not isinstance(object_instance, _BaseGFOadapter):
return None
optimizer = object_instance
# 1. define the experiment
from hyperactive.experiment.integrations import SklearnCvExperiment
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
X, y = load_iris(return_X_y=True)
sklearn_exp = SklearnCvExperiment(
estimator=SVC(),
scoring=accuracy_score,
cv=KFold(n_splits=3, shuffle=True),
X=X,
y=y,
)
# 2. set up the optimizer
import numpy as np
_config = {
"search_space": {
"C": np.array([0.01, 0.1, 1, 10]),
"gamma": np.array([0.0001, 0.01, 0.1, 1, 10]),
},
"n_iter": 100,
"experiment": sklearn_exp,
}
optimizer = optimizer.clone().set_params(**_config)
# 3. run the HillClimbing optimizer
optimizer.run()
best_params = optimizer.best_params_
assert best_params is not None, "Best parameters should not be None"
assert isinstance(best_params, dict), "Best parameters should be a dictionary"
assert "C" in best_params, "Best parameters should contain 'C'"
assert "gamma" in best_params, "Best parameters should contain 'gamma'"