diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 586828934..cff01f26f 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.14.0" + ".": "2.15.0" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 498815517..e8f2d8264 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ -configured_endpoints: 75 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai/togetherai-5f05c9669c67c3f4b0ebfe2317d2768cd96317424965ebb2acf06a7757a7d0ca.yml -openapi_spec_hash: 84f45151f4d0eed68551b5ffda61595a -config_hash: ec427df08d61d8888138f15cd53c6454 +configured_endpoints: 81 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai/togetherai-ce108a2095d36552bb556506de04475674f512a13bc5aa099e9750993405be14.yml +openapi_spec_hash: 4763dd426dd805306bbb38a314158cd3 +config_hash: b35d5968fb07cce1c1be735f874898b1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 418ba50f1..c6d0c9f72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,40 @@ # Changelog +## 2.15.0 (2026-05-20) + +Full Changelog: [v2.14.0...v2.15.0](https://github.com/togethercomputer/together-py/compare/v2.14.0...v2.15.0) + +### Features + +* **api:** add cluster config/OIDC/add-ons params, project filtering, update storage types ([9a8c60e](https://github.com/togethercomputer/together-py/commit/9a8c60eb51daba174c0a4761612b3dd51fb5bee5)) +* **api:** add disable_position_bias_correction, remove num_samples from eval compare results ([27e6c2d](https://github.com/togethercomputer/together-py/commit/27e6c2db1e2549d8b2352f73265067f0eac9b44c)) +* **api:** add h200-140gb gpu_type to jig deploy/update methods ([0f34ea4](https://github.com/togethercomputer/together-py/commit/0f34ea4e1441a08b014f19d99588902b17eda1be)) +* **api:** add instance_name field to remediation model ([4c7fc66](https://github.com/togethercomputer/together-py/commit/4c7fc662363054f47d5e7a01353e3f0da98d8b6a)) +* **api:** Add node remediation APIs to clusters sdks ([029c3fd](https://github.com/togethercomputer/together-py/commit/029c3fd79c22f130dab4a46bc62a6e1410908da4)) +* **api:** add trigger param, support multiple modes in remediations list ([997deea](https://github.com/togethercomputer/together-py/commit/997deeae7c514b0ce2dc65394d262abe9bd35766)) +* **api:** manual updates ([f4de411](https://github.com/togethercomputer/together-py/commit/f4de41192250b3c609e44da6bee18309db209f35)) +* **api:** manual updates ([b5e42a0](https://github.com/togethercomputer/together-py/commit/b5e42a042c367dbe14021be3c6612155ac8f6fac)) +* **cli:** add eval compare bias correction flag ([#375](https://github.com/togethercomputer/together-py/issues/375)) ([ac8482e](https://github.com/togethercomputer/together-py/commit/ac8482ebbf9fdf7e67973ff36a0178ce774963cb)) +* **cli:** add get as alias for retrieve subcommands ([#367](https://github.com/togethercomputer/together-py/issues/367)) ([d283d11](https://github.com/togethercomputer/together-py/commit/d283d1192b1c75ad676420be4ea30137883d55a6)) +* **cli:** add remediation list filters ([#372](https://github.com/togethercomputer/together-py/issues/372)) ([1656759](https://github.com/togethercomputer/together-py/commit/16567597382f3345aff0450bcf1a257976a97139)) +* **jig:** copy and use uv.lock if exists on autogenerated dockerfile ([#370](https://github.com/togethercomputer/together-py/issues/370)) ([47e5c89](https://github.com/togethercomputer/together-py/commit/47e5c891ac2272b0d20c1c266d0e3a9527448019)) +* Sync deployments OpenAPI spec ([1caa5fa](https://github.com/togethercomputer/together-py/commit/1caa5fa4c41dff79164d8edd7436e66747eab712)) + + +### Bug Fixes + +* **api:** make duration_days optional in clusters create, size_tib optional in storage update ([899752d](https://github.com/togethercomputer/together-py/commit/899752dbebed9a75433b9ab95245c3bf15237eb3)) +* **api:** remove error field, make request_id required in jig queue submit response ([5ae0fbc](https://github.com/togethercomputer/together-py/commit/5ae0fbca3e592dafedc4638642582585f29098df)) +* **api:** remove trigger parameter from remediations list method ([d6310d8](https://github.com/togethercomputer/together-py/commit/d6310d881cc12bb67132c9446d301d1663fd9f48)) +* **jig:** honor uv default groups in autogenerated dockerfile ([#301](https://github.com/togethercomputer/together-py/issues/301)) ([85cf77b](https://github.com/togethercomputer/together-py/commit/85cf77b6dc8df8a4f5859deb543123195c484b5b)) +* **types:** correct status field to enum in cluster_storage model ([2109f0a](https://github.com/togethercomputer/together-py/commit/2109f0a0c897a7d5659f700042ec97b0843b3228)) +* **types:** remove node_name from ControlPlaneNode and GPUWorkerNode ([7a1a7c2](https://github.com/togethercomputer/together-py/commit/7a1a7c21f1cb0e898cf9c9cf12746964c5d1b978)) + + +### Documentation + +* **api:** add parameter descriptions to storage methods and types ([8c35457](https://github.com/togethercomputer/together-py/commit/8c35457b06dc6a58f0c1343accb8db09ad91b845)) + ## 2.14.0 (2026-05-12) Full Changelog: [v2.13.0...v2.14.0](https://github.com/togethercomputer/together-py/compare/v2.13.0...v2.14.0) diff --git a/api.md b/api.md index 1ea19fb05..156264d90 100644 --- a/api.md +++ b/api.md @@ -48,7 +48,7 @@ from together.types.beta.jig import Volume, VolumeListResponse Methods: - client.beta.jig.volumes.create(\*\*params) -> Volume -- client.beta.jig.volumes.retrieve(id) -> Volume +- client.beta.jig.volumes.retrieve(id, \*\*params) -> Volume - client.beta.jig.volumes.update(id, \*\*params) -> Volume - client.beta.jig.volumes.list() -> VolumeListResponse - client.beta.jig.volumes.delete(id) -> object @@ -87,10 +87,27 @@ Methods: - client.beta.clusters.create(\*\*params) -> Cluster - client.beta.clusters.retrieve(cluster_id) -> Cluster - client.beta.clusters.update(cluster_id, \*\*params) -> Cluster -- client.beta.clusters.list() -> ClusterListResponse +- client.beta.clusters.list(\*\*params) -> ClusterListResponse - client.beta.clusters.delete(cluster_id) -> ClusterDeleteResponse - client.beta.clusters.list_regions() -> ClusterListRegionsResponse +### Remediations + +Types: + +```python +from together.types.beta.clusters import Remediation, RemediationListResponse +``` + +Methods: + +- client.beta.clusters.remediations.create(instance_id, \*, cluster_id, \*\*params) -> Remediation +- client.beta.clusters.remediations.retrieve(remediation_id, \*, cluster_id, instance_id) -> Remediation +- client.beta.clusters.remediations.list(instance_id, \*, cluster_id, \*\*params) -> RemediationListResponse +- client.beta.clusters.remediations.approve(remediation_id, \*, cluster_id, instance_id, \*\*params) -> Remediation +- client.beta.clusters.remediations.cancel(remediation_id, \*, cluster_id, instance_id) -> Remediation +- client.beta.clusters.remediations.reject(remediation_id, \*, cluster_id, instance_id, \*\*params) -> Remediation + ### Storage Types: @@ -104,7 +121,7 @@ Methods: - client.beta.clusters.storage.create(\*\*params) -> ClusterStorage - client.beta.clusters.storage.retrieve(volume_id) -> ClusterStorage - client.beta.clusters.storage.update(\*\*params) -> ClusterStorage -- client.beta.clusters.storage.list() -> StorageListResponse +- client.beta.clusters.storage.list(\*\*params) -> StorageListResponse - client.beta.clusters.storage.delete(volume_id) -> StorageDeleteResponse # Chat diff --git a/pyproject.toml b/pyproject.toml index 5495bee53..5bf286440 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "together" -version = "2.14.0" +version = "2.15.0" description = "The official Python library for the together API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/together/_version.py b/src/together/_version.py index 605d05197..6656537cd 100644 --- a/src/together/_version.py +++ b/src/together/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "together" -__version__ = "2.14.0" # x-release-please-version +__version__ = "2.15.0" # x-release-please-version diff --git a/src/together/lib/cli/__init__.py b/src/together/lib/cli/__init__.py index 445871f45..dbde778e7 100644 --- a/src/together/lib/cli/__init__.py +++ b/src/together/lib/cli/__init__.py @@ -61,9 +61,12 @@ FINE_TUNING_DOWNLOAD_HELP_EXAMPLES, BETA_CLUSTERS_STORAGE_HELP_EXAMPLES, FILES_RETRIEVE_CONTENT_HELP_EXAMPLES, + FINE_TUNING_LIST_METRICS_HELP_EXAMPLES, + BETA_CLUSTERS_REMEDIATIONS_HELP_EXAMPLES, BETA_CLUSTERS_STORAGE_CREATE_HELP_EXAMPLES, BETA_CLUSTERS_STORAGE_UPDATE_HELP_EXAMPLES, BETA_CLUSTERS_GET_CREDENTIALS_HELP_EXAMPLES, + BETA_CLUSTERS_REMEDIATIONS_CREATE_HELP_EXAMPLES, ) from together.lib.cli.utils._help_formatter import help_formatter from together.lib.cli.utils._preparse_tokens import preparse_tokens @@ -380,6 +383,11 @@ async def run_command() -> None: help_epilogue=FINE_TUNING_DOWNLOAD_HELP_EXAMPLES, ) fine_tuning_app.command((f"{_CLI}.fine_tuning.delete:delete"), alias="-d", help="Delete a fine-tuning job") +fine_tuning_app.command( + (f"{_CLI}.fine_tuning.list_metrics:list_metrics"), + help="Retrieve training metrics for a fine-tuning job", + help_epilogue=FINE_TUNING_LIST_METRICS_HELP_EXAMPLES, +) ## Models API commands models_app = app.command(App(name="models", help="List and upload models", help_epilogue=MODELS_HELP_EXAMPLES)) @@ -486,6 +494,44 @@ async def run_command() -> None: ) storage_app.command((f"{_CLI}.beta.clusters.storage.delete:delete"), help="Delete a storage volume", alias="-d") +### Clusters > Remediations API commands +remediations_app = clusters_app.command( + App( + name="remediations", + help="Manage node remediations", + group="Subcommands", + help_epilogue=BETA_CLUSTERS_REMEDIATIONS_HELP_EXAMPLES, + ) +) +remediations_app.command( + (f"{_CLI}.beta.clusters.remediations.create:create"), + alias="-c", + help="Create a node remediation", + help_epilogue=BETA_CLUSTERS_REMEDIATIONS_CREATE_HELP_EXAMPLES, +) +remediations_app.command( + (f"{_CLI}.beta.clusters.remediations.list:list"), + alias="ls", + help="List node remediations", +) +remediations_app.command( + (f"{_CLI}.beta.clusters.remediations.retrieve:retrieve"), + alias="get", + help="Get remediation details", +) +remediations_app.command( + (f"{_CLI}.beta.clusters.remediations.approve:approve"), + help="Approve a pending remediation", +) +remediations_app.command( + (f"{_CLI}.beta.clusters.remediations.cancel:cancel"), + help="Cancel a pending remediation", +) +remediations_app.command( + (f"{_CLI}.beta.clusters.remediations.reject:reject"), + help="Reject a pending remediation", +) + ### Jig commands jig_app = beta_app.command( App(name="jig", help="Build, deploy, and manage custom containers", help_epilogue=JIG_HELP_EXAMPLES) diff --git a/src/together/lib/cli/api/beta/clusters/remediations/_resolve_remediation.py b/src/together/lib/cli/api/beta/clusters/remediations/_resolve_remediation.py new file mode 100644 index 000000000..55a49e7ba --- /dev/null +++ b/src/together/lib/cli/api/beta/clusters/remediations/_resolve_remediation.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import sys + +from together import omit +from together._types import Omit +from together.lib.cli.utils.config import CLIConfigParameter +from together.lib.cli.utils._console import console +from together.types.beta.clusters.remediation import Remediation + + +async def resolve_remediation(config: CLIConfigParameter, remediation_id: str) -> Remediation: + clusters = await config.client.beta.clusters.list() + + for cluster in clusters.clusters: + page_token: str | Omit = omit + while True: + response = await config.client.beta.clusters.remediations.list( + "-", + cluster_id=cluster.cluster_id, + page_size=100, + page_token=page_token, + ) + for remediation in response.remediations: + if remediation.id == remediation_id: + return remediation + + if not response.has_next or not response.next_page_token: + break + page_token = response.next_page_token + + console.print(f"[red]Error:[/red] Remediation not found: {remediation_id}") + sys.exit(1) diff --git a/src/together/lib/cli/api/beta/clusters/remediations/approve.py b/src/together/lib/cli/api/beta/clusters/remediations/approve.py new file mode 100644 index 000000000..eb81a5346 --- /dev/null +++ b/src/together/lib/cli/api/beta/clusters/remediations/approve.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Optional, Annotated + +from cyclopts import Parameter + +from together import omit +from together._utils._json import openapi_dumps +from together.lib.cli.utils.config import CLIConfigParameter +from together.lib.cli.utils._console import console +from together.lib.cli.components.loader import show_loading_status +from together.lib.cli.api.beta.clusters.remediations._resolve_remediation import resolve_remediation + + +async def approve( + remediation_id: str, + comment: Annotated[Optional[str], Parameter(help="Comment explaining the approval")] = None, + *, + config: CLIConfigParameter, +) -> None: + """Approve a pending remediation.""" + remediation = await show_loading_status("Finding remediation...", resolve_remediation(config, remediation_id)) + response = await show_loading_status( + "Approving remediation...", + config.client.beta.clusters.remediations.approve( + remediation_id, + cluster_id=remediation.cluster_id, + instance_id=remediation.instance_id, + comment=comment or omit, + ), + ) + + if config.json: + console.print_json(openapi_dumps(response).decode("utf-8")) + return + + console.print(f"[blue]Remediation approved.[/blue] ({response.id})") diff --git a/src/together/lib/cli/api/beta/clusters/remediations/cancel.py b/src/together/lib/cli/api/beta/clusters/remediations/cancel.py new file mode 100644 index 000000000..99adc6bff --- /dev/null +++ b/src/together/lib/cli/api/beta/clusters/remediations/cancel.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from together._utils._json import openapi_dumps +from together.lib.cli.utils.config import CLIConfigParameter +from together.lib.cli.utils._console import console +from together.lib.cli.components.loader import show_loading_status +from together.lib.cli.api.beta.clusters.remediations._resolve_remediation import resolve_remediation + + +async def cancel( + remediation_id: str, + *, + config: CLIConfigParameter, +) -> None: + """Cancel a pending remediation.""" + remediation = await show_loading_status("Finding remediation...", resolve_remediation(config, remediation_id)) + response = await show_loading_status( + "Cancelling remediation...", + config.client.beta.clusters.remediations.cancel( + remediation_id, + cluster_id=remediation.cluster_id, + instance_id=remediation.instance_id, + ), + ) + + if config.json: + console.print_json(openapi_dumps(response).decode("utf-8")) + return + + console.print(f"[blue]Remediation cancelled.[/blue] ({response.id})") diff --git a/src/together/lib/cli/api/beta/clusters/remediations/create.py b/src/together/lib/cli/api/beta/clusters/remediations/create.py new file mode 100644 index 000000000..8a6f2db49 --- /dev/null +++ b/src/together/lib/cli/api/beta/clusters/remediations/create.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Literal, Optional, Annotated, cast + +from cyclopts import Parameter + +from together import omit +from together._utils._json import openapi_dumps +from together.lib.cli.utils.config import CLIConfigParameter +from together.lib.cli.utils._console import console +from together.lib.cli.components.loader import show_loading_status + +RemediationModeParameter = Annotated[ + Literal[ + "VM_ONLY", + "HOST_AWARE", + "EVICT_WITHOUT_REPLACEMENT", + "REBOOT_VM", + ], + Parameter(help="The type of remediation to perform"), +] + + +async def create( + cluster_id: Annotated[str, Parameter(help="The ID of the cluster")], + instance_id: Annotated[str, Parameter(help="The ID of the node within the cluster to remediate")], + *, + mode: RemediationModeParameter, + remediation_id: Annotated[Optional[str], Parameter(help="Client-specified ID for idempotency")] = None, + reason: Annotated[Optional[str], Parameter(help="Reason for the remediation")] = None, + config: CLIConfigParameter, +) -> None: + """Create a node remediation for an instance.""" + safe_mode = cast( + Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ], + f"REMEDIATION_MODE_{mode}", + ) + + response = await show_loading_status( + "Creating remediation...", + config.client.beta.clusters.remediations.create( + instance_id, + cluster_id=cluster_id, + mode=safe_mode, + remediation_id=remediation_id or omit, + reason=reason or omit, + ), + ) + + if config.json: + console.print_json(openapi_dumps(response).decode("utf-8")) + return + + console.print(f"[green]√ Remediation created[/green] [dim]({response.id})[/dim]") + console.print(f" Remediations may take some time to complete.\n") + console.print(f" To retrieve the status:") + console.print(f" [dim]-[/dim] [primary]tg beta clusters remediations {response.id}[/primary]") diff --git a/src/together/lib/cli/api/beta/clusters/remediations/list.py b/src/together/lib/cli/api/beta/clusters/remediations/list.py new file mode 100644 index 000000000..a25618b76 --- /dev/null +++ b/src/together/lib/cli/api/beta/clusters/remediations/list.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from typing import List, Literal, Optional, Annotated, cast + +from cyclopts import Parameter + +from together import omit +from together._utils._json import openapi_dumps +from together.lib.utils.tools import format_datetime +from together.lib.cli.utils.config import CLIConfigParameter +from together.lib.cli.utils._console import console +from together.lib.cli.components.list import ListTable +from together.lib.cli.components.loader import show_loading_status + +RemediationModeParameter = Annotated[ + Optional[ + list[ + Literal[ + "VM_ONLY", + "HOST_AWARE", + "EVICT_WITHOUT_REPLACEMENT", + "REBOOT_VM", + ] + ] + ], + Parameter(help="Filter by remediation mode. Can be used multiple times."), +] +RemediationStateParameter = Annotated[ + Optional[ + list[Literal["PENDING_APPROVAL", "PENDING", "RUNNING", "SUCCEEDED", "FAILED", "CANCELLED", "AUTO_RESOLVED"]] + ], + Parameter(help="Filter by remediation state. Can be used multiple times."), +] +RemediationTriggerParameter = Annotated[ + Optional[list[Literal["MANUAL", "AUTOMATED"]]], + Parameter(help="Filter by remediation trigger. Can be used multiple times."), +] + + +async def list( + cluster_id: str, + instance_id: Annotated[Optional[str], Parameter(help="Instance ID to list remediations for")] = None, + after: Annotated[Optional[str], Parameter(help="Pagination token from a previous request")] = None, + mode: RemediationModeParameter = None, + state: RemediationStateParameter = None, + trigger: RemediationTriggerParameter = None, + *, + config: CLIConfigParameter, +) -> None: + """List node remediations for a cluster or instance.""" + safe_modes = cast( + List[ + Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ] + ], + [f"REMEDIATION_MODE_{value}" for value in mode] if mode else [], + ) + safe_triggers = cast( + List[Literal["REMEDIATION_TRIGGER_MANUAL", "REMEDIATION_TRIGGER_AUTOMATED"]], + [f"REMEDIATION_TRIGGER_{value}" for value in trigger] if trigger else [], + ) + response = await show_loading_status( + "Loading remediations...", + config.client.beta.clusters.remediations.list( + instance_id or "-", + cluster_id=cluster_id, + mode=safe_modes or omit, + page_token=after or omit, + state=state or omit, + trigger=safe_triggers or omit, + ), + ) + + if config.json: + console.print_json(openapi_dumps(response).decode("utf-8")) + return + + table = ListTable(title="Cluster Remediations", empty_message="No remediations found for this cluster.") + table.add_column("Created") + table.add_primary_column("Instance", ratio=3) + table.add_column("Mode") + table.add_column("State") + table.add_column("Trigger") + table.add_column("Remediation ID", ratio=3) + + for remediation in response.remediations: + table.add_row( + format_datetime(remediation.create_time) if remediation.create_time else "-", + _format_instance(remediation.instance_id, remediation.instance_name), + remediation.mode.replace("REMEDIATION_MODE_", ""), + _colorize(remediation.state), + remediation.trigger.replace("REMEDIATION_TRIGGER_", ""), + remediation.id, + ) + + console.print(table) + if response.has_next and response.next_page_token: + command = f"tg beta clusters remediations ls {cluster_id}" + if instance_id: + command += f" {instance_id}" + console.print("\n[blue dim]To display the next page, run:[/blue dim]") + console.print(f" [dim]-[/dim] [white]{command} --after {response.next_page_token}[/white]") + + +def _colorize(state: str) -> str: + state_colors = { + "PENDING_APPROVAL": "yellow", + "PENDING": "yellow", + "RUNNING": "yellow", + "SUCCEEDED": "green", + "FAILED": "red", + "CANCELLED": "dim", + "AUTO_RESOLVED": "green", + } + color = state_colors[state] if state in state_colors else "white" + return f"[{color}]{state}[/{color}]" + + +def _format_instance(instance_id: str, instance_name: str | None) -> str: + if not instance_name: + return instance_id + return f"{instance_name} ({instance_id})" diff --git a/src/together/lib/cli/api/beta/clusters/remediations/reject.py b/src/together/lib/cli/api/beta/clusters/remediations/reject.py new file mode 100644 index 000000000..6fab890c4 --- /dev/null +++ b/src/together/lib/cli/api/beta/clusters/remediations/reject.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Optional, Annotated + +from cyclopts import Parameter + +from together import omit +from together._utils._json import openapi_dumps +from together.lib.cli.utils.config import CLIConfigParameter +from together.lib.cli.utils._console import console +from together.lib.cli.components.loader import show_loading_status +from together.lib.cli.api.beta.clusters.remediations._resolve_remediation import resolve_remediation + + +async def reject( + remediation_id: str, + comment: Annotated[Optional[str], Parameter(help="Comment explaining the rejection")] = None, + *, + config: CLIConfigParameter, +) -> None: + """Reject a pending remediation.""" + remediation = await show_loading_status("Finding remediation...", resolve_remediation(config, remediation_id)) + response = await show_loading_status( + "Rejecting remediation...", + config.client.beta.clusters.remediations.reject( + remediation_id, + cluster_id=remediation.cluster_id, + instance_id=remediation.instance_id, + comment=comment or omit, + ), + ) + + if config.json: + console.print_json(openapi_dumps(response).decode("utf-8")) + return + + console.print(f"[blue]Remediation rejected.[/blue] ({response.id})") diff --git a/src/together/lib/cli/api/beta/clusters/remediations/retrieve.py b/src/together/lib/cli/api/beta/clusters/remediations/retrieve.py new file mode 100644 index 000000000..8143dc450 --- /dev/null +++ b/src/together/lib/cli/api/beta/clusters/remediations/retrieve.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from together._utils._json import openapi_dumps +from together.lib.cli.utils.config import CLIConfigParameter +from together.lib.cli.utils._console import console +from together.lib.cli.components.loader import show_loading_status +from together.lib.cli.components.model_dump import print_model_dump +from together.lib.cli.api.beta.clusters.remediations._resolve_remediation import resolve_remediation + + +async def retrieve( + remediation_id: str, + *, + config: CLIConfigParameter, +) -> None: + """Retrieve remediation details.""" + remediation = await show_loading_status("Finding remediation...", resolve_remediation(config, remediation_id)) + response = await show_loading_status( + "Retrieving remediation...", + config.client.beta.clusters.remediations.retrieve( + remediation_id, + cluster_id=remediation.cluster_id, + instance_id=remediation.instance_id, + ), + ) + + if config.json: + console.print_json(openapi_dumps(response).decode("utf-8")) + return + + print_model_dump(response, show_nulls=False, only_set_fields=True) diff --git a/src/together/lib/cli/api/beta/jig/jig.py b/src/together/lib/cli/api/beta/jig/jig.py index 08bfc40b7..002ccefdb 100644 --- a/src/together/lib/cli/api/beta/jig/jig.py +++ b/src/together/lib/cli/api/beta/jig/jig.py @@ -29,6 +29,7 @@ from cyclopts import Parameter from together import Together +from together._types import Omit, omit from together._exceptions import APIError, NotFoundError, AuthenticationError from together._utils._json import openapi_dumps from together.lib.cli.utils.config import CLIConfig, CLIConfigParameter @@ -363,9 +364,14 @@ def _generate_dockerfile(config: JigConfig) -> str: pip = "" if Path("pyproject.toml").exists(): - pip = """COPY pyproject.toml . + pip = "COPY pyproject.toml .\n" + sync_flags = "--inexact --no-dev --no-install-project --compile-bytecode" + if Path("uv.lock").exists(): + pip += "COPY uv.lock .\n" + sync_flags = f"--frozen {sync_flags}" + pip += f"""ENV UV_PROJECT_ENVIRONMENT=/usr/local RUN --mount=type=cache,target=/root/.cache/uv \\ - uv pip install --system --compile-bytecode . && \\ + uv sync {sync_flags} && \\ (python -c "import sprocket" 2>/dev/null || (echo "sprocket not found in pyproject.toml, installing from pypi.together.ai..." && uv pip install --system --extra-index-url https://pypi.together.ai/ sprocket)) """ @@ -615,6 +621,11 @@ def delete_secret(self, name: str) -> None: # == Build / Push / Deploy / Track == def build(self, tag: str = "latest", warmup: bool = False, docker_args: str | None = None) -> None: + if self.config.deploy.image: + raise JigError( + f"Invalid command: deploy.image is set to '{self.config.deploy.image}'. " + "Use 'jig deploy' to deploy the configured image, or remove deploy.image to build from source." + ) image = self.image(tag) if not _dockerfile(self.config): @@ -909,15 +920,36 @@ def once(msg: str, detail: str | None = None) -> None: # == Query == - def logs(self, rid: str | None = None) -> str: - if not rid: - return "\n".join(self.api.retrieve_logs(self.name).lines or []) or "No logs available" - body = "\n".join(self.api.retrieve_logs(self.name, replica_id=rid).lines or []) - return f"\n--- Logs for {rid} ---\n{body or 'No logs available'}\n--- End of logs ---\n" + def logs( + self, + replica_id: str | None = None, + revision: str | None = None, + version: str | None = None, + ) -> str: + response = self.api.retrieve_logs( + self.name, + replica_id=replica_id or omit, + revision=revision or omit, + version=version or omit, + ) + body = "\n".join(response.lines or []) or "No logs available" + if replica_id: + return f"\n--- Logs for {replica_id} ---\n{body}\n--- End of logs ---\n" + return body - def follow_logs(self) -> None: + def follow_logs( + self, + replica_id: str | None = None, + revision: str | None = None, + version: str | None = None, + ) -> None: try: - with self.api.with_streaming_response.retrieve_logs(self.name) as stream: + with self.api.with_streaming_response.retrieve_logs( + self.name, + replica_id=replica_id or omit, + revision=revision or omit, + version=version or omit, + ) as stream: for line in stream.iter_lines(): if line: log_lines = json.loads(line).get("lines", []) @@ -1181,9 +1213,15 @@ def endpoint(jig: Jig) -> str: return f"https://api.together.ai/v1/deployment-request/{jig.name}" -def logs(jig: Jig, follow: bool) -> str | None: +def logs( + jig: Jig, + follow: bool, + replica_id: str | None, + revision: str | None, + version: str | None, +) -> str | None: """Get deployment logs""" - return jig.follow_logs() if follow else jig.logs() + return jig.follow_logs(replica_id, revision, version) if follow else jig.logs(replica_id, revision, version) def destroy(jig: Jig) -> str: @@ -1320,10 +1358,14 @@ def volumes_delete(jig: Jig, name: str) -> None: console.print(f"\N{CHECK MARK} Deleted volume {name}") -def volumes_describe(jig: Jig, name: str) -> Any: +def _optional_int(value: int | None) -> int | Omit: + return value if value is not None else omit + + +def volumes_describe(jig: Jig, name: str, version: int | None = None) -> Any: """Describe a volume""" try: - return jig.api.volumes.with_raw_response.retrieve(name) + return jig.api.volumes.with_raw_response.retrieve(name, version=_optional_int(version)) except NotFoundError: raise JigError(f"Volume {name} not found") from None @@ -1361,12 +1403,16 @@ async def jig_volumes_list( async def jig_volumes_describe( name: Annotated[str, Parameter(name="--name", help="Volume name")], + volume_version: Annotated[ + Optional[int], + Parameter(name="--volume-version", help="Volume version to describe"), + ] = None, *, config: CLIConfigParameter, ) -> None: """Describe a volume.""" try: - vol = await config.client.beta.jig.volumes.retrieve(name) + vol = await config.client.beta.jig.volumes.retrieve(name, version=_optional_int(volume_version)) except NotFoundError: _jig_fail(f"Volume {name} not found") else: @@ -1456,12 +1502,24 @@ def endpoint_cli( def logs_cli( follow: Annotated[bool, Parameter(help="Follow log output", negative=())] = False, + replica_id: Annotated[Optional[str], Parameter(name="--replica-id", help="Replica ID to filter logs")] = None, + revision: Annotated[ + Optional[str], + Parameter(name="--revision", help="Deployment revision UUID to filter logs"), + ] = None, + image_version: Annotated[ + Optional[str], + Parameter( + name="--image-version", + help="Deployment image version (tag or last 4 characters of image digest) to filter logs", + ), + ] = None, *, config: CLIConfigParameter, toml_config: TomlConfigParameter = None, ) -> None: """Get deployment logs.""" - _run_jig_cmd(config, toml_config, lambda jig: logs(jig, follow)) + _run_jig_cmd(config, toml_config, lambda jig: logs(jig, follow, replica_id, revision, image_version)) def destroy_cli( diff --git a/src/together/lib/cli/api/evals/create.py b/src/together/lib/cli/api/evals/create.py index 400397744..345794160 100644 --- a/src/together/lib/cli/api/evals/create.py +++ b/src/together/lib/cli/api/evals/create.py @@ -77,6 +77,13 @@ async def create( pass_threshold: Annotated[ Optional[float], Parameter(help="Threshold for passing (required for score type)") ] = None, + disable_position_bias_correction: Annotated[ + bool, + Parameter( + negative=(), + help="For compare evals, run only the original-order judge pass without position-bias correction", + ), + ] = False, model_a_field: Annotated[ Optional[str], Parameter( @@ -274,6 +281,7 @@ async def create( parameters=ParametersEvaluationCompareParameters( input_data_file_path=training_file, judge=judge_config, + disable_position_bias_correction=disable_position_bias_correction, model_a=cast(ParametersEvaluationCompareParametersModelAEvaluationModelRequest, model_a_final), model_b=cast(ParametersEvaluationCompareParametersModelBEvaluationModelRequest, model_b_final), ), diff --git a/src/together/lib/cli/api/fine_tuning/list_metrics.py b/src/together/lib/cli/api/fine_tuning/list_metrics.py new file mode 100644 index 000000000..c30728564 --- /dev/null +++ b/src/together/lib/cli/api/fine_tuning/list_metrics.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import Optional, Annotated +from datetime import datetime + +from cyclopts import Parameter + +from together import omit +from together._utils._json import openapi_dumps +from together.lib.cli.utils.config import CLIConfigParameter +from together.lib.cli.utils._console import console +from together.lib.cli.components.loader import show_loading_status +from together.lib.cli.components.plot_finetune_metrics import METRICS_WIDTH_PADDING, metrics_ascii_charts + + +async def list_metrics( + fine_tune_id: Annotated[str, Parameter(help="The ID of the fine-tuning job")], + *, + config: CLIConfigParameter, + global_step_from: Annotated[ + Optional[int], Parameter(help="Filter metrics from this global step (inclusive).") + ] = None, + global_step_to: Annotated[Optional[int], Parameter(help="Filter metrics to this global step (inclusive).")] = None, + logged_at_from: Annotated[ + Optional[datetime], Parameter(help="Filter metrics logged at or after this time.") + ] = None, + logged_at_to: Annotated[Optional[datetime], Parameter(help="Filter metrics logged at or before this time.")] = None, + resolution: Annotated[ + Optional[int], + Parameter( + help="Number of uniformly sampled training metric points to return. Does not limit the number of eval metric points." + ), + ] = None, +) -> None: + """Retrieve training metrics for a fine-tuning job.""" + response = await show_loading_status( + "Fetching metrics...", + config.client.fine_tuning.list_metrics( + fine_tune_id, + global_step_from=global_step_from if global_step_from is not None else omit, + global_step_to=global_step_to if global_step_to is not None else omit, + logged_at_from=logged_at_from if logged_at_from is not None else omit, + logged_at_to=logged_at_to if logged_at_to is not None else omit, + resolution=resolution if resolution is not None else omit, + ), + ) + + metrics = response.metrics or [] + + if config.json: + json_bytes = openapi_dumps(metrics) + console.print_json(json_bytes.decode("utf-8")) + return + + if len(metrics) == 0: + console.print(f"[muted]No metrics found for job {fine_tune_id}[/muted]") + return + + console.print(metrics_ascii_charts(metrics, width=console.width - METRICS_WIDTH_PADDING)) diff --git a/src/together/lib/cli/api/fine_tuning/retrieve.py b/src/together/lib/cli/api/fine_tuning/retrieve.py index 64bc68378..0b77997e0 100644 --- a/src/together/lib/cli/api/fine_tuning/retrieve.py +++ b/src/together/lib/cli/api/fine_tuning/retrieve.py @@ -1,7 +1,10 @@ from __future__ import annotations +from typing import Annotated from datetime import datetime +from cyclopts import Parameter + from together._utils._json import openapi_dumps from together.lib.cli.api._utils import generate_progress_bar from together.lib.cli.utils.config import CLIConfigParameter @@ -9,6 +12,7 @@ from together.lib.cli.utils._console import console from together.lib.cli.components.loader import show_loading_status from together.lib.cli.components.model_dump import print_model_dump +from together.lib.cli.components.plot_finetune_metrics import METRICS_WIDTH_PADDING, metrics_block_sparklines _NEST_INDENT = 4 @@ -17,6 +21,7 @@ async def retrieve( fine_tune_id: str, *, config: CLIConfigParameter, + no_plots: Annotated[bool, Parameter(help="Print training metric sparklines.", negative=())] = False, ) -> None: """Retrieve fine-tuning job details.""" response = await show_loading_status( @@ -35,6 +40,18 @@ async def retrieve( console.print(progress_text) print_model_dump(response, show_nulls=False) + + if not no_plots: + metrics_response = await show_loading_status( + "Fetching metrics...", + config.client.fine_tuning.list_metrics(fine_tune_id, resolution=console.width - METRICS_WIDTH_PADDING), + ) + metrics = metrics_response.metrics or [] + + if metrics: + console.print("\n[muted]Training metrics:[/muted]") + console.print(metrics_block_sparklines(metrics, width=console.width - METRICS_WIDTH_PADDING)) + if event_count > 0: console.print("\n[dim]FT Events:[/dim]") console.print(f" [dim]Total events:[/dim] {event_count}") diff --git a/src/together/lib/cli/components/model_dump.py b/src/together/lib/cli/components/model_dump.py index fe2e2d196..4bef66807 100644 --- a/src/together/lib/cli/components/model_dump.py +++ b/src/together/lib/cli/components/model_dump.py @@ -12,9 +12,25 @@ def print_model_dump( - model: BaseModel, show_nulls: bool = True, expand: bool = True, padding: PaddingDimensions = (0, 1, 0, 0) + model: BaseModel, + show_nulls: bool = True, + expand: bool = True, + padding: PaddingDimensions = (0, 1, 0, 0), + *, + only_set_fields: bool = False, ) -> None: - """Print an entire model with __decent__ formatting.""" + """Print an entire model with __decent__ formatting. + + Args: + model: The response model to render. + show_nulls: When True, include fields whose value is None or empty, displayed as + "n/a". When False, omit those fields entirely. + expand: Passed to the Rich table; when True, the table stretches to the terminal width. + padding: Rich table cell padding as (top, right, bottom, left). + only_set_fields: When True, only include fields present in the API response + (model.model_fields_set). Use this to avoid showing optional fields that were never + sent and still carry a default value. When False, all model fields are shown. + """ def _pretty_print_results( results: Any, show_nulls: bool = True, expand: bool = False, padding: PaddingDimensions = (0, 1, 0, 0) @@ -37,6 +53,8 @@ def _pretty_print_results( table.add_row("-", _pretty_print_results(item)) elif isinstance(results, BaseModel): table.add_row("", _pretty_print_results(results.model_dump(), show_nulls=show_nulls)) + elif isinstance(results, datetime): + table.add_row("", _colorize_value(format_datetime(results))) else: table.add_row("", _colorize_value(results)) return table @@ -70,18 +88,25 @@ def _dump_sorted_model(model: BaseModel) -> dict[str, Any]: - Lists last """ + model_dump = model.model_dump() + + if only_set_fields: + model_dump = {k: v for k, v in model_dump.items() if k in model.model_fields_set} + def _sort_items(key: str, value: Any) -> int: - # Returns a sort key: 0 for ID fields, 1 for primitives, 2 for dicts/objects, 3 for lists + # Returns a sort key: 0 for ID fields, 1 for dates, 2 for primitives, 3 for dicts/objects, 4 for lists if key.endswith("_id"): return 0 + elif isinstance(value, datetime): + return 1 elif isinstance(value, dict) or isinstance(value, BaseModel): - return 2 - elif isinstance(value, list): return 3 + elif isinstance(value, list): + return 4 else: - return 1 + return 2 - return dict(sorted(model.model_dump().items(), key=lambda kv: _sort_items(kv[0], kv[1]))) + return dict(sorted(model_dump.items(), key=lambda kv: _sort_items(kv[0], kv[1]))) console.print( _pretty_print_results(_dump_sorted_model(model), show_nulls=show_nulls, expand=expand, padding=padding) diff --git a/src/together/lib/cli/components/plot_finetune_metrics.py b/src/together/lib/cli/components/plot_finetune_metrics.py new file mode 100644 index 000000000..8a926bc84 --- /dev/null +++ b/src/together/lib/cli/components/plot_finetune_metrics.py @@ -0,0 +1,144 @@ +"""Fine-tuning metrics plotting utilities. + +Public API +---------- +``metrics_block_sparklines(metrics)`` + One ▁▂▃▄▅▆▇█ sparkline line per metric — used in ``retrieve``. + +``metrics_ascii_charts(metrics, height=6)`` + One full ASCII line chart per metric — used in ``list-metrics``. +""" + +from __future__ import annotations + +import math +from typing import Any +from collections import defaultdict + +from rich.text import Text + +from together.lib.cli.components.plots import should_log, render_line_chart, render_sparklines + +# Columns reserved for the y-axis label area, ┼ connector, leading indent, and +# surrounding margin in the ASCII chart layout. This must be >= label_width + 1 +# (the default label_width used in metrics_ascii_charts is 8, so the minimum is +# 9). Callers subtract this from the terminal width to get the usable plot width. +METRICS_WIDTH_PADDING = 48 + +_SKIP_KEYS: frozenset[str] = frozenset({"timestamp", "step", "global_step", "epoch"}) + + +def _is_skip(k: str) -> bool: + base = k.rsplit("/", 1)[-1] + return base in _SKIP_KEYS or base.endswith("_step") or base.endswith("_epoch") + + +def _step_label(x: float) -> str: + return str(int(x)) + + +def _collect_series( + metrics: list[dict[str, Any]], +) -> dict[str, tuple[list[float], list[float]]]: + """Collect plottable numeric series from a list of metric dicts. + + Returns a mapping of name → (xs, ys). Keys are discovered in insertion + order; step/epoch/timestamp fields are skipped. NaN values are converted + to ``-inf`` so the rendering engine plots them at the very bottom of the + chart rather than silently dropping them. + """ + series: dict[str, tuple[list[float], list[float]]] = defaultdict(lambda: ([], [])) + for row in metrics: + step = float(row["train/global_step"]) + for k, v in row.items(): + if _is_skip(k) or isinstance(v, bool) or not isinstance(v, (int, float)): + continue + val = float(v) + # NaN is rendered as a dip to the bottom (-inf sentinel). + if math.isnan(val): + val = float("-inf") + series[k][0].append(step) + series[k][1].append(val) + return series + + +def _no_data() -> Text: + t = Text() + t.append("No plottable metrics found.", style="muted") + return t + + +def metrics_block_sparklines( + metrics: list[dict[str, Any]], + *, + width: int = 60, +) -> Text: + """One block-sparkline line per metric, coloured with the CLI theme. + + Args: + metrics: List of flat metric dicts (one per training step). + width: Sparkline character width (default 60). + + Returns: + A ``rich.text.Text`` ready for ``console.print()``. + """ + series = _collect_series(metrics) + if not series: + return _no_data() + label_w = max(len(k) for k in series) + text = Text() + for key, (xs, ys) in series.items(): + text.append_text( + render_sparklines( + key, + xs, + ys, + width=width, + y_log=should_log(ys), + label_width=label_w, + ) + ) + return text + + +def metrics_ascii_charts( + metrics: list[dict[str, Any]], + *, + height: int = 6, + width: int = 60, + label_width: int = 8, +) -> Text: + """One ASCII line chart per metric, with a global-step x-axis. + + Args: + metrics: List of flat metric dicts (one per training step). + height: Chart body height in rows (default 6). + width: Plot character width (default 60). + + Returns: + A ``rich.text.Text`` ready for ``console.print()``. + """ + series = _collect_series(metrics) + text = Text() + for key, (xs, ys) in series.items(): + if text: + text.append("\n") + text.append_text( + render_line_chart( + xs, + {key: ys}, + x_label=_step_label, + y_log=should_log(ys), + height=height, + width=width, + label_width=label_width, + ) + ) + return text if text else _no_data() + + +__all__ = [ + "metrics_block_sparklines", + "metrics_ascii_charts", + "METRICS_WIDTH_PADDING", +] diff --git a/src/together/lib/cli/components/plots/__init__.py b/src/together/lib/cli/components/plots/__init__.py new file mode 100644 index 000000000..c2d08bf68 --- /dev/null +++ b/src/together/lib/cli/components/plots/__init__.py @@ -0,0 +1,9 @@ +"""Generic CLI plot utilities.""" + +from together.lib.cli.components.plots._engine import should_log, render_line_chart, render_sparklines + +__all__ = [ + "render_line_chart", + "render_sparklines", + "should_log", +] diff --git a/src/together/lib/cli/components/plots/_engine.py b/src/together/lib/cli/components/plots/_engine.py new file mode 100644 index 000000000..62c2f6738 --- /dev/null +++ b/src/together/lib/cli/components/plots/_engine.py @@ -0,0 +1,591 @@ +"""ASCII sparkline and chart engine for time-series data. + +Designed for scalar time-series (loss, accuracy, …); not a general-purpose +plotting library. + +Internal pipeline (``_plot``, ``_interpolate``, …) uses a shared x-grid with +named y series: ``xs: list[float]`` + ``ys: dict[str, list[float]]``. + +Public API +---------- +``render_line_chart(xs, ys, ...)`` + One or more named series plotted on a shared ASCII line chart. All series + share the same x-axis and y-scale. + +``render_sparklines(name, xs, ys, ...)`` + A single block-sparkline row (▁▂▃▄▅▆▇█). Call once per series and pass a + shared ``label_width`` across calls for consistent label alignment. Names + are right-justified; those that exceed ``label_width`` are truncated with + ``...``. +""" + +from __future__ import annotations + +import math +import bisect +from typing import Callable + +from rich.text import Text + +_SPARK_BLOCKS = " ▁▂▃▄▅▆▇█" + +# Styles cycled across series in insertion order. +_SERIES_STYLES = ["white", "green", "yellow", "cyan", "magenta"] + +# UI style tokens used throughout the rendering pipeline. +_STYLE_PRIMARY = "primary" # default plot body text +_STYLE_SECONDARY = "secondary" # axis labels and tick text +_STYLE_ACCENT = "accent" # axis border characters (┼ └ ┬ …) +_STYLE_MUTED = "muted" # series name labels and empty-state messages +_STYLE_SPARK = "white" # sparkline bar characters + +# Sentinels used in quantized_ys to signal out-of-range non-finite values. +# Both are outside the valid slot range [0, height-1]. +_NEG_INF_SENTINEL = -1 # -inf: line descends to the x-axis border +_POS_INF_SENTINEL = -2 # +inf: line ascends to the top data row +_NAN_SENTINEL = -3 # NaN: no line at the place + + +def should_log(vals: list[float]) -> bool: + """Return True when values span more than 100×, suggesting log scale.""" + positive_val = [v for v in vals if v > 0] + return len(positive_val) > 1 and (max(positive_val) / min(positive_val)) > 100 + + +def _uniform_grid(vals: list[float], n: int) -> list[float]: + """Return n evenly-spaced points spanning [min(vals), max(vals)]. + + Non-finite values (e.g. the -inf sentinel used for NaN data points) are + excluded from the range computation so they don't corrupt the grid. + """ + finite_val = [v for v in vals if math.isfinite(v)] + min_val, max_val = min(finite_val), max(finite_val) + if n <= 1: + return [min_val] + return [min_val + (max_val - min_val) * idx / (n - 1) for idx in range(n)] + + +def _interpolate( + xs: list[float], + ys: dict[str, list[float]], + x_grid: list[float], +) -> dict[str, list[float]]: + """Linearly interpolate each named y series onto x_grid; clamp at the edges. + + For each grid point: + - If it falls before the first data point, use the first y value. + - If it falls after the last data point, use the last y value. + - Otherwise, linearly interpolate between the two bracketing data points. + """ + results: dict[str, list[float]] = {} + for name, yvals in ys.items(): + # Sort by x, using insertion order as a tiebreaker so that duplicate + # steps are resolved deterministically (first occurrence wins). + pairs = sorted(enumerate(zip(xs, yvals)), key=lambda t: (t[1][0], t[0])) + xs_s = [x for _, (x, _y) in pairs] + ys_s = [y for _, (_x, y) in pairs] + + interpolated: list[float] = [] + for x_point in x_grid: + pos = bisect.bisect_left(xs_s, x_point) + if pos == 0: + interpolated.append(ys_s[0]) + elif pos == len(xs_s): + interpolated.append(ys_s[-1]) + elif xs_s[pos] == x_point: + interpolated.append(ys_s[pos]) + else: + left_x, left_y = xs_s[pos - 1], ys_s[pos - 1] + right_x, right_y = xs_s[pos], ys_s[pos] + # When either bracket endpoint is a non-finite sentinel + # (-inf/NaN or +inf) we cannot compute a meaningful slope. + # Instead, assign this grid point to whichever bracket is + # closer: if that bracket is non-finite the spike/dip extends + # to this column; if it is finite we use its value so the + # spike/dip stays as narrow as the grid resolution allows. + if not math.isfinite(left_y) or not math.isfinite(right_y): + closer_y = left_y if (x_point - left_x) <= (right_x - x_point) else right_y + interpolated.append(closer_y) + else: + slope = (right_y - left_y) / (right_x - left_x) + interpolated.append(left_y + slope * (x_point - left_x)) + + results[name] = interpolated + return results + + +def _log_transform( + named_values: dict[str, list[float]], +) -> dict[str, list[float]]: + """Return new traces with ys replaced by their log10 values.""" + result: dict[str, list[float]] = {} + for name, values in named_values.items(): + nz = [value for value in values if value > 0] + eps = min(nz) * 0.01 if nz else 1e-10 + result[name] = [math.log10(max(value, eps)) for value in values] + return result + + +def _quantize_ys( + interpolated_ys: dict[str, list[float]], + y_grid: list[float], +) -> list[list[int]]: + """Snap each interpolated y value to the index of the nearest y_grid slot. + + Non-finite values are mapped to out-of-band sentinels: + + * ``_NEG_INF_SENTINEL`` (``-1``) for ``-inf`` — the line descends to the + x-axis border row. + * ``_POS_INF_SENTINEL`` (``-2``) for ``+inf`` — the line spikes to the top + data row. + * ``_NAN_SENTINEL`` (``-3``) for ``NaN`` — no line is drawn at that point. + """ + quantized_ys: list[list[int]] = [] + for ys in interpolated_ys.values(): + row: list[int] = [] + for y in ys: + if math.isfinite(y): + row.append(min(range(len(y_grid)), key=lambda i: abs(y_grid[i] - y))) + elif y > 0: # +inf + row.append(_POS_INF_SENTINEL) + elif math.isinf(y): + row.append(_NEG_INF_SENTINEL) + else: # -inf or NaN (NaN > 0 is False) + row.append(_NAN_SENTINEL) + quantized_ys.append(row) + return quantized_ys + + +def _fit_spark_label(name: str, label_width: int) -> str: + """Right-justify *name* in *label_width* chars, truncating with '...' if needed.""" + if len(name) <= label_width: + return name.rjust(label_width) + return name[: max(0, label_width - 3)] + "..." + + +def _y_labels( + y_grid: list[float], + y_log: bool, + y_label: Callable[[float], str], +) -> list[str]: + """Build y-axis tick label strings from the y grid.""" + labels = [y_label(10**y) if y_log else y_label(y) for y in y_grid[::-1]] + return labels + + +def _x_labels( + x_grid: list[float], + n_xticks: int, + x_label: Callable[[float], str], +) -> list[tuple[int, str]]: + """Return (column_index, label_string) pairs for each x-axis tick.""" + width = len(x_grid) + x_min = x_grid[0] + # Extend by one grid step beyond the last point so the rightmost tick + # label shows the true data maximum. round() suppresses floating-point + # noise that would otherwise accumulate in the tick value calculations. + x_max = round(x_grid[-1] + ((x_grid[-1] - x_grid[0]) / (width - 1) if width > 1 else 0.0), 10) + if n_xticks < 2 or width <= 1: + return [(0, x_label(x_min))] + tick_cols = [round(i * (width - 1) / (n_xticks - 1)) for i in range(n_xticks)] + tick_vals = [x_min + (x_max - x_min) * i / (n_xticks - 1) for i in range(n_xticks)] + return [(col, x_label(val)) for col, val in zip(tick_cols, tick_vals)] + + +def _draw_y_axis( + grid: list[list[str]], + style_grid: list[list[str]], + labels: list[str], + label_w: int, +) -> None: + """Fill y-axis labels and ┼ connectors into the grid.""" + for label, grid_row, style_row in zip(labels, grid, style_grid): + if len(label) > label_w: + label = label[: max(0, label_w - 3)] + "..." + label = label.rjust(label_w) + for ci, ch in enumerate(label): + grid_row[ci] = ch + style_row[ci] = _STYLE_SECONDARY + grid_row[label_w] = "┼" + style_row[label_w] = _STYLE_ACCENT + + +def _draw_lines( + grid: list[list[str]], + style_grid: list[list[str]], + quantized_ys: list[list[int]], + styles: list[str], + label_w: int, +) -> frozenset[int]: + """Draw all series into the shared grid (last writer wins on collision). + + Coordinate system: y_grid index 0 is the *bottom* of the data range, but + grid row 0 is the *top* of the terminal output. The conversion is: + screen_row = len(grid) - y_grid_index - 1 + So a higher y_grid index means a higher data value and a *lower* screen row. + + Out-of-band sentinels (``_NEG_INF_SENTINEL``, ``_POS_INF_SENTINEL``) signal + non-finite source values: + + * ``_NEG_INF_SENTINEL`` (-inf / NaN): line descends to the x-axis border. + The set of affected plot-body column indices is returned so + ``_draw_x_axis`` can mark them with ``┴``. + * ``_POS_INF_SENTINEL`` (+inf): line spikes to the top data row (row 0). + """ + height = len(grid) + border_cols: set[int] = set() + offset = label_w + 1 + width = len(grid[0]) + for style, pv in zip(styles, quantized_ys): + # We look one column ahead (pv[col+1]), so stop one short of the end. + for col_idx in range(width - label_w - 2): + cur = pv[col_idx] + nxt = pv[col_idx + 1] + col = col_idx + offset + + cur_is_neg_inf = cur == _NEG_INF_SENTINEL + nxt_is_neg_inf = nxt == _NEG_INF_SENTINEL + cur_is_pos_inf = cur == _POS_INF_SENTINEL + nxt_is_pos_inf = nxt == _POS_INF_SENTINEL + cur_is_nan = cur == _NAN_SENTINEL + nxt_is_nan = nxt == _NAN_SENTINEL + + # Two consecutive non-finite points of the same kind: nothing to draw. + if ( + (cur_is_neg_inf and nxt_is_neg_inf) + or (cur_is_pos_inf and nxt_is_pos_inf) + or (cur_is_nan and nxt_is_nan) + ): + continue + + screen_row = height - cur - 1 + next_screen_row = height - nxt - 1 + + # Recovering from border: │ up from bottom data row to nxt. + if cur_is_neg_inf: + border_cols.add(col_idx) + grid[next_screen_row][col] = "╭" + style_grid[next_screen_row][col] = style + for mid_row in range(next_screen_row + 1, height): + grid[mid_row][col] = "│" + style_grid[mid_row][col] = style + continue + + # Descending to border: │ down from cur to bottom data row. + if nxt_is_neg_inf: + border_cols.add(col_idx) + grid[screen_row][col] = "╮" + style_grid[screen_row][col] = style + for mid_row in range(screen_row + 1, height): + grid[mid_row][col] = "│" + style_grid[mid_row][col] = style + continue + + # Descending from top: │ down from row 0 to nxt. + if cur_is_pos_inf: + grid[0][col] = "│" + style_grid[0][col] = style + for mid_row in range(1, next_screen_row): + grid[mid_row][col] = "│" + style_grid[mid_row][col] = style + grid[next_screen_row][col] = "╰" + style_grid[next_screen_row][col] = style + continue + + # Ascending to top: │ up from cur to row 0. + if nxt_is_pos_inf: + grid[screen_row][col] = "╯" + style_grid[screen_row][col] = style + for mid_row in range(1, screen_row): + grid[mid_row][col] = "│" + style_grid[mid_row][col] = style + grid[0][col] = "│" + style_grid[0][col] = style + continue + + # Continue previous line if the next one is NaN + if not cur_is_nan and nxt_is_nan: + grid[screen_row][col] = "─" + continue + + # Start a new line if the current one is nan, but the previous one is not + if cur_is_nan and not nxt_is_nan: + grid[next_screen_row][col] = "─" + continue + + # If everything is finite and good, compare the values and add horizontal line or increasing/decreasing line + if screen_row == next_screen_row: + grid[screen_row][col] = "─" + style_grid[screen_row][col] = style + continue + + going_down = cur > nxt # value decreases → line goes down on screen + grid[screen_row][col] = "╮" if going_down else "╯" + style_grid[screen_row][col] = style + grid[next_screen_row][col] = "╰" if going_down else "╭" + style_grid[next_screen_row][col] = style + for mid_row in range(min(screen_row, next_screen_row) + 1, max(screen_row, next_screen_row)): + grid[mid_row][col] = "│" + style_grid[mid_row][col] = style + + return frozenset(border_cols) + + +def _draw_x_axis( + grid: list[list[str]], + style_grid: list[list[str]], + label_w: int, + x_labels: list[tuple[int, str]], + nan_cols: frozenset[int] = frozenset(), +) -> None: + """Append the └───┬─── border row and tick label row to the grid. + + ``nan_cols`` is a set of plot-body column indices (0-based within the plot + body, i.e. not including the y-axis label area) where a NaN line descends + to the border. Those positions get ``┴`` instead of ``─``, or ``┼`` when + they coincide with an x-tick ``┬``. + """ + row_len = len(grid[0]) + width = row_len - label_w - 1 + + # Border row: spaces | └ | ─ … ┬ … ─ + tick_cols = {col for col, _ in x_labels} + border_chars = list("─" * width) + for col in tick_cols: + border_chars[col] = "┬" + + # Adding hitting lines to -inf to the border + for col in nan_cols: + if 0 <= col < width: + border_chars[col] = "┼" if col in tick_cols else "┴" + border_row = [" "] * label_w + ["└"] + border_chars + border_styles = [_STYLE_SECONDARY] * label_w + [_STYLE_ACCENT] + [_STYLE_ACCENT] * width + grid.append(border_row) + style_grid.append(border_styles) + + # Label row: tick strings centred under their tick column + label_row = [" "] * row_len + for col, lbl in x_labels: + start = label_w + 1 + col - len(lbl) // 2 + start = max(0, min(start, row_len - len(lbl))) + for i, ch in enumerate(lbl): + label_row[start + i] = ch + grid.append(label_row) + style_grid.append([_STYLE_SECONDARY] * row_len) + + +def _render_data_row( + row: list[str], + style_row: list[str], +) -> Text: + """Colorize one grid row, appending each character with its style.""" + text = Text() + for ch, style in zip(row, style_row): + text.append(ch, style=style) + text.append("\n") + return text + + +def _render_body( + grid: list[list[str]], + style_grid: list[list[str]], +) -> Text: + """Convert the finished grid into a Rich Text object.""" + text = Text() + for row, style_row in zip(grid, style_grid): + text.append_text(_render_data_row(row, style_row)) + return text + + +def _plot( + xs: list[float], + ys: dict[str, list[float]], + *, + width: int = 60, + height: int = 6, + x_label: Callable[[float], str] = str, + y_label: Callable[[float], str] = str, + y_log: bool = False, + n_xticks: int = 3, + label_width: int = 8, +) -> Text: + """Render one or more named y series against a shared x-axis as an ASCII chart. + + Args: + xs: Shared x values for all series. + ys: Mapping of name → y values (must be same length as xs). + width: Number of character columns in the plot body. + height: Number of character rows in the chart body. + x_label: Callable that formats an x value into a tick-label string. + y_label: Callable that formats a y value into a tick-label string. + y_log: When True, values are plotted on a log10 axis. + n_xticks: Number of tick marks and labels on the x-axis (default 3). + label_width: Cap on the y-axis label column width (default 8). + Labels longer than this are truncated with ``...``. + + Returns: + A ``rich.text.Text`` ready for ``console.print()``. + """ + if not ys: + t = Text() + t.append("No data.", style=_STYLE_MUTED) + return t + + ordered_styles = [_SERIES_STYLES[i % len(_SERIES_STYLES)] for i in range(len(ys))] + + x_grid = _uniform_grid(xs, width) + interpolated_ys = _interpolate(xs, ys, x_grid) + if y_log: + interpolated_ys = _log_transform(interpolated_ys) + flat_ys = [v for ys_list in interpolated_ys.values() for v in ys_list] + y_grid = _uniform_grid(flat_ys, height) + + quantized_ys = _quantize_ys(interpolated_ys, y_grid) + y_labels = _y_labels(y_grid, y_log, y_label) + x_labels = _x_labels(x_grid, n_xticks, x_label) + + grid: list[list[str]] = [[" "] * (width + label_width + 1) for _ in range(height)] + style_grid: list[list[str]] = [[_STYLE_PRIMARY] * (width + label_width + 1) for _ in range(height)] + + _draw_y_axis(grid, style_grid, y_labels, label_width) + nan_cols = _draw_lines(grid, style_grid, quantized_ys, ordered_styles, label_width) + _draw_x_axis(grid, style_grid, label_width, x_labels, nan_cols) + + text = _render_body(grid, style_grid) + return text + + +def render_sparklines( + name: str, + xs: list[float], + ys: list[float], + *, + width: int = 60, + y_log: bool = False, + label_width: int = 8, +) -> Text: + """Render a single block-sparkline row for one series. + + Call once per series, passing a shared ``label_width`` across all calls to + keep label columns aligned. The name is right-justified within the column; + names longer than ``label_width`` are truncated with ``...``. + + Args: + name: Series name, used as the row label. + xs: X values (e.g. training steps). + ys: Y values. + width: Sparkline character width (default 60). + y_log: When True, plot on a log10 scale (default False). + label_width: Exact label column width (default 8). Pass the same + value to every call in a group to get consistent + alignment. + + Returns: + A ``rich.text.Text`` ready for ``console.print()``. + """ + if not xs: + t = Text() + t.append("No plottable data.", style=_STYLE_MUTED) + return t + + x_grid = _uniform_grid(xs, width) + interpolated = _interpolate(xs, {name: ys}, x_grid) + if y_log: + interpolated = _log_transform(interpolated) + + series_vals = interpolated[name] + y_grid = _uniform_grid(series_vals, len(_SPARK_BLOCKS)) + quantized = _quantize_ys({name: series_vals}, y_grid)[0] + + label = _fit_spark_label(name, label_width) + + # The sentinel value (len(y_grid)) indicates a NaN data point; render it + # as a space (the lowest sparkline block) since sparklines have no border row. + # Map out-of-band sentinels to the extreme sparkline blocks: + # _NEG_INF_SENTINEL (-inf) or _NAN_SENTINEL (NaN) → space (lowest block, index 0) + # _POS_INF_SENTINEL (+inf) → █ (highest block, last index) + def _spark_block(idx: int) -> str: + if idx == _NEG_INF_SENTINEL or idx == _NAN_SENTINEL: + return _SPARK_BLOCKS[0] + if idx == _POS_INF_SENTINEL: + return _SPARK_BLOCKS[-1] + return _SPARK_BLOCKS[idx] + + spark = "".join(_spark_block(idx) for idx in quantized).ljust(width) + + text = Text() + text.append(f" {label} ", style=_STYLE_MUTED) + text.append(spark, style=_STYLE_SPARK) + text.append(f" {ys[0]:.4g} → {ys[-1]:.4g}", style=_STYLE_SECONDARY) + text.append("\n") + return text + + +def render_line_chart( + xs: list[float], + ys: dict[str, list[float]], + *, + x_label: Callable[[float], str] = str, + y_log: bool = False, + y_label: Callable[[float], str] | None = None, + width: int = 60, + height: int = 6, + n_xticks: int = 3, + label_width: int = 8, +) -> Text: + """Render one or more named series as a shared ASCII line chart with a legend header. + + All series share the same x-axis (``xs``); each has its own named y values:: + + console.print( + render_line_chart( + steps, + {"train_loss": train_losses, "val_loss": val_losses}, + x_label=lambda s: f"step {s:.0f}", + ) + ) + + Args: + xs: Shared x values for all series. + ys: Mapping of name → y values. + x_label: Callable that formats an x value into a tick-label string. + y_log: When True, plot on a log10 y-axis (default False). + y_label: Callable that formats a y value into a tick-label string. + width: Plot width in terminal characters (default 60). + height: Plot height in terminal rows (default 6). + n_xticks: Number of x-axis tick marks and labels (default 3). + label_width: Cap on the y-axis label column width. + + Returns: + A ``rich.text.Text`` ready for ``console.print()``. + """ + if not ys: + t = Text() + t.append("No plottable data.", style=_STYLE_MUTED) + return t + + styles = {key: _SERIES_STYLES[i % len(_SERIES_STYLES)] for i, key in enumerate(ys)} + + text = Text() + x_from = x_label(xs[0]) + x_to = x_label(xs[-1]) + for key in ys: + text.append( + f" {key} ({x_from} – {x_to}) {ys[key][0]:.4g} → {ys[key][-1]:.4g}\n", + style=styles[key], + ) + + text.append_text( + _plot( + xs, + ys, + width=width, + height=height, + x_label=x_label, + y_label=y_label or (lambda v: f"{v:.3g}"), + y_log=y_log, + n_xticks=n_xticks, + label_width=label_width, + ) + ) + return text diff --git a/src/together/lib/cli/utils/_help_examples.py b/src/together/lib/cli/utils/_help_examples.py index 224eee8a4..2bbe64ea8 100644 --- a/src/together/lib/cli/utils/_help_examples.py +++ b/src/together/lib/cli/utils/_help_examples.py @@ -100,6 +100,26 @@ [primary]tg ft create --n-checkpoints 3 -M Qwen/Qwen2-1.5B --training-file ./my-dataset.jsonl[/primary] """ +FINE_TUNING_LIST_METRICS_HELP_EXAMPLES = """[dim]Examples:[/dim] +[dim]-[/dim] Retrieve metrics for a fine-tuning job: + [primary]tg ft list-metrics [/primary] + +[dim]-[/dim] Retrieve metrics from a specific global step range: + [primary]tg ft list-metrics --global-step-from 100 --global-step-to 500[/primary] + +[dim]-[/dim] Retrieve metrics logged within a time range: + [primary]tg ft list-metrics --logged-at-from 2024-01-01T00:00:00 --logged-at-to 2024-01-02T00:00:00[/primary] + +[dim]-[/dim] Retrieve a fixed number of data points as JSON: + [primary]tg ft list-metrics --resolution 50 --json[/primary] + +[dim]-[/dim] Save raw metrics to a file: + [primary]tg ft list-metrics --json > metrics.json[/primary] + +[dim]-[/dim] Save ASCII plots to a file: + [primary]tg ft list-metrics > plots.txt[/primary] +""" + FINE_TUNING_DOWNLOAD_HELP_EXAMPLES = """[dim]Examples:[/dim] [dim]-[/dim] Download a fine-tuned model's weights: [primary]tg ft download --output-dir ./my-model[/primary] @@ -226,7 +246,8 @@ --model-b deepseek-ai/DeepSeek-V3.1 \\ --model-b-source serverless \\ --model-b-system-template "You are a concise assistant." \\ - --model-b-input-template $'Answer the following:\\n\\n{{prompt}}'[/primary] + --model-b-input-template $'Answer the following:\\n\\n{{prompt}}' \\ + --disable-position-bias-correction[/primary] """ ## Beta clusters API commands @@ -251,6 +272,10 @@ [dim]-[/dim] Update or delete a cluster: [primary]tg beta clusters update --num-gpus 16 --cluster-type KUBERNETES[/primary] [primary]tg beta clusters delete [/primary] + +[dim]-[/dim] Manage node remediations: + [primary]tg beta clusters remediations ls [/primary] + [primary]tg beta clusters remediations create --mode VM_ONLY[/primary] """ BETA_CLUSTERS_CREATE_HELP_EXAMPLES = """[dim]Examples:[/dim] @@ -329,6 +354,42 @@ [primary]tg beta clusters storage update --size-tib 4[/primary] """ +BETA_CLUSTERS_REMEDIATIONS_HELP_EXAMPLES = """[dim]Examples:[/dim] +[dim]-[/dim] List all remediations for a cluster: + [primary]tg beta clusters remediations ls [/primary] + +[dim]-[/dim] List remediations for one instance: + [primary]tg beta clusters remediations ls [/primary] + +[dim]-[/dim] List automated remediations by mode: + [primary]tg beta clusters remediations ls --mode VM_ONLY --mode REBOOT_VM --trigger AUTOMATED[/primary] + +[dim]-[/dim] Create a remediation: + [primary]tg beta clusters remediations create --mode VM_ONLY --reason "node unhealthy"[/primary] + +[dim]-[/dim] Get remediation details: + [primary]tg beta clusters remediations [/primary] + +[dim]-[/dim] Review automated remediations: + [primary]tg beta clusters remediations approve [/primary] + [primary]tg beta clusters remediations reject --comment "already handled"[/primary] + [primary]tg beta clusters remediations cancel [/primary] +""" + +BETA_CLUSTERS_REMEDIATIONS_CREATE_HELP_EXAMPLES = """[dim]Examples:[/dim] +[dim]-[/dim] Create a VM-only remediation: + [primary]tg beta clusters remediations create --mode VM_ONLY[/primary] + +[dim]-[/dim] Create a host-aware remediation: + [primary]tg beta clusters remediations create --mode HOST_AWARE[/primary] + +[dim]-[/dim] Create a eviction-without-replacement remediation: + [primary]tg beta clusters remediations create --mode EVICT_WITHOUT_REPLACEMENT[/primary] + +[dim]-[/dim] Create a reboot-vm remediation: + [primary]tg beta clusters remediations create --mode REBOOT_VM[/primary] +""" + ## Beta > Jig commands JIG_HELP_EXAMPLES = """[dim]Examples:[/dim] @@ -411,6 +472,9 @@ [dim]-[/dim] Stream logs ([primary]Ctrl+C[/primary] to stop): [primary]tg beta jig logs --follow[/primary] + +[dim]-[/dim] Filter logs by replica and deployment revision: + [primary]tg beta jig logs --replica-id --revision --image-version [/primary] """ JIG_SUBMIT_HELP_EXAMPLES = """[dim]Examples:[/dim] diff --git a/src/together/lib/cli/utils/_preparse_tokens.py b/src/together/lib/cli/utils/_preparse_tokens.py index cfc7ba9a7..2b302303a 100644 --- a/src/together/lib/cli/utils/_preparse_tokens.py +++ b/src/together/lib/cli/utils/_preparse_tokens.py @@ -15,6 +15,7 @@ "endpoints": re.compile(r"^endpoint-"), "beta clusters": _UUID_RE, "beta clusters storage": _UUID_RE, + "beta clusters remediations": _UUID_RE, "beta jig volumes": _UUID_RE, } diff --git a/src/together/resources/beta/clusters/__init__.py b/src/together/resources/beta/clusters/__init__.py index a428a2a51..58a0f2342 100644 --- a/src/together/resources/beta/clusters/__init__.py +++ b/src/together/resources/beta/clusters/__init__.py @@ -16,8 +16,22 @@ ClustersResourceWithStreamingResponse, AsyncClustersResourceWithStreamingResponse, ) +from .remediations import ( + RemediationsResource, + AsyncRemediationsResource, + RemediationsResourceWithRawResponse, + AsyncRemediationsResourceWithRawResponse, + RemediationsResourceWithStreamingResponse, + AsyncRemediationsResourceWithStreamingResponse, +) __all__ = [ + "RemediationsResource", + "AsyncRemediationsResource", + "RemediationsResourceWithRawResponse", + "AsyncRemediationsResourceWithRawResponse", + "RemediationsResourceWithStreamingResponse", + "AsyncRemediationsResourceWithStreamingResponse", "StorageResource", "AsyncStorageResource", "StorageResourceWithRawResponse", diff --git a/src/together/resources/beta/clusters/clusters.py b/src/together/resources/beta/clusters/clusters.py index 5014087ac..75efca2fe 100644 --- a/src/together/resources/beta/clusters/clusters.py +++ b/src/together/resources/beta/clusters/clusters.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Union +from typing import Union, Iterable from datetime import datetime from typing_extensions import Literal @@ -26,7 +26,15 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ....types.beta import cluster_create_params, cluster_update_params +from .remediations import ( + RemediationsResource, + AsyncRemediationsResource, + RemediationsResourceWithRawResponse, + AsyncRemediationsResourceWithRawResponse, + RemediationsResourceWithStreamingResponse, + AsyncRemediationsResourceWithStreamingResponse, +) +from ....types.beta import cluster_list_params, cluster_create_params, cluster_update_params from ...._base_client import make_request_options from ....types.beta.cluster import Cluster from ....types.beta.cluster_list_response import ClusterListResponse @@ -37,6 +45,10 @@ class ClustersResource(SyncAPIResource): + @cached_property + def remediations(self) -> RemediationsResource: + return RemediationsResource(self._client) + @cached_property def storage(self) -> StorageResource: return StorageResource(self._client) @@ -70,13 +82,22 @@ def create( num_gpus: int, nvidia_driver_version: str, region: str, + acceptance_tests_params: cluster_create_params.AcceptanceTestsParams | Omit = omit, + add_ons: Iterable[cluster_create_params.AddOn] | Omit = omit, + auto_scale: bool | Omit = omit, auto_scale_max_gpus: int | Omit = omit, auto_scaled: bool | Omit = omit, capacity_pool_id: str | Omit = omit, + cluster_config: cluster_create_params.ClusterConfig | Omit = omit, cluster_type: Literal["KUBERNETES", "SLURM"] | Omit = omit, duration_days: int | Omit = omit, gpu_node_failover_enabled: bool | Omit = omit, install_traefik: bool | Omit = omit, + num_capacity_pool_gpus: int | Omit = omit, + num_preemptible_gpus: int | Omit = omit, + num_reserved_gpus: int | Omit = omit, + oidc_config: cluster_create_params.OidcConfig | Omit = omit, + project_id: str | Omit = omit, reservation_end_time: Union[str, datetime] | Omit = omit, reservation_start_time: Union[str, datetime] | Omit = omit, shared_volume: cluster_create_params.SharedVolume | Omit = omit, @@ -120,6 +141,15 @@ def create( region: Region to create the GPU cluster in. Usable regions can be found from `client.clusters.list_regions()` + acceptance_tests_params: AcceptanceTestsParams groups all GPU acceptance test options when enabled is + true. + + add_ons: Add-ons to enable on the cluster at creation time. + + auto_scale: Whether to enable auto-scaling for the cluster. If true, the cluster will + automatically scale the number of GPU worker nodes between num_gpus and + auto_scale_max_gpus based on the workload. + auto_scale_max_gpus: Maximum number of GPUs to which the cluster can be auto-scaled up. This field is required if auto_scaled is true. @@ -139,6 +169,20 @@ def create( install_traefik: Whether to install Traefik ingress controller in the cluster. This field is only applicable for Kubernetes clusters and is false by default. + num_capacity_pool_gpus: Number of GPUs to allocate from the capacity pool. Must be a multiple of 8 and + not exceed num_gpus. + + num_preemptible_gpus: Number of preemptible GPUs to request alongside on-demand capacity. Must be a + multiple of 8. Preemptible nodes are cheaper but may be reclaimed when on-demand + capacity is needed elsewhere; the system fulfills this asynchronously and + surfaces the actual count in allocated_preemptible_gpus. + + num_reserved_gpus: Number of prepaid (PLG) reserved GPUs for this cluster. When omitted for + RESERVED billing on create, the server defaults this to num_gpus. + + project_id: Project ID for the cluster. If not set, the project from the request context is + used. + reservation_end_time: Reservation end time of the cluster. This field is required for SCHEDULED billing to specify the reservation end time for the cluster. @@ -174,13 +218,22 @@ def create( "num_gpus": num_gpus, "nvidia_driver_version": nvidia_driver_version, "region": region, + "acceptance_tests_params": acceptance_tests_params, + "add_ons": add_ons, + "auto_scale": auto_scale, "auto_scale_max_gpus": auto_scale_max_gpus, "auto_scaled": auto_scaled, "capacity_pool_id": capacity_pool_id, + "cluster_config": cluster_config, "cluster_type": cluster_type, "duration_days": duration_days, "gpu_node_failover_enabled": gpu_node_failover_enabled, "install_traefik": install_traefik, + "num_capacity_pool_gpus": num_capacity_pool_gpus, + "num_preemptible_gpus": num_preemptible_gpus, + "num_reserved_gpus": num_reserved_gpus, + "oidc_config": oidc_config, + "project_id": project_id, "reservation_end_time": reservation_end_time, "reservation_start_time": reservation_start_time, "shared_volume": shared_volume, @@ -235,8 +288,12 @@ def update( self, cluster_id: str, *, + add_ons: Iterable[cluster_update_params.AddOn] | Omit = omit, + cluster_config: cluster_update_params.ClusterConfig | Omit = omit, cluster_type: Literal["KUBERNETES", "SLURM"] | Omit = omit, num_gpus: int | Omit = omit, + num_preemptible_gpus: int | Omit = omit, + num_reserved_gpus: int | Omit = omit, reservation_end_time: Union[str, datetime] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -251,10 +308,20 @@ def update( Args: cluster_id: The ID of the cluster to update + add_ons: Add-ons to update on the cluster. Each entry identifies an existing add-on by + name and provides the new external config to merge. + cluster_type: Type of cluster to update. - num_gpus: Number of GPUs to allocate in the cluster. This must be multiple of 8. For - example, 8, 16 or 24 + num_gpus: Target GPU count for the cluster. When omitted, the server keeps the current GPU + count from cluster metadata (use for config-only or decommission-time-only + updates). + + num_preemptible_gpus: Updated desired number of preemptible GPUs for the cluster. When omitted, the + current value is preserved. Must be a multiple of 8. + + num_reserved_gpus: Number of reserved GPUs to update to. This field is only applicable for clusters + with RESERVED billing type. reservation_end_time: Timestamp at which the cluster should be decommissioned. Only accepted for prepaid clusters. @@ -273,8 +340,12 @@ def update( path_template("/compute/clusters/{cluster_id}", cluster_id=cluster_id), body=maybe_transform( { + "add_ons": add_ons, + "cluster_config": cluster_config, "cluster_type": cluster_type, "num_gpus": num_gpus, + "num_preemptible_gpus": num_preemptible_gpus, + "num_reserved_gpus": num_reserved_gpus, "reservation_end_time": reservation_end_time, }, cluster_update_params.ClusterUpdateParams, @@ -288,6 +359,7 @@ def update( def list( self, *, + project_id: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -295,11 +367,30 @@ def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterListResponse: - """List all GPU clusters.""" + """ + List all GPU clusters. + + Args: + project_id: Optional UMS project ID to filter clusters by. When set, only clusters belonging + to this project are returned. The caller must be a member of the project; + otherwise the result set will be empty. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ return self._get( "/compute/clusters", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"project_id": project_id}, cluster_list_params.ClusterListParams), ), cast_to=ClusterListResponse, ) @@ -360,6 +451,10 @@ def list_regions( class AsyncClustersResource(AsyncAPIResource): + @cached_property + def remediations(self) -> AsyncRemediationsResource: + return AsyncRemediationsResource(self._client) + @cached_property def storage(self) -> AsyncStorageResource: return AsyncStorageResource(self._client) @@ -393,13 +488,22 @@ async def create( num_gpus: int, nvidia_driver_version: str, region: str, + acceptance_tests_params: cluster_create_params.AcceptanceTestsParams | Omit = omit, + add_ons: Iterable[cluster_create_params.AddOn] | Omit = omit, + auto_scale: bool | Omit = omit, auto_scale_max_gpus: int | Omit = omit, auto_scaled: bool | Omit = omit, capacity_pool_id: str | Omit = omit, + cluster_config: cluster_create_params.ClusterConfig | Omit = omit, cluster_type: Literal["KUBERNETES", "SLURM"] | Omit = omit, duration_days: int | Omit = omit, gpu_node_failover_enabled: bool | Omit = omit, install_traefik: bool | Omit = omit, + num_capacity_pool_gpus: int | Omit = omit, + num_preemptible_gpus: int | Omit = omit, + num_reserved_gpus: int | Omit = omit, + oidc_config: cluster_create_params.OidcConfig | Omit = omit, + project_id: str | Omit = omit, reservation_end_time: Union[str, datetime] | Omit = omit, reservation_start_time: Union[str, datetime] | Omit = omit, shared_volume: cluster_create_params.SharedVolume | Omit = omit, @@ -443,6 +547,15 @@ async def create( region: Region to create the GPU cluster in. Usable regions can be found from `client.clusters.list_regions()` + acceptance_tests_params: AcceptanceTestsParams groups all GPU acceptance test options when enabled is + true. + + add_ons: Add-ons to enable on the cluster at creation time. + + auto_scale: Whether to enable auto-scaling for the cluster. If true, the cluster will + automatically scale the number of GPU worker nodes between num_gpus and + auto_scale_max_gpus based on the workload. + auto_scale_max_gpus: Maximum number of GPUs to which the cluster can be auto-scaled up. This field is required if auto_scaled is true. @@ -462,6 +575,20 @@ async def create( install_traefik: Whether to install Traefik ingress controller in the cluster. This field is only applicable for Kubernetes clusters and is false by default. + num_capacity_pool_gpus: Number of GPUs to allocate from the capacity pool. Must be a multiple of 8 and + not exceed num_gpus. + + num_preemptible_gpus: Number of preemptible GPUs to request alongside on-demand capacity. Must be a + multiple of 8. Preemptible nodes are cheaper but may be reclaimed when on-demand + capacity is needed elsewhere; the system fulfills this asynchronously and + surfaces the actual count in allocated_preemptible_gpus. + + num_reserved_gpus: Number of prepaid (PLG) reserved GPUs for this cluster. When omitted for + RESERVED billing on create, the server defaults this to num_gpus. + + project_id: Project ID for the cluster. If not set, the project from the request context is + used. + reservation_end_time: Reservation end time of the cluster. This field is required for SCHEDULED billing to specify the reservation end time for the cluster. @@ -497,13 +624,22 @@ async def create( "num_gpus": num_gpus, "nvidia_driver_version": nvidia_driver_version, "region": region, + "acceptance_tests_params": acceptance_tests_params, + "add_ons": add_ons, + "auto_scale": auto_scale, "auto_scale_max_gpus": auto_scale_max_gpus, "auto_scaled": auto_scaled, "capacity_pool_id": capacity_pool_id, + "cluster_config": cluster_config, "cluster_type": cluster_type, "duration_days": duration_days, "gpu_node_failover_enabled": gpu_node_failover_enabled, "install_traefik": install_traefik, + "num_capacity_pool_gpus": num_capacity_pool_gpus, + "num_preemptible_gpus": num_preemptible_gpus, + "num_reserved_gpus": num_reserved_gpus, + "oidc_config": oidc_config, + "project_id": project_id, "reservation_end_time": reservation_end_time, "reservation_start_time": reservation_start_time, "shared_volume": shared_volume, @@ -558,8 +694,12 @@ async def update( self, cluster_id: str, *, + add_ons: Iterable[cluster_update_params.AddOn] | Omit = omit, + cluster_config: cluster_update_params.ClusterConfig | Omit = omit, cluster_type: Literal["KUBERNETES", "SLURM"] | Omit = omit, num_gpus: int | Omit = omit, + num_preemptible_gpus: int | Omit = omit, + num_reserved_gpus: int | Omit = omit, reservation_end_time: Union[str, datetime] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -574,10 +714,20 @@ async def update( Args: cluster_id: The ID of the cluster to update + add_ons: Add-ons to update on the cluster. Each entry identifies an existing add-on by + name and provides the new external config to merge. + cluster_type: Type of cluster to update. - num_gpus: Number of GPUs to allocate in the cluster. This must be multiple of 8. For - example, 8, 16 or 24 + num_gpus: Target GPU count for the cluster. When omitted, the server keeps the current GPU + count from cluster metadata (use for config-only or decommission-time-only + updates). + + num_preemptible_gpus: Updated desired number of preemptible GPUs for the cluster. When omitted, the + current value is preserved. Must be a multiple of 8. + + num_reserved_gpus: Number of reserved GPUs to update to. This field is only applicable for clusters + with RESERVED billing type. reservation_end_time: Timestamp at which the cluster should be decommissioned. Only accepted for prepaid clusters. @@ -596,8 +746,12 @@ async def update( path_template("/compute/clusters/{cluster_id}", cluster_id=cluster_id), body=await async_maybe_transform( { + "add_ons": add_ons, + "cluster_config": cluster_config, "cluster_type": cluster_type, "num_gpus": num_gpus, + "num_preemptible_gpus": num_preemptible_gpus, + "num_reserved_gpus": num_reserved_gpus, "reservation_end_time": reservation_end_time, }, cluster_update_params.ClusterUpdateParams, @@ -611,6 +765,7 @@ async def update( async def list( self, *, + project_id: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -618,11 +773,30 @@ async def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ClusterListResponse: - """List all GPU clusters.""" + """ + List all GPU clusters. + + Args: + project_id: Optional UMS project ID to filter clusters by. When set, only clusters belonging + to this project are returned. The caller must be a member of the project; + otherwise the result set will be empty. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ return await self._get( "/compute/clusters", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"project_id": project_id}, cluster_list_params.ClusterListParams), ), cast_to=ClusterListResponse, ) @@ -705,6 +879,10 @@ def __init__(self, clusters: ClustersResource) -> None: clusters.list_regions, ) + @cached_property + def remediations(self) -> RemediationsResourceWithRawResponse: + return RemediationsResourceWithRawResponse(self._clusters.remediations) + @cached_property def storage(self) -> StorageResourceWithRawResponse: return StorageResourceWithRawResponse(self._clusters.storage) @@ -733,6 +911,10 @@ def __init__(self, clusters: AsyncClustersResource) -> None: clusters.list_regions, ) + @cached_property + def remediations(self) -> AsyncRemediationsResourceWithRawResponse: + return AsyncRemediationsResourceWithRawResponse(self._clusters.remediations) + @cached_property def storage(self) -> AsyncStorageResourceWithRawResponse: return AsyncStorageResourceWithRawResponse(self._clusters.storage) @@ -761,6 +943,10 @@ def __init__(self, clusters: ClustersResource) -> None: clusters.list_regions, ) + @cached_property + def remediations(self) -> RemediationsResourceWithStreamingResponse: + return RemediationsResourceWithStreamingResponse(self._clusters.remediations) + @cached_property def storage(self) -> StorageResourceWithStreamingResponse: return StorageResourceWithStreamingResponse(self._clusters.storage) @@ -789,6 +975,10 @@ def __init__(self, clusters: AsyncClustersResource) -> None: clusters.list_regions, ) + @cached_property + def remediations(self) -> AsyncRemediationsResourceWithStreamingResponse: + return AsyncRemediationsResourceWithStreamingResponse(self._clusters.remediations) + @cached_property def storage(self) -> AsyncStorageResourceWithStreamingResponse: return AsyncStorageResourceWithStreamingResponse(self._clusters.storage) diff --git a/src/together/resources/beta/clusters/remediations.py b/src/together/resources/beta/clusters/remediations.py new file mode 100644 index 000000000..087b3a2e1 --- /dev/null +++ b/src/together/resources/beta/clusters/remediations.py @@ -0,0 +1,930 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import List +from typing_extensions import Literal + +import httpx + +from ...._types import Body, Omit, Query, Headers, NotGiven, omit, not_given +from ...._utils import path_template, maybe_transform, async_maybe_transform +from ...._compat import cached_property +from ...._resource import SyncAPIResource, AsyncAPIResource +from ...._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from ...._base_client import make_request_options +from ....types.beta.clusters import ( + remediation_list_params, + remediation_create_params, + remediation_reject_params, + remediation_approve_params, +) +from ....types.beta.clusters.remediation import Remediation +from ....types.beta.clusters.remediation_list_response import RemediationListResponse + +__all__ = ["RemediationsResource", "AsyncRemediationsResource"] + + +class RemediationsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> RemediationsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/togethercomputer/together-py#accessing-raw-response-data-eg-headers + """ + return RemediationsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> RemediationsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/togethercomputer/together-py#with_streaming_response + """ + return RemediationsResourceWithStreamingResponse(self) + + def create( + self, + instance_id: str, + *, + cluster_id: str, + mode: Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ], + remediation_id: str | Omit = omit, + reason: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Creates a new remediation for an instance. + + Remediations created via the API goes directly to PENDING state. + + Our system may trigger automated remediations that require approval. These + remediations are created with PENDING_APPROVAL state. The user must call + /approve to start the actual remediation process. These operations can also be + rejected by calling /reject. + + Args: + mode: Remediation mode specifies how the remediation should be performed. + + - `REMEDIATION_MODE_VM_ONLY`: Deletes the VM and provisions a new one on any + available host. + - `REMEDIATION_MODE_HOST_AWARE`: Cordons the host, deletes the VM, and + provisions a new one on a different host. + + remediation_id: Client-specified ID for idempotency. + + reason: User-provided reason for the remediation. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + return self._post( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations", + cluster_id=cluster_id, + instance_id=instance_id, + ), + body=maybe_transform( + { + "mode": mode, + "reason": reason, + }, + remediation_create_params.RemediationCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + {"remediation_id": remediation_id}, remediation_create_params.RemediationCreateParams + ), + ), + cast_to=Remediation, + ) + + def retrieve( + self, + remediation_id: str, + *, + cluster_id: str, + instance_id: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Retrieve the status of a specific remdiation on a specific instance in a + specific cluster. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + if not remediation_id: + raise ValueError(f"Expected a non-empty value for `remediation_id` but received {remediation_id!r}") + return self._get( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations/{remediation_id}", + cluster_id=cluster_id, + instance_id=instance_id, + remediation_id=remediation_id, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Remediation, + ) + + def list( + self, + instance_id: str, + *, + cluster_id: str, + mode: List[ + Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ] + ] + | Omit = omit, + order_by: str | Omit = omit, + page_size: int | Omit = omit, + page_token: str | Omit = omit, + state: List[ + Literal["PENDING_APPROVAL", "PENDING", "RUNNING", "SUCCEEDED", "FAILED", "CANCELLED", "AUTO_RESOLVED"] + ] + | Omit = omit, + trigger: List[Literal["REMEDIATION_TRIGGER_MANUAL", "REMEDIATION_TRIGGER_AUTOMATED"]] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> RemediationListResponse: + """ + Lists remediations for an instance or cluster. + + Args: + instance_id: To list remediations on a specific node, pass the node's instance ID. To list + remediations for all nodes in a cluster, pass `-` as a wildcard for the instance + ID. + + mode: Filter by remediation mode(s). Returns remediations matching any of the + specified modes. + + order_by: Order by expression. + + page_size: Maximum results to return. + + page_token: Pagination token from previous request. + + state: Filter by state(s). Returns remediations matching any of the specified states. + + - `PENDING_APPROVAL`: Awaiting approval before processing can begin. + - `PENDING`: Approved and queued for processing. + - `RUNNING`: Actively being processed. + - `SUCCEEDED`: Successfully completed. + - `FAILED`: Failed with an error. + - `CANCELLED`: Cancelled by user or system. + - `AUTO_RESOLVED`: The underlying issue was automatically resolved before + processing. + + trigger: Filter by trigger type(s). Returns remediations matching any of the specified + triggers. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + return self._get( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations", + cluster_id=cluster_id, + instance_id=instance_id, + ), + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "mode": mode, + "order_by": order_by, + "page_size": page_size, + "page_token": page_token, + "state": state, + "trigger": trigger, + }, + remediation_list_params.RemediationListParams, + ), + ), + cast_to=RemediationListResponse, + ) + + def approve( + self, + remediation_id: str, + *, + cluster_id: str, + instance_id: str, + comment: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Approves a pending remediation. + + Only remediations with state PENDING_APPROVAL can be approved. + + On APPROVE: state changes to PENDING and the remediation process begins. The + reviewed_by, review_time, and review_comment fields are populated on the + remediation after approval. + + Args: + comment: Comment explaining the action. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + if not remediation_id: + raise ValueError(f"Expected a non-empty value for `remediation_id` but received {remediation_id!r}") + return self._post( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations/{remediation_id}/approve", + cluster_id=cluster_id, + instance_id=instance_id, + remediation_id=remediation_id, + ), + body=maybe_transform({"comment": comment}, remediation_approve_params.RemediationApproveParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Remediation, + ) + + def cancel( + self, + remediation_id: str, + *, + cluster_id: str, + instance_id: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Cancels a pending remediation. + + Only remediations in PENDING_APPROVAL or PENDING state can be cancelled. + + Args: + cluster_id: The cluster ID. + + instance_id: The instance ID. + + remediation_id: The remediation ID. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + if not remediation_id: + raise ValueError(f"Expected a non-empty value for `remediation_id` but received {remediation_id!r}") + return self._post( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations/{remediation_id}/cancel", + cluster_id=cluster_id, + instance_id=instance_id, + remediation_id=remediation_id, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Remediation, + ) + + def reject( + self, + remediation_id: str, + *, + cluster_id: str, + instance_id: str, + comment: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Rejects a pending remediation. + + Only remediations with state PENDING_APPROVAL can be rejected. + + On REJECT: state changes to CANCELLED. The reviewed_by, review_time, and + review_comment fields are populated on the remediation after rejection. + + Args: + comment: Comment explaining the action. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + if not remediation_id: + raise ValueError(f"Expected a non-empty value for `remediation_id` but received {remediation_id!r}") + return self._post( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations/{remediation_id}/reject", + cluster_id=cluster_id, + instance_id=instance_id, + remediation_id=remediation_id, + ), + body=maybe_transform({"comment": comment}, remediation_reject_params.RemediationRejectParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Remediation, + ) + + +class AsyncRemediationsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncRemediationsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/togethercomputer/together-py#accessing-raw-response-data-eg-headers + """ + return AsyncRemediationsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncRemediationsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/togethercomputer/together-py#with_streaming_response + """ + return AsyncRemediationsResourceWithStreamingResponse(self) + + async def create( + self, + instance_id: str, + *, + cluster_id: str, + mode: Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ], + remediation_id: str | Omit = omit, + reason: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Creates a new remediation for an instance. + + Remediations created via the API goes directly to PENDING state. + + Our system may trigger automated remediations that require approval. These + remediations are created with PENDING_APPROVAL state. The user must call + /approve to start the actual remediation process. These operations can also be + rejected by calling /reject. + + Args: + mode: Remediation mode specifies how the remediation should be performed. + + - `REMEDIATION_MODE_VM_ONLY`: Deletes the VM and provisions a new one on any + available host. + - `REMEDIATION_MODE_HOST_AWARE`: Cordons the host, deletes the VM, and + provisions a new one on a different host. + + remediation_id: Client-specified ID for idempotency. + + reason: User-provided reason for the remediation. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + return await self._post( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations", + cluster_id=cluster_id, + instance_id=instance_id, + ), + body=await async_maybe_transform( + { + "mode": mode, + "reason": reason, + }, + remediation_create_params.RemediationCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + {"remediation_id": remediation_id}, remediation_create_params.RemediationCreateParams + ), + ), + cast_to=Remediation, + ) + + async def retrieve( + self, + remediation_id: str, + *, + cluster_id: str, + instance_id: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Retrieve the status of a specific remdiation on a specific instance in a + specific cluster. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + if not remediation_id: + raise ValueError(f"Expected a non-empty value for `remediation_id` but received {remediation_id!r}") + return await self._get( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations/{remediation_id}", + cluster_id=cluster_id, + instance_id=instance_id, + remediation_id=remediation_id, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Remediation, + ) + + async def list( + self, + instance_id: str, + *, + cluster_id: str, + mode: List[ + Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ] + ] + | Omit = omit, + order_by: str | Omit = omit, + page_size: int | Omit = omit, + page_token: str | Omit = omit, + state: List[ + Literal["PENDING_APPROVAL", "PENDING", "RUNNING", "SUCCEEDED", "FAILED", "CANCELLED", "AUTO_RESOLVED"] + ] + | Omit = omit, + trigger: List[Literal["REMEDIATION_TRIGGER_MANUAL", "REMEDIATION_TRIGGER_AUTOMATED"]] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> RemediationListResponse: + """ + Lists remediations for an instance or cluster. + + Args: + instance_id: To list remediations on a specific node, pass the node's instance ID. To list + remediations for all nodes in a cluster, pass `-` as a wildcard for the instance + ID. + + mode: Filter by remediation mode(s). Returns remediations matching any of the + specified modes. + + order_by: Order by expression. + + page_size: Maximum results to return. + + page_token: Pagination token from previous request. + + state: Filter by state(s). Returns remediations matching any of the specified states. + + - `PENDING_APPROVAL`: Awaiting approval before processing can begin. + - `PENDING`: Approved and queued for processing. + - `RUNNING`: Actively being processed. + - `SUCCEEDED`: Successfully completed. + - `FAILED`: Failed with an error. + - `CANCELLED`: Cancelled by user or system. + - `AUTO_RESOLVED`: The underlying issue was automatically resolved before + processing. + + trigger: Filter by trigger type(s). Returns remediations matching any of the specified + triggers. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + return await self._get( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations", + cluster_id=cluster_id, + instance_id=instance_id, + ), + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + { + "mode": mode, + "order_by": order_by, + "page_size": page_size, + "page_token": page_token, + "state": state, + "trigger": trigger, + }, + remediation_list_params.RemediationListParams, + ), + ), + cast_to=RemediationListResponse, + ) + + async def approve( + self, + remediation_id: str, + *, + cluster_id: str, + instance_id: str, + comment: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Approves a pending remediation. + + Only remediations with state PENDING_APPROVAL can be approved. + + On APPROVE: state changes to PENDING and the remediation process begins. The + reviewed_by, review_time, and review_comment fields are populated on the + remediation after approval. + + Args: + comment: Comment explaining the action. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + if not remediation_id: + raise ValueError(f"Expected a non-empty value for `remediation_id` but received {remediation_id!r}") + return await self._post( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations/{remediation_id}/approve", + cluster_id=cluster_id, + instance_id=instance_id, + remediation_id=remediation_id, + ), + body=await async_maybe_transform({"comment": comment}, remediation_approve_params.RemediationApproveParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Remediation, + ) + + async def cancel( + self, + remediation_id: str, + *, + cluster_id: str, + instance_id: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Cancels a pending remediation. + + Only remediations in PENDING_APPROVAL or PENDING state can be cancelled. + + Args: + cluster_id: The cluster ID. + + instance_id: The instance ID. + + remediation_id: The remediation ID. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + if not remediation_id: + raise ValueError(f"Expected a non-empty value for `remediation_id` but received {remediation_id!r}") + return await self._post( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations/{remediation_id}/cancel", + cluster_id=cluster_id, + instance_id=instance_id, + remediation_id=remediation_id, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Remediation, + ) + + async def reject( + self, + remediation_id: str, + *, + cluster_id: str, + instance_id: str, + comment: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Remediation: + """ + Rejects a pending remediation. + + Only remediations with state PENDING_APPROVAL can be rejected. + + On REJECT: state changes to CANCELLED. The reviewed_by, review_time, and + review_comment fields are populated on the remediation after rejection. + + Args: + comment: Comment explaining the action. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not cluster_id: + raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") + if not instance_id: + raise ValueError(f"Expected a non-empty value for `instance_id` but received {instance_id!r}") + if not remediation_id: + raise ValueError(f"Expected a non-empty value for `remediation_id` but received {remediation_id!r}") + return await self._post( + path_template( + "/compute/clusters/{cluster_id}/instances/{instance_id}/remediations/{remediation_id}/reject", + cluster_id=cluster_id, + instance_id=instance_id, + remediation_id=remediation_id, + ), + body=await async_maybe_transform({"comment": comment}, remediation_reject_params.RemediationRejectParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Remediation, + ) + + +class RemediationsResourceWithRawResponse: + def __init__(self, remediations: RemediationsResource) -> None: + self._remediations = remediations + + self.create = to_raw_response_wrapper( + remediations.create, + ) + self.retrieve = to_raw_response_wrapper( + remediations.retrieve, + ) + self.list = to_raw_response_wrapper( + remediations.list, + ) + self.approve = to_raw_response_wrapper( + remediations.approve, + ) + self.cancel = to_raw_response_wrapper( + remediations.cancel, + ) + self.reject = to_raw_response_wrapper( + remediations.reject, + ) + + +class AsyncRemediationsResourceWithRawResponse: + def __init__(self, remediations: AsyncRemediationsResource) -> None: + self._remediations = remediations + + self.create = async_to_raw_response_wrapper( + remediations.create, + ) + self.retrieve = async_to_raw_response_wrapper( + remediations.retrieve, + ) + self.list = async_to_raw_response_wrapper( + remediations.list, + ) + self.approve = async_to_raw_response_wrapper( + remediations.approve, + ) + self.cancel = async_to_raw_response_wrapper( + remediations.cancel, + ) + self.reject = async_to_raw_response_wrapper( + remediations.reject, + ) + + +class RemediationsResourceWithStreamingResponse: + def __init__(self, remediations: RemediationsResource) -> None: + self._remediations = remediations + + self.create = to_streamed_response_wrapper( + remediations.create, + ) + self.retrieve = to_streamed_response_wrapper( + remediations.retrieve, + ) + self.list = to_streamed_response_wrapper( + remediations.list, + ) + self.approve = to_streamed_response_wrapper( + remediations.approve, + ) + self.cancel = to_streamed_response_wrapper( + remediations.cancel, + ) + self.reject = to_streamed_response_wrapper( + remediations.reject, + ) + + +class AsyncRemediationsResourceWithStreamingResponse: + def __init__(self, remediations: AsyncRemediationsResource) -> None: + self._remediations = remediations + + self.create = async_to_streamed_response_wrapper( + remediations.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + remediations.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + remediations.list, + ) + self.approve = async_to_streamed_response_wrapper( + remediations.approve, + ) + self.cancel = async_to_streamed_response_wrapper( + remediations.cancel, + ) + self.reject = async_to_streamed_response_wrapper( + remediations.reject, + ) diff --git a/src/together/resources/beta/clusters/storage.py b/src/together/resources/beta/clusters/storage.py index c6abf44e8..2493f96cf 100644 --- a/src/together/resources/beta/clusters/storage.py +++ b/src/together/resources/beta/clusters/storage.py @@ -15,7 +15,7 @@ async_to_streamed_response_wrapper, ) from ...._base_client import make_request_options -from ....types.beta.clusters import storage_create_params, storage_update_params +from ....types.beta.clusters import storage_list_params, storage_create_params, storage_update_params from ....types.beta.clusters.cluster_storage import ClusterStorage from ....types.beta.clusters.storage_list_response import StorageListResponse from ....types.beta.clusters.storage_delete_response import StorageDeleteResponse @@ -49,6 +49,7 @@ def create( region: str, size_tib: int, volume_name: str, + is_lifecycle_independent: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -64,11 +65,13 @@ def create( performance for shared storage. Args: - region: Region name. Usable regions can be found from `client.clusters.list_regions()` + region: Region name. Usable regions can be found from `clusters.list_regions()` size_tib: Volume size in whole tebibytes (TiB). - volume_name: Customizable name of the volume to create. + volume_name: User provided name of the volume. + + is_lifecycle_independent: When true, the shared volume is not deleted when the cluster is decommissioned. extra_headers: Send extra headers @@ -85,6 +88,7 @@ def create( "region": region, "size_tib": size_tib, "volume_name": volume_name, + "is_lifecycle_independent": is_lifecycle_independent, }, storage_create_params.StorageCreateParams, ), @@ -132,8 +136,8 @@ def retrieve( def update( self, *, + volume_id: str, size_tib: int | Omit = omit, - volume_id: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -145,9 +149,9 @@ def update( Update the configuration of an existing shared volume. Args: - size_tib: Size of the volume in whole tebibytes (TiB). + volume_id: ID of the volume. - volume_id: ID of the volume to update. + size_tib: Size of the volume in TiB. extra_headers: Send extra headers @@ -161,8 +165,8 @@ def update( "/compute/clusters/storage/volumes", body=maybe_transform( { - "size_tib": size_tib, "volume_id": volume_id, + "size_tib": size_tib, }, storage_update_params.StorageUpdateParams, ), @@ -175,6 +179,7 @@ def update( def list( self, *, + project_id: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -182,11 +187,30 @@ def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> StorageListResponse: - """List all shared volumes.""" + """ + List all shared volumes. + + Args: + project_id: Optional UMS project ID to filter volumes by. When set, only volumes belonging + to this project are returned. The caller must be a member of the project; + otherwise the result set will be empty. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ return self._get( "/compute/clusters/storage/volumes", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"project_id": project_id}, storage_list_params.StorageListParams), ), cast_to=StorageListResponse, ) @@ -255,6 +279,7 @@ async def create( region: str, size_tib: int, volume_name: str, + is_lifecycle_independent: bool | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -270,11 +295,13 @@ async def create( performance for shared storage. Args: - region: Region name. Usable regions can be found from `client.clusters.list_regions()` + region: Region name. Usable regions can be found from `clusters.list_regions()` size_tib: Volume size in whole tebibytes (TiB). - volume_name: Customizable name of the volume to create. + volume_name: User provided name of the volume. + + is_lifecycle_independent: When true, the shared volume is not deleted when the cluster is decommissioned. extra_headers: Send extra headers @@ -291,6 +318,7 @@ async def create( "region": region, "size_tib": size_tib, "volume_name": volume_name, + "is_lifecycle_independent": is_lifecycle_independent, }, storage_create_params.StorageCreateParams, ), @@ -338,8 +366,8 @@ async def retrieve( async def update( self, *, + volume_id: str, size_tib: int | Omit = omit, - volume_id: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -351,9 +379,9 @@ async def update( Update the configuration of an existing shared volume. Args: - size_tib: Size of the volume in whole tebibytes (TiB). + volume_id: ID of the volume. - volume_id: ID of the volume to update. + size_tib: Size of the volume in TiB. extra_headers: Send extra headers @@ -367,8 +395,8 @@ async def update( "/compute/clusters/storage/volumes", body=await async_maybe_transform( { - "size_tib": size_tib, "volume_id": volume_id, + "size_tib": size_tib, }, storage_update_params.StorageUpdateParams, ), @@ -381,6 +409,7 @@ async def update( async def list( self, *, + project_id: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -388,11 +417,30 @@ async def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> StorageListResponse: - """List all shared volumes.""" + """ + List all shared volumes. + + Args: + project_id: Optional UMS project ID to filter volumes by. When set, only volumes belonging + to this project are returned. The caller must be a member of the project; + otherwise the result set will be empty. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ return await self._get( "/compute/clusters/storage/volumes", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"project_id": project_id}, storage_list_params.StorageListParams), ), cast_to=StorageListResponse, ) diff --git a/src/together/resources/beta/jig/jig.py b/src/together/resources/beta/jig/jig.py index 64ba1862f..8beda0e29 100644 --- a/src/together/resources/beta/jig/jig.py +++ b/src/together/resources/beta/jig/jig.py @@ -128,7 +128,7 @@ def update( description: str | Omit = omit, environment_variables: Iterable[jig_update_params.EnvironmentVariable] | Omit = omit, gpu_count: int | Omit = omit, - gpu_type: Literal["h100-80gb", "h100-40gb-mig", "b200-192gb"] | Omit = omit, + gpu_type: Literal["h100-80gb", "h100-40gb-mig", "h200-140gb", "b200-192gb"] | Omit = omit, health_check_path: str | Omit = omit, image: str | Omit = omit, max_replicas: int | Omit = omit, @@ -262,7 +262,7 @@ def list( def deploy( self, *, - gpu_type: Literal["h100-80gb", "h100-40gb-mig", "b200-192gb"], + gpu_type: Literal["h100-80gb", "h100-40gb-mig", "h200-140gb", "b200-192gb"], image: str, name: str, args: SequenceNotStr[str] | Omit = omit, @@ -422,6 +422,8 @@ def retrieve_logs( id: str, *, replica_id: str | Omit = omit, + revision: str | Omit = omit, + version: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -437,6 +439,11 @@ def retrieve_logs( replica_id: Replica ID to filter logs + revision: Deployment revision (UUID) to filter logs + + version: Deployment image version (tag or last 4 characters of image digest) to filter + logs + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -454,7 +461,14 @@ def retrieve_logs( extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=maybe_transform({"replica_id": replica_id}, jig_retrieve_logs_params.JigRetrieveLogsParams), + query=maybe_transform( + { + "replica_id": replica_id, + "revision": revision, + "version": version, + }, + jig_retrieve_logs_params.JigRetrieveLogsParams, + ), ), cast_to=DeploymentLogs, ) @@ -538,7 +552,7 @@ async def update( description: str | Omit = omit, environment_variables: Iterable[jig_update_params.EnvironmentVariable] | Omit = omit, gpu_count: int | Omit = omit, - gpu_type: Literal["h100-80gb", "h100-40gb-mig", "b200-192gb"] | Omit = omit, + gpu_type: Literal["h100-80gb", "h100-40gb-mig", "h200-140gb", "b200-192gb"] | Omit = omit, health_check_path: str | Omit = omit, image: str | Omit = omit, max_replicas: int | Omit = omit, @@ -672,7 +686,7 @@ async def list( async def deploy( self, *, - gpu_type: Literal["h100-80gb", "h100-40gb-mig", "b200-192gb"], + gpu_type: Literal["h100-80gb", "h100-40gb-mig", "h200-140gb", "b200-192gb"], image: str, name: str, args: SequenceNotStr[str] | Omit = omit, @@ -832,6 +846,8 @@ async def retrieve_logs( id: str, *, replica_id: str | Omit = omit, + revision: str | Omit = omit, + version: str | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -847,6 +863,11 @@ async def retrieve_logs( replica_id: Replica ID to filter logs + revision: Deployment revision (UUID) to filter logs + + version: Deployment image version (tag or last 4 characters of image digest) to filter + logs + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -865,7 +886,12 @@ async def retrieve_logs( extra_body=extra_body, timeout=timeout, query=await async_maybe_transform( - {"replica_id": replica_id}, jig_retrieve_logs_params.JigRetrieveLogsParams + { + "replica_id": replica_id, + "revision": revision, + "version": version, + }, + jig_retrieve_logs_params.JigRetrieveLogsParams, ), ), cast_to=DeploymentLogs, diff --git a/src/together/resources/beta/jig/queue.py b/src/together/resources/beta/jig/queue.py index a48837dbf..3c5eaaf10 100644 --- a/src/together/resources/beta/jig/queue.py +++ b/src/together/resources/beta/jig/queue.py @@ -204,8 +204,8 @@ def submit( payload: Freeform model input. Passed unchanged to the model. Contents are model-specific. - info: Arbitrary JSON metadata stored with the job and returned in status responses. - The model and system may add or update keys during processing. + info: Arbitrary JSON metadata stored with the job. Returned in status responses, where + the model and system may have added or modified keys (e.g. progress). priority: Job priority. Higher values are processed first (strict priority ordering). Jobs with equal priority are processed in submission order (FIFO). @@ -414,8 +414,8 @@ async def submit( payload: Freeform model input. Passed unchanged to the model. Contents are model-specific. - info: Arbitrary JSON metadata stored with the job and returned in status responses. - The model and system may add or update keys during processing. + info: Arbitrary JSON metadata stored with the job. Returned in status responses, where + the model and system may have added or modified keys (e.g. progress). priority: Job priority. Higher values are processed first (strict priority ordering). Jobs with equal priority are processed in submission order (FIFO). diff --git a/src/together/resources/beta/jig/volumes.py b/src/together/resources/beta/jig/volumes.py index 3b94490b7..c70d3ae85 100644 --- a/src/together/resources/beta/jig/volumes.py +++ b/src/together/resources/beta/jig/volumes.py @@ -17,7 +17,7 @@ async_to_streamed_response_wrapper, ) from ...._base_client import make_request_options -from ....types.beta.jig import volume_create_params, volume_update_params +from ....types.beta.jig import volume_create_params, volume_update_params, volume_retrieve_params from ....types.beta.jig.volume import Volume from ....types.beta.jig.volume_list_response import VolumeListResponse @@ -95,6 +95,7 @@ def retrieve( self, id: str, *, + version: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -108,6 +109,8 @@ def retrieve( Args: id: Volume ID or name + version: Volume version to describe (defaults to current version) + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -121,7 +124,11 @@ def retrieve( return self._get( path_template("/deployments/storage/volumes/{id}", id=id), options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"version": version}, volume_retrieve_params.VolumeRetrieveParams), ), cast_to=Volume, ) @@ -304,6 +311,7 @@ async def retrieve( self, id: str, *, + version: int | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -317,6 +325,8 @@ async def retrieve( Args: id: Volume ID or name + version: Volume version to describe (defaults to current version) + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -330,7 +340,11 @@ async def retrieve( return await self._get( path_template("/deployments/storage/volumes/{id}", id=id), options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"version": version}, volume_retrieve_params.VolumeRetrieveParams), ), cast_to=Volume, ) diff --git a/src/together/resources/models/models.py b/src/together/resources/models/models.py index b14a0a630..a3486129c 100644 --- a/src/together/resources/models/models.py +++ b/src/together/resources/models/models.py @@ -68,7 +68,8 @@ def list( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ModelListResponse: """ - Lists all of Together's open-source models + Lists all of Together's open-source models and metadata including pricing, chat + template, and context. Args: dedicated: Filter models to only return dedicated models @@ -195,7 +196,8 @@ async def list( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> ModelListResponse: """ - Lists all of Together's open-source models + Lists all of Together's open-source models and metadata including pricing, chat + template, and context. Args: dedicated: Filter models to only return dedicated models diff --git a/src/together/types/beta/__init__.py b/src/together/types/beta/__init__.py index ab66b141e..e7c1d8d94 100644 --- a/src/together/types/beta/__init__.py +++ b/src/together/types/beta/__init__.py @@ -8,6 +8,7 @@ from .jig_deploy_params import JigDeployParams as JigDeployParams from .jig_list_response import JigListResponse as JigListResponse from .jig_update_params import JigUpdateParams as JigUpdateParams +from .cluster_list_params import ClusterListParams as ClusterListParams from .cluster_create_params import ClusterCreateParams as ClusterCreateParams from .cluster_list_response import ClusterListResponse as ClusterListResponse from .cluster_update_params import ClusterUpdateParams as ClusterUpdateParams diff --git a/src/together/types/beta/cluster.py b/src/together/types/beta/cluster.py index 9d5f254c7..cf130458b 100644 --- a/src/together/types/beta/cluster.py +++ b/src/together/types/beta/cluster.py @@ -5,8 +5,86 @@ from typing_extensions import Literal from ..._models import BaseModel +from .clusters.remediation import Remediation -__all__ = ["Cluster", "ControlPlaneNode", "GPUWorkerNode", "Volume"] +__all__ = [ + "Cluster", + "AddOn", + "AddOnConfig", + "AddOnConfigDashboard", + "AddOnConfigIngress", + "AddOnState", + "AddOnStateDashboard", + "AddOnStateIngress", + "ControlPlaneNode", + "ControlPlaneNodePhaseTransition", + "GPUWorkerNode", + "GPUWorkerNodePhaseTransition", + "PhaseTransition", + "Volume", + "ClusterConfig", + "ClusterConfigIngress", + "ClusterConfigObservability", + "ClusterConfigSlurmStartupScripts", + "OidcConfig", +] + + +class AddOnConfigDashboard(BaseModel): + enabled: Optional[bool] = None + + +class AddOnConfigIngress(BaseModel): + enabled: Optional[bool] = None + + +class AddOnConfig(BaseModel): + dashboard: Optional[AddOnConfigDashboard] = None + + ingress: Optional[AddOnConfigIngress] = None + + +class AddOnStateDashboard(BaseModel): + pass + + +class AddOnStateIngress(BaseModel): + pass + + +class AddOnState(BaseModel): + dashboard: Optional[AddOnStateDashboard] = None + + ingress: Optional[AddOnStateIngress] = None + + +class AddOn(BaseModel): + """AddOnInfo is returned in cluster responses and add-on CRUD operations.""" + + add_on_type: str + + config: AddOnConfig + + name: str + + state: AddOnState + + +class ControlPlaneNodePhaseTransition(BaseModel): + phase: Literal[ + "NODE_PHASE_PENDING", + "NODE_PHASE_SCHEDULING", + "NODE_PHASE_BOOTING", + "NODE_PHASE_BOOTSTRAPPING", + "NODE_PHASE_RUNNING", + "NODE_PHASE_SUCCEEDED", + "NODE_PHASE_FAILED", + "NODE_PHASE_PAUSED", + ] + """Node phase.""" + + transition_time: datetime + """Timestamp when the phase transition occurred.""" class ControlPlaneNode(BaseModel): @@ -18,13 +96,31 @@ class ControlPlaneNode(BaseModel): node_id: str - node_name: str - num_cpu_cores: int + phase_transitions: List[ControlPlaneNodePhaseTransition] + """Phase transition history for this control plane node.""" + status: str +class GPUWorkerNodePhaseTransition(BaseModel): + phase: Literal[ + "NODE_PHASE_PENDING", + "NODE_PHASE_SCHEDULING", + "NODE_PHASE_BOOTING", + "NODE_PHASE_BOOTSTRAPPING", + "NODE_PHASE_RUNNING", + "NODE_PHASE_SUCCEEDED", + "NODE_PHASE_FAILED", + "NODE_PHASE_PAUSED", + ] + """Node phase.""" + + transition_time: datetime + """Timestamp when the phase transition occurred.""" + + class GPUWorkerNode(BaseModel): host_name: str @@ -34,28 +130,176 @@ class GPUWorkerNode(BaseModel): node_id: str - node_name: str - num_cpu_cores: int num_gpus: int + phase_transitions: List[GPUWorkerNodePhaseTransition] + """Phase transition history for this GPU worker node.""" + status: str instance_id: Optional[str] = None + latest_remediation: Optional[Remediation] = None + """ + Remediation represents a node remediation request for an instance. An instance + can have multiple remediations over time (e.g., failed attempts followed by + retries). + """ + + slurm_worker_hostname: Optional[str] = None + + +class PhaseTransition(BaseModel): + phase: Literal[ + "CLUSTER_PHASE_QUEUED", + "CLUSTER_PHASE_SCHEDULED", + "CLUSTER_PHASE_WAITING_FOR_CONTROL_PLANE_NODES", + "CLUSTER_PHASE_WAITING_FOR_DATA_PLANE_NODES", + "CLUSTER_PHASE_WAITING_FOR_SUBNET", + "CLUSTER_PHASE_WAITING_FOR_SHARED_VOLUME", + "CLUSTER_PHASE_WAITING_FOR_AUTO_SCALER", + "CLUSTER_PHASE_INSTALLING_DRIVERS", + "CLUSTER_PHASE_RUNNING_ACCEPTANCE_TESTS", + "CLUSTER_PHASE_ACCEPTANCE_TESTS_FAILED", + "CLUSTER_PHASE_RUNNING_NCCL_TESTS", + "CLUSTER_PHASE_NCCL_TESTS_FAILED", + "CLUSTER_PHASE_READY", + "CLUSTER_PHASE_PAUSED", + "CLUSTER_PHASE_ON_DEMAND_COMPUTE_PAUSED", + "CLUSTER_PHASE_DEGRADED", + "CLUSTER_PHASE_DELETING", + ] + """Cluster phase.""" + + transition_time: datetime + """Timestamp when the phase transition occurred.""" + class Volume(BaseModel): size_tib: int + """Size of the volume in TiB.""" status: str + """Current status of the volume.""" volume_id: str + """ID of the volume.""" volume_name: str + """User provided name of the volume.""" + + +class ClusterConfigIngress(BaseModel): + enabled: Optional[bool] = None + + +class ClusterConfigObservability(BaseModel): + enabled: Optional[bool] = None + + +class ClusterConfigSlurmStartupScripts(BaseModel): + """ + SlurmStartupScripts carries optional Slurm lifecycle scripts (prolog/epilog, init, extra conf). + """ + + controller_epilog: Optional[str] = None + """Slurm controller epilog script.""" + + controller_prolog: Optional[str] = None + """Slurm controller prolog script.""" + + extra_slurm_conf: Optional[str] = None + """Additional slurm.conf fragments.""" + + login_init_script: Optional[str] = None + """Script run on Slurm login node init.""" + + nodeset_init_script: Optional[str] = None + """Script run on Slurm nodeset init.""" + + worker_epilog: Optional[str] = None + """Slurm worker node epilog script.""" + + worker_prolog: Optional[str] = None + """Slurm worker node prolog script.""" + + +class ClusterConfig(BaseModel): + load_balancer: Literal["NONE", "TRAEFIK", "NGINX", "ISTIO"] + + gpu_operator_version: Optional[str] = None + """NVIDIA GPU Operator chart/version for the tenant cluster (e.g. + + v24.6.2). When omitted, a service default is applied. + """ + + ingress: Optional[ClusterConfigIngress] = None + + jumphost_enabled: Optional[bool] = None + + kubernetes_dashboard_enabled: Optional[bool] = None + + observability: Optional[ClusterConfigObservability] = None + + slurm_startup_scripts: Optional[ClusterConfigSlurmStartupScripts] = None + """ + SlurmStartupScripts carries optional Slurm lifecycle scripts (prolog/epilog, + init, extra conf). + """ + + +class OidcConfig(BaseModel): + client_id: str + """OIDC client ID for authentication.""" + + group_claim: str + """JWT claim to use for user groups. For example, 'groups'""" + + group_prefix: str + """Prefix to add to the group claim to form the final group name. + + For example, 'oidc:' + """ + + issuer_url: str + """OIDC issuer URL for authentication. For example, https://accounts.google.com""" + + username_claim: str + """JWT claim to use as the username. For example, 'sub' or 'email'""" + + username_prefix: str + """Prefix to add to the username claim to form the final username. + + For example, 'oidc:' + """ + + ca_cert: Optional[str] = None + """CA certificate in PEM format to validate the OIDC issuer's TLS certificate. + + This field is optional but recommended if the issuer uses a private CA or + self-signed certificate. + """ class Cluster(BaseModel): + add_ons: List[AddOn] + """Enabled add-ons on this cluster. + + Only add-ons with enabled=true in their config are returned. + """ + + allocated_preemptible_gpus: int + """Actual number of preemptible GPUs currently allocated to the cluster. + + Updated asynchronously by the fulfillment and reclamation workers; may be less + than desired_preemptible_gpus when capacity is constrained. + """ + + billing_type: Literal["RESERVED", "ON_DEMAND", "SCHEDULED_CAPACITY"] + """Billing type for the cluster (RESERVED, ON_DEMAND, or SCHEDULED_CAPACITY).""" + cluster_id: str cluster_name: str @@ -67,16 +311,30 @@ class Cluster(BaseModel): cuda_version: str + desired_preemptible_gpus: int + """Customer's requested number of preemptible GPUs. + + Set on cluster create or update; persists until changed. + """ + gpu_type: Literal["H100_SXM", "H200_SXM", "RTX_6000_PCI", "L40_PCIE", "B200_SXM", "H100_SXM_INF"] gpu_worker_nodes: List[GPUWorkerNode] kube_config: str + num_cpu_workers: int + """Number of CPU-only worker nodes in the cluster.""" + num_gpus: int nvidia_driver_version: str + phase_transitions: List[PhaseTransition] + """Cluster-level phase transition history.""" + + project_id: str + region: str status: Literal[ @@ -98,12 +356,16 @@ class Cluster(BaseModel): capacity_pool_id: Optional[str] = None + cluster_config: Optional[ClusterConfig] = None + created_at: Optional[datetime] = None duration_hours: Optional[int] = None install_traefik: Optional[bool] = None + oidc_config: Optional[OidcConfig] = None + reservation_end_time: Optional[datetime] = None reservation_start_time: Optional[datetime] = None diff --git a/src/together/types/beta/cluster_create_params.py b/src/together/types/beta/cluster_create_params.py index e961052b7..9ac607436 100644 --- a/src/together/types/beta/cluster_create_params.py +++ b/src/together/types/beta/cluster_create_params.py @@ -2,13 +2,26 @@ from __future__ import annotations -from typing import Union +from typing import Union, Iterable from datetime import datetime from typing_extensions import Literal, Required, Annotated, TypedDict from ..._utils import PropertyInfo -__all__ = ["ClusterCreateParams", "SharedVolume"] +__all__ = [ + "ClusterCreateParams", + "AcceptanceTestsParams", + "AddOn", + "AddOnConfig", + "AddOnConfigDashboard", + "AddOnConfigIngress", + "ClusterConfig", + "ClusterConfigIngress", + "ClusterConfigObservability", + "ClusterConfigSlurmStartupScripts", + "OidcConfig", + "SharedVolume", +] class ClusterCreateParams(TypedDict, total=False): @@ -49,6 +62,22 @@ class ClusterCreateParams(TypedDict, total=False): Usable regions can be found from `client.clusters.list_regions()` """ + acceptance_tests_params: AcceptanceTestsParams + """ + AcceptanceTestsParams groups all GPU acceptance test options when enabled is + true. + """ + + add_ons: Iterable[AddOn] + """Add-ons to enable on the cluster at creation time.""" + + auto_scale: bool + """Whether to enable auto-scaling for the cluster. + + If true, the cluster will automatically scale the number of GPU worker nodes + between num_gpus and auto_scale_max_gpus based on the workload. + """ + auto_scale_max_gpus: int """Maximum number of GPUs to which the cluster can be auto-scaled up. @@ -68,6 +97,8 @@ class ClusterCreateParams(TypedDict, total=False): capacity pool. """ + cluster_config: ClusterConfig + cluster_type: Literal["KUBERNETES", "SLURM"] """Type of cluster to create.""" @@ -86,6 +117,35 @@ class ClusterCreateParams(TypedDict, total=False): This field is only applicable for Kubernetes clusters and is false by default. """ + num_capacity_pool_gpus: int + """Number of GPUs to allocate from the capacity pool. + + Must be a multiple of 8 and not exceed num_gpus. + """ + + num_preemptible_gpus: int + """Number of preemptible GPUs to request alongside on-demand capacity. + + Must be a multiple of 8. Preemptible nodes are cheaper but may be reclaimed when + on-demand capacity is needed elsewhere; the system fulfills this asynchronously + and surfaces the actual count in allocated_preemptible_gpus. + """ + + num_reserved_gpus: int + """Number of prepaid (PLG) reserved GPUs for this cluster. + + When omitted for RESERVED billing on create, the server defaults this to + num_gpus. + """ + + oidc_config: OidcConfig + + project_id: str + """Project ID for the cluster. + + If not set, the project from the request context is used. + """ + reservation_end_time: Annotated[Union[str, datetime], PropertyInfo(format="iso8601")] """Reservation end time of the cluster. @@ -116,14 +176,166 @@ class ClusterCreateParams(TypedDict, total=False): """ID of an existing volume to use with the cluster creation.""" +class AcceptanceTestsParams(TypedDict, total=False): + """ + AcceptanceTestsParams groups all GPU acceptance test options when enabled is true. + """ + + dcgm_diag_level: Literal[ + "DCGM_DIAG_LEVEL_SHORT", "DCGM_DIAG_LEVEL_MEDIUM", "DCGM_DIAG_LEVEL_LONG", "DCGM_DIAG_LEVEL_EXTENDED" + ] + """DCGM diagnostic depth. + + SHORT = readiness; MEDIUM = default; LONG = system validation; EXTENDED = + memtest. An omitted value selects MEDIUM when enabled. + """ + + dcgm_diag_skipped: bool + """Skip DCGM diagnostics acceptance test.""" + + enabled: bool + """Whether to run GPU acceptance tests during cluster bring-up.""" + + gpu_burn_duration: int + """GPU burn duration in seconds; 0 means use the default when enabled.""" + + gpu_burn_skipped: bool + """Skip GPU burn acceptance test.""" + + nccl_multi_node_skipped: bool + """Skip NCCL multi-node acceptance test.""" + + nccl_single_node_skipped: bool + """Skip NCCL single-node acceptance test.""" + + +class AddOnConfigDashboard(TypedDict, total=False): + enabled: bool + + +class AddOnConfigIngress(TypedDict, total=False): + enabled: bool + + +class AddOnConfig(TypedDict, total=False): + dashboard: AddOnConfigDashboard + + ingress: AddOnConfigIngress + + +class AddOn(TypedDict, total=False): + add_on_type: Required[str] + """Type of add-on. Valid values: 'dashboard', 'ingress'.""" + + name: Required[str] + """Human-readable name for this add-on instance.""" + + config: AddOnConfig + + +class ClusterConfigIngress(TypedDict, total=False): + enabled: bool + + +class ClusterConfigObservability(TypedDict, total=False): + enabled: bool + + +class ClusterConfigSlurmStartupScripts(TypedDict, total=False): + """ + SlurmStartupScripts carries optional Slurm lifecycle scripts (prolog/epilog, init, extra conf). + """ + + controller_epilog: str + """Slurm controller epilog script.""" + + controller_prolog: str + """Slurm controller prolog script.""" + + extra_slurm_conf: str + """Additional slurm.conf fragments.""" + + login_init_script: str + """Script run on Slurm login node init.""" + + nodeset_init_script: str + """Script run on Slurm nodeset init.""" + + worker_epilog: str + """Slurm worker node epilog script.""" + + worker_prolog: str + """Slurm worker node prolog script.""" + + +class ClusterConfig(TypedDict, total=False): + load_balancer: Required[Literal["NONE", "TRAEFIK", "NGINX", "ISTIO"]] + + gpu_operator_version: str + """NVIDIA GPU Operator chart/version for the tenant cluster (e.g. + + v24.6.2). When omitted, a service default is applied. + """ + + ingress: ClusterConfigIngress + + jumphost_enabled: bool + + kubernetes_dashboard_enabled: bool + + observability: ClusterConfigObservability + + slurm_startup_scripts: ClusterConfigSlurmStartupScripts + """ + SlurmStartupScripts carries optional Slurm lifecycle scripts (prolog/epilog, + init, extra conf). + """ + + +class OidcConfig(TypedDict, total=False): + client_id: Required[str] + """OIDC client ID for authentication.""" + + group_claim: Required[str] + """JWT claim to use for user groups. For example, 'groups'""" + + group_prefix: Required[str] + """Prefix to add to the group claim to form the final group name. + + For example, 'oidc:' + """ + + issuer_url: Required[str] + """OIDC issuer URL for authentication. For example, https://accounts.google.com""" + + username_claim: Required[str] + """JWT claim to use as the username. For example, 'sub' or 'email'""" + + username_prefix: Required[str] + """Prefix to add to the username claim to form the final username. + + For example, 'oidc:' + """ + + ca_cert: str + """CA certificate in PEM format to validate the OIDC issuer's TLS certificate. + + This field is optional but recommended if the issuer uses a private CA or + self-signed certificate. + """ + + class SharedVolume(TypedDict, total=False): """Inline configuration to create a shared volume with the cluster creation.""" region: Required[str] - """Region name. Usable regions can be found from `client.clusters.list_regions()`""" + """Region name. Usable regions can be found from `clusters.list_regions()`""" size_tib: Required[int] """Volume size in whole tebibytes (TiB).""" volume_name: Required[str] - """Customizable name of the volume to create.""" + """User provided name of the volume.""" + + is_lifecycle_independent: bool + """When true, the shared volume is not deleted when the cluster is decommissioned.""" diff --git a/src/together/types/beta/cluster_list_params.py b/src/together/types/beta/cluster_list_params.py new file mode 100644 index 000000000..69fa8e9e4 --- /dev/null +++ b/src/together/types/beta/cluster_list_params.py @@ -0,0 +1,16 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import TypedDict + +__all__ = ["ClusterListParams"] + + +class ClusterListParams(TypedDict, total=False): + project_id: str + """Optional UMS project ID to filter clusters by. + + When set, only clusters belonging to this project are returned. The caller must + be a member of the project; otherwise the result set will be empty. + """ diff --git a/src/together/types/beta/cluster_update_params.py b/src/together/types/beta/cluster_update_params.py index a98f77653..46cf388aa 100644 --- a/src/together/types/beta/cluster_update_params.py +++ b/src/together/types/beta/cluster_update_params.py @@ -2,23 +2,55 @@ from __future__ import annotations -from typing import Union +from typing import Union, Iterable from datetime import datetime -from typing_extensions import Literal, Annotated, TypedDict +from typing_extensions import Literal, Required, Annotated, TypedDict from ..._utils import PropertyInfo -__all__ = ["ClusterUpdateParams"] +__all__ = [ + "ClusterUpdateParams", + "AddOn", + "AddOnConfig", + "AddOnConfigDashboard", + "AddOnConfigIngress", + "ClusterConfig", + "ClusterConfigIngress", + "ClusterConfigObservability", + "ClusterConfigSlurmStartupScripts", +] class ClusterUpdateParams(TypedDict, total=False): + add_ons: Iterable[AddOn] + """Add-ons to update on the cluster. + + Each entry identifies an existing add-on by name and provides the new external + config to merge. + """ + + cluster_config: ClusterConfig + cluster_type: Literal["KUBERNETES", "SLURM"] """Type of cluster to update.""" num_gpus: int - """Number of GPUs to allocate in the cluster. + """Target GPU count for the cluster. + + When omitted, the server keeps the current GPU count from cluster metadata (use + for config-only or decommission-time-only updates). + """ - This must be multiple of 8. For example, 8, 16 or 24 + num_preemptible_gpus: int + """Updated desired number of preemptible GPUs for the cluster. + + When omitted, the current value is preserved. Must be a multiple of 8. + """ + + num_reserved_gpus: int + """Number of reserved GPUs to update to. + + This field is only applicable for clusters with RESERVED billing type. """ reservation_end_time: Annotated[Union[str, datetime], PropertyInfo(format="iso8601")] @@ -26,3 +58,83 @@ class ClusterUpdateParams(TypedDict, total=False): Only accepted for prepaid clusters. """ + + +class AddOnConfigDashboard(TypedDict, total=False): + enabled: bool + + +class AddOnConfigIngress(TypedDict, total=False): + enabled: bool + + +class AddOnConfig(TypedDict, total=False): + dashboard: AddOnConfigDashboard + + ingress: AddOnConfigIngress + + +class AddOn(TypedDict, total=False): + name: Required[str] + """Name of the add-on to update. Must match an existing add-on on the cluster.""" + + config: AddOnConfig + + +class ClusterConfigIngress(TypedDict, total=False): + enabled: bool + + +class ClusterConfigObservability(TypedDict, total=False): + enabled: bool + + +class ClusterConfigSlurmStartupScripts(TypedDict, total=False): + """ + SlurmStartupScripts carries optional Slurm lifecycle scripts (prolog/epilog, init, extra conf). + """ + + controller_epilog: str + """Slurm controller epilog script.""" + + controller_prolog: str + """Slurm controller prolog script.""" + + extra_slurm_conf: str + """Additional slurm.conf fragments.""" + + login_init_script: str + """Script run on Slurm login node init.""" + + nodeset_init_script: str + """Script run on Slurm nodeset init.""" + + worker_epilog: str + """Slurm worker node epilog script.""" + + worker_prolog: str + """Slurm worker node prolog script.""" + + +class ClusterConfig(TypedDict, total=False): + load_balancer: Required[Literal["NONE", "TRAEFIK", "NGINX", "ISTIO"]] + + gpu_operator_version: str + """NVIDIA GPU Operator chart/version for the tenant cluster (e.g. + + v24.6.2). When omitted, a service default is applied. + """ + + ingress: ClusterConfigIngress + + jumphost_enabled: bool + + kubernetes_dashboard_enabled: bool + + observability: ClusterConfigObservability + + slurm_startup_scripts: ClusterConfigSlurmStartupScripts + """ + SlurmStartupScripts carries optional Slurm lifecycle scripts (prolog/epilog, + init, extra conf). + """ diff --git a/src/together/types/beta/clusters/__init__.py b/src/together/types/beta/clusters/__init__.py index a85f4c49a..89ca43ab3 100644 --- a/src/together/types/beta/clusters/__init__.py +++ b/src/together/types/beta/clusters/__init__.py @@ -2,8 +2,15 @@ from __future__ import annotations +from .remediation import Remediation as Remediation from .cluster_storage import ClusterStorage as ClusterStorage +from .storage_list_params import StorageListParams as StorageListParams from .storage_create_params import StorageCreateParams as StorageCreateParams from .storage_list_response import StorageListResponse as StorageListResponse from .storage_update_params import StorageUpdateParams as StorageUpdateParams +from .remediation_list_params import RemediationListParams as RemediationListParams from .storage_delete_response import StorageDeleteResponse as StorageDeleteResponse +from .remediation_create_params import RemediationCreateParams as RemediationCreateParams +from .remediation_list_response import RemediationListResponse as RemediationListResponse +from .remediation_reject_params import RemediationRejectParams as RemediationRejectParams +from .remediation_approve_params import RemediationApproveParams as RemediationApproveParams diff --git a/src/together/types/beta/clusters/cluster_storage.py b/src/together/types/beta/clusters/cluster_storage.py index 6d7a0bfe4..8daaefc14 100644 --- a/src/together/types/beta/clusters/cluster_storage.py +++ b/src/together/types/beta/clusters/cluster_storage.py @@ -9,13 +9,15 @@ class ClusterStorage(BaseModel): size_tib: int - """Size of the volume in whole tebibytes (TiB).""" + """Size of the volume in TiB.""" - status: Literal["available", "bound", "provisioning"] - """Deployment status of the volume.""" + status: Literal[ + "scheduled", "available", "bound", "provisioning", "deleting", "failed", "access_revoked", "unknown" + ] + """Current status of the shared volume.""" volume_id: str """ID of the volume.""" volume_name: str - """Provided name of the volume.""" + """User provided name of the volume.""" diff --git a/src/together/types/beta/clusters/remediation.py b/src/together/types/beta/clusters/remediation.py new file mode 100644 index 000000000..d5606b57f --- /dev/null +++ b/src/together/types/beta/clusters/remediation.py @@ -0,0 +1,97 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from datetime import datetime +from typing_extensions import Literal + +from ...._models import BaseModel + +__all__ = ["Remediation"] + + +class Remediation(BaseModel): + """ + Remediation represents a node remediation request for an instance. + An instance can have multiple remediations over time (e.g., failed attempts followed by retries). + """ + + id: str + + cluster_id: str + + instance_id: str + + mode: Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ] + """Remediation mode specifies how the remediation should be performed. + + - `REMEDIATION_MODE_VM_ONLY`: Deletes the VM and provisions a new one on any + available host. + - `REMEDIATION_MODE_HOST_AWARE`: Cordons the host, deletes the VM, and + provisions a new one on a different host. + """ + + state: Literal["PENDING_APPROVAL", "PENDING", "RUNNING", "SUCCEEDED", "FAILED", "CANCELLED", "AUTO_RESOLVED"] + """RemediationState represents the lifecycle state of a remediation. + + - `PENDING_APPROVAL`: Awaiting approval before processing can begin. + - `PENDING`: Approved and queued for processing. + - `RUNNING`: Actively being processed. + - `SUCCEEDED`: Successfully completed. + - `FAILED`: Failed with an error. + - `CANCELLED`: Cancelled by user or system. + - `AUTO_RESOLVED`: The underlying issue was automatically resolved before + processing. + """ + + trigger: Literal["REMEDIATION_TRIGGER_MANUAL", "REMEDIATION_TRIGGER_AUTOMATED"] + """RemediationTrigger specifies how the remediation was triggered. + + - `REMEDIATION_TRIGGER_MANUAL`: A user-initiated remediation (either via web UI + or API call). + - `REMEDIATION_TRIGGER_AUTOMATED`: A system-initiated remediation that requires + approval. + """ + + active_health_check_run_id: Optional[str] = None + """Active health check run ID (UUID) that triggered this remediation.""" + + create_time: Optional[datetime] = None + """When the remediation was created.""" + + end_time: Optional[datetime] = None + """When the remediation completed.""" + + error_message: Optional[str] = None + """Error message if the remediation failed.""" + + instance_name: Optional[str] = None + """Display name of the targeted instance.""" + + passive_health_check_event_id: Optional[str] = None + """Passive health check event ID that triggered this remediation.""" + + reason: Optional[str] = None + """User-provided reason for the remediation.""" + + requested_by: Optional[str] = None + """Who requested the remediation.""" + + review_comment: Optional[str] = None + """Review comment.""" + + review_time: Optional[datetime] = None + """When the remediation was reviewed.""" + + reviewed_by: Optional[str] = None + """Who reviewed the remediation.""" + + start_time: Optional[datetime] = None + """When processing started.""" + + update_time: Optional[datetime] = None + """When the remediation was last updated.""" diff --git a/src/together/types/beta/clusters/remediation_approve_params.py b/src/together/types/beta/clusters/remediation_approve_params.py new file mode 100644 index 000000000..721f96e08 --- /dev/null +++ b/src/together/types/beta/clusters/remediation_approve_params.py @@ -0,0 +1,16 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +__all__ = ["RemediationApproveParams"] + + +class RemediationApproveParams(TypedDict, total=False): + cluster_id: Required[str] + + instance_id: Required[str] + + comment: str + """Comment explaining the action.""" diff --git a/src/together/types/beta/clusters/remediation_create_params.py b/src/together/types/beta/clusters/remediation_create_params.py new file mode 100644 index 000000000..c32ade02f --- /dev/null +++ b/src/together/types/beta/clusters/remediation_create_params.py @@ -0,0 +1,33 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["RemediationCreateParams"] + + +class RemediationCreateParams(TypedDict, total=False): + cluster_id: Required[str] + + mode: Required[ + Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ] + ] + """Remediation mode specifies how the remediation should be performed. + + - `REMEDIATION_MODE_VM_ONLY`: Deletes the VM and provisions a new one on any + available host. + - `REMEDIATION_MODE_HOST_AWARE`: Cordons the host, deletes the VM, and + provisions a new one on a different host. + """ + + remediation_id: str + """Client-specified ID for idempotency.""" + + reason: str + """User-provided reason for the remediation.""" diff --git a/src/together/types/beta/clusters/remediation_list_params.py b/src/together/types/beta/clusters/remediation_list_params.py new file mode 100644 index 000000000..39054359d --- /dev/null +++ b/src/together/types/beta/clusters/remediation_list_params.py @@ -0,0 +1,53 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import List +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["RemediationListParams"] + + +class RemediationListParams(TypedDict, total=False): + cluster_id: Required[str] + + mode: List[ + Literal[ + "REMEDIATION_MODE_VM_ONLY", + "REMEDIATION_MODE_HOST_AWARE", + "REMEDIATION_MODE_EVICT_WITHOUT_REPLACEMENT", + "REMEDIATION_MODE_REBOOT_VM", + ] + ] + """Filter by remediation mode(s). + + Returns remediations matching any of the specified modes. + """ + + order_by: str + """Order by expression.""" + + page_size: int + """Maximum results to return.""" + + page_token: str + """Pagination token from previous request.""" + + state: List[Literal["PENDING_APPROVAL", "PENDING", "RUNNING", "SUCCEEDED", "FAILED", "CANCELLED", "AUTO_RESOLVED"]] + """Filter by state(s). Returns remediations matching any of the specified states. + + - `PENDING_APPROVAL`: Awaiting approval before processing can begin. + - `PENDING`: Approved and queued for processing. + - `RUNNING`: Actively being processed. + - `SUCCEEDED`: Successfully completed. + - `FAILED`: Failed with an error. + - `CANCELLED`: Cancelled by user or system. + - `AUTO_RESOLVED`: The underlying issue was automatically resolved before + processing. + """ + + trigger: List[Literal["REMEDIATION_TRIGGER_MANUAL", "REMEDIATION_TRIGGER_AUTOMATED"]] + """Filter by trigger type(s). + + Returns remediations matching any of the specified triggers. + """ diff --git a/src/together/types/beta/clusters/remediation_list_response.py b/src/together/types/beta/clusters/remediation_list_response.py new file mode 100644 index 000000000..174674960 --- /dev/null +++ b/src/together/types/beta/clusters/remediation_list_response.py @@ -0,0 +1,21 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List + +from ...._models import BaseModel +from .remediation import Remediation + +__all__ = ["RemediationListResponse"] + + +class RemediationListResponse(BaseModel): + """ListRemediationsResponse is the response for ListRemediations.""" + + has_next: bool + """Indicates if there are more results available.""" + + next_page_token: str + """Token for the next page.""" + + remediations: List[Remediation] + """The list of remediations.""" diff --git a/src/together/types/beta/clusters/remediation_reject_params.py b/src/together/types/beta/clusters/remediation_reject_params.py new file mode 100644 index 000000000..cfe532aa3 --- /dev/null +++ b/src/together/types/beta/clusters/remediation_reject_params.py @@ -0,0 +1,16 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +__all__ = ["RemediationRejectParams"] + + +class RemediationRejectParams(TypedDict, total=False): + cluster_id: Required[str] + + instance_id: Required[str] + + comment: str + """Comment explaining the action.""" diff --git a/src/together/types/beta/clusters/storage_create_params.py b/src/together/types/beta/clusters/storage_create_params.py index 5629cb113..f368b0e8e 100644 --- a/src/together/types/beta/clusters/storage_create_params.py +++ b/src/together/types/beta/clusters/storage_create_params.py @@ -9,10 +9,13 @@ class StorageCreateParams(TypedDict, total=False): region: Required[str] - """Region name. Usable regions can be found from `client.clusters.list_regions()`""" + """Region name. Usable regions can be found from `clusters.list_regions()`""" size_tib: Required[int] """Volume size in whole tebibytes (TiB).""" volume_name: Required[str] - """Customizable name of the volume to create.""" + """User provided name of the volume.""" + + is_lifecycle_independent: bool + """When true, the shared volume is not deleted when the cluster is decommissioned.""" diff --git a/src/together/types/beta/clusters/storage_list_params.py b/src/together/types/beta/clusters/storage_list_params.py new file mode 100644 index 000000000..2dddfca79 --- /dev/null +++ b/src/together/types/beta/clusters/storage_list_params.py @@ -0,0 +1,16 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import TypedDict + +__all__ = ["StorageListParams"] + + +class StorageListParams(TypedDict, total=False): + project_id: str + """Optional UMS project ID to filter volumes by. + + When set, only volumes belonging to this project are returned. The caller must + be a member of the project; otherwise the result set will be empty. + """ diff --git a/src/together/types/beta/clusters/storage_update_params.py b/src/together/types/beta/clusters/storage_update_params.py index 449a62661..6e6a0162a 100644 --- a/src/together/types/beta/clusters/storage_update_params.py +++ b/src/together/types/beta/clusters/storage_update_params.py @@ -2,14 +2,14 @@ from __future__ import annotations -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict __all__ = ["StorageUpdateParams"] class StorageUpdateParams(TypedDict, total=False): - size_tib: int - """Size of the volume in whole tebibytes (TiB).""" + volume_id: Required[str] + """ID of the volume.""" - volume_id: str - """ID of the volume to update.""" + size_tib: int + """Size of the volume in TiB.""" diff --git a/src/together/types/beta/deployment.py b/src/together/types/beta/deployment.py index f7e36e0af..3ce2f599a 100644 --- a/src/together/types/beta/deployment.py +++ b/src/together/types/beta/deployment.py @@ -196,7 +196,7 @@ class Deployment(BaseModel): gpu_count: Optional[int] = None """GPUCount is the number of GPUs allocated to each replica in this deployment""" - gpu_type: Optional[Literal["h100-80gb", "h100-40gb-mig", "b200-192gb"]] = None + gpu_type: Optional[Literal["h100-80gb", "h100-40gb-mig", "h200-140gb", "b200-192gb"]] = None """GPUType specifies the type of GPU requested (if any) for this deployment""" health_check_path: Optional[str] = None @@ -244,6 +244,12 @@ class Deployment(BaseModel): allocated to each replica """ + termination_grace_period_seconds: Optional[int] = None + """ + TerminationGracePeriodSeconds is the time in seconds to wait for graceful + shutdown before forcefully terminating the replica + """ + updated_at: Optional[datetime] = None """UpdatedAt is the ISO8601 timestamp when this deployment was last updated""" diff --git a/src/together/types/beta/jig/__init__.py b/src/together/types/beta/jig/__init__.py index ef03b5652..f8c024918 100644 --- a/src/together/types/beta/jig/__init__.py +++ b/src/together/types/beta/jig/__init__.py @@ -17,4 +17,5 @@ from .queue_retrieve_params import QueueRetrieveParams as QueueRetrieveParams from .queue_submit_response import QueueSubmitResponse as QueueSubmitResponse from .queue_metrics_response import QueueMetricsResponse as QueueMetricsResponse +from .volume_retrieve_params import VolumeRetrieveParams as VolumeRetrieveParams from .queue_retrieve_response import QueueRetrieveResponse as QueueRetrieveResponse diff --git a/src/together/types/beta/jig/queue_cancel_response.py b/src/together/types/beta/jig/queue_cancel_response.py index eeb6e820d..13a0fec93 100644 --- a/src/together/types/beta/jig/queue_cancel_response.py +++ b/src/together/types/beta/jig/queue_cancel_response.py @@ -8,6 +8,8 @@ class QueueCancelResponse(BaseModel): + """Status returned after a cancel attempt.""" + status: Literal["canceled", "running", "done", "failed"] """Job status after the cancel attempt. diff --git a/src/together/types/beta/jig/queue_metrics_response.py b/src/together/types/beta/jig/queue_metrics_response.py index 3387b0553..0b0c803e5 100644 --- a/src/together/types/beta/jig/queue_metrics_response.py +++ b/src/together/types/beta/jig/queue_metrics_response.py @@ -6,6 +6,8 @@ class QueueMetricsResponse(BaseModel): + """Queue job counts for a model.""" + messages_running: int """Number of jobs currently being processed""" diff --git a/src/together/types/beta/jig/queue_retrieve_response.py b/src/together/types/beta/jig/queue_retrieve_response.py index d93654a12..df9a1f2e3 100644 --- a/src/together/types/beta/jig/queue_retrieve_response.py +++ b/src/together/types/beta/jig/queue_retrieve_response.py @@ -10,6 +10,8 @@ class QueueRetrieveResponse(BaseModel): + """Current status and metadata for a queued job.""" + model: str """Model identifier the job was submitted to""" diff --git a/src/together/types/beta/jig/queue_submit_params.py b/src/together/types/beta/jig/queue_submit_params.py index 3c80dbd8c..113a16945 100644 --- a/src/together/types/beta/jig/queue_submit_params.py +++ b/src/together/types/beta/jig/queue_submit_params.py @@ -19,9 +19,10 @@ class QueueSubmitParams(TypedDict, total=False): """ info: Dict[str, object] - """Arbitrary JSON metadata stored with the job and returned in status responses. + """Arbitrary JSON metadata stored with the job. - The model and system may add or update keys during processing. + Returned in status responses, where the model and system may have added or + modified keys (e.g. progress). """ priority: int diff --git a/src/together/types/beta/jig/queue_submit_response.py b/src/together/types/beta/jig/queue_submit_response.py index a6bb15890..67a94db67 100644 --- a/src/together/types/beta/jig/queue_submit_response.py +++ b/src/together/types/beta/jig/queue_submit_response.py @@ -1,30 +1,14 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Optional - from pydantic import Field as FieldInfo from ...._models import BaseModel -__all__ = ["QueueSubmitResponse", "Error"] - - -class Error(BaseModel): - code: Optional[str] = None - """Machine-readable error code""" - - message: Optional[str] = None - """Human-readable error message""" - - param: Optional[str] = None - """The parameter that caused the error, if applicable""" - - type: Optional[str] = None - """Error category (e.g. "invalid_request_error", "not_found_error")""" +__all__ = ["QueueSubmitResponse"] class QueueSubmitResponse(BaseModel): - error: Optional[Error] = None + """Response returned after queueing a job.""" - request_id: Optional[str] = FieldInfo(alias="requestId", default=None) + request_id: str = FieldInfo(alias="requestId") """Unique identifier for the submitted job. Use this to poll status or cancel.""" diff --git a/src/together/types/beta/jig/volume_retrieve_params.py b/src/together/types/beta/jig/volume_retrieve_params.py new file mode 100644 index 000000000..7b194db4d --- /dev/null +++ b/src/together/types/beta/jig/volume_retrieve_params.py @@ -0,0 +1,12 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import TypedDict + +__all__ = ["VolumeRetrieveParams"] + + +class VolumeRetrieveParams(TypedDict, total=False): + version: int + """Volume version to describe (defaults to current version)""" diff --git a/src/together/types/beta/jig_deploy_params.py b/src/together/types/beta/jig_deploy_params.py index c0bdc9c54..9682a7a3a 100644 --- a/src/together/types/beta/jig_deploy_params.py +++ b/src/together/types/beta/jig_deploy_params.py @@ -19,7 +19,7 @@ class JigDeployParams(TypedDict, total=False): - gpu_type: Required[Literal["h100-80gb", "h100-40gb-mig", "b200-192gb"]] + gpu_type: Required[Literal["h100-80gb", "h100-40gb-mig", "h200-140gb", "b200-192gb"]] """GPUType specifies the GPU hardware to use (e.g., "h100-80gb").""" image: Required[str] diff --git a/src/together/types/beta/jig_retrieve_logs_params.py b/src/together/types/beta/jig_retrieve_logs_params.py index 8afbea2d8..59e0ca3cb 100644 --- a/src/together/types/beta/jig_retrieve_logs_params.py +++ b/src/together/types/beta/jig_retrieve_logs_params.py @@ -10,3 +10,12 @@ class JigRetrieveLogsParams(TypedDict, total=False): replica_id: str """Replica ID to filter logs""" + + revision: str + """Deployment revision (UUID) to filter logs""" + + version: str + """ + Deployment image version (tag or last 4 characters of image digest) to filter + logs + """ diff --git a/src/together/types/beta/jig_update_params.py b/src/together/types/beta/jig_update_params.py index c91e46d92..56ad5d58c 100644 --- a/src/together/types/beta/jig_update_params.py +++ b/src/together/types/beta/jig_update_params.py @@ -52,7 +52,7 @@ class JigUpdateParams(TypedDict, total=False): gpu_count: int """GPUCount is the number of GPUs to allocate per container instance""" - gpu_type: Literal["h100-80gb", "h100-40gb-mig", "b200-192gb"] + gpu_type: Literal["h100-80gb", "h100-40gb-mig", "h200-140gb", "b200-192gb"] """GPUType specifies the GPU hardware to use (e.g., "h100-80gb")""" health_check_path: str diff --git a/src/together/types/eval_create_params.py b/src/together/types/eval_create_params.py index 081698c6c..a574a6bba 100644 --- a/src/together/types/eval_create_params.py +++ b/src/together/types/eval_create_params.py @@ -40,65 +40,101 @@ class ParametersEvaluationClassifyParametersJudge(TypedDict, total=False): """Name of the judge model""" model_source: Required[Literal["serverless", "dedicated", "external"]] - """Source of the judge model.""" + """ + Source of the judge model inference: - `serverless`: Together's shared + serverless inference API. Default concurrency: 25 workers. - `dedicated`: A + Together dedicated deployment endpoint. Default concurrency: 5 workers (minimum + enforced even if num_workers is set lower). + + - `external`: An external inference API (e.g. OpenAI, Anthropic, Google, + OpenRouter). Requires `external_api_token` and `external_base_url`. Default + concurrency: 2 workers for first-party APIs, 20 for proxy/aggregator + endpoints. + """ system_template: Required[str] """System prompt template for the judge""" external_api_token: str - """Bearer/API token for external judge models.""" + """Bearer/API token for the external judge model provider. + + Required when model_source is 'external'. + """ external_base_url: str - """Base URL for external judge models. Must be OpenAI-compatible base URL.""" + """Base URL of the external inference API for the judge. + + Must be OpenAI-compatible. Required when model_source is 'external'. + """ max_tokens: int - """Maximum number of tokens the judge model can generate. + """Maximum number of tokens the judge model may generate. - Defaults to 32768. Increase for reasoning models (e.g. Gemini, o-series) that - consume output token budget for chain-of-thought. + Defaults to 32768 if omitted. Set higher for reasoning judges (e.g. o-series, + Gemini) that spend tokens on internal chain-of-thought before emitting the + verdict JSON. """ num_workers: int - """Number of concurrent workers for inference requests. + """Number of concurrent inference workers for the judge. - Overrides the default concurrency for this model. Useful for tuning throughput - when using proxy endpoints (e.g. OpenRouter) or rate-limited external APIs. + Overrides the source-specific default (serverless: 25, dedicated: 5, external: + 2–20). For dedicated endpoints the value is clamped to a minimum of 5 regardless + of what is set here. """ temperature: float - """Sampling temperature for the judge model. Defaults to 0.05.""" + """Sampling temperature for the judge model. Defaults to 0.05 if omitted.""" class ParametersEvaluationClassifyParametersModelToEvaluateEvaluationModelRequest(TypedDict, total=False): input_template: Required[str] - """Input prompt template""" + """User message template. Supports Jinja2 variables referencing dataset columns.""" max_tokens: Required[int] - """Maximum number of tokens to generate""" + """Maximum number of tokens to generate.""" model: Required[str] """Name of the model to evaluate""" model_source: Required[Literal["serverless", "dedicated", "external"]] - """Source of the model.""" + """ + Source of the model inference: - `serverless`: Together's shared serverless + inference API. Default concurrency: 25 workers. - `dedicated`: A Together + dedicated deployment endpoint. Default concurrency: 5 workers (minimum enforced + even if num_workers is set lower). Authentication uses the requesting user's + Together API token automatically. + + - `external`: An external inference API (e.g. OpenAI, Anthropic, Google, + OpenRouter). Requires `external_api_token` and `external_base_url`. Default + concurrency: 2 workers for first-party APIs (OpenAI, Anthropic, Google), 20 + for proxy/aggregator endpoints. + """ system_template: Required[str] - """System prompt template""" + """System prompt template. Supports Jinja2 variables referencing dataset columns.""" temperature: Required[float] - """Sampling temperature""" + """Sampling temperature for generation.""" external_api_token: str - """Bearer/API token for external models.""" + """Bearer/API token for the external model provider. + + Required when model_source is 'external'. + """ external_base_url: str - """Base URL for external models. Must be OpenAI-compatible base URL""" + """Base URL of the external inference API. + + Must be OpenAI-compatible. Required when model_source is 'external'. + """ num_workers: int - """Number of concurrent workers for inference requests. + """Number of concurrent inference workers. - Overrides the default concurrency for this model. Useful for tuning throughput - when using proxy endpoints (e.g. OpenRouter) or rate-limited external APIs. + Overrides the source-specific default (serverless: 25, dedicated: 5, external: + 2–20). For dedicated endpoints the value is clamped to a minimum of 5 regardless + of what is set here. """ @@ -120,7 +156,7 @@ class ParametersEvaluationClassifyParameters(TypedDict, total=False): """List of labels that are considered passing""" model_to_evaluate: ParametersEvaluationClassifyParametersModelToEvaluate - """Field name in the input data""" + """Column name in the input dataset containing pre-generated responses""" class ParametersEvaluationScoreParametersJudge(TypedDict, total=False): @@ -128,65 +164,101 @@ class ParametersEvaluationScoreParametersJudge(TypedDict, total=False): """Name of the judge model""" model_source: Required[Literal["serverless", "dedicated", "external"]] - """Source of the judge model.""" + """ + Source of the judge model inference: - `serverless`: Together's shared + serverless inference API. Default concurrency: 25 workers. - `dedicated`: A + Together dedicated deployment endpoint. Default concurrency: 5 workers (minimum + enforced even if num_workers is set lower). + + - `external`: An external inference API (e.g. OpenAI, Anthropic, Google, + OpenRouter). Requires `external_api_token` and `external_base_url`. Default + concurrency: 2 workers for first-party APIs, 20 for proxy/aggregator + endpoints. + """ system_template: Required[str] """System prompt template for the judge""" external_api_token: str - """Bearer/API token for external judge models.""" + """Bearer/API token for the external judge model provider. + + Required when model_source is 'external'. + """ external_base_url: str - """Base URL for external judge models. Must be OpenAI-compatible base URL.""" + """Base URL of the external inference API for the judge. + + Must be OpenAI-compatible. Required when model_source is 'external'. + """ max_tokens: int - """Maximum number of tokens the judge model can generate. + """Maximum number of tokens the judge model may generate. - Defaults to 32768. Increase for reasoning models (e.g. Gemini, o-series) that - consume output token budget for chain-of-thought. + Defaults to 32768 if omitted. Set higher for reasoning judges (e.g. o-series, + Gemini) that spend tokens on internal chain-of-thought before emitting the + verdict JSON. """ num_workers: int - """Number of concurrent workers for inference requests. + """Number of concurrent inference workers for the judge. - Overrides the default concurrency for this model. Useful for tuning throughput - when using proxy endpoints (e.g. OpenRouter) or rate-limited external APIs. + Overrides the source-specific default (serverless: 25, dedicated: 5, external: + 2–20). For dedicated endpoints the value is clamped to a minimum of 5 regardless + of what is set here. """ temperature: float - """Sampling temperature for the judge model. Defaults to 0.05.""" + """Sampling temperature for the judge model. Defaults to 0.05 if omitted.""" class ParametersEvaluationScoreParametersModelToEvaluateEvaluationModelRequest(TypedDict, total=False): input_template: Required[str] - """Input prompt template""" + """User message template. Supports Jinja2 variables referencing dataset columns.""" max_tokens: Required[int] - """Maximum number of tokens to generate""" + """Maximum number of tokens to generate.""" model: Required[str] """Name of the model to evaluate""" model_source: Required[Literal["serverless", "dedicated", "external"]] - """Source of the model.""" + """ + Source of the model inference: - `serverless`: Together's shared serverless + inference API. Default concurrency: 25 workers. - `dedicated`: A Together + dedicated deployment endpoint. Default concurrency: 5 workers (minimum enforced + even if num_workers is set lower). Authentication uses the requesting user's + Together API token automatically. + + - `external`: An external inference API (e.g. OpenAI, Anthropic, Google, + OpenRouter). Requires `external_api_token` and `external_base_url`. Default + concurrency: 2 workers for first-party APIs (OpenAI, Anthropic, Google), 20 + for proxy/aggregator endpoints. + """ system_template: Required[str] - """System prompt template""" + """System prompt template. Supports Jinja2 variables referencing dataset columns.""" temperature: Required[float] - """Sampling temperature""" + """Sampling temperature for generation.""" external_api_token: str - """Bearer/API token for external models.""" + """Bearer/API token for the external model provider. + + Required when model_source is 'external'. + """ external_base_url: str - """Base URL for external models. Must be OpenAI-compatible base URL""" + """Base URL of the external inference API. + + Must be OpenAI-compatible. Required when model_source is 'external'. + """ num_workers: int - """Number of concurrent workers for inference requests. + """Number of concurrent inference workers. - Overrides the default concurrency for this model. Useful for tuning throughput - when using proxy endpoints (e.g. OpenRouter) or rate-limited external APIs. + Overrides the source-specific default (serverless: 25, dedicated: 5, external: + 2–20). For dedicated endpoints the value is clamped to a minimum of 5 regardless + of what is set here. """ @@ -211,7 +283,7 @@ class ParametersEvaluationScoreParameters(TypedDict, total=False): """Score threshold for passing""" model_to_evaluate: ParametersEvaluationScoreParametersModelToEvaluate - """Field name in the input data""" + """Column name in the input dataset containing pre-generated responses""" class ParametersEvaluationCompareParametersJudge(TypedDict, total=False): @@ -219,122 +291,196 @@ class ParametersEvaluationCompareParametersJudge(TypedDict, total=False): """Name of the judge model""" model_source: Required[Literal["serverless", "dedicated", "external"]] - """Source of the judge model.""" + """ + Source of the judge model inference: - `serverless`: Together's shared + serverless inference API. Default concurrency: 25 workers. - `dedicated`: A + Together dedicated deployment endpoint. Default concurrency: 5 workers (minimum + enforced even if num_workers is set lower). + + - `external`: An external inference API (e.g. OpenAI, Anthropic, Google, + OpenRouter). Requires `external_api_token` and `external_base_url`. Default + concurrency: 2 workers for first-party APIs, 20 for proxy/aggregator + endpoints. + """ system_template: Required[str] """System prompt template for the judge""" external_api_token: str - """Bearer/API token for external judge models.""" + """Bearer/API token for the external judge model provider. + + Required when model_source is 'external'. + """ external_base_url: str - """Base URL for external judge models. Must be OpenAI-compatible base URL.""" + """Base URL of the external inference API for the judge. + + Must be OpenAI-compatible. Required when model_source is 'external'. + """ max_tokens: int - """Maximum number of tokens the judge model can generate. + """Maximum number of tokens the judge model may generate. - Defaults to 32768. Increase for reasoning models (e.g. Gemini, o-series) that - consume output token budget for chain-of-thought. + Defaults to 32768 if omitted. Set higher for reasoning judges (e.g. o-series, + Gemini) that spend tokens on internal chain-of-thought before emitting the + verdict JSON. """ num_workers: int - """Number of concurrent workers for inference requests. + """Number of concurrent inference workers for the judge. - Overrides the default concurrency for this model. Useful for tuning throughput - when using proxy endpoints (e.g. OpenRouter) or rate-limited external APIs. + Overrides the source-specific default (serverless: 25, dedicated: 5, external: + 2–20). For dedicated endpoints the value is clamped to a minimum of 5 regardless + of what is set here. """ temperature: float - """Sampling temperature for the judge model. Defaults to 0.05.""" + """Sampling temperature for the judge model. Defaults to 0.05 if omitted.""" class ParametersEvaluationCompareParametersModelAEvaluationModelRequest(TypedDict, total=False): input_template: Required[str] - """Input prompt template""" + """User message template. Supports Jinja2 variables referencing dataset columns.""" max_tokens: Required[int] - """Maximum number of tokens to generate""" + """Maximum number of tokens to generate.""" model: Required[str] """Name of the model to evaluate""" model_source: Required[Literal["serverless", "dedicated", "external"]] - """Source of the model.""" + """ + Source of the model inference: - `serverless`: Together's shared serverless + inference API. Default concurrency: 25 workers. - `dedicated`: A Together + dedicated deployment endpoint. Default concurrency: 5 workers (minimum enforced + even if num_workers is set lower). Authentication uses the requesting user's + Together API token automatically. + + - `external`: An external inference API (e.g. OpenAI, Anthropic, Google, + OpenRouter). Requires `external_api_token` and `external_base_url`. Default + concurrency: 2 workers for first-party APIs (OpenAI, Anthropic, Google), 20 + for proxy/aggregator endpoints. + """ system_template: Required[str] - """System prompt template""" + """System prompt template. Supports Jinja2 variables referencing dataset columns.""" temperature: Required[float] - """Sampling temperature""" + """Sampling temperature for generation.""" external_api_token: str - """Bearer/API token for external models.""" + """Bearer/API token for the external model provider. + + Required when model_source is 'external'. + """ external_base_url: str - """Base URL for external models. Must be OpenAI-compatible base URL""" + """Base URL of the external inference API. + + Must be OpenAI-compatible. Required when model_source is 'external'. + """ num_workers: int - """Number of concurrent workers for inference requests. + """Number of concurrent inference workers. - Overrides the default concurrency for this model. Useful for tuning throughput - when using proxy endpoints (e.g. OpenRouter) or rate-limited external APIs. + Overrides the source-specific default (serverless: 25, dedicated: 5, external: + 2–20). For dedicated endpoints the value is clamped to a minimum of 5 regardless + of what is set here. """ ParametersEvaluationCompareParametersModelA: TypeAlias = Union[ - str, ParametersEvaluationCompareParametersModelAEvaluationModelRequest + ParametersEvaluationCompareParametersModelAEvaluationModelRequest, str ] class ParametersEvaluationCompareParametersModelBEvaluationModelRequest(TypedDict, total=False): input_template: Required[str] - """Input prompt template""" + """User message template. Supports Jinja2 variables referencing dataset columns.""" max_tokens: Required[int] - """Maximum number of tokens to generate""" + """Maximum number of tokens to generate.""" model: Required[str] """Name of the model to evaluate""" model_source: Required[Literal["serverless", "dedicated", "external"]] - """Source of the model.""" + """ + Source of the model inference: - `serverless`: Together's shared serverless + inference API. Default concurrency: 25 workers. - `dedicated`: A Together + dedicated deployment endpoint. Default concurrency: 5 workers (minimum enforced + even if num_workers is set lower). Authentication uses the requesting user's + Together API token automatically. + + - `external`: An external inference API (e.g. OpenAI, Anthropic, Google, + OpenRouter). Requires `external_api_token` and `external_base_url`. Default + concurrency: 2 workers for first-party APIs (OpenAI, Anthropic, Google), 20 + for proxy/aggregator endpoints. + """ system_template: Required[str] - """System prompt template""" + """System prompt template. Supports Jinja2 variables referencing dataset columns.""" temperature: Required[float] - """Sampling temperature""" + """Sampling temperature for generation.""" external_api_token: str - """Bearer/API token for external models.""" + """Bearer/API token for the external model provider. + + Required when model_source is 'external'. + """ external_base_url: str - """Base URL for external models. Must be OpenAI-compatible base URL""" + """Base URL of the external inference API. + + Must be OpenAI-compatible. Required when model_source is 'external'. + """ num_workers: int - """Number of concurrent workers for inference requests. + """Number of concurrent inference workers. - Overrides the default concurrency for this model. Useful for tuning throughput - when using proxy endpoints (e.g. OpenRouter) or rate-limited external APIs. + Overrides the source-specific default (serverless: 25, dedicated: 5, external: + 2–20). For dedicated endpoints the value is clamped to a minimum of 5 regardless + of what is set here. """ ParametersEvaluationCompareParametersModelB: TypeAlias = Union[ - str, ParametersEvaluationCompareParametersModelBEvaluationModelRequest + ParametersEvaluationCompareParametersModelBEvaluationModelRequest, str ] class ParametersEvaluationCompareParameters(TypedDict, total=False): input_data_file_path: Required[str] - """Data file name""" + """Data file ID""" judge: Required[ParametersEvaluationCompareParametersJudge] + disable_position_bias_correction: bool + """ + When false (default), the judge runs twice per sample: once with model A's + response first (original order) and once with model B's response first (flipped + order). The two verdicts are reconciled to cancel out position bias. When true, + only the original-order pass is run, halving judge cost and latency at the + expense of position-bias correction. The result file will not contain + flipped-order judge fields when this is true. + """ + model_a: ParametersEvaluationCompareParametersModelA - """Field name in the input data""" + """ + Either an EvaluationModelRequest for generation or a string column name from the + dataset (when responses are pre-generated). When both model_a and model_b are + EvaluationModelRequest objects, their inference runs execute in parallel to + reduce total wall-clock time. + """ model_b: ParametersEvaluationCompareParametersModelB - """Field name in the input data""" + """ + Either an EvaluationModelRequest for generation or a string column name from the + dataset (when responses are pre-generated). When both model_a and model_b are + EvaluationModelRequest objects, their inference runs execute in parallel to + reduce total wall-clock time. + """ Parameters: TypeAlias = Union[ diff --git a/src/together/types/eval_status_response.py b/src/together/types/eval_status_response.py index ceaa864ea..0caf5327a 100644 --- a/src/together/types/eval_status_response.py +++ b/src/together/types/eval_status_response.py @@ -66,25 +66,33 @@ class ResultsEvaluationScoreResults(BaseModel): class ResultsEvaluationCompareResults(BaseModel): a_wins: Optional[int] = FieldInfo(alias="A_wins", default=None) - """Number of times model A won""" + """Number of samples where model A was judged the winner""" b_wins: Optional[int] = FieldInfo(alias="B_wins", default=None) - """Number of times model B won""" + """Number of samples where model B was judged the winner""" generation_fail_count: Optional[float] = None - """Number of failed generations.""" + """Number of generation failures across model A and model B.""" judge_fail_count: Optional[float] = None - """Number of failed judge generations""" + """Number of judge inference failures. - num_samples: Optional[int] = None - """Total number of samples compared""" + In the default two-pass mode (disable_position_bias_correction=false) this is + the combined failure count from both the original-order and flipped-order judge + passes. + """ result_file_id: Optional[str] = None - """Data File ID""" + """File ID of the detailed output file. + + Each row contains the original input fields plus judge outputs. In two-pass mode + the file includes both original-order and flipped-order judge fields; in + single-pass mode (disable_position_bias_correction=true) only original-order + fields are present. + """ ties: Optional[int] = FieldInfo(alias="Ties", default=None) - """Number of ties""" + """Number of samples that resulted in a tie""" Results: TypeAlias = Union[ diff --git a/src/together/types/evaluation_job.py b/src/together/types/evaluation_job.py index 5f9aba7ff..e822f854d 100644 --- a/src/together/types/evaluation_job.py +++ b/src/together/types/evaluation_job.py @@ -69,25 +69,33 @@ class ResultsEvaluationScoreResults(BaseModel): class ResultsEvaluationCompareResults(BaseModel): a_wins: Optional[int] = FieldInfo(alias="A_wins", default=None) - """Number of times model A won""" + """Number of samples where model A was judged the winner""" b_wins: Optional[int] = FieldInfo(alias="B_wins", default=None) - """Number of times model B won""" + """Number of samples where model B was judged the winner""" generation_fail_count: Optional[float] = None - """Number of failed generations.""" + """Number of generation failures across model A and model B.""" judge_fail_count: Optional[float] = None - """Number of failed judge generations""" + """Number of judge inference failures. - num_samples: Optional[int] = None - """Total number of samples compared""" + In the default two-pass mode (disable_position_bias_correction=false) this is + the combined failure count from both the original-order and flipped-order judge + passes. + """ result_file_id: Optional[str] = None - """Data File ID""" + """File ID of the detailed output file. + + Each row contains the original input fields plus judge outputs. In two-pass mode + the file includes both original-order and flipped-order judge fields; in + single-pass mode (disable_position_bias_correction=true) only original-order + fields are present. + """ ties: Optional[int] = FieldInfo(alias="Ties", default=None) - """Number of ties""" + """Number of samples that resulted in a tie""" class ResultsError(BaseModel): diff --git a/tests/api_resources/beta/clusters/test_remediations.py b/tests/api_resources/beta/clusters/test_remediations.py new file mode 100644 index 000000000..bb10b17f9 --- /dev/null +++ b/tests/api_resources/beta/clusters/test_remediations.py @@ -0,0 +1,799 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from together import Together, AsyncTogether +from tests.utils import assert_matches_type +from together.types.beta.clusters import ( + Remediation, + RemediationListResponse, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestRemediations: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_create(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.create( + instance_id="instance_id", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_method_create_with_all_params(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.create( + instance_id="instance_id", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + remediation_id="remediation_id", + reason="reason", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_raw_response_create(self, client: Together) -> None: + response = client.beta.clusters.remediations.with_raw_response.create( + instance_id="instance_id", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_streaming_response_create(self, client: Together) -> None: + with client.beta.clusters.remediations.with_streaming_response.create( + instance_id="instance_id", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_create(self, client: Together) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.create( + instance_id="instance_id", + cluster_id="", + mode="REMEDIATION_MODE_VM_ONLY", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.create( + instance_id="", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + ) + + @parametrize + def test_method_retrieve(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.retrieve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: Together) -> None: + response = client.beta.clusters.remediations.with_raw_response.retrieve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: Together) -> None: + with client.beta.clusters.remediations.with_streaming_response.retrieve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Together) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.retrieve( + remediation_id="remediation_id", + cluster_id="", + instance_id="instance_id", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.retrieve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `remediation_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.retrieve( + remediation_id="", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + @parametrize + def test_method_list(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.list( + instance_id="instance_id", + cluster_id="cluster_id", + ) + assert_matches_type(RemediationListResponse, remediation, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.list( + instance_id="instance_id", + cluster_id="cluster_id", + mode=["REMEDIATION_MODE_VM_ONLY"], + order_by="order_by", + page_size=0, + page_token="page_token", + state=["PENDING_APPROVAL"], + trigger=["REMEDIATION_TRIGGER_MANUAL"], + ) + assert_matches_type(RemediationListResponse, remediation, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: Together) -> None: + response = client.beta.clusters.remediations.with_raw_response.list( + instance_id="instance_id", + cluster_id="cluster_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = response.parse() + assert_matches_type(RemediationListResponse, remediation, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: Together) -> None: + with client.beta.clusters.remediations.with_streaming_response.list( + instance_id="instance_id", + cluster_id="cluster_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = response.parse() + assert_matches_type(RemediationListResponse, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_list(self, client: Together) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.list( + instance_id="instance_id", + cluster_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.list( + instance_id="", + cluster_id="cluster_id", + ) + + @parametrize + def test_method_approve(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_method_approve_with_all_params(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + comment="comment", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_raw_response_approve(self, client: Together) -> None: + response = client.beta.clusters.remediations.with_raw_response.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_streaming_response_approve(self, client: Together) -> None: + with client.beta.clusters.remediations.with_streaming_response.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_approve(self, client: Together) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.approve( + remediation_id="remediation_id", + cluster_id="", + instance_id="instance_id", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `remediation_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.approve( + remediation_id="", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + @parametrize + def test_method_cancel(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.cancel( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_raw_response_cancel(self, client: Together) -> None: + response = client.beta.clusters.remediations.with_raw_response.cancel( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_streaming_response_cancel(self, client: Together) -> None: + with client.beta.clusters.remediations.with_streaming_response.cancel( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_cancel(self, client: Together) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.cancel( + remediation_id="remediation_id", + cluster_id="", + instance_id="instance_id", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.cancel( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `remediation_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.cancel( + remediation_id="", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + @parametrize + def test_method_reject(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_method_reject_with_all_params(self, client: Together) -> None: + remediation = client.beta.clusters.remediations.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + comment="comment", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_raw_response_reject(self, client: Together) -> None: + response = client.beta.clusters.remediations.with_raw_response.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + def test_streaming_response_reject(self, client: Together) -> None: + with client.beta.clusters.remediations.with_streaming_response.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_reject(self, client: Together) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.reject( + remediation_id="remediation_id", + cluster_id="", + instance_id="instance_id", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `remediation_id` but received ''"): + client.beta.clusters.remediations.with_raw_response.reject( + remediation_id="", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + +class TestAsyncRemediations: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_create(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.create( + instance_id="instance_id", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.create( + instance_id="instance_id", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + remediation_id="remediation_id", + reason="reason", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_raw_response_create(self, async_client: AsyncTogether) -> None: + response = await async_client.beta.clusters.remediations.with_raw_response.create( + instance_id="instance_id", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_streaming_response_create(self, async_client: AsyncTogether) -> None: + async with async_client.beta.clusters.remediations.with_streaming_response.create( + instance_id="instance_id", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_create(self, async_client: AsyncTogether) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.create( + instance_id="instance_id", + cluster_id="", + mode="REMEDIATION_MODE_VM_ONLY", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.create( + instance_id="", + cluster_id="cluster_id", + mode="REMEDIATION_MODE_VM_ONLY", + ) + + @parametrize + async def test_method_retrieve(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.retrieve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncTogether) -> None: + response = await async_client.beta.clusters.remediations.with_raw_response.retrieve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncTogether) -> None: + async with async_client.beta.clusters.remediations.with_streaming_response.retrieve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, async_client: AsyncTogether) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.retrieve( + remediation_id="remediation_id", + cluster_id="", + instance_id="instance_id", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.retrieve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `remediation_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.retrieve( + remediation_id="", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + @parametrize + async def test_method_list(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.list( + instance_id="instance_id", + cluster_id="cluster_id", + ) + assert_matches_type(RemediationListResponse, remediation, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.list( + instance_id="instance_id", + cluster_id="cluster_id", + mode=["REMEDIATION_MODE_VM_ONLY"], + order_by="order_by", + page_size=0, + page_token="page_token", + state=["PENDING_APPROVAL"], + trigger=["REMEDIATION_TRIGGER_MANUAL"], + ) + assert_matches_type(RemediationListResponse, remediation, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncTogether) -> None: + response = await async_client.beta.clusters.remediations.with_raw_response.list( + instance_id="instance_id", + cluster_id="cluster_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = await response.parse() + assert_matches_type(RemediationListResponse, remediation, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncTogether) -> None: + async with async_client.beta.clusters.remediations.with_streaming_response.list( + instance_id="instance_id", + cluster_id="cluster_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = await response.parse() + assert_matches_type(RemediationListResponse, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_list(self, async_client: AsyncTogether) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.list( + instance_id="instance_id", + cluster_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.list( + instance_id="", + cluster_id="cluster_id", + ) + + @parametrize + async def test_method_approve(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_method_approve_with_all_params(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + comment="comment", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_raw_response_approve(self, async_client: AsyncTogether) -> None: + response = await async_client.beta.clusters.remediations.with_raw_response.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_streaming_response_approve(self, async_client: AsyncTogether) -> None: + async with async_client.beta.clusters.remediations.with_streaming_response.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_approve(self, async_client: AsyncTogether) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.approve( + remediation_id="remediation_id", + cluster_id="", + instance_id="instance_id", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.approve( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `remediation_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.approve( + remediation_id="", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + @parametrize + async def test_method_cancel(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.cancel( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_raw_response_cancel(self, async_client: AsyncTogether) -> None: + response = await async_client.beta.clusters.remediations.with_raw_response.cancel( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_streaming_response_cancel(self, async_client: AsyncTogether) -> None: + async with async_client.beta.clusters.remediations.with_streaming_response.cancel( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_cancel(self, async_client: AsyncTogether) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.cancel( + remediation_id="remediation_id", + cluster_id="", + instance_id="instance_id", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.cancel( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `remediation_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.cancel( + remediation_id="", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + @parametrize + async def test_method_reject(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_method_reject_with_all_params(self, async_client: AsyncTogether) -> None: + remediation = await async_client.beta.clusters.remediations.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + comment="comment", + ) + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_raw_response_reject(self, async_client: AsyncTogether) -> None: + response = await async_client.beta.clusters.remediations.with_raw_response.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + @parametrize + async def test_streaming_response_reject(self, async_client: AsyncTogether) -> None: + async with async_client.beta.clusters.remediations.with_streaming_response.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="instance_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + remediation = await response.parse() + assert_matches_type(Remediation, remediation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_reject(self, async_client: AsyncTogether) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `cluster_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.reject( + remediation_id="remediation_id", + cluster_id="", + instance_id="instance_id", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `instance_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.reject( + remediation_id="remediation_id", + cluster_id="cluster_id", + instance_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `remediation_id` but received ''"): + await async_client.beta.clusters.remediations.with_raw_response.reject( + remediation_id="", + cluster_id="cluster_id", + instance_id="instance_id", + ) diff --git a/tests/api_resources/beta/clusters/test_storage.py b/tests/api_resources/beta/clusters/test_storage.py index 7a78f0d7a..de120cc40 100644 --- a/tests/api_resources/beta/clusters/test_storage.py +++ b/tests/api_resources/beta/clusters/test_storage.py @@ -30,6 +30,16 @@ def test_method_create(self, client: Together) -> None: ) assert_matches_type(ClusterStorage, storage, path=["response"]) + @parametrize + def test_method_create_with_all_params(self, client: Together) -> None: + storage = client.beta.clusters.storage.create( + region="region", + size_tib=0, + volume_name="volume_name", + is_lifecycle_independent=True, + ) + assert_matches_type(ClusterStorage, storage, path=["response"]) + @parametrize def test_raw_response_create(self, client: Together) -> None: response = client.beta.clusters.storage.with_raw_response.create( @@ -98,20 +108,24 @@ def test_path_params_retrieve(self, client: Together) -> None: @parametrize def test_method_update(self, client: Together) -> None: - storage = client.beta.clusters.storage.update() + storage = client.beta.clusters.storage.update( + volume_id="volume_id", + ) assert_matches_type(ClusterStorage, storage, path=["response"]) @parametrize def test_method_update_with_all_params(self, client: Together) -> None: storage = client.beta.clusters.storage.update( - size_tib=0, volume_id="volume_id", + size_tib=0, ) assert_matches_type(ClusterStorage, storage, path=["response"]) @parametrize def test_raw_response_update(self, client: Together) -> None: - response = client.beta.clusters.storage.with_raw_response.update() + response = client.beta.clusters.storage.with_raw_response.update( + volume_id="volume_id", + ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -120,7 +134,9 @@ def test_raw_response_update(self, client: Together) -> None: @parametrize def test_streaming_response_update(self, client: Together) -> None: - with client.beta.clusters.storage.with_streaming_response.update() as response: + with client.beta.clusters.storage.with_streaming_response.update( + volume_id="volume_id", + ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -134,6 +150,13 @@ def test_method_list(self, client: Together) -> None: storage = client.beta.clusters.storage.list() assert_matches_type(StorageListResponse, storage, path=["response"]) + @parametrize + def test_method_list_with_all_params(self, client: Together) -> None: + storage = client.beta.clusters.storage.list( + project_id="project_id", + ) + assert_matches_type(StorageListResponse, storage, path=["response"]) + @parametrize def test_raw_response_list(self, client: Together) -> None: response = client.beta.clusters.storage.with_raw_response.list() @@ -207,6 +230,16 @@ async def test_method_create(self, async_client: AsyncTogether) -> None: ) assert_matches_type(ClusterStorage, storage, path=["response"]) + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncTogether) -> None: + storage = await async_client.beta.clusters.storage.create( + region="region", + size_tib=0, + volume_name="volume_name", + is_lifecycle_independent=True, + ) + assert_matches_type(ClusterStorage, storage, path=["response"]) + @parametrize async def test_raw_response_create(self, async_client: AsyncTogether) -> None: response = await async_client.beta.clusters.storage.with_raw_response.create( @@ -275,20 +308,24 @@ async def test_path_params_retrieve(self, async_client: AsyncTogether) -> None: @parametrize async def test_method_update(self, async_client: AsyncTogether) -> None: - storage = await async_client.beta.clusters.storage.update() + storage = await async_client.beta.clusters.storage.update( + volume_id="volume_id", + ) assert_matches_type(ClusterStorage, storage, path=["response"]) @parametrize async def test_method_update_with_all_params(self, async_client: AsyncTogether) -> None: storage = await async_client.beta.clusters.storage.update( - size_tib=0, volume_id="volume_id", + size_tib=0, ) assert_matches_type(ClusterStorage, storage, path=["response"]) @parametrize async def test_raw_response_update(self, async_client: AsyncTogether) -> None: - response = await async_client.beta.clusters.storage.with_raw_response.update() + response = await async_client.beta.clusters.storage.with_raw_response.update( + volume_id="volume_id", + ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -297,7 +334,9 @@ async def test_raw_response_update(self, async_client: AsyncTogether) -> None: @parametrize async def test_streaming_response_update(self, async_client: AsyncTogether) -> None: - async with async_client.beta.clusters.storage.with_streaming_response.update() as response: + async with async_client.beta.clusters.storage.with_streaming_response.update( + volume_id="volume_id", + ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -311,6 +350,13 @@ async def test_method_list(self, async_client: AsyncTogether) -> None: storage = await async_client.beta.clusters.storage.list() assert_matches_type(StorageListResponse, storage, path=["response"]) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncTogether) -> None: + storage = await async_client.beta.clusters.storage.list( + project_id="project_id", + ) + assert_matches_type(StorageListResponse, storage, path=["response"]) + @parametrize async def test_raw_response_list(self, async_client: AsyncTogether) -> None: response = await async_client.beta.clusters.storage.with_raw_response.list() diff --git a/tests/api_resources/beta/jig/test_queue.py b/tests/api_resources/beta/jig/test_queue.py index 3058edbc8..e3d779f96 100644 --- a/tests/api_resources/beta/jig/test_queue.py +++ b/tests/api_resources/beta/jig/test_queue.py @@ -124,7 +124,7 @@ def test_streaming_response_metrics(self, client: Together) -> None: @parametrize def test_method_submit(self, client: Together) -> None: queue = client.beta.jig.queue.submit( - model="model", + model="my-queue-model", payload={"foo": "bar"}, ) assert_matches_type(QueueSubmitResponse, queue, path=["response"]) @@ -132,7 +132,7 @@ def test_method_submit(self, client: Together) -> None: @parametrize def test_method_submit_with_all_params(self, client: Together) -> None: queue = client.beta.jig.queue.submit( - model="model", + model="my-queue-model", payload={"foo": "bar"}, info={"foo": "bar"}, priority=0, @@ -142,7 +142,7 @@ def test_method_submit_with_all_params(self, client: Together) -> None: @parametrize def test_raw_response_submit(self, client: Together) -> None: response = client.beta.jig.queue.with_raw_response.submit( - model="model", + model="my-queue-model", payload={"foo": "bar"}, ) @@ -154,7 +154,7 @@ def test_raw_response_submit(self, client: Together) -> None: @parametrize def test_streaming_response_submit(self, client: Together) -> None: with client.beta.jig.queue.with_streaming_response.submit( - model="model", + model="my-queue-model", payload={"foo": "bar"}, ) as response: assert not response.is_closed @@ -273,7 +273,7 @@ async def test_streaming_response_metrics(self, async_client: AsyncTogether) -> @parametrize async def test_method_submit(self, async_client: AsyncTogether) -> None: queue = await async_client.beta.jig.queue.submit( - model="model", + model="my-queue-model", payload={"foo": "bar"}, ) assert_matches_type(QueueSubmitResponse, queue, path=["response"]) @@ -281,7 +281,7 @@ async def test_method_submit(self, async_client: AsyncTogether) -> None: @parametrize async def test_method_submit_with_all_params(self, async_client: AsyncTogether) -> None: queue = await async_client.beta.jig.queue.submit( - model="model", + model="my-queue-model", payload={"foo": "bar"}, info={"foo": "bar"}, priority=0, @@ -291,7 +291,7 @@ async def test_method_submit_with_all_params(self, async_client: AsyncTogether) @parametrize async def test_raw_response_submit(self, async_client: AsyncTogether) -> None: response = await async_client.beta.jig.queue.with_raw_response.submit( - model="model", + model="my-queue-model", payload={"foo": "bar"}, ) @@ -303,7 +303,7 @@ async def test_raw_response_submit(self, async_client: AsyncTogether) -> None: @parametrize async def test_streaming_response_submit(self, async_client: AsyncTogether) -> None: async with async_client.beta.jig.queue.with_streaming_response.submit( - model="model", + model="my-queue-model", payload={"foo": "bar"}, ) as response: assert not response.is_closed diff --git a/tests/api_resources/beta/jig/test_volumes.py b/tests/api_resources/beta/jig/test_volumes.py index bc547d755..a630f5a83 100644 --- a/tests/api_resources/beta/jig/test_volumes.py +++ b/tests/api_resources/beta/jig/test_volumes.py @@ -9,7 +9,10 @@ from together import Together, AsyncTogether from tests.utils import assert_matches_type -from together.types.beta.jig import Volume, VolumeListResponse +from together.types.beta.jig import ( + Volume, + VolumeListResponse, +) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -21,7 +24,7 @@ class TestVolumes: def test_method_create(self, client: Together) -> None: volume = client.beta.jig.volumes.create( content={}, - name="name", + name="x", type="readOnly", ) assert_matches_type(Volume, volume, path=["response"]) @@ -33,7 +36,7 @@ def test_method_create_with_all_params(self, client: Together) -> None: "source_prefix": "models/", "type": "files", }, - name="name", + name="x", type="readOnly", ) assert_matches_type(Volume, volume, path=["response"]) @@ -42,7 +45,7 @@ def test_method_create_with_all_params(self, client: Together) -> None: def test_raw_response_create(self, client: Together) -> None: response = client.beta.jig.volumes.with_raw_response.create( content={}, - name="name", + name="x", type="readOnly", ) @@ -55,7 +58,7 @@ def test_raw_response_create(self, client: Together) -> None: def test_streaming_response_create(self, client: Together) -> None: with client.beta.jig.volumes.with_streaming_response.create( content={}, - name="name", + name="x", type="readOnly", ) as response: assert not response.is_closed @@ -69,14 +72,22 @@ def test_streaming_response_create(self, client: Together) -> None: @parametrize def test_method_retrieve(self, client: Together) -> None: volume = client.beta.jig.volumes.retrieve( - "id", + id="id", + ) + assert_matches_type(Volume, volume, path=["response"]) + + @parametrize + def test_method_retrieve_with_all_params(self, client: Together) -> None: + volume = client.beta.jig.volumes.retrieve( + id="id", + version=0, ) assert_matches_type(Volume, volume, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: Together) -> None: response = client.beta.jig.volumes.with_raw_response.retrieve( - "id", + id="id", ) assert response.is_closed is True @@ -87,7 +98,7 @@ def test_raw_response_retrieve(self, client: Together) -> None: @parametrize def test_streaming_response_retrieve(self, client: Together) -> None: with client.beta.jig.volumes.with_streaming_response.retrieve( - "id", + id="id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -101,7 +112,7 @@ def test_streaming_response_retrieve(self, client: Together) -> None: def test_path_params_retrieve(self, client: Together) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): client.beta.jig.volumes.with_raw_response.retrieve( - "", + id="", ) @parametrize @@ -228,7 +239,7 @@ class TestAsyncVolumes: async def test_method_create(self, async_client: AsyncTogether) -> None: volume = await async_client.beta.jig.volumes.create( content={}, - name="name", + name="x", type="readOnly", ) assert_matches_type(Volume, volume, path=["response"]) @@ -240,7 +251,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncTogether) "source_prefix": "models/", "type": "files", }, - name="name", + name="x", type="readOnly", ) assert_matches_type(Volume, volume, path=["response"]) @@ -249,7 +260,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncTogether) async def test_raw_response_create(self, async_client: AsyncTogether) -> None: response = await async_client.beta.jig.volumes.with_raw_response.create( content={}, - name="name", + name="x", type="readOnly", ) @@ -262,7 +273,7 @@ async def test_raw_response_create(self, async_client: AsyncTogether) -> None: async def test_streaming_response_create(self, async_client: AsyncTogether) -> None: async with async_client.beta.jig.volumes.with_streaming_response.create( content={}, - name="name", + name="x", type="readOnly", ) as response: assert not response.is_closed @@ -276,14 +287,22 @@ async def test_streaming_response_create(self, async_client: AsyncTogether) -> N @parametrize async def test_method_retrieve(self, async_client: AsyncTogether) -> None: volume = await async_client.beta.jig.volumes.retrieve( - "id", + id="id", + ) + assert_matches_type(Volume, volume, path=["response"]) + + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncTogether) -> None: + volume = await async_client.beta.jig.volumes.retrieve( + id="id", + version=0, ) assert_matches_type(Volume, volume, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncTogether) -> None: response = await async_client.beta.jig.volumes.with_raw_response.retrieve( - "id", + id="id", ) assert response.is_closed is True @@ -294,7 +313,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncTogether) -> None: @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncTogether) -> None: async with async_client.beta.jig.volumes.with_streaming_response.retrieve( - "id", + id="id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -308,7 +327,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncTogether) -> async def test_path_params_retrieve(self, async_client: AsyncTogether) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): await async_client.beta.jig.volumes.with_raw_response.retrieve( - "", + id="", ) @parametrize diff --git a/tests/api_resources/beta/test_clusters.py b/tests/api_resources/beta/test_clusters.py index e1f452619..4d84e3b2b 100644 --- a/tests/api_resources/beta/test_clusters.py +++ b/tests/api_resources/beta/test_clusters.py @@ -46,19 +46,70 @@ def test_method_create_with_all_params(self, client: Together) -> None: num_gpus=0, nvidia_driver_version="nvidia_driver_version", region="region", + acceptance_tests_params={ + "dcgm_diag_level": "DCGM_DIAG_LEVEL_SHORT", + "dcgm_diag_skipped": True, + "enabled": True, + "gpu_burn_duration": 0, + "gpu_burn_skipped": True, + "nccl_multi_node_skipped": True, + "nccl_single_node_skipped": True, + }, + add_ons=[ + { + "add_on_type": "add_on_type", + "name": "name", + "config": { + "dashboard": {"enabled": True}, + "ingress": {"enabled": True}, + }, + } + ], + auto_scale=True, auto_scale_max_gpus=0, auto_scaled=True, capacity_pool_id="capacity_pool_id", + cluster_config={ + "load_balancer": "NONE", + "gpu_operator_version": "gpu_operator_version", + "ingress": {"enabled": True}, + "jumphost_enabled": True, + "kubernetes_dashboard_enabled": True, + "observability": {"enabled": True}, + "slurm_startup_scripts": { + "controller_epilog": "controller_epilog", + "controller_prolog": "controller_prolog", + "extra_slurm_conf": "extra_slurm_conf", + "login_init_script": "login_init_script", + "nodeset_init_script": "nodeset_init_script", + "worker_epilog": "worker_epilog", + "worker_prolog": "worker_prolog", + }, + }, cluster_type="KUBERNETES", duration_days=0, gpu_node_failover_enabled=True, install_traefik=True, + num_capacity_pool_gpus=0, + num_preemptible_gpus=0, + num_reserved_gpus=0, + oidc_config={ + "client_id": "client_id", + "group_claim": "group_claim", + "group_prefix": "group_prefix", + "issuer_url": "issuer_url", + "username_claim": "username_claim", + "username_prefix": "username_prefix", + "ca_cert": "ca_cert", + }, + project_id="project_id", reservation_end_time=parse_datetime("2019-12-27T18:11:19.117Z"), reservation_start_time=parse_datetime("2019-12-27T18:11:19.117Z"), shared_volume={ "region": "region", "size_tib": 0, "volume_name": "volume_name", + "is_lifecycle_independent": True, }, slurm_image="slurm_image", slurm_shm_size_gib=0, @@ -151,8 +202,36 @@ def test_method_update(self, client: Together) -> None: def test_method_update_with_all_params(self, client: Together) -> None: cluster = client.beta.clusters.update( cluster_id="cluster_id", + add_ons=[ + { + "name": "name", + "config": { + "dashboard": {"enabled": True}, + "ingress": {"enabled": True}, + }, + } + ], + cluster_config={ + "load_balancer": "NONE", + "gpu_operator_version": "gpu_operator_version", + "ingress": {"enabled": True}, + "jumphost_enabled": True, + "kubernetes_dashboard_enabled": True, + "observability": {"enabled": True}, + "slurm_startup_scripts": { + "controller_epilog": "controller_epilog", + "controller_prolog": "controller_prolog", + "extra_slurm_conf": "extra_slurm_conf", + "login_init_script": "login_init_script", + "nodeset_init_script": "nodeset_init_script", + "worker_epilog": "worker_epilog", + "worker_prolog": "worker_prolog", + }, + }, cluster_type="KUBERNETES", num_gpus=0, + num_preemptible_gpus=0, + num_reserved_gpus=0, reservation_end_time=parse_datetime("2019-12-27T18:11:19.117Z"), ) assert_matches_type(Cluster, cluster, path=["response"]) @@ -193,6 +272,13 @@ def test_method_list(self, client: Together) -> None: cluster = client.beta.clusters.list() assert_matches_type(ClusterListResponse, cluster, path=["response"]) + @parametrize + def test_method_list_with_all_params(self, client: Together) -> None: + cluster = client.beta.clusters.list( + project_id="project_id", + ) + assert_matches_type(ClusterListResponse, cluster, path=["response"]) + @parametrize def test_raw_response_list(self, client: Together) -> None: response = client.beta.clusters.with_raw_response.list() @@ -305,19 +391,70 @@ async def test_method_create_with_all_params(self, async_client: AsyncTogether) num_gpus=0, nvidia_driver_version="nvidia_driver_version", region="region", + acceptance_tests_params={ + "dcgm_diag_level": "DCGM_DIAG_LEVEL_SHORT", + "dcgm_diag_skipped": True, + "enabled": True, + "gpu_burn_duration": 0, + "gpu_burn_skipped": True, + "nccl_multi_node_skipped": True, + "nccl_single_node_skipped": True, + }, + add_ons=[ + { + "add_on_type": "add_on_type", + "name": "name", + "config": { + "dashboard": {"enabled": True}, + "ingress": {"enabled": True}, + }, + } + ], + auto_scale=True, auto_scale_max_gpus=0, auto_scaled=True, capacity_pool_id="capacity_pool_id", + cluster_config={ + "load_balancer": "NONE", + "gpu_operator_version": "gpu_operator_version", + "ingress": {"enabled": True}, + "jumphost_enabled": True, + "kubernetes_dashboard_enabled": True, + "observability": {"enabled": True}, + "slurm_startup_scripts": { + "controller_epilog": "controller_epilog", + "controller_prolog": "controller_prolog", + "extra_slurm_conf": "extra_slurm_conf", + "login_init_script": "login_init_script", + "nodeset_init_script": "nodeset_init_script", + "worker_epilog": "worker_epilog", + "worker_prolog": "worker_prolog", + }, + }, cluster_type="KUBERNETES", duration_days=0, gpu_node_failover_enabled=True, install_traefik=True, + num_capacity_pool_gpus=0, + num_preemptible_gpus=0, + num_reserved_gpus=0, + oidc_config={ + "client_id": "client_id", + "group_claim": "group_claim", + "group_prefix": "group_prefix", + "issuer_url": "issuer_url", + "username_claim": "username_claim", + "username_prefix": "username_prefix", + "ca_cert": "ca_cert", + }, + project_id="project_id", reservation_end_time=parse_datetime("2019-12-27T18:11:19.117Z"), reservation_start_time=parse_datetime("2019-12-27T18:11:19.117Z"), shared_volume={ "region": "region", "size_tib": 0, "volume_name": "volume_name", + "is_lifecycle_independent": True, }, slurm_image="slurm_image", slurm_shm_size_gib=0, @@ -410,8 +547,36 @@ async def test_method_update(self, async_client: AsyncTogether) -> None: async def test_method_update_with_all_params(self, async_client: AsyncTogether) -> None: cluster = await async_client.beta.clusters.update( cluster_id="cluster_id", + add_ons=[ + { + "name": "name", + "config": { + "dashboard": {"enabled": True}, + "ingress": {"enabled": True}, + }, + } + ], + cluster_config={ + "load_balancer": "NONE", + "gpu_operator_version": "gpu_operator_version", + "ingress": {"enabled": True}, + "jumphost_enabled": True, + "kubernetes_dashboard_enabled": True, + "observability": {"enabled": True}, + "slurm_startup_scripts": { + "controller_epilog": "controller_epilog", + "controller_prolog": "controller_prolog", + "extra_slurm_conf": "extra_slurm_conf", + "login_init_script": "login_init_script", + "nodeset_init_script": "nodeset_init_script", + "worker_epilog": "worker_epilog", + "worker_prolog": "worker_prolog", + }, + }, cluster_type="KUBERNETES", num_gpus=0, + num_preemptible_gpus=0, + num_reserved_gpus=0, reservation_end_time=parse_datetime("2019-12-27T18:11:19.117Z"), ) assert_matches_type(Cluster, cluster, path=["response"]) @@ -452,6 +617,13 @@ async def test_method_list(self, async_client: AsyncTogether) -> None: cluster = await async_client.beta.clusters.list() assert_matches_type(ClusterListResponse, cluster, path=["response"]) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncTogether) -> None: + cluster = await async_client.beta.clusters.list( + project_id="project_id", + ) + assert_matches_type(ClusterListResponse, cluster, path=["response"]) + @parametrize async def test_raw_response_list(self, async_client: AsyncTogether) -> None: response = await async_client.beta.clusters.with_raw_response.list() diff --git a/tests/api_resources/beta/test_jig.py b/tests/api_resources/beta/test_jig.py index 30f5e8303..dabf106ca 100644 --- a/tests/api_resources/beta/test_jig.py +++ b/tests/api_resources/beta/test_jig.py @@ -290,6 +290,8 @@ def test_method_retrieve_logs_with_all_params(self, client: Together) -> None: jig = client.beta.jig.retrieve_logs( id="id", replica_id="replica_id", + revision="revision", + version="version", ) assert_matches_type(DeploymentLogs, jig, path=["response"]) @@ -599,6 +601,8 @@ async def test_method_retrieve_logs_with_all_params(self, async_client: AsyncTog jig = await async_client.beta.jig.retrieve_logs( id="id", replica_id="replica_id", + revision="revision", + version="version", ) assert_matches_type(DeploymentLogs, jig, path=["response"]) diff --git a/tests/cli/test_beta_clusters.py b/tests/cli/test_beta_clusters.py index 5f2bf86ac..ead923348 100644 --- a/tests/cli/test_beta_clusters.py +++ b/tests/cli/test_beta_clusters.py @@ -53,6 +53,28 @@ def _cluster_body(cluster_id: str = "cluster-1", name: str = "my-cluster", **ove } +def _remediation_body(remediation_id: str = "rem-1", **overrides: Any) -> dict[str, Any]: + body: dict[str, Any] = { + "id": remediation_id, + "cluster_id": "c1", + "instance_id": "i1", + "mode": "REMEDIATION_MODE_VM_ONLY", + "state": "PENDING_APPROVAL", + "trigger": "REMEDIATION_TRIGGER_AUTOMATED", + "reason": "health check failed", + } + body.update(overrides) + return body + + +def _remediation_list_body(*remediations: dict[str, Any]) -> dict[str, Any]: + return { + "has_next": False, + "next_page_token": "", + "remediations": list(remediations), + } + + class TestBetaClustersList: @pytest.mark.respx(base_url=base_url) def test_list_table(self, respx_mock: MockRouter, cli_runner: CliRunner) -> None: @@ -238,3 +260,208 @@ def test_storage_delete_json(self, respx_mock: MockRouter, cli_runner: CliRunner result = cli_runner.invoke(["beta", "clusters", "storage", "delete", "vol-1", "--json"]) assert json.loads(result.output) == {"success": True} assert result.exit_code == 0 + + +class TestBetaClustersRemediations: + @pytest.mark.respx(base_url=base_url) + def test_remediations_create_json(self, respx_mock: MockRouter, cli_runner: CliRunner) -> None: + route = respx_mock.post("/compute/clusters/c1/instances/i1/remediations").mock( + return_value=httpx.Response(200, json=_remediation_body("rem-created", state="PENDING")) + ) + result = cli_runner.invoke( + [ + "beta", + "clusters", + "remediations", + "create", + "c1", + "i1", + "--mode", + "VM_ONLY", + "--reason", + "node unhealthy", + "--remediation-id", + "rem-created", + "--json", + ], + ) + + assert json.loads(result.output)["id"] == "rem-created" + request = cast(Call, route.calls[0]).request + assert request.url.params["remediation_id"] == "rem-created" + assert json.loads(request.content.decode()) == { + "mode": "REMEDIATION_MODE_VM_ONLY", + "reason": "node unhealthy", + } + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_list_uses_wildcard_when_instance_id_omitted( + self, respx_mock: MockRouter, cli_runner: CliRunner + ) -> None: + payload = _remediation_list_body(_remediation_body()) + route = respx_mock.get("/compute/clusters/c1/instances/-/remediations").mock( + return_value=httpx.Response(200, json=payload) + ) + result = cli_runner.invoke(["beta", "clusters", "remediations", "list", "c1", "--json"]) + + assert json.loads(result.output) == payload + assert cast(Call, route.calls[0]).request.url.path == "/compute/clusters/c1/instances/-/remediations" + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_list_accepts_instance_id(self, respx_mock: MockRouter, cli_runner: CliRunner) -> None: + payload = _remediation_list_body(_remediation_body()) + route = respx_mock.get("/compute/clusters/c1/instances/i1/remediations").mock( + return_value=httpx.Response(200, json=payload) + ) + result = cli_runner.invoke(["beta", "clusters", "remediations", "list", "c1", "i1", "--json"]) + + assert json.loads(result.output) == payload + assert cast(Call, route.calls[0]).request.url.path == "/compute/clusters/c1/instances/i1/remediations" + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_list_table_uses_instance_name(self, respx_mock: MockRouter, cli_runner: CliRunner) -> None: + payload = _remediation_list_body(_remediation_body(instance_name="gpu-node-a")) + respx_mock.get("/compute/clusters/c1/instances/-/remediations").mock( + return_value=httpx.Response(200, json=payload) + ) + + result = cli_runner.invoke(["beta", "clusters", "remediations", "list", "c1"]) + + assert "gpu-node-a (i1)" in result.output + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_list_table_falls_back_to_instance_id( + self, respx_mock: MockRouter, cli_runner: CliRunner + ) -> None: + payload = _remediation_list_body(_remediation_body()) + respx_mock.get("/compute/clusters/c1/instances/-/remediations").mock( + return_value=httpx.Response(200, json=payload) + ) + + result = cli_runner.invoke(["beta", "clusters", "remediations", "list", "c1"]) + + assert "i1" in result.output + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_list_accepts_filters(self, respx_mock: MockRouter, cli_runner: CliRunner) -> None: + payload = _remediation_list_body(_remediation_body()) + route = respx_mock.get("/compute/clusters/c1/instances/-/remediations").mock( + return_value=httpx.Response(200, json=payload) + ) + result = cli_runner.invoke( + [ + "beta", + "clusters", + "remediations", + "list", + "c1", + "--mode", + "VM_ONLY", + "--mode", + "REBOOT_VM", + "--state", + "PENDING_APPROVAL", + "--trigger", + "AUTOMATED", + "--after", + "next-token", + "--json", + ] + ) + + params = cast(Call, route.calls[0]).request.url.params + assert params["mode"] == "REMEDIATION_MODE_VM_ONLY,REMEDIATION_MODE_REBOOT_VM" + assert params["state"] == "PENDING_APPROVAL" + assert params["trigger"] == "REMEDIATION_TRIGGER_AUTOMATED" + assert params["page_token"] == "next-token" + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_retrieve_resolves_cluster_and_instance( + self, respx_mock: MockRouter, cli_runner: CliRunner + ) -> None: + body = _remediation_body("rem-get", state="RUNNING") + respx_mock.get("/compute/clusters").mock( + return_value=httpx.Response(200, json={"clusters": [_cluster_body("c1")]}) + ) + respx_mock.get("/compute/clusters/c1/instances/-/remediations").mock( + return_value=httpx.Response(200, json=_remediation_list_body(_remediation_body("rem-get"))) + ) + route = respx_mock.get("/compute/clusters/c1/instances/i1/remediations/rem-get").mock( + return_value=httpx.Response(200, json=body) + ) + + result = cli_runner.invoke(["beta", "clusters", "remediations", "get", "rem-get", "--json"]) + + assert json.loads(result.output) == body + assert cast(Call, route.calls[0]).request.url.path == "/compute/clusters/c1/instances/i1/remediations/rem-get" + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_approve_resolves_cluster_and_instance( + self, respx_mock: MockRouter, cli_runner: CliRunner + ) -> None: + respx_mock.get("/compute/clusters").mock( + return_value=httpx.Response(200, json={"clusters": [_cluster_body("c1")]}) + ) + respx_mock.get("/compute/clusters/c1/instances/-/remediations").mock( + return_value=httpx.Response(200, json=_remediation_list_body(_remediation_body("rem-approve"))) + ) + route = respx_mock.post("/compute/clusters/c1/instances/i1/remediations/rem-approve/approve").mock( + return_value=httpx.Response(200, json=_remediation_body("rem-approve", state="PENDING")) + ) + + result = cli_runner.invoke( + ["beta", "clusters", "remediations", "approve", "rem-approve", "--comment", "go", "--json"] + ) + + assert json.loads(result.output)["state"] == "PENDING" + assert json.loads(cast(Call, route.calls[0]).request.content.decode()) == {"comment": "go"} + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_cancel_resolves_cluster_and_instance( + self, respx_mock: MockRouter, cli_runner: CliRunner + ) -> None: + respx_mock.get("/compute/clusters").mock( + return_value=httpx.Response(200, json={"clusters": [_cluster_body("c1")]}) + ) + respx_mock.get("/compute/clusters/c1/instances/-/remediations").mock( + return_value=httpx.Response(200, json=_remediation_list_body(_remediation_body("rem-cancel"))) + ) + route = respx_mock.post("/compute/clusters/c1/instances/i1/remediations/rem-cancel/cancel").mock( + return_value=httpx.Response(200, json=_remediation_body("rem-cancel", state="CANCELLED")) + ) + + result = cli_runner.invoke(["beta", "clusters", "remediations", "cancel", "rem-cancel", "--json"]) + + assert json.loads(result.output)["state"] == "CANCELLED" + assert route.calls + assert result.exit_code == 0 + + @pytest.mark.respx(base_url=base_url) + def test_remediations_reject_resolves_cluster_and_instance( + self, respx_mock: MockRouter, cli_runner: CliRunner + ) -> None: + respx_mock.get("/compute/clusters").mock( + return_value=httpx.Response(200, json={"clusters": [_cluster_body("c1")]}) + ) + respx_mock.get("/compute/clusters/c1/instances/-/remediations").mock( + return_value=httpx.Response(200, json=_remediation_list_body(_remediation_body("rem-reject"))) + ) + route = respx_mock.post("/compute/clusters/c1/instances/i1/remediations/rem-reject/reject").mock( + return_value=httpx.Response(200, json=_remediation_body("rem-reject", state="CANCELLED")) + ) + + result = cli_runner.invoke( + ["beta", "clusters", "remediations", "reject", "rem-reject", "--comment", "skip", "--json"] + ) + + assert json.loads(result.output)["state"] == "CANCELLED" + assert json.loads(cast(Call, route.calls[0]).request.content.decode()) == {"comment": "skip"} + assert result.exit_code == 0 diff --git a/tests/cli/test_beta_jig.py b/tests/cli/test_beta_jig.py index 446de3d1b..35a89402a 100644 --- a/tests/cli/test_beta_jig.py +++ b/tests/cli/test_beta_jig.py @@ -219,6 +219,59 @@ def test_unset_missing_secret_message(self, tmp_path: Path, cli_runner: CliRunne assert result.exit_code == 0 +class TestBetaJigBuild: + def test_build_blocked_when_deploy_image_set(self, tmp_path: Path, cli_runner: CliRunner) -> None: + with patch.object(_jig_mod.Config, "__post_init__", _noop_config_post_init): + cfg = _jig_mod.Config( + model_name=_DEPLOY_NAME, + image=_jig_mod.ImageConfig(), + deploy=_jig_mod.DeployConfig(image="ghcr.io/org/prebuilt:latest"), + _path=tmp_path / "pyproject.toml", + _unique_name_hint="h", + ) + + def _find(*_args: Any): + return cfg + + with patch.object(_jig_mod.Config, "find", classmethod(_find)): + with _chdir(tmp_path): + result = cli_runner.invoke(["beta", "jig", "build"]) + assert result.exit_code == 1 + assert "deploy.image is set" in result.output + + +class TestBetaJigLogs: + @pytest.mark.respx(base_url=base_url) + def test_logs_forwards_sdk_filters(self, respx_mock: MockRouter, tmp_path: Path, cli_runner: CliRunner) -> None: + _write_jig_project(tmp_path) + route = respx_mock.get(f"/deployments/{_DEPLOY_NAME}/logs").mock( + return_value=httpx.Response(200, json={"lines": ["line 1", "line 2"]}) + ) + + with _chdir(tmp_path): + result = cli_runner.invoke( + [ + "beta", + "jig", + "logs", + "--replica-id", + "replica-1", + "--revision", + "revision-1", + "--image-version", + "v2", + ] + ) + + assert "line 1" in result.output + assert "line 2" in result.output + request = cast(Call, route.calls[0]).request + assert request.url.params["replica_id"] == "replica-1" + assert request.url.params["revision"] == "revision-1" + assert request.url.params["version"] == "v2" + assert result.exit_code == 0 + + class TestBetaJigVolumes: @pytest.mark.respx(base_url=base_url) def test_delete(self, respx_mock: MockRouter, tmp_path: Path, cli_runner: CliRunner) -> None: @@ -253,6 +306,21 @@ def test_describe_json(self, respx_mock: MockRouter, tmp_path: Path, cli_runner: assert json.loads(result.output) == payload assert result.exit_code == 0 + @pytest.mark.respx(base_url=base_url) + def test_describe_forwards_version(self, respx_mock: MockRouter, tmp_path: Path, cli_runner: CliRunner) -> None: + _write_jig_project(tmp_path) + route = respx_mock.get("/deployments/storage/volumes/v1").mock( + return_value=httpx.Response(200, json=_volume_api_body("v1", current_version=1)) + ) + + with _chdir(tmp_path): + result = cli_runner.invoke(["beta", "jig", "volumes", "describe", "--name", "v1", "--volume-version", "1"]) + + assert "Version" in result.output + request = cast(Call, route.calls[0]).request + assert request.url.params["version"] == "1" + assert result.exit_code == 0 + @pytest.mark.respx(base_url=base_url) def test_list_json(self, respx_mock: MockRouter, tmp_path: Path, cli_runner: CliRunner) -> None: _write_jig_project(tmp_path) diff --git a/tests/cli/test_evals.py b/tests/cli/test_evals.py index 319a95a61..0bc318f17 100644 --- a/tests/cli/test_evals.py +++ b/tests/cli/test_evals.py @@ -59,3 +59,41 @@ def test_status(self, respx_mock: MockRouter, cli_runner: CliRunner) -> None: result = cli_runner.invoke(["evals", "status", "eval-wf-1"]) assert result.exit_code == 0 assert "Status: completed" in result.output + + +class TestEvalsCreate: + @pytest.mark.respx(base_url=base_url) + def test_compare_passes_disable_position_bias_correction( + self, respx_mock: MockRouter, cli_runner: CliRunner + ) -> None: + route = respx_mock.post("/evaluation").mock( + return_value=httpx.Response(200, json={"workflow_id": "eval-wf-1", "status": "pending"}) + ) + + result = cli_runner.invoke( + [ + "evals", + "create", + "--type", + "compare", + "--judge-model", + "Qwen/Qwen3.5-9B", + "--judge-model-source", + "serverless", + "--judge-system-template", + "Choose the better response.", + "--input-data-file-path", + "file-123", + "--model-a-field", + "response_a", + "--model-b-field", + "response_b", + "--disable-position-bias-correction", + ] + ) + + assert result.exit_code == 0 + req = cast(Call, route.calls[0]).request + payload = json.loads(req.content) + assert payload["type"] == "compare" + assert payload["parameters"]["disable_position_bias_correction"] is True diff --git a/tests/cli/test_fine_tuning.py b/tests/cli/test_fine_tuning.py index 08d1b3d1f..f10c22692 100644 --- a/tests/cli/test_fine_tuning.py +++ b/tests/cli/test_fine_tuning.py @@ -3,12 +3,14 @@ import os import json import importlib +from typing import cast from pathlib import Path from unittest.mock import patch import httpx import pytest from respx import MockRouter +from respx.models import Call from tests.cli.utils import CliRunner @@ -70,6 +72,16 @@ "step": 5, } +_FT_METRICS_BODY = { + "metrics": [ + { + "global_step": 0, + "train_loss": 1.25, + "logged_at": "2024-01-01T00:00:00Z", + } + ] +} + class TestFineTuningList: @pytest.mark.respx(base_url=base_url) @@ -197,6 +209,40 @@ def test_list_checkpoints_empty_message(self, respx_mock: MockRouter, cli_runner assert "No checkpoints found" in result.output +class TestFineTuningListMetrics: + @pytest.mark.respx(base_url=base_url) + def test_list_metrics_json_includes_zero_step_filters(self, respx_mock: MockRouter, cli_runner: CliRunner) -> None: + route = respx_mock.get("/fine-tunes/ft-1/metrics").mock(return_value=httpx.Response(200, json=_FT_METRICS_BODY)) + + result = cli_runner.invoke( + [ + "fine-tuning", + "list-metrics", + "ft-1", + "--global-step-from", + "0", + "--global-step-to", + "0", + "--logged-at-from", + "2024-01-01T00:00:00+00:00", + "--logged-at-to", + "2024-01-02T00:00:00+00:00", + "--resolution", + "50", + "--json", + ] + ) + + assert result.exit_code == 0 + params = cast(Call, route.calls[0]).request.url.params + assert params["global_step_from"] == "0" + assert params["global_step_to"] == "0" + assert params["logged_at_from"] == "2024-01-01T00:00:00+00:00" + assert params["logged_at_to"] == "2024-01-02T00:00:00+00:00" + assert params["resolution"] == "50" + assert json.loads(result.output) == _FT_METRICS_BODY["metrics"] + + class TestFineTuningDownload: @pytest.mark.respx(base_url=base_url) def test_download_invokes_download_manager( diff --git a/tests/test_plots_engine.py b/tests/test_plots_engine.py new file mode 100644 index 000000000..f612cffc2 --- /dev/null +++ b/tests/test_plots_engine.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import pytest + +from together.lib.cli.components.plots._engine import ( + _interpolate, + _uniform_grid, + render_line_chart, + render_sparklines, +) +from together.lib.cli.components.plot_finetune_metrics import _step_label + + +def constant_series(n: int = 5, value: float = 1.0) -> list[tuple[float, float]]: + return [(float(i), value) for i in range(n)] + + +# Shared deterministic series used by golden-output tests +_LOSS = [(float(i), 1.0 - i * 0.1) for i in range(10)] # 1.0 → 0.1 +_ACCURACY = [(float(i), 0.5 + i * 0.05) for i in range(10)] # 0.5 → 0.95 +_WIDE = [(float(i), 10.0**i) for i in range(5)] # 1, 10, 100, 1000, 10000 + +_LOSS_XS = [p[0] for p in _LOSS] +_LOSS_YS = [p[1] for p in _LOSS] +_ACCURACY_XS = [p[0] for p in _ACCURACY] +_ACCURACY_YS = [p[1] for p in _ACCURACY] +_WIDE_XS = [p[0] for p in _WIDE] +_WIDE_YS = [p[1] for p in _WIDE] + + +def _interp(xs: list[float], ys: list[float], x_grid: list[float]) -> list[float]: + """Helper: interpolate a single series onto x_grid.""" + return _interpolate(xs, {"s": ys}, x_grid)["s"] + + +class TestInterpolate: + def test_output_length_equals_grid(self) -> None: + xs = [float(i) for i in range(10)] + ys = [float(i) for i in range(10)] + x_grid = _uniform_grid(xs, 5) + result = _interp(xs, ys, x_grid) + assert len(result) == 5 + + def test_linear_data_interpolates_exactly(self) -> None: + xs = [0.0, 9.0] + ys = [0.0, 9.0] + x_grid = _uniform_grid(xs, 10) + result = _interp(xs, ys, x_grid) + # grid points are 0.0, 0.9, 1.8, ..., 8.1 — y=x so values match + assert result == pytest.approx(x_grid, abs=1e-9) # type: ignore[misc] + + def test_constant_series_stays_constant(self) -> None: + xs = [float(i) for i in range(20)] + ys = [7.0] * 20 + x_grid = _uniform_grid(xs, 10) + result = _interp(xs, ys, x_grid) + assert result == pytest.approx([7.0] * 10, abs=1e-9) # type: ignore[misc] + + def test_left_clamp(self) -> None: + xs = [5.0, 9.0] + ys = [99.0, 99.0] + x_grid = _uniform_grid([0.0, 9.0], 10) + result = _interp(xs, ys, x_grid) + assert result == [99.0] * 10 + + def test_right_clamp(self) -> None: + xs = [0.0, 2.0] + ys = [42.0, 42.0] + x_grid = _uniform_grid([0.0, 9.0], 10) + result = _interp(xs, ys, x_grid) + assert result == [42.0] * 10 + + def test_single_point_fills_all(self) -> None: + xs = [5.0] + ys = [3.14] + x_grid = _uniform_grid([0.0, 9.0], 8) + result = _interp(xs, ys, x_grid) + assert result == [3.14] * 8 + + def test_uniform_grid_length(self) -> None: + assert len(_uniform_grid([0.0, 10.0], 5)) == 5 + + def test_uniform_grid_endpoints(self) -> None: + grid = _uniform_grid([0.0, 9.0], 10) + assert grid[0] == pytest.approx(0.0) # type: ignore[misc] + assert grid[-1] == pytest.approx(9.0) # type: ignore[misc] + + +class TestRenderSparklines: + def test_empty_series_returns_no_data_message(self) -> None: + result = render_sparklines("loss", [], [], width=20) + assert result.plain == "No plottable data." + + def test_single_series_golden(self) -> None: + result = render_sparklines("loss", _LOSS_XS, _LOSS_YS, width=20) + assert result.plain == " loss ██▇▇▆▆▅▅▅▄▄▃▃▃▂▂▁▁ 1 → 0.1\n" + + def test_constant_series_golden(self) -> None: + _flat = constant_series(10, 5.0) + result = render_sparklines("flat", [p[0] for p in _flat], [p[1] for p in _flat], width=20) + assert result.plain == " flat 5 → 5\n" + + def test_single_point_golden(self) -> None: + result = render_sparklines("single", [0.0], [1.0], width=20) + assert result.plain == " single 1 → 1\n" + + def test_log_scale_golden(self) -> None: + result = render_sparklines("wide", _WIDE_XS, _WIDE_YS, width=20, y_log=True) + assert result.plain == " wide ▁▁▂▂▂▃▃▄▄▅▅▆▆▆▇▇███ 1 → 1e+04\n" # leading space = first sparkline block + + def test_label_width_truncates_with_ellipsis(self) -> None: + result = render_sparklines("verylongname", _LOSS_XS, _LOSS_YS, width=20, label_width=6) + # "verylongname" (12 chars) truncated to label_width=6: "ver..." + assert result.plain.startswith(" ver... ") + + def test_label_width_truncates_long_name_aligned(self) -> None: + # A name longer than label_width is truncated with ..., staying aligned + r1 = render_sparklines("loss", _LOSS_XS, _LOSS_YS, width=20, label_width=8) + r2 = render_sparklines("averylongmetricname", _LOSS_XS, _LOSS_YS, width=20, label_width=8) + assert r1.plain == " loss ██▇▇▆▆▅▅▅▄▄▃▃▃▂▂▁▁ 1 → 0.1\n" # right-justified + assert r2.plain == " avery... ██▇▇▆▆▅▅▅▄▄▃▃▃▂▂▁▁ 1 → 0.1\n" # truncated to 8 + + def test_aligned_across_calls(self) -> None: + # Pass the same label_width to both calls → sparklines start at the same column + shared_w = 8 + r1 = render_sparklines("loss", _LOSS_XS, _LOSS_YS, width=20, label_width=shared_w) + r2 = render_sparklines("accuracy", _ACCURACY_XS, _ACCURACY_YS, width=20, label_width=shared_w) + assert r1.plain == " loss ██▇▇▆▆▅▅▅▄▄▃▃▃▂▂▁▁ 1 → 0.1\n" # "loss" right-justified in 8 + assert r2.plain == " accuracy ▁▁▂▂▃▃▃▄▄▅▅▅▆▆▇▇██ 0.5 → 0.95\n" # "accuracy" fills 8 exactly + + @pytest.mark.parametrize( + "bad_value, expected", + [ + (float("-inf"), " loss ██▇▇▆▆▅▅▅▄ ▃▃▂▂▁▁ 1 → 0.1\n"), + (float("nan"), " loss ██▇▇▆▆▅▅▅▄ ▃▃▂▂▁▁ 1 → 0.1\n"), + (float("inf"), " loss ██▇▇▆▆▅▅▅▄██▃▃▂▂▁▁ 1 → 0.1\n"), + ], + ids=["neg_inf", "nan", "pos_inf"], + ) + def test_non_finite_rendered_as_extreme_block_golden(self, bad_value: float, expected: str) -> None: + # -inf/NaN → blank (bottom) block; +inf → █ (top) block. + xs = [float(i) for i in range(10)] + ys = [(1.0 - i * 0.1) if i != 5 else bad_value for i in range(10)] + result = render_sparklines("loss", xs, ys, width=20) + assert result.plain == expected + + +class TestRenderLineChart: + def test_empty_series_returns_no_data_message(self) -> None: + result = render_line_chart([], {}) + assert result.plain == "No plottable data." + + def test_single_series_golden(self) -> None: + result = render_line_chart( + _LOSS_XS, + {"loss": _LOSS_YS}, + width=20, + height=4, + n_xticks=3, + x_label=_step_label, + ) + assert result.plain == ( + " loss (0 – 9) 1 → 0.1\n" + " 1┼───╮ \n" + " 0.7┼ ╰─────╮ \n" + " 0.4┼ ╰─────╮ \n" + " 0.1┼ ╰─── \n" + " └┬─────────┬────────┬\n" + " 0 4 9\n" + ) + + def test_multi_series_golden(self) -> None: + # loss and accuracy share the same x-axis (steps 0–9) + result = render_line_chart( + _LOSS_XS, + {"loss": _LOSS_YS, "accuracy": _ACCURACY_YS}, + width=20, + height=4, + n_xticks=3, + x_label=_step_label, + ) + assert result.plain == ( + " loss (0 – 9) 1 → 0.1\n" + " accuracy (0 – 9) 0.5 → 0.95\n" + " 1┼───╮ ╭──── \n" + " 0.7┼ ╭───────────╯ \n" + " 0.4┼──╯ ╰─────╮ \n" + " 0.1┼ ╰─── \n" + " └┬─────────┬────────┬\n" + " 0 4 9\n" + ) + + def test_log_scale_golden(self) -> None: + result = render_line_chart( + _WIDE_XS, + {"metric": _WIDE_YS}, + width=20, + height=4, + n_xticks=3, + x_label=_step_label, + y_log=True, + ) + assert result.plain == ( + " metric (0 – 4) 1 → 1e+04\n" + " 1e+04┼ ╭──── \n" + " 464┼ ╭────╯ \n" + " 21.5┼ ╭───────╯ \n" + " 1┼─╯ \n" + " └┬─────────┬────────┬\n" + " 0 2 4\n" + ) + + def test_constant_series_golden(self) -> None: + _flat = constant_series(10, 42.0) + result = render_line_chart( + [p[0] for p in _flat], + {"flat": [p[1] for p in _flat]}, + width=20, + height=4, + x_label=_step_label, + ) + assert result.plain == ( + " flat (0 – 9) 42 → 42\n" + " 42┼ \n" + " 42┼ \n" + " 42┼ \n" + " 42┼─────────────────── \n" + " └┬─────────┬────────┬\n" + " 0 4 9\n" + ) + + def test_custom_x_label_golden(self) -> None: + result = render_line_chart( + _LOSS_XS, + {"m": _LOSS_YS}, + width=20, + height=4, + n_xticks=3, + x_label=lambda x: f"step{int(x)}", + ) + assert result.plain == ( + " m (step0 – step9) 1 → 0.1\n" + " 1┼───╮ \n" + " 0.7┼ ╰─────╮ \n" + " 0.4┼ ╰─────╮ \n" + " 0.1┼ ╰─── \n" + " └┬─────────┬────────┬\n" + " step0 step4 step9\n" + ) + + @pytest.mark.parametrize( + "bad_value, expected", + [ + ( + float("-inf"), + ( + " loss (0 – 9) 1 → 0.1\n" + " 1┼───╮ \n" + " 0.7┼ ╰─────╮ \n" + " 0.4┼ │ ╭───╮ \n" + " 0.1┼ │ │ ╰─── \n" + " └┬────────┴┬┴───────┬\n" + " 0 4 9\n" + ), + ), + ( + float("nan"), + ( + " loss (0 – 9) 1 → 0.1\n" + " 1┼───╮ \n" + " 0.7┼ ╰────── \n" + " 0.4┼ ────╮ \n" + " 0.1┼ ╰─── \n" + " └┬─────────┬────────┬\n" + " 0 4 9\n" + ), + ), + ( + float("inf"), + ( + " loss (0 – 9) 1 → 0.1\n" + " 1┼───╮ │ │ \n" + " 0.7┼ ╰─────╯ │ \n" + " 0.4┼ ╰───╮ \n" + " 0.1┼ ╰─── \n" + " └┬─────────┬────────┬\n" + " 0 4 9\n" + ), + ), + ], + ids=["neg_inf", "nan", "pos_inf"], + ) + def test_non_finite_rendered_as_extreme_golden(self, bad_value: float, expected: str) -> None: + # -inf/NaN → dip to x-axis border; +inf → spike to top data row. + xs = [float(i) for i in range(10)] + ys = [(1.0 - i * 0.1) if i != 5 else bad_value for i in range(10)] + result = render_line_chart(xs, {"loss": ys}, width=20, height=4, n_xticks=3, x_label=_step_label) + assert result.plain == expected + + def test_label_width_caps_y_axis(self) -> None: + # "1e+04" is exactly 5 chars; label_width=5 fits it without truncation + result = render_line_chart( + _WIDE_XS, + {"metric": _WIDE_YS}, + width=20, + height=4, + x_label=_step_label, + y_log=True, + label_width=5, + ) + assert result.plain == ( + " metric (0 – 4) 1 → 1e+04\n" + "1e+04┼ ╭──── \n" + " 464┼ ╭────╯ \n" + " 21.5┼ ╭───────╯ \n" + " 1┼─╯ \n" + " └┬─────────┬────────┬\n" + " 0 2 4\n" + ) diff --git a/uv.lock b/uv.lock index a10f2d912..abe93dbc1 100644 --- a/uv.lock +++ b/uv.lock @@ -1559,7 +1559,7 @@ wheels = [ [[package]] name = "together" -version = "2.12.0" +version = "2.14.0" source = { editable = "." } dependencies = [ { name = "anyio" },