diff --git a/mkdocs/docs/reference/dstack.yml/fleet.md b/mkdocs/docs/reference/dstack.yml/fleet.md index b6d0553853..76312a52b9 100644 --- a/mkdocs/docs/reference/dstack.yml/fleet.md +++ b/mkdocs/docs/reference/dstack.yml/fleet.md @@ -2,76 +2,93 @@ The `fleet` configuration type allows creating and updating fleets. -## Root reference -#SCHEMA# dstack._internal.core.models.fleets.FleetConfiguration - overrides: - show_root_heading: false - type: - required: true +=== "Backend fleet" -### `ssh_config` { data-toc-label="ssh_config" } + ## Root reference -#SCHEMA# dstack._internal.core.models.fleets.SSHParams - overrides: - show_root_heading: false - item_id_prefix: ssh_config- + #SCHEMA# dstack._internal.core.models.fleets.BackendFleetConfiguration + overrides: + show_root_heading: false + type: + required: true + nodes: + required: true -#### `ssh_config.proxy_jump` { #ssh_config-proxy_jump data-toc-label="proxy_jump" } + ### `resources` -#SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams - overrides: - show_root_heading: false - item_id_prefix: proxy_jump- + #SCHEMA# dstack._internal.core.models.resources.ResourcesSpec + overrides: + show_root_heading: false + type: + required: true + item_id_prefix: resources- -#### `ssh_config.hosts[n]` { #ssh_config-hosts data-toc-label="hosts" } + #### `resources.cpu` { #resources-cpu data-toc-label="cpu" } -#SCHEMA# dstack._internal.core.models.fleets.SSHHostParams - overrides: - show_root_heading: false + #SCHEMA# dstack._internal.core.models.resources.CPUSpec + overrides: + show_root_heading: false + type: + required: true -##### `ssh_config.hosts[n].proxy_jump` { #proxy_jump data-toc-label="hosts[n].proxy_jump" } + #### `resources.gpu` { #resources-gpu data-toc-label="gpu" } -#SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams - overrides: - show_root_heading: false - item_id_prefix: hosts-proxy_jump- + #SCHEMA# dstack._internal.core.models.resources.GPUSpec + overrides: + show_root_heading: false + type: + required: true -### `resources` + #### `resources.disk` { #resources-disk data-toc-label="disk" } -#SCHEMA# dstack._internal.core.models.resources.ResourcesSpec - overrides: - show_root_heading: false - type: - required: true - item_id_prefix: resources- + #SCHEMA# dstack._internal.core.models.resources.DiskSpec + overrides: + show_root_heading: false + type: + required: true -#### `resources.cpu` { #resources-cpu data-toc-label="cpu" } + ### `retry` -#SCHEMA# dstack._internal.core.models.resources.CPUSpec - overrides: - show_root_heading: false - type: - required: true + #SCHEMA# dstack._internal.core.models.profiles.ProfileRetry + overrides: + show_root_heading: false -#### `resources.gpu` { #resources-gpu data-toc-label="gpu" } +=== "SSH fleet" -#SCHEMA# dstack._internal.core.models.resources.GPUSpec - overrides: - show_root_heading: false - type: - required: true + ## Root reference -#### `resources.disk` { #resources-disk data-toc-label="disk" } + #SCHEMA# dstack._internal.core.models.fleets.SSHFleetConfiguration + overrides: + show_root_heading: false + type: + required: true + ssh_config: + required: true -#SCHEMA# dstack._internal.core.models.resources.DiskSpec - overrides: - show_root_heading: false - type: - required: true + ### `ssh_config` { data-toc-label="ssh_config" } -### `retry` + #SCHEMA# dstack._internal.core.models.fleets.SSHParams + overrides: + show_root_heading: false + item_id_prefix: ssh_config- -#SCHEMA# dstack._internal.core.models.profiles.ProfileRetry - overrides: - show_root_heading: false + #### `ssh_config.proxy_jump` { #ssh_config-proxy_jump data-toc-label="proxy_jump" } + + #SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams + overrides: + show_root_heading: false + item_id_prefix: proxy_jump- + + #### `ssh_config.hosts[n]` { #ssh_config-hosts data-toc-label="hosts" } + + #SCHEMA# dstack._internal.core.models.fleets.SSHHostParams + overrides: + show_root_heading: false + + ##### `ssh_config.hosts[n].proxy_jump` { #proxy_jump data-toc-label="hosts[n].proxy_jump" } + + #SCHEMA# dstack._internal.core.models.fleets.SSHProxyParams + overrides: + show_root_heading: false + item_id_prefix: hosts-proxy_jump- diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index af8f045664..5824d2a1f0 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -84,7 +84,7 @@ class SSHHostParams(CoreModel): "The amount of blocks to split the instance into, a number or `auto`." " `auto` means as many as possible." " The number of GPUs and CPUs must be divisible by the number of blocks." - " Defaults to the top-level `blocks` value." + " Defaults to the top-level `blocks` value" ), ge=1, ), @@ -130,7 +130,7 @@ class SSHParams(CoreModel): " If not specified, `dstack` will use IPs from the first found internal network." ) ), - ] + ] = None @validator("network") def validate_network(cls, value): @@ -206,50 +206,13 @@ def _post_validate_ranges(cls, values): return values -class InstanceGroupParamsConfig(CoreConfig): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["nodes"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["idle_duration"], - extra_types=[{"type": "string"}], - ) - - -class InstanceGroupParams(CoreModel): - env: Annotated[ - Env, - Field(description="The mapping or the list of environment variables"), - ] = Env() - ssh_config: Annotated[ - Optional[SSHParams], - Field(description="The parameters for adding instances via SSH"), - ] = None - - nodes: Annotated[ - Optional[FleetNodesSpec], Field(description="The number of instances in cloud fleet") - ] = None +class CommonFleetConfigurationProps(CoreModel): + type: Literal["fleet"] = "fleet" + name: Annotated[Optional[str], Field(description="The fleet name")] = None placement: Annotated[ Optional[InstanceGroupPlacement], Field(description="The placement of instances: `any` or `cluster`"), ] = None - reservation: Annotated[ - Optional[str], - Field( - description=( - "The existing reservation to use for instance provisioning." - " Supports AWS Capacity Reservations, AWS Capacity Blocks, and GCP reservations" - ) - ), - ] = None - resources: Annotated[ - Optional[ResourcesSpec], - Field(description="The resources requirements"), - ] = None - blocks: Annotated[ Union[Literal["auto"], int], Field( @@ -263,6 +226,22 @@ class InstanceGroupParams(CoreModel): ), ] = 1 + +class BackendFleetConfiguraionProps(CoreModel): + nodes: Annotated[Optional[FleetNodesSpec], Field(description="The number of instances")] = None + reservation: Annotated[ + Optional[str], + Field( + description=( + "The existing reservation to use for instance provisioning." + " Supports AWS Capacity Reservations, AWS Capacity Blocks, and GCP reservations" + ) + ), + ] = None + resources: Annotated[ + Optional[ResourcesSpec], + Field(description="The resources requirements"), + ] = None backends: Annotated[ Optional[List[BackendType]], Field(description="The backends to consider for provisioning (e.g., `[aws, gcp]`)"), @@ -314,6 +293,16 @@ class InstanceGroupParams(CoreModel): ) ), ] = None + tags: Annotated[ + Optional[Dict[str, str]], + Field( + description=( + "The custom tags to associate with the resource." + " The tags are also propagated to the underlying backend resources." + " If there is a conflict with backend-level tags, does not override them" + ) + ), + ] = None @validator("nodes", pre=True) def parse_nodes(cls, v: Optional[Union[dict, str]]) -> Optional[dict]: @@ -329,35 +318,61 @@ def parse_nodes(cls, v: Optional[Union[dict, str]]) -> Optional[dict]: parse_idle_duration ) - -class FleetProps(CoreModel): - type: Literal["fleet"] = "fleet" - name: Annotated[Optional[str], Field(description="The fleet name")] = None + _validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator) -class FleetConfigurationConfig(InstanceGroupParamsConfig): +class BackendFleetConfigurationPropsConfig(CoreConfig): @staticmethod def schema_extra(schema: Dict[str, Any]): - InstanceGroupParamsConfig.schema_extra(schema) + add_extra_schema_types( + schema["properties"]["nodes"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["idle_duration"], + extra_types=[{"type": "string"}], + ) + + +class SSHFleetConfigurationProps(CoreModel): + ssh_config: Annotated[ + Optional[SSHParams], + Field(description="The parameters for adding instances via SSH"), + ] = None + env: Annotated[ + Env, + Field(description="The mapping or the list of environment variables"), + ] = Env() + + +class FleetConfigurationConfig(BackendFleetConfigurationPropsConfig): + @staticmethod + def schema_extra(schema: dict[str, Any]): + BackendFleetConfigurationPropsConfig.schema_extra(schema) class FleetConfiguration( - InstanceGroupParams, - FleetProps, + SSHFleetConfigurationProps, + BackendFleetConfiguraionProps, + CommonFleetConfigurationProps, generate_dual_core_model(FleetConfigurationConfig), ): - tags: Annotated[ - Optional[Dict[str, str]], - Field( - description=( - "The custom tags to associate with the resource." - " The tags are also propagated to the underlying backend resources." - " If there is a conflict with backend-level tags, does not override them" - ) - ), - ] = None + pass - _validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator) + +class BackendFleetConfiguration( + BackendFleetConfiguraionProps, + CommonFleetConfigurationProps, + generate_dual_core_model(BackendFleetConfigurationPropsConfig), +): + """For the documentation only""" + + +class SSHFleetConfiguration( + SSHFleetConfigurationProps, + CommonFleetConfigurationProps, +): + """For the documentation only""" class FleetSpecConfig(CoreConfig): diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 72633d8cee..1b2612ebc0 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -20,12 +20,14 @@ from dstack._internal.core.models.envs import Env from dstack._internal.core.models.fleets import ( ApplyFleetPlanInput, + BackendFleetConfiguraionProps, Fleet, FleetConfiguration, FleetPlan, FleetSpec, FleetStatus, InstanceGroupPlacement, + SSHFleetConfigurationProps, SSHHostParams, SSHParams, ) @@ -1370,10 +1372,7 @@ def _remove_fleet_spec_sensitive_info(spec: FleetSpec): def _validate_fleet_spec_and_set_defaults(spec: FleetSpec): if spec.configuration.name is not None: validate_dstack_resource_name(spec.configuration.name) - if spec.configuration.ssh_config is None and spec.configuration.nodes is None: - raise ServerClientError("No ssh_config or nodes specified") - if spec.configuration.ssh_config is not None and spec.configuration.nodes is not None: - raise ServerClientError("ssh_config and nodes are mutually exclusive") + _validate_fleet_configuration_subtype_specific_fields(spec.configuration) if spec.configuration.ssh_config is not None: _validate_all_ssh_params_specified(spec.configuration.ssh_config) if spec.configuration.ssh_config.ssh_key is not None: @@ -1385,6 +1384,31 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec): _set_fleet_spec_defaults(spec) +def _validate_fleet_configuration_subtype_specific_fields(conf: FleetConfiguration): + if conf.ssh_config is None and conf.nodes is None: + raise ServerClientError("No ssh_config or nodes specified") + if conf.ssh_config is not None and conf.nodes is not None: + raise ServerClientError("ssh_config and nodes are mutually exclusive") + subtype: str + props_model: type[CoreModel] + if conf.ssh_config is not None: + subtype = "SSH" + props_model = BackendFleetConfiguraionProps + else: + subtype = "Backend" + props_model = SSHFleetConfigurationProps + non_default_fields: list[str] = [] + for field in props_model.__fields__.values(): + if getattr(conf, field.name) != field.default: + non_default_fields.append(field.name) + if non_default_fields: + raise ServerClientError( + f"{subtype} fleet configuration does not support the following fields:" + f" {non_default_fields}" + ) + return conf + + def _set_fleet_spec_defaults(spec: FleetSpec): if spec.configuration.resources is not None: set_resources_defaults(spec.configuration.resources) diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index d5f9ffd95f..7748708a1a 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timezone -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Union from unittest.mock import Mock, patch from uuid import uuid4 @@ -12,6 +12,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import EntityReference +from dstack._internal.core.models.envs import Env from dstack._internal.core.models.fleets import ( FleetConfiguration, FleetNodesSpec, @@ -1528,6 +1529,78 @@ async def test_errors_if_ssh_key_is_bad( ) assert response.status_code == 400 + @pytest.mark.parametrize( + ["field_name", "field_value"], + [ + pytest.param("backends", [BackendType.AWS], id="backends"), + pytest.param("regions", ["eu-west-1"], id="regions"), + pytest.param("instance_types", ["p3.8xlarge"], id="instance_types"), + pytest.param("idle_duration", 60, id="idle_duration"), + pytest.param("tags", {}, id="tags"), # falsy value + ], + ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_errors_if_ssh_fleet_uses_backend_only_field( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + field_name: str, + field_value: Any, + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + conf = get_ssh_fleet_configuration(name="test-ssh-fleet", hosts=["1.1.1.1"]) + setattr(conf, field_name, field_value) + spec = get_fleet_spec(conf=conf) + response = await client.post( + f"/api/project/{project.name}/fleets/apply", + headers=get_auth_headers(user.token), + json={"plan": {"spec": spec.dict()}, "force": False}, + ) + assert response.status_code == 400, response.json() + assert response.json()["detail"][0]["msg"] == ( + f"SSH fleet configuration does not support the following fields: ['{field_name}']" + ) + + @pytest.mark.parametrize( + ["field_name", "field_value"], + [ + pytest.param("env", Env.parse_obj({"K": "V"}), id="env"), + ], + ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_errors_if_backend_fleet_uses_ssh_only_field( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + field_name: str, + field_value: Any, + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + conf = get_fleet_configuration() + setattr(conf, field_name, field_value) + spec = get_fleet_spec(conf=conf) + response = await client.post( + f"/api/project/{project.name}/fleets/apply", + headers=get_auth_headers(user.token), + json={"plan": {"spec": spec.dict()}, "force": False}, + ) + assert response.status_code == 400, response.json() + assert response.json()["detail"][0]["msg"] == ( + f"Backend fleet configuration does not support the following fields: ['{field_name}']" + ) + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_forbids_if_no_permission_to_manage_ssh_fleets(