diff --git a/ema_workbench/analysis/cart.py b/ema_workbench/analysis/cart.py index a9fcf6f01..7b3b81fb3 100644 --- a/ema_workbench/analysis/cart.py +++ b/ema_workbench/analysis/cart.py @@ -5,6 +5,8 @@ """ +from __future__ import annotations + import contextlib import io import math @@ -13,6 +15,7 @@ import matplotlib.image as mpimg import matplotlib.pyplot as plt import numpy as np +import numpy.typing as npt import pandas as pd from sklearn import tree @@ -71,7 +74,13 @@ class CART(sdutil.OutputFormatterMixin): sep = "!?!" - def __init__(self, x, y, mass_min=0.05, mode=sdutil.RuleInductionType.BINARY): + def __init__( + self, + x: pd.DataFrame, + y: npt.NDArray, + mass_min: float = 0.05, + mode: sdutil.RuleInductionType = sdutil.RuleInductionType.BINARY, + ) -> None: """Init.""" with contextlib.suppress(KeyError): x = x.drop(["scenario"], axis=1) @@ -97,7 +106,7 @@ def __init__(self, x, y, mass_min=0.05, mode=sdutil.RuleInductionType.BINARY): self._stats = None @property - def boxes(self): + def boxes(self) -> list[pd.DataFrame]: """Return a list with the box limits for each terminal leaf. Returns @@ -176,7 +185,7 @@ def recurse(left, right, child, lineage=None): return self._boxes @property - def stats(self): + def stats(self) -> list[dict]: """Returns list with the scenario discovery statistics for each terminal leaf. Returns @@ -197,7 +206,7 @@ def stats(self): self._stats.append(boxstats) return self._stats - def _binary_stats(self, box, box_init): + def _binary_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict: indices = sdutil._in_box(self.x, box) y_in_box = self.y[indices] @@ -211,7 +220,7 @@ def _binary_stats(self, box, box_init): } return boxstats - def _regression_stats(self, box, box_init): + def _regression_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict: indices = sdutil._in_box(self.x, box) y_in_box = self.y[indices] @@ -223,7 +232,7 @@ def _regression_stats(self, box, box_init): } return boxstats - def _classification_stats(self, box, box_init): + def _classification_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict: indices = sdutil._in_box(self.x, box) y_in_box = self.y[indices] @@ -252,7 +261,7 @@ def _classification_stats(self, box, box_init): sdutil.RuleInductionType.CLASSIFICATION: _classification_stats, } - def build_tree(self): + def build_tree(self) -> None: """Train CART on the data.""" min_samples = int(self.mass_min * self.x.shape[0]) @@ -262,7 +271,7 @@ def build_tree(self): self.clf = tree.DecisionTreeClassifier(min_samples_leaf=min_samples) self.clf.fit(self._x, self.y) - def show_tree(self, mplfig=True, format="png"): + def show_tree(self, mplfig: bool = True, format: str = "png"): """Return a png (defaults) or svg of the tree. On Windows, graphviz needs to be installed with conda. diff --git a/ema_workbench/analysis/pairs_plotting.py b/ema_workbench/analysis/pairs_plotting.py index 0f1dc8d84..fcef55b7e 100644 --- a/ema_workbench/analysis/pairs_plotting.py +++ b/ema_workbench/analysis/pairs_plotting.py @@ -1,9 +1,13 @@ """R-style pairs plotting functionality.""" +from __future__ import annotations + +import matplotlib.axes import matplotlib.cm as cm import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import numpy as np +import pandas as pd from ..util import get_module_logger @@ -17,15 +21,15 @@ def pairs_lines( - experiments, - outcomes, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + outcomes: dict[str, np.ndarray], + outcomes_to_show: list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - ylabels=None, - legend=True, + ylabels: dict[str, str] | None = None, + legend: bool = True, **kwargs, -): +) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]: """Generate a pairs lines multiplot. It shows the behavior of two outcomes over time against @@ -147,18 +151,18 @@ def simple_pairs_lines(ax, y_data, x_data, color): def pairs_density( - experiments, - outcomes, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + outcomes: dict[str, np.ndarray], + outcomes_to_show: list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - ylabels=None, - point_in_time=-1, - log=True, - gridsize=50, - colormap="coolwarm", - filter_scalar=True, -): + ylabels: dict[str, str] | None = None, + point_in_time: int = -1, + log: bool = True, + gridsize: int = 50, + colormap: str = "coolwarm", + filter_scalar: bool = True, +) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]: """Generate a pairs hexbin density multiplot. In case of time-series data, the end states are used. @@ -392,17 +396,17 @@ def simple_pairs_density( def pairs_scatter( - experiments, - outcomes, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + outcomes: dict[str, np.ndarray], + outcomes_to_show: list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - ylabels=None, - legend=True, - point_in_time=-1, - filter_scalar=False, + ylabels: dict[str, str] | None = None, + legend: bool = True, + point_in_time: int = -1, + filter_scalar: bool = False, **kwargs, -): +) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]: """Generate a pairs scatter multiplot. In case of time-series data, the end states are used. diff --git a/ema_workbench/analysis/parcoords.py b/ema_workbench/analysis/parcoords.py index f88cb9030..6ac886b42 100644 --- a/ema_workbench/analysis/parcoords.py +++ b/ema_workbench/analysis/parcoords.py @@ -1,5 +1,7 @@ """A general purpose matplotlib-based parallel coordinate plotting Class.""" +from __future__ import annotations + import matplotlib.pyplot as plt import matplotlib.ticker as ticker import pandas as pd @@ -14,7 +16,7 @@ __all__ = ["ParallelAxes", "get_limits"] -def setup_parallel_plot(labels, minima, maxima, formatter=None, fs=14, rot=90): +def setup_parallel_plot(labels: list[str], minima: pd.Series, maxima: pd.Series, formatter: dict[str, str] | None = None, fs: int = 14, rot: float = 90) -> tuple[plt.Figure, list[plt.Axes], dict]: """Helper function for setting up the parallel axes plot. Parameters @@ -100,7 +102,7 @@ def setup_parallel_plot(labels, minima, maxima, formatter=None, fs=14, rot=90): return fig, axes, tick_labels -def get_limits(data): +def get_limits(data: pd.DataFrame) -> pd.DataFrame: """Helper function to get limits of a FataFrame that can serve as input to ParallelAxis. Parameters @@ -165,7 +167,7 @@ class ParallelAxes: """ - def __init__(self, limits, formatter=None, fontsize=14, rot=90): + def __init__(self, limits: pd.DataFrame, formatter: dict[str, str] | None = None, fontsize: int = 14, rot: float = 90): """Init. Parameters @@ -216,7 +218,7 @@ def __init__(self, limits, formatter=None, fontsize=14, rot=90): plt.tight_layout(h_pad=0, w_pad=0) plt.subplots_adjust(wspace=0) - def plot(self, data, color=None, label=None, **kwargs): + def plot(self, data: pd.DataFrame | pd.Series, color=None, label: str | None = None, **kwargs) -> None: """Plot data on parallel axes. Parameters @@ -257,7 +259,7 @@ def plot(self, data, color=None, label=None, **kwargs): # plot the data self._plot(normalized_data, color=color, **kwargs) - def legend(self): + def legend(self) -> None: """Add a legend to the figure.""" artists = [] labels = [] @@ -299,7 +301,7 @@ def _plot(self, data, **kwargs): if label_j in self.flipped_axes: self._update_plot_data(ax, 1, lines=lines) - def invert_axis(self, axis): + def invert_axis(self, axis: str | list[str]) -> None: """Flip direction for specified axis. Parameters diff --git a/ema_workbench/analysis/plotting.py b/ema_workbench/analysis/plotting.py index b8796cc13..fb2ecd0d5 100644 --- a/ema_workbench/analysis/plotting.py +++ b/ema_workbench/analysis/plotting.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd from matplotlib.patches import ConnectionPatch from ..util import EMAError, get_module_logger @@ -36,18 +37,18 @@ def envelopes( - experiments, - outcomes, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + outcomes: dict[str, np.ndarray], + outcomes_to_show: str | list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - density=None, - fill=False, - legend=True, - titles=None, - ylabels=None, - log=False, -): + density: Density | None = None, + fill: bool = False, + legend: bool = True, + titles: dict[str, str] | None = None, + ylabels: dict[str, str] | None = None, + log: bool = False, +) -> tuple[plt.Figure, dict[str, plt.Axes]]: """Make envelop plots. An envelope shows over time the minimum and maximum value for a set @@ -260,19 +261,19 @@ def single_envelope(outcomes, outcome_to_plot, time, density, ax, ax_d, fill, lo def lines( - experiments, - outcomes, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + outcomes: dict[str, np.ndarray], + outcomes_to_show: str | list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - density="", - legend=True, - titles=None, - ylabels=None, - experiments_to_show=None, - show_envelope=False, - log=False, -): + density: Density | str = "", + legend: bool = True, + titles: dict[str, str] | None = None, + ylabels: dict[str, str] | None = None, + experiments_to_show: np.ndarray | None = None, + show_envelope: bool = False, + log: bool = False, +) -> tuple[plt.Figure, dict[str, plt.Axes]]: """Visualize results from experiments as line plots. It is thus to be used in case of time @@ -601,13 +602,13 @@ def simple_lines(outcomes, outcome_to_plot, time, density, ax, ax_d, log): def kde_over_time( - experiments, - outcomes, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + outcomes: dict[str, np.ndarray], + outcomes_to_show: str | list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - colormap="viridis", - log=True, + colormap: str = "viridis", + log: bool = True, ): """Plot a KDE over time. The KDE is visualized through a heatmap. @@ -679,21 +680,21 @@ def kde_over_time( def multiple_densities( - experiments, - outcomes, - points_in_time=None, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + outcomes: dict[str, np.ndarray], + points_in_time: list[float] | None = None, + outcomes_to_show: str | list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - density=Density.KDE, - legend=True, - titles=None, - ylabels=None, - experiments_to_show=None, - plot_type=PlotType.ENVELOPE, - log=False, + density: Density = Density.KDE, + legend: bool = True, + titles: dict[str, str] | None = None, + ylabels: dict[str, str] | None = None, + experiments_to_show: np.ndarray | None = None, + plot_type: PlotType = PlotType.ENVELOPE, + log: bool = False, **kwargs, -): +) -> tuple[list[plt.Figure], dict[str, dict[str, plt.Axes]]]: """Make an envelope plot with multiple density plots over the run time. Parameters diff --git a/ema_workbench/analysis/plotting_util.py b/ema_workbench/analysis/plotting_util.py index b782cd310..de913f911 100644 --- a/ema_workbench/analysis/plotting_util.py +++ b/ema_workbench/analysis/plotting_util.py @@ -1,12 +1,18 @@ """Plotting utility functions.""" +from __future__ import annotations + import copy import enum +from typing import Any import matplotlib as mpl +import matplotlib.axes +import matplotlib.figure import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import numpy as np +import numpy.typing as npt import pandas as pd import scipy.stats as stats import seaborn as sns @@ -74,7 +80,7 @@ class PlotType(enum.Enum): """constant for plotting envelopes with lines.""" -def plot_envelope(ax, j, time, value, fill=False): +def plot_envelope(ax: matplotlib.axes.Axes, j: int, time: npt.NDArray, value: npt.NDArray, fill: bool = False) -> None: """Helper function, responsible for plotting an envelope. Parameters @@ -102,7 +108,7 @@ def plot_envelope(ax, j, time, value, fill=False): ax.plot(time, maximum, c=color) -def plot_histogram(ax, values, log): +def plot_histogram(ax: matplotlib.axes.Axes, values: npt.NDArray | list[npt.NDArray], log: bool) -> Any: """Helper function, responsible for plotting a histogram. Parameters @@ -131,7 +137,7 @@ def plot_histogram(ax, values, log): return a -def plot_kde(ax, values, log): +def plot_kde(ax: matplotlib.axes.Axes, values: list[npt.NDArray], log: bool) -> None: """Helper function, responsible for plotting a KDE. Parameters @@ -155,7 +161,7 @@ def plot_kde(ax, values, log): ax.set_xticklabels(labels) -def plot_boxplots(ax, values, log, group_labels=None): +def plot_boxplots(ax: matplotlib.axes.Axes, values: list[npt.NDArray], log: bool, group_labels: list[str] | None = None) -> None: """Helper function for plotting a boxplot. Parameters @@ -189,7 +195,7 @@ def plot_boxplots(ax, values, log, group_labels=None): sns.boxplot(x="id_var", y=0, data=data, order=group_labels, ax=ax) -def plot_violinplot(ax, values, log, group_labels=None): +def plot_violinplot(ax: matplotlib.axes.Axes, values: list[npt.NDArray], log: bool, group_labels: list[str] | None = None) -> None: """Helper function for plotting violin plots on axes. Parameters @@ -214,7 +220,7 @@ def plot_violinplot(ax, values, log, group_labels=None): sns.violinplot(x="variable", y="value", data=data, order=group_labels, ax=ax) -def plot_boxenplot(ax, values, log, group_labels=None): +def plot_boxenplot(ax: matplotlib.axes.Axes, values: list[npt.NDArray], log: bool, group_labels: list[str] | None = None) -> None: """Helper function for plotting boxenplot plots on axes. Parameters @@ -237,8 +243,8 @@ def plot_boxenplot(ax, values, log, group_labels=None): def group_density( - ax_d, density, outcomes, outcome_to_plot, group_labels, log=False, index=-1 -): + ax_d: matplotlib.axes.Axes, density: Density, outcomes: dict, outcome_to_plot: str, group_labels: list[str], log: bool = False, index: int = -1 +) -> None: """Helper function for plotting densities in case of grouped data. Parameters @@ -276,7 +282,7 @@ def group_density( ax_d.set_ylabel("") -def simple_density(density, value, ax_d, ax, log): +def simple_density(density: Density, value: npt.NDArray, ax_d: matplotlib.axes.Axes, ax: matplotlib.axes.Axes, log: bool) -> None: """Helper function, responsible for producing a density plot. Parameters @@ -313,7 +319,7 @@ def simple_density(density, value, ax_d, ax, log): ax_d.set_ylabel("") -def simple_kde(outcomes, outcomes_to_show, colormap, log, minima, maxima): +def simple_kde(outcomes: dict[str, npt.NDArray], outcomes_to_show: list[str], colormap: str, log: bool, minima: dict[str, float], maxima: dict[str, float]) -> tuple[matplotlib.figure.Figure, dict[str, matplotlib.axes.Axes]]: """Helper function for generating a density heatmap over time. Parameters @@ -360,7 +366,7 @@ def simple_kde(outcomes, outcomes_to_show, colormap, log, minima, maxima): return fig, axes_dict -def make_legend(categories, ax, ncol=3, legend_type=LegendEnum.LINE, alpha=1): +def make_legend(categories: list[str], ax: matplotlib.axes.Axes, ncol: int = 3, legend_type: LegendEnum = LegendEnum.LINE, alpha: float = 1) -> None: """Helper function responsible for making the legend. Parameters @@ -423,7 +429,7 @@ def make_legend(categories, ax, ncol=3, legend_type=LegendEnum.LINE, alpha=1): ) -def determine_kde(data, size_kde=1000, ymin=None, ymax=None): +def determine_kde(data: npt.NDArray, size_kde: int = 1000, ymin: float | None = None, ymax: float | None = None) -> tuple[npt.NDArray, npt.NDArray]: """Helper function responsible for performing a KDE. Parameters @@ -468,7 +474,7 @@ def determine_kde(data, size_kde=1000, ymin=None, ymax=None): return kde_x, kde_y -def filter_scalar_outcomes(outcomes): +def filter_scalar_outcomes(outcomes: dict[str, npt.NDArray]) -> dict[str, npt.NDArray]: """Helper function that removes non time series outcomes from all the utcomes. Parameters @@ -491,7 +497,7 @@ def filter_scalar_outcomes(outcomes): return temp -def determine_time_dimension(outcomes): +def determine_time_dimension(outcomes: dict[str, npt.NDArray]) -> tuple[npt.NDArray | None, dict[str, npt.NDArray]]: """Helper function for determining or creating time dimension. Parameters @@ -522,8 +528,8 @@ def determine_time_dimension(outcomes): def group_results( - experiments, outcomes, group_by, grouping_specifiers, grouping_labels -): + experiments: pd.DataFrame, outcomes: dict[str, npt.NDArray], group_by: str, grouping_specifiers, grouping_labels: list[str] +) -> dict: """Helper function that takes the experiments and results and returns a list based on groupoing. Each element in the dictionary contains the experiments @@ -596,7 +602,7 @@ def group_results( return groups -def make_continuous_grouping_specifiers(array, nr_of_groups=5): +def make_continuous_grouping_specifiers(array: npt.NDArray, nr_of_groups: int = 5) -> list[tuple[float, float]]: """Helper function for discretizing a continuous array. By default, the array is split into 5 equally wide intervals. @@ -630,13 +636,13 @@ def make_continuous_grouping_specifiers(array, nr_of_groups=5): def prepare_pairs_data( - experiments, - outcomes, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + outcomes: dict[str, npt.NDArray], + outcomes_to_show: list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - point_in_time=-1, - filter_scalar=True, + point_in_time: int = -1, + filter_scalar: bool = True, ): """Helper function to prepare the data for pairs plotting. @@ -696,13 +702,13 @@ def filter_outcomes(outcomes, point_in_time): def prepare_data( - experiments, - experiments_to_show, - outcomes, - outcomes_to_show=None, - group_by=None, + experiments: pd.DataFrame, + experiments_to_show: npt.NDArray | None, + outcomes: dict[str, npt.NDArray], + outcomes_to_show: str | list[str] | None = None, + group_by: str | None = None, grouping_specifiers=None, - filter_scalar=True, + filter_scalar: bool = True, ): """Helper function for preparing datasets prior to plotting. @@ -787,7 +793,7 @@ def prepare_data( return experiments, outcomes, outcomes_to_show, time, grouping_labels -def do_titles(ax, titles, outcome): +def do_titles(ax: matplotlib.axes.Axes, titles: dict[str, str] | None, outcome: str) -> None: """Helper function for setting the title on an ax. Parameters @@ -812,7 +818,7 @@ def do_titles(ax, titles, outcome): ax.set_title(outcome) -def do_ylabels(ax, ylabels, outcome): +def do_ylabels(ax: matplotlib.axes.Axes, ylabels: dict[str, str] | None, outcome: str) -> None: """Helper function for setting the y labels on an ax. Parameters @@ -837,7 +843,7 @@ def do_ylabels(ax, ylabels, outcome): ax.set_ylabel(outcome) -def make_grid(outcomes_to_show, density=False): +def make_grid(outcomes_to_show: list[str], density: bool = False) -> tuple[matplotlib.figure.Figure, gridspec.GridSpec]: """Helper function for making the grid that specifies the size and location of all axes. Parameters @@ -858,7 +864,7 @@ def make_grid(outcomes_to_show, density=False): return figure, grid -def get_color(index): +def get_color(index: int): """Helper function for cycling over color list. Useful if the number of items is higher than the length of the color list. diff --git a/ema_workbench/analysis/prim.py b/ema_workbench/analysis/prim.py index 17adfadb7..5aadc068d 100644 --- a/ema_workbench/analysis/prim.py +++ b/ema_workbench/analysis/prim.py @@ -18,7 +18,7 @@ import itertools import numbers import warnings -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from operator import itemgetter from typing import Literal @@ -38,6 +38,7 @@ "altair based interactive inspection not available", ImportWarning, stacklevel=2 ) +from ..em_framework.util import NumpySeedLike, RNGLike from ..util import INFO, EMAError, get_module_logger, temporary_filter from . import scenario_discovery_util as sdutil from .prim_util import ( @@ -57,9 +58,6 @@ # # .. codeauthor:: jhkwakkel -SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence -RNGLike = np.random.Generator | np.random.BitGenerator - __all__ = [ "PRIMObjectiveFunctions", @@ -828,7 +826,7 @@ def resample( i: int | None = None, iterations: int = 10, p: float = 1 / 2, - rng: RNGLike | SeedLike | None = None, + rng: RNGLike | NumpySeedLike | None = None, ) -> pd.DataFrame: """Calculate resample statistics for candidate box i. diff --git a/ema_workbench/connectors/excel.py b/ema_workbench/connectors/excel.py index f30851b1d..c78d2c0ad 100644 --- a/ema_workbench/connectors/excel.py +++ b/ema_workbench/connectors/excel.py @@ -69,8 +69,8 @@ class BaseExcelModel(FileModel): com_warning_msg = "com error: no cell(s) named %s found" def __init__( - self, name, wd=None, model_file=None, default_sheet=None, pointers=None - ): + self, name: str, wd: str | None = None, model_file: str | None = None, default_sheet: str | None = None, pointers: dict[str, str] | None = None + ) -> None: super().__init__(name, wd=wd, model_file=model_file) #: Reference to the Excel application. This attribute is `None` until #: model_init has been invoked. @@ -91,11 +91,11 @@ def __init__( self.pointers = pointers @property - def workbook(self): + def workbook(self) -> str: return self.model_file @method_logger(__name__) - def model_init(self, policy): + def model_init(self, policy) -> None: """Method called to initialize the model. Parameters @@ -130,7 +130,7 @@ def model_init(self, policy): _logger.debug(self.working_directory) @method_logger(__name__) - def run_experiment(self, experiment): + def run_experiment(self, experiment) -> dict: """Method for running an experiment. This implementation assumes that the names of the uncertainties correspond @@ -174,7 +174,7 @@ def run_experiment(self, experiment): return results @method_logger(__name__) - def cleanup(self): + def cleanup(self) -> None: """Cleaning up prior to finishing performing experiments. This will close the workbook and close Excel @@ -199,7 +199,7 @@ def cleanup(self): self.xl = None self.wb = None - def get_sheet(self, sheetname=None): + def get_sheet(self, sheetname: str | None = None): """Get a named worksheet, or the default worksheet if set. Parameters @@ -227,7 +227,7 @@ def get_sheet(self, sheetname=None): return sheet - def get_wb_value(self, name): + def get_wb_value(self, name: str): """Extract a value from a cell of the excel workbook. Parameters @@ -267,7 +267,7 @@ def get_wb_value(self, name): return value - def set_wb_value(self, name, value): + def set_wb_value(self, name: str, value) -> None: """Inject a value into a cell of the excel workbook. Parameters @@ -304,7 +304,7 @@ def set_wb_value(self, name, value): f"com error: no cell(s) named {this_range} found on sheet {this_sheet}" ) - def get_wb_sheetnames(self): + def get_wb_sheetnames(self) -> list[str]: """Get the names of all the workbook's worksheets.""" if self.wb: try: diff --git a/ema_workbench/connectors/netlogo.py b/ema_workbench/connectors/netlogo.py index 99682bec8..17707913a 100644 --- a/ema_workbench/connectors/netlogo.py +++ b/ema_workbench/connectors/netlogo.py @@ -123,7 +123,7 @@ def __init__( self.jvm_args = jvm_args @method_logger(__name__) - def model_init(self, policy: Sample): + def model_init(self, policy: Sample) -> None: """Method called to initialize the model. Parameters @@ -148,7 +148,7 @@ def model_init(self, policy: Sample): _logger.debug("model opened") @method_logger(__name__) - def run_experiment(self, experiment: Experiment): + def run_experiment(self, experiment: Experiment) -> dict: """Method for running an experiment.. Parameters @@ -230,7 +230,7 @@ def run_experiment(self, experiment: Experiment): return results - def retrieve_output(self): + def retrieve_output(self) -> dict: """Method for retrieving output after a model run. Returns @@ -241,7 +241,7 @@ def retrieve_output(self): return self.output @method_logger(__name__) - def cleanup(self): + def cleanup(self) -> None: """Cleanup after finishing all the experiments, but just prior to returning the results. This method gives a hook for doing any cleanup, such as closing applications. @@ -259,7 +259,7 @@ def cleanup(self): # jpype.shutdownJVM() # self.netlogo = None - def _handle_outcomes(self, fns): + def _handle_outcomes(self, fns: dict[str, str]) -> dict: """Helper function for parsing outcomes.""" results = {} for key, value in fns.items(): diff --git a/ema_workbench/connectors/pysd_connector.py b/ema_workbench/connectors/pysd_connector.py index e5d4afe1a..480e6a7db 100644 --- a/ema_workbench/connectors/pysd_connector.py +++ b/ema_workbench/connectors/pysd_connector.py @@ -31,11 +31,11 @@ class BasePysdModel(AbstractModel): """ @property - def mdl_file(self): + def mdl_file(self) -> str: return self._mdl_file @mdl_file.setter - def mdl_file(self, mdl_file): + def mdl_file(self, mdl_file: str) -> None: if not mdl_file.endswith(".mdl"): raise ValueError("model file needs to be a vensim .mdl file") if not os.path.isfile(mdl_file): @@ -54,7 +54,7 @@ def __init__(self, name: str, mdl_file: str | None = None): self.model = None @method_logger(__name__) - def model_init(self, policy: Sample, **kwargs): + def model_init(self, policy: Sample, **kwargs) -> None: """Initialize the model.""" super().model_init(policy) @@ -63,7 +63,7 @@ def model_init(self, policy: Sample, **kwargs): self.model = pysd.read_vensim(self.mdl_file) @method_logger(__name__) - def run_experiment(self, experiment): + def run_experiment(self, experiment) -> dict: """Run the experiment.""" res = self.model.run(params=experiment, return_columns=self.output_variables) @@ -71,7 +71,7 @@ def run_experiment(self, experiment): output = {col: series.values for col, series in res.items()} return output - def reset_model(self): + def reset_model(self) -> None: """Method for resetting the model to its initial state. The default implementation only sets the outputs to an empty dict. diff --git a/ema_workbench/connectors/vadere.py b/ema_workbench/connectors/vadere.py index ef6ef0da2..0e3b6a38b 100644 --- a/ema_workbench/connectors/vadere.py +++ b/ema_workbench/connectors/vadere.py @@ -24,7 +24,7 @@ ] -def change_vadere_scenario(model_file, variable, value): +def change_vadere_scenario(model_file: dict, variable: str, value: float) -> None: """Change variable in vadere .scenario file structure. Note that a vadere scenario takes the format of a nested directory. @@ -47,7 +47,7 @@ def change_vadere_scenario(model_file, variable, value): reduce(operator.getitem, index[:-1], model_file)[index[-1]] = value -def update_vadere_scenario(model_file, experiment, output_file): +def update_vadere_scenario(model_file: str, experiment: dict, output_file: str) -> None: """Load a vadere .scenario file, change it depending on the passed experiment, and save it again as .scenario file. Parameters @@ -83,7 +83,7 @@ class BaseVadereModel(FileModel): """ - def __init__(self, name, vadere_jar, processor_files, wd, model_file): + def __init__(self, name: str, vadere_jar: str, processor_files: list[str], wd: str, model_file: str) -> None: """Init of class. Parameters @@ -120,7 +120,7 @@ def __init__(self, name, vadere_jar, processor_files, wd, model_file): self.processor_files = processor_files @method_logger(__name__) - def model_init(self, policy): + def model_init(self, policy) -> None: """Method called to initialize the model. Parameters @@ -133,7 +133,7 @@ def model_init(self, policy): super().model_init(policy) @method_logger(__name__) - def run_experiment(self, experiment): + def run_experiment(self, experiment) -> dict: """Run the experiment. Parameters @@ -234,7 +234,7 @@ def run_experiment(self, experiment): pass return res - def cleanup(self): + def cleanup(self) -> None: """Cleanup after performing all experiments. This method gives a hook for doing any cleanup, such as closing applications. diff --git a/ema_workbench/connectors/vensim.py b/ema_workbench/connectors/vensim.py index 027a626fa..95d0175fc 100644 --- a/ema_workbench/connectors/vensim.py +++ b/ema_workbench/connectors/vensim.py @@ -36,7 +36,7 @@ _logger = get_module_logger(__name__) -def be_quiet(): +def be_quiet() -> None: """Turn off the work in progress dialog of Vensim. Defaults to 2, suppressing all windows, for more fine-grained control, use @@ -45,7 +45,7 @@ def be_quiet(): vensim_dll_wrapper.be_quiet(2) -def load_model(file_name): +def load_model(file_name: str) -> None: """Load the model. Parameters @@ -68,7 +68,7 @@ def load_model(file_name): raise VensimError("vensim file not found") from w -def read_cin_file(file_name): +def read_cin_file(file_name: str) -> None: """Read a .cin file. Parameters @@ -89,7 +89,7 @@ def read_cin_file(file_name): raise w -def set_value(variable, value): +def set_value(variable: str, value: int | float | list) -> None: """Set the value of a variable to value. current implementation only works for lookups and normal values. In case @@ -117,7 +117,7 @@ def set_value(variable, value): _logger.warning("variable: '" + variable + "' not found") -def run_simulation(file_name): +def run_simulation(file_name: str) -> None: """Rn a model and store the results of the run in the specified .vdf file. The specified output file will be overwritten by default @@ -145,7 +145,7 @@ def run_simulation(file_name): raise VensimError(str(w)) from w -def get_data(filename, varname, step=1): +def get_data(filename: str, varname: str, step: int = 1) -> list: """Retrieve data from simulation runs or imported data sets. Parameters @@ -184,7 +184,7 @@ class VensimModel(SingleReplication, FileModel): """ @property - def result_file(self): + def result_file(self) -> str: """Return path to results file.""" return os.path.join(self.working_directory, self._result_file) @@ -245,7 +245,7 @@ def __init__( _logger.debug("vensim interface init completed") - def model_init(self, policy: Sample): + def model_init(self, policy: Sample) -> None: """Init of the model. Parameters @@ -290,7 +290,7 @@ def handle_underscores(variables: list[Variable]): raise EMAWarning(str(VensimWarning)) from w @method_logger(__name__) - def run_experiment(self, experiment: Experiment): + def run_experiment(self, experiment: Experiment) -> dict: """Run the experiment. The provided implementation assumes that the keys (i.e., the parameter names) in the @@ -358,7 +358,7 @@ def check_data(result): return results -def create_model_for_debugging(path_to_existing_model, path_to_new_model, error): +def create_model_for_debugging(path_to_existing_model: str, path_to_new_model: str, error: str) -> None: """Create a vensim mdl file parameterized according to the experiment. To be able to debug the Vensim model, a few steps are needed: diff --git a/ema_workbench/em_framework/callbacks.py b/ema_workbench/em_framework/callbacks.py index 524300569..8399385a9 100644 --- a/ema_workbench/em_framework/callbacks.py +++ b/ema_workbench/em_framework/callbacks.py @@ -10,11 +10,14 @@ """ +from __future__ import annotations + import abc import csv import math import os import shutil +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -27,9 +30,12 @@ IntegerParameter, Parameter, ) -from .points import flatten_sample +from .points import Experiment, flatten_sample from .util import ProgressTrackingMixIn +if TYPE_CHECKING: + from typing import Any + # # Created on 22 Jan 2013 # @@ -80,13 +86,13 @@ class AbstractCallback(ProgressTrackingMixIn, metaclass=abc.ABCMeta): def __init__( self, - uncertainties, - levers, - outcomes, - nr_experiments, - reporting_interval=None, - reporting_frequency=10, - log_progress=False, + uncertainties: list[Parameter], + levers: list[Parameter], + outcomes: list[Outcome], + nr_experiments: int, + reporting_interval: int | None = None, + reporting_frequency: int = 10, + log_progress: bool = False, ): """Init.""" super().__init__(nr_experiments, reporting_frequency, _logger, log_progress) @@ -102,7 +108,7 @@ def __init__( self.reporting_interval = reporting_interval @abc.abstractmethod - def __call__(self, experiment, outcomes): + def __call__(self, experiment: Experiment, outcomes: dict[str, Any]) -> None: """Method responsible for storing results. The implementation in this class only keeps track of how many runs @@ -120,7 +126,7 @@ def __call__(self, experiment, outcomes): super().__call__(1) @abc.abstractmethod - def get_results(self): + def get_results(self) -> Any: """Method for retrieving the results. Called after all experiments have been completed. Any extension of AbstractCallback needs to @@ -294,7 +300,7 @@ def _store_outcomes(self, case_id, outcomes): ) self.results[outcome_name][case_id,] = outcome_res - def __call__(self, experiment, outcomes): + def __call__(self, experiment: Experiment, outcomes: dict[str, Any]) -> None: """Method responsible for storing results. This method calls :meth:`super` first, thus utilizing the logging provided there. @@ -310,7 +316,7 @@ def __call__(self, experiment, outcomes): self._store_case(experiment) self._store_outcomes(experiment.experiment_id, outcomes) - def get_results(self): + def get_results(self) -> tuple[pd.DataFrame, dict[str, np.ndarray]]: """Return the experiments and their results.""" results = {} for k, v in self.results.items(): @@ -376,12 +382,12 @@ class FileBasedCallback(AbstractCallback): def __init__( self, - uncertainties, - levers, - outcomes, - nr_experiments, - reporting_interval=100, - reporting_frequency=10, + uncertainties: list[Parameter], + levers: list[Parameter], + outcomes: list[Outcome], + nr_experiments: int, + reporting_interval: int = 100, + reporting_frequency: int = 10, ): """Init.""" super().__init__( diff --git a/ema_workbench/em_framework/evaluators.py b/ema_workbench/em_framework/evaluators.py index add37f6ad..fb453f5b7 100644 --- a/ema_workbench/em_framework/evaluators.py +++ b/ema_workbench/em_framework/evaluators.py @@ -33,7 +33,7 @@ LHSSampler, MonteCarloSampler, ) -from .util import determine_objects +from .util import StdlibSeedLike, determine_objects # Created on 5 Mar 2017 # @@ -49,8 +49,6 @@ _logger = get_module_logger(__name__) -SeedLike = int | float | str | bytes | bytearray # seedlike for stdlib random.seed - class Samplers(enum.Enum): """Enum for different kinds of samplers.""" @@ -66,7 +64,7 @@ class Samplers(enum.Enum): MORRIS = MorrisSampler() -SamplerTypes = Literal[ +type SamplerTypes = Literal[ Samplers.MC, Samplers.LHS, Samplers.FF, @@ -209,6 +207,44 @@ def perform_experiments( **kwargs, ) + @overload + def optimize( + self, + algorithm: type[AbstractGeneticAlgorithm] = EpsNSGAII, + nfe: int = 10000, + searchover: str = "levers", + reference: Sample | None = None, + constraints: Iterable[Constraint] | None = None, + convergence_freq: int = 1000, + logging_freq: int = 5, + variator: type[Variator] | None = None, + rng: Iterable[StdlibSeedLike] | None = None, + initial_population: Iterable[Sample] | None = None, + filename: str | None = None, + directory: str | None = None, + **kwargs, + ) -> list[tuple[pd.DataFrame, pd.DataFrame]]: + ... + + @overload + def optimize( + self, + algorithm: type[AbstractGeneticAlgorithm] = EpsNSGAII, + nfe: int = 10000, + searchover: str = "levers", + reference: Sample | None = None, + constraints: Iterable[Constraint] | None = None, + convergence_freq: int = 1000, + logging_freq: int = 5, + variator: type[Variator] | None = None, + rng: StdlibSeedLike | None = None, + initial_population: Iterable[Sample] | None = None, + filename: str | None = None, + directory: str | None = None, + **kwargs, + ) -> tuple[pd.DataFrame, pd.DataFrame]: + ... + def optimize( self, algorithm: type[AbstractGeneticAlgorithm] = EpsNSGAII, @@ -219,12 +255,12 @@ def optimize( convergence_freq: int = 1000, logging_freq: int = 5, variator: type[Variator] | None = None, - rng: SeedLike | Iterable[SeedLike] | None = None, + rng: StdlibSeedLike | Iterable[StdlibSeedLike] | None = None, initial_population: Iterable[Sample] | None = None, filename: str | None = None, directory: str | None = None, **kwargs, - ) -> tuple[pd.DataFrame, pd.DataFrame]: + ): """Convenience method for outcome optimization. A call to this method is forwarded to :func:optimize, with evaluator and models @@ -263,7 +299,7 @@ def robust_optimize( nfe: int = 10000, convergence_freq: int = 1000, logging_freq: int = 5, - rng: SeedLike | Iterable[SeedLike] | None = None, + rng: StdlibSeedLike | Iterable[StdlibSeedLike] | None = None, **kwargs, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Convenience method for robust optimization. @@ -578,7 +614,7 @@ def optimize( convergence_freq: int = 1000, logging_freq: int = 5, variator: Variator = None, - rng: Iterable[SeedLike] | None = None, + rng: Iterable[StdlibSeedLike] | None = None, initial_population: Iterable[Sample] | None = None, filename: str | None = None, directory: str | None = None, @@ -598,7 +634,7 @@ def optimize( convergence_freq: int = 1000, logging_freq: int = 5, variator: Variator = None, - rng: SeedLike | None = None, + rng: StdlibSeedLike | None = None, initial_population: Iterable[Sample] | None = None, filename: str | None = None, directory: str | None = None, @@ -617,7 +653,7 @@ def optimize( convergence_freq: int = 1000, logging_freq: int = 5, variator: Variator = None, - rng: SeedLike | Iterable[SeedLike] | None = None, + rng: StdlibSeedLike | Iterable[StdlibSeedLike] | None = None, initial_population: Iterable[Sample] | None = None, filename: str | None = None, directory: str | None = None, @@ -731,7 +767,7 @@ def robust_optimize( convergence_freq: int = 1000, logging_freq: int = 5, variator: Variator = None, - rng: SeedLike | None = None, + rng: StdlibSeedLike | None = None, initial_population: Iterable[Sample] | None = None, filename: str | None = None, directory: str | None = None, @@ -751,7 +787,7 @@ def robust_optimize( convergence_freq: int = 1000, logging_freq: int = 5, variator: Variator = None, - rng: Iterable[SeedLike] | None = None, + rng: Iterable[StdlibSeedLike] | None = None, initial_population: Iterable[Sample] | None = None, filename: str | None = None, directory: str | None = None, @@ -770,7 +806,7 @@ def robust_optimize( convergence_freq: int = 1000, logging_freq: int = 5, variator: Variator = None, - rng: SeedLike | Iterable[SeedLike] | None = None, + rng: StdlibSeedLike | Iterable[StdlibSeedLike] | None = None, initial_population: Iterable[Sample] | None = None, filename: str | None = None, directory: str | None = None, diff --git a/ema_workbench/em_framework/futures_util.py b/ema_workbench/em_framework/futures_util.py index 41d8fbfb8..fe75d3777 100644 --- a/ema_workbench/em_framework/futures_util.py +++ b/ema_workbench/em_framework/futures_util.py @@ -1,20 +1,27 @@ """Utilities for futures modules.""" +from __future__ import annotations + import os import random import shutil import string import time from collections import defaultdict +from collections.abc import Collection +from typing import TYPE_CHECKING from ..util import get_module_logger +if TYPE_CHECKING: + from .model import AbstractModel + __all__ = ["determine_rootdir", "finalizer", "setup_working_directories"] _logger = get_module_logger(__name__) -def determine_rootdir(msis): +def determine_rootdir(msis: Collection[AbstractModel]) -> str | None: """Determine common root directory for all models.""" for model in msis: try: @@ -32,10 +39,10 @@ def determine_rootdir(msis): return root_dir -def finalizer(experiment_runner): +def finalizer(experiment_runner: AbstractModel) -> callable: """Cleanup.""" - def finalizer(tmpdir): + def finalizer(tmpdir: str | None) -> None: _logger.info("finalizing") experiment_runner.cleanup() @@ -52,7 +59,7 @@ def finalizer(tmpdir): return finalizer -def setup_working_directories(models, root_dir): +def setup_working_directories(models: Collection[AbstractModel], root_dir: str) -> str | None: """Setup working directories when running in parallel. Copies the working directory of each model to a process specific diff --git a/ema_workbench/em_framework/model.py b/ema_workbench/em_framework/model.py index bdf252b54..1f057c21b 100644 --- a/ema_workbench/em_framework/model.py +++ b/ema_workbench/em_framework/model.py @@ -64,7 +64,7 @@ class AbstractModel(NamedObject): """ @property - def outcomes_output(self): + def outcomes_output(self) -> dict: """Getter for outcomes output.""" return self._outcomes_output @@ -76,7 +76,7 @@ def outcomes_output(self, outputs): self._outcomes_output[outcome.name] = outcome.process(data) @property - def output_variables(self): + def output_variables(self) -> list[str]: """Getter for output variables.""" if self._output_variables is None: self._output_variables = [ @@ -177,7 +177,7 @@ def _transform(self, sampled_parameters: Sample, parameters: list[Variable]): sampled_parameters.data = temp @method_logger(__name__) - def run_model(self, scenario: Sample, policy: Sample, constants: Sample): + def run_model(self, scenario: Sample, policy: Sample, constants: Sample) -> None: """Run the model for the specified scenario, policy, and constants. Parameters @@ -195,7 +195,7 @@ def run_model(self, scenario: Sample, policy: Sample, constants: Sample): self._transform(constants, self.constants) @method_logger(__name__) - def initialized(self, policy: Sample): + def initialized(self, policy: Sample) -> bool: """Check if model has been initialized. Parameters @@ -297,7 +297,7 @@ def replications(self, replications: int | list[dict]): ) @method_logger(__name__) - def run_model(self, scenario: Sample, policy: Sample, constants: Sample): + def run_model(self, scenario: Sample, policy: Sample, constants: Sample) -> None: """Run the model for the specified scenario, policy, and constants. Parameters @@ -333,7 +333,7 @@ class SingleReplication(AbstractModel): """Base class for models that require only a single replication.""" @method_logger(__name__) - def run_model(self, scenario: Sample, policy: Sample, constants: Sample): + def run_model(self, scenario: Sample, policy: Sample, constants: Sample) -> None: """Run the model for the specified scenario, policy, and constants. Parameters diff --git a/ema_workbench/em_framework/optimization.py b/ema_workbench/em_framework/optimization.py index 05e173313..f14b24bf3 100644 --- a/ema_workbench/em_framework/optimization.py +++ b/ema_workbench/em_framework/optimization.py @@ -48,7 +48,7 @@ RealParameter, ) from .points import Sample, SampleCollection -from .util import ProgressTrackingMixIn +from .util import ProgressTrackingMixIn, StdlibSeedLike # Created on 5 Jun 2017 # @@ -64,9 +64,6 @@ _logger = get_module_logger(__name__) -SeedLike = int | float | str | bytes | bytearray # seedlike for stdlib random.seed - - class Problem(PlatypusProblem): """Small extension to Platypus problem object. @@ -165,7 +162,7 @@ def to_platypus_types(decision_variables: Iterable[Parameter]) -> list[platypus. def to_dataframe( solutions: Iterable[platypus.Solution], dvnames: list[str], outcome_names: list[str] -): +) -> pd.DataFrame: """Helper function to turn a collection of platypus Solution instances into a pandas DataFrame. Parameters @@ -193,7 +190,7 @@ def to_dataframe( return results -def process_jobs(jobs: list[platypus.core.EvaluateSolution]): +def process_jobs(jobs: list[platypus.core.EvaluateSolution]) -> tuple[Sample | SampleCollection | int, list[Sample] | SampleCollection]: """Helper function to map jobs generated by platypus to Sample instances. Parameters @@ -233,7 +230,7 @@ def evaluate( experiments: pd.DataFrame, outcomes: dict[str, np.ndarray], problem: Problem, -): +) -> None: """Helper function for mapping the results from perform_experiments back to what platypus needs.""" searchover = problem.searchover outcome_names = problem.outcome_names @@ -479,7 +476,7 @@ def epsilon_nondominated( return to_dataframe(archive, problem.parameter_names, problem.outcome_names) -def rebuild_platypus_population(archive: pd.DataFrame, problem: Problem): +def rebuild_platypus_population(archive: pd.DataFrame, problem: Problem) -> list[Solution]: """Rebuild a population of platypus Solution instances. Parameters @@ -689,9 +686,9 @@ def _optimize( initial_population: Iterable[Sample] | None = None, filename: str | None = None, directory: str | None = None, - rng: None | SeedLike = None, + rng: None | StdlibSeedLike = None, **kwargs, -): +) -> tuple[pd.DataFrame, pd.DataFrame]: """Helper function for optimization.""" klass = problem.types[0].__class__ diff --git a/ema_workbench/em_framework/outcomes.py b/ema_workbench/em_framework/outcomes.py index e6f273185..b3a205785 100644 --- a/ema_workbench/em_framework/outcomes.py +++ b/ema_workbench/em_framework/outcomes.py @@ -3,9 +3,11 @@ import abc import collections import numbers +import tarfile import warnings -from collections.abc import Callable +from collections.abc import Callable, Sequence from io import BytesIO +from typing import Any import numpy as np import pandas as pd @@ -56,7 +58,7 @@ def __call__(self, outcome): else: pass # multiple instances of the same class and name is fine - def serialize(self, name, values): + def serialize(self, name: str, values: np.ndarray | pd.DataFrame) -> tuple[BytesIO, str]: """Serialize the given outcome. Parameters @@ -77,7 +79,7 @@ def serialize(self, name, values): return stream, f"{name}.{extension}" - def deserialize(self, name, filename, archive): + def deserialize(self, name: str, filename: str, archive: tarfile.TarFile) -> np.ndarray: """Serialize the given outcome.""" return self.outcomes[name].from_disk(filename, archive) @@ -127,13 +129,13 @@ class Outcome(Variable, metaclass=abc.ABCMeta): def __init__( self, - name, - kind=INFO, - variable_name=None, - function=None, - shape=None, - dtype=None, - ): + name: str, + kind: int = INFO, + variable_name: str | Sequence[str] | None = None, + function: Callable[..., Any] | None = None, + shape: tuple[int, ...] | None = None, + dtype: np.dtype | type | None = None, + ) -> None: """Init.""" super().__init__(name) @@ -165,7 +167,7 @@ def __init__( self.shape = shape self.dtype = dtype - def process(self, values): + def process(self, values: list[Any]) -> Any: """Process the values.""" if self.function: var_names = self.variable_name @@ -226,7 +228,7 @@ def __hash__(self): # noqa: D105 @classmethod @abc.abstractmethod - def to_disk(cls, values): + def to_disk(cls, values: np.ndarray | pd.DataFrame) -> tuple[BytesIO, str]: """Helper function for writing outcome to disk. Parameters @@ -243,7 +245,7 @@ def to_disk(cls, values): @classmethod @abc.abstractmethod - def from_disk(cls, filename, archive): + def from_disk(cls, filename: str, archive: tarfile.TarFile) -> np.ndarray: """Helper function for loading from disk. Parameters @@ -294,12 +296,12 @@ class ScalarOutcome(Outcome): def __init__( self, - name, - kind=Outcome.INFO, - variable_name=None, - function=None, - dtype=None, - ): + name: str, + kind: int = Outcome.INFO, + variable_name: str | Sequence[str] | None = None, + function: Callable[..., Any] | None = None, + dtype: np.dtype | type | None = None, + ) -> None: """Init.""" shape = None if dtype is not None: diff --git a/ema_workbench/em_framework/outputspace_exploration.py b/ema_workbench/em_framework/outputspace_exploration.py index dc074b46f..6ae64eed2 100644 --- a/ema_workbench/em_framework/outputspace_exploration.py +++ b/ema_workbench/em_framework/outputspace_exploration.py @@ -15,6 +15,8 @@ """ +from __future__ import annotations + import functools import math @@ -34,6 +36,7 @@ Multimethod, PlatypusConfig, RandomGenerator, + Solution, TournamentSelector, ) @@ -54,11 +57,11 @@ class Novelty(Dominance): """ - def __init__(self, algorithm): + def __init__(self, algorithm: OutputSpaceExplorationAlgorithm) -> None: super().__init__() self.algorithm = algorithm - def compare(self, winner, candidate): + def compare(self, winner: Solution, candidate: Solution) -> int: """Compare two solutions. Returns -1 if the first solution dominates the second, 1 if the @@ -101,7 +104,7 @@ class HitBox(Archive): """ - def __init__(self, grid_spec): + def __init__(self, grid_spec: list[tuple[float, float, float]]) -> None: """Init.""" super().__init__(None) self.archive = {} @@ -111,7 +114,7 @@ def __init__(self, grid_spec): self.improvements = 0 self.overall_novelty = 0 - def add(self, solution): + def add(self, solution: Solution) -> bool: """Add a solution to the archive.""" key = get_index_for_solution(solution, self.grid_spec) @@ -146,7 +149,7 @@ def add(self, solution): return True - def get_novelty_score(self, solution): + def get_novelty_score(self, solution: Solution) -> float: """Return the novelty score of the solution.""" key = get_index_for_solution(solution, self.grid_spec) return 1 / self.grid_counter[key] @@ -212,7 +215,7 @@ def __init__( self.comparator = Novelty(self) self.add_extension(AdaptiveTimeContinuationExtension()) - def step(self): + def step(self) -> None: """A single step of the algorithm.""" if self.nfe == 0: self.initialize() @@ -221,7 +224,7 @@ def step(self): self.result = self.archive - def initialize(self): + def initialize(self) -> None: """Initialize the algorithm.""" super().initialize() @@ -231,7 +234,7 @@ def initialize(self): if self.variator is None: self.variator = PlatypusConfig.default_variator(self.problem) - def iterate(self): + def iterate(self) -> None: """A signle iteration of the algorithm.""" offspring = [] @@ -249,7 +252,7 @@ def iterate(self): self.population = offspring[: self.population_size] -def get_index_for_solution(solution, grid_spec): +def get_index_for_solution(solution: Solution, grid_spec: list[tuple[float, float, float]]) -> tuple[int, ...]: """Maps the objectives to the key for the grid cell into which this solution falls. Parameters @@ -271,7 +274,7 @@ def get_index_for_solution(solution, grid_spec): return key -def get_bin_index(value, minumum_value, epsilon): +def get_bin_index(value: float, minumum_value: float, epsilon: float) -> int: """Maps the value for a single objective to the index of the grid cell along that dimension. Parameters diff --git a/ema_workbench/em_framework/parameters.py b/ema_workbench/em_framework/parameters.py index d629b53ac..3528a0f23 100644 --- a/ema_workbench/em_framework/parameters.py +++ b/ema_workbench/em_framework/parameters.py @@ -98,7 +98,7 @@ def __init__(self, name: str, value: Any): self.value = value -def create_category(cat): +def create_category(cat: Any) -> Category: """Helper function for creating a Category object.""" if isinstance(cat, Category): return cat @@ -133,7 +133,7 @@ class Parameter(Variable, metaclass=abc.ABCMeta): default = None @property - def resolution(self): + def resolution(self) -> list | None: """Getter for resolution.""" return self._resolution @@ -168,7 +168,7 @@ def __init__( self.uniform = True @classmethod - def from_dist(cls, name: str, dist, **kwargs): + def from_dist(cls, name: str, dist: sp.stats.rv_continuous | sp.stats.rv_discrete, **kwargs: Any) -> "Parameter": """Factory method for creating a Parameter from a scipy distribution. Alternative constructor for creating a parameter from a frozen @@ -444,7 +444,7 @@ def __init__( self.resolution = list(range(len(self.categories))) self.multivalue = multivalue - def index_for_cat(self, category): + def index_for_cat(self, category: str) -> int: """Return index of category. Parameters @@ -462,7 +462,7 @@ def index_for_cat(self, category): return i raise ValueError(f"Category {category} not found") - def cat_for_index(self, index: int): + def cat_for_index(self, index: int) -> Category: """Return category given index. Parameters @@ -699,7 +699,7 @@ def latent_parameters(self) -> list[Parameter]: parameters.append(latent_parameter) return parameters - def copy(self): + def copy(self) -> "ParameterMap": copy = self.__class__() copy._data = self._data.copy() diff --git a/ema_workbench/em_framework/points.py b/ema_workbench/em_framework/points.py index c1d042fd0..b53d1c4e2 100644 --- a/ema_workbench/em_framework/points.py +++ b/ema_workbench/em_framework/points.py @@ -7,7 +7,7 @@ import itertools import math -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Generator, Iterable from typing import TYPE_CHECKING, Literal, overload import numpy as np @@ -21,7 +21,7 @@ Parameter, ParameterMap, ) -from .util import Counter, NamedDict, NamedObject, combine +from .util import Counter, NamedDict, NamedObject, NumpySeedLike, RNGLike, combine if TYPE_CHECKING: from .optimization import Problem @@ -37,9 +37,6 @@ ] _logger = get_module_logger(__name__) -SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence -RNGLike = np.random.Generator | np.random.BitGenerator - class Sample(NamedDict): """A point in parameter space.""" @@ -233,7 +230,7 @@ def combine( self, other: SampleCollection, how: Literal["full_factorial", "sample", "cycle"], - rng: SeedLike | RNGLike | None = None, + rng: NumpySeedLike | RNGLike | None = None, ) -> SampleCollection: """Combine two SampleCollections into a new SampleCollection. @@ -384,7 +381,7 @@ def experiment_generator( scenarios: Iterable[Sample], policies: Iterable[Sample], combine: Literal["full_factorial", "sample", "cycle"] = "full_factorial", - rng: SeedLike | RNGLike | None = None, + rng: NumpySeedLike | RNGLike | None = None, ) -> Generator[Experiment]: """Generator function which yields experiments. diff --git a/ema_workbench/em_framework/samplers.py b/ema_workbench/em_framework/samplers.py index 68f3968fd..a3c5bfae3 100644 --- a/ema_workbench/em_framework/samplers.py +++ b/ema_workbench/em_framework/samplers.py @@ -8,7 +8,7 @@ import abc import itertools -from collections.abc import Iterable, Sequence +from collections.abc import Iterable import numpy as np import scipy.stats as stats @@ -20,6 +20,7 @@ ParameterMap, ) from ema_workbench.em_framework.points import SampleCollection +from ema_workbench.em_framework.util import NumpySeedLike, RNGLike # Created on 16 aug. 2011 # @@ -33,10 +34,6 @@ ] -SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence -RNGLike = np.random.Generator | np.random.BitGenerator - - class AbstractSampler(metaclass=abc.ABCMeta): """Abstract base class from which different samplers can be derived. @@ -51,7 +48,7 @@ def generate_samples( self, parameters: ParameterMap | Iterable[Parameter], size: int, - rng: SeedLike | RNGLike | None = None, + rng: NumpySeedLike | RNGLike | None = None, **kwargs, ) -> "SampleCollection": """Generate n samples from the parameters. @@ -105,7 +102,7 @@ def generate_samples( self, parameters: ParameterMap, size: int, - rng: SeedLike | RNGLike | None = None, + rng: NumpySeedLike | RNGLike | None = None, **kwargs, ) -> "SampleCollection": """Generate samples using latin hypercube sampling. @@ -148,7 +145,7 @@ def generate_samples( self, parameters: ParameterMap | Iterable[Parameter], size: int, - rng: SeedLike | RNGLike | None = None, + rng: NumpySeedLike | RNGLike | None = None, **kwargs, ) -> "SampleCollection": """Generate samples using Monte Carlo sampling. @@ -189,7 +186,7 @@ def generate_samples( self, parameters: ParameterMap | Iterable[Parameter], size: int, - rng: SeedLike | RNGLike | None = None, + rng: NumpySeedLike | RNGLike | None = None, **kwargs, ) -> "SampleCollection": """Generate samples using full factorial sampling. diff --git a/ema_workbench/em_framework/util.py b/ema_workbench/em_framework/util.py index 951e319f8..470583d8f 100644 --- a/ema_workbench/em_framework/util.py +++ b/ema_workbench/em_framework/util.py @@ -6,29 +6,41 @@ import copy import itertools import warnings -from collections.abc import Iterable, Iterator, KeysView, Mapping, MutableMapping -from typing import Generic, Literal, TypeVar - +from collections.abc import ( + Iterable, + Iterator, + KeysView, + Mapping, + MutableMapping, + Sequence, +) +from typing import Literal + +import numpy as np import tqdm from ..util import EMAError +type NumpySeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence +type RNGLike = np.random.Generator | np.random.BitGenerator +type StdlibSeedLike = int | float | str | bytes | bytearray + __all__ = [ "Counter", "NamedDict", "NamedObject", "NamedObjectMap", "NamedObjectMapDescriptor", + "NumpySeedLike", "ProgressTrackingMixIn", + "RNGLike", + "StdlibSeedLike", "Variable", "combine", "determine_objects", "representation", ] -T = TypeVar("T") - - class NamedObject: """Base object with a name attribute.""" @@ -90,7 +102,7 @@ def __init__(self, name: str, variable_name: str | list[str] | None = None): self.variable_name = variable_name -class NamedObjectMap(MutableMapping, Generic[T]): +class NamedObjectMap[T](MutableMapping): """A named object mapping class.""" def __init__(self, kind: type[T]) -> None: @@ -228,7 +240,7 @@ def copy(self): """Return a shallow copy of this object.""" return copy.copy(self) - def __getitem__(self, key) -> T: # noqa: D105 + def __getitem__(self, key): # noqa: D105 return self.data[key] def __setitem__(self, key, value): # noqa: D105 @@ -237,7 +249,7 @@ def __setitem__(self, key, value): # noqa: D105 def __delitem__(self, key): # noqa: D105 del self.data[key] - def __iter__(self) -> Iterator[T]: # noqa: D105 + def __iter__(self) -> Iterator: # noqa: D105 return iter(self.data) def __len__(self) -> int: # noqa: D105 diff --git a/ema_workbench/util/ema_logging.py b/ema_workbench/util/ema_logging.py index 5c164a45d..293f4f280 100644 --- a/ema_workbench/util/ema_logging.py +++ b/ema_workbench/util/ema_logging.py @@ -44,7 +44,7 @@ def create_module_logger(name: str | None = None) -> logging.Logger: return logger -def get_module_logger(name) -> logging.Logger: +def get_module_logger(name: str) -> logging.Logger: """Return a module logger with the given name.""" try: logger = _module_loggers[name] @@ -64,7 +64,7 @@ def get_module_logger(name) -> logging.Logger: class TemporaryFilter(logging.Filter): """Helper class to temporarily log messages.""" - def __init__(self, *args, level: int = 0, func_name=None, **kwargs): + def __init__(self, *args, level: int = 0, func_name: str | None = None, **kwargs): super().__init__(*args, **kwargs) self.level = level self.func_name = func_name @@ -141,7 +141,7 @@ def temporary_filter( v.removeFilter(k) -def method_logger(name): +def method_logger(name: str) -> callable: """Wrap method so that every call to it is logged.""" logger = get_module_logger(name) classname = inspect.getouterframes(inspect.currentframe())[1][3] @@ -180,7 +180,7 @@ def get_rootlogger() -> logging.Logger: return _rootlogger -def log_to_stderr(level=None, pass_root_logger_level=False): +def log_to_stderr(level: int | None = None, pass_root_logger_level: bool = False) -> logging.Logger: """Turn on logging and add a handler which prints to stderr. Parameters diff --git a/ema_workbench/util/utilities.py b/ema_workbench/util/utilities.py index 9eb04a179..90cbdb3d4 100644 --- a/ema_workbench/util/utilities.py +++ b/ema_workbench/util/utilities.py @@ -4,6 +4,7 @@ import json import os import tarfile +from collections.abc import Callable from io import BytesIO import numpy as np @@ -19,7 +20,7 @@ _logger = get_module_logger(__name__) -def load_results(file_name): +def load_results(file_name: str) -> tuple[pd.DataFrame, dict[str, np.ndarray]]: """Load the specified tar.gz file. the file is assumed to be saves using save_results. @@ -170,7 +171,7 @@ def load_results_old(archive): return experiments, outcomes_new -def save_results(results, file_name): +def save_results(results: tuple[pd.DataFrame, dict[str, np.ndarray]], file_name: str) -> None: """Save the results to the specified tar.gz file. The way in which results are stored depends. Experiments are saved @@ -231,7 +232,10 @@ def add_file(tararchive, stream, filename): _logger.info(f"results saved successfully to {file_name}") -def merge_results(results1, results2): +def merge_results( + results1: tuple[pd.DataFrame, dict[str, np.ndarray]], + results2: tuple[pd.DataFrame, dict[str, np.ndarray]], +) -> tuple[pd.DataFrame, dict[str, np.ndarray]]: """Convenience function for merging results from the workbench. The function merges results2 with results1. For the experiments, @@ -285,7 +289,7 @@ def merge_results(results1, results2): return mr -def get_ema_project_home_dir(): +def get_ema_project_home_dir() -> str: try: config_file_name = "expworkbench.cfg" directory = os.path.dirname(__file__) @@ -305,7 +309,10 @@ def get_ema_project_home_dir(): return os.getcwd() -def process_replications(data, aggregation_func=np.mean): +def process_replications( + data: dict[str, np.ndarray] | tuple[pd.DataFrame, dict[str, np.ndarray]], + aggregation_func: Callable[..., np.ndarray] = np.mean, +) -> dict[str, np.ndarray] | tuple[pd.DataFrame, dict[str, np.ndarray]]: """Convenience function for processing the replications of a stochastic model outcomes. The default behavior is to take the mean of the replications. This reduces diff --git a/pyproject.toml b/pyproject.toml index 108d4a7b2..8d2decdc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ exclude = [ "build", ] -target-version = "py311" +target-version = "py312" [tool.ruff.lint] select = [