From ee332c051374ecce925820a2a193e9c12759bcb2 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 29 Apr 2026 19:40:05 +0200 Subject: [PATCH 1/9] Update Exports & Imports API --- src/dstack/_internal/core/models/exports.py | 6 + src/dstack/_internal/core/models/imports.py | 6 + ...0_db3679abd063_add_exportedgatewaymodel.py | 64 ++++++++ src/dstack/_internal/server/models.py | 25 +++ .../_internal/server/routers/exports.py | 3 + .../_internal/server/schemas/exports.py | 3 + .../_internal/server/services/exports.py | 73 ++++++++- .../_internal/server/services/imports.py | 19 ++- src/dstack/_internal/server/testing/common.py | 5 + .../_internal/server/routers/test_exports.py | 143 +++++++++++++++++- .../_internal/server/routers/test_imports.py | 22 +++ 11 files changed, 365 insertions(+), 4 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/04_29_1700_db3679abd063_add_exportedgatewaymodel.py diff --git a/src/dstack/_internal/core/models/exports.py b/src/dstack/_internal/core/models/exports.py index 52cb4a65a0..ae215f6ec1 100644 --- a/src/dstack/_internal/core/models/exports.py +++ b/src/dstack/_internal/core/models/exports.py @@ -12,8 +12,14 @@ class ExportedFleet(CoreModel): name: str +class ExportedGateway(CoreModel): + id: uuid.UUID + name: str + + class Export(CoreModel): id: uuid.UUID name: str imports: list[ExportImport] exported_fleets: list[ExportedFleet] + exported_gateways: list[ExportedGateway] diff --git a/src/dstack/_internal/core/models/imports.py b/src/dstack/_internal/core/models/imports.py index 7a79bde7bc..d3c297a44e 100644 --- a/src/dstack/_internal/core/models/imports.py +++ b/src/dstack/_internal/core/models/imports.py @@ -8,11 +8,17 @@ class ImportExportedFleet(CoreModel): name: str +class ImportExportedGateway(CoreModel): + id: uuid.UUID + name: str + + class ImportExport(CoreModel): id: uuid.UUID name: str project_name: str exported_fleets: list[ImportExportedFleet] + exported_gateways: list[ImportExportedGateway] class Import(CoreModel): diff --git a/src/dstack/_internal/server/migrations/versions/2026/04_29_1700_db3679abd063_add_exportedgatewaymodel.py b/src/dstack/_internal/server/migrations/versions/2026/04_29_1700_db3679abd063_add_exportedgatewaymodel.py new file mode 100644 index 0000000000..85fa9c94b0 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/04_29_1700_db3679abd063_add_exportedgatewaymodel.py @@ -0,0 +1,64 @@ +"""Add ExportedGatewayModel + +Revision ID: db3679abd063 +Revises: 05c351d08f6b +Create Date: 2026-04-29 17:00:29.551669+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "db3679abd063" +down_revision = "05c351d08f6b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "exported_gateways", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("export_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column( + "gateway_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.ForeignKeyConstraint( + ["export_id"], + ["exports.id"], + name=op.f("fk_exported_gateways_export_id_exports"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["gateway_id"], + ["gateways.id"], + name=op.f("fk_exported_gateways_gateway_id_gateways"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_exported_gateways")), + sa.UniqueConstraint( + "export_id", "gateway_id", name="uq_exported_gateways_export_id_gateway_id" + ), + ) + with op.batch_alter_table("exported_gateways", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_exported_gateways_export_id"), ["export_id"], unique=False + ) + batch_op.create_index( + batch_op.f("ix_exported_gateways_gateway_id"), ["gateway_id"], unique=False + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("exported_gateways", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_exported_gateways_gateway_id")) + batch_op.drop_index(batch_op.f("ix_exported_gateways_export_id")) + + op.drop_table("exported_gateways") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 565aec51f8..d6a3bb940b 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -1142,6 +1142,10 @@ class ExportModel(BaseModel): back_populates="export", cascade=CASCADE_DEFAULT_WITH_DELETE_ORPHAN, ) + exported_gateways: Mapped[List["ExportedGatewayModel"]] = relationship( + back_populates="export", + cascade=CASCADE_DEFAULT_WITH_DELETE_ORPHAN, + ) class ImportModel(BaseModel): @@ -1187,6 +1191,27 @@ class ExportedFleetModel(BaseModel): fleet: Mapped["FleetModel"] = relationship() +class ExportedGatewayModel(BaseModel): + __tablename__ = "exported_gateways" + __table_args__ = ( + UniqueConstraint( + "export_id", "gateway_id", name="uq_exported_gateways_export_id_gateway_id" + ), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + export_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("exports.id", ondelete="CASCADE"), index=True + ) + export: Mapped["ExportModel"] = relationship() + gateway_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("gateways.id", ondelete="CASCADE"), index=True + ) + gateway: Mapped["GatewayModel"] = relationship() + + class UserPublicKeyModel(BaseModel): __tablename__ = "user_public_keys" __table_args__ = ( diff --git a/src/dstack/_internal/server/routers/exports.py b/src/dstack/_internal/server/routers/exports.py index bc30ad822d..3d3f7a5055 100644 --- a/src/dstack/_internal/server/routers/exports.py +++ b/src/dstack/_internal/server/routers/exports.py @@ -36,6 +36,7 @@ async def create_export( name=body.name, importer_project_names=body.importer_projects, exported_fleet_names=body.exported_fleets, + exported_gateway_names=body.exported_gateways, ) @@ -55,6 +56,8 @@ async def update_export( remove_importer_project_names=body.remove_importer_projects, add_exported_fleet_names=body.add_exported_fleets, remove_exported_fleet_names=body.remove_exported_fleets, + add_exported_gateway_names=body.add_exported_gateways, + remove_exported_gateway_names=body.remove_exported_gateways, ) diff --git a/src/dstack/_internal/server/schemas/exports.py b/src/dstack/_internal/server/schemas/exports.py index 240b6364af..7f013c92ea 100644 --- a/src/dstack/_internal/server/schemas/exports.py +++ b/src/dstack/_internal/server/schemas/exports.py @@ -5,6 +5,7 @@ class CreateExportRequest(CoreModel): name: str importer_projects: list[str] = [] exported_fleets: list[str] = [] + exported_gateways: list[str] = [] class UpdateExportRequest(CoreModel): @@ -13,6 +14,8 @@ class UpdateExportRequest(CoreModel): remove_importer_projects: list[str] = [] add_exported_fleets: list[str] = [] remove_exported_fleets: list[str] = [] + add_exported_gateways: list[str] = [] + remove_exported_gateways: list[str] = [] class DeleteExportRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/exports.py b/src/dstack/_internal/server/services/exports.py index 4c47187dd7..e6e6ad0ba1 100644 --- a/src/dstack/_internal/server/services/exports.py +++ b/src/dstack/_internal/server/services/exports.py @@ -11,20 +11,28 @@ ResourceNotExistsError, ServerClientError, ) -from dstack._internal.core.models.exports import Export, ExportedFleet, ExportImport +from dstack._internal.core.models.exports import ( + Export, + ExportedFleet, + ExportedGateway, + ExportImport, +) from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( ExportedFleetModel, + ExportedGatewayModel, ExportModel, FleetModel, + GatewayModel, ImportModel, ProjectModel, ProjectRole, UserModel, ) from dstack._internal.server.services.fleets import get_fleet_spec, list_project_fleet_models +from dstack._internal.server.services.gateways import list_project_gateway_models from dstack._internal.server.services.locking import get_locker, string_to_lock_id from dstack._internal.server.services.projects import ( get_user_project_role, @@ -73,6 +81,9 @@ async def get_export_model_by_name_for_update( ) .joinedload(ExportedFleetModel.fleet) .load_only(FleetModel.name), + selectinload(ExportModel.exported_gateways) + .joinedload(ExportedGatewayModel.gateway) + .load_only(GatewayModel.name), ) .with_for_update(key_share=True) ) @@ -95,6 +106,7 @@ async def create_export( name: str, importer_project_names: list[str], exported_fleet_names: list[str], + exported_gateway_names: list[str], ) -> Export: validate_dstack_resource_name(name) @@ -118,9 +130,11 @@ async def create_export( project=project, imports=[], exported_fleets=[], + exported_gateways=[], ) await add_importer_projects(session, user, export, importer_project_names) await add_exported_fleets(session, export, exported_fleet_names) + await add_exported_gateways(session, export, exported_gateway_names) session.add(export) await session.commit() return export_model_to_export(export) @@ -135,6 +149,8 @@ async def update_export( remove_importer_project_names: list[str], add_exported_fleet_names: list[str], remove_exported_fleet_names: list[str], + add_exported_gateway_names: list[str], + remove_exported_gateway_names: list[str], ) -> Export: async with get_export_model_by_name_for_update(session, project, name) as export: if export is None: @@ -145,6 +161,8 @@ async def update_export( and not remove_importer_project_names and not add_exported_fleet_names and not remove_exported_fleet_names + and not add_exported_gateway_names + and not remove_exported_gateway_names ): raise ServerClientError("No changes specified") @@ -167,11 +185,21 @@ async def update_export( f"Fleets {add_remove_conflict_fleets} are listed for both addition and removal." " Cannot add and remove at the same time" ) + add_remove_conflict_gateways = set(add_exported_gateway_names) & set( + remove_exported_gateway_names + ) + if add_remove_conflict_gateways: + raise ServerClientError( + f"Gateways {add_remove_conflict_gateways} are listed for both addition and removal." + " Cannot add and remove at the same time" + ) await add_importer_projects(session, user, export, add_importer_project_names) await add_exported_fleets(session, export, add_exported_fleet_names) + await add_exported_gateways(session, export, add_exported_gateway_names) await remove_importer_projects(export, remove_importer_project_names) await remove_exported_fleets(export, remove_exported_fleet_names) + await remove_exported_gateways(export, remove_exported_gateway_names) await session.commit() return export_model_to_export(export) @@ -259,6 +287,39 @@ async def remove_exported_fleets(export: ExportModel, names: list[str]) -> None: export.exported_fleets = [ef for ef in export.exported_fleets if ef.fleet.name not in names] +async def add_exported_gateways( + session: AsyncSession, export: ExportModel, names: list[str] +) -> None: + if not names: + return + if len(names) != len(set(names)): + raise ServerClientError("Some gateways are listed for addition more than once") + already_exported = {eg.gateway.name for eg in export.exported_gateways} & set(names) + if already_exported: + raise ServerClientError( + f"Gateways {already_exported} are already exported by export {export.name!r}" + ) + gateways = await list_project_gateway_models(session=session, project=export.project) + gateways = [g for g in gateways if g.name in names] + if missing := set(names) - {g.name for g in gateways}: + raise ResourceNotExistsError( + f"Gateways {missing} not found in project {export.project.name!r}" + ) + for gateway in gateways: + export.exported_gateways.append(ExportedGatewayModel(gateway=gateway)) + + +async def remove_exported_gateways(export: ExportModel, names: list[str]) -> None: + if len(names) != len(set(names)): + raise ServerClientError("Some gateways are listed for removal more than once") + existing = {eg.gateway.name for eg in export.exported_gateways} + if missing := set(names) - existing: + raise ServerClientError(f"Gateways {missing} are not exported by export {export.name!r}") + export.exported_gateways = [ + eg for eg in export.exported_gateways if eg.gateway.name not in names + ] + + async def delete_export(session: AsyncSession, project: ProjectModel, name: str) -> None: async with get_export_model_by_name_for_update(session, project, name) as export: if export is None: @@ -284,6 +345,9 @@ async def list_exports(session: AsyncSession, project: ProjectModel) -> list[Exp ) .joinedload(ExportedFleetModel.fleet) .load_only(FleetModel.name), + selectinload(ExportModel.exported_gateways) + .joinedload(ExportedGatewayModel.gateway) + .load_only(GatewayModel.name), ) .order_by(ExportModel.created_at.desc()) ) @@ -308,4 +372,11 @@ def export_model_to_export(export_model: ExportModel) -> Export: ) for exported_fleet_model in export_model.exported_fleets ], + exported_gateways=[ + ExportedGateway( + id=exported_gateway_model.gateway.id, + name=exported_gateway_model.gateway.name, + ) + for exported_gateway_model in export_model.exported_gateways + ], ) diff --git a/src/dstack/_internal/server/services/imports.py b/src/dstack/_internal/server/services/imports.py index f6d8b6c7a8..cee432764a 100644 --- a/src/dstack/_internal/server/services/imports.py +++ b/src/dstack/_internal/server/services/imports.py @@ -3,11 +3,18 @@ from sqlalchemy.orm import joinedload, selectinload from dstack._internal.core.errors import ResourceNotExistsError -from dstack._internal.core.models.imports import Import, ImportExport, ImportExportedFleet +from dstack._internal.core.models.imports import ( + Import, + ImportExport, + ImportExportedFleet, + ImportExportedGateway, +) from dstack._internal.server.models import ( ExportedFleetModel, + ExportedGatewayModel, ExportModel, FleetModel, + GatewayModel, ImportModel, ProjectModel, ) @@ -31,6 +38,9 @@ async def list_imports(session: AsyncSession, project: ProjectModel) -> list[Imp ) .joinedload(ExportedFleetModel.fleet) .load_only(FleetModel.id, FleetModel.name), + selectinload(ExportModel.exported_gateways) + .joinedload(ExportedGatewayModel.gateway) + .load_only(GatewayModel.id, GatewayModel.name), ) ) .order_by(ImportModel.created_at.desc()) @@ -80,5 +90,12 @@ def import_model_to_import(import_model: ImportModel) -> Import: ) for ef in import_model.export.exported_fleets ], + exported_gateways=[ + ImportExportedGateway( + id=eg.gateway.id, + name=eg.gateway.name, + ) + for eg in import_model.export.exported_gateways + ], ), ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 8775c043c0..65d81946aa 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -100,6 +100,7 @@ DecryptedString, EventModel, ExportedFleetModel, + ExportedGatewayModel, ExportModel, FileArchiveModel, FleetModel, @@ -591,6 +592,7 @@ async def create_export( exporter_project: ProjectModel, importer_projects: list[ProjectModel], exported_fleets: list[FleetModel], + exported_gateways: Optional[list[GatewayModel]] = None, name: str = "test-export", ) -> ExportModel: export = ExportModel( @@ -598,6 +600,9 @@ async def create_export( project=exporter_project, imports=[ImportModel(project=project) for project in importer_projects], exported_fleets=[ExportedFleetModel(fleet=fleet) for fleet in exported_fleets], + exported_gateways=[ + ExportedGatewayModel(gateway=gateway) for gateway in (exported_gateways or []) + ], ) session.add(export) await session.commit() diff --git a/src/tests/_internal/server/routers/test_exports.py b/src/tests/_internal/server/routers/test_exports.py index dd41419fd3..6c6d0da65a 100644 --- a/src/tests/_internal/server/routers/test_exports.py +++ b/src/tests/_internal/server/routers/test_exports.py @@ -9,8 +9,10 @@ from dstack._internal.server.models import ExportModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + create_backend, create_export, create_fleet, + create_gateway, create_project, create_user, get_auth_headers, @@ -87,6 +89,10 @@ async def test_creates_export( name="fleet1", spec=get_fleet_spec(get_ssh_fleet_configuration()), ) + backend = await create_backend(session=session, project_id=project.id) + await create_gateway( + session=session, project_id=project.id, backend_id=backend.id, name="gateway1" + ) response = await client.post( f"/api/project/{project.name}/exports/create", @@ -95,6 +101,7 @@ async def test_creates_export( "name": "test-export", "importer_projects": ["ImporterProject"], "exported_fleets": ["fleet1"], + "exported_gateways": ["gateway1"], }, ) assert response.status_code == 200 @@ -104,6 +111,8 @@ async def test_creates_export( assert export_response["imports"][0]["project_name"] == "ImporterProject" assert len(export_response["exported_fleets"]) == 1 assert export_response["exported_fleets"][0]["name"] == "fleet1" + assert len(export_response["exported_gateways"]) == 1 + assert export_response["exported_gateways"][0]["name"] == "gateway1" res = await session.execute(select(ExportModel).where(ExportModel.name == "test-export")) assert res.scalar() is not None @@ -158,6 +167,14 @@ async def test_creates_empty_export(self, session: AsyncSession, client: AsyncCl "Fleets {'nonexistent-fleet'} not found in project 'ExporterProject'", id="nonexistent-fleet", ), + pytest.param( + { + "name": "test-export", + "exported_gateways": ["nonexistent-gateway"], + }, + "Gateways {'nonexistent-gateway'} not found in project 'ExporterProject'", + id="nonexistent-gateway", + ), pytest.param( { "name": "test-export", @@ -177,6 +194,14 @@ async def test_creates_empty_export(self, session: AsyncSession, client: AsyncCl "Some fleets are listed for addition more than once", id="duplicate-fleet", ), + pytest.param( + { + "name": "test-export", + "exported_gateways": ["exported-gateway", "exported-gateway"], + }, + "Some gateways are listed for addition more than once", + id="duplicate-gateway", + ), pytest.param( { "name": "test-export", @@ -237,6 +262,13 @@ async def test_rejects_invalid_export( spec=get_fleet_spec(get_ssh_fleet_configuration()), ) await create_fleet(session=session, project=project, name="cloud-fleet") + backend = await create_backend(session=session, project_id=project.id) + await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="exported-gateway", + ) not_permitted_project = await create_project( session=session, name="NotPermittedProject", owner=user ) @@ -344,11 +376,19 @@ async def test_updates_export( name="fleet2", spec=get_fleet_spec(get_ssh_fleet_configuration()), ) + backend = await create_backend(session=session, project_id=project.id) + gateway1 = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id, name="gateway1" + ) + gateway2 = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id, name="gateway2" + ) export = await create_export( session=session, exporter_project=project, importer_projects=[other_project, another_project], exported_fleets=[fleet1, fleet2], + exported_gateways=[gateway1, gateway2], name="test-export", ) @@ -366,6 +406,9 @@ async def test_updates_export( name="fleet4", spec=get_fleet_spec(get_ssh_fleet_configuration()), ) + await create_gateway( + session=session, project_id=project.id, backend_id=backend.id, name="gateway3" + ) if importer_project_role is not None: await add_project_member( session=session, project=new_project1, user=user, project_role=ProjectRole.ADMIN @@ -383,6 +426,8 @@ async def test_updates_export( "remove_importer_projects": ["AnotherProject"], "add_exported_fleets": ["fleet3", "fleet4"], "remove_exported_fleets": ["fleet2"], + "add_exported_gateways": ["gateway3"], + "remove_exported_gateways": ["gateway2"], }, ) assert response.status_code == 200 @@ -401,10 +446,16 @@ async def test_updates_export( "fleet3", "fleet4", } + assert len(export_response["exported_gateways"]) == 2 + assert {g["name"] for g in export_response["exported_gateways"]} == { + "gateway1", + "gateway3", + } - await session.refresh(export, ["imports", "exported_fleets"]) + await session.refresh(export, ["imports", "exported_fleets", "exported_gateways"]) assert len(export.imports) == 3 assert len(export.exported_fleets) == 3 + assert len(export.exported_gateways) == 2 response = await client.post( f"/api/project/{project.name}/exports/list", headers=get_auth_headers(user.token) @@ -416,6 +467,8 @@ async def test_updates_export( export_list[0]["imports"].sort(key=lambda i: i["project_name"]) export_response["exported_fleets"].sort(key=lambda f: f["name"]) export_list[0]["exported_fleets"].sort(key=lambda f: f["name"]) + export_response["exported_gateways"].sort(key=lambda g: g["name"]) + export_list[0]["exported_gateways"].sort(key=lambda g: g["name"]) assert export_list[0] == export_response async def test_can_add_same_entities_as_existing_deleted_ones( @@ -530,6 +583,14 @@ async def test_can_add_same_entities_as_existing_deleted_ones( "Fleets {'nonexistent-fleet'} not found in project 'ExporterProject'", id="add-nonexistent-fleet", ), + pytest.param( + { + "name": "test-export", + "add_exported_gateways": ["nonexistent-gateway"], + }, + "Gateways {'nonexistent-gateway'} not found in project 'ExporterProject'", + id="add-nonexistent-gateway", + ), pytest.param( { "name": "test-export", @@ -565,6 +626,22 @@ async def test_can_add_same_entities_as_existing_deleted_ones( "Some fleets are listed for addition more than once", id="add-duplicate-fleet", ), + pytest.param( + { + "name": "test-export", + "add_exported_gateways": ["exported-gateway"], + }, + "Gateways {'exported-gateway'} are already exported by export 'test-export'", + id="add-already-added-gateway", + ), + pytest.param( + { + "name": "test-export", + "add_exported_gateways": ["not-exported-gateway", "not-exported-gateway"], + }, + "Some gateways are listed for addition more than once", + id="add-duplicate-gateway", + ), pytest.param( { "name": "test-export", @@ -613,6 +690,22 @@ async def test_can_add_same_entities_as_existing_deleted_ones( "Fleets {'nonexistent-fleet'} are not exported by export 'test-export'", id="remove-nonexistent-fleet", ), + pytest.param( + { + "name": "test-export", + "remove_exported_gateways": ["not-exported-gateway"], + }, + "Gateways {'not-exported-gateway'} are not exported by export 'test-export'", + id="remove-not-exported-gateway", + ), + pytest.param( + { + "name": "test-export", + "remove_exported_gateways": ["nonexistent-gateway"], + }, + "Gateways {'nonexistent-gateway'} are not exported by export 'test-export'", + id="remove-nonexistent-gateway", + ), pytest.param( { "name": "test-export", @@ -632,6 +725,14 @@ async def test_can_add_same_entities_as_existing_deleted_ones( "Some fleets are listed for removal more than once", id="remove-duplicate-fleet", ), + pytest.param( + { + "name": "test-export", + "remove_exported_gateways": ["exported-gateway", "exported-gateway"], + }, + "Some gateways are listed for removal more than once", + id="remove-duplicate-gateway", + ), pytest.param( { "name": "test-export", @@ -650,6 +751,15 @@ async def test_can_add_same_entities_as_existing_deleted_ones( "Fleets {'not-exported-fleet'} are listed for both addition and removal. Cannot add and remove at the same time", id="add-remove-same-fleet", ), + pytest.param( + { + "name": "test-export", + "add_exported_gateways": ["not-exported-gateway"], + "remove_exported_gateways": ["not-exported-gateway"], + }, + "Gateways {'not-exported-gateway'} are listed for both addition and removal. Cannot add and remove at the same time", + id="add-remove-same-gateway", + ), ], ) async def test_rejects_invalid_update( @@ -672,11 +782,25 @@ async def test_rejects_invalid_update( name="exported-fleet", spec=get_fleet_spec(get_ssh_fleet_configuration()), ) + backend = await create_backend(session=session, project_id=project.id) + exported_gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="exported-gateway", + ) + await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="not-exported-gateway", + ) await create_export( session=session, exporter_project=project, importer_projects=[importer_project], exported_fleets=[exported_fleet], + exported_gateways=[exported_gateway], name="test-export", ) await create_fleet(session=session, project=project, name="cloud-fleet") @@ -847,12 +971,23 @@ async def test_lists_exports( name="fleet2", spec=get_fleet_spec(get_ssh_fleet_configuration()), ) - for name, fleet in (("export1", fleet1), ("export2", fleet2)): + backend = await create_backend(session=session, project_id=project.id) + gateway1 = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id, name="gateway1" + ) + gateway2 = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id, name="gateway2" + ) + for name, fleet, gateway in ( + ("export1", fleet1, gateway1), + ("export2", fleet2, gateway2), + ): await create_export( session=session, exporter_project=project, importer_projects=[other_project], exported_fleets=[fleet], + exported_gateways=[gateway], name=name, ) @@ -870,12 +1005,16 @@ async def test_lists_exports( assert exports[0]["imports"][0]["project_name"] == "OtherProject" assert len(exports[0]["exported_fleets"]) == 1 assert exports[0]["exported_fleets"][0]["name"] == "fleet1" + assert len(exports[0]["exported_gateways"]) == 1 + assert exports[0]["exported_gateways"][0]["name"] == "gateway1" assert exports[1]["name"] == "export2" assert len(exports[1]["imports"]) == 1 assert exports[1]["imports"][0]["project_name"] == "OtherProject" assert len(exports[1]["exported_fleets"]) == 1 assert exports[1]["exported_fleets"][0]["name"] == "fleet2" + assert len(exports[1]["exported_gateways"]) == 1 + assert exports[1]["exported_gateways"][0]["name"] == "gateway2" @pytest.mark.parametrize( "global_role, project_role", diff --git a/src/tests/_internal/server/routers/test_imports.py b/src/tests/_internal/server/routers/test_imports.py index 1551c50317..d4ef4a931f 100644 --- a/src/tests/_internal/server/routers/test_imports.py +++ b/src/tests/_internal/server/routers/test_imports.py @@ -9,8 +9,10 @@ from dstack._internal.server.models import ExportModel, ImportModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + create_backend, create_export, create_fleet, + create_gateway, create_project, create_user, get_auth_headers, @@ -195,11 +197,26 @@ async def test_lists_imports( name="fleet2", spec=get_fleet_spec(get_ssh_fleet_configuration()), ) + backend1 = await create_backend(session=session, project_id=exporter_project1.id) + gateway1 = await create_gateway( + session=session, + project_id=exporter_project1.id, + backend_id=backend1.id, + name="gateway1", + ) + backend2 = await create_backend(session=session, project_id=exporter_project2.id) + gateway2 = await create_gateway( + session=session, + project_id=exporter_project2.id, + backend_id=backend2.id, + name="gateway2", + ) await create_export( session=session, exporter_project=exporter_project1, importer_projects=[importer_project], exported_fleets=[fleet1], + exported_gateways=[gateway1], name="export1", ) await create_export( @@ -207,6 +224,7 @@ async def test_lists_imports( exporter_project=exporter_project2, importer_projects=[importer_project], exported_fleets=[fleet2], + exported_gateways=[gateway2], name="export2", ) @@ -223,11 +241,15 @@ async def test_lists_imports( assert imports[0]["export"]["project_name"] == "ExporterProject1" assert len(imports[0]["export"]["exported_fleets"]) == 1 assert imports[0]["export"]["exported_fleets"][0]["name"] == "fleet1" + assert len(imports[0]["export"]["exported_gateways"]) == 1 + assert imports[0]["export"]["exported_gateways"][0]["name"] == "gateway1" assert imports[1]["export"]["name"] == "export2" assert imports[1]["export"]["project_name"] == "ExporterProject2" assert len(imports[1]["export"]["exported_fleets"]) == 1 assert imports[1]["export"]["exported_fleets"][0]["name"] == "fleet2" + assert len(imports[1]["export"]["exported_gateways"]) == 1 + assert imports[1]["export"]["exported_gateways"][0]["name"] == "gateway2" @pytest.mark.parametrize( "global_role, project_role", From 1a08301c15e1542f5821ae0c6e6502ab3bbaf5ed Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 29 Apr 2026 22:12:56 +0200 Subject: [PATCH 2/9] Update Gateways API --- src/dstack/_internal/core/models/gateways.py | 2 + .../_internal/server/routers/gateways.py | 21 +- .../_internal/server/schemas/gateways.py | 4 + .../_internal/server/security/permissions.py | 28 ++ .../server/services/gateways/__init__.py | 57 +++- .../_internal/server/routers/test_gateways.py | 299 ++++++++++++++++++ 6 files changed, 396 insertions(+), 15 deletions(-) diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index b3fbadb844..ce28c62054 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -106,6 +106,8 @@ class Gateway(CoreModel): id: Optional[uuid.UUID] = None """`id` is only optional on the client side for compatibility with pre-0.20.7 servers.""" name: str + project_name: Optional[str] = None + """`project_name` is only optional on the client side for compatibility with pre-0.20.20 servers.""" configuration: GatewayConfiguration created_at: datetime.datetime status: GatewayStatus diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index af4557a449..8056e782dc 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession @@ -8,10 +8,13 @@ import dstack._internal.server.services.gateways as gateways from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.server.db import get_session +from dstack._internal.server.deps import Project from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.security.permissions import ( + Authenticated, ProjectAdmin, ProjectMemberOrPublicAccess, + check_can_access_gateway, ) from dstack._internal.server.services.pipelines import PipelineHinterProtocol, get_pipeline_hinter from dstack._internal.server.utils.routers import ( @@ -28,12 +31,19 @@ @router.post("/list", response_model=List[models.Gateway]) async def list_gateways( + body: Optional[schemas.ListGatewaysRequest] = None, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), ): _, project = user_project + if body is None: + body = schemas.ListGatewaysRequest() return CustomORJSONResponse( - await gateways.list_project_gateways(session=session, project=project) + await gateways.list_project_gateways( + session=session, + project=project, + include_imported=body.include_imported, + ) ) @@ -41,9 +51,12 @@ async def list_gateways( async def get_gateway( body: schemas.GetGatewayRequest, session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), + user: UserModel = Depends(Authenticated()), + project: ProjectModel = Depends(Project()), ): - _, project = user_project + await check_can_access_gateway( + session=session, user=user, gateway_project=project, gateway_name=body.name + ) gateway = await gateways.get_gateway_by_name(session=session, project=project, name=body.name) if gateway is None: raise ResourceNotExistsError() diff --git a/src/dstack/_internal/server/schemas/gateways.py b/src/dstack/_internal/server/schemas/gateways.py index 9c00caa2e3..71192453ee 100644 --- a/src/dstack/_internal/server/schemas/gateways.py +++ b/src/dstack/_internal/server/schemas/gateways.py @@ -14,6 +14,10 @@ class CreateGatewayRequest(generate_dual_core_model(CreateGatewayRequestConfig)) configuration: GatewayConfiguration +class ListGatewaysRequest(CoreModel): + include_imported: bool = False + + class GetGatewayRequest(CoreModel): name: str diff --git a/src/dstack/_internal/server/security/permissions.py b/src/dstack/_internal/server/security/permissions.py index a343152e6e..7d63d97560 100644 --- a/src/dstack/_internal/server/security/permissions.py +++ b/src/dstack/_internal/server/security/permissions.py @@ -12,7 +12,9 @@ from dstack._internal.server.db import get_session from dstack._internal.server.models import ( ExportedFleetModel, + ExportedGatewayModel, FleetModel, + GatewayModel, ImportModel, InstanceModel, MemberModel, @@ -310,6 +312,32 @@ async def check_can_access_fleet( raise error_forbidden() +async def check_can_access_gateway( + session: AsyncSession, + user: UserModel, + gateway_project: ProjectModel, + gateway_name: str, +) -> None: + if ( + user.global_role == GlobalRole.ADMIN + or get_user_project_role(user=user, project=gateway_project) is not None + ): + return + filters = [ + GatewayModel.project_id == gateway_project.id, + GatewayModel.name == gateway_name, + exists().where( + MemberModel.user_id == user.id, + MemberModel.project_id == ImportModel.project_id, + ImportModel.export_id == ExportedGatewayModel.export_id, + ExportedGatewayModel.gateway_id == GatewayModel.id, + ), + ] + res = await session.execute(select(func.count()).select_from(GatewayModel).where(*filters)) + if res.scalar_one() == 0: + raise error_forbidden() + + async def check_can_access_instance( session: AsyncSession, user: UserModel, diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 1298e452df..694eef7656 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -8,7 +8,7 @@ from typing import List, Optional, Sequence import httpx -from sqlalchemy import func, select, update +from sqlalchemy import exists, func, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -44,8 +44,10 @@ from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( BackendModel, + ExportedGatewayModel, GatewayComputeModel, GatewayModel, + ImportModel, ProjectModel, UserModel, ) @@ -125,14 +127,22 @@ def get_gateway_status_change_message( GATEWAY_CONFIGURE_DELAY = 3 -async def list_project_gateways(session: AsyncSession, project: ProjectModel) -> List[Gateway]: +async def list_project_gateways( + session: AsyncSession, + project: ProjectModel, + include_imported: bool = False, +) -> List[Gateway]: gateways = await list_project_gateway_models( session=session, project=project, + include_imported=include_imported, load_gateway_compute=True, load_backend_type=True, ) - return [gateway_model_to_gateway(g) for g in gateways] + return [ + gateway_model_to_gateway(g, default_gateway_id=project.default_gateway_id) + for g in gateways + ] async def get_gateway_by_name( @@ -147,7 +157,7 @@ async def get_gateway_by_name( ) if gateway is None: return None - return gateway_model_to_gateway(gateway) + return gateway_model_to_gateway(gateway, default_gateway_id=project.default_gateway_id) async def create_gateway_compute( @@ -254,6 +264,7 @@ async def create_gateway( await set_default_gateway( session=session, project=project, name=configuration.name, user=user ) + default_gateway = gateway pipeline_hinter.hint_fetch(GatewayModel.__name__) gateway = await get_project_gateway_model_by_name( session=session, @@ -263,7 +274,7 @@ async def create_gateway( load_backend_type=True, ) assert gateway is not None - return gateway_model_to_gateway(gateway) + return gateway_model_to_gateway(gateway, default_gateway_id=default_gateway.id) # NOTE: dstack Sky imports and uses this function @@ -378,7 +389,7 @@ async def set_gateway_wildcard_domain( targets=[events.Target.from_model(gateway)], ) await session.commit() - return gateway_model_to_gateway(gateway) + return gateway_model_to_gateway(gateway, default_gateway_id=project.default_gateway_id) async def set_default_gateway( @@ -420,16 +431,30 @@ async def set_default_gateway( async def list_project_gateway_models( session: AsyncSession, project: ProjectModel, + include_imported: bool = False, load_gateway_compute: bool = False, load_backend_type: bool = False, ) -> Sequence[GatewayModel]: - stmt = select(GatewayModel).where(GatewayModel.project_id == project.id) + stmt = select(GatewayModel) + if include_imported: + stmt = stmt.where( + or_( + GatewayModel.project_id == project.id, + exists().where( + ImportModel.project_id == project.id, + ImportModel.export_id == ExportedGatewayModel.export_id, + ExportedGatewayModel.gateway_id == GatewayModel.id, + ), + ) + ).options(joinedload(GatewayModel.project).load_only(ProjectModel.id, ProjectModel.name)) + else: + stmt = stmt.where(GatewayModel.project_id == project.id) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) - return res.scalars().all() + return res.unique().scalars().all() async def get_project_gateway_model_by_name( @@ -708,7 +733,15 @@ def get_gateway_compute_configuration( ) -def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway: +def gateway_model_to_gateway( + gateway_model: GatewayModel, default_gateway_id: Optional[uuid.UUID] +) -> Gateway: + """ + Args: + gateway_model: Gateway model to convert + default_gateway_id: ID of the default gateway in the project where `gateway_model` is being + viewed. Can be different from `gateway_model.project` if the gateway is imported. + """ ip_address = "" instance_id = "" hostname = "" @@ -721,18 +754,20 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway: backend_type = gateway_model.backend.type if gateway_model.backend.type == BackendType.DSTACK: backend_type = BackendType.AWS + is_default = default_gateway_id == gateway_model.id configuration = get_gateway_configuration(gateway_model) - configuration.default = gateway_model.project.default_gateway_id == gateway_model.id + configuration.default = is_default return Gateway( id=gateway_model.id, name=gateway_model.name, + project_name=gateway_model.project.name, ip_address=ip_address, instance_id=instance_id, hostname=hostname, backend=backend_type, region=gateway_model.region, wildcard_domain=gateway_model.wildcard_domain, - default=gateway_model.project.default_gateway_id == gateway_model.id, + default=is_default, created_at=gateway_model.created_at, status=gateway_model.status, status_message=gateway_model.status_message, diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 3577f9af30..304f06854c 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -10,6 +10,7 @@ from dstack._internal.server.testing.common import ( clear_events, create_backend, + create_export, create_gateway, create_gateway_compute, create_project, @@ -53,6 +54,7 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): assert response.json() == [ { "id": SomeUUID4Str(), + "project_name": project.name, "backend": backend.type.value, "created_at": response.json()[0]["created_at"], "default": False, @@ -107,6 +109,7 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): assert response.status_code == 200 assert response.json() == { "id": SomeUUID4Str(), + "project_name": project.name, "backend": backend.type.value, "created_at": response.json()["created_at"], "default": False, @@ -148,6 +151,181 @@ async def test_get_missing(self, test_db, session: AsyncSession, client: AsyncCl ) assert response.status_code == 400 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_list_returns_imported_gateway_with_include_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{importer_project.name}/gateways/list", + headers=get_auth_headers(importer_user.token), + json={"include_imported": True}, + ) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json) == 1 + assert response_json[0]["name"] == "exported-gateway" + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_list_not_returns_imported_gateway_without_include_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{importer_project.name}/gateways/list", + headers=get_auth_headers(importer_user.token), + json={}, + ) + assert response.status_code == 200 + assert response.json() == [] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_get_returns_imported_gateway( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/gateways/get", + headers=get_auth_headers(importer_user.token), + json={"name": "exported-gateway"}, + ) + assert response.status_code == 200 + assert response.json()["name"] == "exported-gateway" + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_get_returns_403_on_foreign_gateway_if_not_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + not_importer_user = await create_user( + session, name="not-importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + not_importer_project = await create_project( + session, name="not-importer-project", owner=not_importer_user + ) + await add_project_member( + session=session, + project=not_importer_project, + user=not_importer_user, + project_role=ProjectRole.USER, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/gateways/get", + headers=get_auth_headers(not_importer_user.token), + json={"name": "exported-gateway"}, + ) + assert response.status_code == 403 + class TestCreateGateway: @pytest.mark.asyncio @@ -190,6 +368,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn assert response.status_code == 200 assert response.json() == { "id": SomeUUID4Str(), + "project_name": project.name, "name": "test", "backend": "aws", "region": "us", @@ -247,6 +426,7 @@ async def test_create_gateway_without_name( assert response.status_code == 200 assert response.json() == { "id": SomeUUID4Str(), + "project_name": project.name, "name": "random-name", "backend": "aws", "region": "us", @@ -355,6 +535,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: assert response.status_code == 200 assert response.json() == { "id": SomeUUID4Str(), + "project_name": project.name, "backend": backend.type.value, "created_at": response.json()["created_at"], "default": True, @@ -432,6 +613,45 @@ async def test_set_default_gateway_missing( ) assert response.status_code == 400 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_importer_member_cannot_set_default_imported_gateway( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/gateways/set_default", + headers=get_auth_headers(importer_user.token), + json={"name": gateway.name}, + ) + assert response.status_code == 403 + class TestDeleteGateway: @pytest.mark.asyncio @@ -515,6 +735,45 @@ async def test_marks_gateways_to_be_deleted( assert {e.targets[0].entity_name for e in events} == {"gateway-aws", "gateway-gcp"} assert all(e.actor_user_id == user.id for e in events) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_importer_member_cannot_delete_imported_gateway( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/gateways/delete", + headers=get_auth_headers(importer_user.token), + json={"names": [gateway.name]}, + ) + assert response.status_code == 403 + class TestUpdateGateway: @pytest.mark.asyncio @@ -561,6 +820,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: assert response.status_code == 200 assert response.json() == { "id": SomeUUID4Str(), + "project_name": project.name, "backend": backend.type.value, "created_at": response.json()["created_at"], "status": "submitted", @@ -608,3 +868,42 @@ async def test_set_wildcard_domain_missing( headers=get_auth_headers(user.token), ) assert response.status_code == 400 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_importer_member_cannot_set_wildcard_domain_on_imported_gateway( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + name="exported-gateway", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/gateways/set_wildcard_domain", + headers=get_auth_headers(importer_user.token), + json={"name": gateway.name, "wildcard_domain": "new.example"}, + ) + assert response.status_code == 403 From 5e2932d98ae63c10b46004881c472f3bd1ba8c08 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Fri, 1 May 2026 14:02:17 +0200 Subject: [PATCH 3/9] Update Exports, Imports, and Gateways CLI --- src/dstack/_internal/cli/commands/export.py | 31 +++++++++++ src/dstack/_internal/cli/commands/gateway.py | 52 +++++++++++++------ src/dstack/_internal/cli/commands/import_.py | 7 +++ .../cli/services/configurators/gateway.py | 5 +- src/dstack/_internal/cli/utils/common.py | 7 +++ src/dstack/_internal/cli/utils/fleet.py | 8 +-- src/dstack/_internal/cli/utils/gateway.py | 16 ++++-- .../_internal/core/compatibility/exports.py | 18 +++++++ .../_internal/core/compatibility/gateways.py | 8 +++ src/dstack/api/server/_exports.py | 20 ++++++- src/dstack/api/server/_gateways.py | 16 ++++-- 11 files changed, 156 insertions(+), 32 deletions(-) create mode 100644 src/dstack/_internal/core/compatibility/exports.py diff --git a/src/dstack/_internal/cli/commands/export.py b/src/dstack/_internal/cli/commands/export.py index fc7920edf7..e8bf5db9ef 100644 --- a/src/dstack/_internal/cli/commands/export.py +++ b/src/dstack/_internal/cli/commands/export.py @@ -43,6 +43,13 @@ def _register(self): help="Fleet name to export (can be specified multiple times)", default=[], ) + create_parser.add_argument( + "--gateway", + action="append", + dest="gateways", + help="Gateway name to export (can be specified multiple times)", + default=[], + ) create_parser.set_defaults(subfunc=self._create) update_parser = subparsers.add_parser( @@ -80,6 +87,20 @@ def _register(self): help="Fleet name to remove (can be specified multiple times)", default=[], ) + update_parser.add_argument( + "--add-gateway", + action="append", + dest="add_gateways", + help="Gateway name to add (can be specified multiple times)", + default=[], + ) + update_parser.add_argument( + "--remove-gateway", + action="append", + dest="remove_gateways", + help="Gateway name to remove (can be specified multiple times)", + default=[], + ) update_parser.set_defaults(subfunc=self._update) delete_parser = subparsers.add_parser( @@ -109,6 +130,7 @@ def _create(self, args: argparse.Namespace): name=args.name, importer_projects=args.importers, exported_fleets=args.fleets, + exported_gateways=args.gateways, ) print_exports_table([export]) @@ -121,6 +143,8 @@ def _update(self, args: argparse.Namespace): remove_importer_projects=args.remove_importers, add_exported_fleets=args.add_fleets, remove_exported_fleets=args.remove_fleets, + add_exported_gateways=args.add_gateways, + remove_exported_gateways=args.remove_gateways, ) print_exports_table([export]) @@ -139,17 +163,24 @@ def print_exports_table(exports: list[Export]): table = Table(box=None) table.add_column("NAME", no_wrap=True) table.add_column("FLEETS") + table.add_column("GATEWAYS") table.add_column("IMPORTERS") for export in exports: fleets = ( ", ".join([f.name for f in export.exported_fleets]) if export.exported_fleets else "-" ) + gateways = ( + ", ".join([g.name for g in export.exported_gateways]) + if export.exported_gateways + else "-" + ) importers = ", ".join([i.project_name for i in export.imports]) if export.imports else "-" row = { "NAME": export.name, "FLEETS": fleets, + "GATEWAYS": gateways, "IMPORTERS": importers, } add_row_from_dict(table, row) diff --git a/src/dstack/_internal/cli/commands/gateway.py b/src/dstack/_internal/cli/commands/gateway.py index 9455f3d5fc..22dc42d5a4 100644 --- a/src/dstack/_internal/cli/commands/gateway.py +++ b/src/dstack/_internal/cli/commands/gateway.py @@ -17,6 +17,7 @@ print_gateways_table, ) from dstack._internal.core.errors import CLIError, ResourceNotExistsError +from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.gateways import GatewayStatus from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent from dstack._internal.utils.logging import get_logger @@ -67,7 +68,7 @@ def _register(self): ) delete_parser.set_defaults(subfunc=self._delete) delete_parser.add_argument( - "name", help="The name of the gateway" + "name", type=EntityReference.parse, help="The name of the gateway" ).completer = GatewayNameCompleter() # type: ignore[attr-defined] delete_parser.add_argument( "-y", "--yes", action="store_true", help="Don't ask for confirmation" @@ -78,7 +79,7 @@ def _register(self): ) update_parser.set_defaults(subfunc=self._update) update_parser.add_argument( - "name", help="The name of the gateway" + "name", type=EntityReference.parse, help="The name of the gateway" ).completer = GatewayNameCompleter() # type: ignore[attr-defined] update_parser.add_argument( "--set-default", action="store_true", help="Set it the default gateway for the project" @@ -89,7 +90,7 @@ def _register(self): "get", help="Get a gateway", formatter_class=self._parser.formatter_class ) get_parser.add_argument( - "name", metavar="NAME", help="The name of the gateway" + "name", metavar="NAME", type=EntityReference.parse, help="The name of the gateway" ).completer = GatewayNameCompleter() # type: ignore[attr-defined] get_parser.add_argument( "--json", @@ -108,7 +109,7 @@ def _list(self, args: argparse.Namespace): if args.watch and args.format == "json": raise CLIError("JSON output is not supported together with --watch") - gateways = self.api.client.gateways.list(self.api.project) + gateways = self.api.client.gateways.list(self.api.project, include_imported=True) deprecated_router_gateways = [ g.name for g in gateways @@ -127,45 +128,64 @@ def _list(self, args: argparse.Namespace): if args.format == "json": print_gateways_json(gateways, project=self.api.project) else: - print_gateways_table(gateways, verbose=args.verbose) + print_gateways_table( + gateways, current_project=self.api.project, verbose=args.verbose + ) return try: with Live(console=console, refresh_per_second=LIVE_TABLE_REFRESH_RATE_PER_SEC) as live: while True: - live.update(get_gateways_table(gateways, verbose=args.verbose)) + live.update( + get_gateways_table( + gateways, current_project=self.api.project, verbose=args.verbose + ) + ) time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS) gateways = self.api.client.gateways.list(self.api.project) except KeyboardInterrupt: pass def _delete(self, args: argparse.Namespace): - gateway = self.api.client.gateways.get(self.api.project, args.name) - print_gateways_table([gateway]) + if args.name.project is not None: + console.print( + "The [code]/[/] format is not supported for gateway names." + " Can only delete gateways owned by the current project" + ) + exit(1) + name = args.name.name + gateway = self.api.client.gateways.get(self.api.project, name) + print_gateways_table([gateway], current_project=self.api.project) if args.yes or confirm_ask("Do you want to delete the gateway?"): with console.status("Deleting gateway..."): - self.api.client.gateways.delete(self.api.project, [args.name]) + self.api.client.gateways.delete(self.api.project, [name]) console.print("Gateway deleted") else: console.print("Exiting...") return def _update(self, args: argparse.Namespace): + if args.name.project is not None: + console.print( + "The [code]/[/] format is not supported for gateway names." + " Can only update gateways owned by the current project" + ) + exit(1) + name = args.name.name with console.status("Updating gateway..."): if args.set_default: - self.api.client.gateways.set_default(self.api.project, args.name) + self.api.client.gateways.set_default(self.api.project, name) if args.domain: - self.api.client.gateways.set_wildcard_domain( - self.api.project, args.name, args.domain - ) - gateway = self.api.client.gateways.get(self.api.project, args.name) - print_gateways_table([gateway]) + self.api.client.gateways.set_wildcard_domain(self.api.project, name, args.domain) + gateway = self.api.client.gateways.get(self.api.project, name) + print_gateways_table([gateway], current_project=self.api.project) def _get(self, args: argparse.Namespace): # TODO: Implement non-json output format try: gateway = self.api.client.gateways.get( - project_name=self.api.project, gateway_name=args.name + project_name=args.name.project or self.api.project, + gateway_name=args.name.name, ) except ResourceNotExistsError: console.print("Gateway not found") diff --git a/src/dstack/_internal/cli/commands/import_.py b/src/dstack/_internal/cli/commands/import_.py index 2d8c4a1032..155a42a789 100644 --- a/src/dstack/_internal/cli/commands/import_.py +++ b/src/dstack/_internal/cli/commands/import_.py @@ -68,6 +68,7 @@ def print_imports_table(imports: list[Import]): table = Table(box=None) table.add_column("NAME", no_wrap=True) table.add_column("FLEETS") + table.add_column("GATEWAYS") for imp in imports: name = f"{imp.export.project_name}/{imp.export.name}" @@ -76,10 +77,16 @@ def print_imports_table(imports: list[Import]): if imp.export.exported_fleets else "-" ) + gateways = ( + ", ".join([g.name for g in imp.export.exported_gateways]) + if imp.export.exported_gateways + else "-" + ) row = { "NAME": name, "FLEETS": fleets, + "GATEWAYS": gateways, } add_row_from_dict(table, row) diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py index 1ddee68606..0b6993e18b 100644 --- a/src/dstack/_internal/cli/services/configurators/gateway.py +++ b/src/dstack/_internal/cli/services/configurators/gateway.py @@ -122,7 +122,9 @@ def apply_configuration( f"Provisioning [code]{gateway.name}[/]...", console=console ) as live: while not _finished_provisioning(gateway): - table = get_gateways_table([gateway], include_created=True) + table = get_gateways_table( + [gateway], current_project=self.api.project, include_created=True + ) live.update(table) time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS) gateway = self.api.client.gateways.get(self.api.project, gateway.name) @@ -139,6 +141,7 @@ def apply_configuration( console.print( get_gateways_table( [gateway], + current_project=self.api.project, verbose=gateway.status == GatewayStatus.FAILED, include_created=True, format_date=local_time, diff --git a/src/dstack/_internal/cli/utils/common.py b/src/dstack/_internal/cli/utils/common.py index 89d29032ec..bf837e161b 100644 --- a/src/dstack/_internal/cli/utils/common.py +++ b/src/dstack/_internal/cli/utils/common.py @@ -151,6 +151,13 @@ def resolve_url(url: str, timeout: float = 5.0) -> str: return response.url +def format_entity_reference(name: str, project: str, current_project: str) -> str: + if current_project == project: + return name + else: + return f"{project}/{name}" + + def format_instance_availability(v: InstanceAvailability) -> str: if v in (InstanceAvailability.UNKNOWN, InstanceAvailability.AVAILABLE): return "" diff --git a/src/dstack/_internal/cli/utils/fleet.py b/src/dstack/_internal/cli/utils/fleet.py index cf65b6fa35..75f6b50162 100644 --- a/src/dstack/_internal/cli/utils/fleet.py +++ b/src/dstack/_internal/cli/utils/fleet.py @@ -2,7 +2,7 @@ from rich.table import Table -from dstack._internal.cli.utils.common import add_row_from_dict, console +from dstack._internal.cli.utils.common import add_row_from_dict, console, format_entity_reference from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.fleets import Fleet, FleetNodesSpec, FleetStatus from dstack._internal.core.models.instances import Instance, InstanceStatus @@ -43,10 +43,6 @@ def get_fleets_table( config = fleet.spec.configuration merged_profile = fleet.spec.merged_profile - name = fleet.name - if fleet.project_name != current_project: - name = f"{fleet.project_name}/{fleet.name}" - # Detect SSH fleet vs backend fleet if config.ssh_config is not None: # SSH fleet: fixed number of hosts, no cloud billing @@ -76,7 +72,7 @@ def get_fleets_table( nodes = f"{nodes} (cluster)" fleet_row = { - "NAME": name, + "NAME": format_entity_reference(fleet.name, fleet.project_name, current_project), "NODES": nodes, "RESOURCES": resources, "GPU": gpu, diff --git a/src/dstack/_internal/cli/utils/gateway.py b/src/dstack/_internal/cli/utils/gateway.py index 9d9aae0f9a..0c35b3c371 100644 --- a/src/dstack/_internal/cli/utils/gateway.py +++ b/src/dstack/_internal/cli/utils/gateway.py @@ -3,13 +3,13 @@ from rich.table import Table from dstack._internal.cli.models.gateways import GatewayCommandOutput -from dstack._internal.cli.utils.common import add_row_from_dict, console +from dstack._internal.cli.utils.common import add_row_from_dict, console, format_entity_reference from dstack._internal.core.models.gateways import Gateway from dstack._internal.utils.common import DateFormatter, pretty_date -def print_gateways_table(gateways: List[Gateway], verbose: bool = False): - table = get_gateways_table(gateways, verbose=verbose) +def print_gateways_table(gateways: List[Gateway], current_project: str, verbose: bool = False): + table = get_gateways_table(gateways, current_project, verbose=verbose) console.print(table) console.print() @@ -25,6 +25,7 @@ def print_gateways_json(gateways: List[Gateway], project: str) -> None: def get_gateways_table( gateways: List[Gateway], + current_project: str, verbose: bool = False, include_created: bool = False, format_date: DateFormatter = pretty_date, @@ -42,8 +43,15 @@ def get_gateways_table( table.add_column("ERROR") for gateway in gateways: + name = format_entity_reference( + gateway.name, + # project_name == None means pre-0.20.20 server, which means no gateway exports support, + # which means the gateway is from the current project + gateway.project_name if gateway.project_name is not None else current_project, + current_project, + ) row = { - "NAME": gateway.name, + "NAME": name, "BACKEND": f"{gateway.configuration.backend.value} ({gateway.configuration.region})", "HOSTNAME": gateway.hostname, "DOMAIN": gateway.wildcard_domain, diff --git a/src/dstack/_internal/core/compatibility/exports.py b/src/dstack/_internal/core/compatibility/exports.py new file mode 100644 index 0000000000..2b9e2c85ba --- /dev/null +++ b/src/dstack/_internal/core/compatibility/exports.py @@ -0,0 +1,18 @@ +from dstack._internal.core.models.common import IncludeExcludeDictType +from dstack._internal.server.schemas.exports import CreateExportRequest, UpdateExportRequest + + +def get_create_export_excludes(request: CreateExportRequest) -> IncludeExcludeDictType: + excludes: IncludeExcludeDictType = {} + if not request.exported_gateways: + excludes["exported_gateways"] = True + return excludes + + +def get_update_export_excludes(request: UpdateExportRequest) -> IncludeExcludeDictType: + excludes: IncludeExcludeDictType = {} + if not request.add_exported_gateways: + excludes["add_exported_gateways"] = True + if not request.remove_exported_gateways: + excludes["remove_exported_gateways"] = True + return excludes diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index 949d6515f8..163cde8f3c 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -1,5 +1,13 @@ from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec +from dstack._internal.server.schemas.gateways import ListGatewaysRequest + + +def get_list_gateways_excludes(request: ListGatewaysRequest) -> IncludeExcludeDictType: + excludes: IncludeExcludeDictType = {} + if not request.include_imported: + excludes["include_imported"] = True + return excludes def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> IncludeExcludeDictType: diff --git a/src/dstack/api/server/_exports.py b/src/dstack/api/server/_exports.py index 419a4179ba..ad18bbd7bb 100644 --- a/src/dstack/api/server/_exports.py +++ b/src/dstack/api/server/_exports.py @@ -2,6 +2,10 @@ from pydantic import parse_obj_as +from dstack._internal.core.compatibility.exports import ( + get_create_export_excludes, + get_update_export_excludes, +) from dstack._internal.core.models.exports import Export from dstack._internal.server.schemas.exports import ( CreateExportRequest, @@ -23,13 +27,18 @@ def create( *, importer_projects: List[str] = [], exported_fleets: List[str] = [], + exported_gateways: List[str] = [], ) -> Export: body = CreateExportRequest( name=name, importer_projects=importer_projects, exported_fleets=exported_fleets, + exported_gateways=exported_gateways, + ) + resp = self._request( + f"/api/project/{project_name}/exports/create", + body=body.json(exclude=get_create_export_excludes(body)), ) - resp = self._request(f"/api/project/{project_name}/exports/create", body=body.json()) return parse_obj_as(Export.__response__, resp.json()) def update( @@ -41,6 +50,8 @@ def update( remove_importer_projects: List[str] = [], add_exported_fleets: List[str] = [], remove_exported_fleets: List[str] = [], + add_exported_gateways: List[str] = [], + remove_exported_gateways: List[str] = [], ) -> Export: body = UpdateExportRequest( name=name, @@ -48,8 +59,13 @@ def update( remove_importer_projects=remove_importer_projects, add_exported_fleets=add_exported_fleets, remove_exported_fleets=remove_exported_fleets, + add_exported_gateways=add_exported_gateways, + remove_exported_gateways=remove_exported_gateways, + ) + resp = self._request( + f"/api/project/{project_name}/exports/update", + body=body.json(exclude=get_update_export_excludes(body)), ) - resp = self._request(f"/api/project/{project_name}/exports/update", body=body.json()) return parse_obj_as(Export.__response__, resp.json()) def delete(self, project_name: str, name: str) -> None: diff --git a/src/dstack/api/server/_gateways.py b/src/dstack/api/server/_gateways.py index fe998d735a..87b7029f78 100644 --- a/src/dstack/api/server/_gateways.py +++ b/src/dstack/api/server/_gateways.py @@ -2,12 +2,16 @@ from pydantic import parse_obj_as -from dstack._internal.core.compatibility.gateways import get_create_gateway_excludes +from dstack._internal.core.compatibility.gateways import ( + get_create_gateway_excludes, + get_list_gateways_excludes, +) from dstack._internal.core.models.gateways import Gateway, GatewayConfiguration from dstack._internal.server.schemas.gateways import ( CreateGatewayRequest, DeleteGatewaysRequest, GetGatewayRequest, + ListGatewaysRequest, SetDefaultGatewayRequest, SetWildcardDomainRequest, ) @@ -15,8 +19,14 @@ class GatewaysAPIClient(APIClientGroup): - def list(self, project_name: str) -> List[Gateway]: - resp = self._request(f"/api/project/{project_name}/gateways/list") + def list(self, project_name: str, *, include_imported: bool = False) -> List[Gateway]: + body = ListGatewaysRequest( + include_imported=include_imported, + ) + resp = self._request( + f"/api/project/{project_name}/gateways/list", + body=body.json(exclude=get_list_gateways_excludes(body)), + ) return parse_obj_as(List[Gateway.__response__], resp.json()) def get(self, project_name: str, gateway_name: str) -> Gateway: From c569c4821887d76e2abdb4ac9b540edf4e0eb675 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 4 May 2026 09:02:13 +0200 Subject: [PATCH 4/9] `/gateway` syntax in the `gateway` prop --- .../_internal/core/compatibility/runs.py | 5 + .../_internal/core/models/configurations.py | 17 +- .../_internal/server/compatibility/runs.py | 8 + .../server/services/gateways/__init__.py | 34 ++-- .../server/services/services/__init__.py | 17 +- .../_internal/server/routers/test_runs.py | 149 +++++++++++++++++- 6 files changed, 211 insertions(+), 19 deletions(-) diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index f9dbaf4e28..27c80d952b 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -2,6 +2,7 @@ from dstack._internal.core.compatibility.common import patch_profile_params from dstack._internal.core.models.common import ( + EntityReference, IncludeExcludeDictType, IncludeExcludeSetType, ) @@ -167,3 +168,7 @@ def patch_run_spec(run_spec: RunSpec) -> None: patch_profile_params(run_spec.configuration) if run_spec.profile is not None: patch_profile_params(run_spec.profile) + if isinstance(run_spec.configuration, ServiceConfiguration): + if isinstance(run_spec.configuration.gateway, EntityReference): + # Pre-0.20.20 servers do not support `EntityReference` in `gateway` + run_spec.configuration.gateway = run_spec.configuration.gateway.format() diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 15da190b86..99da0553fe 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -14,6 +14,7 @@ CoreConfig, CoreModel, Duration, + EntityReference, RegistryAuth, generate_dual_core_model, ) @@ -876,7 +877,13 @@ class ServiceConfigurationParams(CoreModel): Field(description="The port the application listens on"), ] gateway: Annotated[ - Optional[Union[bool, str]], + Optional[ + Union[ + bool, + EntityReference, + str, # For server response compatibility with pre-0.20.20 clients + ] + ], Field( description=( "The name of the gateway. Specify boolean `false` to run without a gateway." @@ -989,6 +996,14 @@ def validate_probes(cls, v: Optional[list[ProbeConfig]]) -> Optional[list[ProbeC raise ValueError("Probes must be unique") return v + @validator("gateway") + def validate_gateway( + cls, v: Optional[Union[bool, EntityReference, str]] + ) -> Optional[Union[bool, EntityReference]]: + if isinstance(v, str): + return EntityReference.parse(v) + return v + @validator("replicas") def validate_replicas( cls, v: Optional[Union[Range[int], List[ReplicaGroup]]] diff --git a/src/dstack/_internal/server/compatibility/runs.py b/src/dstack/_internal/server/compatibility/runs.py index 9a5d0d99b8..752f5f784b 100644 --- a/src/dstack/_internal/server/compatibility/runs.py +++ b/src/dstack/_internal/server/compatibility/runs.py @@ -2,6 +2,7 @@ from packaging.version import Version +from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration from dstack._internal.core.models.runs import Run, RunPlan, RunSpec from dstack._internal.server.compatibility.common import patch_offers_list, patch_profile_params @@ -44,3 +45,10 @@ def patch_run_spec(run_spec: RunSpec, client_version: Optional[Version]) -> None patch_profile_params(run_spec.configuration, client_version) if run_spec.profile is not None: patch_profile_params(run_spec.profile, client_version) + # Clients prior to 0.20.20 do not support `EntityReference` in `gateway` + if ( + client_version < Version("0.20.20") + and isinstance(run_spec.configuration, ServiceConfiguration) + and isinstance(run_spec.configuration.gateway, EntityReference) + ): + run_spec.configuration.gateway = run_spec.configuration.gateway.format() diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 694eef7656..8abd28a128 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -30,6 +30,7 @@ SSHError, ) from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.gateways import ( AnyGatewayRouterConfig, Gateway, @@ -148,10 +149,10 @@ async def list_project_gateways( async def get_gateway_by_name( session: AsyncSession, project: ProjectModel, name: str ) -> Optional[Gateway]: - gateway = await get_project_gateway_model_by_name( + gateway = await get_project_gateway_model_by_reference( session=session, project=project, - name=name, + ref=EntityReference(name=name, project=None), load_gateway_compute=True, load_backend_type=True, ) @@ -266,10 +267,10 @@ async def create_gateway( ) default_gateway = gateway pipeline_hinter.hint_fetch(GatewayModel.__name__) - gateway = await get_project_gateway_model_by_name( + gateway = await get_project_gateway_model_by_reference( session=session, project=project, - name=configuration.name, + ref=EntityReference(name=configuration.name, project=None), load_gateway_compute=True, load_backend_type=True, ) @@ -395,7 +396,9 @@ async def set_gateway_wildcard_domain( async def set_default_gateway( session: AsyncSession, project: ProjectModel, name: str, user: Optional[UserModel] ): - gateway = await get_project_gateway_model_by_name(session=session, project=project, name=name) + gateway = await get_project_gateway_model_by_reference( + session=session, project=project, ref=EntityReference(name=name, project=None) + ) if gateway is None: raise ResourceNotExistsError() if gateway.to_be_deleted: @@ -457,17 +460,26 @@ async def list_project_gateway_models( return res.unique().scalars().all() -async def get_project_gateway_model_by_name( +async def get_project_gateway_model_by_reference( session: AsyncSession, project: ProjectModel, - name: str, + ref: EntityReference, load_gateway_compute: bool = False, load_backend_type: bool = False, ) -> Optional[GatewayModel]: - stmt = select(GatewayModel).where( - GatewayModel.project_id == project.id, - GatewayModel.name == name, - ) + stmt = select(GatewayModel).where(GatewayModel.name == ref.name) + if ref.project is None or ref.project == project.name: + stmt = stmt.where(GatewayModel.project_id == project.id) + else: + stmt = stmt.where( + exists().where( + ImportModel.project_id == project.id, + ImportModel.export_id == ExportedGatewayModel.export_id, + ExportedGatewayModel.gateway_id == GatewayModel.id, + GatewayModel.project_id == ProjectModel.id, + ProjectModel.name == ref.project, + ) + ) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) if load_backend_type: diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index c244b045de..c5c0867bec 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -16,6 +16,7 @@ ) from dstack._internal.core.models.configurations import ( SERVICE_HTTPS_DEFAULT, + EntityReference, ServiceConfiguration, ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus @@ -34,7 +35,7 @@ get_gateway_configuration, get_or_add_gateway_connection, get_project_default_gateway_model, - get_project_gateway_model_by_name, + get_project_gateway_model_by_reference, ) from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.services.options import get_service_options @@ -46,21 +47,25 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec: RunSpec): assert isinstance(run_spec.configuration, ServiceConfiguration) - if isinstance(run_spec.configuration.gateway, str): - gateway = await get_project_gateway_model_by_name( + if isinstance(run_spec.configuration.gateway, EntityReference) or isinstance( + run_spec.configuration.gateway, str + ): + gateway_reference = EntityReference.parse(run_spec.configuration.gateway) + gateway = await get_project_gateway_model_by_reference( session=session, project=run_model.project, - name=run_spec.configuration.gateway, + ref=gateway_reference, load_gateway_compute=True, load_backend_type=True, ) if gateway is None: raise ResourceNotExistsError( - f"Gateway {run_spec.configuration.gateway} does not exist" + f"Gateway {gateway_reference.format()} does not exist" + f" in project {run_model.project.name}" ) if gateway.to_be_deleted: raise ResourceNotExistsError( - f"Gateway {run_spec.configuration.gateway} was marked for deletion" + f"Gateway {gateway_reference.format()} was marked for deletion" ) elif run_spec.configuration.gateway == False: gateway = None diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 7dbc567cac..71d265b1d3 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -1338,6 +1338,61 @@ async def test_patches_fleets_for_old_clients( assert response.json()["run_spec"]["configuration"]["fleets"] == expected_fleets assert response.json()["run_spec"]["profile"]["fleets"] == expected_fleets + @pytest.mark.asyncio + @pytest.mark.parametrize( + "client_version,expected_gateway", + [ + ( + "0.20.19", + "other-project/my-gateway", + ), + ( + "0.20.20", + {"project": "other-project", "name": "my-gateway"}, + ), + ( + None, + {"project": "other-project", "name": "my-gateway"}, + ), + ], + ) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_patches_service_gateway_for_old_clients( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + client_version: Optional[str], + expected_gateway, + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + run_spec = get_run_spec( + configuration=ServiceConfiguration( + commands=["echo hello"], + port=80, + gateway=EntityReference(project="other-project", name="my-gateway"), + ), + repo_id=repo.name, + ) + run = await create_run( + session=session, project=project, repo=repo, user=user, run_spec=run_spec + ) + + headers = get_auth_headers(user.token) + if client_version is not None: + headers["X-API-Version"] = client_version + response = await client.post( + f"/api/project/{project.name}/runs/get", + headers=headers, + json={"run_name": run.run_name}, + ) + + assert response.status_code == 200 + assert response.json()["run_spec"]["configuration"]["gateway"] == expected_gateway + class TestGetRunPlan: @pytest.mark.asyncio @@ -3390,7 +3445,10 @@ async def test_return_error_if_specified_gateway_not_exists( assert response.status_code == 400 assert response.json() == { "detail": [ - {"msg": "Gateway nonexistent does not exist", "code": "resource_not_exists"} + { + "msg": f"Gateway nonexistent does not exist in project {project.name}", + "code": "resource_not_exists", + } ] } @@ -3420,6 +3478,95 @@ async def test_return_error_if_specified_gateway_is_true_and_no_gateway_exists( ] } + @pytest.mark.asyncio + async def test_submit_to_foreign_gateway_only_if_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ) -> None: + exporter_user = await create_user( + session=session, global_role=GlobalRole.USER, name="exporter_user" + ) + exporter_project = await create_project( + session=session, owner=exporter_user, name="exporter-project" + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="exported-gateway", + wildcard_domain="exported-gateway.example", + ) + + importer_user = await create_user( + session=session, global_role=GlobalRole.USER, name="importer_user" + ) + importer_project = await create_project( + session=session, owner=importer_user, name="importer-project" + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + importer_repo = await create_repo(session=session, project_id=importer_project.id) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + + not_importer_user = await create_user( + session=session, global_role=GlobalRole.USER, name="not_importer_user" + ) + not_importer_project = await create_project( + session=session, owner=not_importer_user, name="not-importer-project" + ) + await add_project_member( + session=session, + project=not_importer_project, + user=not_importer_user, + project_role=ProjectRole.USER, + ) + not_importer_repo = await create_repo(session=session, project_id=not_importer_project.id) + + importer_run_spec = get_service_run_spec( + repo_id=importer_repo.name, + run_name="test-service", + gateway="exporter-project/exported-gateway", + ) + response = await client.post( + f"/api/project/{importer_project.name}/runs/submit", + headers=get_auth_headers(importer_user.token), + json={"run_spec": importer_run_spec}, + ) + assert response.status_code == 200 + assert response.json()["service"]["url"] == "https://test-service.exported-gateway.example" + + not_importer_run_spec = get_service_run_spec( + repo_id=not_importer_repo.name, + gateway="exporter-project/exported-gateway", + ) + response = await client.post( + f"/api/project/{not_importer_project.name}/runs/submit", + headers=get_auth_headers(not_importer_user.token), + json={"run_spec": not_importer_run_spec}, + ) + assert response.status_code == 400 + assert response.json() == { + "detail": [ + { + "msg": "Gateway exporter-project/exported-gateway does not exist in project not-importer-project", + "code": "resource_not_exists", + } + ] + } + @pytest.mark.asyncio async def test_unregister_dangling_service( self, From 1e15f1eb01e2a114d3150bdcf92d8bf31af163ed Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 4 May 2026 09:18:20 +0200 Subject: [PATCH 5/9] Check for domain name conflicts --- .../_internal/proxy/gateway/repo/repo.py | 8 ++++++++ .../proxy/gateway/services/registry.py | 2 ++ .../proxy/gateway/routers/test_registry.py | 18 ++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/src/dstack/_internal/proxy/gateway/repo/repo.py b/src/dstack/_internal/proxy/gateway/repo/repo.py index eb74522fd1..0956faf04c 100644 --- a/src/dstack/_internal/proxy/gateway/repo/repo.py +++ b/src/dstack/_internal/proxy/gateway/repo/repo.py @@ -41,6 +41,14 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic async with self.reader(): return self._state.services.get(project_name, {}).get(run_name) + async def get_service_by_domain(self, domain: str) -> Optional[Service]: + async with self.reader(): + for project_services in self._state.services.values(): + for service in project_services.values(): + if service.domain == domain: + return service + return None + async def set_service(self, service: Service) -> None: async with self.writer(): self._state.services.setdefault(service.project_name, {})[service.run_name] = service diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 3f96b44563..84fbce8711 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -68,6 +68,8 @@ async def register_service( async with lock: if await repo.get_service(project_name, run_name) is not None: raise ProxyError(SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format(ref=service.fmt())) + if await repo.get_service_by_domain(domain) is not None: + raise ProxyError(f"Domain name {domain!r} is already taken by another service") old_project = await repo.get_project(project_name) new_project = models.Project(name=project_name, ssh_private_key=ssh_private_key) diff --git a/src/tests/_internal/proxy/gateway/routers/test_registry.py b/src/tests/_internal/proxy/gateway/routers/test_registry.py index ede195e487..239413cfda 100644 --- a/src/tests/_internal/proxy/gateway/routers/test_registry.py +++ b/src/tests/_internal/proxy/gateway/routers/test_registry.py @@ -174,6 +174,24 @@ async def test_register_same_name_in_different_projects( assert (tmp_path / "443-test-run.proj-1.gtw.test.conf").exists() assert (tmp_path / "443-test-run.proj-2.gtw.test.conf").exists() + async def test_register_same_domain_error(self, tmp_path: Path, system_mocks: Mocks) -> None: + client = make_client(tmp_path) + resp = await client.post( + "/api/registry/test-proj-1/services/register", + json=register_service_payload(run_name="test-run", domain="test-run.gtw.test"), + ) + assert resp.status_code == 200 + resp = await client.post( + "/api/registry/test-proj/services/register", + json=register_service_payload(run_name="test-run", domain="test-run.gtw.test"), + ) + assert resp.status_code == 400 + assert resp.json() == { + "detail": "Domain name 'test-run.gtw.test' is already taken by another service" + } + assert (tmp_path / "443-test-run.gtw.test.conf").exists() + assert system_mocks.reload_nginx.call_count == 1 + @freeze_time(datetime(2024, 12, 12, 0, 30)) async def test_register_with_model(self, tmp_path: Path, system_mocks: Mocks) -> None: repo = GatewayProxyRepo() From 87afd010d38eca6c9937ffad57bb4ad7e3c77e8e Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 5 May 2026 12:03:23 +0200 Subject: [PATCH 6/9] Fix client compatibility with pre-0.20.20 servers --- src/dstack/_internal/core/models/exports.py | 2 +- src/dstack/_internal/core/models/imports.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/models/exports.py b/src/dstack/_internal/core/models/exports.py index ae215f6ec1..d1f9ed6c61 100644 --- a/src/dstack/_internal/core/models/exports.py +++ b/src/dstack/_internal/core/models/exports.py @@ -22,4 +22,4 @@ class Export(CoreModel): name: str imports: list[ExportImport] exported_fleets: list[ExportedFleet] - exported_gateways: list[ExportedGateway] + exported_gateways: list[ExportedGateway] = [] diff --git a/src/dstack/_internal/core/models/imports.py b/src/dstack/_internal/core/models/imports.py index d3c297a44e..3329b3a753 100644 --- a/src/dstack/_internal/core/models/imports.py +++ b/src/dstack/_internal/core/models/imports.py @@ -18,7 +18,7 @@ class ImportExport(CoreModel): name: str project_name: str exported_fleets: list[ImportExportedFleet] - exported_gateways: list[ImportExportedGateway] + exported_gateways: list[ImportExportedGateway] = [] class Import(CoreModel): From 73c882e863b02be7d8c3946e16a34a8415cb1ce2 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 5 May 2026 12:29:17 +0200 Subject: [PATCH 7/9] Fix imported gateways in `dstack gateway -w` --- src/dstack/_internal/cli/commands/gateway.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/cli/commands/gateway.py b/src/dstack/_internal/cli/commands/gateway.py index 22dc42d5a4..16ebfb6665 100644 --- a/src/dstack/_internal/cli/commands/gateway.py +++ b/src/dstack/_internal/cli/commands/gateway.py @@ -142,7 +142,9 @@ def _list(self, args: argparse.Namespace): ) ) time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS) - gateways = self.api.client.gateways.list(self.api.project) + gateways = self.api.client.gateways.list( + self.api.project, include_imported=True + ) except KeyboardInterrupt: pass From f2f3dbfe8f02073405572282e9c9193b3394a56e Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 5 May 2026 12:35:01 +0200 Subject: [PATCH 8/9] Remove redundant list gateways excludes Older servers ignore the request body anyway. --- src/dstack/_internal/core/compatibility/gateways.py | 8 -------- src/dstack/api/server/_gateways.py | 10 ++-------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index 163cde8f3c..949d6515f8 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -1,13 +1,5 @@ from dstack._internal.core.models.common import IncludeExcludeDictType from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec -from dstack._internal.server.schemas.gateways import ListGatewaysRequest - - -def get_list_gateways_excludes(request: ListGatewaysRequest) -> IncludeExcludeDictType: - excludes: IncludeExcludeDictType = {} - if not request.include_imported: - excludes["include_imported"] = True - return excludes def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> IncludeExcludeDictType: diff --git a/src/dstack/api/server/_gateways.py b/src/dstack/api/server/_gateways.py index 87b7029f78..2952646767 100644 --- a/src/dstack/api/server/_gateways.py +++ b/src/dstack/api/server/_gateways.py @@ -2,10 +2,7 @@ from pydantic import parse_obj_as -from dstack._internal.core.compatibility.gateways import ( - get_create_gateway_excludes, - get_list_gateways_excludes, -) +from dstack._internal.core.compatibility.gateways import get_create_gateway_excludes from dstack._internal.core.models.gateways import Gateway, GatewayConfiguration from dstack._internal.server.schemas.gateways import ( CreateGatewayRequest, @@ -23,10 +20,7 @@ def list(self, project_name: str, *, include_imported: bool = False) -> List[Gat body = ListGatewaysRequest( include_imported=include_imported, ) - resp = self._request( - f"/api/project/{project_name}/gateways/list", - body=body.json(exclude=get_list_gateways_excludes(body)), - ) + resp = self._request(f"/api/project/{project_name}/gateways/list", body=body.json()) return parse_obj_as(List[Gateway.__response__], resp.json()) def get(self, project_name: str, gateway_name: str) -> Gateway: From 7f5b6c82b0db6d90bf21bbef207b1fb3fcf672f4 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 6 May 2026 13:35:58 +0200 Subject: [PATCH 9/9] Fix public project gateway access --- .../_internal/server/security/permissions.py | 1 + .../_internal/server/routers/test_gateways.py | 52 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/src/dstack/_internal/server/security/permissions.py b/src/dstack/_internal/server/security/permissions.py index 7d63d97560..6a4269a256 100644 --- a/src/dstack/_internal/server/security/permissions.py +++ b/src/dstack/_internal/server/security/permissions.py @@ -320,6 +320,7 @@ async def check_can_access_gateway( ) -> None: if ( user.global_role == GlobalRole.ADMIN + or gateway_project.is_public or get_user_project_role(user=user, project=gateway_project) is not None ): return diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 304f06854c..28645788bd 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -136,6 +136,58 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): }, } + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_list_non_member_public_project( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session, is_public=True) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + ) + response = await client.post( + f"/api/project/{project.name}/gateways/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0]["name"] == gateway.name + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_get_non_member_public_project( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session, is_public=True) + backend = await create_backend(session, project.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + ) + response = await client.post( + f"/api/project/{project.name}/gateways/get", + json={"name": gateway.name}, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert response.json()["name"] == gateway.name + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_get_missing(self, test_db, session: AsyncSession, client: AsyncClient):