From c5fdb8199c36a569e839e4385c7f7da04285d8e2 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 6 May 2026 21:47:30 +0200 Subject: [PATCH] Allow setting imported gateway as project default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``` $ dstack gateway update main/main-gateway --set-default NAME BACKEND HOSTNAME DOMAIN DEFAULT STATUS main/main-gateway aws (eu-west-1) 108.131.126.35 gtw.mycompany.example ✓ running ``` --- src/dstack/_internal/cli/commands/gateway.py | 51 +++++----- src/dstack/_internal/cli/utils/gateway.py | 28 ++++++ .../_internal/core/compatibility/gateways.py | 8 ++ src/dstack/_internal/server/models.py | 7 +- .../_internal/server/routers/gateways.py | 8 +- .../_internal/server/schemas/gateways.py | 3 +- .../server/services/gateways/__init__.py | 35 +++++-- src/dstack/api/server/_gateways.py | 20 +++- .../_internal/server/routers/test_gateways.py | 94 ++++++++++++++++++- .../_internal/server/routers/test_runs.py | 49 ++++++++++ 10 files changed, 259 insertions(+), 44 deletions(-) diff --git a/src/dstack/_internal/cli/commands/gateway.py b/src/dstack/_internal/cli/commands/gateway.py index 16ebfb6665..c226b3334b 100644 --- a/src/dstack/_internal/cli/commands/gateway.py +++ b/src/dstack/_internal/cli/commands/gateway.py @@ -12,11 +12,12 @@ console, ) from dstack._internal.cli.utils.gateway import ( + get_gateway_relative_to_project, get_gateways_table, print_gateways_json, print_gateways_table, ) -from dstack._internal.core.errors import CLIError, ResourceNotExistsError +from dstack._internal.core.errors import CLIError from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.gateways import GatewayStatus from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent @@ -167,30 +168,38 @@ def _delete(self, args: argparse.Namespace): return def _update(self, args: argparse.Namespace): - if args.name.project is not None: - console.print( - "The [code]/[/] format is not supported for gateway names." - " Can only update gateways owned by the current project" - ) - exit(1) - name = args.name.name with console.status("Updating gateway..."): - if args.set_default: - self.api.client.gateways.set_default(self.api.project, name) if args.domain: - self.api.client.gateways.set_wildcard_domain(self.api.project, name, args.domain) - gateway = self.api.client.gateways.get(self.api.project, name) + if args.name.project is not None: + console.print( + "The [code]/[/] format is not supported for gateway names" + " when [code]--domain[/] is passed." + " Can only update gateways owned by the current project" + ) + exit(1) + self.api.client.gateways.set_wildcard_domain( + self.api.project, args.name.name, args.domain + ) + if args.set_default: + self.api.client.gateways.set_default( + self.api.project, + gateway_name=args.name.name, + gateway_project=args.name.project, + ) + gateway = get_gateway_relative_to_project( + client=self.api.client.gateways, + project=self.api.project, + gateway_project=args.name.project or self.api.project, + gateway_name=args.name.name, + ) print_gateways_table([gateway], current_project=self.api.project) def _get(self, args: argparse.Namespace): # TODO: Implement non-json output format - try: - gateway = self.api.client.gateways.get( - project_name=args.name.project or self.api.project, - gateway_name=args.name.name, - ) - except ResourceNotExistsError: - console.print("Gateway not found") - exit(1) - + gateway = get_gateway_relative_to_project( + client=self.api.client.gateways, + project=self.api.project, + gateway_project=args.name.project or self.api.project, + gateway_name=args.name.name, + ) print(pydantic_orjson_dumps_with_indent(gateway.dict(), default=None)) diff --git a/src/dstack/_internal/cli/utils/gateway.py b/src/dstack/_internal/cli/utils/gateway.py index 0c35b3c371..0458798014 100644 --- a/src/dstack/_internal/cli/utils/gateway.py +++ b/src/dstack/_internal/cli/utils/gateway.py @@ -4,8 +4,36 @@ from dstack._internal.cli.models.gateways import GatewayCommandOutput from dstack._internal.cli.utils.common import add_row_from_dict, console, format_entity_reference +from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.gateways import Gateway from dstack._internal.utils.common import DateFormatter, pretty_date +from dstack.api.server._gateways import GatewaysAPIClient + + +def get_gateway_relative_to_project( + client: GatewaysAPIClient, project: str, gateway_project: str, gateway_name: str +) -> Gateway: + """ + Retrieves a single gateway, ensuring that `Gateway.default` is resolved relative to + `project` rather than relative to the gateway's host project. + """ + if project == gateway_project: + return client.get(project, gateway_name) + + # For imported gateways, use `list`. + # `get` would resolve `Gateway.default` relative to the gateway's host project + gateways = client.list(project, include_imported=True) + for gateway in gateways: + if gateway.name == gateway_name and ( + gateway_project == gateway.project_name + # Compatibility with pre-0.20.20 servers: + # gateway.project_name is None means the gateway is in the current `project` + or (gateway.project_name is None and gateway_project == project) + ): + return gateway + ref = EntityReference(name=gateway_name, project=gateway_project) + raise ResourceNotExistsError(msg=f"Gateway {ref.format()!r} not found in project {project!r}") def print_gateways_table(gateways: List[Gateway], current_project: str, verbose: bool = False): diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index 949d6515f8..a2fc6101e6 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -1,5 +1,6 @@ from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec +from dstack._internal.server.schemas.gateways import SetDefaultGatewayRequest def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> IncludeExcludeDictType: @@ -26,6 +27,13 @@ def get_create_gateway_excludes(configuration: GatewayConfiguration) -> IncludeE return create_gateway_excludes +def get_set_default_gateway_excludes(request: SetDefaultGatewayRequest) -> IncludeExcludeDictType: + excludes: IncludeExcludeDictType = {} + if request.gateway_project is None: + excludes["gateway_project"] = True + return excludes + + def _get_gateway_configuration_excludes( configuration: GatewayConfiguration, ) -> IncludeExcludeDictType: diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index d6a3bb940b..1951368ca1 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -274,9 +274,10 @@ class ProjectModel(BaseModel): default_gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("gateways.id", use_alter=True, ondelete="SET NULL"), nullable=True ) - default_gateway: Mapped[Optional["GatewayModel"]] = relationship( - foreign_keys=[default_gateway_id] - ) + """ + **NOTE**: default_gateway_id may point to a previously imported gateway that the project is no + longer authorized to use. Check access before using the gateway. + """ # TODO: drop `default_pool_id` after the release without pools. default_pool_id: Mapped[Optional[UUIDType]] = mapped_column( diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index 8056e782dc..e46697fbce 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -7,6 +7,7 @@ import dstack._internal.server.schemas.gateways as schemas import dstack._internal.server.services.gateways as gateways from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.core.models.common import EntityReference from dstack._internal.server.db import get_session from dstack._internal.server.deps import Project from dstack._internal.server.models import ProjectModel, UserModel @@ -104,7 +105,12 @@ async def set_default_gateway( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): user, project = user_project - await gateways.set_default_gateway(session=session, project=project, name=body.name, user=user) + await gateways.set_default_gateway( + session=session, + project=project, + ref=EntityReference(name=body.name, project=body.gateway_project), + user=user, + ) @router.post("/set_wildcard_domain", response_model=models.Gateway) diff --git a/src/dstack/_internal/server/schemas/gateways.py b/src/dstack/_internal/server/schemas/gateways.py index 71192453ee..4357c30430 100644 --- a/src/dstack/_internal/server/schemas/gateways.py +++ b/src/dstack/_internal/server/schemas/gateways.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from dstack._internal.core.models.common import CoreConfig, CoreModel, generate_dual_core_model from dstack._internal.core.models.gateways import GatewayConfiguration @@ -28,6 +28,7 @@ class DeleteGatewaysRequest(CoreModel): class SetDefaultGatewayRequest(CoreModel): name: str + gateway_project: Optional[str] = None class SetWildcardDomainRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 8abd28a128..e007b65b49 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -263,7 +263,10 @@ async def create_gateway( default_gateway = await get_project_default_gateway_model(session=session, project=project) if default_gateway is None or configuration.default: await set_default_gateway( - session=session, project=project, name=configuration.name, user=user + session=session, + project=project, + ref=EntityReference(name=configuration.name, project=None), + user=user, ) default_gateway = gateway pipeline_hinter.hint_fetch(GatewayModel.__name__) @@ -394,18 +397,18 @@ async def set_gateway_wildcard_domain( async def set_default_gateway( - session: AsyncSession, project: ProjectModel, name: str, user: Optional[UserModel] + session: AsyncSession, project: ProjectModel, ref: EntityReference, user: Optional[UserModel] ): gateway = await get_project_gateway_model_by_reference( - session=session, project=project, ref=EntityReference(name=name, project=None) + session=session, project=project, ref=ref ) if gateway is None: raise ResourceNotExistsError() if gateway.to_be_deleted: raise ServerClientError("Cannot set gateway marked for deletion as default") - if project.default_gateway_id == gateway.id: - return previous_gateway = await get_project_default_gateway_model(session, project) + if previous_gateway is not None and previous_gateway.id == gateway.id: + return await session.execute( update(ProjectModel) .where( @@ -418,15 +421,21 @@ async def set_default_gateway( if previous_gateway is not None: events.emit( session, - "Gateway unset as default", + "Gateway unset as project default", actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(), - targets=[events.Target.from_model(previous_gateway)], + targets=[ + events.Target.from_model(previous_gateway), + events.Target.from_model(project), + ], ) events.emit( session, - "Gateway set as default", + "Gateway set as project default", actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(), - targets=[events.Target.from_model(gateway)], + targets=[ + events.Target.from_model(gateway), + events.Target.from_model(project), + ], ) await session.commit() @@ -531,6 +540,14 @@ async def get_project_default_gateway_model( stmt = select(GatewayModel).where( GatewayModel.id == project.default_gateway_id, GatewayModel.to_be_deleted == False, + or_( + GatewayModel.project_id == project.id, + exists().where( + ImportModel.project_id == project.id, + ImportModel.export_id == ExportedGatewayModel.export_id, + ExportedGatewayModel.gateway_id == GatewayModel.id, + ), + ), ) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) diff --git a/src/dstack/api/server/_gateways.py b/src/dstack/api/server/_gateways.py index 2952646767..6a325bdbde 100644 --- a/src/dstack/api/server/_gateways.py +++ b/src/dstack/api/server/_gateways.py @@ -1,8 +1,11 @@ -from typing import List +from typing import List, Optional from pydantic import parse_obj_as -from dstack._internal.core.compatibility.gateways import get_create_gateway_excludes +from dstack._internal.core.compatibility.gateways import ( + get_create_gateway_excludes, + get_set_default_gateway_excludes, +) from dstack._internal.core.models.gateways import Gateway, GatewayConfiguration from dstack._internal.server.schemas.gateways import ( CreateGatewayRequest, @@ -44,9 +47,16 @@ def delete(self, project_name: str, gateways_names: List[str]) -> None: body = DeleteGatewaysRequest(names=gateways_names) self._request(f"/api/project/{project_name}/gateways/delete", body=body.json()) - def set_default(self, project_name: str, gateway_name: str) -> None: - body = SetDefaultGatewayRequest(name=gateway_name) - self._request(f"/api/project/{project_name}/gateways/set_default", body=body.json()) + def set_default( + self, project_name: str, gateway_name: str, *, gateway_project: Optional[str] = None + ) -> None: + if gateway_project == project_name: + gateway_project = None # omit for compatibility with pre-0.20.20 servers + body = SetDefaultGatewayRequest(name=gateway_name, gateway_project=gateway_project) + self._request( + f"/api/project/{project_name}/gateways/set_default", + body=body.json(exclude=get_set_default_gateway_excludes(body)), + ) def set_wildcard_domain( self, project_name: str, gateway_name: str, wildcard_domain: str diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 28645788bd..a41fc6de01 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -615,7 +615,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: } events = await list_events(session) assert len(events) == 1 - assert events[0].message == "Gateway set as default" + assert events[0].message == "Gateway set as project default" second_gateway_compute = await create_gateway_compute( session=session, @@ -637,10 +637,10 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: assert response.status_code == 200 events = await list_events(session) assert len(events) == 2 - actual_events = [(e.targets[0].entity_name, e.message) for e in events] + actual_events = [({t.entity_name for t in e.targets}, e.message) for e in events] expected_events = [ - ("first_gateway", "Gateway unset as default"), - ("second_gateway", "Gateway set as default"), + ({"first_gateway", project.name}, "Gateway unset as project default"), + ({"second_gateway", project.name}, "Gateway set as project default"), ] assert ( actual_events == expected_events @@ -704,6 +704,92 @@ async def test_importer_member_cannot_set_default_imported_gateway( ) assert response.status_code == 403 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_set_imported_gateway_as_default( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{importer_project.name}/gateways/set_default", + headers=get_auth_headers(importer_user.token), + json={"name": gateway.name, "gateway_project": exporter_project.name}, + ) + assert response.status_code == 200 + await session.refresh(importer_project) + assert importer_project.default_gateway_id == gateway.id + events = await list_events(session) + assert any(e.message == "Gateway set as project default" for e in events) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_cannot_set_non_imported_foreign_gateway_as_default( + self, test_db, session: AsyncSession, client: AsyncClient + ): + not_importer_user = await create_user( + session, name="not-importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + not_importer_project = await create_project( + session, name="not-importer-project", owner=not_importer_user + ) + await add_project_member( + session=session, + project=not_importer_project, + user=not_importer_user, + project_role=ProjectRole.ADMIN, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{not_importer_project.name}/gateways/set_default", + headers=get_auth_headers(not_importer_user.token), + json={"name": gateway.name, "gateway_project": exporter_project.name}, + ) + assert response.status_code == 400 + class TestDeleteGateway: @pytest.mark.asyncio diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 71d265b1d3..4e1afc24b6 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -3567,6 +3567,55 @@ async def test_submit_to_foreign_gateway_only_if_imported( ] } + @pytest.mark.asyncio + async def test_not_submits_to_default_gateway_if_not_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ) -> None: + user = await create_user(session=session, global_role=GlobalRole.USER) + gateway_project = await create_project(session=session, owner=user, name="gateway-project") + backend = await create_backend(session=session, project_id=gateway_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=gateway_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + ) + + service_project = await create_project(session=session, owner=user, name="service-project") + # The project's default_gateway_id may point to the gateway (e.g., if the gateway was + # imported previously), but that does not authorize the project to use this gateway if it + # is no longer imported. + service_project.default_gateway_id = gateway.id + await session.commit() + await add_project_member( + session=session, + project=service_project, + user=user, + project_role=ProjectRole.USER, + ) + repo = await create_repo(session=session, project_id=service_project.id) + + run_spec = get_service_run_spec( + repo_id=repo.name, + gateway=True, + ) + response = await client.post( + f"/api/project/{service_project.name}/runs/submit", + headers=get_auth_headers(user.token), + json={"run_spec": run_spec}, + ) + assert response.status_code == 400 + assert response.json() == { + "detail": [ + { + "msg": "The service requires a gateway, but there is no default gateway in the project", + "code": "resource_not_exists", + } + ] + } + @pytest.mark.asyncio async def test_unregister_dangling_service( self,