From 098efbc6248477cc567500a6deec88b5d53e17c0 Mon Sep 17 00:00:00 2001 From: Pratibha Shrivastav Date: Tue, 5 May 2026 20:09:06 +0530 Subject: [PATCH] Fix DT update round-trip for allowedEnvironmentVariableOverrides and allowedInstanceType --- .../_deployment/deployment_template.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py index de17a218e4d5..e323b1c77920 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py @@ -104,6 +104,7 @@ def __init__( # pylint: disable=too-many-locals ) self.allowed_instance_types = allowed_instance_types self.default_instance_type = default_instance_type + self._allowed_environment_variable_overrides = None self.scoring_port = scoring_port self.scoring_path = scoring_path self.model_mount_path = model_mount_path @@ -368,6 +369,15 @@ def get_value(source, key, default=None): allowed_instance_types = get_value(properties, "allowedInstanceTypes") or get_value( obj, "allowed_instance_types" ) + # Also check additional_properties for service fields with mismatched names + if not allowed_instance_types: + additional_props = get_value(obj, "additional_properties", {}) + if isinstance(additional_props, dict): + allowed_instance_types = additional_props.get("allowedInstanceType") or additional_props.get("allowedInstanceTypes") + allowed_environment_variable_overrides = ( + get_value(properties, "allowedEnvironmentVariableOverrides") + or get_value(obj, "allowed_environment_variable_overrides") + ) scoring_port = get_value(properties, "scoringPort") or get_value(obj, "scoring_port") scoring_path = get_value(properties, "scoringPath") or get_value(obj, "scoring_path") model_mount_path = get_value(properties, "modelMountPath") or get_value(obj, "model_mount_path") @@ -399,6 +409,13 @@ def get_value(source, key, default=None): except (ValueError, SyntaxError): allowed_instance_types = None + # Parse allowed_environment_variable_overrides if it's a string + if isinstance(allowed_environment_variable_overrides, str): + try: + allowed_environment_variable_overrides = ast.literal_eval(allowed_environment_variable_overrides) + except (ValueError, SyntaxError): + allowed_environment_variable_overrides = None + # Convert request_settings to OnlineRequestSettings object using the built-in conversion method request_settings_obj = OnlineRequestSettings._from_rest_object(request_settings) if request_settings else None @@ -451,6 +468,9 @@ def get_value(source, key, default=None): # updates template._from_service = True + # Store allowed_environment_variable_overrides as private field for round-trip + template._allowed_environment_variable_overrides = allowed_environment_variable_overrides + # Store additional fields from the REST response that may be needed template.environment_id = environment_id # type: ignore[attr-defined] # Alternative name for deployment_template_type @@ -472,6 +492,7 @@ def get_value(source, key, default=None): "app_insights_enabled": get_value(obj, "app_insights_enabled"), "deployment_template_type": deployment_template_type, "allowed_instance_types": allowed_instance_types, + "allowed_environment_variable_overrides": allowed_environment_variable_overrides, "scoring_port": scoring_port, "scoring_path": scoring_path, "model_mount_path": model_mount_path, @@ -571,6 +592,10 @@ def _to_rest_object(self) -> dict: # Handle allowed instance types if hasattr(self, "allowed_instance_types") and self.allowed_instance_types: result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment] + result["allowedInstanceType"] = self.allowed_instance_types # type: ignore[assignment] + + if hasattr(self, "_allowed_environment_variable_overrides") and self._allowed_environment_variable_overrides: + result["allowedEnvironmentVariableOverrides"] = self._allowed_environment_variable_overrides return result @@ -637,6 +662,8 @@ def _to_dict(self) -> Dict: # Add instance configuration if hasattr(self, "allowed_instance_types") and self.allowed_instance_types: result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment] + if hasattr(self, "_allowed_environment_variable_overrides") and self._allowed_environment_variable_overrides: + result["allowedEnvironmentVariableOverrides"] = self._allowed_environment_variable_overrides if self.default_instance_type: result["defaultInstanceType"] = self.default_instance_type elif self.instance_type: