Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions ema_workbench/analysis/cart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

"""

from __future__ import annotations

import contextlib
import io
import math
Expand All @@ -13,6 +15,7 @@
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
from sklearn import tree

Expand Down Expand Up @@ -71,7 +74,13 @@ class CART(sdutil.OutputFormatterMixin):

sep = "!?!"

def __init__(self, x, y, mass_min=0.05, mode=sdutil.RuleInductionType.BINARY):
def __init__(
self,
x: pd.DataFrame,
y: npt.NDArray,
mass_min: float = 0.05,
mode: sdutil.RuleInductionType = sdutil.RuleInductionType.BINARY,
) -> None:
"""Init."""
with contextlib.suppress(KeyError):
x = x.drop(["scenario"], axis=1)
Expand All @@ -97,7 +106,7 @@ def __init__(self, x, y, mass_min=0.05, mode=sdutil.RuleInductionType.BINARY):
self._stats = None

@property
def boxes(self):
def boxes(self) -> list[pd.DataFrame]:
"""Return a list with the box limits for each terminal leaf.

Returns
Expand Down Expand Up @@ -176,7 +185,7 @@ def recurse(left, right, child, lineage=None):
return self._boxes

@property
def stats(self):
def stats(self) -> list[dict]:
"""Returns list with the scenario discovery statistics for each terminal leaf.

Returns
Expand All @@ -197,7 +206,7 @@ def stats(self):
self._stats.append(boxstats)
return self._stats

def _binary_stats(self, box, box_init):
def _binary_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict:
indices = sdutil._in_box(self.x, box)

y_in_box = self.y[indices]
Expand All @@ -211,7 +220,7 @@ def _binary_stats(self, box, box_init):
}
return boxstats

def _regression_stats(self, box, box_init):
def _regression_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict:
indices = sdutil._in_box(self.x, box)

y_in_box = self.y[indices]
Expand All @@ -223,7 +232,7 @@ def _regression_stats(self, box, box_init):
}
return boxstats

def _classification_stats(self, box, box_init):
def _classification_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict:
indices = sdutil._in_box(self.x, box)

y_in_box = self.y[indices]
Expand Down Expand Up @@ -252,7 +261,7 @@ def _classification_stats(self, box, box_init):
sdutil.RuleInductionType.CLASSIFICATION: _classification_stats,
}

def build_tree(self):
def build_tree(self) -> None:
"""Train CART on the data."""
min_samples = int(self.mass_min * self.x.shape[0])

Expand All @@ -262,7 +271,7 @@ def build_tree(self):
self.clf = tree.DecisionTreeClassifier(min_samples_leaf=min_samples)
self.clf.fit(self._x, self.y)

def show_tree(self, mplfig=True, format="png"):
def show_tree(self, mplfig: bool = True, format: str = "png"):
"""Return a png (defaults) or svg of the tree.

On Windows, graphviz needs to be installed with conda.
Expand Down
58 changes: 31 additions & 27 deletions ema_workbench/analysis/pairs_plotting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""R-style pairs plotting functionality."""

from __future__ import annotations

import matplotlib.axes
import matplotlib.cm as cm
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from ..util import get_module_logger

Expand All @@ -17,15 +21,15 @@


def pairs_lines(
experiments,
outcomes,
outcomes_to_show=None,
group_by=None,
experiments: pd.DataFrame,
outcomes: dict[str, np.ndarray],
outcomes_to_show: list[str] | None = None,
group_by: str | None = None,
grouping_specifiers=None,
ylabels=None,
legend=True,
ylabels: dict[str, str] | None = None,
legend: bool = True,
**kwargs,
):
) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]:
"""Generate a pairs lines multiplot.

It shows the behavior of two outcomes over time against
Expand Down Expand Up @@ -147,18 +151,18 @@ def simple_pairs_lines(ax, y_data, x_data, color):


def pairs_density(
experiments,
outcomes,
outcomes_to_show=None,
group_by=None,
experiments: pd.DataFrame,
outcomes: dict[str, np.ndarray],
outcomes_to_show: list[str] | None = None,
group_by: str | None = None,
grouping_specifiers=None,
ylabels=None,
point_in_time=-1,
log=True,
gridsize=50,
colormap="coolwarm",
filter_scalar=True,
):
ylabels: dict[str, str] | None = None,
point_in_time: int = -1,
log: bool = True,
gridsize: int = 50,
colormap: str = "coolwarm",
filter_scalar: bool = True,
) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]:
"""Generate a pairs hexbin density multiplot.

In case of time-series data, the end states are used.
Expand Down Expand Up @@ -392,17 +396,17 @@ def simple_pairs_density(


def pairs_scatter(
experiments,
outcomes,
outcomes_to_show=None,
group_by=None,
experiments: pd.DataFrame,
outcomes: dict[str, np.ndarray],
outcomes_to_show: list[str] | None = None,
group_by: str | None = None,
grouping_specifiers=None,
ylabels=None,
legend=True,
point_in_time=-1,
filter_scalar=False,
ylabels: dict[str, str] | None = None,
legend: bool = True,
point_in_time: int = -1,
filter_scalar: bool = False,
**kwargs,
):
) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]:
"""Generate a pairs scatter multiplot.

In case of time-series data, the end states are used.
Expand Down
14 changes: 8 additions & 6 deletions ema_workbench/analysis/parcoords.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A general purpose matplotlib-based parallel coordinate plotting Class."""

from __future__ import annotations

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd
Expand All @@ -14,7 +16,7 @@
__all__ = ["ParallelAxes", "get_limits"]


def setup_parallel_plot(labels, minima, maxima, formatter=None, fs=14, rot=90):
def setup_parallel_plot(labels: list[str], minima: pd.Series, maxima: pd.Series, formatter: dict[str, str] | None = None, fs: int = 14, rot: float = 90) -> tuple[plt.Figure, list[plt.Axes], dict]:
"""Helper function for setting up the parallel axes plot.

Parameters
Expand Down Expand Up @@ -100,7 +102,7 @@ def setup_parallel_plot(labels, minima, maxima, formatter=None, fs=14, rot=90):
return fig, axes, tick_labels


def get_limits(data):
def get_limits(data: pd.DataFrame) -> pd.DataFrame:
"""Helper function to get limits of a FataFrame that can serve as input to ParallelAxis.

Parameters
Expand Down Expand Up @@ -165,7 +167,7 @@ class ParallelAxes:

"""

def __init__(self, limits, formatter=None, fontsize=14, rot=90):
def __init__(self, limits: pd.DataFrame, formatter: dict[str, str] | None = None, fontsize: int = 14, rot: float = 90):
"""Init.

Parameters
Expand Down Expand Up @@ -216,7 +218,7 @@ def __init__(self, limits, formatter=None, fontsize=14, rot=90):
plt.tight_layout(h_pad=0, w_pad=0)
plt.subplots_adjust(wspace=0)

def plot(self, data, color=None, label=None, **kwargs):
def plot(self, data: pd.DataFrame | pd.Series, color=None, label: str | None = None, **kwargs) -> None:
"""Plot data on parallel axes.

Parameters
Expand Down Expand Up @@ -257,7 +259,7 @@ def plot(self, data, color=None, label=None, **kwargs):
# plot the data
self._plot(normalized_data, color=color, **kwargs)

def legend(self):
def legend(self) -> None:
"""Add a legend to the figure."""
artists = []
labels = []
Expand Down Expand Up @@ -299,7 +301,7 @@ def _plot(self, data, **kwargs):
if label_j in self.flipped_axes:
self._update_plot_data(ax, 1, lines=lines)

def invert_axis(self, axis):
def invert_axis(self, axis: str | list[str]) -> None:
"""Flip direction for specified axis.

Parameters
Expand Down
85 changes: 43 additions & 42 deletions ema_workbench/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.patches import ConnectionPatch

from ..util import EMAError, get_module_logger
Expand Down Expand Up @@ -36,18 +37,18 @@


def envelopes(
experiments,
outcomes,
outcomes_to_show=None,
group_by=None,
experiments: pd.DataFrame,
outcomes: dict[str, np.ndarray],
outcomes_to_show: str | list[str] | None = None,
group_by: str | None = None,
grouping_specifiers=None,
density=None,
fill=False,
legend=True,
titles=None,
ylabels=None,
log=False,
):
density: Density | None = None,
fill: bool = False,
legend: bool = True,
titles: dict[str, str] | None = None,
ylabels: dict[str, str] | None = None,
log: bool = False,
) -> tuple[plt.Figure, dict[str, plt.Axes]]:
"""Make envelop plots.

An envelope shows over time the minimum and maximum value for a set
Expand Down Expand Up @@ -260,19 +261,19 @@ def single_envelope(outcomes, outcome_to_plot, time, density, ax, ax_d, fill, lo


def lines(
experiments,
outcomes,
outcomes_to_show=None,
group_by=None,
experiments: pd.DataFrame,
outcomes: dict[str, np.ndarray],
outcomes_to_show: str | list[str] | None = None,
group_by: str | None = None,
grouping_specifiers=None,
density="",
legend=True,
titles=None,
ylabels=None,
experiments_to_show=None,
show_envelope=False,
log=False,
):
density: Density | str = "",
legend: bool = True,
titles: dict[str, str] | None = None,
ylabels: dict[str, str] | None = None,
experiments_to_show: np.ndarray | None = None,
show_envelope: bool = False,
log: bool = False,
) -> tuple[plt.Figure, dict[str, plt.Axes]]:
"""Visualize results from experiments as line plots.

It is thus to be used in case of time
Expand Down Expand Up @@ -601,13 +602,13 @@ def simple_lines(outcomes, outcome_to_plot, time, density, ax, ax_d, log):


def kde_over_time(
experiments,
outcomes,
outcomes_to_show=None,
group_by=None,
experiments: pd.DataFrame,
outcomes: dict[str, np.ndarray],
outcomes_to_show: str | list[str] | None = None,
group_by: str | None = None,
grouping_specifiers=None,
colormap="viridis",
log=True,
colormap: str = "viridis",
log: bool = True,
):
"""Plot a KDE over time. The KDE is visualized through a heatmap.

Expand Down Expand Up @@ -679,21 +680,21 @@ def kde_over_time(


def multiple_densities(
experiments,
outcomes,
points_in_time=None,
outcomes_to_show=None,
group_by=None,
experiments: pd.DataFrame,
outcomes: dict[str, np.ndarray],
points_in_time: list[float] | None = None,
outcomes_to_show: str | list[str] | None = None,
group_by: str | None = None,
grouping_specifiers=None,
density=Density.KDE,
legend=True,
titles=None,
ylabels=None,
experiments_to_show=None,
plot_type=PlotType.ENVELOPE,
log=False,
density: Density = Density.KDE,
legend: bool = True,
titles: dict[str, str] | None = None,
ylabels: dict[str, str] | None = None,
experiments_to_show: np.ndarray | None = None,
plot_type: PlotType = PlotType.ENVELOPE,
log: bool = False,
**kwargs,
):
) -> tuple[list[plt.Figure], dict[str, dict[str, plt.Axes]]]:
"""Make an envelope plot with multiple density plots over the run time.

Parameters
Expand Down
Loading