Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__( # pylint: disable=too-many-locals
)
self.allowed_instance_types = allowed_instance_types
self.default_instance_type = default_instance_type
self._allowed_environment_variable_overrides = None
self.scoring_port = scoring_port
self.scoring_path = scoring_path
self.model_mount_path = model_mount_path
Expand Down Expand Up @@ -368,6 +369,15 @@ def get_value(source, key, default=None):
allowed_instance_types = get_value(properties, "allowedInstanceTypes") or get_value(
obj, "allowed_instance_types"
)
# Also check additional_properties for service fields with mismatched names
if not allowed_instance_types:
additional_props = get_value(obj, "additional_properties", {})
if isinstance(additional_props, dict):
allowed_instance_types = additional_props.get("allowedInstanceType") or additional_props.get("allowedInstanceTypes")
Comment on lines 369 to +376
allowed_environment_variable_overrides = (
get_value(properties, "allowedEnvironmentVariableOverrides")
or get_value(obj, "allowed_environment_variable_overrides")
)
scoring_port = get_value(properties, "scoringPort") or get_value(obj, "scoring_port")
scoring_path = get_value(properties, "scoringPath") or get_value(obj, "scoring_path")
model_mount_path = get_value(properties, "modelMountPath") or get_value(obj, "model_mount_path")
Expand Down Expand Up @@ -399,6 +409,13 @@ def get_value(source, key, default=None):
except (ValueError, SyntaxError):
allowed_instance_types = None

# Parse allowed_environment_variable_overrides if it's a string
if isinstance(allowed_environment_variable_overrides, str):
try:
allowed_environment_variable_overrides = ast.literal_eval(allowed_environment_variable_overrides)
except (ValueError, SyntaxError):
allowed_environment_variable_overrides = None

# Convert request_settings to OnlineRequestSettings object using the built-in conversion method
Comment on lines 372 to 419
request_settings_obj = OnlineRequestSettings._from_rest_object(request_settings) if request_settings else None

Expand Down Expand Up @@ -451,6 +468,9 @@ def get_value(source, key, default=None):
# updates
template._from_service = True

# Store allowed_environment_variable_overrides as private field for round-trip
template._allowed_environment_variable_overrides = allowed_environment_variable_overrides

# Store additional fields from the REST response that may be needed
template.environment_id = environment_id # type: ignore[attr-defined]
# Alternative name for deployment_template_type
Expand All @@ -472,6 +492,7 @@ def get_value(source, key, default=None):
"app_insights_enabled": get_value(obj, "app_insights_enabled"),
"deployment_template_type": deployment_template_type,
"allowed_instance_types": allowed_instance_types,
"allowed_environment_variable_overrides": allowed_environment_variable_overrides,
"scoring_port": scoring_port,
"scoring_path": scoring_path,
"model_mount_path": model_mount_path,
Expand Down Expand Up @@ -571,6 +592,10 @@ def _to_rest_object(self) -> dict:
# Handle allowed instance types
if hasattr(self, "allowed_instance_types") and self.allowed_instance_types:
result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment]
result["allowedInstanceType"] = self.allowed_instance_types # type: ignore[assignment]

if hasattr(self, "_allowed_environment_variable_overrides") and self._allowed_environment_variable_overrides:
result["allowedEnvironmentVariableOverrides"] = self._allowed_environment_variable_overrides

return result

Expand Down Expand Up @@ -637,6 +662,8 @@ def _to_dict(self) -> Dict:
# Add instance configuration
if hasattr(self, "allowed_instance_types") and self.allowed_instance_types:
result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment]
if hasattr(self, "_allowed_environment_variable_overrides") and self._allowed_environment_variable_overrides:
result["allowedEnvironmentVariableOverrides"] = self._allowed_environment_variable_overrides
if self.default_instance_type:
result["defaultInstanceType"] = self.default_instance_type
elif self.instance_type:
Expand Down
Loading