Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
)
from dstack._internal.core.backends.kubernetes.utils import (
call_api_method,
get_api_from_config_data,
get_api_from_kubeconfig_dict,
kubeconfig_data_to_kubeconfig_dict,
kubeconfig_dict_to_kubeconfig,
)
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.errors import ComputeError, ProvisioningError
Expand Down Expand Up @@ -127,7 +129,29 @@ def __init__(self, config: KubernetesConfig):
if proxy_jump is None:
proxy_jump = KubernetesProxyJumpConfig()
self.proxy_jump = proxy_jump
self.api = get_api_from_config_data(config.kubeconfig.data)
kubeconfig_dict = kubeconfig_data_to_kubeconfig_dict(config.kubeconfig.data)
self.api = get_api_from_kubeconfig_dict(kubeconfig_dict)
kubeconfig = kubeconfig_dict_to_kubeconfig(kubeconfig_dict)
current_context = kubeconfig.get_context()
if current_context.namespace != config.namespace:
logger.warning(
(
"Namespace mismatch: kubeconfig -> '%s', backend config -> '%s'."
" The current dstack version ignores kubeconfig"
" and uses deprecated namespace property from backend config."
" Future versions will use namespace from kubeconfig."
" To keep using '%s' namespace in future versions and suppress this warning,"
" set namespace to '%s' in kubeconfig context '%s'"
),
current_context.namespace,
config.namespace,
config.namespace,
config.namespace,
kubeconfig.current_context,
)
# TODO: switch to current_context.namespace
self.namespace = config.namespace
logger.debug("Using namespace '%s'", self.namespace)

def get_offers_by_requirements(
self, requirements: Requirements
Expand Down Expand Up @@ -156,7 +180,7 @@ def run_job(
jump_pod_service_name = _get_pod_service_name(jump_pod_name)
_create_jump_pod_service_if_not_exists(
api=self.api,
namespace=self.config.namespace,
namespace=self.namespace,
jump_pod_name=jump_pod_name,
jump_pod_service_name=jump_pod_service_name,
jump_pod_port=self.proxy_jump.port,
Expand All @@ -177,7 +201,7 @@ def run_job(
string_data={".dockerconfigjson": dockerconfigjson},
)
self.api.create_namespaced_secret(
namespace=self.config.namespace,
namespace=self.namespace,
body=registry_auth_secret,
)
image_pull_secrets = [client.V1LocalObjectReference(name=registry_auth_secret_name)]
Expand Down Expand Up @@ -342,11 +366,11 @@ def run_job(
),
)
self.api.create_namespaced_pod(
namespace=self.config.namespace,
namespace=self.namespace,
body=pod,
)
self.api.create_namespaced_service(
namespace=self.config.namespace,
namespace=self.namespace,
body=client.V1Service(
metadata=client.V1ObjectMeta(name=_get_pod_service_name(instance_name)),
spec=client.V1ServiceSpec(
Expand Down Expand Up @@ -395,7 +419,7 @@ def update_provisioning_data(
backend_data = KubernetesBackendData.load(provisioning_data.backend_data)
ssh_proxy = _check_and_configure_jump_pod_service(
api=self.api,
namespace=self.config.namespace,
namespace=self.namespace,
jump_pod_name=backend_data.jump_pod_name,
jump_pod_service_name=backend_data.jump_pod_service_name,
jump_pod_hostname=self.proxy_jump.hostname,
Expand All @@ -412,7 +436,7 @@ def update_provisioning_data(

pod = self.api.read_namespaced_pod(
name=provisioning_data.instance_id,
namespace=self.config.namespace,
namespace=self.namespace,
)
if pod.status is None:
return
Expand All @@ -422,7 +446,7 @@ def update_provisioning_data(
provisioning_data.internal_ip = pod_ip
service = self.api.read_namespaced_service(
name=_get_pod_service_name(provisioning_data.instance_id),
namespace=self.config.namespace,
namespace=self.namespace,
)
service_spec = get_or_error(service.spec)
provisioning_data.hostname = get_or_error(service_spec.cluster_ip)
Expand Down Expand Up @@ -450,21 +474,21 @@ def terminate_instance(
self.api.delete_namespaced_service,
expected=404,
name=_get_pod_service_name(instance_id),
namespace=self.config.namespace,
namespace=self.namespace,
body=client.V1DeleteOptions(),
)
call_api_method(
self.api.delete_namespaced_pod,
expected=404,
name=instance_id,
namespace=self.config.namespace,
namespace=self.namespace,
body=client.V1DeleteOptions(),
)
call_api_method(
self.api.delete_namespaced_secret,
expected=404,
name=_get_registry_auth_secret_name(instance_id),
namespace=self.config.namespace,
namespace=self.namespace,
body=client.V1DeleteOptions(),
)

Expand Down Expand Up @@ -520,7 +544,7 @@ def create_gateway(
),
)
self.api.create_namespaced_pod(
namespace=self.config.namespace,
namespace=self.namespace,
body=pod,
)
service = client.V1Service(
Expand Down Expand Up @@ -550,13 +574,13 @@ def create_gateway(
),
)
self.api.create_namespaced_service(
namespace=self.config.namespace,
namespace=self.namespace,
body=service,
)
# address is eiher a domain name or an IP address
address = _wait_for_load_balancer_address(
api=self.api,
namespace=self.config.namespace,
namespace=self.namespace,
service_name=_get_pod_service_name(instance_name),
)
if address is None:
Expand Down Expand Up @@ -591,7 +615,7 @@ def register_volume(self, volume: Volume) -> VolumeProvisioningData:
pvc = call_api_method(
self.api.read_namespaced_persistent_volume_claim,
expected=404,
namespace=self.config.namespace,
namespace=self.namespace,
name=pvc_name,
)
if pvc is None:
Expand Down Expand Up @@ -650,7 +674,7 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
),
)
self.api.create_namespaced_persistent_volume_claim(
namespace=self.config.namespace,
namespace=self.namespace,
body=pvc,
)
logger.debug("Created PVC %s for volume %s", pvc_name, volume.name)
Expand All @@ -671,7 +695,7 @@ def delete_volume(self, volume: Volume):
pvc = call_api_method(
self.api.delete_namespaced_persistent_volume_claim,
expected=404,
namespace=self.config.namespace,
namespace=self.namespace,
name=pvc_name,
)
if pvc is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def validate_config(
self, config: KubernetesBackendConfigWithCreds, default_creds_enabled: bool
):
try:
api = kubernetes_utils.get_api_from_config_data(config.kubeconfig.data)
api = kubernetes_utils.get_api_from_kubeconfig_data(config.kubeconfig.data)
api.list_node()
except Exception as e:
logger.debug("Invalid kubeconfig: %s", str(e))
Expand Down
14 changes: 13 additions & 1 deletion src/dstack/_internal/core/backends/kubernetes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,20 @@ class KubernetesBackendConfig(CoreModel):
Optional[KubernetesProxyJumpConfig], Field(description="The SSH proxy jump configuration")
] = None
namespace: Annotated[
str, Field(description="The namespace for resources managed by `dstack`")
str,
Field(
description=(
"The namespace for resources managed by `dstack`."
" Always overrides the namespace set in the kubeconfig, even if not set. "
" Deprecated and will be eventually removed in futute versions, but"
" in the current version must be set unless equals to `default`."
" Future versions will use the namespace from the kubeconfig instead."
" To prepare for future versions, set the same value in the kubeconfig"
)
),
] = DEFAULT_NAMESPACE
"""`namespace` is formally deprecated since 0.20.20 but still used. Future versions will switch
to namespace from kubeconfig context, which is currently ignored"""


class KubernetesBackendConfigWithCreds(KubernetesBackendConfig):
Expand Down
59 changes: 53 additions & 6 deletions src/dstack/_internal/core/backends/kubernetes/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, TypeVar, Union
from typing import Annotated, Callable, Optional, TypeVar, Union

import yaml
from kubernetes.client import CoreV1Api
Expand All @@ -7,19 +7,66 @@
# XXX: This function is missing in the stubs package
new_client_from_config_dict, # pyright: ignore[reportAttributeAccessIssue]
)
from pydantic import Field
from typing_extensions import ParamSpec

from dstack._internal.core.models.common import CoreModel

T = TypeVar("T")
P = ParamSpec("P")


def get_api_from_config_data(kubeconfig_data: str) -> CoreV1Api:
config_dict = yaml.load(kubeconfig_data, yaml.FullLoader)
return get_api_from_config_dict(config_dict)
class KubeconfigContext(CoreModel):
namespace: str = "default"


class KubeconfigNamedContext(CoreModel):
name: str
context: KubeconfigContext


class Kubeconfig(CoreModel):
"""
`Kubeconfig` model only includes fields used by `dstack`.
Reference: https://kubernetes.io/docs/reference/config-api/kubeconfig.v1/
"""

contexts: list[KubeconfigNamedContext] = []
current_context: Annotated[Optional[str], Field(alias="current-context")] = None

def get_context(self, name: Optional[str] = None) -> KubeconfigContext:
if name is None:
name = self.current_context
if name is None:
raise ValueError("current-context is not set")
for named_context in self.contexts:
if named_context.name == name:
return named_context.context
raise ValueError(f"context {name} not found")


def kubeconfig_data_to_kubeconfig_dict(kubeconfig_data: str) -> dict:
kubeconfig_dict = yaml.load(kubeconfig_data, yaml.FullLoader)
if not isinstance(kubeconfig_dict, dict):
raise TypeError(f"Unexpected kubeconfig_data type: {kubeconfig_dict.__class__.__name__}")
return kubeconfig_dict


def kubeconfig_dict_to_kubeconfig(kubeconfig_dict: dict) -> Kubeconfig:
return Kubeconfig.__response__.parse_obj(kubeconfig_dict)


def get_api_from_kubeconfig_data(
kubeconfig_data: str, *, context: Optional[str] = None
) -> CoreV1Api:
kubeconfig_dict = kubeconfig_data_to_kubeconfig_dict(kubeconfig_data)
return get_api_from_kubeconfig_dict(kubeconfig_dict, context=context)


def get_api_from_config_dict(kubeconfig: dict) -> CoreV1Api:
api_client = new_client_from_config_dict(config_dict=kubeconfig)
def get_api_from_kubeconfig_dict(
kubeconfig_dict: dict, *, context: Optional[str] = None
) -> CoreV1Api:
api_client = new_client_from_config_dict(config_dict=kubeconfig_dict, context=context)
return CoreV1Api(api_client=api_client)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_validate_config_valid(self):
proxy_jump=KubernetesProxyJumpConfig(hostname=None, port=None),
)
with patch(
"dstack._internal.core.backends.kubernetes.utils.get_api_from_config_data"
"dstack._internal.core.backends.kubernetes.utils.get_api_from_kubeconfig_data"
) as get_api_mock:
api_mock = Mock()
api_mock.list_node.return_value = Mock()
Expand All @@ -34,7 +34,7 @@ def test_validate_config_invalid_config(self):
)
with (
patch(
"dstack._internal.core.backends.kubernetes.utils.get_api_from_config_data"
"dstack._internal.core.backends.kubernetes.utils.get_api_from_kubeconfig_data"
) as get_api_mock,
pytest.raises(BackendInvalidCredentialsError) as exc_info,
):
Expand Down
Loading