diff --git a/scorecards.ipynb b/scorecards.ipynb new file mode 100644 index 0000000..ff2f050 --- /dev/null +++ b/scorecards.ipynb @@ -0,0 +1,724 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e7ab252b", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import seaborn as sns\n", + "\n", + "from climatebenchpress.compressor.plotting.plot_metrics import (\n", + " _rename_compressors,\n", + " _get_legend_name,\n", + " _COMPRESSOR_ORDER,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9d54a4d9-b587-4ad7-a2c6-1cf1003262bd", + "metadata": {}, + "outputs": [], + "source": [ + "RESULTS_FILE = Path(\"metrics\") / \"all_results.csv\"\n", + "OUTPUT_DIR = Path(\"scorecards\")\n", + "REF_COMPRESSOR = \"bitround\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b29dcce5-687d-41a6-926c-07a6634bd2f7", + "metadata": {}, + "outputs": [], + "source": [ + "METRICS = (\n", + " \"DSSIM\",\n", + " \"MAE\",\n", + " \"Max Absolute Error\",\n", + " \"Spectral Error\",\n", + " \"Compression Ratio [raw B / enc B]\",\n", + " \"Satisfies Bound (Value)\",\n", + ")\n", + "\n", + "METRICS2NAME = {\n", + " \"DSSIM\": \"dSSIM\",\n", + " \"MAE\": \"Mean Absolute Error\",\n", + " \"Compression Ratio [raw B / enc B]\": \"Compression Ratio\",\n", + " \"Satisfies Bound (Value)\": r\"% of Data Points Violating the Error Bound\",\n", + "}\n", + "\n", + "VARIABLE2NAME = {\n", + " \"10m_u_component_of_wind\": \"10u\",\n", + " \"10m_v_component_of_wind\": \"10v\",\n", + " \"mean_sea_level_pressure\": \"msl\",\n", + "}\n", + "\n", + "HIGHER_BETTER_METRICS = (\"DSSIM\", \"Compression Ratio [raw B / enc B]\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ecd4fa63-3b2d-4171-accc-00a93c5d51d0", + "metadata": {}, + "outputs": [], + "source": [ + "# DSSIM and Spectral Error are unreliable for variables with large NaN regions.\n", + "UNRELIABLE_NAN_METRICS = {\"DSSIM\", \"Spectral Error\"}\n", + "UNRELIABLE_NAN_VARIABLES = {\"ta\", \"tos\"}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ff76b593-a02c-4989-8abb-f36e2f8212d9", + "metadata": {}, + "outputs": [], + "source": [ + "# (compressor, variable) pairs where a relative error constraint had to be\n", + "# converted to an absolute error bound. These cells are drawn with a hatched\n", + "# overlay to flag the conversion.\n", + "CONVERTED_REL_TO_ABS_COMPRESSORS = {\n", + " \"ebcc-abs\",\n", + " \"jpeg2000\",\n", + " \"sperr\",\n", + " \"stochround\",\n", + " \"stochround-pco\",\n", + " \"sz3-abs\",\n", + " \"zfp\",\n", + " \"zfp-round\",\n", + "}\n", + "CONVERTED_REL_TO_ABS_VARIABLES = {\"agb\", \"no2\", \"pr\", \"q\"}\n", + "CONVERTED_ABS_TO_REL_COMPRESSORS = {\"bitround\", \"bitround-pco\"}\n", + "CONVERTED_ABS_TO_REL_VARIABLES = {\"10u\", \"10v\", \"msl\", \"rlut\", \"ta\", \"tos\"}\n", + "\n", + "CONVERTED_REL_TO_ABS_EDGECOLOR = \"#F6AE2D\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ec124c29", + "metadata": {}, + "outputs": [], + "source": [ + "def create_data_matrix(\n", + " df: pd.DataFrame,\n", + " error_bound: str,\n", + " metrics: tuple[str] = METRICS,\n", + "):\n", + " df_filtered = df[df[\"Error Bound Name\"] == error_bound].copy()\n", + " df_filtered[\"Satisfies Bound (Value)\"] = (\n", + " df_filtered[\"Satisfies Bound (Value)\"] * 100\n", + " ) # Convert to percentage\n", + "\n", + " # Get unique variables and compressors\n", + " variables = sorted(df_filtered[\"Variable\"].unique())\n", + " compressors = sorted(\n", + " df_filtered[\"Compressor\"].unique(),\n", + " key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)),\n", + " )\n", + "\n", + " column_labels = [f\"{v}\\n{m}\" for m in metrics for v in variables]\n", + " data_matrix = np.full((len(compressors), len(column_labels)), np.nan)\n", + "\n", + " # Fill the matrix with data\n", + " for i, compressor in enumerate(compressors):\n", + " for j, metric in enumerate(metrics):\n", + " for k, variable in enumerate(variables):\n", + " subset = df_filtered[\n", + " (df_filtered[\"Compressor\"] == compressor)\n", + " & (df_filtered[\"Variable\"] == variable)\n", + " ]\n", + " if subset.empty:\n", + " print(f\"No data for Compressor: {compressor}, Variable: {variable}\")\n", + " continue\n", + "\n", + " if (\n", + " metric in UNRELIABLE_NAN_METRICS\n", + " and variable in UNRELIABLE_NAN_VARIABLES\n", + " ):\n", + " continue\n", + "\n", + " col_idx = j * len(variables) + k\n", + " if metric in subset.columns:\n", + " values = subset[metric]\n", + " if len(values) == 1:\n", + " data_matrix[i, col_idx] = values.iloc[0]\n", + "\n", + " return data_matrix, compressors, variables" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "871ae766", + "metadata": {}, + "outputs": [], + "source": [ + "def create_compression_scorecard(\n", + " data_matrix: np.ndarray,\n", + " compressors: list[str],\n", + " variables: list[str],\n", + " metrics: list[str],\n", + " cbar: bool = True,\n", + " ref_compressor: str = REF_COMPRESSOR,\n", + " higher_better_metrics: tuple[str] = HIGHER_BETTER_METRICS,\n", + " save_fn: str | Path | None = None,\n", + "):\n", + " \"\"\"\n", + " Create a scorecard plot of relative metric differences vs a reference.\n", + "\n", + " Parameters:\n", + " - data_matrix: 2D array with compressors as rows, metric-variable combinations as columns\n", + " - compressors: list of compressor names\n", + " - variables: list of variable names\n", + " - metrics: list of metric names\n", + " - ref_compressor: reference compressor for relative calculations\n", + " - save_fn: filename to save plot (optional)\n", + " \"\"\"\n", + "\n", + " ref_idx = compressors.index(ref_compressor)\n", + " ref_values = data_matrix[ref_idx, :]\n", + "\n", + " relative_matrix = np.full_like(data_matrix, np.nan)\n", + " for i in range(len(compressors)):\n", + " for j in range(data_matrix.shape[1]):\n", + " if np.isnan(data_matrix[i, j]) or np.isnan(ref_values[j]):\n", + " continue\n", + " ref_val = np.abs(ref_values[j])\n", + " if ref_val == 0.0:\n", + " ref_val = 1e-10 # Avoid division by zero\n", + " metric = metrics[j // len(variables)]\n", + " if metric in higher_better_metrics:\n", + " relative_matrix[i, j] = (\n", + " (ref_values[j] - data_matrix[i, j]) / ref_val * 100\n", + " )\n", + " elif metric == \"Satisfies Bound (Value)\":\n", + " relative_matrix[i, j] = 100 if data_matrix[i, j] != 0 else 0\n", + " else:\n", + " relative_matrix[i, j] = (\n", + " (data_matrix[i, j] - ref_values[j]) / ref_val * 100\n", + " )\n", + "\n", + " reds = sns.color_palette(\"Reds\", 6)\n", + " blues = sns.color_palette(\"Blues_r\", 6)\n", + " cmap = mpl.colors.ListedColormap(blues + [(0.95, 0.95, 0.95)] + reds)\n", + " cb_levels = [-100, -75, -50, -25, -10, -1, 1, 10, 25, 50, 75, 100]\n", + " norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend=\"both\")\n", + "\n", + " ncompressors = len(compressors)\n", + " nvariables = len(variables)\n", + " nmetrics = len(metrics)\n", + "\n", + " panel_width = (2.5 / 5) * nvariables\n", + " label_width = 1.5 * panel_width\n", + " padding_right = 0.1\n", + " panel_height = panel_width / nvariables\n", + "\n", + " title_height = panel_height * 1.25\n", + " cbar_height = panel_height * 2\n", + " spacing_height = panel_height * 0.1\n", + " spacing_width = panel_height * 0.2\n", + "\n", + " total_width = (\n", + " label_width\n", + " + nmetrics * panel_width\n", + " + (nmetrics - 1) * spacing_width\n", + " + padding_right\n", + " )\n", + " total_height = (\n", + " title_height\n", + " + cbar_height\n", + " + ncompressors * panel_height\n", + " + (ncompressors - 1) * spacing_height\n", + " )\n", + "\n", + " fig = plt.figure(figsize=(total_width, total_height))\n", + " gs = mpl.gridspec.GridSpec(\n", + " ncompressors,\n", + " nmetrics,\n", + " figure=fig,\n", + " left=label_width / total_width,\n", + " right=1 - (padding_right / total_width),\n", + " top=1 - (title_height / total_height),\n", + " bottom=cbar_height / total_height,\n", + " hspace=spacing_height / panel_height,\n", + " wspace=spacing_width / panel_width,\n", + " )\n", + "\n", + " img = None\n", + " border_targets: list[tuple[mpl.axes.Axes, int]] = []\n", + " for row, compressor in enumerate(compressors):\n", + " for col, metric in enumerate(metrics):\n", + " ax = fig.add_subplot(gs[row, col])\n", + "\n", + " start_col = col * nvariables\n", + " end_col = start_col + nvariables\n", + " rel_values = relative_matrix[row, start_col:end_col].reshape(1, -1)\n", + " abs_values = data_matrix[row, start_col:end_col]\n", + "\n", + " img = ax.imshow(rel_values, aspect=\"auto\", cmap=cmap, norm=norm)\n", + "\n", + " ax.set_xticks([])\n", + " ax.set_xticklabels([])\n", + " ax.set_yticks([])\n", + " ax.set_yticklabels([])\n", + "\n", + " for i in range(1, nvariables):\n", + " rect = mpl.patches.Rectangle(\n", + " (i - 0.5, -0.5),\n", + " 1,\n", + " 1,\n", + " linewidth=1,\n", + " edgecolor=\"lightgrey\"\n", + " if np.isnan(abs_values[i]) and np.isnan(abs_values[i - 1])\n", + " else \"white\",\n", + " facecolor=\"none\",\n", + " )\n", + " ax.add_patch(rect)\n", + "\n", + " rel_to_abs_converted = (\n", + " compressor in CONVERTED_REL_TO_ABS_COMPRESSORS\n", + " and variables[i] in CONVERTED_REL_TO_ABS_VARIABLES\n", + " )\n", + " abs_to_rel_converted = (\n", + " compressor in CONVERTED_ABS_TO_REL_COMPRESSORS\n", + " and variables[i] in CONVERTED_ABS_TO_REL_VARIABLES\n", + " )\n", + " if rel_to_abs_converted or abs_to_rel_converted:\n", + " border_targets.append((ax, i))\n", + "\n", + " # fontweight = \"bold\" if compressor.startswith(\"safeguarded-\") else \"normal\"\n", + "\n", + " for i, val in enumerate(abs_values):\n", + " color = \"black\" if abs(rel_values[0, i]) < 75 else \"white\"\n", + " fontsize = 10\n", + " if (\n", + " metric in UNRELIABLE_NAN_METRICS\n", + " and variables[i] in UNRELIABLE_NAN_VARIABLES\n", + " ):\n", + " text = \"N/A\"\n", + " color = \"black\"\n", + " elif np.isnan(val):\n", + " text = \"Crash\"\n", + " color = \"black\"\n", + " elif abs(val) > 10_000:\n", + " text = f\"{val:.1e}\"\n", + " fontsize = 8\n", + " elif abs(val) > 10:\n", + " text = f\"{val:.0f}\"\n", + " elif abs(val) > 1:\n", + " text = f\"{val:.1f}\"\n", + " elif val == 1 and metric == \"DSSIM\":\n", + " text = \"1\"\n", + " elif val == 0:\n", + " text = \"0\"\n", + " elif abs(val) < 0.01:\n", + " text = f\"{val:.1e}\"\n", + " fontsize = 8\n", + " else:\n", + " text = f\"{val:.2f}\"\n", + " ax.text(\n", + " i,\n", + " 0,\n", + " text,\n", + " ha=\"center\",\n", + " va=\"center\",\n", + " fontsize=fontsize,\n", + " color=color,\n", + " # fontweight=fontweight,\n", + " )\n", + "\n", + " if (\n", + " row > 0\n", + " and np.isnan(val)\n", + " and np.isnan(data_matrix[row - 1, col * nvariables + i])\n", + " and compressor == f\"safeguarded-{compressors[row - 1]}\"\n", + " and not (\n", + " metric in UNRELIABLE_NAN_METRICS\n", + " and variables[i] in UNRELIABLE_NAN_VARIABLES\n", + " )\n", + " ):\n", + " ax.annotate(\n", + " \"\",\n", + " xy=(i, -0.15),\n", + " xytext=(i, -0.9),\n", + " arrowprops=dict(arrowstyle=\"->\", lw=2, color=\"lightgrey\"),\n", + " )\n", + "\n", + " if col == 0:\n", + " ax.set_ylabel(\n", + " _get_legend_name(compressor),\n", + " rotation=0,\n", + " ha=\"right\",\n", + " va=\"center\",\n", + " labelpad=10,\n", + " fontsize=14,\n", + " # fontweight=fontweight,\n", + " )\n", + "\n", + " if row == 0:\n", + " ax.set_title(METRICS2NAME.get(metric, metric), fontsize=16, pad=10)\n", + " ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", + " ax.set_xticks(range(nvariables))\n", + " ax.set_xticklabels(\n", + " [VARIABLE2NAME.get(v, v) for v in variables],\n", + " rotation=45,\n", + " ha=\"left\",\n", + " fontsize=12,\n", + " )\n", + "\n", + " for spine in ax.spines.values():\n", + " if compressor.startswith(\"safeguarded-\"):\n", + " spine.set_color(\"black\")\n", + " spine.set_linewidth(2)\n", + " else:\n", + " spine.set_color(\"0.7\")\n", + "\n", + " # Mark rel-to-abs converted cells with a small black triangle in the upper\n", + " # right corner. Drawn last so they sit on top of the white grid rectangles\n", + " # and the axes spines.\n", + " triangle_size = 0.3\n", + " for ax, i in border_targets:\n", + " x_right = i + 0.5\n", + " y_top = -0.5\n", + " triangle = mpl.patches.Polygon(\n", + " [\n", + " (x_right - triangle_size, y_top),\n", + " (x_right, y_top),\n", + " (x_right, y_top + triangle_size),\n", + " ],\n", + " closed=True,\n", + " facecolor=\"black\",\n", + " edgecolor=\"none\",\n", + " zorder=10,\n", + " clip_on=False,\n", + " )\n", + " ax.add_patch(triangle)\n", + "\n", + " if cbar and img is not None:\n", + " rel_cbar_height = cbar_height / total_height\n", + " cax = fig.add_axes((0.4, rel_cbar_height * 0.3, 0.5, rel_cbar_height * 0.2))\n", + " cb = fig.colorbar(img, cax=cax, orientation=\"horizontal\")\n", + " cb.ax.set_xticks(cb_levels)\n", + " cb.ax.set_xlabel(\n", + " f\"Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse\",\n", + " fontsize=16,\n", + " )\n", + "\n", + " # plt.tight_layout()\n", + "\n", + " if save_fn:\n", + " plt.savefig(save_fn, dpi=300, bbox_inches=\"tight\")\n", + " plt.close()\n", + " else:\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ee0600cb-6682-443e-a399-fb0b686502a5", + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess(df: pd.DataFrame) -> pd.DataFrame:\n", + " df = df[~df[\"Dataset\"].str.contains(\"-tiny\")]\n", + " df = df[~df[\"Dataset\"].str.contains(\"-chunked\")]\n", + " df = df[~df[\"Dataset\"].str.contains(\"cloud-ice\")]\n", + " df = df[\n", + " ~df[\"Compressor\"].isin(\n", + " [\n", + " \"bitround\",\n", + " \"jpeg2000-conservative-abs\",\n", + " \"stochround-conservative-abs\",\n", + " \"stochround-pco-conservative-abs\",\n", + " \"zfp-conservative-abs\",\n", + " \"bitround-conservative-rel\",\n", + " \"stochround-pco\",\n", + " \"stochround\",\n", + " \"zfp\",\n", + " \"jpeg2000\",\n", + " \"sz3-abs\",\n", + " \"sz3-abs-conservative-abs\",\n", + " \"ebcc-abs\",\n", + " \"ebcc-abs-conservative-abs\",\n", + " ]\n", + " )\n", + " ]\n", + " df = df[~df[\"Compressor\"].str.contains(\"rp\")]\n", + " df = _rename_compressors(df)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dfd542e0-a0d7-4513-9184-da7be46b27e9", + "metadata": {}, + "outputs": [], + "source": [ + "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2d242e7c", + "metadata": {}, + "outputs": [], + "source": [ + "df = preprocess(pd.read_csv(RESULTS_FILE))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2d019b8d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n", + "No data for Compressor: ebcc, Variable: ta\n", + "No data for Compressor: ebcc, Variable: tos\n" + ] + } + ], + "source": [ + "scorecard_data = {}\n", + "for bound in [\"low\", \"mid\", \"high\"]:\n", + " scorecard_data[bound] = create_data_matrix(df, bound, METRICS)" + ] + }, + { + "cell_type": "markdown", + "id": "ae80d757", + "metadata": {}, + "source": [ + "# Scorecard" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "678c927b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for low bound...\n", + "Creating scorecard for mid bound...\n", + "Creating scorecard for high bound...\n" + ] + } + ], + "source": [ + "for bound, (data_matrix, compressors, variables) in scorecard_data.items():\n", + " print(f\"Creating scorecard for {bound} bound...\")\n", + " nvars = len(variables)\n", + " create_compression_scorecard(\n", + " data_matrix[:, : 3 * nvars],\n", + " compressors,\n", + " variables,\n", + " METRICS[:3],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=False,\n", + " save_fn=OUTPUT_DIR / f\"{bound}_scorecard_row1.pdf\",\n", + " )\n", + " create_compression_scorecard(\n", + " data_matrix[:, 3 * nvars :],\n", + " compressors,\n", + " variables,\n", + " METRICS[3:],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=True,\n", + " save_fn=OUTPUT_DIR / f\"{bound}_scorecard_row2.pdf\",\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "23236392-85e3-41f5-98cd-1471b90106ba", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating reduced scorecard for mid bound...\n" + ] + } + ], + "source": [ + "bound = \"mid\"\n", + "data_matrix, compressors, variables = scorecard_data[bound]\n", + "nvars = len(variables)\n", + "\n", + "print(f\"Creating reduced scorecard for {bound} bound...\")\n", + "\n", + "# extract only MAE, CR, and V\n", + "create_compression_scorecard(\n", + " data_matrix[\n", + " :, np.concat([np.arange(nvars, nvars * 2), np.arange(nvars * 4, nvars * 6)])\n", + " ],\n", + " compressors,\n", + " variables,\n", + " METRICS[1:2] + METRICS[-2:],\n", + " ref_compressor=\"bitround-pco\",\n", + " save_fn=OUTPUT_DIR / f\"{bound}_scorecard_reduced.pdf\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3859154-7882-4825-82e8-05efdf594099", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 408b098..12171c1 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -12,45 +12,76 @@ from .variable_plotters import PLOTTERS _COMPRESSOR2LINEINFO = [ - ("jpeg2000", ("#EE7733", "-")), - ("sperr", ("#117733", ":")), - ("zfp-round", ("#DDAA33", "--")), - ("zfp", ("#EE3377", "--")), - ("sz3", ("#CC3311", "-.")), - ("bitround-pco", ("#0077BB", ":")), - ("bitround", ("#33BBEE", "-")), - ("stochround-pco", ("#BBBBBB", "--")), - ("stochround", ("#009988", "--")), - ("tthresh", ("#882255", "-.")), + ("jpeg2000", ("#EE7733", "-", "o")), + ("sperr", ("#117733", "-", "s")), + ("zfp-round", ("#DDAA33", "-", "D")), + ("zfp", ("#EE3377", "--", "^")), + ("sz3", ("#CC3311", "-", "v")), + ("bitround-pco", ("#0077BB", "-", "P")), + ("bitround", ("#33BBEE", "-", "X")), + ("stochround-pco", ("#BBBBBB", "--", "d")), + ("stochround", ("#009988", "--", "h")), + ("tthresh", ("#882255", "-.", "<")), + ("safeguarded-sperr", ("#117733", ":", "s")), + ("safeguarded-zfp-round", ("#DDAA33", ":", "D")), + ("safeguarded-sz3", ("#CC3311", ":", "v")), + ("safeguarded-zero-dssim", ("#9467BD", "--", "*")), + ("safeguarded-zero", ("#9467BD", ":", "H")), + ("safeguarded-bitround-pco", ("#0077BB", ":", "P")), + ("ebcc", ("#AA4444", "-", "8")), + ("safeguarded-ebcc", ("#AA4444", ":", "8")), ] -def _get_lineinfo(compressor: str) -> tuple[str, str]: - """Get the line color and style for a given compressor.""" - for comp, (color, linestyle) in _COMPRESSOR2LINEINFO: +def _get_lineinfo(compressor: str) -> tuple[str, str, str]: + """Get the line color, style, and marker for a given compressor.""" + for comp, (color, linestyle, marker) in _COMPRESSOR2LINEINFO: if compressor.startswith(comp): - return color, linestyle + return color, linestyle, marker raise ValueError(f"Unknown compressor: {compressor}") _COMPRESSOR2LEGEND_NAME = [ ("jpeg2000", "JPEG2000"), ("sperr", "SPERR"), - ("zfp-round", "ZFP-ROUND"), + ("zfp-round", "ZFP"), ("zfp", "ZFP"), - ("sz3", "SZ3"), - ("bitround-pco", "BitRound + PCO"), + ("sz3", "SZ3[v3.2]"), + ("bitround-pco", "BitRound"), ("bitround", "BitRound + Zstd"), ("stochround-pco", "StochRound + PCO"), ("stochround", "StochRound + Zstd"), ("tthresh", "TTHRESH"), + ("safeguarded-sperr", "Safeguarded(SPERR)"), + ("safeguarded-zfp-round", "Safeguarded(ZFP)"), + ("safeguarded-sz3", "Safeguarded(SZ3[v3.2])"), + ("safeguarded-zero-dssim", "Safeguarded(0, dSSIM)"), + ("safeguarded-zero", "Safeguarded(0)"), + ("safeguarded-bitround-pco", "Safeguarded(BitRound)"), + ("ebcc", "EBCC"), + ("safeguarded-ebcc", "Safeguarded(EBCC)"), +] + +_COMPRESSOR_ORDER = [ + "BitRound", + "Safeguarded(BitRound)", + "ZFP", + "Safeguarded(ZFP)", + "SZ3[v3.2]", + "Safeguarded(SZ3[v3.2])", + "SPERR", + "Safeguarded(SPERR)", + "EBCC", + "Safeguarded(EBCC)", + "Safeguarded(0)", + "Safeguarded(0, dSSIM)", ] DISTORTION2LEGEND_NAME = { - "Relative MAE": "Mean Absolute Error", - "Relative DSSIM": "DSSIM", - "Relative MaxAbsError": "Max Absolute Error", - "Spectral Error": "Spectral Error", + "Relative MAE": "Normalised Mean Absolute Error", + "Relative dSSIM": "dSSIM", + "Relative MaxAbsError": "Normalised Max Absolute Error", + "Relative SpectralError": "Normalised Spectral Error", } @@ -102,6 +133,7 @@ def plot_metrics( df = pd.read_csv(metrics_path / "all_results.csv") # Filter out excluded datasets and compressors + # bitround jpeg2000-conservative-abs stochround-conservative-abs stochround-pco-conservative-abs zfp-conservative-abs bitround-conservative-rel stochround-pco stochround zfp jpeg2000 sz3-abs sz3-abs-conservative-abs ebcc-abs ebcc-abs-conservative-abs rp rp-conservative-abs rp-10.0 rp-10.0-conservative-abs rp-100.0 rp-100.0-conservative-abs rp-2.0 rp-2.0-conservative-abs rp-5.0 rp-5.0-conservative-abs rp-50.0 rp-50.0-conservative-abs rp-dct rp-dct-conservative-abs rp-dct-10.0 rp-dct-10.0-conservative-abs rp-dct-100.0 rp-dct-100.0-conservative-abs rp-dct-2.0 rp-dct-2.0-conservative-abs rp-dct-5.0 rp-dct-5.0-conservative-abs rp-dct-50.0 rp-dct-50.0-conservative-abs safeguarded-rp safeguarded-rp-dct df = df[~df["Compressor"].isin(exclude_compressor)] df = df[~df["Dataset"].isin(exclude_dataset)] is_tiny = df["Dataset"].str.endswith("-tiny") @@ -111,16 +143,16 @@ def plot_metrics( filter_chunked = is_chunked if chunked_datasets else ~is_chunked df = df[filter_chunked] - _plot_per_variable_metrics( - datasets=datasets, - compressed_datasets=compressed_datasets, - plots_path=plots_path, - all_results=df, - rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], - ) + # _plot_per_variable_metrics( + # datasets=datasets, + # compressed_datasets=compressed_datasets, + # plots_path=plots_path, + # all_results=df, + # rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], + # ) df = _rename_compressors(df) - normalized_df = _normalize(df) + normalized_df, normalized_mean_std = _normalize(df) _plot_bound_violations( normalized_df, bound_names, plots_path / "bound_violations.pdf" ) @@ -129,7 +161,7 @@ def plot_metrics( for metric in [ "Relative MAE", - "Relative DSSIM", + "Relative dSSIM", "Relative MaxAbsError", "Relative SpectralError", ]: @@ -138,6 +170,7 @@ def plot_metrics( normalized_df, compression_metric="Relative CR", distortion_metric=metric, + mean_std=normalized_mean_std[metric], outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="mean", bound_names=bound_names, @@ -147,6 +180,7 @@ def plot_metrics( normalized_df, compression_metric="Relative CR", distortion_metric=metric, + mean_std=normalized_mean_std[metric], outfile=plots_path / f"full_rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="mean", @@ -188,7 +222,7 @@ def _normalize(data): normalize_vars = [ ("Compression Ratio [raw B / enc B]", "Relative CR"), ("MAE", "Relative MAE"), - ("DSSIM", "Relative DSSIM"), + ("DSSIM", "Relative dSSIM"), ("Max Absolute Error", "Relative MaxAbsError"), ("Spectral Error", "Relative SpectralError"), ] @@ -198,11 +232,15 @@ def _normalize(data): dssim_unreliable = normalized["Variable"].isin(["ta", "tos"]) normalized.loc[dssim_unreliable, "DSSIM"] = np.nan + normalize_mean_std = dict() for col, new_col in normalize_vars: mean_std = dict() for var in variables: - mean = normalized[normalized["Variable"] == var][col].mean() - std = normalized[normalized["Variable"] == var][col].std() + if col in ["DSSIM"]: + mean, std = 0.0, 1.0 + else: + mean = normalized[normalized["Variable"] == var][col].mean() + std = normalized[normalized["Variable"] == var][col].std() mean_std[var] = (mean, std) # Normalize each variable by its mean and std @@ -213,7 +251,9 @@ def _normalize(data): axis=1, ) - return normalized + normalize_mean_std[new_col] = mean_std + + return normalized, normalize_mean_std def _plot_per_variable_metrics( @@ -356,12 +396,12 @@ def _plot_variable_rd_curve( for i in bound_ixs ] distortion = [compressor_data[distortion_metric].loc[i] for i in bound_ixs] - color, linestyle = _get_lineinfo(comp) + color, linestyle, marker = _get_lineinfo(comp) plt.plot( compr_ratio, distortion, label=_get_legend_name(comp), - marker="s", + marker=marker, color=color, linestyle=linestyle, linewidth=4, @@ -408,6 +448,7 @@ def _plot_aggregated_rd_curve( normalized_df, compression_metric, distortion_metric, + mean_std, outfile: None | Path = None, agg="median", bound_names=["low", "mid", "high"], @@ -419,7 +460,10 @@ def _plot_aggregated_rd_curve( # Exclude variables that are not relevant for the distortion metric. normalized_df = normalized_df[~normalized_df["Variable"].isin(exclude_vars)] - compressors = normalized_df["Compressor"].unique() + compressors = sorted( + normalized_df["Compressor"].unique(), + key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)), + ) agg_distortion = normalized_df.groupby(["Error Bound Name", "Compressor"])[ [compression_metric, distortion_metric] ].agg(agg) @@ -433,12 +477,12 @@ def _plot_aggregated_rd_curve( agg_distortion.loc[(bound, comp), distortion_metric] for bound in bound_names ] - color, linestyle = _get_lineinfo(comp) + color, linestyle, marker = _get_lineinfo(comp) plt.plot( compr_ratio, distortion, label=_get_legend_name(comp), - marker="s", + marker=marker, color=color, linestyle=linestyle, linewidth=4, @@ -447,7 +491,13 @@ def _plot_aggregated_rd_curve( if remove_outliers: # SZ3 and JPEG2000 often give outlier values and violate the bounds. - exclude_compressors = ["sz3", "jpeg2000"] + exclude_compressors = [ + "sz3", + "jpeg2000", + "safeguarded-zero-dssim", + "safeguarded-zero", + "safeguarded-sz3", + ] filtered_agg = agg_distortion[ ~agg_distortion.index.get_level_values("Compressor").isin( exclude_compressors @@ -493,28 +543,34 @@ def _plot_aggregated_rd_curve( right=True, ) plt.xlabel( - r"Mean Normalized Compression Ratio ($\uparrow$)", + r"Mean Normalised Compression Ratio ($\uparrow$)", fontsize=16, ) metric_name = DISTORTION2LEGEND_NAME.get(distortion_metric, distortion_metric) plt.ylabel( - rf"Mean Normalized {metric_name} ($\downarrow$)", + rf"Mean {metric_name} ($\downarrow$)", fontsize=16, ) plt.legend( title="Compressor", - loc="upper right", - bbox_to_anchor=(0.8, 0.99), + loc="upper left", + # bbox_to_anchor=(0.8, 0.99), fontsize=12, title_fontsize=14, ) arrow_color = "black" - if "DSSIM" in distortion_metric: + if "dSSIM" in distortion_metric: + # Annotate dSSIM = 1, accounting for the normalization + dssim_one = getattr(np, f"nan{agg}")( + [(1 - ms[0]) / ms[1] for ms in mean_std.values()] + ) + plt.axhline(dssim_one, c="k", ls="--") + # Add an arrow pointing into the top right corner plt.annotate( "", - xy=(0.95, 0.95), + xy=(0.95, 0.875 if remove_outliers else 0.9), xycoords="axes fraction", xytext=(-60, -50), textcoords="offset points", @@ -527,7 +583,7 @@ def _plot_aggregated_rd_curve( # Attach the text to the lower left of the arrow plt.text( 0.83, - 0.92, + 0.845 if remove_outliers else 0.87, "Better", transform=plt.gca().transAxes, fontsize=16, @@ -538,7 +594,7 @@ def _plot_aggregated_rd_curve( ) # Correct the y-label to point upwards plt.ylabel( - rf"Mean Normalized {metric_name} ($\uparrow$)", + rf"Mean {metric_name} ($\uparrow$)", fontsize=16, ) else: @@ -566,7 +622,7 @@ def _plot_aggregated_rd_curve( ha="center", ) if ( - "DSSIM" in distortion_metric + "dSSIM" in distortion_metric or "MaxAbsError" in distortion_metric or "SpectralError" in distortion_metric ): @@ -579,24 +635,23 @@ def _plot_aggregated_rd_curve( def _plot_throughput(df, outfile: None | Path = None): - # Transform throughput measurements from raw B/s to s/MB for better comparison - # with instruction count measurements. encode_col = "Encode Throughput [raw B / s]" decode_col = "Decode Throughput [raw B / s]" new_df = df[["Compressor", "Error Bound Name", encode_col, decode_col]].copy() - transformed_encode_col = "Encode Throughput [s / MB]" - transformed_decode_col = "Decode Throughput [s / MB]" - new_df[transformed_encode_col] = 1e6 / new_df[encode_col] - new_df[transformed_decode_col] = 1e6 / new_df[decode_col] + transformed_encode_col = "Encode Throughput [MiB / s]" + transformed_decode_col = "Decode Throughput [MiB / s]" + new_df[transformed_encode_col] = new_df[encode_col] / (2**20) + new_df[transformed_decode_col] = new_df[decode_col] / (2**20) encode_col, decode_col = transformed_encode_col, transformed_decode_col grouped_df = _get_median_and_quantiles(new_df, encode_col, decode_col) _plot_grouped_df( grouped_df, title="", - ylabel="Throughput [s / MB]", + ylabel="Throughput [MiB / s]", logy=True, outfile=outfile, + up=True, ) @@ -610,105 +665,162 @@ def _plot_instruction_count(df, outfile: None | Path = None): ylabel="Instructions [# / raw B]", logy=True, outfile=outfile, + up=False, ) def _get_median_and_quantiles(df, encode_column, decode_column): - return df.groupby(["Compressor", "Error Bound Name"])[ - [encode_column, decode_column] - ].agg( - encode_median=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.5) - ), - encode_lower_quantile=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.25) - ), - encode_upper_quantile=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.75) - ), - decode_median=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.5) - ), - decode_lower_quantile=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.25) - ), - decode_upper_quantile=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.75) - ), + return ( + df.groupby(["Compressor", "Error Bound Name"])[[encode_column, decode_column]] + .agg( + encode_median=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.5) + ), + encode_lower_quantile=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.25) + ), + encode_upper_quantile=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.75) + ), + decode_median=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.5) + ), + decode_lower_quantile=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.25) + ), + decode_upper_quantile=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.75) + ), + ) + .sort_index( + level=[0, 1], + key=lambda ks: [ + ( + ["low", "mid", "high"].index(k) + if k in ["low", "mid", "high"] + else _COMPRESSOR_ORDER.index(_get_legend_name(k)) + ) + for k in ks + ], + ) ) def _plot_grouped_df( - grouped_df, title, ylabel, outfile: None | Path = None, logy=False + grouped_df, + title, + ylabel, + outfile: None | Path = None, + logy=False, + up=False, ): - fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) + fig, (ax1, ax2) = plt.subplots( + 1, 2, figsize=(18, 4), sharex=True, sharey=True, gridspec_kw=dict(wspace=0.1) + ) # Bar width - bar_width = 0.35 - compressors = grouped_df.index.levels[0].tolist() + bar_width = 0.25 + compressors = sorted( + grouped_df.index.levels[0].tolist(), + key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)), + ) x_labels = [_get_legend_name(c) for c in compressors] - x_positions = range(len(x_labels)) error_bounds = ["low", "mid", "high"] - for i, error_bound in enumerate(error_bounds): - ax = axes[i] - bound_data = grouped_df.xs(error_bound, level="Error Bound Name") - - # Plot encode throughput - ax.bar( - x_positions, - bound_data["encode_median"], - bar_width, - yerr=[ - bound_data["encode_lower_quantile"], - bound_data["encode_upper_quantile"], + ax1.bar( + [ + (x // len(error_bounds)) + + ((x % len(error_bounds)) - ((len(error_bounds) - 1) / 2)) + * bar_width + * 1.2 + for x in range(len(grouped_df["encode_median"])) + ], + grouped_df["encode_median"], + bar_width, + yerr=[ + grouped_df["encode_lower_quantile"], + grouped_df["encode_upper_quantile"], + ], + edgecolor="white", + linewidth=0, + color=np.repeat( + [_get_lineinfo(comp)[0] for comp in compressors], len(error_bounds) + ), + hatch=np.repeat( + ["O" if comp.startswith("safeguarded-") else "" for comp in compressors], + len(error_bounds), + ), + label=np.array( + [ + ["Safeguarded"] + [None] * (len(error_bounds) - 1) + if comp == "safeguarded-bitround-pco" + else [None] * len(error_bounds) + for comp in compressors ], - label="Encoding", - color=[_get_lineinfo(comp)[0] for comp in compressors], - ) + ).flatten(), + ) - # Plot decode throughput - ax.bar( - [p + bar_width for p in x_positions], - bound_data["decode_median"], - bar_width, - yerr=[ - bound_data["decode_lower_quantile"], - bound_data["decode_upper_quantile"], - ], - label="Decoding", - edgecolor=[_get_lineinfo(comp)[0] for comp in compressors], - fill=False, - linewidth=4, - ) + # Add labels and title + ax1.set_xticks([p for p in range(len(x_labels))]) + ax1.set_xticklabels(x_labels, rotation=45, ha="right", fontsize=14) + ax1.set_xlim(-0.5, len(x_labels) - 0.5) + ax1.set_yscale("log" if logy else "linear") + ax1.set_title("Compression", fontsize=14) + ax1.grid(axis="y", linestyle="--", alpha=0.7) + ax1.legend(fontsize=14, loc="upper right" if up else "upper left", framealpha=0.9) + ax1.set_ylabel(ylabel, fontsize=14) + ax1.annotate( + "Better", + xy=(0.5, 0.75), + xycoords="axes fraction", + xytext=(0.5, 0.92), + textcoords="axes fraction", + arrowprops=dict(arrowstyle="<-" if up else "->", lw=3, color="black"), + fontsize=14, + ha="center", + va="bottom", + ) - # Add labels and title - ax.set_xticks([p + bar_width / 2 for p in x_positions]) - ax.set_xticklabels(x_labels, rotation=45, ha="right", fontsize=14) - ax.set_yscale("log" if logy else "linear") - ax.set_title(f"{error_bound.capitalize()} Error Bound", fontsize=14) - ax.grid(axis="y", linestyle="--", alpha=0.7) - if i == 0: - ax.legend(fontsize=14) - ax.set_ylabel(ylabel, fontsize=14) - ax.annotate( - "Better", - xy=(0.1, 0.8), - xycoords="axes fraction", - xytext=(0.1, 0.95), - textcoords="axes fraction", - arrowprops=dict(arrowstyle="->", lw=3, color="black"), - fontsize=14, - ha="center", - va="bottom", - ) + ax2.bar( + [ + (x // len(error_bounds)) + + ((x % len(error_bounds)) - ((len(error_bounds) - 1) / 2)) + * bar_width + * 1.2 + for x in range(len(grouped_df["decode_median"])) + ], + grouped_df["decode_median"], + bar_width, + yerr=[ + grouped_df["decode_lower_quantile"], + grouped_df["decode_upper_quantile"], + ], + edgecolor="white", + linewidth=0, + color=np.repeat( + [_get_lineinfo(comp)[0] for comp in compressors], len(error_bounds) + ), + hatch=np.repeat( + ["O" if comp.startswith("safeguarded-") else "" for comp in compressors], + len(error_bounds), + ), + ) + + # Add labels and title + ax2.set_xticks([p for p in range(len(x_labels))]) + ax2.set_xticklabels(x_labels, rotation=45, ha="right", fontsize=14) + ax2.set_yscale("log" if logy else "linear") + ax2.set_title("Decompression", fontsize=14) + ax2.grid(axis="y", linestyle="--", alpha=0.7) + ax2.yaxis.tick_right() + ax2.tick_params("y", labelright=True) fig.suptitle(title) - fig.tight_layout() + # fig.tight_layout() if outfile is not None: - _savefig(outfile) + _savefig(outfile, bbox_inches="tight") plt.close() @@ -720,11 +832,11 @@ def _plot_bound_violations(df, bound_names, outfile: None | Path = None): df_bound["Compressor"] = df_bound["Compressor"].map(_get_legend_name) pass_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Passed)" - ) + ).sort_index(key=lambda ks: [_COMPRESSOR_ORDER.index(k) for k in ks]) pass_fail = pass_fail.astype(np.float32) fraction_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Value)" - ) + ).sort_index(key=lambda ks: [_COMPRESSOR_ORDER.index(k) for k in ks]) annotations = fraction_fail.map( lambda x: "{:.2f}".format(x * 100) if x * 100 >= 0.01 else "<0.01" ) @@ -754,17 +866,17 @@ def _plot_bound_violations(df, bound_names, outfile: None | Path = None): plt.close() -def _savefig(outfile: Path, fig=None): +def _savefig(outfile: Path, fig=None, **kwargs): ispdf = outfile.suffix == ".pdf" fig = fig if fig is not None else plt.gcf() if ispdf: # Saving a PDF with the alternative code below leads to a corrupted file. # Hence, we use the default savefig method. # NOTE: This means passing a virtual UPath is only supported for non-PDF files. - fig.savefig(outfile, dpi=300) + fig.savefig(outfile, dpi=300, **kwargs) else: with outfile.open("wb") as f: - fig.savefig(f, dpi=300) + fig.savefig(f, dpi=300, **kwargs) if __name__ == "__main__":