-
Notifications
You must be signed in to change notification settings - Fork 1
Add fancy Metrics plot to the together-py #344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7aa34f9
950baf7
37259db
d753f70
ecd5c7b
ad7f85c
6af5596
54583a7
0cee80a
528956b
bc64358
2fb9a36
5dbf05c
8b95caa
a418fa0
6817537
d6e28c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import sys | ||
| from typing import Literal, Optional, Annotated | ||
| from datetime import datetime | ||
|
|
||
| from cyclopts import Parameter | ||
|
|
||
| from together import omit | ||
| from together._utils._json import openapi_dumps | ||
| from together.lib.cli.utils.config import CLIConfigParameter | ||
| from together.lib.cli.utils._console import console | ||
| from together.lib.cli.components.loader import show_loading_status | ||
| from together.lib.cli.components.plot_finetune_metrics import METRICS_WIDTH_PADDING, metrics_ascii_charts | ||
|
|
||
|
|
||
| async def list_metrics( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still not sure if we want to keep it as one command. just as an example, lets say I want to only save metrics to the file but using this command I would get a bunch of plots I can be not interested in Wouldn't it be more convenient for us to have save_metrics and plot_metrics? What do you think? |
||
| fine_tune_id: Annotated[str, Parameter(help="The ID of the fine-tuning job")], | ||
| *, | ||
| config: CLIConfigParameter, | ||
| global_step_from: Annotated[ | ||
| Optional[int], Parameter(help="Filter metrics from this global step (inclusive).") | ||
| ] = None, | ||
| global_step_to: Annotated[Optional[int], Parameter(help="Filter metrics to this global step (inclusive).")] = None, | ||
| logged_at_from: Annotated[ | ||
| Optional[datetime], Parameter(help="Filter metrics logged at or after this time.") | ||
| ] = None, | ||
| logged_at_to: Annotated[Optional[datetime], Parameter(help="Filter metrics logged at or before this time.")] = None, | ||
| resolution: Annotated[ | ||
| Optional[int], | ||
| Parameter( | ||
| help="Number of uniformly sampled training metric points to return. Does not limit the number of eval metric points." | ||
| ), | ||
| ] = None, | ||
| ) -> None: | ||
| """Retrieve training metrics for a fine-tuning job.""" | ||
| response = await show_loading_status( | ||
| "Fetching metrics...", | ||
| config.client.fine_tuning.list_metrics( | ||
| fine_tune_id, | ||
| global_step_from=global_step_from or omit, | ||
| global_step_to=global_step_to or omit, | ||
| logged_at_from=logged_at_from or omit, | ||
| logged_at_to=logged_at_to or omit, | ||
| resolution=resolution or omit, | ||
| ), | ||
| ) | ||
|
|
||
| metrics = response.metrics or [] | ||
|
|
||
| if config.json: | ||
| json_bytes = openapi_dumps(metrics) | ||
| console.print_json(json_bytes.decode("utf-8")) | ||
| return | ||
|
|
||
| if len(metrics) == 0: | ||
| console.print(f"[muted]No metrics found for job {fine_tune_id}[/muted]") | ||
| return | ||
|
|
||
| console.print(metrics_ascii_charts(metrics, width=console.width - METRICS_WIDTH_PADDING)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| """Fine-tuning metrics plotting utilities. | ||
|
|
||
| Public API | ||
| ---------- | ||
| ``metrics_block_sparklines(metrics)`` | ||
| One ▁▂▃▄▅▆▇█ sparkline line per metric — used in ``retrieve``. | ||
|
|
||
| ``metrics_ascii_charts(metrics, height=6)`` | ||
| One full ASCII line chart per metric — used in ``list-metrics``. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from typing import Any | ||
| from collections import defaultdict | ||
|
|
||
| from rich.text import Text | ||
|
|
||
| from together.lib.cli.components.plots import should_log, render_line_chart, render_sparklines | ||
|
|
||
| # Columns reserved for the y-axis label area, ┼ connector, leading indent, and | ||
| # surrounding margin in the ASCII chart layout. This must be >= label_width + 1 | ||
| # (the default label_width used in metrics_ascii_charts is 8, so the minimum is | ||
| # 9). Callers subtract this from the terminal width to get the usable plot width. | ||
| METRICS_WIDTH_PADDING = 48 | ||
|
|
||
| _SKIP_KEYS: frozenset[str] = frozenset({"timestamp", "step", "global_step", "epoch"}) | ||
|
|
||
|
|
||
| def _is_skip(k: str) -> bool: | ||
| base = k.rsplit("/", 1)[-1] | ||
| return base in _SKIP_KEYS or base.endswith("_step") or base.endswith("_epoch") | ||
|
|
||
|
|
||
| def _step_label(x: float) -> str: | ||
| return str(int(x)) | ||
|
|
||
|
|
||
| def _collect_series( | ||
| metrics: list[dict[str, Any]], | ||
| ) -> dict[str, tuple[list[float], list[float]]]: | ||
| """Collect plottable numeric series from a list of metric dicts. | ||
|
|
||
| Returns a mapping of name → (xs, ys). Keys are discovered in insertion | ||
| order; step/epoch/timestamp fields are skipped. NaN values are converted | ||
| to ``-inf`` so the rendering engine plots them at the very bottom of the | ||
| chart rather than silently dropping them. | ||
| """ | ||
| series: dict[str, tuple[list[float], list[float]]] = defaultdict(lambda: ([], [])) | ||
| for row in metrics: | ||
| step = float(row["train/global_step"]) | ||
| for k, v in row.items(): | ||
| if _is_skip(k) or isinstance(v, bool) or not isinstance(v, (int, float)): | ||
| continue | ||
| val = float(v) | ||
| # NaN is rendered as a dip to the bottom (-inf sentinel). | ||
| if math.isnan(val): | ||
| val = float("-inf") | ||
| series[k][0].append(step) | ||
| series[k][1].append(val) | ||
| return series | ||
|
|
||
|
|
||
| def _no_data() -> Text: | ||
| t = Text() | ||
| t.append("No plottable metrics found.", style="muted") | ||
| return t | ||
|
|
||
|
|
||
| def metrics_block_sparklines( | ||
| metrics: list[dict[str, Any]], | ||
| *, | ||
| width: int = 60, | ||
| ) -> Text: | ||
| """One block-sparkline line per metric, coloured with the CLI theme. | ||
|
|
||
| Args: | ||
| metrics: List of flat metric dicts (one per training step). | ||
| width: Sparkline character width (default 60). | ||
|
|
||
| Returns: | ||
| A ``rich.text.Text`` ready for ``console.print()``. | ||
| """ | ||
| series = _collect_series(metrics) | ||
| if not series: | ||
| return _no_data() | ||
| label_w = max(len(k) for k in series) | ||
| text = Text() | ||
| for key, (xs, ys) in series.items(): | ||
| text.append_text( | ||
| render_sparklines( | ||
| key, | ||
| xs, | ||
| ys, | ||
| width=width, | ||
| y_log=should_log(ys), | ||
| label_width=label_w, | ||
| ) | ||
| ) | ||
| return text | ||
|
|
||
|
|
||
| def metrics_ascii_charts( | ||
| metrics: list[dict[str, Any]], | ||
| *, | ||
| height: int = 6, | ||
| width: int = 60, | ||
| label_width: int = 8, | ||
| ) -> Text: | ||
| """One ASCII line chart per metric, with a global-step x-axis. | ||
|
|
||
| Args: | ||
| metrics: List of flat metric dicts (one per training step). | ||
| height: Chart body height in rows (default 6). | ||
| width: Plot character width (default 60). | ||
|
|
||
| Returns: | ||
| A ``rich.text.Text`` ready for ``console.print()``. | ||
| """ | ||
| series = _collect_series(metrics) | ||
| text = Text() | ||
| for key, (xs, ys) in series.items(): | ||
| if text: | ||
| text.append("\n") | ||
| text.append_text( | ||
| render_line_chart( | ||
| xs, | ||
| {key: ys}, | ||
| x_label=_step_label, | ||
| y_log=should_log(ys), | ||
| height=height, | ||
| width=width, | ||
| label_width=label_width, | ||
| ) | ||
| ) | ||
| return text if text else _no_data() | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "metrics_block_sparklines", | ||
| "metrics_ascii_charts", | ||
| "METRICS_WIDTH_PADDING", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| """Generic CLI plot utilities.""" | ||
|
|
||
| from together.lib.cli.components.plots._engine import should_log, render_line_chart, render_sparklines | ||
|
|
||
| __all__ = [ | ||
| "render_line_chart", | ||
| "render_sparklines", | ||
| "should_log", | ||
| ] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we move this from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest adding some examples to the usage like we do with other commands that have different parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added help