Skip to content

Commit 30da6fc

Browse files
authored
Allow setting imported gateway as project default (#3860)
``` $ dstack gateway update main/main-gateway --set-default NAME BACKEND HOSTNAME DOMAIN DEFAULT STATUS main/main-gateway aws (eu-west-1) 108.131.126.35 gtw.mycompany.example ✓ running ```
1 parent 985764b commit 30da6fc

10 files changed

Lines changed: 259 additions & 44 deletions

File tree

src/dstack/_internal/cli/commands/gateway.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
console,
1313
)
1414
from dstack._internal.cli.utils.gateway import (
15+
get_gateway_relative_to_project,
1516
get_gateways_table,
1617
print_gateways_json,
1718
print_gateways_table,
1819
)
19-
from dstack._internal.core.errors import CLIError, ResourceNotExistsError
20+
from dstack._internal.core.errors import CLIError
2021
from dstack._internal.core.models.common import EntityReference
2122
from dstack._internal.core.models.gateways import GatewayStatus
2223
from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent
@@ -167,30 +168,38 @@ def _delete(self, args: argparse.Namespace):
167168
return
168169

169170
def _update(self, args: argparse.Namespace):
170-
if args.name.project is not None:
171-
console.print(
172-
"The [code]<project>/<gateway>[/] format is not supported for gateway names."
173-
" Can only update gateways owned by the current project"
174-
)
175-
exit(1)
176-
name = args.name.name
177171
with console.status("Updating gateway..."):
178-
if args.set_default:
179-
self.api.client.gateways.set_default(self.api.project, name)
180172
if args.domain:
181-
self.api.client.gateways.set_wildcard_domain(self.api.project, name, args.domain)
182-
gateway = self.api.client.gateways.get(self.api.project, name)
173+
if args.name.project is not None:
174+
console.print(
175+
"The [code]<project>/<gateway>[/] format is not supported for gateway names"
176+
" when [code]--domain[/] is passed."
177+
" Can only update gateways owned by the current project"
178+
)
179+
exit(1)
180+
self.api.client.gateways.set_wildcard_domain(
181+
self.api.project, args.name.name, args.domain
182+
)
183+
if args.set_default:
184+
self.api.client.gateways.set_default(
185+
self.api.project,
186+
gateway_name=args.name.name,
187+
gateway_project=args.name.project,
188+
)
189+
gateway = get_gateway_relative_to_project(
190+
client=self.api.client.gateways,
191+
project=self.api.project,
192+
gateway_project=args.name.project or self.api.project,
193+
gateway_name=args.name.name,
194+
)
183195
print_gateways_table([gateway], current_project=self.api.project)
184196

185197
def _get(self, args: argparse.Namespace):
186198
# TODO: Implement non-json output format
187-
try:
188-
gateway = self.api.client.gateways.get(
189-
project_name=args.name.project or self.api.project,
190-
gateway_name=args.name.name,
191-
)
192-
except ResourceNotExistsError:
193-
console.print("Gateway not found")
194-
exit(1)
195-
199+
gateway = get_gateway_relative_to_project(
200+
client=self.api.client.gateways,
201+
project=self.api.project,
202+
gateway_project=args.name.project or self.api.project,
203+
gateway_name=args.name.name,
204+
)
196205
print(pydantic_orjson_dumps_with_indent(gateway.dict(), default=None))

src/dstack/_internal/cli/utils/gateway.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,36 @@
44

55
from dstack._internal.cli.models.gateways import GatewayCommandOutput
66
from dstack._internal.cli.utils.common import add_row_from_dict, console, format_entity_reference
7+
from dstack._internal.core.errors import ResourceNotExistsError
8+
from dstack._internal.core.models.common import EntityReference
79
from dstack._internal.core.models.gateways import Gateway
810
from dstack._internal.utils.common import DateFormatter, pretty_date
11+
from dstack.api.server._gateways import GatewaysAPIClient
12+
13+
14+
def get_gateway_relative_to_project(
15+
client: GatewaysAPIClient, project: str, gateway_project: str, gateway_name: str
16+
) -> Gateway:
17+
"""
18+
Retrieves a single gateway, ensuring that `Gateway.default` is resolved relative to
19+
`project` rather than relative to the gateway's host project.
20+
"""
21+
if project == gateway_project:
22+
return client.get(project, gateway_name)
23+
24+
# For imported gateways, use `list`.
25+
# `get` would resolve `Gateway.default` relative to the gateway's host project
26+
gateways = client.list(project, include_imported=True)
27+
for gateway in gateways:
28+
if gateway.name == gateway_name and (
29+
gateway_project == gateway.project_name
30+
# Compatibility with pre-0.20.20 servers:
31+
# gateway.project_name is None means the gateway is in the current `project`
32+
or (gateway.project_name is None and gateway_project == project)
33+
):
34+
return gateway
35+
ref = EntityReference(name=gateway_name, project=gateway_project)
36+
raise ResourceNotExistsError(msg=f"Gateway {ref.format()!r} not found in project {project!r}")
937

1038

1139
def print_gateways_table(gateways: List[Gateway], current_project: str, verbose: bool = False):

src/dstack/_internal/core/compatibility/gateways.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dstack._internal.core.models.common import IncludeExcludeDictType
22
from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec
3+
from dstack._internal.server.schemas.gateways import SetDefaultGatewayRequest
34

45

56
def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> IncludeExcludeDictType:
@@ -26,6 +27,13 @@ def get_create_gateway_excludes(configuration: GatewayConfiguration) -> IncludeE
2627
return create_gateway_excludes
2728

2829

30+
def get_set_default_gateway_excludes(request: SetDefaultGatewayRequest) -> IncludeExcludeDictType:
31+
excludes: IncludeExcludeDictType = {}
32+
if request.gateway_project is None:
33+
excludes["gateway_project"] = True
34+
return excludes
35+
36+
2937
def _get_gateway_configuration_excludes(
3038
configuration: GatewayConfiguration,
3139
) -> IncludeExcludeDictType:

src/dstack/_internal/server/models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,10 @@ class ProjectModel(BaseModel):
274274
default_gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column(
275275
ForeignKey("gateways.id", use_alter=True, ondelete="SET NULL"), nullable=True
276276
)
277-
default_gateway: Mapped[Optional["GatewayModel"]] = relationship(
278-
foreign_keys=[default_gateway_id]
279-
)
277+
"""
278+
**NOTE**: default_gateway_id may point to a previously imported gateway that the project is no
279+
longer authorized to use. Check access before using the gateway.
280+
"""
280281

281282
# TODO: drop `default_pool_id` after the release without pools.
282283
default_pool_id: Mapped[Optional[UUIDType]] = mapped_column(

src/dstack/_internal/server/routers/gateways.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dstack._internal.server.schemas.gateways as schemas
88
import dstack._internal.server.services.gateways as gateways
99
from dstack._internal.core.errors import ResourceNotExistsError
10+
from dstack._internal.core.models.common import EntityReference
1011
from dstack._internal.server.db import get_session
1112
from dstack._internal.server.deps import Project
1213
from dstack._internal.server.models import ProjectModel, UserModel
@@ -104,7 +105,12 @@ async def set_default_gateway(
104105
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
105106
):
106107
user, project = user_project
107-
await gateways.set_default_gateway(session=session, project=project, name=body.name, user=user)
108+
await gateways.set_default_gateway(
109+
session=session,
110+
project=project,
111+
ref=EntityReference(name=body.name, project=body.gateway_project),
112+
user=user,
113+
)
108114

109115

110116
@router.post("/set_wildcard_domain", response_model=models.Gateway)

src/dstack/_internal/server/schemas/gateways.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List
1+
from typing import Any, Dict, List, Optional
22

33
from dstack._internal.core.models.common import CoreConfig, CoreModel, generate_dual_core_model
44
from dstack._internal.core.models.gateways import GatewayConfiguration
@@ -28,6 +28,7 @@ class DeleteGatewaysRequest(CoreModel):
2828

2929
class SetDefaultGatewayRequest(CoreModel):
3030
name: str
31+
gateway_project: Optional[str] = None
3132

3233

3334
class SetWildcardDomainRequest(CoreModel):

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,10 @@ async def create_gateway(
263263
default_gateway = await get_project_default_gateway_model(session=session, project=project)
264264
if default_gateway is None or configuration.default:
265265
await set_default_gateway(
266-
session=session, project=project, name=configuration.name, user=user
266+
session=session,
267+
project=project,
268+
ref=EntityReference(name=configuration.name, project=None),
269+
user=user,
267270
)
268271
default_gateway = gateway
269272
pipeline_hinter.hint_fetch(GatewayModel.__name__)
@@ -394,18 +397,18 @@ async def set_gateway_wildcard_domain(
394397

395398

396399
async def set_default_gateway(
397-
session: AsyncSession, project: ProjectModel, name: str, user: Optional[UserModel]
400+
session: AsyncSession, project: ProjectModel, ref: EntityReference, user: Optional[UserModel]
398401
):
399402
gateway = await get_project_gateway_model_by_reference(
400-
session=session, project=project, ref=EntityReference(name=name, project=None)
403+
session=session, project=project, ref=ref
401404
)
402405
if gateway is None:
403406
raise ResourceNotExistsError()
404407
if gateway.to_be_deleted:
405408
raise ServerClientError("Cannot set gateway marked for deletion as default")
406-
if project.default_gateway_id == gateway.id:
407-
return
408409
previous_gateway = await get_project_default_gateway_model(session, project)
410+
if previous_gateway is not None and previous_gateway.id == gateway.id:
411+
return
409412
await session.execute(
410413
update(ProjectModel)
411414
.where(
@@ -418,15 +421,21 @@ async def set_default_gateway(
418421
if previous_gateway is not None:
419422
events.emit(
420423
session,
421-
"Gateway unset as default",
424+
"Gateway unset as project default",
422425
actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(),
423-
targets=[events.Target.from_model(previous_gateway)],
426+
targets=[
427+
events.Target.from_model(previous_gateway),
428+
events.Target.from_model(project),
429+
],
424430
)
425431
events.emit(
426432
session,
427-
"Gateway set as default",
433+
"Gateway set as project default",
428434
actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(),
429-
targets=[events.Target.from_model(gateway)],
435+
targets=[
436+
events.Target.from_model(gateway),
437+
events.Target.from_model(project),
438+
],
430439
)
431440
await session.commit()
432441

@@ -531,6 +540,14 @@ async def get_project_default_gateway_model(
531540
stmt = select(GatewayModel).where(
532541
GatewayModel.id == project.default_gateway_id,
533542
GatewayModel.to_be_deleted == False,
543+
or_(
544+
GatewayModel.project_id == project.id,
545+
exists().where(
546+
ImportModel.project_id == project.id,
547+
ImportModel.export_id == ExportedGatewayModel.export_id,
548+
ExportedGatewayModel.gateway_id == GatewayModel.id,
549+
),
550+
),
534551
)
535552
if load_gateway_compute:
536553
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))

src/dstack/api/server/_gateways.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
from pydantic import parse_obj_as
44

5-
from dstack._internal.core.compatibility.gateways import get_create_gateway_excludes
5+
from dstack._internal.core.compatibility.gateways import (
6+
get_create_gateway_excludes,
7+
get_set_default_gateway_excludes,
8+
)
69
from dstack._internal.core.models.gateways import Gateway, GatewayConfiguration
710
from dstack._internal.server.schemas.gateways import (
811
CreateGatewayRequest,
@@ -44,9 +47,16 @@ def delete(self, project_name: str, gateways_names: List[str]) -> None:
4447
body = DeleteGatewaysRequest(names=gateways_names)
4548
self._request(f"/api/project/{project_name}/gateways/delete", body=body.json())
4649

47-
def set_default(self, project_name: str, gateway_name: str) -> None:
48-
body = SetDefaultGatewayRequest(name=gateway_name)
49-
self._request(f"/api/project/{project_name}/gateways/set_default", body=body.json())
50+
def set_default(
51+
self, project_name: str, gateway_name: str, *, gateway_project: Optional[str] = None
52+
) -> None:
53+
if gateway_project == project_name:
54+
gateway_project = None # omit for compatibility with pre-0.20.20 servers
55+
body = SetDefaultGatewayRequest(name=gateway_name, gateway_project=gateway_project)
56+
self._request(
57+
f"/api/project/{project_name}/gateways/set_default",
58+
body=body.json(exclude=get_set_default_gateway_excludes(body)),
59+
)
5060

5161
def set_wildcard_domain(
5262
self, project_name: str, gateway_name: str, wildcard_domain: str

0 commit comments

Comments
 (0)