diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py index de17a218e4d5..e323b1c77920 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py @@ -104,6 +104,7 @@ def __init__( # pylint: disable=too-many-locals ) self.allowed_instance_types = allowed_instance_types self.default_instance_type = default_instance_type + self._allowed_environment_variable_overrides = None self.scoring_port = scoring_port self.scoring_path = scoring_path self.model_mount_path = model_mount_path @@ -368,6 +369,15 @@ def get_value(source, key, default=None): allowed_instance_types = get_value(properties, "allowedInstanceTypes") or get_value( obj, "allowed_instance_types" ) + # Also check additional_properties for service fields with mismatched names + if not allowed_instance_types: + additional_props = get_value(obj, "additional_properties", {}) + if isinstance(additional_props, dict): + allowed_instance_types = additional_props.get("allowedInstanceType") or additional_props.get("allowedInstanceTypes") + allowed_environment_variable_overrides = ( + get_value(properties, "allowedEnvironmentVariableOverrides") + or get_value(obj, "allowed_environment_variable_overrides") + ) scoring_port = get_value(properties, "scoringPort") or get_value(obj, "scoring_port") scoring_path = get_value(properties, "scoringPath") or get_value(obj, "scoring_path") model_mount_path = get_value(properties, "modelMountPath") or get_value(obj, "model_mount_path") @@ -399,6 +409,13 @@ def get_value(source, key, default=None): except (ValueError, SyntaxError): allowed_instance_types = None + # Parse allowed_environment_variable_overrides if it's a string + if isinstance(allowed_environment_variable_overrides, str): + try: + allowed_environment_variable_overrides = ast.literal_eval(allowed_environment_variable_overrides) + except (ValueError, SyntaxError): + allowed_environment_variable_overrides = None + # Convert request_settings to OnlineRequestSettings object using the built-in conversion method request_settings_obj = OnlineRequestSettings._from_rest_object(request_settings) if request_settings else None @@ -451,6 +468,9 @@ def get_value(source, key, default=None): # updates template._from_service = True + # Store allowed_environment_variable_overrides as private field for round-trip + template._allowed_environment_variable_overrides = allowed_environment_variable_overrides + # Store additional fields from the REST response that may be needed template.environment_id = environment_id # type: ignore[attr-defined] # Alternative name for deployment_template_type @@ -472,6 +492,7 @@ def get_value(source, key, default=None): "app_insights_enabled": get_value(obj, "app_insights_enabled"), "deployment_template_type": deployment_template_type, "allowed_instance_types": allowed_instance_types, + "allowed_environment_variable_overrides": allowed_environment_variable_overrides, "scoring_port": scoring_port, "scoring_path": scoring_path, "model_mount_path": model_mount_path, @@ -571,6 +592,10 @@ def _to_rest_object(self) -> dict: # Handle allowed instance types if hasattr(self, "allowed_instance_types") and self.allowed_instance_types: result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment] + result["allowedInstanceType"] = self.allowed_instance_types # type: ignore[assignment] + + if hasattr(self, "_allowed_environment_variable_overrides") and self._allowed_environment_variable_overrides: + result["allowedEnvironmentVariableOverrides"] = self._allowed_environment_variable_overrides return result @@ -637,6 +662,8 @@ def _to_dict(self) -> Dict: # Add instance configuration if hasattr(self, "allowed_instance_types") and self.allowed_instance_types: result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment] + if hasattr(self, "_allowed_environment_variable_overrides") and self._allowed_environment_variable_overrides: + result["allowedEnvironmentVariableOverrides"] = self._allowed_environment_variable_overrides if self.default_instance_type: result["defaultInstanceType"] = self.default_instance_type elif self.instance_type: diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_deployment_template_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_deployment_template_operations.py index 8ae77f2a99d2..1041cbce9421 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_deployment_template_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_deployment_template_operations.py @@ -45,15 +45,21 @@ def __init__( def _get_registry_endpoint(self) -> str: """Dynamically determine the registry endpoint based on registry region. + Uses the registry discovery API (which does not require ARM access to the + registry's subscription) to resolve the primary region, then constructs the + appropriate dataplane endpoint. + :return: The API endpoint URL for the registry :rtype: str """ try: - # Import here to avoid circular dependencies - from azure.ai.ml._restclient.v2022_10_01_preview import ( - AzureMachineLearningWorkspaces as ServiceClient102022, + from azure.ai.ml._azure_environments import ( + _get_default_cloud_name, + _get_registry_discovery_endpoint_from_metadata, + ) + from azure.ai.ml._restclient.registry_discovery import ( + RegistryDiscoveryClient as ServiceClientRegistryDiscovery, ) - from azure.ai.ml.operations import RegistryOperations # Try to get credential from service client or operation config credential = None @@ -63,31 +69,19 @@ def _get_registry_endpoint(self) -> str: credential = self._operation_config.credential if credential and self._operation_scope.registry_name: - # Get registry information to determine the region - registry_operations = RegistryOperations( - operation_scope=self._operation_scope, - service_client=ServiceClient102022( - credential=credential, - subscription_id=self._operation_scope.subscription_id, - resource_group_name=self._operation_scope.resource_group_name, - ), - all_operations=None, # type: ignore[arg-type] - credentials=credential, + # Use registry discovery API to get the primary region + discovery_base_url = _get_registry_discovery_endpoint_from_metadata(_get_default_cloud_name()) + discovery_client = ServiceClientRegistryDiscovery( + credential=credential, base_url=discovery_base_url + ) + response = ( + discovery_client.registry_management_non_workspace.get_registry_management_non_workspace( + self._operation_scope.registry_name + ) ) - registry = registry_operations.get(self._operation_scope.registry_name) - - # Extract region from registry location or replication locations - region = None - if registry.location: - region = registry.location - elif registry.replication_locations and len(registry.replication_locations) > 0: - region = registry.replication_locations[0].location - - if region: - # Format the endpoint using the detected region - # return f"https://int.experiments.azureml-test.net" - return f"https://{region}.api.azureml.ms" + if response.primary_region: + return f"https://{response.primary_region}.api.azureml.ms" except Exception as e: module_logger.debug("Could not determine registry region dynamically: %s. Using default.", e)