Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions src/dstack/_internal/cli/commands/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
console,
)
from dstack._internal.cli.utils.gateway import (
get_gateway_relative_to_project,
get_gateways_table,
print_gateways_json,
print_gateways_table,
)
from dstack._internal.core.errors import CLIError, ResourceNotExistsError
from dstack._internal.core.errors import CLIError
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.gateways import GatewayStatus
from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent
Expand Down Expand Up @@ -167,30 +168,38 @@ def _delete(self, args: argparse.Namespace):
return

def _update(self, args: argparse.Namespace):
if args.name.project is not None:
console.print(
"The [code]<project>/<gateway>[/] format is not supported for gateway names."
" Can only update gateways owned by the current project"
)
exit(1)
name = args.name.name
with console.status("Updating gateway..."):
if args.set_default:
self.api.client.gateways.set_default(self.api.project, name)
if args.domain:
self.api.client.gateways.set_wildcard_domain(self.api.project, name, args.domain)
gateway = self.api.client.gateways.get(self.api.project, name)
if args.name.project is not None:
console.print(
"The [code]<project>/<gateway>[/] format is not supported for gateway names"
" when [code]--domain[/] is passed."
" Can only update gateways owned by the current project"
)
exit(1)
self.api.client.gateways.set_wildcard_domain(
self.api.project, args.name.name, args.domain
)
if args.set_default:
self.api.client.gateways.set_default(
self.api.project,
gateway_name=args.name.name,
gateway_project=args.name.project,
)
gateway = get_gateway_relative_to_project(
client=self.api.client.gateways,
project=self.api.project,
gateway_project=args.name.project or self.api.project,
gateway_name=args.name.name,
)
print_gateways_table([gateway], current_project=self.api.project)

def _get(self, args: argparse.Namespace):
# TODO: Implement non-json output format
try:
gateway = self.api.client.gateways.get(
project_name=args.name.project or self.api.project,
gateway_name=args.name.name,
)
except ResourceNotExistsError:
console.print("Gateway not found")
exit(1)

gateway = get_gateway_relative_to_project(
client=self.api.client.gateways,
project=self.api.project,
gateway_project=args.name.project or self.api.project,
gateway_name=args.name.name,
)
print(pydantic_orjson_dumps_with_indent(gateway.dict(), default=None))
28 changes: 28 additions & 0 deletions src/dstack/_internal/cli/utils/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,36 @@

from dstack._internal.cli.models.gateways import GatewayCommandOutput
from dstack._internal.cli.utils.common import add_row_from_dict, console, format_entity_reference
from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.gateways import Gateway
from dstack._internal.utils.common import DateFormatter, pretty_date
from dstack.api.server._gateways import GatewaysAPIClient


def get_gateway_relative_to_project(
client: GatewaysAPIClient, project: str, gateway_project: str, gateway_name: str
) -> Gateway:
"""
Retrieves a single gateway, ensuring that `Gateway.default` is resolved relative to
`project` rather than relative to the gateway's host project.
"""
if project == gateway_project:
return client.get(project, gateway_name)

# For imported gateways, use `list`.
# `get` would resolve `Gateway.default` relative to the gateway's host project
gateways = client.list(project, include_imported=True)
for gateway in gateways:
if gateway.name == gateway_name and (
gateway_project == gateway.project_name
# Compatibility with pre-0.20.20 servers:
# gateway.project_name is None means the gateway is in the current `project`
or (gateway.project_name is None and gateway_project == project)
):
return gateway
ref = EntityReference(name=gateway_name, project=gateway_project)
raise ResourceNotExistsError(msg=f"Gateway {ref.format()!r} not found in project {project!r}")


def print_gateways_table(gateways: List[Gateway], current_project: str, verbose: bool = False):
Expand Down
8 changes: 8 additions & 0 deletions src/dstack/_internal/core/compatibility/gateways.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dstack._internal.core.models.common import IncludeExcludeDictType
from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec
from dstack._internal.server.schemas.gateways import SetDefaultGatewayRequest


def get_gateway_spec_excludes(gateway_spec: GatewaySpec) -> IncludeExcludeDictType:
Expand All @@ -26,6 +27,13 @@ def get_create_gateway_excludes(configuration: GatewayConfiguration) -> IncludeE
return create_gateway_excludes


def get_set_default_gateway_excludes(request: SetDefaultGatewayRequest) -> IncludeExcludeDictType:
excludes: IncludeExcludeDictType = {}
if request.gateway_project is None:
excludes["gateway_project"] = True
return excludes


def _get_gateway_configuration_excludes(
configuration: GatewayConfiguration,
) -> IncludeExcludeDictType:
Expand Down
7 changes: 4 additions & 3 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,10 @@ class ProjectModel(BaseModel):
default_gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column(
ForeignKey("gateways.id", use_alter=True, ondelete="SET NULL"), nullable=True
)
default_gateway: Mapped[Optional["GatewayModel"]] = relationship(
foreign_keys=[default_gateway_id]
)
"""
**NOTE**: default_gateway_id may point to a previously imported gateway that the project is no
longer authorized to use. Check access before using the gateway.
"""

# TODO: drop `default_pool_id` after the release without pools.
default_pool_id: Mapped[Optional[UUIDType]] = mapped_column(
Expand Down
8 changes: 7 additions & 1 deletion src/dstack/_internal/server/routers/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dstack._internal.server.schemas.gateways as schemas
import dstack._internal.server.services.gateways as gateways
from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.common import EntityReference
from dstack._internal.server.db import get_session
from dstack._internal.server.deps import Project
from dstack._internal.server.models import ProjectModel, UserModel
Expand Down Expand Up @@ -104,7 +105,12 @@ async def set_default_gateway(
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
):
user, project = user_project
await gateways.set_default_gateway(session=session, project=project, name=body.name, user=user)
await gateways.set_default_gateway(
session=session,
project=project,
ref=EntityReference(name=body.name, project=body.gateway_project),
user=user,
)


@router.post("/set_wildcard_domain", response_model=models.Gateway)
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/schemas/gateways.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from dstack._internal.core.models.common import CoreConfig, CoreModel, generate_dual_core_model
from dstack._internal.core.models.gateways import GatewayConfiguration
Expand Down Expand Up @@ -28,6 +28,7 @@ class DeleteGatewaysRequest(CoreModel):

class SetDefaultGatewayRequest(CoreModel):
name: str
gateway_project: Optional[str] = None


class SetWildcardDomainRequest(CoreModel):
Expand Down
35 changes: 26 additions & 9 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,10 @@ async def create_gateway(
default_gateway = await get_project_default_gateway_model(session=session, project=project)
if default_gateway is None or configuration.default:
await set_default_gateway(
session=session, project=project, name=configuration.name, user=user
session=session,
project=project,
ref=EntityReference(name=configuration.name, project=None),
user=user,
)
default_gateway = gateway
pipeline_hinter.hint_fetch(GatewayModel.__name__)
Expand Down Expand Up @@ -394,18 +397,18 @@ async def set_gateway_wildcard_domain(


async def set_default_gateway(
session: AsyncSession, project: ProjectModel, name: str, user: Optional[UserModel]
session: AsyncSession, project: ProjectModel, ref: EntityReference, user: Optional[UserModel]
):
gateway = await get_project_gateway_model_by_reference(
session=session, project=project, ref=EntityReference(name=name, project=None)
session=session, project=project, ref=ref
)
if gateway is None:
raise ResourceNotExistsError()
if gateway.to_be_deleted:
raise ServerClientError("Cannot set gateway marked for deletion as default")
if project.default_gateway_id == gateway.id:
return
previous_gateway = await get_project_default_gateway_model(session, project)
if previous_gateway is not None and previous_gateway.id == gateway.id:
return
await session.execute(
update(ProjectModel)
.where(
Expand All @@ -418,15 +421,21 @@ async def set_default_gateway(
if previous_gateway is not None:
events.emit(
session,
"Gateway unset as default",
"Gateway unset as project default",
actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(),
targets=[events.Target.from_model(previous_gateway)],
targets=[
events.Target.from_model(previous_gateway),
events.Target.from_model(project),
],
)
events.emit(
session,
"Gateway set as default",
"Gateway set as project default",
actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(),
targets=[events.Target.from_model(gateway)],
targets=[
events.Target.from_model(gateway),
events.Target.from_model(project),
],
)
await session.commit()

Expand Down Expand Up @@ -531,6 +540,14 @@ async def get_project_default_gateway_model(
stmt = select(GatewayModel).where(
GatewayModel.id == project.default_gateway_id,
GatewayModel.to_be_deleted == False,
or_(
GatewayModel.project_id == project.id,
exists().where(
ImportModel.project_id == project.id,
ImportModel.export_id == ExportedGatewayModel.export_id,
ExportedGatewayModel.gateway_id == GatewayModel.id,
),
),
)
if load_gateway_compute:
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
Expand Down
20 changes: 15 additions & 5 deletions src/dstack/api/server/_gateways.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List
from typing import List, Optional

from pydantic import parse_obj_as

from dstack._internal.core.compatibility.gateways import get_create_gateway_excludes
from dstack._internal.core.compatibility.gateways import (
get_create_gateway_excludes,
get_set_default_gateway_excludes,
)
from dstack._internal.core.models.gateways import Gateway, GatewayConfiguration
from dstack._internal.server.schemas.gateways import (
CreateGatewayRequest,
Expand Down Expand Up @@ -44,9 +47,16 @@ def delete(self, project_name: str, gateways_names: List[str]) -> None:
body = DeleteGatewaysRequest(names=gateways_names)
self._request(f"/api/project/{project_name}/gateways/delete", body=body.json())

def set_default(self, project_name: str, gateway_name: str) -> None:
body = SetDefaultGatewayRequest(name=gateway_name)
self._request(f"/api/project/{project_name}/gateways/set_default", body=body.json())
def set_default(
self, project_name: str, gateway_name: str, *, gateway_project: Optional[str] = None
) -> None:
if gateway_project == project_name:
gateway_project = None # omit for compatibility with pre-0.20.20 servers
body = SetDefaultGatewayRequest(name=gateway_name, gateway_project=gateway_project)
self._request(
f"/api/project/{project_name}/gateways/set_default",
body=body.json(exclude=get_set_default_gateway_excludes(body)),
)

def set_wildcard_domain(
self, project_name: str, gateway_name: str, wildcard_domain: str
Expand Down
Loading
Loading