Skip to content
31 changes: 31 additions & 0 deletions src/dstack/_internal/cli/commands/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def _register(self):
help="Fleet name to export (can be specified multiple times)",
default=[],
)
create_parser.add_argument(
"--gateway",
action="append",
dest="gateways",
help="Gateway name to export (can be specified multiple times)",
default=[],
)
create_parser.set_defaults(subfunc=self._create)

update_parser = subparsers.add_parser(
Expand Down Expand Up @@ -80,6 +87,20 @@ def _register(self):
help="Fleet name to remove (can be specified multiple times)",
default=[],
)
update_parser.add_argument(
"--add-gateway",
action="append",
dest="add_gateways",
help="Gateway name to add (can be specified multiple times)",
default=[],
)
update_parser.add_argument(
"--remove-gateway",
action="append",
dest="remove_gateways",
help="Gateway name to remove (can be specified multiple times)",
default=[],
)
update_parser.set_defaults(subfunc=self._update)

delete_parser = subparsers.add_parser(
Expand Down Expand Up @@ -109,6 +130,7 @@ def _create(self, args: argparse.Namespace):
name=args.name,
importer_projects=args.importers,
exported_fleets=args.fleets,
exported_gateways=args.gateways,
)
print_exports_table([export])

Expand All @@ -121,6 +143,8 @@ def _update(self, args: argparse.Namespace):
remove_importer_projects=args.remove_importers,
add_exported_fleets=args.add_fleets,
remove_exported_fleets=args.remove_fleets,
add_exported_gateways=args.add_gateways,
remove_exported_gateways=args.remove_gateways,
)
print_exports_table([export])

Expand All @@ -139,17 +163,24 @@ def print_exports_table(exports: list[Export]):
table = Table(box=None)
table.add_column("NAME", no_wrap=True)
table.add_column("FLEETS")
table.add_column("GATEWAYS")
table.add_column("IMPORTERS")

for export in exports:
fleets = (
", ".join([f.name for f in export.exported_fleets]) if export.exported_fleets else "-"
)
gateways = (
", ".join([g.name for g in export.exported_gateways])
if export.exported_gateways
else "-"
)
importers = ", ".join([i.project_name for i in export.imports]) if export.imports else "-"

row = {
"NAME": export.name,
"FLEETS": fleets,
"GATEWAYS": gateways,
"IMPORTERS": importers,
}
add_row_from_dict(table, row)
Expand Down
56 changes: 39 additions & 17 deletions src/dstack/_internal/cli/commands/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
print_gateways_table,
)
from dstack._internal.core.errors import CLIError, ResourceNotExistsError
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.gateways import GatewayStatus
from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -67,7 +68,7 @@ def _register(self):
)
delete_parser.set_defaults(subfunc=self._delete)
delete_parser.add_argument(
"name", help="The name of the gateway"
"name", type=EntityReference.parse, help="The name of the gateway"
).completer = GatewayNameCompleter() # type: ignore[attr-defined]
delete_parser.add_argument(
"-y", "--yes", action="store_true", help="Don't ask for confirmation"
Expand All @@ -78,7 +79,7 @@ def _register(self):
)
update_parser.set_defaults(subfunc=self._update)
update_parser.add_argument(
"name", help="The name of the gateway"
"name", type=EntityReference.parse, help="The name of the gateway"
).completer = GatewayNameCompleter() # type: ignore[attr-defined]
update_parser.add_argument(
"--set-default", action="store_true", help="Set it the default gateway for the project"
Expand All @@ -89,7 +90,7 @@ def _register(self):
"get", help="Get a gateway", formatter_class=self._parser.formatter_class
)
get_parser.add_argument(
"name", metavar="NAME", help="The name of the gateway"
"name", metavar="NAME", type=EntityReference.parse, help="The name of the gateway"
).completer = GatewayNameCompleter() # type: ignore[attr-defined]
get_parser.add_argument(
"--json",
Expand All @@ -108,7 +109,7 @@ def _list(self, args: argparse.Namespace):
if args.watch and args.format == "json":
raise CLIError("JSON output is not supported together with --watch")

gateways = self.api.client.gateways.list(self.api.project)
gateways = self.api.client.gateways.list(self.api.project, include_imported=True)
deprecated_router_gateways = [
g.name
for g in gateways
Expand All @@ -127,45 +128,66 @@ def _list(self, args: argparse.Namespace):
if args.format == "json":
print_gateways_json(gateways, project=self.api.project)
else:
print_gateways_table(gateways, verbose=args.verbose)
print_gateways_table(
gateways, current_project=self.api.project, verbose=args.verbose
)
return

try:
with Live(console=console, refresh_per_second=LIVE_TABLE_REFRESH_RATE_PER_SEC) as live:
while True:
live.update(get_gateways_table(gateways, verbose=args.verbose))
live.update(
get_gateways_table(
gateways, current_project=self.api.project, verbose=args.verbose
)
)
time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
gateways = self.api.client.gateways.list(self.api.project)
gateways = self.api.client.gateways.list(
self.api.project, include_imported=True
)
except KeyboardInterrupt:
pass

def _delete(self, args: argparse.Namespace):
gateway = self.api.client.gateways.get(self.api.project, args.name)
print_gateways_table([gateway])
if args.name.project is not None:
console.print(
"The [code]<project>/<gateway>[/] format is not supported for gateway names."
" Can only delete gateways owned by the current project"
)
exit(1)
name = args.name.name
gateway = self.api.client.gateways.get(self.api.project, name)
print_gateways_table([gateway], current_project=self.api.project)
if args.yes or confirm_ask("Do you want to delete the gateway?"):
with console.status("Deleting gateway..."):
self.api.client.gateways.delete(self.api.project, [args.name])
self.api.client.gateways.delete(self.api.project, [name])
console.print("Gateway deleted")
else:
console.print("Exiting...")
return

def _update(self, args: argparse.Namespace):
if args.name.project is not None:
console.print(
"The [code]<project>/<gateway>[/] format is not supported for gateway names."
" Can only update gateways owned by the current project"
)
exit(1)
name = args.name.name
with console.status("Updating gateway..."):
if args.set_default:
self.api.client.gateways.set_default(self.api.project, args.name)
self.api.client.gateways.set_default(self.api.project, name)
if args.domain:
self.api.client.gateways.set_wildcard_domain(
self.api.project, args.name, args.domain
)
gateway = self.api.client.gateways.get(self.api.project, args.name)
print_gateways_table([gateway])
self.api.client.gateways.set_wildcard_domain(self.api.project, name, args.domain)
gateway = self.api.client.gateways.get(self.api.project, name)
print_gateways_table([gateway], current_project=self.api.project)

def _get(self, args: argparse.Namespace):
# TODO: Implement non-json output format
try:
gateway = self.api.client.gateways.get(
project_name=self.api.project, gateway_name=args.name
project_name=args.name.project or self.api.project,
gateway_name=args.name.name,
)
except ResourceNotExistsError:
console.print("Gateway not found")
Expand Down
7 changes: 7 additions & 0 deletions src/dstack/_internal/cli/commands/import_.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def print_imports_table(imports: list[Import]):
table = Table(box=None)
table.add_column("NAME", no_wrap=True)
table.add_column("FLEETS")
table.add_column("GATEWAYS")

for imp in imports:
name = f"{imp.export.project_name}/{imp.export.name}"
Expand All @@ -76,10 +77,16 @@ def print_imports_table(imports: list[Import]):
if imp.export.exported_fleets
else "-"
)
gateways = (
", ".join([g.name for g in imp.export.exported_gateways])
if imp.export.exported_gateways
else "-"
)

row = {
"NAME": name,
"FLEETS": fleets,
"GATEWAYS": gateways,
}
add_row_from_dict(table, row)

Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/cli/services/configurators/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def apply_configuration(
f"Provisioning [code]{gateway.name}[/]...", console=console
) as live:
while not _finished_provisioning(gateway):
table = get_gateways_table([gateway], include_created=True)
table = get_gateways_table(
[gateway], current_project=self.api.project, include_created=True
)
live.update(table)
time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
gateway = self.api.client.gateways.get(self.api.project, gateway.name)
Expand All @@ -139,6 +141,7 @@ def apply_configuration(
console.print(
get_gateways_table(
[gateway],
current_project=self.api.project,
verbose=gateway.status == GatewayStatus.FAILED,
include_created=True,
format_date=local_time,
Expand Down
7 changes: 7 additions & 0 deletions src/dstack/_internal/cli/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def resolve_url(url: str, timeout: float = 5.0) -> str:
return response.url


def format_entity_reference(name: str, project: str, current_project: str) -> str:
if current_project == project:
return name
else:
return f"{project}/{name}"


def format_instance_availability(v: InstanceAvailability) -> str:
if v in (InstanceAvailability.UNKNOWN, InstanceAvailability.AVAILABLE):
return ""
Expand Down
8 changes: 2 additions & 6 deletions src/dstack/_internal/cli/utils/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from rich.table import Table

from dstack._internal.cli.utils.common import add_row_from_dict, console
from dstack._internal.cli.utils.common import add_row_from_dict, console, format_entity_reference
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.fleets import Fleet, FleetNodesSpec, FleetStatus
from dstack._internal.core.models.instances import Instance, InstanceStatus
Expand Down Expand Up @@ -43,10 +43,6 @@ def get_fleets_table(
config = fleet.spec.configuration
merged_profile = fleet.spec.merged_profile

name = fleet.name
if fleet.project_name != current_project:
name = f"{fleet.project_name}/{fleet.name}"

# Detect SSH fleet vs backend fleet
if config.ssh_config is not None:
# SSH fleet: fixed number of hosts, no cloud billing
Expand Down Expand Up @@ -76,7 +72,7 @@ def get_fleets_table(
nodes = f"{nodes} (cluster)"

fleet_row = {
"NAME": name,
"NAME": format_entity_reference(fleet.name, fleet.project_name, current_project),
"NODES": nodes,
"RESOURCES": resources,
"GPU": gpu,
Expand Down
16 changes: 12 additions & 4 deletions src/dstack/_internal/cli/utils/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from rich.table import Table

from dstack._internal.cli.models.gateways import GatewayCommandOutput
from dstack._internal.cli.utils.common import add_row_from_dict, console
from dstack._internal.cli.utils.common import add_row_from_dict, console, format_entity_reference
from dstack._internal.core.models.gateways import Gateway
from dstack._internal.utils.common import DateFormatter, pretty_date


def print_gateways_table(gateways: List[Gateway], verbose: bool = False):
table = get_gateways_table(gateways, verbose=verbose)
def print_gateways_table(gateways: List[Gateway], current_project: str, verbose: bool = False):
table = get_gateways_table(gateways, current_project, verbose=verbose)
console.print(table)
console.print()

Expand All @@ -25,6 +25,7 @@ def print_gateways_json(gateways: List[Gateway], project: str) -> None:

def get_gateways_table(
gateways: List[Gateway],
current_project: str,
verbose: bool = False,
include_created: bool = False,
format_date: DateFormatter = pretty_date,
Expand All @@ -42,8 +43,15 @@ def get_gateways_table(
table.add_column("ERROR")

for gateway in gateways:
name = format_entity_reference(
gateway.name,
# project_name == None means pre-0.20.20 server, which means no gateway exports support,
# which means the gateway is from the current project
gateway.project_name if gateway.project_name is not None else current_project,
current_project,
)
row = {
"NAME": gateway.name,
"NAME": name,
"BACKEND": f"{gateway.configuration.backend.value} ({gateway.configuration.region})",
"HOSTNAME": gateway.hostname,
"DOMAIN": gateway.wildcard_domain,
Expand Down
18 changes: 18 additions & 0 deletions src/dstack/_internal/core/compatibility/exports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dstack._internal.core.models.common import IncludeExcludeDictType
from dstack._internal.server.schemas.exports import CreateExportRequest, UpdateExportRequest


def get_create_export_excludes(request: CreateExportRequest) -> IncludeExcludeDictType:
excludes: IncludeExcludeDictType = {}
if not request.exported_gateways:
excludes["exported_gateways"] = True
return excludes


def get_update_export_excludes(request: UpdateExportRequest) -> IncludeExcludeDictType:
excludes: IncludeExcludeDictType = {}
if not request.add_exported_gateways:
excludes["add_exported_gateways"] = True
if not request.remove_exported_gateways:
excludes["remove_exported_gateways"] = True
return excludes
5 changes: 5 additions & 0 deletions src/dstack/_internal/core/compatibility/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dstack._internal.core.compatibility.common import patch_profile_params
from dstack._internal.core.models.common import (
EntityReference,
IncludeExcludeDictType,
IncludeExcludeSetType,
)
Expand Down Expand Up @@ -177,3 +178,7 @@ def patch_run_spec(run_spec: RunSpec) -> None:
patch_profile_params(run_spec.configuration)
if run_spec.profile is not None:
patch_profile_params(run_spec.profile)
if isinstance(run_spec.configuration, ServiceConfiguration):
if isinstance(run_spec.configuration.gateway, EntityReference):
# Pre-0.20.20 servers do not support `EntityReference` in `gateway`
run_spec.configuration.gateway = run_spec.configuration.gateway.format()
Loading
Loading