Skip to content

Commit 0585e95

Browse files
authored
support per-replica-group image, docker, python, nvcc, privileged (#3832)
* support per-replica-group image, docker, python, nvcc, privileged * Merge Conflict Resolved * Resolve Review Comments --------- Co-authored-by: Bihan Rana
1 parent 9e23658 commit 0585e95

5 files changed

Lines changed: 1048 additions & 9 deletions

File tree

src/dstack/_internal/core/compatibility/runs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType:
108108
replica_group_excludes["router"] = True
109109
if all(g.scaling is None or g.scaling.window is None for g in replicas):
110110
replica_group_excludes["scaling"] = {"window": True}
111+
if all(g.image is None for g in replicas):
112+
replica_group_excludes["image"] = True
113+
if all(g.docker is None for g in replicas):
114+
replica_group_excludes["docker"] = True
115+
if all(g.python is None for g in replicas):
116+
replica_group_excludes["python"] = True
117+
if all(g.nvcc is None for g in replicas):
118+
replica_group_excludes["nvcc"] = True
119+
if all(g.privileged is None for g in replicas):
120+
replica_group_excludes["privileged"] = True
111121
if replica_group_excludes:
112122
configuration_excludes["replicas"] = {"__all__": replica_group_excludes}
113123

src/dstack/_internal/core/models/configurations.py

Lines changed: 165 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,39 @@ class ReplicaGroup(CoreModel):
840840
CommandsList,
841841
Field(description="The shell commands to run for replicas in this group"),
842842
] = []
843+
image: Annotated[
844+
Optional[str],
845+
Field(
846+
description="The name of the Docker image to run for replicas in this group. "
847+
"Mutually exclusive with group-level `docker` and `python`."
848+
),
849+
] = None
850+
python: Annotated[
851+
Optional[PythonVersion],
852+
Field(
853+
description="The major version of Python for replicas in this group. "
854+
"Mutually exclusive with group-level `image` and `docker`."
855+
),
856+
] = None
857+
nvcc: Annotated[
858+
Optional[bool],
859+
Field(
860+
description="Use the image with NVIDIA CUDA Compiler (NVCC) included for replicas in this group. "
861+
"Mutually exclusive with group-level `docker`."
862+
),
863+
] = None
864+
docker: Annotated[
865+
Optional[bool],
866+
Field(
867+
description="Use the docker-in-docker image for this group "
868+
"(injects `start-dockerd` and runs privileged). Mutually "
869+
"exclusive with group-level `image`, `python`, and `nvcc`."
870+
),
871+
] = None
872+
privileged: Annotated[
873+
Optional[bool],
874+
Field(description="Run replicas in this group in privileged mode."),
875+
] = None
843876
router: Annotated[
844877
Optional[ReplicaGroupRouterConfig],
845878
Field(
@@ -858,6 +891,42 @@ def validate_name(cls, v: Optional[str]) -> Optional[str]:
858891
def convert_count(cls, v: Range[int]) -> Range[int]:
859892
return _validate_replica_range(v)
860893

894+
@validator("python", pre=True, always=True)
895+
def convert_python(cls, v, values) -> Optional[PythonVersion]:
896+
if v is not None and values.get("image"):
897+
raise ValueError("`image` and `python` are mutually exclusive within a replica group")
898+
if isinstance(v, float):
899+
v = str(v)
900+
if v == "3.1":
901+
v = "3.10"
902+
if isinstance(v, str):
903+
return PythonVersion(v)
904+
return v
905+
906+
@validator("docker", pre=True, always=True)
907+
def _docker(cls, v, values) -> Optional[bool]:
908+
if v is True and values.get("image"):
909+
raise ValueError("`image` and `docker` are mutually exclusive within a replica group")
910+
if v is True and values.get("python"):
911+
raise ValueError("`python` and `docker` are mutually exclusive within a replica group")
912+
if v is True and values.get("nvcc"):
913+
raise ValueError("`nvcc` and `docker` are mutually exclusive within a replica group")
914+
return v
915+
916+
@validator("privileged", pre=True, always=True)
917+
def _privileged(cls, v, values) -> Optional[bool]:
918+
# Docker-in-docker requires privileged mode. The service level
919+
# cannot enforce this rule because its `privileged` field defaults
920+
# to `False` (existing backwards-compatibility constraint), so it
921+
# cannot distinguish "unset" from explicit `False`. At the group
922+
# level we keep `privileged` as `Optional[bool] = None`, so we can.
923+
if v is False and values.get("docker") is True:
924+
raise ValueError(
925+
"`privileged: false` is incompatible with `docker: true` within "
926+
"a replica group (docker-in-docker requires privileged mode)"
927+
)
928+
return v
929+
861930
@root_validator()
862931
def validate_scaling(cls, values):
863932
scaling = values.get("scaling")
@@ -1057,22 +1126,113 @@ def validate_top_level_properties_with_replica_groups(cls, values):
10571126

10581127
return values
10591128

1129+
@root_validator()
1130+
def validate_no_mixed_service_and_group_container_fields(cls, values):
1131+
"""
1132+
When replicas is a list (image, docker, privileged) may be set
1133+
at the service level OR in replica groups, never both. Mixing is
1134+
rejected — including partial mixing, where only some groups set a
1135+
field the service also sets — because it leaves precedence ambiguous.
1136+
"""
1137+
replicas = values.get("replicas")
1138+
if not isinstance(replicas, list):
1139+
return values
1140+
1141+
checks = [
1142+
(
1143+
"image",
1144+
values.get("image") is not None,
1145+
lambda g: g.image is not None,
1146+
),
1147+
(
1148+
"docker",
1149+
values.get("docker") is True,
1150+
lambda g: g.docker is not None,
1151+
),
1152+
(
1153+
"privileged",
1154+
values.get("privileged") is True,
1155+
lambda g: g.privileged is not None,
1156+
),
1157+
(
1158+
"python",
1159+
values.get("python") is not None,
1160+
lambda g: g.python is not None,
1161+
),
1162+
(
1163+
"nvcc",
1164+
values.get("nvcc") is True,
1165+
lambda g: g.nvcc is not None,
1166+
),
1167+
]
1168+
1169+
for field, service_set, group_set in checks:
1170+
if service_set:
1171+
conflicting = [g.name for g in replicas if group_set(g)]
1172+
if conflicting:
1173+
raise ValueError(
1174+
f"`{field}` is set at both the service level and in "
1175+
f"replica group(s) {conflicting}. Set `{field}` in one "
1176+
f"place only — either at the service level (all groups "
1177+
f"inherit) or per group, but not both."
1178+
)
1179+
return values
1180+
1181+
@root_validator()
1182+
def validate_no_conflicting_image_sources_across_levels(cls, values):
1183+
"""
1184+
Image-source fields (`image`, `docker`, `python`, `nvcc`) cannot
1185+
be mixed across service and group levels in conflicting ways.
1186+
"""
1187+
replicas = values.get("replicas")
1188+
if not isinstance(replicas, list):
1189+
return values
1190+
1191+
forbidden = [
1192+
("image", values.get("image") is not None, "docker", lambda g: g.docker is not None),
1193+
("image", values.get("image") is not None, "python", lambda g: g.python is not None),
1194+
("image", values.get("image") is not None, "nvcc", lambda g: g.nvcc is not None),
1195+
("docker", values.get("docker") is True, "image", lambda g: g.image is not None),
1196+
("docker", values.get("docker") is True, "python", lambda g: g.python is not None),
1197+
("docker", values.get("docker") is True, "nvcc", lambda g: g.nvcc is not None),
1198+
("python", values.get("python") is not None, "image", lambda g: g.image is not None),
1199+
("python", values.get("python") is not None, "docker", lambda g: g.docker is not None),
1200+
("nvcc", values.get("nvcc") is True, "image", lambda g: g.image is not None),
1201+
("nvcc", values.get("nvcc") is True, "docker", lambda g: g.docker is not None),
1202+
]
1203+
1204+
for s_field, s_set, g_field, g_pred in forbidden:
1205+
if s_set:
1206+
conflicting = [g.name for g in replicas if g_pred(g)]
1207+
if conflicting:
1208+
raise ValueError(
1209+
f"Service-level `{s_field}` conflicts with group-level "
1210+
f"`{g_field}` in replica group(s) {conflicting}. "
1211+
f"These image-source fields are mutually exclusive."
1212+
)
1213+
return values
1214+
10601215
@root_validator()
10611216
def validate_replica_groups_have_commands_or_image(cls, values):
10621217
"""
1063-
When replicas is a list, ensure each ReplicaGroup has commands OR service has image.
1218+
When replicas is a list, ensure each ReplicaGroup has something
1219+
to run. Mirrors the service-level rule: either explicit
1220+
`commands` or an `image` (group-level or service-level) is
1221+
required.
10641222
"""
10651223
replicas = values.get("replicas")
1066-
image = values.get("image")
10671224

10681225
if not isinstance(replicas, list):
10691226
return values
10701227

1228+
service_has_image = values.get("image") is not None
1229+
10711230
for group in replicas:
1072-
if not group.commands and not image:
1231+
if not group.commands and group.image is None and not service_has_image:
10731232
raise ValueError(
1074-
f"Replica group '{group.name}' has no commands. "
1075-
"Either set `commands` in the replica group or set `image` at the service level."
1233+
f"Replica group '{group.name}': either `commands` or "
1234+
"`image` must be set in the group, or `image` at the "
1235+
"service level."
10761236
)
10771237

10781238
return values

src/dstack/_internal/server/services/jobs/configurators/service.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,111 @@
11
from typing import List, Optional
22

3-
from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType
3+
from dstack._internal import settings
4+
from dstack._internal.core.models.configurations import (
5+
PortMapping,
6+
ReplicaGroup,
7+
RunConfigurationType,
8+
)
49
from dstack._internal.core.models.profiles import SpotPolicy
5-
from dstack._internal.server.services.jobs.configurators.base import JobConfigurator
10+
from dstack._internal.core.models.unix import UnixUser
11+
from dstack._internal.server.services.jobs.configurators.base import (
12+
JobConfigurator,
13+
get_default_image,
14+
)
615

716

817
class ServiceJobConfigurator(JobConfigurator):
918
TYPE: RunConfigurationType = RunConfigurationType.SERVICE
1019

11-
def _shell_commands(self) -> List[str]:
20+
def _current_replica_group(self) -> Optional[ReplicaGroup]:
1221
assert self.run_spec.configuration.type == "service"
1322
for group in self.run_spec.configuration.replica_groups:
1423
if group.name == self.replica_group_name:
15-
return group.commands
24+
return group
25+
return None
26+
27+
def _shell_commands(self) -> List[str]:
28+
assert self.run_spec.configuration.type == "service"
29+
group = self._current_replica_group()
30+
if group is not None:
31+
return group.commands
1632
return self.run_spec.configuration.commands
1733

34+
def _image_name(self) -> str:
35+
group = self._current_replica_group()
36+
if group is not None:
37+
if group.docker is True:
38+
return settings.DSTACK_DIND_IMAGE
39+
if group.image is not None:
40+
return group.image
41+
if group.nvcc is True:
42+
return get_default_image(nvcc=True)
43+
return super()._image_name()
44+
45+
def _privileged(self) -> bool:
46+
group = self._current_replica_group()
47+
if group is not None:
48+
if group.docker is True:
49+
return True
50+
if group.privileged is not None:
51+
return group.privileged
52+
return super()._privileged()
53+
54+
def _dstack_image_commands(self) -> List[str]:
55+
group = self._current_replica_group()
56+
if group is not None:
57+
if group.docker is True:
58+
return ["start-dockerd"]
59+
if group.image is not None:
60+
return []
61+
return super()._dstack_image_commands()
62+
63+
def _shell(self) -> str:
64+
# Shell resolution order:
65+
# 1. If `shell:` is set explicitly, the base honors it.
66+
# 2. If this group sets `docker: true`, use /bin/bash — the
67+
# DIND image ships bash, matching the service-level path.
68+
# 3. If this group sets its own `image`, force /bin/sh. The
69+
# base returns /bin/bash when service-level `image` is None,
70+
# but a group-level custom image (e.g. alpine) may not ship
71+
# bash.
72+
# 4. Otherwise defer to the base (bash for dstackai/base, sh
73+
# for a service-level custom image).
74+
if self.run_spec.configuration.shell is None:
75+
group = self._current_replica_group()
76+
if group is not None:
77+
if group.docker is True:
78+
return "/bin/bash"
79+
if group.image is not None:
80+
return "/bin/sh"
81+
return super()._shell()
82+
83+
async def _user(self) -> Optional[UnixUser]:
84+
# Base `_user()` only queries the image for a default user when
85+
# `configuration.image` is set at the service level. When the
86+
# group supplies its own `image`, perform the lookup here so the
87+
# container runs as that image's default user.
88+
#
89+
# We intentionally do NOT look up the DIND image when the group
90+
# sets `docker: true`. That matches service-level behavior: when
91+
# `configuration.docker is True`, `configuration.image` is None,
92+
# so the base skips the lookup. DIND is always privileged and
93+
# effectively root anyway.
94+
if self.run_spec.configuration.user is None:
95+
group = self._current_replica_group()
96+
if group is not None and group.image is not None:
97+
image_config = await self._get_image_config()
98+
if image_config.user is None:
99+
return None
100+
return UnixUser.parse(image_config.user)
101+
return await super()._user()
102+
103+
def _python(self) -> str:
104+
group = self._current_replica_group()
105+
if group is not None and group.python is not None:
106+
return group.python.value
107+
return super()._python()
108+
18109
def _default_single_branch(self) -> bool:
19110
return True
20111

0 commit comments

Comments
 (0)