From 346fcd3196cbb18f6c4674d2275127cc40ccdbea Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Thu, 19 Mar 2026 20:54:25 +0000 Subject: [PATCH 1/3] fix: Remove hardcoded secret key from Triton ONNX export path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The ONNX export path in _prepare_for_triton() set self.secret_key to a hardcoded value 'dummy secret key for onnx backend'. This key was then passed as SAGEMAKER_SERVE_SECRET_KEY into container environment variables and exposed in plaintext via DescribeModel/DescribeEndpointConfig APIs. The ONNX path does not use pickle serialization — models are exported to .onnx format and loaded natively by Triton's ONNX Runtime backend. There is no serve.pkl, no metadata.json, and no integrity check to perform. The secret key was dead code that also constituted a hardcoded credential (CWE-798). With this change, self.secret_key remains empty string (set by _build_for_triton), and the existing cleanup in _build_for_transformers removes empty SAGEMAKER_SERVE_SECRET_KEY from env_vars before CreateModel. Addresses: P400136088 (Bug 2 - Hardcoded secret key) --- sagemaker-serve/src/sagemaker/serve/model_builder_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 8c1fd6db1b..87048682eb 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -3075,7 +3075,8 @@ def _prepare_for_triton(self): export_path.mkdir(parents=True) if self.model: - self.secret_key = "dummy secret key for onnx backend" + # ONNX path: no pickle serialization, no serve.pkl, no integrity check needed. + # Do not set secret_key — there is nothing to sign. if self.framework == Framework.PYTORCH: self._export_pytorch_to_onnx( From fc535b47773f04886b52bf98552c50ac45d446f2 Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Thu, 19 Mar 2026 21:24:13 +0000 Subject: [PATCH 2/3] fix: Add HMAC integrity verification for Triton inference handler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses P400136088 Bug 1 and V2146375387 (Triton path). Three changes: 1. check_integrity.py: Switch from HMAC-SHA256 to plain SHA-256. - Remove generate_secret_key() — no longer needed - compute_hash() now uses hashlib.sha256() instead of hmac.new() - perform_integrity_check() no longer reads SAGEMAKER_SERVE_SECRET_KEY from environment 2. triton/model.py: Add integrity check in initialize() BEFORE cloudpickle deserialization. Previously the handler called cloudpickle.load() with no verification (acknowledged by a TODO comment). Now reads the file into a buffer, runs perform_integrity_check(), then deserializes with cloudpickle.loads(). 3. triton/server.py: Remove SAGEMAKER_SERVE_SECRET_KEY from container environment variables in both local and SageMaker deployment modes. The key is no longer needed since integrity checking uses plain SHA-256. 4. model_builder_utils.py: Update _hmac_signing() to use plain SHA-256 and stop generating/storing a secret key. Remove generate_secret_key import. The integrity check still detects accidental corruption of model artifacts in S3. The HMAC was providing a false sense of security since the key was exposed via DescribeModel/DescribeEndpointConfig APIs. --- .../sagemaker/serve/model_builder_utils.py | 12 ++++------- .../serve/model_server/triton/model.py | 10 +++++++--- .../serve/model_server/triton/server.py | 2 -- .../serve/validations/check_integrity.py | 20 ++++++------------- .../unit/test_model_builder_utils_triton.py | 8 ++++---- 5 files changed, 21 insertions(+), 31 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 87048682eb..c4495a8ffb 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -131,7 +131,6 @@ def build(self): from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.builder.requirements_manager import RequirementsManager from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.core.remote_function.core.serialization import _MetaData @@ -2884,20 +2883,17 @@ def _save_inference_spec(self) -> None: pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model") save_pkl(pkl_path, (self.inference_spec, self.schema_builder)) - def _hmac_signing(self): - """Perform HMAC signing on picke file for integrity check""" - secret_key = generate_secret_key() + def _compute_integrity_hash(self): + """Compute SHA-256 hash of serve.pkl and store in metadata.json for integrity check.""" pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model") with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - self.secret_key = secret_key - def _generate_config_pbtxt(self, pkl_path: Path): """Generate Triton config.pbtxt file.""" config_path = pkl_path.joinpath("config.pbtxt") @@ -3100,7 +3096,7 @@ def _prepare_for_triton(self): self._pack_conda_env(pkl_path=pkl_path) - self._hmac_signing() + self._compute_integrity_hash() return diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py b/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py index a1c731b0d6..7d49b0723d 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py @@ -26,10 +26,14 @@ def auto_complete_config(auto_complete_model_config): def initialize(self, args: dict) -> None: """Placeholder docstring""" serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl") - with open(str(serve_path), mode="rb") as f: - inference_spec, schema_builder = cloudpickle.load(f) + metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json") - # TODO: HMAC signing for integrity check + # Integrity check BEFORE deserialization to prevent RCE via malicious pickle + with open(str(serve_path), "rb") as f: + buffer = f.read() + perform_integrity_check(buffer=buffer, metadata_path=metadata_path) + + inference_spec, schema_builder = cloudpickle.loads(buffer) self.inference_spec = inference_spec self.schema_builder = schema_builder diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py index 134f12dd42..b425f8a689 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py @@ -41,7 +41,6 @@ def _start_triton_server( env_vars.update( { "TRITON_MODEL_DIR": "/models/model", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } ) @@ -133,7 +132,6 @@ def _upload_triton_artifacts( env_vars = { "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model", "TRITON_MODEL_DIR": "/opt/ml/model/model", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars diff --git a/sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py b/sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py index 4363d8d6ed..880ca5b602 100644 --- a/sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py +++ b/sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py @@ -1,29 +1,21 @@ -"""Validates the integrity of pickled file with HMAC signing.""" +"""Validates the integrity of pickled file with SHA-256 hash.""" from __future__ import absolute_import -import secrets import hmac import hashlib -import os from pathlib import Path from sagemaker.core.remote_function.core.serialization import _MetaData -def generate_secret_key(nbytes: int = 32) -> str: - """Generates secret key""" - return secrets.token_hex(nbytes) - - -def compute_hash(buffer: bytes, secret_key: str) -> str: - """Compute hash value using HMAC""" - return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() +def compute_hash(buffer: bytes) -> str: + """Compute SHA-256 hash of the given buffer.""" + return hashlib.sha256(buffer).hexdigest() def perform_integrity_check(buffer: bytes, metadata_path: Path): - """Validates the integrity of bytes by comparing the hash value""" - secret_key = os.environ.get("SAGEMAKER_SERVE_SECRET_KEY") - actual_hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + """Validates the integrity of bytes by comparing the hash value.""" + actual_hash_value = compute_hash(buffer=buffer) if not Path.exists(metadata_path): raise ValueError("Path to metadata.json does not exist") diff --git a/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py b/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py index bb0d1d874c..3ac82016b6 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py +++ b/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py @@ -113,7 +113,7 @@ def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy): @patch('shutil.copy2') @patch.object(_ModelBuilderUtils, '_generate_config_pbtxt') @patch.object(_ModelBuilderUtils, '_pack_conda_env') - @patch.object(_ModelBuilderUtils, '_hmac_signing') + @patch.object(_ModelBuilderUtils, '_compute_integrity_hash') def test_prepare_for_triton_inference_spec(self, mock_hmac, mock_pack, mock_config, mock_copy): """Test preparing inference spec for Triton.""" utils = _ModelBuilderUtils() @@ -262,9 +262,9 @@ def test_save_inference_spec(self): class TestHMACSignin(unittest.TestCase): - """Test _hmac_signing method.""" + """Test _compute_integrity_hash method.""" - def test_hmac_signing(self): + def test_compute_integrity_hash(self): """Test HMAC signing.""" utils = _ModelBuilderUtils() @@ -276,7 +276,7 @@ def test_hmac_signing(self): # Create dummy serve.pkl (pkl_path / "serve.pkl").write_bytes(b"dummy content") - utils._hmac_signing() + utils._compute_integrity_hash() # Secret key is generated, not mocked self.assertIsNotNone(utils.secret_key) From c3744f28e522b79a81dd7f83294eaf96c199b752 Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Thu, 19 Mar 2026 22:18:19 +0000 Subject: [PATCH 3/3] fix: Update all model server prepare.py to use plain SHA-256 Remove generate_secret_key import and usage from TorchServe, MMS, TF Serving, and SMD prepare functions. Switch compute_hash calls from HMAC-SHA256 to plain SHA-256 (no secret_key parameter). This is required because generate_secret_key was removed from check_integrity.py in the previous commit. Without this change, all model server imports fail with ImportError. --- .../sagemaker/serve/model_builder_utils.py | 902 ++++++++--------- .../multi_model_server/prepare.py | 6 +- .../model_server/multi_model_server/server.py | 4 +- .../serve/model_server/smd/prepare.py | 6 +- .../serve/model_server/smd/server.py | 1 - .../serve/model_server/tei/server.py | 2 - .../tensorflow_serving/prepare.py | 8 +- .../model_server/tensorflow_serving/server.py | 4 +- .../serve/model_server/torchserve/prepare.py | 10 +- .../serve/model_server/torchserve/server.py | 4 +- .../tests/unit/model_server/test_djl_utils.py | 2 +- .../test_in_process_model_server_app.py | 133 +-- .../test_multi_model_server_inference.py | 99 +- .../test_multi_model_server_prepare.py | 112 +-- .../test_multi_model_server_server.py | 130 +-- .../unit/model_server/test_smd_prepare.py | 92 +- .../unit/model_server/test_smd_server.py | 50 +- .../unit/model_server/test_tei_server.py | 134 +-- .../test_tensorflow_serving_inference.py | 56 +- .../test_tensorflow_serving_prepare.py | 130 ++- .../test_tensorflow_serving_server.py | 102 +- .../unit/model_server/test_tgi_prepare.py | 132 ++- .../unit/model_server/test_tgi_server.py | 118 +-- .../tests/unit/model_server/test_tgi_utils.py | 71 +- .../model_server/test_torchserve_inference.py | 72 +- .../model_server/test_torchserve_prepare.py | 132 ++- .../model_server/test_torchserve_server.py | 90 +- .../test_torchserve_xgboost_inference.py | 68 +- .../servers/test_model_builder_servers.py | 927 ++++++++++-------- .../unit/test_model_builder_utils_triton.py | 125 +-- .../unit/validations/test_check_integrity.py | 36 +- 31 files changed, 1943 insertions(+), 1815 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index c4495a8ffb..f9efe42a18 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -70,7 +70,7 @@ def build(self): _cast_to_compatible_version, _detect_framework_and_version, auto_detect_container, - _get_model_base + _get_model_base, ) from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils import task @@ -223,19 +223,20 @@ def serialize(self, data): return super().serialize(payload) + class _ModelBuilderUtils: """Utility mixin class providing common functionality for ModelBuilder. - + This class provides utility methods for: - Session management and initialization - - Instance type detection and optimization + - Instance type detection and optimization - Container image auto-detection - HuggingFace and JumpStart model handling - Resource requirement calculation - Framework serialization support - MLflow model integration - General model deployment utilities - + This class is designed to be used as a mixin with ModelBuilder classes. It expects certain attributes to be available on the instance: - sagemaker_session: SageMaker session object @@ -243,14 +244,14 @@ class _ModelBuilderUtils: - instance_type: EC2 instance type - region: AWS region - env_vars: Environment variables dict - + Example: class MyModelBuilder(ModelBuilderUtils): def __init__(self): self.model = "huggingface-model-id" self.instance_type = "ml.g5.xlarge" self.sagemaker_session = None - + def build(self): self._init_sagemaker_session_if_does_not_exist() self._auto_detect_image_uri() @@ -261,7 +262,9 @@ def build(self): # Session Management # ======================================== - def _init_sagemaker_session_if_does_not_exist(self, instance_type: Optional[str] = None) -> None: + def _init_sagemaker_session_if_does_not_exist( + self, instance_type: Optional[str] = None + ) -> None: """Initialize SageMaker session if it doesn't exist. Sets self.sagemaker_session to LocalSession for local instances, @@ -274,24 +277,25 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type: Optional[str] if self.sagemaker_session: return - effective_instance_type = instance_type or getattr(self, 'instance_type', None) - + effective_instance_type = instance_type or getattr(self, "instance_type", None) + if effective_instance_type in ("local", "local_gpu"): self.sagemaker_session = LocalSession( - sagemaker_config=getattr(self, '_sagemaker_config', None) + sagemaker_config=getattr(self, "_sagemaker_config", None) ) else: # Create session with correct region - if hasattr(self, 'region') and self.region: + if hasattr(self, "region") and self.region: import boto3 + boto_session = boto3.Session(region_name=self.region) self.sagemaker_session = Session( boto_session=boto_session, - sagemaker_config=getattr(self, '_sagemaker_config', None) + sagemaker_config=getattr(self, "_sagemaker_config", None), ) else: self.sagemaker_session = Session( - sagemaker_config=getattr(self, '_sagemaker_config', None) + sagemaker_config=getattr(self, "_sagemaker_config", None) ) # ======================================== @@ -300,98 +304,100 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type: Optional[str] def _get_jumpstart_recommended_instance_type(self) -> Optional[str]: """Get recommended instance type from JumpStart metadata. - + Returns: Recommended instance type string, or None if not available. """ try: deploy_kwargs = get_deploy_kwargs( model_id=self.model, - model_version=getattr(self, 'model_version', None) or "*", + model_version=getattr(self, "model_version", None) or "*", region=self.region, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) - + # JumpStart provides recommended instance type - if hasattr(deploy_kwargs, 'instance_type') and deploy_kwargs.instance_type: + if hasattr(deploy_kwargs, "instance_type") and deploy_kwargs.instance_type: return deploy_kwargs.instance_type - + except Exception: pass - + return None def _get_default_instance_type(self) -> str: """Get optimal default instance type based on model characteristics. - + Analyzes the model to determine appropriate instance type: - JumpStart models: Use recommended instance type from metadata - HuggingFace models: Analyze model size and tags for GPU requirements - Fallback: ml.m5.large for CPU workloads - + Returns: Instance type string (e.g., 'ml.g5.xlarge', 'ml.m5.large'). """ logger.debug("Auto-detecting optimal instance type for model...") - + if isinstance(self.model, str) and self._is_jumpstart_model_id(): recommended_type = self._get_jumpstart_recommended_instance_type() if recommended_type: logger.debug(f"Using JumpStart recommended instance type: {recommended_type}") return recommended_type - + # For HuggingFace models, use metadata to detect requirements elif isinstance(self.model, str): try: - env_vars = getattr(self, 'env_vars', {}) or {} + env_vars = getattr(self, "env_vars", {}) or {} hf_model_md = self.get_huggingface_model_metadata( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + # Check model size from metadata model_size = hf_model_md.get("safetensors", {}).get("total", 0) model_tags = hf_model_md.get("tags", []) - + # Large models or specific tags indicate GPU need - if (model_size > 2_000_000_000 or # > 2GB - any(tag in model_tags for tag in ["7b", "13b", "70b"]) or - "7b" in self.model.lower() or "13b" in self.model.lower()): + if ( + model_size > 2_000_000_000 # > 2GB + or any(tag in model_tags for tag in ["7b", "13b", "70b"]) + or "7b" in self.model.lower() + or "13b" in self.model.lower() + ): logger.debug("Detected large model, using GPU instance type: ml.g5.xlarge") return "ml.g5.xlarge" - + except Exception as e: logger.debug(f"Could not get HF metadata for smart detection: {e}") - + # Default fallback logger.debug("Using default CPU instance type: ml.m5.large") return "ml.m5.large" - + # ======================================== # Image Detection and Container Utils # ======================================== def _auto_detect_container_default(self) -> str: """Auto-detect container image for framework-based models. - + Detects the appropriate Deep Learning Container (DLC) based on: - Model framework (PyTorch, TensorFlow) - Framework version from HuggingFace metadata - Python version compatibility - Instance type requirements - + Returns: Container image URI string. - + Raises: ValueError: If instance type not specified or no compatible image found. """ from sagemaker.core import image_uris - + logger.debug("Auto-detecting image since image_uri was not provided in ModelBuilder()") - if not getattr(self, 'instance_type', None): + if not getattr(self, "instance_type", None): raise ValueError( "Instance type is not specified. " "Unable to detect if the container needs to be GPU or CPU." @@ -402,13 +408,12 @@ def _auto_detect_container_default(self) -> str: ) py_tuple = platform.python_version_tuple() - env_vars = getattr(self, 'env_vars', {}) or {} - + env_vars = getattr(self, "env_vars", {}) or {} + torch_v, tf_v, base_hf_v, _ = self._get_hf_framework_versions( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + if torch_v: fw, fw_version = "pytorch", torch_v elif tf_v: @@ -445,21 +450,20 @@ def _auto_detect_container_default(self) -> str: f"framework version {fw_version} and python version py{py_tuple[0]}{py_tuple[1]}. " f"Please manually provide image_uri to ModelBuilder()" ) - def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str: """Get SageMaker Distribution (SMD) inference image URI. - + Retrieves the appropriate SMD container image for custom orchestrator deployment. Requires Python >= 3.12 for SMD inference. - + Args: processing_unit: Target processing unit ('cpu' or 'gpu'). If None, defaults to 'cpu'. - + Returns: SMD inference image URI string. - + Raises: ValueError: If Python version < 3.12 or invalid processing unit. """ @@ -467,8 +471,9 @@ def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str: from sagemaker.core import image_uris if not self.sagemaker_session: - if hasattr(self, 'region') and self.region: + if hasattr(self, "region") and self.region: import boto3 + boto_session = boto3.Session(region_name=self.region) self.sagemaker_session = Session(boto_session=boto_session) else: @@ -483,14 +488,16 @@ def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str: INSTANCE_TYPES = {"cpu": "ml.c5.xlarge", "gpu": "ml.g5.4xlarge"} effective_processing_unit = processing_unit or "cpu" - + if effective_processing_unit not in INSTANCE_TYPES: raise ValueError( f"Invalid processing unit '{effective_processing_unit}'. " f"Must be one of: {list(INSTANCE_TYPES.keys())}" ) - logger.debug("Finding SMD inference image URI for a %s instance.", effective_processing_unit) + logger.debug( + "Finding SMD inference image URI for a %s instance.", effective_processing_unit + ) smd_uri = image_uris.retrieve( framework="sagemaker-distribution", @@ -501,163 +508,172 @@ def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str: logger.debug("Found compatible image: %s", smd_uri) return smd_uri - def _is_huggingface_model(self) -> bool: """Check if model is a HuggingFace model ID. - + Determines if the model string represents a HuggingFace model by: - Checking for organization/model-name format - Checking explicit model_type designation - Fallback: assume HuggingFace if not JumpStart - + Returns: True if model appears to be a HuggingFace model ID. """ if not isinstance(self.model, str): return False - + # Simple pattern matching for HuggingFace model IDs # Format: "organization/model-name" or just "model-name" - model_type = getattr(self, 'model_type', None) + model_type = getattr(self, "model_type", None) if "/" in self.model or model_type == "huggingface": return True - + # Additional check: if it's not a JumpStart model, assume HuggingFace return not self._is_jumpstart_model_id() - - def _get_supported_version(self, hf_config: Dict[str, Any], hugging_face_version: str, base_fw: str) -> str: + def _get_supported_version( + self, hf_config: Dict[str, Any], hugging_face_version: str, base_fw: str + ) -> str: """Extract supported framework version from HuggingFace config. - + Uses the HuggingFace JSON config to pick the best supported version for the given framework. - + Args: hf_config: HuggingFace configuration dictionary hugging_face_version: HuggingFace transformers version base_fw: Base framework name (e.g., 'pytorch', 'tensorflow') - + Returns: Best supported framework version string. """ version_config = hf_config.get("versions", {}).get(hugging_face_version, {}) versions_to_return = [] - + for key in version_config.keys(): if key.startswith(base_fw): - base_fw_version = key[len(base_fw):] + base_fw_version = key[len(base_fw) :] if len(hugging_face_version.split(".")) == 2: base_fw_version = ".".join(base_fw_version.split(".")[:-1]) versions_to_return.append(base_fw_version) - + if not versions_to_return: raise ValueError(f"No supported versions found for framework {base_fw}") - + return sorted(versions_to_return, reverse=True)[0] - def _get_hf_framework_versions(self, model_id: str, hf_token: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str, str]: + def _get_hf_framework_versions( + self, model_id: str, hf_token: Optional[str] = None + ) -> Tuple[Optional[str], Optional[str], str, str]: """Get HuggingFace framework versions for image_uris.retrieve(). - + Analyzes HuggingFace model metadata to determine the appropriate framework versions for container image selection. - + Args: model_id: HuggingFace model identifier hf_token: Optional HuggingFace API token for private models - + Returns: Tuple of (pytorch_version, tensorflow_version, transformers_version, py_version). One of pytorch_version or tensorflow_version will be None. - + Raises: ValueError: If no supported framework versions found. """ from sagemaker.core import image_uris - + # Get model metadata for framework detection hf_model_md = self.get_huggingface_model_metadata(model_id, hf_token) - + # Get HuggingFace framework configuration hf_config = image_uris.config_for_framework("huggingface").get("inference") config = hf_config["versions"] base_hf_version = sorted(config.keys(), key=lambda v: Version(v), reverse=True)[0] - + model_tags = hf_model_md.get("tags", []) - + # Detect framework from model tags if "pytorch" in model_tags: pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch") - py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get("py_versions", [])[-1] + py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get( + "py_versions", [] + )[-1] return pytorch_version, None, base_hf_version, py_version - + elif "keras" in model_tags or "tensorflow" in model_tags: - tensorflow_version = self._get_supported_version(hf_config, base_hf_version, "tensorflow") - py_version = config[base_hf_version][f"tensorflow{tensorflow_version}"].get("py_versions", [])[-1] + tensorflow_version = self._get_supported_version( + hf_config, base_hf_version, "tensorflow" + ) + py_version = config[base_hf_version][f"tensorflow{tensorflow_version}"].get( + "py_versions", [] + )[-1] return None, tensorflow_version, base_hf_version, py_version - + else: # Default to PyTorch if no framework detected (matches V2 behavior) pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch") - py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get("py_versions", [])[-1] + py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get( + "py_versions", [] + )[-1] return pytorch_version, None, base_hf_version, py_version - def _detect_jumpstart_image(self) -> None: """Detect and set image URI for JumpStart models. - + Uses JumpStart metadata to determine the appropriate container image and framework information for the model. - + Raises: ValueError: If image URI cannot be determined or JumpStart lookup fails. """ try: init_kwargs = get_init_kwargs( model_id=self.model, - model_version=getattr(self, 'model_version', None) or "*", + model_version=getattr(self, "model_version", None) or "*", region=self.region, - instance_type=getattr(self, 'instance_type', None), - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + instance_type=getattr(self, "instance_type", None), + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) - + self.image_uri = init_kwargs.get("image_uri") if not self.image_uri: raise ValueError(f"Could not determine image URI for JumpStart model: {self.model}") - + logger.debug("Auto-detected JumpStart image: %s", self.image_uri) self.framework, self.framework_version = self._extract_framework_from_image_uri() - + except Exception as e: - raise ValueError(f"Failed to auto-detect image for JumpStart model {self.model}: {e}") from e + raise ValueError( + f"Failed to auto-detect image for JumpStart model {self.model}: {e}" + ) from e - def _detect_huggingface_image(self) -> None: """Detect and set image URI for HuggingFace models based on model server. - + Automatically selects the appropriate container image based on: - Explicit model_server setting - Model task type from HuggingFace metadata - Framework requirements and versions - + Raises: ValueError: If image detection fails or unsupported model server. """ from sagemaker.core import image_uris - + try: - env_vars = getattr(self, 'env_vars', {}) or {} - + env_vars = getattr(self, "env_vars", {}) or {} + # Determine which model server we're using - model_server = getattr(self, 'model_server', None) + model_server = getattr(self, "model_server", None) if not model_server: # Auto-select model server based on HF model task hf_model_md = self.get_huggingface_model_metadata( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) model_task = hf_model_md.get("pipeline_tag") - + if model_task == "text-generation": effective_model_server = ModelServer.TGI elif model_task in ["sentence-similarity", "feature-extraction"]: @@ -666,7 +682,7 @@ def _detect_huggingface_image(self) -> None: effective_model_server = ModelServer.MMS # Transformers else: effective_model_server = model_server - + # Choose image based on effective model server if effective_model_server == ModelServer.TGI: # TGI: Use image_uris.retrieve with "huggingface-llm" framework @@ -683,11 +699,11 @@ def _detect_huggingface_image(self) -> None: self.image_uri = image_uris.retrieve( framework="huggingface-tei", image_scope="inference", - instance_type=getattr(self, 'instance_type', None), + instance_type=getattr(self, "instance_type", None), region=self.region, ) self.framework = Framework.HUGGINGFACE - + elif effective_model_server == ModelServer.DJL_SERVING: # DJL: Use image_uris.retrieve with "djl-lmi" framework (matches DJLModel default) self.image_uri = image_uris.retrieve( @@ -695,109 +711,108 @@ def _detect_huggingface_image(self) -> None: region=self.region, version="latest", image_scope="inference", - instance_type=getattr(self, 'instance_type', None) + instance_type=getattr(self, "instance_type", None), ) self.framework = Framework.DJL - + elif effective_model_server == ModelServer.MMS: # Transformers # Transformers: Use HuggingFace framework with detected versions - pytorch_version, tensorflow_version, transformers_version, py_version = \ + pytorch_version, tensorflow_version, transformers_version, py_version = ( self._get_hf_framework_versions( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + ) + base_framework_version = ( - f"pytorch{pytorch_version}" if pytorch_version + f"pytorch{pytorch_version}" + if pytorch_version else f"tensorflow{tensorflow_version}" ) - + self.image_uri = image_uris.retrieve( framework="huggingface", region=self.region, version=transformers_version, py_version=py_version, - instance_type=getattr(self, 'instance_type', None), + instance_type=getattr(self, "instance_type", None), image_scope="inference", base_framework_version=base_framework_version, ) self.framework = Framework.HUGGINGFACE - + elif effective_model_server == ModelServer.TORCHSERVE: # TorchServe: Use HuggingFace framework with detected versions - pytorch_version, tensorflow_version, transformers_version, py_version = \ + pytorch_version, tensorflow_version, transformers_version, py_version = ( self._get_hf_framework_versions( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + ) + base_framework_version = ( - f"pytorch{pytorch_version}" if pytorch_version + f"pytorch{pytorch_version}" + if pytorch_version else f"tensorflow{tensorflow_version}" ) - + self.image_uri = image_uris.retrieve( framework="huggingface", region=self.region, version=transformers_version, py_version=py_version, - instance_type=getattr(self, 'instance_type', None), + instance_type=getattr(self, "instance_type", None), image_scope="inference", base_framework_version=base_framework_version, ) self.framework = Framework.HUGGINGFACE - + elif effective_model_server == ModelServer.TRITON: # Triton: Uses custom image construction (not image_uris.retrieve) raise ValueError( "Triton image detection for HuggingFace models requires custom implementation" ) - + elif effective_model_server == ModelServer.TENSORFLOW_SERVING: # TensorFlow Serving: V2 required explicit image_uri (no auto-detection) raise ValueError("TensorFlow Serving requires explicit image_uri specification") - + elif effective_model_server == ModelServer.SMD: # SMD: Uses _get_smd_image_uri helper cpu_or_gpu = self._get_processing_unit() self.image_uri = self._get_smd_image_uri(processing_unit=cpu_or_gpu) self.framework = Framework.SMD - + else: raise ValueError( f"Unsupported model server for HuggingFace models: {effective_model_server}" ) - + logger.debug("Auto-detected HuggingFace image: %s", self.image_uri) - + except Exception as e: raise ValueError( f"Failed to auto-detect image for HuggingFace model {self.model}: {e}" ) from e - def _detect_model_object_image(self) -> None: """Detect image for legacy object-based models. - + Handles model objects (not string IDs) by using the auto_detect_container function to determine appropriate container image. - + Raises: ValueError: If neither model nor inference_spec available for detection. """ - model = getattr(self, 'model', None) - inference_spec = getattr(self, 'inference_spec', None) - model_path = getattr(self, 'model_path', None) - + model = getattr(self, "model", None) + inference_spec = getattr(self, "inference_spec", None) + model_path = getattr(self, "model_path", None) + if model: logger.debug( "Auto-detecting container URL for the provided model on instance %s", - getattr(self, 'instance_type', None), + getattr(self, "instance_type", None), ) self.image_uri, fw, self.framework_version = auto_detect_container( - model, - self.region, - getattr(self, 'instance_type', None) + model, self.region, getattr(self, "instance_type", None) ) self.framework = self._normalize_framework_to_enum(fw) @@ -809,27 +824,26 @@ def _detect_model_object_image(self) -> None: self.image_uri, fw, self.framework_version = auto_detect_container( inference_spec.load(model_path), self.region, - getattr(self, 'instance_type', None), + getattr(self, "instance_type", None), ) self.framework = self._normalize_framework_to_enum(fw) else: raise ValueError("Cannot detect required model or inference spec") - def _auto_detect_image_uri(self) -> None: """Auto-detect container image URI based on model type. - + Determines the appropriate container image by: 1. Using provided image_uri if available 2. For string models: JumpStart vs HuggingFace detection 3. For object models: Legacy auto-detection - + Sets self.image_uri, self.framework, and self.framework_version. - + Raises: ValueError: If image cannot be auto-detected for the model type. """ - image_uri = getattr(self, 'image_uri', None) + image_uri = getattr(self, "image_uri", None) if image_uri: self.framework, self.framework_version = self._extract_framework_from_image_uri() logger.debug("Skipping auto-detection as image_uri is provided: %s", image_uri) @@ -839,13 +853,13 @@ def _auto_detect_image_uri(self) -> None: self._detect_inference_image_from_training() return - model = getattr(self, 'model', None) - inference_spec = getattr(self, 'inference_spec', None) + model = getattr(self, "model", None) + inference_spec = getattr(self, "inference_spec", None) if isinstance(model, str): # V3: String-based model detection - model_type = getattr(self, 'model_type', None) - + model_type = getattr(self, "model_type", None) + # First priority: Use model_type if it indicates JumpStart if model_type in ["open_weights", "proprietary"]: self._detect_jumpstart_image() @@ -857,38 +871,40 @@ def _auto_detect_image_uri(self) -> None: self._detect_huggingface_image() else: raise ValueError(f"Cannot auto-detect image for model: {model}") - elif inference_spec and hasattr(inference_spec, 'get_model'): + elif inference_spec and hasattr(inference_spec, "get_model"): try: spec_model = inference_spec.get_model() if spec_model is None: logger.warning( - "InferenceSpec.get_model() returned None. If you are using a JumpStar or HuggingFace model, you may need to implement get_model() in your InferenceSpec class") - + "InferenceSpec.get_model() returned None. If you are using a JumpStar or HuggingFace model, you may need to implement get_model() in your InferenceSpec class" + ) + if isinstance(spec_model, str): # Temporarily set model for detection, then restore original_model = self.model self.model = spec_model - + # Use existing detection logic if self._is_jumpstart_model_id(): self._detect_jumpstart_image() elif self._is_huggingface_model(): self._detect_huggingface_image() else: - raise ValueError(f"Cannot auto-detect image for inference_spec model: {spec_model}") - + raise ValueError( + f"Cannot auto-detect image for inference_spec model: {spec_model}" + ) + # Restore original model self.model = original_model return except Exception as e: pass - + # Fall back to existing object detection self._detect_model_object_image() else: # V2: Object-based model detection self._detect_model_object_image() - # ======================================== # HuggingFace Jumpstart Utils @@ -896,32 +912,32 @@ def _auto_detect_image_uri(self) -> None: def _use_jumpstart_equivalent(self) -> bool: """Check if HuggingFace model has JumpStart equivalent and use it. - + Replaces the HuggingFace model with its JumpStart equivalent if available. Skips replacement if image_uri or env_vars are explicitly provided. - + Returns: True if JumpStart equivalent was found and used, False otherwise. """ # Do not use the equivalent JS model if image_uri or env_vars is provided - image_uri = getattr(self, 'image_uri', None) - env_vars = getattr(self, 'env_vars', None) + image_uri = getattr(self, "image_uri", None) + env_vars = getattr(self, "env_vars", None) if image_uri or env_vars: return False - + if not hasattr(self, "_has_jumpstart_equivalent"): self._jumpstart_mapping = self._retrieve_hugging_face_model_mapping() self._has_jumpstart_equivalent = self.model in self._jumpstart_mapping - + if self._has_jumpstart_equivalent: # Use schema builder from HF model metadata - schema_builder = getattr(self, 'schema_builder', None) + schema_builder = getattr(self, "schema_builder", None) if not schema_builder: model_task = None - model_metadata = getattr(self, 'model_metadata', None) + model_metadata = getattr(self, "model_metadata", None) if model_metadata: model_task = model_metadata.get("HF_TASK") - + hf_model_md = self.get_huggingface_model_metadata(self.model) if not model_task: model_task = hf_model_md.get("pipeline_tag") @@ -932,19 +948,19 @@ def _use_jumpstart_equivalent(self) -> bool: jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"] self.model = jumpstart_model_id merged_date = self._jumpstart_mapping[huggingface_model_id].get("merged-at") - + # Call _build_for_jumpstart if method exists - if hasattr(self, '_build_for_jumpstart'): + if hasattr(self, "_build_for_jumpstart"): self._build_for_jumpstart() - + compare_model_diff_message = ( "If you want to identify the differences between the two, " "please use model_uris.retrieve() to retrieve the model " "artifact S3 URI and compare them." ) - - is_gated = (hasattr(self, '_is_gated_model') and self._is_gated_model()) - + + is_gated = hasattr(self, "_is_gated_model") and self._is_gated_model() + logger.warning( "Please note that for this model we are using the JumpStart's " f'local copy "{jumpstart_model_id}" ' @@ -957,7 +973,6 @@ def _use_jumpstart_equivalent(self) -> bool: return True return False - def _hf_schema_builder_init(self, model_task: str) -> None: """Initialize schema builder for HuggingFace model task. @@ -975,8 +990,7 @@ def _hf_schema_builder_init(self, model_task: str) -> None: sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task) except ValueError: # Samples could not be loaded locally, try to fetch remote HF schema - from sagemaker_schema_inference_artifacts.huggingface import \ - remote_schema_retriever + from sagemaker_schema_inference_artifacts.huggingface import remote_schema_retriever if model_task in ("text-to-image", "automatic-speech-recognition"): logger.warning( @@ -984,37 +998,36 @@ def _hf_schema_builder_init(self, model_task: str) -> None: "with all models at this time.", model_task, ) - + remote_hf_schema_helper = remote_schema_retriever.RemoteSchemaRetriever() ( sample_inputs, sample_outputs, ) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task) - + self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) - + except ValueError as e: raise TaskNotFoundException( f"HuggingFace Schema builder samples for {model_task} could not be found " f"locally or via remote." ) from e - def _retrieve_hugging_face_model_mapping(self) -> Dict[str, Dict[str, Any]]: """Retrieve and preprocess HuggingFace/JumpStart model mapping. - + Downloads the mapping file from S3 that contains the correspondence between HuggingFace model IDs and their JumpStart equivalents. - + Returns: Dictionary mapping HuggingFace model IDs to JumpStart model metadata. Empty dict if mapping cannot be retrieved. """ converted_mapping = {} - session = getattr(self, 'sagemaker_session', None) + session = getattr(self, "sagemaker_session", None) if not session: return converted_mapping - + region = session.boto_region_name try: mapping_json_object = JumpStartS3PayloadAccessor.get_object_cached( @@ -1038,22 +1051,19 @@ def _retrieve_hugging_face_model_mapping(self) -> Dict[str, Dict[str, Any]]: def _prepare_hf_model_for_upload(self) -> None: """Download HuggingFace model metadata for upload. - + Creates a temporary directory and downloads the necessary HuggingFace model metadata files if model_path is not already set. """ - model_path = getattr(self, 'model_path', None) + model_path = getattr(self, "model_path", None) if not model_path: self.model_path = f"/tmp/sagemaker/model-builder/{self.model}" - env_vars = getattr(self, 'env_vars', {}) or {} + env_vars = getattr(self, "env_vars", {}) or {} self.download_huggingface_model_metadata( self.model, os.path.join(self.model_path, "code"), env_vars.get("HUGGING_FACE_HUB_TOKEN"), ) - - - # ======================================== # Resource and Hardware Utils @@ -1061,60 +1071,63 @@ def _prepare_hf_model_for_upload(self) -> None: def _get_processing_unit(self) -> str: """Detect if resource requirements are intended for CPU or GPU instance. - + Analyzes resource requirements to determine the target processing unit: - Checks for accelerator requirements in resource_requirements - Checks for accelerator requirements in modelbuilder_list items - Defaults to CPU if no accelerators specified - + Returns: 'gpu' if accelerators are required, 'cpu' otherwise. """ # Assume custom orchestrator will be deployed as an endpoint to a CPU instance - resource_requirements = getattr(self, 'resource_requirements', None) - if not resource_requirements or not getattr(resource_requirements, 'num_accelerators', None): - modelbuilder_list = getattr(self, 'modelbuilder_list', None) or [] + resource_requirements = getattr(self, "resource_requirements", None) + if not resource_requirements or not getattr( + resource_requirements, "num_accelerators", None + ): + modelbuilder_list = getattr(self, "modelbuilder_list", None) or [] for ic in modelbuilder_list: - ic_resource_req = getattr(ic, 'resource_requirements', None) - if ic_resource_req and getattr(ic_resource_req, 'num_accelerators', 0) > 0: + ic_resource_req = getattr(ic, "resource_requirements", None) + if ic_resource_req and getattr(ic_resource_req, "num_accelerators", 0) > 0: return "gpu" return "cpu" - - if getattr(resource_requirements, 'num_accelerators', 0) > 0: + + if getattr(resource_requirements, "num_accelerators", 0) > 0: return "gpu" return "cpu" - def _get_inference_component_resource_requirements(self, mb) -> None: """Fetch pre-benchmarked resource requirements from JumpStart. - + Attempts to retrieve and set resource requirements for inference components using JumpStart deployment configurations when available. - + Raises: ValueError: If no resource requirements provided and no JumpStart configs found. """ - resource_requirements = getattr(mb, 'resource_requirements', None) + resource_requirements = getattr(mb, "resource_requirements", None) if mb._is_jumpstart_model_id() and not resource_requirements: - if not hasattr(mb, 'list_deployment_configs'): + if not hasattr(mb, "list_deployment_configs"): return - + deployment_configs = mb.list_deployment_configs() if not deployment_configs: - inference_component_name = getattr(mb, 'inference_component_name', 'Unknown') + inference_component_name = getattr(mb, "inference_component_name", "Unknown") raise ValueError( f"No resource requirements were provided for Inference Component " f"{inference_component_name} and no default deployment " f"configs were found in JumpStart." ) - + compute_requirements = ( - deployment_configs[0].get("DeploymentArgs", {}).get("ComputeResourceRequirements", {}) + deployment_configs[0] + .get("DeploymentArgs", {}) + .get("ComputeResourceRequirements", {}) ) - + logger.debug("Retrieved pre-benchmarked deployment configurations from JumpStart.") - + mb.resource_requirements = ResourceRequirements( requests={ "memory": compute_requirements.get("MinMemoryRequiredInMb"), @@ -1126,9 +1139,8 @@ def _get_inference_component_resource_requirements(self, mb) -> None: }, limits={"memory": compute_requirements.get("MaxMemoryRequiredInMb", None)}, ) - + return mb - def _can_fit_on_single_gpu(self) -> bool: """Check if model can fit on a single GPU. @@ -1140,17 +1152,16 @@ def _can_fit_on_single_gpu(self) -> bool: True if model size <= single GPU memory size, False otherwise. """ try: - if not hasattr(self, '_try_fetch_gpu_info'): + if not hasattr(self, "_try_fetch_gpu_info"): return False - + single_gpu_size_mib = self._try_fetch_gpu_info() - env_vars = getattr(self, 'env_vars', {}) or {} - + env_vars = getattr(self, "env_vars", {}) or {} + model_size_mib = _total_inference_model_size_mib( - self.model, - env_vars.get("dtypes", "float32") + self.model, env_vars.get("dtypes", "float32") ) - + if model_size_mib <= single_gpu_size_mib: logger.debug( "Total inference model size: %s MiB, single GPU size: %s MiB", @@ -1159,56 +1170,53 @@ def _can_fit_on_single_gpu(self) -> bool: ) return True return False - + except ValueError: - instance_type = getattr(self, 'instance_type', 'Unknown') + instance_type = getattr(self, "instance_type", "Unknown") logger.debug("Unable to determine single GPU size for instance %s", instance_type) return False - - # ======================================== # Serialization Utils # ======================================== def _extract_framework_from_image_uri(self) -> Tuple[Optional[Framework], Optional[str]]: """Extract framework and version information from SageMaker image URI. - + Analyzes the container image URI to determine the ML framework and version being used. - + Returns: Tuple of (Framework enum, version string). Both can be None if not detected. """ - image_uri = getattr(self, 'image_uri', None) + image_uri = getattr(self, "image_uri", None) if not image_uri: return None, None - + if "pytorch-inference" in image_uri or "pytorch-training" in image_uri: - version_match = re.search(r'pytorch.*:(\d+\.\d+\.\d+)', image_uri) + version_match = re.search(r"pytorch.*:(\d+\.\d+\.\d+)", image_uri) return Framework.PYTORCH, version_match.group(1) if version_match else None - + elif "tensorflow-inference" in image_uri or "tensorflow-training" in image_uri: - version_match = re.search(r'tensorflow.*:(\d+\.\d+\.\d+)', image_uri) + version_match = re.search(r"tensorflow.*:(\d+\.\d+\.\d+)", image_uri) return Framework.TENSORFLOW, version_match.group(1) if version_match else None - + elif "sagemaker-xgboost" in image_uri: - version_match = re.search(r'sagemaker-xgboost:(\d+\.\d+)', image_uri) + version_match = re.search(r"sagemaker-xgboost:(\d+\.\d+)", image_uri) return Framework.XGBOOST, version_match.group(1) if version_match else None - + elif "sagemaker-scikit-learn" in image_uri: - version_match = re.search(r'scikit-learn:(\d+\.\d+)', image_uri) + version_match = re.search(r"scikit-learn:(\d+\.\d+)", image_uri) return Framework.SKLEARN, version_match.group(1) if version_match else None - + elif "huggingface" in image_uri: return Framework.HUGGINGFACE, None - + elif "mxnet" in image_uri: - version_match = re.search(r'mxnet.*:(\d+\.\d+\.\d+)', image_uri) + version_match = re.search(r"mxnet.*:(\d+\.\d+\.\d+)", image_uri) return Framework.MXNET, version_match.group(1) if version_match else None - + return None, None - def _fetch_serializer_and_deserializer_for_framework(self, framework: str) -> Tuple[Any, Any]: """Fetch default serializer and deserializer for a framework. @@ -1224,26 +1232,27 @@ def _fetch_serializer_and_deserializer_for_framework(self, framework: str) -> Tu if framework_enum and framework_enum in DEFAULT_SERIALIZERS_BY_FRAMEWORK: return DEFAULT_SERIALIZERS_BY_FRAMEWORK[framework_enum] return NumpySerializer(), JSONDeserializer() - - def _normalize_framework_to_enum(self, framework: Union[str, Framework, None]) -> Optional[Framework]: + def _normalize_framework_to_enum( + self, framework: Union[str, Framework, None] + ) -> Optional[Framework]: """Convert any framework input to Framework enum. - + Args: framework: Framework as string, enum, or None - + Returns: Framework enum or None if not found/None input """ if framework is None: return None - + if isinstance(framework, Framework): return framework - + if not isinstance(framework, str): return None - + framework_mapping = { "xgboost": Framework.XGBOOST, "xgb": Framework.XGBOOST, @@ -1268,9 +1277,8 @@ def _normalize_framework_to_enum(self, framework: Union[str, Framework, None]) - "smd": Framework.SMD, "sagemaker-distribution": Framework.SMD, } - - return framework_mapping.get(framework.lower()) + return framework_mapping.get(framework.lower()) # ======================================== # MLflow Utils @@ -1278,7 +1286,7 @@ def _normalize_framework_to_enum(self, framework: Union[str, Framework, None]) - def _handle_mlflow_input(self) -> None: """Check and handle MLflow model input if present. - + Detects MLflow model arguments, validates metadata existence, and initializes MLflow-specific configurations. """ @@ -1286,19 +1294,19 @@ def _handle_mlflow_input(self) -> None: if not self._is_mlflow_model: return - model_metadata = getattr(self, 'model_metadata', {}) + model_metadata = getattr(self, "model_metadata", {}) mlflow_model_path = model_metadata.get(MLFLOW_MODEL_PATH) if not mlflow_model_path: return - + artifact_path = self._get_artifact_path(mlflow_model_path) if not self._mlflow_metadata_exists(artifact_path): return self._initialize_for_mlflow(artifact_path) - - model_server = getattr(self, 'model_server', None) - env_vars = getattr(self, 'env_vars', {}) or {} + + model_server = getattr(self, "model_server", None) + env_vars = getattr(self, "env_vars", {}) or {} _validate_input_for_mlflow(model_server, env_vars.get("MLFLOW_MODEL_FLAVOR")) def _has_mlflow_arguments(self) -> bool: @@ -1307,9 +1315,9 @@ def _has_mlflow_arguments(self) -> bool: Returns: True if MLflow arguments are present and should be handled, False otherwise. """ - inference_spec = getattr(self, 'inference_spec', None) - model = getattr(self, 'model', None) - + inference_spec = getattr(self, "inference_spec", None) + model = getattr(self, "model", None) + if inference_spec or model: logger.debug( "Either inference spec or model is provided. " @@ -1317,7 +1325,7 @@ def _has_mlflow_arguments(self) -> bool: ) return False - model_metadata = getattr(self, 'model_metadata', None) + model_metadata = getattr(self, "model_metadata", None) if not model_metadata: logger.debug( "No ModelMetadata provided. ModelBuilder is not handling MLflow model input" @@ -1349,16 +1357,16 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str: Returns: Path to the model artifact. - + Raises: ValueError: If tracking ARN not provided for run/registry paths. ImportError: If sagemaker_mlflow not installed. """ is_run_id_type = re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path) is_registry_type = re.match(MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path) - + if is_run_id_type or is_registry_type: - model_metadata = getattr(self, 'model_metadata', {}) + model_metadata = getattr(self, "model_metadata", {}) mlflow_tracking_arn = model_metadata.get(MLFLOW_TRACKING_ARN) if not mlflow_tracking_arn: raise ValueError( @@ -1374,7 +1382,7 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str: import mlflow mlflow.set_tracking_uri(mlflow_tracking_arn) - + if is_run_id_type: _, run_id, model_path = mlflow_model_path.split("/", 2) artifact_uri = mlflow.get_run(run_id).info.artifact_uri @@ -1390,7 +1398,9 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str: if "@" in mlflow_model_path: _, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2) model_name, model_alias = model_name_and_alias.split("@") - model_version_info = mlflow_client.get_model_version_by_alias(model_name, model_alias) + model_version_info = mlflow_client.get_model_version_by_alias( + model_name, model_alias + ) source = mlflow_client.get_model_version_download_uri( model_name, model_version_info.version ) @@ -1404,7 +1414,7 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str: # Handle model package ARN if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path): - sagemaker_session = getattr(self, 'sagemaker_session', None) + sagemaker_session = getattr(self, "sagemaker_session", None) if sagemaker_session: model_package = sagemaker_session.sagemaker_client.describe_model_package( ModelPackageName=mlflow_model_path @@ -1419,7 +1429,7 @@ def _mlflow_metadata_exists(self, path: str) -> bool: Args: path: Directory path to check (local or S3). - + Returns: True if MLmodel file exists, False otherwise. """ @@ -1428,7 +1438,7 @@ def _mlflow_metadata_exists(self, path: str) -> bool: if not path.endswith("/"): path += "/" s3_uri_to_mlmodel_file = f"{path}{MLFLOW_METADATA_FILE}" - sagemaker_session = getattr(self, 'sagemaker_session', None) + sagemaker_session = getattr(self, "sagemaker_session", None) if not sagemaker_session: return False response = s3_downloader.list(s3_uri_to_mlmodel_file, sagemaker_session) @@ -1445,51 +1455,49 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None: Args: artifact_path: Path to the MLflow artifact store. - + Raises: ValueError: If artifact path is invalid. """ - model_path = getattr(self, 'model_path', None) - sagemaker_session = getattr(self, 'sagemaker_session', None) - + model_path = getattr(self, "model_path", None) + sagemaker_session = getattr(self, "sagemaker_session", None) + if artifact_path.startswith("s3://"): _download_s3_artifacts(artifact_path, model_path, sagemaker_session) elif os.path.exists(artifact_path): _copy_directory_contents(artifact_path, model_path) else: raise ValueError(f"Invalid path: {artifact_path}") - + mlflow_model_metadata_path = _generate_mlflow_artifact_path( model_path, MLFLOW_METADATA_FILE ) mlflow_model_dependency_path = _generate_mlflow_artifact_path( model_path, MLFLOW_PIP_DEPENDENCY_FILE ) - + flavor_metadata = _get_all_flavor_metadata(mlflow_model_metadata_path) deployment_flavor = _get_deployment_flavor(flavor_metadata) - current_model_server = getattr(self, 'model_server', None) + current_model_server = getattr(self, "model_server", None) self.model_server = current_model_server or _get_default_model_server_for_mlflow( deployment_flavor ) - - current_image_uri = getattr(self, 'image_uri', None) + + current_image_uri = getattr(self, "image_uri", None) if not current_image_uri: self.image_uri = _select_container_for_mlflow_model( mlflow_model_src_path=model_path, deployment_flavor=deployment_flavor, region=sagemaker_session.boto_region_name if sagemaker_session else None, - instance_type=getattr(self, 'instance_type', None), + instance_type=getattr(self, "instance_type", None), ) - - env_vars = getattr(self, 'env_vars', {}) - env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"}) - - dependencies = getattr(self, 'dependencies', {}) - dependencies.update({"requirements": mlflow_model_dependency_path}) + env_vars = getattr(self, "env_vars", {}) + env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"}) + dependencies = getattr(self, "dependencies", {}) + dependencies.update({"requirements": mlflow_model_dependency_path}) # ======================================== # Optimize Utils @@ -1511,7 +1519,6 @@ def _is_inferentia_or_trainium(self, instance_type: Optional[str]) -> bool: return True return False - def _is_image_compatible_with_optimization_job(self, image_uri: Optional[str]) -> bool: """Checks whether an instance is compatible with an optimization job. @@ -1525,7 +1532,6 @@ def _is_image_compatible_with_optimization_job(self, image_uri: Optional[str]) - return True return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri) - def _deployment_config_contains_draft_model(self, deployment_config: Optional[Dict]) -> bool: """Checks whether a deployment config contains a speculative decoding draft model. @@ -1540,8 +1546,9 @@ def _deployment_config_contains_draft_model(self, deployment_config: Optional[Di deployment_args = deployment_config.get("DeploymentArgs", {}) additional_data_sources = deployment_args.get("AdditionalDataSources") - return "speculative_decoding" in additional_data_sources if additional_data_sources else False - + return ( + "speculative_decoding" in additional_data_sources if additional_data_sources else False + ) def _is_draft_model_jumpstart_provided(self, deployment_config: Optional[Dict]) -> bool: """Checks whether a deployment config's draft model is provided by JumpStart. @@ -1565,7 +1572,6 @@ def _is_draft_model_jumpstart_provided(self, deployment_config: Optional[Dict]) continue return False - def _generate_optimized_model(self, optimization_response: dict): """Generates a new optimization model. @@ -1590,10 +1596,12 @@ def _generate_optimized_model(self, optimization_response: dict): self.instance_type = deployment_instance_type self.add_tags( - {"Key": Tag.OPTIMIZATION_JOB_NAME, "Value": optimization_response["OptimizationJobName"]} + { + "Key": Tag.OPTIMIZATION_JOB_NAME, + "Value": optimization_response["OptimizationJobName"], + } ) - def _is_optimized(self) -> bool: """Checks whether an optimization model is optimized. @@ -1612,7 +1620,6 @@ def _is_optimized(self) -> bool: return True return False - def _generate_model_source( self, model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool] ) -> Optional[Dict[str, Any]]: @@ -1636,7 +1643,6 @@ def _generate_model_source( model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True} return model_source - def _update_environment_variables( self, env: Optional[Dict[str, str]], new_env: Optional[Dict[str, str]] ) -> Optional[Dict[str, str]]: @@ -1656,9 +1662,9 @@ def _update_environment_variables( env = new_env return env - def _extract_speculative_draft_model_provider( - self, speculative_decoding_config: Optional[Dict] = None, + self, + speculative_decoding_config: Optional[Dict] = None, ) -> Optional[str]: """Extracts speculative draft model provider from speculative decoding config. @@ -1684,9 +1690,9 @@ def _extract_speculative_draft_model_provider( return "auto" - def _extract_additional_model_data_source_s3_uri( - self, additional_model_data_source: Optional[Dict] = None, + self, + additional_model_data_source: Optional[Dict] = None, ) -> Optional[str]: """Extracts model data source s3 uri from a model data source in Pascal case. @@ -1704,9 +1710,9 @@ def _extract_additional_model_data_source_s3_uri( return additional_model_data_source.get("S3DataSource").get("S3Uri") - def _extract_deployment_config_additional_model_data_source_s3_uri( - self, additional_model_data_source: Optional[Dict] = None, + self, + additional_model_data_source: Optional[Dict] = None, ) -> Optional[str]: """Extracts model data source s3 uri from a model data source in snake case. @@ -1724,9 +1730,9 @@ def _extract_deployment_config_additional_model_data_source_s3_uri( return additional_model_data_source.get("s3_data_source").get("s3_uri", None) - def _is_draft_model_gated( - self, draft_model_config: Optional[Dict] = None, + self, + draft_model_config: Optional[Dict] = None, ) -> bool: """Extracts model gated-ness from draft model data source. @@ -1738,9 +1744,9 @@ def _is_draft_model_gated( """ return "hosting_eula_key" in draft_model_config if draft_model_config else False - def _extracts_and_validates_speculative_model_source( - self, speculative_decoding_config: Dict, + self, + speculative_decoding_config: Dict, ) -> str: """Extracts model source from speculative decoding config. @@ -1759,7 +1765,6 @@ def _extracts_and_validates_speculative_model_source( raise ValueError("ModelSource must be provided in speculative decoding config.") return model_source - def _generate_channel_name(self, additional_model_data_sources: Optional[List[Dict]]) -> str: """Generates a channel name. @@ -1775,9 +1780,8 @@ def _generate_channel_name(self, additional_model_data_sources: Optional[List[Di return channel_name - def _generate_additional_model_data_sources( - self, + self, model_source: str, channel_name: str, accept_eula: bool = False, @@ -1810,7 +1814,6 @@ def _generate_additional_model_data_sources( return [additional_model_data_source] - def _is_s3_uri(self, s3_uri: Optional[str]) -> bool: """Checks whether an S3 URI is valid. @@ -1825,7 +1828,6 @@ def _is_s3_uri(self, s3_uri: Optional[str]) -> bool: return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None - def _extract_optimization_config_and_env( self, quantization_config: Optional[Dict] = None, @@ -1850,7 +1852,9 @@ def _extract_optimization_config_and_env( compilation_override_env = ( compilation_config.get("OverrideEnvironment") if compilation_config else None ) - sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None + sharding_override_env = ( + sharding_config.get("OverrideEnvironment") if sharding_config else None + ) if quantization_config is not None: optimization_config["ModelQuantizationConfig"] = quantization_config @@ -1872,7 +1876,6 @@ def _extract_optimization_config_and_env( return None, None, None, None - def _custom_speculative_decoding( self, speculative_decoding_config: Optional[Dict], @@ -1909,7 +1912,7 @@ def _custom_speculative_decoding( def _get_cached_model_specs(self, model_id, version, region, sagemaker_session): """Get cached JumpStart model specs to avoid repeated fetches""" - if not hasattr(self, '_cached_js_model_specs'): + if not hasattr(self, "_cached_js_model_specs"): self._cached_js_model_specs = accessors.JumpStartModelsAccessor.get_model_specs( model_id=model_id, version=version, @@ -1918,7 +1921,6 @@ def _get_cached_model_specs(self, model_id, version, region, sagemaker_session): ) return self._cached_js_model_specs - def _jumpstart_speculative_decoding( self, speculative_decoding_config: Optional[Dict[str, Any]] = None, @@ -1938,7 +1940,7 @@ def _jumpstart_speculative_decoding( "`ModelID` is a required field in `speculative_decoding_config` when " "using JumpStart as draft model provider." ) - + model_version = speculative_decoding_config.get("ModelVersion", "*") accept_eula = speculative_decoding_config.get("AcceptEula", False) channel_name = self._generate_channel_name(self.additional_model_data_sources) @@ -1948,9 +1950,8 @@ def _jumpstart_speculative_decoding( version=model_version, region=sagemaker_session.boto_region_name, sagemaker_session=sagemaker_session, - ) - + model_spec_json = model_specs.to_json() js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket(self.region) @@ -1964,10 +1965,12 @@ def _jumpstart_speculative_decoding( f"{eula_message} Set `AcceptEula`=True in " f"speculative_decoding_config once acknowledged." ) - js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket(self.region) + js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket( + self.region + ) key_prefix = model_spec_json.get("hosting_prepacked_artifact_key") - self.additional_model_data_sources = self. _generate_additional_model_data_sources( + self.additional_model_data_sources = self._generate_additional_model_data_sources( f"s3://{js_bucket}/{key_prefix}", channel_name, accept_eula, @@ -1980,7 +1983,6 @@ def _jumpstart_speculative_decoding( {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"}, ) - def _optimize_for_hf( self, output_path: str, @@ -2025,9 +2027,7 @@ def _optimize_for_hf( sagemaker_session=self.sagemaker_session, ) else: - self._custom_speculative_decoding( - speculative_decoding_config, False - ) + self._custom_speculative_decoding(speculative_decoding_config, False) if quantization_config or compilation_config or sharding_config: create_optimization_job_args = { @@ -2107,7 +2107,6 @@ def _optimize_prepare_for_hf(self): ) self.env_vars.update(env) - def _is_gated_model(self) -> bool: """Determine if ``this`` Model is Gated @@ -2123,7 +2122,7 @@ def _is_gated_model(self) -> bool: if s3_uri is None: return False return "private" in s3_uri - + def set_js_deployment_config(self, config_name: str, instance_type: str) -> None: """Sets the deployment config to apply to the model. @@ -2149,7 +2148,6 @@ def set_js_deployment_config(self, config_name: str, instance_type: str) -> None self.remove_tag_with_key(Tag.FINE_TUNING_MODEL_PATH) self.remove_tag_with_key(Tag.FINE_TUNING_JOB_NAME) - def _set_additional_model_source( self, speculative_decoding_config: Optional[Dict[str, Any]] = None ) -> None: @@ -2160,15 +2158,15 @@ def _set_additional_model_source( accept_eula (Optional[bool]): For models that require a Model Access Config. """ if speculative_decoding_config: - model_provider = self._extract_speculative_draft_model_provider(speculative_decoding_config) + model_provider = self._extract_speculative_draft_model_provider( + speculative_decoding_config + ) channel_name = self._generate_channel_name(self.additional_model_data_sources) if model_provider in ["sagemaker", "auto"]: additional_model_data_sources = ( - self._deployment_config.get("DeploymentArgs", {}).get( - "AdditionalDataSources" - ) + self._deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources") if self._deployment_config else None ) @@ -2177,8 +2175,9 @@ def _set_additional_model_source( speculative_decoding_config ) if deployment_config: - if model_provider == "sagemaker" and self._is_draft_model_jumpstart_provided( - deployment_config + if ( + model_provider == "sagemaker" + and self._is_draft_model_jumpstart_provided(deployment_config) ): raise ValueError( "No `Sagemaker` provided draft model was found for " @@ -2283,14 +2282,12 @@ def _get_neuron_model_env_vars( version=neuro_model_version, region=self.region, sagemaker_session=self.sagemaker_session, - ) - + model_spec_json = model_specs.to_json() return model_spec_json.get("hosting_env_vars", {}) - - return None + return None def _set_optimization_image_default( self, create_optimization_job_args: Dict[str, Any] @@ -2312,8 +2309,8 @@ def _set_optimization_image_default( region=self.region, model_version=self.model_version, hub_arn=self.hub_arn, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) default_image = self._get_default_vllm_image(init_kwargs.image_uri) @@ -2392,33 +2389,31 @@ def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]: Returns: Tuple[int, int, int]: LMI version split into major, minor, patch - + Raises: ValueError: If the image format cannot be parsed """ _, dlc_tag = image.split(":") parts = dlc_tag.split("-") - + lmi_version = None for part in parts: if "." in part and part[0].isdigit(): lmi_version = part break - + if not lmi_version: raise ValueError(f"Could not find version in image: {image}") - + version_parts = lmi_version.split(".") if len(version_parts) < 3: raise ValueError(f"Invalid version format: {lmi_version} in image: {image}") - + major_version = int(version_parts[0]) minor_version = int(version_parts[1]) patch_version = int(version_parts[2]) - - return (major_version, minor_version, patch_version) - + return (major_version, minor_version, patch_version) def _optimize_for_jumpstart( self, @@ -2523,7 +2518,9 @@ def _optimize_for_jumpstart( if self._deployment_config else None ) - self.instance_type = instance_type or deployment_config_instance_type or self._get_nb_instance() + self.instance_type = ( + instance_type or deployment_config_instance_type or self._get_nb_instance() + ) create_optimization_job_args = { "OptimizationJobName": job_name, @@ -2548,9 +2545,7 @@ def _optimize_for_jumpstart( if accept_eula: self.accept_eula = accept_eula if isinstance(self.s3_upload_path, dict): - self.s3_upload_path["S3DataSource"]["ModelAccessConfig"] = { - "AcceptEula": True - } + self.s3_upload_path["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True} optimization_env_vars = self._update_environment_variables( optimization_env_vars, @@ -2578,11 +2573,12 @@ def _optimize_for_jumpstart( ) return None - def _generate_optimized_core_model(self, optimization_response: dict) -> Model: """Generate optimized CoreModel from optimization job response.""" - - recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get("RecommendedInferenceImage") + + recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get( + "RecommendedInferenceImage" + ) s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation") deployment_instance_type = optimization_response.get("DeploymentInstanceType") if recommended_image_uri: @@ -2595,15 +2591,15 @@ def _generate_optimized_core_model(self, optimization_response: dict) -> Model: if deployment_instance_type: self.instance_type = deployment_instance_type - self.add_tags({"Key": "OptimizationJobName", "Value": optimization_response["OptimizationJobName"]}) - + self.add_tags( + {"Key": "OptimizationJobName", "Value": optimization_response["OptimizationJobName"]} + ) + self._optimizing = False optimized_core_model = self._create_model() self.built_model = optimized_core_model - - return optimized_core_model - + return optimized_core_model def deployment_config_response_data( self, @@ -2633,7 +2629,7 @@ def deployment_config_response_data( configs.append(deployment_config_json) return configs - + # @_deployment_config_lru_cache def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]: """Deployment configs benchmark metrics. @@ -2694,20 +2690,20 @@ def _get_deployment_configs( sagemaker_session=self.sagemaker_session, image_uri=image_uri, region=self.region, - model_version=getattr(self, 'model_version', None) or "*", + model_version=getattr(self, "model_version", None) or "*", hub_arn=self.hub_arn, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) deploy_kwargs = get_deploy_kwargs( model_id=self.model, instance_type=instance_type_to_use, sagemaker_session=self.sagemaker_session, region=self.region, - model_version=getattr(self, 'model_version', None) or "*", + model_version=getattr(self, "model_version", None) or "*", hub_arn=self.hub_arn, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) deployment_config_metadata = DeploymentConfigMetadata( @@ -2724,8 +2720,6 @@ def _get_deployment_configs( return deployment_configs - - # ======================================== # General Utils # ======================================== @@ -2736,7 +2730,7 @@ def add_tags(self, tags: Tags) -> None: Args: tags: Tags to add to the model. """ - current_tags = getattr(self, '_tags', None) + current_tags = getattr(self, "_tags", None) self._tags = _validate_new_tags(tags, current_tags) def remove_tag_with_key(self, key: str) -> None: @@ -2745,107 +2739,109 @@ def remove_tag_with_key(self, key: str) -> None: Args: key: The key of the tag to remove. """ - current_tags = getattr(self, '_tags', None) + current_tags = getattr(self, "_tags", None) self._tags = remove_tag_with_key(key, current_tags) def _get_model_uri(self) -> Optional[str]: """Extract model URI from s3_model_data_url. - + Returns: Model URI string, or None if not available. """ - s3_model_data_url = getattr(self, 's3_model_data_url', None) + s3_model_data_url = getattr(self, "s3_model_data_url", None) if not s3_model_data_url: return None - + if isinstance(s3_model_data_url, (str, PipelineVariable)): return s3_model_data_url elif isinstance(s3_model_data_url, dict): return s3_model_data_url.get("S3DataSource", {}).get("S3Uri", None) return None - def _ensure_base_name_if_needed(self, image_uri: str, script_uri: Optional[str], model_uri: Optional[str]) -> None: + def _ensure_base_name_if_needed( + self, image_uri: str, script_uri: Optional[str], model_uri: Optional[str] + ) -> None: """Create base name from image URI if no model name provided. Uses JumpStart base name if available, otherwise derives from image URI. - + Args: image_uri: Container image URI script_uri: Optional script URI for JumpStart models model_uri: Optional model URI for JumpStart models """ - model_name = getattr(self, 'model_name', None) + model_name = getattr(self, "model_name", None) if model_name is None: - base_name = getattr(self, '_base_name', None) + base_name = getattr(self, "_base_name", None) self._base_name = ( base_name or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri) or base_name_from_image(image_uri, default_base_name="ModelBuilder") ) - def _ensure_metadata_configs(self) -> None: """Lazy load JumpStart metadata configs when needed.""" - metadata_configs = getattr(self, '_metadata_configs', None) - model = getattr(self, 'model', None) - + metadata_configs = getattr(self, "_metadata_configs", None) + model = getattr(self, "model", None) + if metadata_configs is None and isinstance(model, str): from sagemaker.core.jumpstart.utils import get_jumpstart_configs - + self._metadata_configs = get_jumpstart_configs( region=self.region, model_id=model, - model_version=getattr(self, 'model_version', None) or "*", - sagemaker_session=getattr(self, 'sagemaker_session', None), + model_version=getattr(self, "model_version", None) or "*", + sagemaker_session=getattr(self, "sagemaker_session", None), ) - + def _user_agent_decorator(self, func): """Decorator to add ModelBuilder to user agent string. - + Args: func: Function to decorate - + Returns: Decorated function that appends ModelBuilder to user agent. """ + def wrapper(*args, **kwargs): # Call the original function result = func(*args, **kwargs) if "ModelBuilder" in result: return result return result + " ModelBuilder" + return wrapper def _get_serve_setting(self) -> _ServeSettings: """Get serve settings for model deployment. - + Creates or uses existing S3 model data URL and constructs serve settings with deployment configuration. - + Returns: ServeSettings object with deployment configuration. """ - s3_model_data_url = getattr(self, 's3_model_data_url', None) + s3_model_data_url = getattr(self, "s3_model_data_url", None) if not s3_model_data_url: - sagemaker_session = getattr(self, 'sagemaker_session', None) + sagemaker_session = getattr(self, "sagemaker_session", None) if sagemaker_session: bucket = sagemaker_session.default_bucket() - model_name = getattr(self, 'model_name', None) + model_name = getattr(self, "model_name", None) prefix = f"model-builder/{model_name or 'model'}/{uuid.uuid4().hex}" self.s3_model_data_url = f"s3://{bucket}/{prefix}/" - + return _ServeSettings( - role_arn=getattr(self, 'role_arn', None), - s3_model_data_url=getattr(self, 's3_model_data_url', None), - instance_type=getattr(self, 'instance_type', None), - env_vars=getattr(self, 'env_vars', None), - sagemaker_session=getattr(self, 'sagemaker_session', None), + role_arn=getattr(self, "role_arn", None), + s3_model_data_url=getattr(self, "s3_model_data_url", None), + instance_type=getattr(self, "instance_type", None), + env_vars=getattr(self, "env_vars", None), + sagemaker_session=getattr(self, "sagemaker_session", None), ) - def _is_jumpstart_model_id(self) -> bool: """Check if model is a JumpStart model ID.""" - if not hasattr(self, '_cached_is_jumpstart'): + if not hasattr(self, "_cached_is_jumpstart"): if self.model is None: self._cached_is_jumpstart = False return self._cached_is_jumpstart @@ -2862,7 +2858,6 @@ def _is_jumpstart_model_id(self) -> bool: return self._cached_is_jumpstart return self._cached_is_jumpstart - def _has_nvidia_gpu(self) -> bool: try: @@ -2872,7 +2867,7 @@ def _has_nvidia_gpu(self) -> bool: # for nvidia-smi to run, a cuda driver must be present logger.debug("CUDA not found, launching Triton in CPU mode.") return False - + def _is_gpu_instance(self, instance_type: str) -> bool: instance_family = instance_type.rsplit(".", 1)[0] return instance_family in GPU_INSTANCE_FAMILIES @@ -2882,7 +2877,7 @@ def _save_inference_spec(self) -> None: if self.inference_spec: pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model") save_pkl(pkl_path, (self.inference_spec, self.schema_builder)) - + def _compute_integrity_hash(self): """Compute SHA-256 hash of serve.pkl and store in metadata.json for integrity check.""" pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model") @@ -2920,6 +2915,7 @@ def _pack_conda_env(self, pkl_path: Path): """Pack conda environment for Triton deployment.""" try: import conda_pack + conda_pack.__version__ except ModuleNotFoundError: raise ImportError( @@ -2987,11 +2983,12 @@ def _export_pytorch_to_onnx( "And follow the ones that match your environment. " "Please note that you may need to restart your runtime after installation." ) - + def _validate_for_triton(self): """Validation for Triton deployment.""" try: import tritonclient.http as httpClient + httpClient.__class__ except ModuleNotFoundError: raise ImportError( @@ -3101,7 +3098,6 @@ def _prepare_for_triton(self): return raise ValueError("Either model or inference_spec should be provided to ModelBuilder.") - def _auto_detect_image_for_triton(self): """Detect image of triton given framework, version and region. @@ -3150,7 +3146,6 @@ def _auto_detect_image_for_triton(self): self.image_uri += "-cpu" logger.debug(f"Autodetected image: {self.image_uri}. Proceeding with the deployment.") - def _validate_djl_serving_sample_data(self): """Validate sample data format for DJL serving.""" @@ -3166,7 +3161,7 @@ def _validate_djl_serving_sample_data(self): or "generated_text" not in sample_output[0] ): raise ValueError(_INVALID_DJL_SAMPLE_DATA_EX) - + def _validate_tgi_serving_sample_data(self): """Validate sample data format for TGI serving.""" sample_input = self.schema_builder.sample_input @@ -3181,7 +3176,7 @@ def _validate_tgi_serving_sample_data(self): or "generated_text" not in sample_output[0] ): raise ValueError(_INVALID_TGI_SAMPLE_DATA_EX) - + def _create_conda_env(self): """Create conda environment by running commands.""" try: @@ -3189,13 +3184,16 @@ def _create_conda_env(self): except subprocess.CalledProcessError: logger.error("Failed to create and activate conda environment.") - - def _extract_framework_from_model_trainer(self, model_trainer: ModelTrainer) -> Optional[Framework]: + def _extract_framework_from_model_trainer( + self, model_trainer: ModelTrainer + ) -> Optional[Framework]: """Extract framework from ModelTrainer training image.""" training_image = model_trainer.training_image if not training_image: - training_image = model_trainer._latest_training_job.algorithm_specification.training_image - + training_image = ( + model_trainer._latest_training_job.algorithm_specification.training_image + ) + if "pytorch" in training_image.lower(): return Framework.PYTORCH elif "tensorflow" in training_image.lower(): @@ -3204,15 +3202,16 @@ def _extract_framework_from_model_trainer(self, model_trainer: ModelTrainer) -> return Framework.HUGGINGFACE elif "xgboost" in training_image.lower(): return Framework.XGBOOST - - return None + return None - def _infer_model_server_from_training(self, model_trainer: ModelTrainer) -> Optional[ModelServer]: + def _infer_model_server_from_training( + self, model_trainer: ModelTrainer + ) -> Optional[ModelServer]: """Infer the best model server based on training configuration.""" training_image = model_trainer.training_image framework = self._extract_framework_from_model_trainer(model_trainer) - + if "huggingface" in training_image.lower(): hyperparams = model_trainer.hyperparameters or {} if any(key in hyperparams for key in ["max_new_tokens", "do_sample", "temperature"]): @@ -3221,29 +3220,30 @@ def _infer_model_server_from_training(self, model_trainer: ModelTrainer) -> Opti else: logger.info("Auto-detected model server: MMS (HuggingFace)") return ModelServer.MMS - + if framework == Framework.PYTORCH: logger.info("Auto-detected model server: TORCHSERVE (PyTorch framework)") return ModelServer.TORCHSERVE - + if framework == Framework.TENSORFLOW: logger.info("Auto-detected model server: TENSORFLOW_SERVING (TensorFlow framework)") return ModelServer.TENSORFLOW_SERVING - + logger.warning( f"Could not auto-detect model server for framework: {framework}. " "Defaulting to TORCHSERVE. Consider explicitly setting model_server parameter." ) return ModelServer.TORCHSERVE - - def _extract_inference_spec_from_training_code(self, model_trainer: ModelTrainer) -> Optional[str]: + def _extract_inference_spec_from_training_code( + self, model_trainer: ModelTrainer + ) -> Optional[str]: """Check if training source code contains inference.py.""" if not model_trainer.source_code or not model_trainer.source_code.source_dir: return None - + source_dir = model_trainer.source_code.source_dir - + # Check for inference.py in source directory if source_dir.startswith("s3://"): pass @@ -3251,59 +3251,65 @@ def _extract_inference_spec_from_training_code(self, model_trainer: ModelTrainer inference_path = os.path.join(source_dir, "inference.py") if os.path.exists(inference_path): return inference_path - + return None - def _inherit_training_environment(self, model_trainer: ModelTrainer) -> Dict[str, str]: """Inherit relevant environment variables from training.""" from sagemaker.core.utils.utils import Unassigned - + training_env = model_trainer.environment or {} if isinstance(training_env, Unassigned): training_env = {} - + training_job_env = model_trainer._latest_training_job.environment if isinstance(training_job_env, Unassigned) or training_job_env is None: training_job_env = {} - + inherited_env = {**training_env, **training_job_env} inference_relevant_keys = [ - "HUGGING_FACE_HUB_TOKEN", "HF_TOKEN", - "MODEL_CLASS_NAME", "TRANSFORMERS_CACHE", - "PYTORCH_TRANSFORMERS_CACHE", "HF_HOME" + "HUGGING_FACE_HUB_TOKEN", + "HF_TOKEN", + "MODEL_CLASS_NAME", + "TRANSFORMERS_CACHE", + "PYTORCH_TRANSFORMERS_CACHE", + "HF_HOME", ] - - return {k: v for k, v in inherited_env.items() - if k in inference_relevant_keys or k.startswith("SAGEMAKER_")} - + + return { + k: v + for k, v in inherited_env.items() + if k in inference_relevant_keys or k.startswith("SAGEMAKER_") + } def _extract_version_from_training_image(self, training_image: str) -> Optional[str]: """Extract framework version from training image URI.""" import re - - version_match = re.search(r':(\d+\.\d+(?:\.\d+)?)', training_image) + + version_match = re.search(r":(\d+\.\d+(?:\.\d+)?)", training_image) if version_match: return version_match.group(1) - - return None + return None def _detect_inference_image_from_training(self) -> None: """Detect inference image based on ModelTrainer's training image.""" from sagemaker.core import image_uris + training_image = self.model.training_image - + if "pytorch-training" in training_image: self.image_uri = training_image.replace("pytorch-training", "pytorch-inference") elif "tensorflow-training" in training_image: self.image_uri = training_image.replace("tensorflow-training", "tensorflow-inference") elif "huggingface-pytorch-training" in training_image: - self.image_uri = training_image.replace("huggingface-pytorch-training", "huggingface-pytorch-inference") + self.image_uri = training_image.replace( + "huggingface-pytorch-training", "huggingface-pytorch-inference" + ) else: framework = self._extract_framework_from_model_trainer(self.model) fw = framework.value.lower() if framework else "pytorch" - + fw_version = self._extract_version_from_training_image(training_image) py_tuple = platform.python_version_tuple() casted_versions = _cast_to_compatible_version(fw, fw_version) if fw_version else (None,) @@ -3322,12 +3328,13 @@ def _detect_inference_image_from_training(self) -> None: break except ValueError: pass - + if dlc: self.image_uri = dlc else: - raise ValueError(f"Could not detect inference image for training image: {training_image}") - + raise ValueError( + f"Could not detect inference image for training image: {training_image}" + ) def _extract_speculative_draft_model_provider( self, @@ -3356,9 +3363,10 @@ def _extract_speculative_draft_model_provider( return "sagemaker" return "auto" - - def get_huggingface_model_metadata(self, model_id: str, hf_hub_token: Optional[str] = None) -> dict: + def get_huggingface_model_metadata( + self, model_id: str, hf_hub_token: Optional[str] = None + ) -> dict: """Retrieves the json metadata of the HuggingFace Model via HuggingFace API. Args: @@ -3402,7 +3410,6 @@ def get_huggingface_model_metadata(self, model_id: str, hf_hub_token: Optional[s ) return hf_model_metadata_json - def download_huggingface_model_metadata( self, model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None ) -> None: @@ -3417,11 +3424,12 @@ def download_huggingface_model_metadata( ImportError: If huggingface_hub is not installed. """ if not importlib.util.find_spec("huggingface_hub"): - raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed") + raise ImportError( + "Unable to import huggingface_hub, check if huggingface_hub is installed" + ) from huggingface_hub import snapshot_download os.makedirs(model_local_path, exist_ok=True) logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path) snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token) - diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py b/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py index 37ca745987..3b347ee65c 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py @@ -26,7 +26,6 @@ from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.detector.dependency_manager import capture_dependencies from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.core.remote_function.core.serialization import _MetaData @@ -119,11 +118,8 @@ def prepare_for_mms( capture_dependencies(dependencies=dependencies, work_dir=code_dir) - secret_key = generate_secret_key() with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - - return secret_key diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py index 9401dd74d9..1e02be0621 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -35,7 +35,6 @@ def _start_serving( env = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } if env_vars: @@ -47,7 +46,7 @@ def _start_serving( image, "serve", # network_mode="host", - ports={'8080/tcp': 8080}, + ports={"8080/tcp": 8080}, detach=True, auto_remove=True, volumes={ @@ -131,7 +130,6 @@ def _upload_server_artifacts( env_vars = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "SAGEMAKER_REGION": sagemaker_session.boto_region_name, "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", "LOCAL_PYTHON": platform.python_version(), diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py b/sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py index b66de32bf7..f29b8ebcbd 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py @@ -12,7 +12,6 @@ from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.detector.dependency_manager import capture_dependencies from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.core.remote_function.core.serialization import _MetaData @@ -64,11 +63,8 @@ def prepare_for_smd( capture_dependencies(dependencies=dependencies, work_dir=code_dir) - secret_key = generate_secret_key() with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - - return secret_key diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py index e40dc3aa61..ecb68406c1 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py @@ -53,7 +53,6 @@ def _upload_smd_artifacts( "SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_INFERENCE_CODE": "inference.handler", "SAGEMAKER_REGION": sagemaker_session.boto_region_name, - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py index 9f2f4b71b3..c23c52a513 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py @@ -38,8 +38,6 @@ def _start_tei_serving( secret_key: Secret key to use for authentication env_vars: Environment variables to set """ - if env_vars and secret_key: - env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key self.container = client.containers.run( image, diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py b/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py index 3525cc9b4a..d56d0ec7bd 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py @@ -11,7 +11,6 @@ ) from sagemaker.serve.detector.dependency_manager import capture_dependencies from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.core.remote_function.core.serialization import _MetaData @@ -56,12 +55,9 @@ def prepare_for_tf_serving( if not mlflow_saved_model_dir: raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.") _move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir) - - secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - - return secret_key diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py index 2f4a959528..cbd6412d34 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py @@ -37,7 +37,7 @@ def _start_tensorflow_serving( detach=True, auto_remove=False, # Temporarily disabled to see crash logs # network_mode="host", - ports={'8501/tcp': 8501}, + ports={"8501/tcp": 8501}, volumes={ Path(model_path): { "bind": "/opt/ml/model", @@ -47,7 +47,6 @@ def _start_tensorflow_serving( environment={ "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), **env_vars, }, @@ -124,7 +123,6 @@ def _upload_tensorflow_serving_artifacts( "SAGEMAKER_PROGRAM": "inference.py", "SAGEMAKER_REGION": sagemaker_session.boto_region_name, "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py b/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py index 988acf646d..ad053d25c9 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py @@ -13,7 +13,6 @@ from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.detector.dependency_manager import capture_dependencies from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.serve.validations.check_image_uri import is_1p_image_uri @@ -56,7 +55,9 @@ def prepare_for_torchserve( # https://github.com/aws/sagemaker-python-sdk/issues/4288 if is_1p_image_uri(image_uri=image_uri) and "xgboost" in image_uri: shutil.copy2(Path(__file__).parent.joinpath("xgboost_inference.py"), code_dir) - os.rename(str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py"))) + os.rename( + str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py")) + ) else: shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) @@ -67,11 +68,8 @@ def prepare_for_torchserve( capture_dependencies(dependencies=dependencies, work_dir=code_dir) - secret_key = generate_secret_key() with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - - return secret_key \ No newline at end of file diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/server.py index 0d237df987..9cc4e6196f 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/server.py @@ -29,7 +29,7 @@ def _start_torch_serve( detach=True, auto_remove=True, # network_mode="host", - ports={'8080/tcp': 8080}, + ports={"8080/tcp": 8080}, volumes={ Path(model_path): { "bind": "/opt/ml/model", @@ -39,7 +39,6 @@ def _start_torch_serve( environment={ "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), **env_vars, }, @@ -103,7 +102,6 @@ def _upload_torchserve_artifacts( "SAGEMAKER_PROGRAM": "inference.py", "SAGEMAKER_REGION": sagemaker_session.boto_region_name, "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars diff --git a/sagemaker-serve/tests/unit/model_server/test_djl_utils.py b/sagemaker-serve/tests/unit/model_server/test_djl_utils.py index 814feb3cc4..a66f04cbf4 100644 --- a/sagemaker-serve/tests/unit/model_server/test_djl_utils.py +++ b/sagemaker-serve/tests/unit/model_server/test_djl_utils.py @@ -6,7 +6,7 @@ _get_default_batch_size, _tokens_from_chars, _tokens_from_words, - _set_tokens_to_tokens_threshold + _set_tokens_to_tokens_threshold, ) diff --git a/sagemaker-serve/tests/unit/model_server/test_in_process_model_server_app.py b/sagemaker-serve/tests/unit/model_server/test_in_process_model_server_app.py index deeeefc704..d9c32c7e2d 100644 --- a/sagemaker-serve/tests/unit/model_server/test_in_process_model_server_app.py +++ b/sagemaker-serve/tests/unit/model_server/test_in_process_model_server_app.py @@ -14,10 +14,10 @@ # Mock optional dependencies before importing mock_transformers = MagicMock() -mock_pipeline_class = type('Pipeline', (), {}) +mock_pipeline_class = type("Pipeline", (), {}) mock_transformers.Pipeline = mock_pipeline_class -sys.modules['transformers'] = mock_transformers -sys.modules['sentence_transformers'] = MagicMock() +sys.modules["transformers"] = mock_transformers +sys.modules["sentence_transformers"] = MagicMock() from sagemaker.serve.model_server.in_process_model_server.app import InProcessServer @@ -25,22 +25,22 @@ class TestInProcessServerInitialization(unittest.TestCase): """Test InProcessServer initialization.""" - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_init_with_inference_spec(self, mock_fastapi, mock_uvicorn): """Test initialization with inference_spec.""" mock_inference_spec = Mock() mock_model = Mock() mock_inference_spec.load.return_value = mock_model mock_schema_builder = Mock() - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, schema_builder=mock_schema_builder, - task="text-generation" + task="text-generation", ) - + self.assertEqual(server.model, "test-model") self.assertEqual(server.inference_spec, mock_inference_spec) self.assertEqual(server.schema_builder, mock_schema_builder) @@ -62,19 +62,19 @@ def test_init_fallback_to_sentence_transformer(self): # This test requires sentence-transformers to be installed self.skipTest("Requires sentence-transformers package") - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_init_without_model_or_inference_spec_raises_error(self, mock_fastapi, mock_uvicorn): """Test that initialization without model or inference_spec raises ValueError.""" mock_schema_builder = Mock() - + with self.assertRaises(ValueError) as context: InProcessServer(schema_builder=mock_schema_builder) - + self.assertIn("Either inference_spec or model must be provided", str(context.exception)) - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_create_server_configuration(self, mock_fastapi, mock_uvicorn): """Test that server is created with correct configuration.""" mock_inference_spec = Mock() @@ -88,24 +88,24 @@ def test_create_server_configuration(self, mock_fastapi, mock_uvicorn): mock_uvicorn.Config.return_value = mock_config mock_server = Mock() mock_uvicorn.Server.return_value = mock_server - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # Verify FastAPI app was created mock_fastapi.assert_called_once() mock_app.include_router.assert_called_once() - + # Verify uvicorn config mock_uvicorn.Config.assert_called_once() config_call_args = mock_uvicorn.Config.call_args - self.assertEqual(config_call_args[1]['host'], "127.0.0.1") - self.assertEqual(config_call_args[1]['port'], 9007) - self.assertEqual(config_call_args[1]['log_level'], "info") - + self.assertEqual(config_call_args[1]["host"], "127.0.0.1") + self.assertEqual(config_call_args[1]["port"], 9007) + self.assertEqual(config_call_args[1]["log_level"], "info") + # Verify server attributes self.assertEqual(server.host, "127.0.0.1") self.assertEqual(server.port, 9007) @@ -115,38 +115,39 @@ def test_create_server_configuration(self, mock_fastapi, mock_uvicorn): class TestInProcessServerInvokeEndpoint(unittest.TestCase): """Test InProcessServer /invoke endpoint.""" - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_invoke_with_inference_spec(self, mock_fastapi, mock_uvicorn): """Test /invoke endpoint with inference_spec.""" mock_inference_spec = Mock() mock_model = Mock() mock_inference_spec.load.return_value = mock_model mock_inference_spec.invoke.return_value = {"predictions": [0.1, 0.9]} - + mock_schema_builder = Mock() mock_deserializer = Mock() mock_deserializer.deserialize.return_value = {"inputs": [[1, 2, 3]]} mock_schema_builder.input_deserializer = mock_deserializer - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # Simulate request mock_request = AsyncMock() mock_request.headers = {"Content-Type": ["application/json"]} mock_request.body = AsyncMock(return_value=b'{"inputs": [[1, 2, 3]]}') - + # Get the invoke function from the router invoke_func = server._router.routes[0].endpoint - + # Run async function import asyncio + result = asyncio.run(invoke_func(mock_request)) - + self.assertEqual(result, {"predictions": [0.1, 0.9]}) mock_inference_spec.invoke.assert_called_once_with({"inputs": [[1, 2, 3]]}, mock_model) @@ -166,51 +167,51 @@ def test_invoke_with_sentence_transformer(self): class TestInProcessServerLifecycle(unittest.TestCase): """Test InProcessServer lifecycle methods.""" - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_start_server(self, mock_fastapi, mock_uvicorn): """Test starting the server.""" mock_inference_spec = Mock() mock_inference_spec.load.return_value = Mock() mock_schema_builder = Mock() - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - - with patch.object(threading.Thread, 'start') as mock_thread_start: + + with patch.object(threading.Thread, "start") as mock_thread_start: server.start_server() mock_thread_start.assert_called_once() self.assertIsNotNone(server._thread) - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_start_server_when_already_running(self, mock_fastapi, mock_uvicorn): """Test starting server when it's already running.""" mock_inference_spec = Mock() mock_inference_spec.load.return_value = Mock() mock_schema_builder = Mock() - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # Mock thread as already running mock_thread = Mock() mock_thread.is_alive.return_value = True server._thread = mock_thread - - with patch.object(threading.Thread, 'start') as mock_thread_start: + + with patch.object(threading.Thread, "start") as mock_thread_start: server.start_server() # Should not start a new thread mock_thread_start.assert_not_called() - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_stop_server(self, mock_fastapi, mock_uvicorn): """Test stopping the server.""" mock_inference_spec = Mock() @@ -218,26 +219,26 @@ def test_stop_server(self, mock_fastapi, mock_uvicorn): mock_schema_builder = Mock() mock_server = Mock() mock_uvicorn.Server.return_value = mock_server - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # Mock thread as running mock_thread = Mock() mock_thread.is_alive.return_value = True server._thread = mock_thread - + server.stop_server() - + self.assertTrue(server._shutdown_event.is_set()) mock_server.handle_exit.assert_called_once_with(sig=0, frame=None) mock_thread.join.assert_called_once() - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_stop_server_when_not_running(self, mock_fastapi, mock_uvicorn): """Test stopping server when it's not running.""" mock_inference_spec = Mock() @@ -245,40 +246,40 @@ def test_stop_server_when_not_running(self, mock_fastapi, mock_uvicorn): mock_schema_builder = Mock() mock_server = Mock() mock_uvicorn.Server.return_value = mock_server - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # No thread or thread not alive server._thread = None - + # Should not raise error server.stop_server() mock_server.handle_exit.assert_not_called() - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') - @patch('sagemaker.serve.model_server.in_process_model_server.app.asyncio') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") + @patch("sagemaker.serve.model_server.in_process_model_server.app.asyncio") def test_start_run_async_in_thread(self, mock_asyncio, mock_fastapi, mock_uvicorn): """Test _start_run_async_in_thread method.""" mock_inference_spec = Mock() mock_inference_spec.load.return_value = Mock() mock_schema_builder = Mock() - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + mock_loop = Mock() mock_asyncio.new_event_loop.return_value = mock_loop - + server._start_run_async_in_thread() - + mock_asyncio.new_event_loop.assert_called_once() mock_asyncio.set_event_loop.assert_called_once_with(mock_loop) mock_loop.run_until_complete.assert_called_once() @@ -300,5 +301,5 @@ def test_invoke_without_inputs_key(self): self.skipTest("Requires transformers package") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_inference.py b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_inference.py index 5842bf0f8d..34c0fa671f 100644 --- a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_inference.py +++ b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_inference.py @@ -13,128 +13,161 @@ class TestMultiModelServerInference(unittest.TestCase): def test_predict_fn_logic(self): """Test predict_fn logic.""" + def predict_fn(input_data, predict_callable, context=None): return predict_callable(input_data) - + mock_predict_callable = Mock(return_value=[0.1, 0.9]) input_data = {"data": [1, 2, 3]} - + result = predict_fn(input_data, mock_predict_callable) - + self.assertEqual(result, [0.1, 0.9]) mock_predict_callable.assert_called_once_with(input_data) def test_input_fn_with_preprocess_logic(self): """Test input_fn with preprocess logic.""" + def input_fn(input_data, content_type, schema_builder, inference_spec, context=None): # Deserialize if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type[0], ) - + # Preprocess if available if hasattr(inference_spec, "preprocess"): preprocessed = inference_spec.preprocess(deserialized_data) if preprocessed is not None: return preprocessed - + return deserialized_data - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value={"data": [1, 2, 3]}) - + inference_spec = Mock() inference_spec.preprocess = Mock(return_value={"preprocessed": True}) - - result = input_fn('{"data": [1, 2, 3]}', ["application/json"], schema_builder, inference_spec) - + + result = input_fn( + '{"data": [1, 2, 3]}', ["application/json"], schema_builder, inference_spec + ) + self.assertEqual(result, {"preprocessed": True}) inference_spec.preprocess.assert_called_once_with({"data": [1, 2, 3]}) def test_input_fn_with_bytes_input_logic(self): """Test input_fn with bytes input.""" + def input_fn(input_data, content_type, schema_builder, inference_spec, context=None): if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if isinstance(input_data, (bytes, bytearray)) else io.BytesIO(input_data.encode("utf-8")), + ( + io.BytesIO(input_data) + if isinstance(input_data, (bytes, bytearray)) + else io.BytesIO(input_data.encode("utf-8")) + ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if isinstance(input_data, (bytes, bytearray)) else io.BytesIO(input_data.encode("utf-8")), + ( + io.BytesIO(input_data) + if isinstance(input_data, (bytes, bytearray)) + else io.BytesIO(input_data.encode("utf-8")) + ), content_type[0], ) return deserialized_data - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value={"data": [1, 2, 3]}) - + inference_spec = None - - result = input_fn(b'{"data": [1, 2, 3]}', ["application/json"], schema_builder, inference_spec) - + + result = input_fn( + b'{"data": [1, 2, 3]}', ["application/json"], schema_builder, inference_spec + ) + self.assertEqual(result, {"data": [1, 2, 3]}) def test_output_fn_with_postprocess_logic(self): """Test output_fn with postprocess logic.""" + def output_fn(predictions, accept_type, schema_builder, inference_spec, context=None): # Postprocess if available if hasattr(inference_spec, "postprocess"): postprocessed = inference_spec.postprocess(predictions) if postprocessed is not None: predictions = postprocessed - + # Serialize if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: return schema_builder.output_serializer.serialize(predictions) - + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + inference_spec = Mock() inference_spec.postprocess = Mock(return_value={"postprocessed": True}) - + result = output_fn([0.1, 0.9], "application/json", schema_builder, inference_spec) - + inference_spec.postprocess.assert_called_once_with([0.1, 0.9]) - schema_builder.custom_output_translator.serialize.assert_called_once_with({"postprocessed": True}, "application/json") + schema_builder.custom_output_translator.serialize.assert_called_once_with( + {"postprocessed": True}, "application/json" + ) def test_output_fn_postprocess_returns_none_logic(self): """Test output_fn when postprocess returns None.""" + def output_fn(predictions, accept_type, schema_builder, inference_spec, context=None): if hasattr(inference_spec, "postprocess"): postprocessed = inference_spec.postprocess(predictions) if postprocessed is not None: predictions = postprocessed - + if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: return schema_builder.output_serializer.serialize(predictions) - + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + inference_spec = Mock() inference_spec.postprocess = Mock(return_value=None) - + result = output_fn([0.1, 0.9], "application/json", schema_builder, inference_spec) - + # Should use original predictions since postprocess returned None - schema_builder.custom_output_translator.serialize.assert_called_once_with([0.1, 0.9], "application/json") + schema_builder.custom_output_translator.serialize.assert_called_once_with( + [0.1, 0.9], "application/json" + ) if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_prepare.py b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_prepare.py index d6a571cd1a..8d8f5ec9d2 100644 --- a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_prepare.py @@ -17,147 +17,149 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space') + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space") def test_create_dir_structure_creates_directories(self, mock_disk_space, mock_docker_disk): """Test _create_dir_structure creates model and code directories.""" from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure - + model_path = Path(self.temp_dir) / "model" model_path_obj, code_dir = _create_dir_structure(str(model_path)) - + self.assertTrue(model_path.exists()) self.assertTrue(code_dir.exists()) self.assertEqual(code_dir, model_path / "code") mock_disk_space.assert_called_once() mock_docker_disk.assert_called_once() - @patch('sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space') + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space") def test_create_dir_structure_raises_on_file(self, mock_disk_space, mock_docker_disk): """Test _create_dir_structure raises ValueError when path is a file.""" from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + with self.assertRaises(ValueError) as context: _create_dir_structure(str(file_path)) self.assertIn("not a valid directory", str(context.exception)) - @patch('sagemaker.serve.model_server.multi_model_server.prepare._copy_jumpstart_artifacts') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') + @patch("sagemaker.serve.model_server.multi_model_server.prepare._copy_jumpstart_artifacts") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") def test_prepare_mms_js_resources(self, mock_create_dir, mock_copy_js): """Test prepare_mms_js_resources calls necessary functions.""" from sagemaker.serve.model_server.multi_model_server.prepare import prepare_mms_js_resources - + mock_model_path = Path(self.temp_dir) / "model" mock_code_dir = mock_model_path / "code" mock_create_dir.return_value = (mock_model_path, mock_code_dir) mock_copy_js.return_value = ({"config": "data"}, True) - + result = prepare_mms_js_resources( model_path=str(mock_model_path), js_id="test-js-id", - model_data="s3://bucket/model.tar.gz" + model_data="s3://bucket/model.tar.gz", ) - + mock_create_dir.assert_called_once_with(str(mock_model_path)) - mock_copy_js.assert_called_once_with("s3://bucket/model.tar.gz", "test-js-id", mock_code_dir) + mock_copy_js.assert_called_once_with( + "s3://bucket/model.tar.gz", "test-js-id", mock_code_dir + ) self.assertEqual(result, ({"config": "data"}, True)) - @patch('builtins.input', return_value='') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_mms_creates_structure(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input): + @patch("builtins.input", return_value="") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.compute_hash") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_mms_creates_structure( + self, mock_copy, mock_capture, mock_hash, mock_input + ): """Test prepare_for_mms creates directory structure and files.""" from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + # Create serve.pkl file serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_session = Mock() mock_inference_spec = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_mms( model_path=str(model_path), shared_libs=[], dependencies={}, session=mock_session, image_uri="test-image", - inference_spec=mock_inference_spec + inference_spec=mock_inference_spec, ) - - self.assertEqual(secret_key, "test-secret-key") + mock_inference_spec.prepare.assert_called_once_with(str(model_path)) mock_capture.assert_called_once() - @patch('builtins.input', return_value='') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_mms_raises_on_invalid_dir(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input): + @patch("builtins.input", return_value="") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.compute_hash") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_mms_raises_on_invalid_dir( + self, mock_copy, mock_capture, mock_hash, mock_input + ): """Test prepare_for_mms raises exception for invalid directory.""" from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + mock_session = Mock() - + with self.assertRaises(Exception) as context: prepare_for_mms( model_path=str(file_path), shared_libs=[], dependencies={}, session=mock_session, - image_uri="test-image" + image_uri="test-image", ) self.assertIn("not a valid directory", str(context.exception)) - @patch('builtins.input', return_value='') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_mms_copies_shared_libs(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input): + @patch("builtins.input", return_value="") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.compute_hash") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_mms_copies_shared_libs( + self, mock_copy, mock_capture, mock_hash, mock_input + ): """Test prepare_for_mms copies shared libraries.""" from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + shared_lib = Path(self.temp_dir) / "lib.so" shared_lib.touch() - - mock_gen_key.return_value = "test-key" + mock_hash.return_value = "test-hash" mock_session = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): prepare_for_mms( model_path=str(model_path), shared_libs=[str(shared_lib)], dependencies={}, session=mock_session, - image_uri="test-image" + image_uri="test-image", ) - + # Verify copy2 was called for shared lib self.assertTrue(any(str(shared_lib) in str(call) for call in mock_copy.call_args_list)) diff --git a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_server.py b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_server.py index 02ae4dc596..a19c808264 100644 --- a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_server.py @@ -8,97 +8,94 @@ class TestLocalMultiModelServer(unittest.TestCase): """Test LocalMultiModelServer class.""" - @patch('sagemaker.serve.model_server.multi_model_server.server.Path') + @patch("sagemaker.serve.model_server.multi_model_server.server.Path") def test_start_serving_creates_container(self, mock_path): """Test _start_serving creates and configures container.""" from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer - + server = LocalMultiModelServer() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj - + server._start_serving( client=mock_client, image="test-image:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars={"CUSTOM_VAR": "value"} + env_vars={"CUSTOM_VAR": "value"}, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() call_kwargs = mock_client.containers.run.call_args[1] - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", call_kwargs["environment"]) - self.assertEqual(call_kwargs["environment"]["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret") + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", call_kwargs["environment"]) - @patch('sagemaker.serve.model_server.multi_model_server.server.Path') + @patch("sagemaker.serve.model_server.multi_model_server.server.Path") def test_start_serving_with_no_env_vars(self, mock_path): """Test _start_serving with no custom env vars.""" from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer - + server = LocalMultiModelServer() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj - + server._start_serving( client=mock_client, image="test-image:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars=None + env_vars=None, ) - + call_kwargs = mock_client.containers.run.call_args[1] self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", call_kwargs["environment"]) self.assertIn("SAGEMAKER_PROGRAM", call_kwargs["environment"]) - @patch('sagemaker.serve.model_server.multi_model_server.server.requests.post') - @patch('sagemaker.serve.model_server.multi_model_server.server.get_docker_host') + @patch("sagemaker.serve.model_server.multi_model_server.server.requests.post") + @patch("sagemaker.serve.model_server.multi_model_server.server.get_docker_host") def test_invoke_multi_model_server_serving_success(self, mock_get_host, mock_post): """Test _invoke_multi_model_server_serving successful request.""" from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer - + server = LocalMultiModelServer() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"result": "success"}' mock_post.return_value = mock_response - + result = server._invoke_multi_model_server_serving( - request='{"input": "data"}', - content_type="application/json", - accept="application/json" + request='{"input": "data"}', content_type="application/json", accept="application/json" ) - + self.assertEqual(result, b'{"result": "success"}') mock_post.assert_called_once() call_kwargs = mock_post.call_args[1] self.assertEqual(call_kwargs["headers"]["Content-Type"], "application/json") self.assertEqual(call_kwargs["headers"]["Accept"], "application/json") - @patch('sagemaker.serve.model_server.multi_model_server.server.requests.post') - @patch('sagemaker.serve.model_server.multi_model_server.server.get_docker_host') + @patch("sagemaker.serve.model_server.multi_model_server.server.requests.post") + @patch("sagemaker.serve.model_server.multi_model_server.server.get_docker_host") def test_invoke_multi_model_server_serving_failure(self, mock_get_host, mock_post): """Test _invoke_multi_model_server_serving handles errors.""" from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer - + server = LocalMultiModelServer() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_multi_model_server_serving( request='{"input": "data"}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -106,88 +103,97 @@ def test_invoke_multi_model_server_serving_failure(self, mock_get_host, mock_pos class TestSageMakerMultiModelServer(unittest.TestCase): """Test SageMakerMultiModelServer class.""" - @patch('sagemaker.serve.model_server.multi_model_server.server.S3Uploader') - @patch('sagemaker.serve.model_server.multi_model_server.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.multi_model_server.server.fw_utils') - @patch('sagemaker.serve.model_server.multi_model_server.server._is_s3_uri') - def test_upload_server_artifacts_with_s3_path(self, mock_is_s3, mock_fw_utils, mock_determine, mock_uploader): + @patch("sagemaker.serve.model_server.multi_model_server.server.S3Uploader") + @patch("sagemaker.serve.model_server.multi_model_server.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.multi_model_server.server.fw_utils") + @patch("sagemaker.serve.model_server.multi_model_server.server._is_s3_uri") + def test_upload_server_artifacts_with_s3_path( + self, mock_is_s3, mock_fw_utils, mock_determine, mock_uploader + ): """Test _upload_server_artifacts with S3 path.""" from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer - + server = SageMakerMultiModelServer() mock_is_s3.return_value = True mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + model_data, env_vars = server._upload_server_artifacts( model_path="s3://bucket/model", secret_key="test-key", sagemaker_session=mock_session, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNotNone(model_data) self.assertEqual(model_data["S3DataSource"]["S3Uri"], "s3://bucket/model/") - @patch('sagemaker.serve.model_server.multi_model_server.server.S3Uploader') - @patch('sagemaker.serve.model_server.multi_model_server.server.s3_path_join') - @patch('sagemaker.serve.model_server.multi_model_server.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.multi_model_server.server.parse_s3_url') - @patch('sagemaker.serve.model_server.multi_model_server.server.fw_utils') - @patch('sagemaker.serve.model_server.multi_model_server.server._is_s3_uri') - @patch('sagemaker.serve.model_server.multi_model_server.server.Path') - def test_upload_server_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_s3_join, mock_uploader): + @patch("sagemaker.serve.model_server.multi_model_server.server.S3Uploader") + @patch("sagemaker.serve.model_server.multi_model_server.server.s3_path_join") + @patch("sagemaker.serve.model_server.multi_model_server.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.multi_model_server.server.parse_s3_url") + @patch("sagemaker.serve.model_server.multi_model_server.server.fw_utils") + @patch("sagemaker.serve.model_server.multi_model_server.server._is_s3_uri") + @patch("sagemaker.serve.model_server.multi_model_server.server.Path") + def test_upload_server_artifacts_uploads_to_s3( + self, + mock_path, + mock_is_s3, + mock_fw_utils, + mock_parse, + mock_determine, + mock_s3_join, + mock_uploader, + ): """Test _upload_server_artifacts uploads artifacts to S3.""" from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer - + server = SageMakerMultiModelServer() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_s3_join.return_value = "s3://bucket/code_prefix/code" mock_uploader.upload.return_value = "s3://bucket/code_prefix/code" - + mock_path_obj = Mock() mock_code_dir = Mock() mock_path_obj.joinpath.return_value = mock_code_dir mock_path.return_value = mock_path_obj - + mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + model_data, env_vars = server._upload_server_artifacts( model_path="/local/model", secret_key="test-key", sagemaker_session=mock_session, s3_model_data_url="s3://bucket/prefix", image="test-image", - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertIsNotNone(model_data) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key") + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) - @patch('sagemaker.serve.model_server.multi_model_server.server._is_s3_uri') + @patch("sagemaker.serve.model_server.multi_model_server.server._is_s3_uri") def test_upload_server_artifacts_no_upload(self, mock_is_s3): """Test _upload_server_artifacts without uploading.""" from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer - + server = SageMakerMultiModelServer() mock_is_s3.return_value = False mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + model_data, env_vars = server._upload_server_artifacts( model_path="/local/model", secret_key="test-key", sagemaker_session=mock_session, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNone(model_data) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) class TestUpdateEnvVars(unittest.TestCase): @@ -196,17 +202,17 @@ class TestUpdateEnvVars(unittest.TestCase): def test_update_env_vars_with_none(self): """Test _update_env_vars with None input.""" from sagemaker.serve.model_server.multi_model_server.server import _update_env_vars - + result = _update_env_vars(None) self.assertIsInstance(result, dict) def test_update_env_vars_with_custom_vars(self): """Test _update_env_vars with custom variables.""" from sagemaker.serve.model_server.multi_model_server.server import _update_env_vars - + custom_vars = {"CUSTOM_KEY": "custom_value"} result = _update_env_vars(custom_vars) - + self.assertIn("CUSTOM_KEY", result) self.assertEqual(result["CUSTOM_KEY"], "custom_value") diff --git a/sagemaker-serve/tests/unit/model_server/test_smd_prepare.py b/sagemaker-serve/tests/unit/model_server/test_smd_prepare.py index 4d5a0a7de8..aa21763180 100644 --- a/sagemaker-serve/tests/unit/model_server/test_smd_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_smd_prepare.py @@ -17,114 +17,102 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_server.smd.prepare.compute_hash') - @patch('sagemaker.serve.model_server.smd.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.smd.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_smd_with_inference_spec(self, mock_copy, mock_capture, mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.smd.prepare.compute_hash") + @patch("sagemaker.serve.model_server.smd.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_smd_with_inference_spec(self, mock_copy, mock_capture, mock_hash): """Test prepare_for_smd with InferenceSpec.""" from sagemaker.serve.model_server.smd.prepare import prepare_for_smd from sagemaker.serve.spec.inference_spec import InferenceSpec - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_inference_spec = Mock(spec=InferenceSpec) - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_smd( model_path=str(model_path), shared_libs=[], dependencies={}, - inference_spec=mock_inference_spec + inference_spec=mock_inference_spec, ) - - self.assertEqual(secret_key, "test-secret-key") + mock_inference_spec.prepare.assert_called_once_with(str(model_path)) - @patch('os.rename') - @patch('sagemaker.serve.model_server.smd.prepare.compute_hash') - @patch('sagemaker.serve.model_server.smd.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.smd.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_smd_with_custom_orchestrator(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_rename): + @patch("os.rename") + @patch("sagemaker.serve.model_server.smd.prepare.compute_hash") + @patch("sagemaker.serve.model_server.smd.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_smd_with_custom_orchestrator( + self, mock_copy, mock_capture, mock_hash, mock_rename + ): """Test prepare_for_smd with CustomOrchestrator.""" from sagemaker.serve.model_server.smd.prepare import prepare_for_smd from sagemaker.serve.spec.inference_base import CustomOrchestrator - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_orchestrator = Mock(spec=CustomOrchestrator) - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_smd( model_path=str(model_path), shared_libs=[], dependencies={}, - inference_spec=mock_orchestrator + inference_spec=mock_orchestrator, ) - - self.assertEqual(secret_key, "test-secret-key") + # Verify custom_execution_inference.py was copied and renamed mock_rename.assert_called_once() - @patch('sagemaker.serve.model_server.smd.prepare.compute_hash') - @patch('sagemaker.serve.model_server.smd.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.smd.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_smd_with_shared_libs(self, mock_copy, mock_capture, mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.smd.prepare.compute_hash") + @patch("sagemaker.serve.model_server.smd.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_smd_with_shared_libs(self, mock_copy, mock_capture, mock_hash): """Test prepare_for_smd copies shared libraries.""" from sagemaker.serve.model_server.smd.prepare import prepare_for_smd - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + shared_lib = Path(self.temp_dir) / "lib.so" shared_lib.touch() - - mock_gen_key.return_value = "test-key" + mock_hash.return_value = "test-hash" - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): prepare_for_smd( - model_path=str(model_path), - shared_libs=[str(shared_lib)], - dependencies={} + model_path=str(model_path), shared_libs=[str(shared_lib)], dependencies={} ) - + # Verify copy2 was called for shared lib self.assertTrue(any(str(shared_lib) in str(call) for call in mock_copy.call_args_list)) def test_prepare_for_smd_invalid_dir(self): """Test prepare_for_smd raises exception for invalid directory.""" from sagemaker.serve.model_server.smd.prepare import prepare_for_smd - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + with self.assertRaises(Exception) as context: - prepare_for_smd( - model_path=str(file_path), - shared_libs=[], - dependencies={} - ) + prepare_for_smd(model_path=str(file_path), shared_libs=[], dependencies={}) self.assertIn("not a valid directory", str(context.exception)) diff --git a/sagemaker-serve/tests/unit/model_server/test_smd_server.py b/sagemaker-serve/tests/unit/model_server/test_smd_server.py index 8bf7d4424e..c88331219f 100644 --- a/sagemaker-serve/tests/unit/model_server/test_smd_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_smd_server.py @@ -7,80 +7,78 @@ class TestSageMakerSmdServer(unittest.TestCase): """Test SageMakerSmdServer class.""" - @patch('sagemaker.serve.model_server.smd.server._is_s3_uri') + @patch("sagemaker.serve.model_server.smd.server._is_s3_uri") def test_upload_smd_artifacts_with_s3_path(self, mock_is_s3): """Test _upload_smd_artifacts with S3 path.""" from sagemaker.serve.model_server.smd.server import SageMakerSmdServer - + server = SageMakerSmdServer() mock_is_s3.return_value = True mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_smd_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertEqual(s3_path, "s3://bucket/model") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key") self.assertIn("SAGEMAKER_INFERENCE_CODE_DIRECTORY", env_vars) - @patch('sagemaker.serve.model_server.smd.server.upload') - @patch('sagemaker.serve.model_server.smd.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.smd.server.parse_s3_url') - @patch('sagemaker.serve.model_server.smd.server.fw_utils') - @patch('sagemaker.serve.model_server.smd.server._is_s3_uri') - def test_upload_smd_artifacts_uploads_to_s3(self, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_upload): + @patch("sagemaker.serve.model_server.smd.server.upload") + @patch("sagemaker.serve.model_server.smd.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.smd.server.parse_s3_url") + @patch("sagemaker.serve.model_server.smd.server.fw_utils") + @patch("sagemaker.serve.model_server.smd.server._is_s3_uri") + def test_upload_smd_artifacts_uploads_to_s3( + self, mock_is_s3, mock_fw_utils, mock_parse, mock_determine, mock_upload + ): """Test _upload_smd_artifacts uploads to S3.""" from sagemaker.serve.model_server.smd.server import SageMakerSmdServer - + server = SageMakerSmdServer() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_upload.return_value = "s3://bucket/code_prefix/model.tar.gz" - + mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_smd_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", s3_model_data_url="s3://bucket/prefix", image="test-image", - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertEqual(s3_path, "s3://bucket/code_prefix/model.tar.gz") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) self.assertIn("SAGEMAKER_INFERENCE_CODE", env_vars) mock_upload.assert_called_once() - @patch('sagemaker.serve.model_server.smd.server._is_s3_uri') + @patch("sagemaker.serve.model_server.smd.server._is_s3_uri") def test_upload_smd_artifacts_no_upload(self, mock_is_s3): """Test _upload_smd_artifacts without uploading.""" from sagemaker.serve.model_server.smd.server import SageMakerSmdServer - + server = SageMakerSmdServer() mock_is_s3.return_value = False mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_smd_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNone(s3_path) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_INFERENCE_CODE_DIRECTORY", env_vars) if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_tei_server.py b/sagemaker-serve/tests/unit/model_server/test_tei_server.py index c280e4b546..4fff01710b 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tei_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_tei_server.py @@ -8,99 +8,99 @@ class TestLocalTeiServing(unittest.TestCase): """Test LocalTeiServing class.""" - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server.Path') - @patch('sagemaker.serve.model_server.tei.server.DeviceRequest') + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server.Path") + @patch("sagemaker.serve.model_server.tei.server.DeviceRequest") def test_start_tei_serving(self, mock_device_req, mock_path, mock_update_env): """Test _start_tei_serving creates container.""" from sagemaker.serve.model_server.tei.server import LocalTeiServing - + server = LocalTeiServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj mock_device_req.return_value = Mock() mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} - + server._start_tei_serving( client=mock_client, image="tei:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars={"CUSTOM_VAR": "value"} + env_vars={"CUSTOM_VAR": "value"}, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server.Path') - @patch('sagemaker.serve.model_server.tei.server.DeviceRequest') + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server.Path") + @patch("sagemaker.serve.model_server.tei.server.DeviceRequest") def test_start_tei_serving_adds_secret_key(self, mock_device_req, mock_path, mock_update_env): - """Test _start_tei_serving adds secret key to env vars.""" + """Test _start_tei_serving no longer adds secret key to env vars.""" from sagemaker.serve.model_server.tei.server import LocalTeiServing - + server = LocalTeiServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj mock_device_req.return_value = Mock() mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} - + env_vars = {"CUSTOM_VAR": "value"} server._start_tei_serving( client=mock_client, image="tei:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars=env_vars + env_vars=env_vars, ) - - # Verify secret key was added to env_vars - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret") - @patch('sagemaker.serve.model_server.tei.server.requests.post') - @patch('sagemaker.serve.model_server.tei.server.get_docker_host') + # Verify secret key is NOT added to env_vars + self.assertNotIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + + @patch("sagemaker.serve.model_server.tei.server.requests.post") + @patch("sagemaker.serve.model_server.tei.server.get_docker_host") def test_invoke_tei_serving_success(self, mock_get_host, mock_post): """Test _invoke_tei_serving successful request.""" from sagemaker.serve.model_server.tei.server import LocalTeiServing - + server = LocalTeiServing() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"embeddings": [[0.1, 0.2]]}' mock_post.return_value = mock_response - + result = server._invoke_tei_serving( request='{"inputs": "test text"}', content_type="application/json", - accept="application/json" + accept="application/json", ) - + self.assertEqual(result, b'{"embeddings": [[0.1, 0.2]]}') mock_post.assert_called_once() - @patch('sagemaker.serve.model_server.tei.server.requests.post') - @patch('sagemaker.serve.model_server.tei.server.get_docker_host') + @patch("sagemaker.serve.model_server.tei.server.requests.post") + @patch("sagemaker.serve.model_server.tei.server.get_docker_host") def test_invoke_tei_serving_failure(self, mock_get_host, mock_post): """Test _invoke_tei_serving handles errors.""" from sagemaker.serve.model_server.tei.server import LocalTeiServing - + server = LocalTeiServing() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_tei_serving( request='{"inputs": "test"}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -108,40 +108,48 @@ def test_invoke_tei_serving_failure(self, mock_get_host, mock_post): class TestSageMakerTeiServing(unittest.TestCase): """Test SageMakerTeiServing class.""" - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server._is_s3_uri") def test_upload_tei_artifacts_with_s3_path(self, mock_is_s3, mock_update_env): """Test _upload_tei_artifacts with S3 path.""" from sagemaker.serve.model_server.tei.server import SageMakerTeiServing - + server = SageMakerTeiServing() mock_is_s3.return_value = True mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} mock_session = Mock() - + model_data, env_vars = server._upload_tei_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNotNone(model_data) self.assertEqual(model_data["S3DataSource"]["S3Uri"], "s3://bucket/model/") - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server.S3Uploader') - @patch('sagemaker.serve.model_server.tei.server.s3_path_join') - @patch('sagemaker.serve.model_server.tei.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.tei.server.parse_s3_url') - @patch('sagemaker.serve.model_server.tei.server.fw_utils') - @patch('sagemaker.serve.model_server.tei.server._is_s3_uri') - @patch('sagemaker.serve.model_server.tei.server.Path') - def test_upload_tei_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_s3_join, - mock_uploader, mock_update_env): + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server.S3Uploader") + @patch("sagemaker.serve.model_server.tei.server.s3_path_join") + @patch("sagemaker.serve.model_server.tei.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tei.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tei.server.fw_utils") + @patch("sagemaker.serve.model_server.tei.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tei.server.Path") + def test_upload_tei_artifacts_uploads_to_s3( + self, + mock_path, + mock_is_s3, + mock_fw_utils, + mock_parse, + mock_determine, + mock_s3_join, + mock_uploader, + mock_update_env, + ): """Test _upload_tei_artifacts uploads to S3.""" from sagemaker.serve.model_server.tei.server import SageMakerTeiServing - + server = SageMakerTeiServing() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") @@ -149,43 +157,41 @@ def test_upload_tei_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw mock_s3_join.return_value = "s3://bucket/code_prefix/code" mock_uploader.upload.return_value = "s3://bucket/code_prefix/code" mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} - + mock_path_obj = Mock() mock_code_dir = Mock() mock_path_obj.joinpath.return_value = mock_code_dir mock_path.return_value = mock_path_obj - + mock_session = Mock() - + model_data, env_vars = server._upload_tei_artifacts( model_path="/local/model", sagemaker_session=mock_session, s3_model_data_url="s3://bucket/prefix", image="test-image", env_vars={"CUSTOM": "var"}, - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertIsNotNone(model_data) mock_uploader.upload.assert_called_once() - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server._is_s3_uri") def test_upload_tei_artifacts_no_upload(self, mock_is_s3, mock_update_env): """Test _upload_tei_artifacts without uploading.""" from sagemaker.serve.model_server.tei.server import SageMakerTeiServing - + server = SageMakerTeiServing() mock_is_s3.return_value = False mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} mock_session = Mock() - + model_data, env_vars = server._upload_tei_artifacts( - model_path="/local/model", - sagemaker_session=mock_session, - should_upload_artifacts=False + model_path="/local/model", sagemaker_session=mock_session, should_upload_artifacts=False ) - + self.assertIsNone(model_data) @@ -195,7 +201,7 @@ class TestUpdateEnvVars(unittest.TestCase): def test_update_env_vars_with_none(self): """Test _update_env_vars with None input.""" from sagemaker.serve.model_server.tei.server import _update_env_vars - + result = _update_env_vars(None) self.assertIn("HF_HOME", result) self.assertIn("HUGGINGFACE_HUB_CACHE", result) @@ -203,10 +209,10 @@ def test_update_env_vars_with_none(self): def test_update_env_vars_with_custom_vars(self): """Test _update_env_vars with custom variables.""" from sagemaker.serve.model_server.tei.server import _update_env_vars - + custom_vars = {"CUSTOM_KEY": "custom_value"} result = _update_env_vars(custom_vars) - + self.assertIn("CUSTOM_KEY", result) self.assertIn("HF_HOME", result) self.assertEqual(result["CUSTOM_KEY"], "custom_value") diff --git a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_inference.py b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_inference.py index 14aad247c3..cc6d7b967e 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_inference.py +++ b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_inference.py @@ -16,6 +16,7 @@ class TestTensorFlowServingInference(unittest.TestCase): def test_input_handler_logic(self): """Test input_handler logic.""" + def input_handler(data, context, schema_builder): read_data = data.read() if hasattr(schema_builder, "custom_input_translator"): @@ -27,32 +28,33 @@ def input_handler(data, context, schema_builder): io.BytesIO(read_data), context.request_content_type ) return json.dumps({"instances": deserialized_data}) - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value=[[1, 2, 3]]) - + mock_data = Mock() mock_data.read = Mock(return_value=b'{"data": [1, 2, 3]}') - + mock_context = Mock() mock_context.request_content_type = "application/json" - + result = input_handler(mock_data, mock_context, schema_builder) - + expected = json.dumps({"instances": [[1, 2, 3]]}) self.assertEqual(result, expected) def test_output_handler_logic(self): """Test output_handler logic.""" + def output_handler(data, context, schema_builder): if data.status_code != 200: raise ValueError(data.content.decode("utf-8")) - + response_content_type = context.accept_header prediction = data.content prediction_dict = json.loads(prediction.decode("utf-8")) - + if hasattr(schema_builder, "custom_output_translator"): return ( schema_builder.custom_output_translator.serialize( @@ -61,58 +63,66 @@ def output_handler(data, context, schema_builder): response_content_type, ) else: - return schema_builder.output_serializer.serialize(prediction_dict["predictions"]), response_content_type - + return ( + schema_builder.output_serializer.serialize(prediction_dict["predictions"]), + response_content_type, + ) + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + mock_data = Mock() mock_data.status_code = 200 - mock_data.content = json.dumps({"predictions": [0.1, 0.9]}).encode('utf-8') - + mock_data.content = json.dumps({"predictions": [0.1, 0.9]}).encode("utf-8") + mock_context = Mock() mock_context.accept_header = "application/json" - + result, content_type = output_handler(mock_data, mock_context, schema_builder) - + self.assertEqual(result, b'{"predictions": [0.1, 0.9]}') self.assertEqual(content_type, "application/json") def test_convert_numpy_array_logic(self): """Test conversion of numpy array.""" + def _convert_for_serialization(deserialized_data): if isinstance(deserialized_data, np.ndarray): return deserialized_data.tolist() return deserialized_data - + data = np.array([[1, 2, 3], [4, 5, 6]]) result = _convert_for_serialization(data) - + self.assertEqual(result, [[1, 2, 3], [4, 5, 6]]) def test_convert_pandas_dataframe_logic(self): """Test conversion of pandas DataFrame.""" + def _convert_for_serialization(deserialized_data): if isinstance(deserialized_data, pd.DataFrame): return deserialized_data.to_dict(orient="list") return deserialized_data - - data = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + + data = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) result = _convert_for_serialization(data) - - self.assertEqual(result, {'a': [1, 2], 'b': [3, 4]}) + + self.assertEqual(result, {"a": [1, 2], "b": [3, 4]}) def test_convert_pandas_series_logic(self): """Test conversion of pandas Series.""" + def _convert_for_serialization(deserialized_data): if isinstance(deserialized_data, pd.Series): return deserialized_data.tolist() return deserialized_data - + data = pd.Series([1, 2, 3, 4]) result = _convert_for_serialization(data) - + self.assertEqual(result, [1, 2, 3, 4]) diff --git a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_prepare.py b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_prepare.py index e6ca1161dc..c78797be04 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_prepare.py @@ -17,122 +17,114 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_tf_serving_success(self, mock_copy, mock_capture, mock_gen_key, - mock_hash, mock_get_saved, mock_move): + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents") + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_tf_serving_success( + self, mock_copy, mock_capture, mock_hash, mock_get_saved, mock_move + ): """Test prepare_for_tf_serving creates structure successfully.""" from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_get_saved.return_value = Path(self.temp_dir) / "saved_model" - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_tf_serving( - model_path=str(model_path), - shared_libs=[], - dependencies={} + model_path=str(model_path), shared_libs=[], dependencies={} ) - - self.assertEqual(secret_key, "test-secret-key") + mock_capture.assert_called_once() mock_move.assert_called_once() - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_tf_serving_no_saved_model(self, mock_copy, mock_capture, mock_gen_key, - mock_hash, mock_get_saved): + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_tf_serving_no_saved_model( + self, mock_copy, mock_capture, mock_hash, mock_get_saved + ): """Test prepare_for_tf_serving raises error when SavedModel not found.""" from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_get_saved.return_value = None - + with self.assertRaises(ValueError) as context: - prepare_for_tf_serving( - model_path=str(model_path), - shared_libs=[], - dependencies={} - ) + prepare_for_tf_serving(model_path=str(model_path), shared_libs=[], dependencies={}) self.assertIn("SavedModel is not found", str(context.exception)) - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_tf_serving_with_shared_libs(self, mock_copy, mock_capture, mock_gen_key, - mock_hash, mock_get_saved, mock_move): + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents") + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_tf_serving_with_shared_libs( + self, mock_copy, mock_capture, mock_hash, mock_get_saved, mock_move + ): """Test prepare_for_tf_serving copies shared libraries.""" from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + shared_lib = Path(self.temp_dir) / "lib.so" shared_lib.touch() - - mock_gen_key.return_value = "test-key" + mock_hash.return_value = "test-hash" mock_get_saved.return_value = Path(self.temp_dir) / "saved_model" - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): prepare_for_tf_serving( - model_path=str(model_path), - shared_libs=[str(shared_lib)], - dependencies={} + model_path=str(model_path), shared_libs=[str(shared_lib)], dependencies={} ) - + # Verify copy2 was called for shared lib self.assertTrue(any(str(shared_lib) in str(call) for call in mock_copy.call_args_list)) - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_tf_serving_invalid_dir(self, mock_copy, mock_capture, mock_gen_key, - mock_hash, mock_get_saved): + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_tf_serving_invalid_dir( + self, mock_copy, mock_capture, mock_hash, mock_get_saved + ): """Test prepare_for_tf_serving raises exception for invalid directory.""" from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + with self.assertRaises(Exception) as context: - prepare_for_tf_serving( - model_path=str(file_path), - shared_libs=[], - dependencies={} - ) + prepare_for_tf_serving(model_path=str(file_path), shared_libs=[], dependencies={}) self.assertIn("not a valid directory", str(context.exception)) diff --git a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_server.py b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_server.py index d0bac2e5dc..4013b5c11c 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_server.py @@ -8,70 +8,68 @@ class TestLocalTensorflowServing(unittest.TestCase): """Test LocalTensorflowServing class.""" - @patch('sagemaker.serve.model_server.tensorflow_serving.server.Path') + @patch("sagemaker.serve.model_server.tensorflow_serving.server.Path") def test_start_tensorflow_serving(self, mock_path): """Test _start_tensorflow_serving creates container.""" from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing - + server = LocalTensorflowServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value = mock_path_obj - + server._start_tensorflow_serving( client=mock_client, image="tensorflow-serving:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars={"CUSTOM_VAR": "value"} + env_vars={"CUSTOM_VAR": "value"}, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() call_kwargs = mock_client.containers.run.call_args[1] - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", call_kwargs["environment"]) - self.assertEqual(call_kwargs["environment"]["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret") self.assertEqual(call_kwargs["environment"]["CUSTOM_VAR"], "value") - @patch('sagemaker.serve.model_server.tensorflow_serving.server.requests.post') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.get_docker_host') + @patch("sagemaker.serve.model_server.tensorflow_serving.server.requests.post") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.get_docker_host") def test_invoke_tensorflow_serving_success(self, mock_get_host, mock_post): """Test _invoke_tensorflow_serving successful request.""" from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing - + server = LocalTensorflowServing() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"predictions": [[0.1, 0.9]]}' mock_post.return_value = mock_response - + result = server._invoke_tensorflow_serving( request='{"instances": [[1, 2, 3]]}', content_type="application/json", - accept="application/json" + accept="application/json", ) - + self.assertEqual(result, b'{"predictions": [[0.1, 0.9]]}') mock_post.assert_called_once() - @patch('sagemaker.serve.model_server.tensorflow_serving.server.requests.post') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.get_docker_host') + @patch("sagemaker.serve.model_server.tensorflow_serving.server.requests.post") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.get_docker_host") def test_invoke_tensorflow_serving_failure(self, mock_get_host, mock_post): """Test _invoke_tensorflow_serving handles errors.""" from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing - + server = LocalTensorflowServing() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_tensorflow_serving( request='{"instances": [[1, 2, 3]]}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -79,78 +77,84 @@ def test_invoke_tensorflow_serving_failure(self, mock_get_host, mock_post): class TestSageMakerTensorflowServing(unittest.TestCase): """Test SageMakerTensorflowServing class.""" - @patch('sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri") def test_upload_tensorflow_serving_artifacts_with_s3_path(self, mock_is_s3): """Test _upload_tensorflow_serving_artifacts with S3 path.""" - from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing - + from sagemaker.serve.model_server.tensorflow_serving.server import ( + SageMakerTensorflowServing, + ) + server = SageMakerTensorflowServing() mock_is_s3.return_value = True mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_tensorflow_serving_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertEqual(s3_path, "s3://bucket/model") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key") - - @patch('sagemaker.serve.model_server.tensorflow_serving.server.upload') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.parse_s3_url') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.fw_utils') - @patch('sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri') - def test_upload_tensorflow_serving_artifacts_uploads_to_s3(self, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_upload): + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) + + @patch("sagemaker.serve.model_server.tensorflow_serving.server.upload") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.fw_utils") + @patch("sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri") + def test_upload_tensorflow_serving_artifacts_uploads_to_s3( + self, mock_is_s3, mock_fw_utils, mock_parse, mock_determine, mock_upload + ): """Test _upload_tensorflow_serving_artifacts uploads to S3.""" - from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing - + from sagemaker.serve.model_server.tensorflow_serving.server import ( + SageMakerTensorflowServing, + ) + server = SageMakerTensorflowServing() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_upload.return_value = "s3://bucket/code_prefix/model.tar.gz" - + mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_tensorflow_serving_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", s3_model_data_url="s3://bucket/prefix", image="test-image", - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertEqual(s3_path, "s3://bucket/code_prefix/model.tar.gz") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) mock_upload.assert_called_once() - @patch('sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri") def test_upload_tensorflow_serving_artifacts_no_upload(self, mock_is_s3): """Test _upload_tensorflow_serving_artifacts without uploading.""" - from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing - + from sagemaker.serve.model_server.tensorflow_serving.server import ( + SageMakerTensorflowServing, + ) + server = SageMakerTensorflowServing() mock_is_s3.return_value = False mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_tensorflow_serving_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNone(s3_path) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_tgi_prepare.py b/sagemaker-serve/tests/unit/model_server/test_tgi_prepare.py index 79417b7b74..992b83d2be 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tgi_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_tgi_prepare.py @@ -18,187 +18,175 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('tarfile.open') - @patch('sagemaker.serve.model_server.tgi.prepare.custom_extractall_tarfile') + @patch("tarfile.open") + @patch("sagemaker.serve.model_server.tgi.prepare.custom_extractall_tarfile") def test_extract_js_resource(self, mock_extract, mock_tarfile): """Test _extract_js_resource extracts tarball.""" from sagemaker.serve.model_server.tgi.prepare import _extract_js_resource - + js_model_dir = self.temp_dir code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + # Create a dummy tar file tar_path = Path(js_model_dir) / "infer-prepack-test-id.tar.gz" tar_path.touch() - + mock_tar = Mock() mock_tarfile.return_value.__enter__.return_value = mock_tar - + _extract_js_resource(js_model_dir, code_dir, "test-id") - + mock_extract.assert_called_once_with(mock_tar, code_dir) - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') - @patch('sagemaker.serve.model_server.tgi.prepare._tmpdir') - @patch('sagemaker.serve.model_server.tgi.prepare._extract_js_resource') - def test_copy_jumpstart_artifacts_with_tarball(self, mock_extract, mock_tmpdir, mock_s3_downloader): + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") + @patch("sagemaker.serve.model_server.tgi.prepare._tmpdir") + @patch("sagemaker.serve.model_server.tgi.prepare._extract_js_resource") + def test_copy_jumpstart_artifacts_with_tarball( + self, mock_extract, mock_tmpdir, mock_s3_downloader + ): """Test _copy_jumpstart_artifacts with tar.gz file.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + # Create config.json config_file = code_dir / "config.json" config_data = {"model_type": "gpt2"} config_file.write_text(json.dumps(config_data)) - + mock_tmpdir.return_value.__enter__.return_value = self.temp_dir mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - + result = _copy_jumpstart_artifacts( - model_data="s3://bucket/model.tar.gz", - js_id="test-id", - code_dir=code_dir + model_data="s3://bucket/model.tar.gz", js_id="test-id", code_dir=code_dir ) - + self.assertEqual(result, (config_data, True)) mock_downloader_instance.download.assert_called_once() - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") def test_copy_jumpstart_artifacts_uncompressed(self, mock_s3_downloader): """Test _copy_jumpstart_artifacts with uncompressed data.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + config_file = code_dir / "config.json" config_data = {"model_type": "bert"} config_file.write_text(json.dumps(config_data)) - + mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - + result = _copy_jumpstart_artifacts( - model_data="s3://bucket/model/", - js_id="test-id", - code_dir=code_dir + model_data="s3://bucket/model/", js_id="test-id", code_dir=code_dir ) - + self.assertEqual(result, (config_data, True)) - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") def test_copy_jumpstart_artifacts_with_dict(self, mock_s3_downloader): """Test _copy_jumpstart_artifacts with dict model_data.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + config_file = code_dir / "config.json" config_file.write_text(json.dumps({"model_type": "t5"})) - + mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - - model_data = { - "S3DataSource": { - "S3Uri": "s3://bucket/model/" - } - } - + + model_data = {"S3DataSource": {"S3Uri": "s3://bucket/model/"}} + result = _copy_jumpstart_artifacts( - model_data=model_data, - js_id="test-id", - code_dir=code_dir + model_data=model_data, js_id="test-id", code_dir=code_dir ) - + self.assertIsNotNone(result) mock_downloader_instance.download.assert_called_once_with("s3://bucket/model/", code_dir) - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") def test_copy_jumpstart_artifacts_invalid_format(self, mock_s3_downloader): """Test _copy_jumpstart_artifacts raises error for invalid format.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - + with self.assertRaises(ValueError): _copy_jumpstart_artifacts( - model_data={"invalid": "format"}, - js_id="test-id", - code_dir=code_dir + model_data={"invalid": "format"}, js_id="test-id", code_dir=code_dir ) - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") def test_copy_jumpstart_artifacts_no_config(self, mock_s3_downloader): """Test _copy_jumpstart_artifacts when config.json doesn't exist.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - + result = _copy_jumpstart_artifacts( - model_data="s3://bucket/model/", - js_id="test-id", - code_dir=code_dir + model_data="s3://bucket/model/", js_id="test-id", code_dir=code_dir ) - + self.assertEqual(result, (None, True)) - @patch('sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage') - @patch('sagemaker.serve.model_server.tgi.prepare._check_disk_space') + @patch("sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.tgi.prepare._check_disk_space") def test_create_dir_structure(self, mock_disk_space, mock_docker_disk): """Test _create_dir_structure creates directories.""" from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure - + model_path = Path(self.temp_dir) / "model" model_path_obj, code_dir = _create_dir_structure(str(model_path)) - + self.assertTrue(model_path.exists()) self.assertTrue(code_dir.exists()) mock_disk_space.assert_called_once() mock_docker_disk.assert_called_once() - @patch('sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage') - @patch('sagemaker.serve.model_server.tgi.prepare._check_disk_space') + @patch("sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.tgi.prepare._check_disk_space") def test_create_dir_structure_raises_on_file(self, mock_disk_space, mock_docker_disk): """Test _create_dir_structure raises ValueError for file path.""" from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + with self.assertRaises(ValueError): _create_dir_structure(str(file_path)) - @patch('sagemaker.serve.model_server.tgi.prepare._copy_jumpstart_artifacts') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') + @patch("sagemaker.serve.model_server.tgi.prepare._copy_jumpstart_artifacts") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") def test_prepare_tgi_js_resources(self, mock_create_dir, mock_copy_js): """Test prepare_tgi_js_resources.""" from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources - + mock_model_path = Path(self.temp_dir) / "model" mock_code_dir = mock_model_path / "code" mock_create_dir.return_value = (mock_model_path, mock_code_dir) mock_copy_js.return_value = ({"config": "data"}, True) - + result = prepare_tgi_js_resources( model_path=str(mock_model_path), js_id="test-js-id", - model_data="s3://bucket/model.tar.gz" + model_data="s3://bucket/model.tar.gz", ) - + mock_create_dir.assert_called_once() mock_copy_js.assert_called_once() self.assertEqual(result, ({"config": "data"}, True)) diff --git a/sagemaker-serve/tests/unit/model_server/test_tgi_server.py b/sagemaker-serve/tests/unit/model_server/test_tgi_server.py index 12f84a747b..244ae6462c 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tgi_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_tgi_server.py @@ -8,101 +8,99 @@ class TestLocalTgiServing(unittest.TestCase): """Test LocalTgiServing class.""" - @patch('sagemaker.serve.model_server.tgi.server.Path') - @patch('sagemaker.serve.model_server.tgi.server.DeviceRequest') + @patch("sagemaker.serve.model_server.tgi.server.Path") + @patch("sagemaker.serve.model_server.tgi.server.DeviceRequest") def test_start_tgi_serving_jumpstart(self, mock_device_req, mock_path): """Test _start_tgi_serving with jumpstart=True.""" from sagemaker.serve.model_server.tgi.server import LocalTgiServing - + server = LocalTgiServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj mock_device_req.return_value = Mock() - + server._start_tgi_serving( client=mock_client, image="test-image:latest", model_path="/path/to/model", secret_key="test-secret", env_vars={"CUSTOM_VAR": "value"}, - jumpstart=True + jumpstart=True, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() call_args = mock_client.containers.run.call_args # Check that the command includes --model-id self.assertEqual(call_args[0][1][0], "--model-id") - @patch('sagemaker.serve.model_server.tgi.server._update_env_vars') - @patch('sagemaker.serve.model_server.tgi.server.Path') - @patch('sagemaker.serve.model_server.tgi.server.DeviceRequest') + @patch("sagemaker.serve.model_server.tgi.server._update_env_vars") + @patch("sagemaker.serve.model_server.tgi.server.Path") + @patch("sagemaker.serve.model_server.tgi.server.DeviceRequest") def test_start_tgi_serving_non_jumpstart(self, mock_device_req, mock_path, mock_update_env): """Test _start_tgi_serving with jumpstart=False.""" from sagemaker.serve.model_server.tgi.server import LocalTgiServing - + server = LocalTgiServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj mock_device_req.return_value = Mock() mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} - + server._start_tgi_serving( client=mock_client, image="test-image:latest", model_path="/path/to/model", secret_key="test-secret", env_vars={"CUSTOM_VAR": "value"}, - jumpstart=False + jumpstart=False, ) - + self.assertEqual(server.container, mock_container) mock_update_env.assert_called_once() - @patch('sagemaker.serve.model_server.tgi.server.requests.post') - @patch('sagemaker.serve.model_server.tgi.server.get_docker_host') + @patch("sagemaker.serve.model_server.tgi.server.requests.post") + @patch("sagemaker.serve.model_server.tgi.server.get_docker_host") def test_invoke_tgi_serving_success(self, mock_get_host, mock_post): """Test _invoke_tgi_serving successful request.""" from sagemaker.serve.model_server.tgi.server import LocalTgiServing - + server = LocalTgiServing() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"generated_text": "result"}' mock_post.return_value = mock_response - + result = server._invoke_tgi_serving( - request='{"inputs": "test"}', - content_type="application/json", - accept="application/json" + request='{"inputs": "test"}', content_type="application/json", accept="application/json" ) - + self.assertEqual(result, b'{"generated_text": "result"}') mock_post.assert_called_once() - @patch('sagemaker.serve.model_server.tgi.server.requests.post') - @patch('sagemaker.serve.model_server.tgi.server.get_docker_host') + @patch("sagemaker.serve.model_server.tgi.server.requests.post") + @patch("sagemaker.serve.model_server.tgi.server.get_docker_host") def test_invoke_tgi_serving_failure(self, mock_get_host, mock_post): """Test _invoke_tgi_serving handles errors.""" from sagemaker.serve.model_server.tgi.server import LocalTgiServing - + server = LocalTgiServing() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_tgi_serving( request='{"inputs": "test"}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -110,70 +108,78 @@ def test_invoke_tgi_serving_failure(self, mock_get_host, mock_post): class TestSageMakerTgiServing(unittest.TestCase): """Test SageMakerTgiServing class.""" - @patch('sagemaker.serve.model_server.tgi.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") def test_upload_tgi_artifacts_with_s3_path(self, mock_is_s3): """Test _upload_tgi_artifacts with S3 path.""" from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing - + server = SageMakerTgiServing() mock_is_s3.return_value = True mock_session = Mock() - + model_data, env_vars = server._upload_tgi_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, jumpstart=False, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNotNone(model_data) self.assertEqual(model_data["S3DataSource"]["S3Uri"], "s3://bucket/model/") - @patch('sagemaker.serve.model_server.tgi.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") def test_upload_tgi_artifacts_jumpstart(self, mock_is_s3): """Test _upload_tgi_artifacts with jumpstart=True.""" from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing - + server = SageMakerTgiServing() mock_is_s3.return_value = True mock_session = Mock() - + model_data, env_vars = server._upload_tgi_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, jumpstart=True, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNotNone(model_data) self.assertEqual(env_vars, {}) - @patch('sagemaker.serve.model_server.tgi.server.S3Uploader') - @patch('sagemaker.serve.model_server.tgi.server.s3_path_join') - @patch('sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.tgi.server.parse_s3_url') - @patch('sagemaker.serve.model_server.tgi.server.fw_utils') - @patch('sagemaker.serve.model_server.tgi.server._is_s3_uri') - @patch('sagemaker.serve.model_server.tgi.server.Path') - def test_upload_tgi_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_s3_join, mock_uploader): + @patch("sagemaker.serve.model_server.tgi.server.S3Uploader") + @patch("sagemaker.serve.model_server.tgi.server.s3_path_join") + @patch("sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tgi.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tgi.server.fw_utils") + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tgi.server.Path") + def test_upload_tgi_artifacts_uploads_to_s3( + self, + mock_path, + mock_is_s3, + mock_fw_utils, + mock_parse, + mock_determine, + mock_s3_join, + mock_uploader, + ): """Test _upload_tgi_artifacts uploads to S3.""" from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing - + server = SageMakerTgiServing() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_s3_join.return_value = "s3://bucket/code_prefix/code" mock_uploader.upload.return_value = "s3://bucket/code_prefix/code" - + mock_path_obj = Mock() mock_code_dir = Mock() mock_path_obj.joinpath.return_value = mock_code_dir mock_path.return_value = mock_path_obj - + mock_session = Mock() - + model_data, env_vars = server._upload_tgi_artifacts( model_path="/local/model", sagemaker_session=mock_session, @@ -181,9 +187,9 @@ def test_upload_tgi_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw s3_model_data_url="s3://bucket/prefix", image="test-image", env_vars={"CUSTOM": "var"}, - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertIsNotNone(model_data) mock_uploader.upload.assert_called_once() @@ -194,7 +200,7 @@ class TestUpdateEnvVars(unittest.TestCase): def test_update_env_vars_with_none(self): """Test _update_env_vars with None input.""" from sagemaker.serve.model_server.tgi.server import _update_env_vars - + result = _update_env_vars(None) self.assertIn("HF_HOME", result) self.assertIn("HUGGINGFACE_HUB_CACHE", result) @@ -202,10 +208,10 @@ def test_update_env_vars_with_none(self): def test_update_env_vars_with_custom_vars(self): """Test _update_env_vars with custom variables.""" from sagemaker.serve.model_server.tgi.server import _update_env_vars - + custom_vars = {"CUSTOM_KEY": "custom_value"} result = _update_env_vars(custom_vars) - + self.assertIn("CUSTOM_KEY", result) self.assertIn("HF_HOME", result) self.assertEqual(result["CUSTOM_KEY"], "custom_value") diff --git a/sagemaker-serve/tests/unit/model_server/test_tgi_utils.py b/sagemaker-serve/tests/unit/model_server/test_tgi_utils.py index 23042138e4..317836ef4c 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tgi_utils.py +++ b/sagemaker-serve/tests/unit/model_server/test_tgi_utils.py @@ -1,4 +1,5 @@ """Unit tests for TGI serving utils module.""" + import unittest from unittest.mock import Mock, patch @@ -9,14 +10,14 @@ class TestTGIUtilsDataType(unittest.TestCase): def test_get_default_dtype(self): """Test _get_default_dtype returns bfloat16.""" from sagemaker.serve.model_server.tgi.utils import _get_default_dtype - + result = _get_default_dtype() self.assertEqual(result, "bfloat16") def test_get_admissible_dtypes(self): """Test _get_admissible_dtypes returns list with bfloat16.""" from sagemaker.serve.model_server.tgi.utils import _get_admissible_dtypes - + result = _get_admissible_dtypes() self.assertEqual(result, ["bfloat16"]) @@ -24,97 +25,87 @@ def test_get_admissible_dtypes(self): class TestTGIUtilsConfigurations(unittest.TestCase): """Test TGI utils configuration functions.""" - @patch('sagemaker.serve.model_server.tgi.utils._get_default_max_tokens') - @patch('sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree') + @patch("sagemaker.serve.model_server.tgi.utils._get_default_max_tokens") + @patch("sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree") def test_get_default_tgi_configurations_with_sharding(self, mock_parallel, mock_tokens): """Test TGI configurations with sharding enabled.""" from sagemaker.serve.model_server.tgi.utils import _get_default_tgi_configurations - + mock_parallel.return_value = 4 mock_tokens.return_value = (2048, 512) - + mock_schema_builder = Mock() mock_schema_builder.sample_input = {"inputs": "test"} mock_schema_builder.sample_output = [{"generated_text": "output"}] - + env, max_new_tokens = _get_default_tgi_configurations( - "model-id", - {"num_attention_heads": 32}, - mock_schema_builder + "model-id", {"num_attention_heads": 32}, mock_schema_builder ) - + self.assertEqual(env["SHARDED"], "true") self.assertEqual(env["NUM_SHARD"], "4") self.assertEqual(env["DTYPE"], "bfloat16") self.assertEqual(max_new_tokens, 512) - @patch('sagemaker.serve.model_server.tgi.utils._get_default_max_tokens') - @patch('sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree') + @patch("sagemaker.serve.model_server.tgi.utils._get_default_max_tokens") + @patch("sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree") def test_get_default_tgi_configurations_without_sharding(self, mock_parallel, mock_tokens): """Test TGI configurations with sharding disabled.""" from sagemaker.serve.model_server.tgi.utils import _get_default_tgi_configurations - + mock_parallel.return_value = 1 mock_tokens.return_value = (1024, 256) - + mock_schema_builder = Mock() mock_schema_builder.sample_input = {"inputs": "test"} mock_schema_builder.sample_output = [{"generated_text": "output"}] - + env, max_new_tokens = _get_default_tgi_configurations( - "model-id", - {"num_attention_heads": 12}, - mock_schema_builder + "model-id", {"num_attention_heads": 12}, mock_schema_builder ) - + self.assertEqual(env["SHARDED"], "false") self.assertEqual(env["NUM_SHARD"], "1") self.assertEqual(env["DTYPE"], "bfloat16") self.assertEqual(max_new_tokens, 256) - @patch('sagemaker.serve.model_server.tgi.utils._get_default_max_tokens') - @patch('sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree') + @patch("sagemaker.serve.model_server.tgi.utils._get_default_max_tokens") + @patch("sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree") def test_get_default_tgi_configurations_no_parallel_degree(self, mock_parallel, mock_tokens): """Test TGI configurations when parallel degree is None.""" from sagemaker.serve.model_server.tgi.utils import _get_default_tgi_configurations - + mock_parallel.return_value = None mock_tokens.return_value = (1024, 256) - + mock_schema_builder = Mock() mock_schema_builder.sample_input = {"inputs": "test"} mock_schema_builder.sample_output = [{"generated_text": "output"}] - - env, max_new_tokens = _get_default_tgi_configurations( - "model-id", - {}, - mock_schema_builder - ) - + + env, max_new_tokens = _get_default_tgi_configurations("model-id", {}, mock_schema_builder) + self.assertIsNone(env["SHARDED"]) self.assertIsNone(env["NUM_SHARD"]) self.assertEqual(env["DTYPE"], "bfloat16") self.assertEqual(max_new_tokens, 256) - @patch('sagemaker.serve.model_server.tgi.utils._get_default_max_tokens') - @patch('sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree') + @patch("sagemaker.serve.model_server.tgi.utils._get_default_max_tokens") + @patch("sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree") def test_get_default_tgi_configurations_returns_tuple(self, mock_parallel, mock_tokens): """Test that function returns a tuple.""" from sagemaker.serve.model_server.tgi.utils import _get_default_tgi_configurations - + mock_parallel.return_value = 2 mock_tokens.return_value = (1024, 256) - + mock_schema_builder = Mock() mock_schema_builder.sample_input = {"inputs": "test"} mock_schema_builder.sample_output = [{"generated_text": "output"}] - + result = _get_default_tgi_configurations( - "model-id", - {"num_attention_heads": 16}, - mock_schema_builder + "model-id", {"num_attention_heads": 16}, mock_schema_builder ) - + self.assertIsInstance(result, tuple) self.assertEqual(len(result), 2) self.assertIsInstance(result[0], dict) diff --git a/sagemaker-serve/tests/unit/model_server/test_torchserve_inference.py b/sagemaker-serve/tests/unit/model_server/test_torchserve_inference.py index 9d5fc57485..e3081f8b94 100644 --- a/sagemaker-serve/tests/unit/model_server/test_torchserve_inference.py +++ b/sagemaker-serve/tests/unit/model_server/test_torchserve_inference.py @@ -14,93 +14,111 @@ class TestTorchServeInference(unittest.TestCase): def test_predict_fn_logic(self): """Test predict_fn logic.""" + def predict_fn(input_data, predict_callable): return predict_callable(input_data) - + mock_predict_callable = Mock(return_value=[0.1, 0.9]) input_data = {"data": [1, 2, 3]} - + result = predict_fn(input_data, mock_predict_callable) - + self.assertEqual(result, [0.1, 0.9]) mock_predict_callable.assert_called_once_with(input_data) def test_input_fn_with_preprocess_logic(self): """Test input_fn with preprocess logic.""" + def input_fn(input_data, content_type, schema_builder, inference_spec): # Deserialize if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type, ) - + # Preprocess if available if hasattr(inference_spec, "preprocess"): preprocessed = inference_spec.preprocess(deserialized_data) if preprocessed is not None: return preprocessed - + return deserialized_data - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value={"data": [1, 2, 3]}) - + inference_spec = Mock() inference_spec.preprocess = Mock(return_value={"preprocessed": True}) - + result = input_fn('{"data": [1, 2, 3]}', "application/json", schema_builder, inference_spec) - + self.assertEqual(result, {"preprocessed": True}) inference_spec.preprocess.assert_called_once_with({"data": [1, 2, 3]}) def test_output_fn_with_postprocess_logic(self): """Test output_fn with postprocess logic.""" + def output_fn(predictions, accept_type, schema_builder, inference_spec): # Postprocess if available if hasattr(inference_spec, "postprocess"): postprocessed = inference_spec.postprocess(predictions) if postprocessed is not None: predictions = postprocessed - + # Serialize if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: return schema_builder.output_serializer.serialize(predictions) - + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + inference_spec = Mock() inference_spec.postprocess = Mock(return_value={"postprocessed": True}) - + result = output_fn([0.1, 0.9], "application/json", schema_builder, inference_spec) - + inference_spec.postprocess.assert_called_once_with([0.1, 0.9]) - schema_builder.custom_output_translator.serialize.assert_called_once_with({"postprocessed": True}, "application/json") + schema_builder.custom_output_translator.serialize.assert_called_once_with( + {"postprocessed": True}, "application/json" + ) - @patch.dict(os.environ, {'MLFLOW_MODEL_FLAVOR': 'pytorch'}) + @patch.dict(os.environ, {"MLFLOW_MODEL_FLAVOR": "pytorch"}) def test_get_mlflow_flavor_logic(self): """Test _get_mlflow_flavor logic.""" + def _get_mlflow_flavor(): return os.getenv("MLFLOW_MODEL_FLAVOR") - + result = _get_mlflow_flavor() - self.assertEqual(result, 'pytorch') + self.assertEqual(result, "pytorch") - @patch('importlib.import_module') + @patch("importlib.import_module") def test_load_mlflow_model_logic(self, mock_import): """Test _load_mlflow_model logic.""" + def _load_mlflow_model(deployment_flavor, model_dir): import importlib + flavor_loader_map = { "pytorch": ("mlflow.pytorch", "load_model"), "tensorflow": ("mlflow.tensorflow", "load_model"), @@ -111,14 +129,14 @@ def _load_mlflow_model(deployment_flavor, model_dir): flavor_module = importlib.import_module(flavor_module_name) load_model_function = getattr(flavor_module, load_function_name) return load_model_function(model_dir) - + mock_module = Mock() mock_module.load_model = Mock(return_value=Mock()) mock_import.return_value = mock_module - - result = _load_mlflow_model('tensorflow', '/model/dir') - - mock_import.assert_called_once_with('mlflow.tensorflow') + + result = _load_mlflow_model("tensorflow", "/model/dir") + + mock_import.assert_called_once_with("mlflow.tensorflow") if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_torchserve_prepare.py b/sagemaker-serve/tests/unit/model_server/test_torchserve_prepare.py index 1ae35eca6a..d1ca6decde 100644 --- a/sagemaker-serve/tests/unit/model_server/test_torchserve_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_torchserve_prepare.py @@ -17,170 +17,162 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_server.torchserve.prepare.compute_hash') - @patch('sagemaker.serve.model_server.torchserve.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.torchserve.prepare.capture_dependencies') - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') - @patch('shutil.copy2') - def test_prepare_for_torchserve_standard_image(self, mock_copy, mock_is_1p, mock_capture, - mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash") + @patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") + @patch("shutil.copy2") + def test_prepare_for_torchserve_standard_image( + self, mock_copy, mock_is_1p, mock_capture, mock_hash + ): """Test prepare_for_torchserve with standard image.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + mock_is_1p.return_value = True - mock_gen_key.return_value = "test-secret-key" mock_hash.return_value = "test-hash" mock_session = Mock() mock_inference_spec = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_torchserve( model_path=str(model_path), shared_libs=[], dependencies={}, session=mock_session, image_uri="test-pytorch-image", - inference_spec=mock_inference_spec + inference_spec=mock_inference_spec, ) - - self.assertEqual(secret_key, "test-secret-key") + mock_inference_spec.prepare.assert_called_once_with(str(model_path)) mock_capture.assert_called_once() - @patch('os.rename') - @patch('sagemaker.serve.model_server.torchserve.prepare.compute_hash') - @patch('sagemaker.serve.model_server.torchserve.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.torchserve.prepare.capture_dependencies') - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') - @patch('shutil.copy2') - def test_prepare_for_torchserve_xgboost_image(self, mock_copy, mock_is_1p, mock_capture, - mock_gen_key, mock_hash, mock_rename): + @patch("os.rename") + @patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash") + @patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") + @patch("shutil.copy2") + def test_prepare_for_torchserve_xgboost_image( + self, mock_copy, mock_is_1p, mock_capture, mock_hash, mock_rename + ): """Test prepare_for_torchserve with xgboost image.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + mock_is_1p.return_value = True - mock_gen_key.return_value = "test-secret-key" mock_hash.return_value = "test-hash" mock_session = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_torchserve( model_path=str(model_path), shared_libs=[], dependencies={}, session=mock_session, image_uri="xgboost-image:latest", - inference_spec=None + inference_spec=None, ) - - self.assertEqual(secret_key, "test-secret-key") + # Verify xgboost_inference.py was copied and renamed mock_rename.assert_called_once() - @patch('sagemaker.serve.model_server.torchserve.prepare.compute_hash') - @patch('sagemaker.serve.model_server.torchserve.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.torchserve.prepare.capture_dependencies') - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') - @patch('shutil.copy2') - def test_prepare_for_torchserve_with_shared_libs(self, mock_copy, mock_is_1p, mock_capture, - mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash") + @patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") + @patch("shutil.copy2") + def test_prepare_for_torchserve_with_shared_libs( + self, mock_copy, mock_is_1p, mock_capture, mock_hash + ): """Test prepare_for_torchserve copies shared libraries.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + shared_lib = Path(self.temp_dir) / "lib.so" shared_lib.touch() - + mock_is_1p.return_value = False - mock_gen_key.return_value = "test-key" mock_hash.return_value = "test-hash" mock_session = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): prepare_for_torchserve( model_path=str(model_path), shared_libs=[str(shared_lib)], dependencies={}, session=mock_session, - image_uri="test-image" + image_uri="test-image", ) - + # Verify copy2 was called for shared lib self.assertTrue(any(str(shared_lib) in str(call) for call in mock_copy.call_args_list)) - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") def test_prepare_for_torchserve_invalid_dir(self, mock_is_1p): """Test prepare_for_torchserve raises exception for invalid directory.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + mock_session = Mock() - + with self.assertRaises(Exception) as context: prepare_for_torchserve( model_path=str(file_path), shared_libs=[], dependencies={}, session=mock_session, - image_uri="test-image" + image_uri="test-image", ) self.assertIn("not a valid directory", str(context.exception)) - @patch('sagemaker.serve.model_server.torchserve.prepare.compute_hash') - @patch('sagemaker.serve.model_server.torchserve.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.torchserve.prepare.capture_dependencies') - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') - @patch('shutil.copy2') - def test_prepare_for_torchserve_no_inference_spec(self, mock_copy, mock_is_1p, mock_capture, - mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash") + @patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") + @patch("shutil.copy2") + def test_prepare_for_torchserve_no_inference_spec( + self, mock_copy, mock_is_1p, mock_capture, mock_hash + ): """Test prepare_for_torchserve without inference_spec.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + mock_is_1p.return_value = False - mock_gen_key.return_value = "test-key" mock_hash.return_value = "test-hash" mock_session = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_torchserve( model_path=str(model_path), shared_libs=[], dependencies={}, session=mock_session, image_uri="test-image", - inference_spec=None + inference_spec=None, ) - - self.assertEqual(secret_key, "test-key") if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_torchserve_server.py b/sagemaker-serve/tests/unit/model_server/test_torchserve_server.py index 95b0645076..ccc4368841 100644 --- a/sagemaker-serve/tests/unit/model_server/test_torchserve_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_torchserve_server.py @@ -8,70 +8,68 @@ class TestLocalTorchServe(unittest.TestCase): """Test LocalTorchServe class.""" - @patch('sagemaker.serve.model_server.torchserve.server.Path') + @patch("sagemaker.serve.model_server.torchserve.server.Path") def test_start_torch_serve(self, mock_path): """Test _start_torch_serve creates container.""" from sagemaker.serve.model_server.torchserve.server import LocalTorchServe - + server = LocalTorchServe() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value = mock_path_obj - + server._start_torch_serve( client=mock_client, image="torchserve:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars={"CUSTOM_VAR": "value"} + env_vars={"CUSTOM_VAR": "value"}, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() call_kwargs = mock_client.containers.run.call_args[1] - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", call_kwargs["environment"]) - self.assertEqual(call_kwargs["environment"]["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret") self.assertEqual(call_kwargs["environment"]["CUSTOM_VAR"], "value") - @patch('sagemaker.serve.model_server.torchserve.server.requests.post') - @patch('sagemaker.serve.model_server.torchserve.server.get_docker_host') + @patch("sagemaker.serve.model_server.torchserve.server.requests.post") + @patch("sagemaker.serve.model_server.torchserve.server.get_docker_host") def test_invoke_torch_serve_success(self, mock_get_host, mock_post): """Test _invoke_torch_serve successful request.""" from sagemaker.serve.model_server.torchserve.server import LocalTorchServe - + server = LocalTorchServe() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"predictions": [0.1, 0.9]}' mock_post.return_value = mock_response - + result = server._invoke_torch_serve( request='{"data": [1, 2, 3]}', content_type="application/json", - accept="application/json" + accept="application/json", ) - + self.assertEqual(result, b'{"predictions": [0.1, 0.9]}') mock_post.assert_called_once() - @patch('sagemaker.serve.model_server.torchserve.server.requests.post') - @patch('sagemaker.serve.model_server.torchserve.server.get_docker_host') + @patch("sagemaker.serve.model_server.torchserve.server.requests.post") + @patch("sagemaker.serve.model_server.torchserve.server.get_docker_host") def test_invoke_torch_serve_failure(self, mock_get_host, mock_post): """Test _invoke_torch_serve handles errors.""" from sagemaker.serve.model_server.torchserve.server import LocalTorchServe - + server = LocalTorchServe() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_torch_serve( request='{"data": [1, 2, 3]}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -79,78 +77,78 @@ def test_invoke_torch_serve_failure(self, mock_get_host, mock_post): class TestSageMakerTorchServe(unittest.TestCase): """Test SageMakerTorchServe class.""" - @patch('sagemaker.serve.model_server.torchserve.server._is_s3_uri') + @patch("sagemaker.serve.model_server.torchserve.server._is_s3_uri") def test_upload_torchserve_artifacts_with_s3_path(self, mock_is_s3): """Test _upload_torchserve_artifacts with S3 path.""" from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe - + server = SageMakerTorchServe() mock_is_s3.return_value = True mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_torchserve_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertEqual(s3_path, "s3://bucket/model") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key") - - @patch('sagemaker.serve.model_server.torchserve.server.upload') - @patch('sagemaker.serve.model_server.torchserve.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.torchserve.server.parse_s3_url') - @patch('sagemaker.serve.model_server.torchserve.server.fw_utils') - @patch('sagemaker.serve.model_server.torchserve.server._is_s3_uri') - def test_upload_torchserve_artifacts_uploads_to_s3(self, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_upload): + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) + + @patch("sagemaker.serve.model_server.torchserve.server.upload") + @patch("sagemaker.serve.model_server.torchserve.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.torchserve.server.parse_s3_url") + @patch("sagemaker.serve.model_server.torchserve.server.fw_utils") + @patch("sagemaker.serve.model_server.torchserve.server._is_s3_uri") + def test_upload_torchserve_artifacts_uploads_to_s3( + self, mock_is_s3, mock_fw_utils, mock_parse, mock_determine, mock_upload + ): """Test _upload_torchserve_artifacts uploads to S3.""" from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe - + server = SageMakerTorchServe() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_upload.return_value = "s3://bucket/code_prefix/model.tar.gz" - + mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_torchserve_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", s3_model_data_url="s3://bucket/prefix", image="test-image", - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertEqual(s3_path, "s3://bucket/code_prefix/model.tar.gz") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) mock_upload.assert_called_once() - @patch('sagemaker.serve.model_server.torchserve.server._is_s3_uri') + @patch("sagemaker.serve.model_server.torchserve.server._is_s3_uri") def test_upload_torchserve_artifacts_no_upload(self, mock_is_s3): """Test _upload_torchserve_artifacts without uploading.""" from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe - + server = SageMakerTorchServe() mock_is_s3.return_value = False mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_torchserve_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNone(s3_path) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_torchserve_xgboost_inference.py b/sagemaker-serve/tests/unit/model_server/test_torchserve_xgboost_inference.py index 8d065707f6..9b2f81febd 100644 --- a/sagemaker-serve/tests/unit/model_server/test_torchserve_xgboost_inference.py +++ b/sagemaker-serve/tests/unit/model_server/test_torchserve_xgboost_inference.py @@ -13,43 +13,48 @@ class TestXGBoostInferenceSimple(unittest.TestCase): def test_predict_fn_logic(self): """Test predict_fn logic.""" + # Simulate the predict_fn behavior def predict_fn(input_data, predict_callable): return predict_callable(input_data) - + mock_predict_callable = Mock(return_value=[0.1, 0.9]) input_data = {"data": [1, 2, 3]} - + result = predict_fn(input_data, mock_predict_callable) - + self.assertEqual(result, [0.1, 0.9]) mock_predict_callable.assert_called_once_with(input_data) - @patch.dict(os.environ, {'MLFLOW_MODEL_FLAVOR': 'sklearn'}) + @patch.dict(os.environ, {"MLFLOW_MODEL_FLAVOR": "sklearn"}) def test_get_mlflow_flavor_logic(self): """Test _get_mlflow_flavor logic.""" + # Simulate the _get_mlflow_flavor behavior def _get_mlflow_flavor(): return os.getenv("MLFLOW_MODEL_FLAVOR") - + result = _get_mlflow_flavor() - self.assertEqual(result, 'sklearn') + self.assertEqual(result, "sklearn") @patch.dict(os.environ, {}, clear=True) def test_get_mlflow_flavor_none_logic(self): """Test _get_mlflow_flavor with no env var.""" + def _get_mlflow_flavor(): return os.getenv("MLFLOW_MODEL_FLAVOR") - + result = _get_mlflow_flavor() self.assertIsNone(result) - @patch('importlib.import_module') + @patch("importlib.import_module") def test_load_mlflow_model_logic(self, mock_import): """Test _load_mlflow_model logic.""" + # Simulate the _load_mlflow_model behavior def _load_mlflow_model(deployment_flavor, model_dir): import importlib + flavor_loader_map = { "sklearn": ("mlflow.sklearn", "load_model"), "pytorch": ("mlflow.pytorch", "load_model"), @@ -60,70 +65,81 @@ def _load_mlflow_model(deployment_flavor, model_dir): flavor_module = importlib.import_module(flavor_module_name) load_model_function = getattr(flavor_module, load_function_name) return load_model_function(model_dir) - + mock_module = Mock() mock_module.load_model = Mock(return_value=Mock()) mock_import.return_value = mock_module - - result = _load_mlflow_model('sklearn', '/model/dir') - - mock_import.assert_called_once_with('mlflow.sklearn') + + result = _load_mlflow_model("sklearn", "/model/dir") + + mock_import.assert_called_once_with("mlflow.sklearn") def test_input_fn_custom_translator_logic(self): """Test input_fn with custom translator logic.""" import io - + # Simulate input_fn behavior def input_fn(input_data, content_type, schema_builder): if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type, ) else: return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type[0], ) - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value={"data": [1, 2, 3]}) - + result = input_fn('{"data": [1, 2, 3]}', ["application/json"], schema_builder) - + self.assertEqual(result, {"data": [1, 2, 3]}) def test_output_fn_custom_translator_logic(self): """Test output_fn with custom translator logic.""" + # Simulate output_fn behavior def output_fn(predictions, accept_type, schema_builder): if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: return schema_builder.output_serializer.serialize(predictions) - + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + result = output_fn([0.1, 0.9], "application/json", schema_builder) - + self.assertEqual(result, b'{"predictions": [0.1, 0.9]}') def test_python_version_check_logic(self): """Test Python version parity check logic.""" import platform - + # Simulate _py_vs_parity_check behavior def _py_vs_parity_check(local_py_vs): container_py_vs = platform.python_version() if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: return False # Would log warning return True - + # Test matching versions - result = _py_vs_parity_check('3.9.0') + result = _py_vs_parity_check("3.9.0") # Result depends on actual Python version, just verify it runs self.assertIsInstance(result, bool) diff --git a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py index b15e77a0b0..4355474c3d 100644 --- a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py +++ b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py @@ -8,15 +8,16 @@ import unittest # Prevent JumpStart from loading region config during import -os.environ['SAGEMAKER_INTERNAL_SKIP_REGION_CONFIG'] = '1' +os.environ["SAGEMAKER_INTERNAL_SKIP_REGION_CONFIG"] = "1" from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.model_builder_servers import _ModelBuilderServers + class MockModelBuilderServers(_ModelBuilderServers): """Mock class that inherits _ModelBuilderServers behavior.""" - + def __init__(self): self.model_server = ModelServer.TORCHSERVE self.model = None @@ -46,75 +47,75 @@ def __init__(self): self.framework_version = None self._is_mlflow_model = False self.config_name = None - + def _deploy_local_endpoint(self, **kwargs): return Mock() - + def _deploy_core_endpoint(self, *args, **kwargs): return Mock() - + def _save_model_inference_spec(self): pass - + def _is_jumpstart_model_id(self): return False - + def _auto_detect_image_uri(self): pass - + def _prepare_for_mode(self, should_upload_artifacts=False): return ("s3://bucket/model.tar.gz", None) - + def _create_model(self): return Mock() - + def _validate_tgi_serving_sample_data(self): pass - + def _validate_djl_serving_sample_data(self): pass - + def _validate_for_triton(self): pass - + def _auto_detect_image_for_triton(self): pass - + def _save_inference_spec(self): pass - + def _prepare_for_triton(self): pass - + def get_huggingface_model_metadata(self, model_id, token=None): return {} - + def _normalize_framework_to_enum(self, framework): return framework - + def _get_processing_unit(self): return "cpu" - + def _get_smd_image_uri(self, processing_unit): return "smd-image-uri" - + def _create_conda_env(self): pass class TestBuildForModelServer(unittest.TestCase): """Test _build_for_model_server method.""" - + def setUp(self): self.builder = MockModelBuilderServers() - + def test_unsupported_model_server(self): """Test error for unsupported model server.""" self.builder.model_server = "INVALID_SERVER" with self.assertRaises(ValueError) as ctx: self.builder._build_for_model_server() self.assertIn("not supported", str(ctx.exception)) - + def test_missing_required_parameters(self): """Test error when model, MLflow path, and inference_spec are all missing.""" self.builder.model = None @@ -123,8 +124,8 @@ def test_missing_required_parameters(self): with self.assertRaises(ValueError) as ctx: self.builder._build_for_model_server() self.assertIn("Missing required parameter", str(ctx.exception)) - - @patch.object(MockModelBuilderServers, '_build_for_torchserve') + + @patch.object(MockModelBuilderServers, "_build_for_torchserve") def test_route_to_torchserve(self, mock_build): """Test routing to TorchServe builder.""" self.builder.model_server = ModelServer.TORCHSERVE @@ -132,8 +133,8 @@ def test_route_to_torchserve(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_triton') + + @patch.object(MockModelBuilderServers, "_build_for_triton") def test_route_to_triton(self, mock_build): """Test routing to Triton builder.""" self.builder.model_server = ModelServer.TRITON @@ -141,8 +142,8 @@ def test_route_to_triton(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_tensorflow_serving') + + @patch.object(MockModelBuilderServers, "_build_for_tensorflow_serving") def test_route_to_tensorflow_serving(self, mock_build): """Test routing to TensorFlow Serving builder.""" self.builder.model_server = ModelServer.TENSORFLOW_SERVING @@ -150,8 +151,8 @@ def test_route_to_tensorflow_serving(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_djl') + + @patch.object(MockModelBuilderServers, "_build_for_djl") def test_route_to_djl(self, mock_build): """Test routing to DJL builder.""" self.builder.model_server = ModelServer.DJL_SERVING @@ -159,8 +160,8 @@ def test_route_to_djl(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_tei') + + @patch.object(MockModelBuilderServers, "_build_for_tei") def test_route_to_tei(self, mock_build): """Test routing to TEI builder.""" self.builder.model_server = ModelServer.TEI @@ -168,8 +169,8 @@ def test_route_to_tei(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_tgi') + + @patch.object(MockModelBuilderServers, "_build_for_tgi") def test_route_to_tgi(self, mock_build): """Test routing to TGI builder.""" self.builder.model_server = ModelServer.TGI @@ -177,8 +178,8 @@ def test_route_to_tgi(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_transformers') + + @patch.object(MockModelBuilderServers, "_build_for_transformers") def test_route_to_mms(self, mock_build): """Test routing to MMS builder.""" self.builder.model_server = ModelServer.MMS @@ -186,8 +187,8 @@ def test_route_to_mms(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_smd') + + @patch.object(MockModelBuilderServers, "_build_for_smd") def test_route_to_smd(self, mock_build): """Test routing to SMD builder.""" self.builder.model_server = ModelServer.SMD @@ -199,107 +200,124 @@ def test_route_to_smd(self, mock_build): class TestBuildForTorchServe(unittest.TestCase): """Test _build_for_torchserve method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TORCHSERVE - - @patch.object(MockModelBuilderServers, '_save_model_inference_spec') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model_id(self, mock_create, mock_prepare, mock_detect, mock_js, mock_save): + + @patch.object(MockModelBuilderServers, "_save_model_inference_spec") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model_id( + self, mock_create, mock_prepare, mock_detect, mock_js, mock_save + ): """Test building with HuggingFace model ID.""" mock_js.return_value = False mock_create.return_value = Mock() self.builder.mode = Mode.IN_PROCESS self.builder.model = "bert-base-uncased" self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "test-token"} - + result = self.builder._build_for_torchserve() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "bert-base-uncased") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "test-token") self.assertIsNone(self.builder.s3_upload_path) mock_save.assert_called_once() mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers.prepare_for_torchserve') - @patch.object(MockModelBuilderServers, '_save_model_inference_spec') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_local_container_mode(self, mock_create, mock_prepare, mock_detect, mock_save, mock_ts_prepare): + + @patch("sagemaker.serve.model_builder_servers.prepare_for_torchserve") + @patch.object(MockModelBuilderServers, "_save_model_inference_spec") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_local_container_mode( + self, mock_create, mock_prepare, mock_detect, mock_save, mock_ts_prepare + ): """Test building for LOCAL_CONTAINER mode.""" self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model = Mock() - mock_ts_prepare.return_value = "secret123" + mock_ts_prepare.return_value = "" mock_create.return_value = Mock() - + result = self.builder._build_for_torchserve() - + mock_ts_prepare.assert_called_once() - self.assertEqual(self.builder.secret_key, "secret123") + self.assertEqual(self.builder.secret_key, "") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers.prepare_for_torchserve') - @patch.object(MockModelBuilderServers, '_save_model_inference_spec') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_mode(self, mock_create, mock_prepare, mock_detect, mock_save, mock_ts_prepare): + + @patch("sagemaker.serve.model_builder_servers.prepare_for_torchserve") + @patch.object(MockModelBuilderServers, "_save_model_inference_spec") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_mode( + self, mock_create, mock_prepare, mock_detect, mock_save, mock_ts_prepare + ): """Test building for SAGEMAKER_ENDPOINT mode.""" self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = Mock() - mock_ts_prepare.return_value = "secret456" + mock_ts_prepare.return_value = "" mock_create.return_value = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) - + result = self.builder._build_for_torchserve() - + mock_ts_prepare.assert_called_once() - self.assertEqual(self.builder.secret_key, "secret456") + self.assertEqual(self.builder.secret_key, "") mock_prepare.assert_called_with(should_upload_artifacts=True) class TestBuildForTGI(unittest.TestCase): """Test _build_for_tgi method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TGI - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_tgi_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_notebook_instance(self, mock_create, mock_prepare, mock_detect, - mock_validate, mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_tgi_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_notebook_instance( + self, mock_create, mock_prepare, mock_detect, mock_validate, mock_dir, mock_nb + ): """Test building with notebook instance detection.""" mock_nb.return_value = "ml.g4dn.xlarge" mock_create.return_value = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = Mock() - + result = self.builder._build_for_tgi() - + self.assertEqual(self.builder.instance_type, "ml.g4dn.xlarge") mock_create.assert_called_once() - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_tgi_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_tgi_serving_sample_data') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_js, - mock_validate, mock_dir, mock_nb, mock_tgi_config, mock_hf_config): + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_tgi_serving_sample_data") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model( + self, + mock_create, + mock_prepare, + mock_detect, + mock_js, + mock_validate, + mock_dir, + mock_nb, + mock_tgi_config, + mock_hf_config, + ): """Test building with HuggingFace model.""" mock_js.return_value = False mock_nb.return_value = None @@ -310,25 +328,34 @@ def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_ self.builder.model = "gpt2" self.builder.mode = Mode.LOCAL_CONTAINER self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_tgi() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "token") self.assertEqual(self.builder.env_vars["SHARDED"], "false") self.assertEqual(self.builder.env_vars["NUM_SHARD"], "1") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_tgi_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_with_gpu(self, mock_create, mock_prepare, mock_detect, - mock_validate, mock_dir, mock_nb, mock_tp, mock_gpu): + + @patch("sagemaker.serve.model_builder_servers._get_gpu_info") + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_tgi_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_with_gpu( + self, + mock_create, + mock_prepare, + mock_detect, + mock_validate, + mock_dir, + mock_nb, + mock_tp, + mock_gpu, + ): """Test building for SAGEMAKER_ENDPOINT with GPU sharding.""" mock_nb.return_value = None mock_gpu.return_value = 4 @@ -338,24 +365,34 @@ def test_build_sagemaker_endpoint_with_gpu(self, mock_create, mock_prepare, mock self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = Mock() self.builder.hf_model_config = {"model_type": "gpt2"} - + result = self.builder._build_for_tgi() - + self.assertEqual(self.builder.env_vars["NUM_SHARD"], "2") self.assertEqual(self.builder.env_vars["SHARDED"], "true") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info_fallback') - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_tgi_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_gpu_fallback(self, mock_create, mock_prepare, mock_detect, mock_validate, - mock_dir, mock_nb, mock_tp, mock_gpu, mock_fallback): + + @patch("sagemaker.serve.model_builder_servers._get_gpu_info_fallback") + @patch("sagemaker.serve.model_builder_servers._get_gpu_info") + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_tgi_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_gpu_fallback( + self, + mock_create, + mock_prepare, + mock_detect, + mock_validate, + mock_dir, + mock_nb, + mock_tp, + mock_gpu, + mock_fallback, + ): """Test GPU info fallback when primary method fails.""" mock_nb.return_value = None mock_gpu.side_effect = Exception("GPU info failed") @@ -365,28 +402,29 @@ def test_build_gpu_fallback(self, mock_create, mock_prepare, mock_detect, mock_v mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = Mock() - + result = self.builder._build_for_tgi() - + mock_fallback.assert_called_once() mock_create.assert_called_once() class TestBuildForDJL(unittest.TestCase): """Test _build_for_djl method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.DJL_SERVING - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_djl_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_timeout(self, mock_create, mock_prepare, mock_detect, - mock_validate, mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_djl_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_timeout( + self, mock_create, mock_prepare, mock_detect, mock_validate, mock_dir, mock_nb + ): """Test building with model_data_download_timeout.""" mock_nb.return_value = None mock_create.return_value = Mock() @@ -394,23 +432,33 @@ def test_build_with_timeout(self, mock_create, mock_prepare, mock_detect, self.builder.model = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model_data_download_timeout = 600 - + result = self.builder._build_for_djl() - + self.assertEqual(self.builder.env_vars["MODEL_LOADING_TIMEOUT"], "600") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_djl_serving_sample_data') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_js, - mock_validate, mock_dir, mock_nb, mock_djl_config, mock_hf_config): + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_djl_serving_sample_data") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model( + self, + mock_create, + mock_prepare, + mock_detect, + mock_js, + mock_validate, + mock_dir, + mock_nb, + mock_djl_config, + mock_hf_config, + ): """Test building with HuggingFace model.""" mock_js.return_value = False mock_nb.return_value = None @@ -421,24 +469,33 @@ def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_ self.builder.model = "gpt2" self.builder.mode = Mode.LOCAL_CONTAINER self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_djl() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "token") self.assertEqual(self.builder.env_vars["OPTION_ENGINE"], "Python") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_djl_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_tensor_parallel(self, mock_create, mock_prepare, mock_detect, - mock_validate, mock_dir, mock_nb, mock_tp, mock_gpu): + + @patch("sagemaker.serve.model_builder_servers._get_gpu_info") + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_djl_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_tensor_parallel( + self, + mock_create, + mock_prepare, + mock_detect, + mock_validate, + mock_dir, + mock_nb, + mock_tp, + mock_gpu, + ): """Test building for SAGEMAKER_ENDPOINT with tensor parallelism.""" mock_nb.return_value = None mock_gpu.return_value = 4 @@ -448,29 +505,37 @@ def test_build_sagemaker_endpoint_tensor_parallel(self, mock_create, mock_prepar self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = Mock() self.builder.hf_model_config = {"model_type": "gpt2"} - + result = self.builder._build_for_djl() - + self.assertEqual(self.builder.env_vars["TENSOR_PARALLEL_DEGREE"], "4") mock_create.assert_called_once() class TestBuildForTriton(unittest.TestCase): """Test _build_for_triton method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TRITON - - @patch.object(MockModelBuilderServers, 'get_huggingface_model_metadata') - @patch.object(MockModelBuilderServers, '_validate_for_triton') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_save_inference_spec') - @patch.object(MockModelBuilderServers, '_prepare_for_triton') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model_string(self, mock_create, mock_prepare_mode, mock_prepare_triton, - mock_save, mock_js, mock_validate, mock_hf_meta): + + @patch.object(MockModelBuilderServers, "get_huggingface_model_metadata") + @patch.object(MockModelBuilderServers, "_validate_for_triton") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_save_inference_spec") + @patch.object(MockModelBuilderServers, "_prepare_for_triton") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model_string( + self, + mock_create, + mock_prepare_mode, + mock_prepare_triton, + mock_save, + mock_js, + mock_validate, + mock_hf_meta, + ): """Test building with HuggingFace model string.""" mock_js.return_value = False mock_hf_meta.return_value = {"pipeline_tag": "text-generation"} @@ -478,26 +543,35 @@ def test_build_with_hf_model_string(self, mock_create, mock_prepare_mode, mock_p mock_prepare_mode.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = "gpt2" self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_triton() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2") self.assertEqual(self.builder.env_vars["HF_TASK"], "text-generation") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "token") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._detect_framework_and_version') - @patch('sagemaker.serve.model_builder_servers._get_model_base') - @patch.object(MockModelBuilderServers, '_normalize_framework_to_enum') - @patch.object(MockModelBuilderServers, '_validate_for_triton') - @patch.object(MockModelBuilderServers, '_auto_detect_image_for_triton') - @patch.object(MockModelBuilderServers, '_save_inference_spec') - @patch.object(MockModelBuilderServers, '_prepare_for_triton') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_model_object(self, mock_create, mock_prepare_mode, mock_prepare_triton, - mock_save, mock_detect_img, mock_validate, mock_normalize, - mock_base, mock_detect_fw): + + @patch("sagemaker.serve.model_builder_servers._detect_framework_and_version") + @patch("sagemaker.serve.model_builder_servers._get_model_base") + @patch.object(MockModelBuilderServers, "_normalize_framework_to_enum") + @patch.object(MockModelBuilderServers, "_validate_for_triton") + @patch.object(MockModelBuilderServers, "_auto_detect_image_for_triton") + @patch.object(MockModelBuilderServers, "_save_inference_spec") + @patch.object(MockModelBuilderServers, "_prepare_for_triton") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_model_object( + self, + mock_create, + mock_prepare_mode, + mock_prepare_triton, + mock_save, + mock_detect_img, + mock_validate, + mock_normalize, + mock_base, + mock_detect_fw, + ): """Test building with model object.""" mock_base.return_value = "pytorch_model" mock_detect_fw.return_value = ("pytorch", "1.8.0") @@ -506,9 +580,9 @@ def test_build_with_model_object(self, mock_create, mock_prepare_mode, mock_prep mock_prepare_mode.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = Mock() self.builder.image_uri = None - + result = self.builder._build_for_triton() - + self.assertEqual(self.builder.framework_version, "1.8.0") mock_detect_img.assert_called_once() mock_create.assert_called_once() @@ -516,40 +590,40 @@ def test_build_with_model_object(self, mock_create, mock_prepare_mode, mock_prep class TestBuildForTensorFlowServing(unittest.TestCase): """Test _build_for_tensorflow_serving method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TENSORFLOW_SERVING self.builder._is_mlflow_model = True - - @patch('sagemaker.serve.model_builder_servers.save_pkl') - @patch('sagemaker.serve.model_builder_servers.prepare_for_tf_serving') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') + + @patch("sagemaker.serve.model_builder_servers.save_pkl") + @patch("sagemaker.serve.model_builder_servers.prepare_for_tf_serving") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") def test_build_mlflow_model(self, mock_create, mock_prepare_mode, mock_tf_prepare, mock_save): """Test building MLflow model for TensorFlow Serving.""" - mock_tf_prepare.return_value = "secret789" + mock_tf_prepare.return_value = "" mock_create.return_value = Mock() mock_prepare_mode.return_value = ("s3://bucket/model.tar.gz", None) - + result = self.builder._build_for_tensorflow_serving() - - self.assertEqual(self.builder.secret_key, "secret789") + + self.assertEqual(self.builder.secret_key, "") mock_save.assert_called_once() mock_create.assert_called_once() - + def test_build_non_mlflow_model_error(self): """Test error when building non-MLflow model.""" self.builder._is_mlflow_model = False - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_tensorflow_serving() self.assertIn("mlflow", str(ctx.exception).lower()) - + def test_build_missing_image_uri_error(self): """Test error when image_uri is missing.""" self.builder.image_uri = None - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_tensorflow_serving() self.assertIn("image_uri", str(ctx.exception)) @@ -557,20 +631,21 @@ def test_build_missing_image_uri_error(self): class TestBuildForTEI(unittest.TestCase): """Test _build_for_tei method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TEI - - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_js, - mock_dir, mock_nb, mock_hf_config): + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model( + self, mock_create, mock_prepare, mock_detect, mock_js, mock_dir, mock_nb, mock_hf_config + ): """Test building with HuggingFace model.""" mock_js.return_value = False mock_nb.return_value = None @@ -579,27 +654,28 @@ def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_ mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = "bert-base-uncased" self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_tei() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "bert-base-uncased") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "token") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_missing_instance_type(self, mock_create, mock_prepare, - mock_detect, mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_missing_instance_type( + self, mock_create, mock_prepare, mock_detect, mock_dir, mock_nb + ): """Test error when instance_type is missing for SAGEMAKER_ENDPOINT.""" mock_nb.return_value = None self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.instance_type = None self.builder.model = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_tei() self.assertIn("Instance type", str(ctx.exception)) @@ -607,76 +683,92 @@ def test_build_sagemaker_endpoint_missing_instance_type(self, mock_create, mock_ class TestBuildForSMD(unittest.TestCase): """Test _build_for_smd method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.SMD - - @patch('sagemaker.serve.model_builder_servers.prepare_for_smd') - @patch.object(MockModelBuilderServers, '_save_model_inference_spec') - @patch.object(MockModelBuilderServers, '_get_processing_unit') - @patch.object(MockModelBuilderServers, '_get_smd_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_auto_image(self, mock_create, mock_prepare_mode, mock_get_img, - mock_get_unit, mock_save, mock_smd_prepare): + + @patch("sagemaker.serve.model_builder_servers.prepare_for_smd") + @patch.object(MockModelBuilderServers, "_save_model_inference_spec") + @patch.object(MockModelBuilderServers, "_get_processing_unit") + @patch.object(MockModelBuilderServers, "_get_smd_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_auto_image( + self, + mock_create, + mock_prepare_mode, + mock_get_img, + mock_get_unit, + mock_save, + mock_smd_prepare, + ): """Test building with auto-detected image.""" mock_get_unit.return_value = "gpu" mock_get_img.return_value = "smd-image-uri" - mock_smd_prepare.return_value = "secret999" + mock_smd_prepare.return_value = "" mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None self.builder.model = Mock() - + result = self.builder._build_for_smd() - + self.assertEqual(self.builder.image_uri, "smd-image-uri") - self.assertEqual(self.builder.secret_key, "secret999") + self.assertEqual(self.builder.secret_key, "") mock_create.assert_called_once() class TestBuildForTransformers(unittest.TestCase): """Test _build_for_transformers method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.MMS - - @patch('sagemaker.serve.model_builder_servers.save_pkl') - @patch('sagemaker.serve.model_builder_servers.prepare_for_mms') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_create_conda_env') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_inference_spec_local_container(self, mock_create, mock_prepare_mode, - mock_conda, mock_detect, mock_dir, - mock_nb, mock_mms_prepare, mock_save): + + @patch("sagemaker.serve.model_builder_servers.save_pkl") + @patch("sagemaker.serve.model_builder_servers.prepare_for_mms") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_create_conda_env") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_inference_spec_local_container( + self, + mock_create, + mock_prepare_mode, + mock_conda, + mock_detect, + mock_dir, + mock_nb, + mock_mms_prepare, + mock_save, + ): """Test building with inference_spec for LOCAL_CONTAINER.""" mock_nb.return_value = None - mock_mms_prepare.return_value = "secret111" + mock_mms_prepare.return_value = "" mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.inference_spec = Mock() - + result = self.builder._build_for_transformers() - + mock_save.assert_called_once() mock_mms_prepare.assert_called_once() - self.assertEqual(self.builder.secret_key, "secret111") + self.assertEqual(self.builder.secret_key, "") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model_string(self, mock_create, mock_prepare, mock_detect, mock_js, - mock_dir, mock_nb, mock_hf_config): + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model_string( + self, mock_create, mock_prepare, mock_detect, mock_js, mock_dir, mock_nb, mock_hf_config + ): """Test building with HuggingFace model string.""" mock_js.return_value = False mock_nb.return_value = None @@ -685,62 +777,66 @@ def test_build_with_hf_model_string(self, mock_create, mock_prepare, mock_detect mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = "gpt2" self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_transformers() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_missing_instance_type(self, mock_create, mock_prepare, - mock_detect, mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_missing_instance_type( + self, mock_create, mock_prepare, mock_detect, mock_dir, mock_nb + ): """Test error when instance_type is missing for SAGEMAKER_ENDPOINT.""" mock_nb.return_value = None self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.instance_type = None self.builder.model = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_transformers() self.assertIn("Instance type", str(ctx.exception)) - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_clean_empty_secret_key(self, mock_create, mock_prepare, mock_detect, - mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_clean_empty_secret_key( + self, mock_create, mock_prepare, mock_detect, mock_dir, mock_nb + ): """Test cleaning empty secret key from env_vars.""" mock_nb.return_value = None mock_create.return_value = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = Mock() self.builder.env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = "" - + result = self.builder._build_for_transformers() - + self.assertNotIn("SAGEMAKER_SERVE_SECRET_KEY", self.builder.env_vars) mock_create.assert_called_once() class TestBuildForJumpStart(unittest.TestCase): """Test _build_for_jumpstart and related methods.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model = "huggingface-llm-falcon-7b" - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_djl_local_container(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init): + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_djl_local_container( + self, mock_create, mock_prepare_mode, mock_djl_res, mock_init + ): """Test building DJL JumpStart model for LOCAL_CONTAINER.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "djl-inference:0.21.0" @@ -751,18 +847,20 @@ def test_build_djl_local_container(self, mock_create, mock_prepare_mode, mock_dj mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None - + result = self.builder._build_for_jumpstart() - + self.assertEqual(self.builder.model_server, ModelServer.DJL_SERVING) self.assertTrue(self.builder.prepared_for_djl) mock_create.assert_called_once() - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_tgi_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_tgi_local_container(self, mock_create, mock_prepare_mode, mock_tgi_res, mock_init): + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_tgi_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_tgi_local_container( + self, mock_create, mock_prepare_mode, mock_tgi_res, mock_init + ): """Test building TGI JumpStart model for LOCAL_CONTAINER.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "tgi-inference:1.0.0" @@ -773,18 +871,20 @@ def test_build_tgi_local_container(self, mock_create, mock_prepare_mode, mock_tg mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None - + result = self.builder._build_for_jumpstart() - + self.assertEqual(self.builder.model_server, ModelServer.TGI) self.assertTrue(self.builder.prepared_for_tgi) mock_create.assert_called_once() - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_mms_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_mms_local_container(self, mock_create, mock_prepare_mode, mock_mms_res, mock_init): + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_mms_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_mms_local_container( + self, mock_create, mock_prepare_mode, mock_mms_res, mock_init + ): """Test building MMS JumpStart model for LOCAL_CONTAINER.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "huggingface-pytorch-inference:1.10.0" @@ -795,14 +895,14 @@ def test_build_mms_local_container(self, mock_create, mock_prepare_mode, mock_mm mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None - + result = self.builder._build_for_jumpstart() - + self.assertEqual(self.builder.model_server, ModelServer.MMS) self.assertTrue(self.builder.prepared_for_mms) mock_create.assert_called_once() - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") def test_build_unsupported_image_uri(self, mock_init): """Test error for unsupported JumpStart image URI.""" mock_init_kwargs = Mock() @@ -812,16 +912,18 @@ def test_build_unsupported_image_uri(self, mock_init): mock_init.return_value = mock_init_kwargs self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_jumpstart() self.assertIn("Unsupported", str(ctx.exception)) - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_passes_config_name_to_get_init_kwargs(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init): + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_passes_config_name_to_get_init_kwargs( + self, mock_create, mock_prepare_mode, mock_djl_res, mock_init + ): """Test that config_name is forwarded to get_init_kwargs.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "djl-inference:0.21.0" @@ -846,11 +948,13 @@ def test_build_passes_config_name_to_get_init_kwargs(self, mock_create, mock_pre config_name="lmi-optimized", ) - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_passes_none_config_name_when_not_set(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init): + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_passes_none_config_name_when_not_set( + self, mock_create, mock_prepare_mode, mock_djl_res, mock_init + ): """Test that config_name defaults to None when not set.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "djl-inference:0.21.0" @@ -875,9 +979,9 @@ def test_build_passes_none_config_name_when_not_set(self, mock_create, mock_prep config_name=None, ) - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_build_for_djl_jumpstart') + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_build_for_djl_jumpstart") def test_build_sagemaker_endpoint_djl(self, mock_djl_build, mock_prepare, mock_init): """Test building DJL JumpStart for SAGEMAKER_ENDPOINT.""" mock_init_kwargs = Mock() @@ -888,157 +992,154 @@ def test_build_sagemaker_endpoint_djl(self, mock_djl_build, mock_prepare, mock_i mock_djl_build.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.image_uri = None - + result = self.builder._build_for_jumpstart() - + mock_djl_build.assert_called_once() class TestDeployWrappers(unittest.TestCase): """Test deploy wrapper methods.""" - + def setUp(self): self.builder = MockModelBuilderServers() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_djl_deploy_in_process(self, mock_deploy): """Test DJL deploy wrapper for IN_PROCESS mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.IN_PROCESS - + result = self.builder._djl_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_djl_deploy_local_container(self, mock_deploy): """Test DJL deploy wrapper for LOCAL_CONTAINER mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER - + result = self.builder._djl_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_djl_deploy_sagemaker_endpoint(self, mock_deploy): """Test DJL deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - - result = self.builder._djl_model_builder_deploy_wrapper( - model_data_download_timeout=600 - ) - + + result = self.builder._djl_model_builder_deploy_wrapper(model_data_download_timeout=600) + self.assertEqual(self.builder.env_vars["MODEL_LOADING_TIMEOUT"], "600") mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_djl_deploy_with_defaults(self, mock_deploy): """Test DJL deploy wrapper sets default values.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._djl_model_builder_deploy_wrapper() - + call_kwargs = mock_deploy.call_args[1] self.assertEqual(call_kwargs["endpoint_logging"], True) self.assertEqual(call_kwargs["initial_instance_count"], 1) - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_tgi_deploy_local_container(self, mock_deploy): """Test TGI deploy wrapper for LOCAL_CONTAINER mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER - + result = self.builder._tgi_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_tgi_deploy_sagemaker_endpoint(self, mock_deploy): """Test TGI deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._tgi_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_tei_deploy_in_process(self, mock_deploy): """Test TEI deploy wrapper for IN_PROCESS mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.IN_PROCESS - + result = self.builder._tei_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_tei_deploy_sagemaker_endpoint(self, mock_deploy): """Test TEI deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._tei_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_js_deploy_local_container(self, mock_deploy): """Test JumpStart deploy wrapper for LOCAL_CONTAINER mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER - + result = self.builder._js_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_js_deploy_sagemaker_endpoint(self, mock_deploy): """Test JumpStart deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.instance_type = "ml.g5.xlarge" - + result = self.builder._js_builder_deploy_wrapper() - + call_kwargs = mock_deploy.call_args[1] self.assertEqual(call_kwargs["instance_type"], "ml.g5.xlarge") mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_transformers_deploy_local_container(self, mock_deploy): """Test Transformers deploy wrapper for LOCAL_CONTAINER mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER - + result = self.builder._transformers_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_transformers_deploy_sagemaker_endpoint(self, mock_deploy): """Test Transformers deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._transformers_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_deploy_wrapper_removes_mode_and_role(self, mock_deploy): """Test deploy wrapper removes mode and role from kwargs.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._djl_model_builder_deploy_wrapper( - mode=Mode.LOCAL_CONTAINER, - role="arn:aws:iam::123456789012:role/test" + mode=Mode.LOCAL_CONTAINER, role="arn:aws:iam::123456789012:role/test" ) - + call_kwargs = mock_deploy.call_args[1] self.assertNotIn("mode", call_kwargs) self.assertNotIn("role", call_kwargs) @@ -1047,13 +1148,13 @@ def test_deploy_wrapper_removes_mode_and_role(self, mock_deploy): class TestJumpStartBuilders(unittest.TestCase): """Test JumpStart-specific builder methods.""" - + def setUp(self): self.builder = MockModelBuilderServers() - - @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') + + @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") def test_build_for_djl_jumpstart_local(self, mock_create, mock_prepare, mock_djl_res): """Test _build_for_djl_jumpstart for local mode.""" mock_init_kwargs = Mock() @@ -1063,15 +1164,15 @@ def test_build_for_djl_jumpstart_local(self, mock_create, mock_prepare, mock_djl self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model = "jumpstart-model-id" self.builder.s3_model_data_url = "s3://bucket/model.tar.gz" - + result = self.builder._build_for_djl_jumpstart(mock_init_kwargs) - + self.assertEqual(self.builder.model_server, ModelServer.DJL_SERVING) self.assertTrue(self.builder.prepared_for_djl) mock_djl_res.assert_called_once() mock_create.assert_called_once() - - @patch.object(MockModelBuilderServers, '_create_model') + + @patch.object(MockModelBuilderServers, "_create_model") def test_build_for_djl_jumpstart_sagemaker(self, mock_create): """Test _build_for_djl_jumpstart for SAGEMAKER_ENDPOINT mode.""" mock_init_kwargs = Mock() @@ -1079,16 +1180,16 @@ def test_build_for_djl_jumpstart_sagemaker(self, mock_create): mock_create.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = "jumpstart-model-id" - + result = self.builder._build_for_djl_jumpstart(mock_init_kwargs) - + self.assertEqual(self.builder.s3_upload_path, "s3://bucket/model.tar.gz") self.assertTrue(self.builder.prepared_for_djl) mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers.prepare_tgi_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') + + @patch("sagemaker.serve.model_builder_servers.prepare_tgi_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") def test_build_for_tgi_jumpstart_local(self, mock_create, mock_prepare, mock_tgi_res): """Test _build_for_tgi_jumpstart for local mode.""" mock_init_kwargs = Mock() @@ -1098,17 +1199,17 @@ def test_build_for_tgi_jumpstart_local(self, mock_create, mock_prepare, mock_tgi self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model = "jumpstart-model-id" self.builder.s3_model_data_url = "s3://bucket/model.tar.gz" - + result = self.builder._build_for_tgi_jumpstart(mock_init_kwargs) - + self.assertEqual(self.builder.model_server, ModelServer.TGI) self.assertTrue(self.builder.prepared_for_tgi) mock_tgi_res.assert_called_once() mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers.prepare_mms_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') + + @patch("sagemaker.serve.model_builder_servers.prepare_mms_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") def test_build_for_mms_jumpstart_local(self, mock_create, mock_prepare, mock_mms_res): """Test _build_for_mms_jumpstart for local mode.""" mock_init_kwargs = Mock() @@ -1118,9 +1219,9 @@ def test_build_for_mms_jumpstart_local(self, mock_create, mock_prepare, mock_mms self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model = "jumpstart-model-id" self.builder.s3_model_data_url = "s3://bucket/model.tar.gz" - + result = self.builder._build_for_mms_jumpstart(mock_init_kwargs) - + self.assertEqual(self.builder.model_server, ModelServer.MMS) self.assertTrue(self.builder.prepared_for_mms) mock_mms_res.assert_called_once() diff --git a/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py b/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py index 3ac82016b6..85672a8d50 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py +++ b/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py @@ -21,20 +21,21 @@ def test_triton_serializer_init(self): """Test TritonSerializer initialization.""" mock_serializer = Mock() serializer = TritonSerializer(mock_serializer, "FP32") - + self.assertEqual(serializer.dtype, "FP32") self.assertEqual(serializer.input_serializer, mock_serializer) def test_triton_serializer_serialize(self): """Test TritonSerializer serialize method.""" import numpy as np + mock_serializer = Mock() mock_array = np.array([[1, 2, 3]]) mock_serializer.serialize.return_value = mock_array - + serializer = TritonSerializer(mock_serializer, "FP32") result = serializer.serialize(mock_array) - + self.assertIsNotNone(result) @@ -45,8 +46,8 @@ def test_validate_for_triton_missing_tritonclient(self): """Test validation fails without tritonclient - skipped as tritonclient is installed.""" pass - @patch('importlib.util.find_spec') - @patch.object(_ModelBuilderUtils, '_has_nvidia_gpu') + @patch("importlib.util.find_spec") + @patch.object(_ModelBuilderUtils, "_has_nvidia_gpu") def test_validate_for_triton_no_gpu_local(self, mock_has_gpu, mock_find_spec): """Test validation fails for GPU mode without GPU.""" utils = _ModelBuilderUtils() @@ -56,23 +57,23 @@ def test_validate_for_triton_no_gpu_local(self, mock_has_gpu, mock_find_spec): utils.schema_builder = Mock() utils.schema_builder._update_serializer_deserializer_for_triton = Mock() utils.schema_builder._detect_dtype_for_triton = Mock() - + mock_find_spec.return_value = Mock() mock_has_gpu.return_value = False - + with self.assertRaises(ValueError): utils._validate_for_triton() - @patch('importlib.util.find_spec') + @patch("importlib.util.find_spec") def test_validate_for_triton_unsupported_mode(self, mock_find_spec): """Test validation fails for unsupported mode.""" utils = _ModelBuilderUtils() utils.mode = "UNSUPPORTED_MODE" utils.model_path = "/tmp/model" utils.schema_builder = Mock() - + mock_find_spec.return_value = Mock() - + with self.assertRaises(ValueError): utils._validate_for_triton() @@ -80,51 +81,51 @@ def test_validate_for_triton_unsupported_mode(self, mock_find_spec): class TestPrepareForTriton(unittest.TestCase): """Test _prepare_for_triton method.""" - @patch('shutil.copy2') - @patch.object(_ModelBuilderUtils, '_export_pytorch_to_onnx') + @patch("shutil.copy2") + @patch.object(_ModelBuilderUtils, "_export_pytorch_to_onnx") def test_prepare_for_triton_pytorch(self, mock_export, mock_copy): """Test preparing PyTorch model for Triton.""" utils = _ModelBuilderUtils() utils.framework = Framework.PYTORCH utils.model = Mock() utils.schema_builder = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir utils._prepare_for_triton() - + mock_export.assert_called_once() - @patch('shutil.copy2') - @patch.object(_ModelBuilderUtils, '_export_tf_to_onnx') + @patch("shutil.copy2") + @patch.object(_ModelBuilderUtils, "_export_tf_to_onnx") def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy): """Test preparing TensorFlow model for Triton.""" utils = _ModelBuilderUtils() utils.framework = Framework.TENSORFLOW utils.model = Mock() utils.schema_builder = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir utils._prepare_for_triton() - + mock_export.assert_called_once() - @patch('shutil.copy2') - @patch.object(_ModelBuilderUtils, '_generate_config_pbtxt') - @patch.object(_ModelBuilderUtils, '_pack_conda_env') - @patch.object(_ModelBuilderUtils, '_compute_integrity_hash') + @patch("shutil.copy2") + @patch.object(_ModelBuilderUtils, "_generate_config_pbtxt") + @patch.object(_ModelBuilderUtils, "_pack_conda_env") + @patch.object(_ModelBuilderUtils, "_compute_integrity_hash") def test_prepare_for_triton_inference_spec(self, mock_hmac, mock_pack, mock_config, mock_copy): """Test preparing inference spec for Triton.""" utils = _ModelBuilderUtils() utils.inference_spec = Mock() utils.model = None utils.schema_builder = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir utils._prepare_for_triton() - + mock_config.assert_called_once() mock_pack.assert_called_once() mock_hmac.assert_called_once() @@ -133,26 +134,27 @@ def test_prepare_for_triton_inference_spec(self, mock_hmac, mock_pack, mock_conf class TestExportPytorchToOnnx(unittest.TestCase): """Test _export_pytorch_to_onnx method.""" - @patch('torch.onnx.export') + @patch("torch.onnx.export") def test_export_pytorch_to_onnx_success(self, mock_export): """Test successful PyTorch to ONNX export.""" try: import ml_dtypes + # Skip test if ml_dtypes doesn't have required attribute - if not hasattr(ml_dtypes, 'float4_e2m1fn'): + if not hasattr(ml_dtypes, "float4_e2m1fn"): self.skipTest("ml_dtypes version incompatible with current numpy/onnx") except ImportError: pass - + utils = _ModelBuilderUtils() mock_model = Mock() mock_schema = Mock() mock_schema.sample_input = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: export_path = Path(tmpdir) utils._export_pytorch_to_onnx(mock_model, export_path, mock_schema) - + mock_export.assert_called_once() def test_export_pytorch_to_onnx_no_torch(self): @@ -167,7 +169,7 @@ class TestExportTFToOnnx(unittest.TestCase): def test_export_tf_to_onnx_no_tf2onnx(self): """Test TensorFlow export without tf2onnx installed.""" utils = _ModelBuilderUtils() - + # tf2onnx not installed in test environment with tempfile.TemporaryDirectory() as tmpdir: with self.assertRaises(ImportError): @@ -188,11 +190,11 @@ def test_generate_config_pbtxt_cpu(self): utils.schema_builder._sample_output_ndarray.shape = [1, 5] utils.schema_builder._input_triton_dtype = "FP32" utils.schema_builder._output_triton_dtype = "FP32" - + with tempfile.TemporaryDirectory() as tmpdir: pkl_path = Path(tmpdir) utils._generate_config_pbtxt(pkl_path) - + config_path = pkl_path / "config.pbtxt" self.assertTrue(config_path.exists()) content = config_path.read_text() @@ -209,11 +211,11 @@ def test_generate_config_pbtxt_gpu(self): utils.schema_builder._sample_output_ndarray.shape = [1, 5] utils.schema_builder._input_triton_dtype = "FP32" utils.schema_builder._output_triton_dtype = "FP32" - + with tempfile.TemporaryDirectory() as tmpdir: pkl_path = Path(tmpdir) utils._generate_config_pbtxt(pkl_path) - + config_path = pkl_path / "config.pbtxt" self.assertTrue(config_path.exists()) content = config_path.read_text() @@ -226,8 +228,8 @@ class TestPackCondaEnv(unittest.TestCase): def test_pack_conda_env_no_conda_pack(self): """Test packing conda env without conda_pack.""" utils = _ModelBuilderUtils() - - with patch('importlib.util.find_spec', return_value=None): + + with patch("importlib.util.find_spec", return_value=None): with tempfile.TemporaryDirectory() as tmpdir: with self.assertRaises(ImportError): utils._pack_conda_env(Path(tmpdir)) @@ -235,7 +237,7 @@ def test_pack_conda_env_no_conda_pack(self): def test_pack_conda_env_no_conda_pack_real(self): """Test packing conda env without conda_pack - real check.""" utils = _ModelBuilderUtils() - + with tempfile.TemporaryDirectory() as tmpdir: with self.assertRaises(ImportError): utils._pack_conda_env(Path(tmpdir)) @@ -249,14 +251,14 @@ def test_save_inference_spec(self): utils = _ModelBuilderUtils() utils.inference_spec = Mock() utils.schema_builder = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir pkl_path = os.path.join(tmpdir, "model_repository", "model") os.makedirs(pkl_path, exist_ok=True) - + utils._save_inference_spec() - + # Check that serve.pkl was created self.assertTrue(os.path.exists(os.path.join(pkl_path, "serve.pkl"))) @@ -265,21 +267,20 @@ class TestHMACSignin(unittest.TestCase): """Test _compute_integrity_hash method.""" def test_compute_integrity_hash(self): - """Test HMAC signing.""" + """Test SHA-256 integrity hash computation.""" utils = _ModelBuilderUtils() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir pkl_path = Path(tmpdir) / "model_repository" / "model" pkl_path.mkdir(parents=True) - + # Create dummy serve.pkl (pkl_path / "serve.pkl").write_bytes(b"dummy content") - + utils._compute_integrity_hash() - - # Secret key is generated, not mocked - self.assertIsNotNone(utils.secret_key) + + # metadata.json should be created with the SHA-256 hash self.assertTrue((pkl_path / "metadata.json").exists()) @@ -291,9 +292,9 @@ def test_auto_detect_image_skip_if_provided(self): utils = _ModelBuilderUtils() utils.image_uri = "custom-triton-image" utils.sagemaker_session = Mock() - + utils._auto_detect_image_for_triton() - + self.assertEqual(utils.image_uri, "custom-triton-image") def test_auto_detect_image_cpu_instance(self): @@ -306,9 +307,9 @@ def test_auto_detect_image_cpu_instance(self): utils.inference_spec = None utils.framework = "pytorch" utils.version = "1.13" - + utils._auto_detect_image_for_triton() - + self.assertIsNotNone(utils.image_uri) self.assertIn("-cpu", utils.image_uri) @@ -322,9 +323,9 @@ def test_auto_detect_image_gpu_instance(self): utils.inference_spec = None utils.framework = "pytorch" utils.version = "1.13" - + utils._auto_detect_image_for_triton() - + self.assertIsNotNone(utils.image_uri) self.assertNotIn("-cpu", utils.image_uri) @@ -335,7 +336,7 @@ def test_auto_detect_image_unsupported_region(self): utils.instance_type = "ml.g5.xlarge" utils.sagemaker_session = Mock() utils.sagemaker_session.boto_region_name = "unsupported-region" - + with self.assertRaises(ValueError): utils._auto_detect_image_for_triton() @@ -349,7 +350,7 @@ def test_validate_djl_valid_data(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = {"inputs": "test", "parameters": {}} utils.schema_builder.sample_output = [{"generated_text": "output"}] - + # Should not raise utils._validate_djl_serving_sample_data() @@ -359,7 +360,7 @@ def test_validate_djl_invalid_input(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = {"wrong_key": "test"} utils.schema_builder.sample_output = [{"generated_text": "output"}] - + with self.assertRaises(ValueError): utils._validate_djl_serving_sample_data() @@ -369,7 +370,7 @@ def test_validate_djl_invalid_output(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = {"inputs": "test", "parameters": {}} utils.schema_builder.sample_output = [{"wrong_key": "output"}] - + with self.assertRaises(ValueError): utils._validate_djl_serving_sample_data() @@ -383,7 +384,7 @@ def test_validate_tgi_valid_data(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = {"inputs": "test", "parameters": {}} utils.schema_builder.sample_output = [{"generated_text": "output"}] - + # Should not raise utils._validate_tgi_serving_sample_data() @@ -393,7 +394,7 @@ def test_validate_tgi_invalid_input(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = "invalid" utils.schema_builder.sample_output = [{"generated_text": "output"}] - + with self.assertRaises(ValueError): utils._validate_tgi_serving_sample_data() @@ -401,15 +402,15 @@ def test_validate_tgi_invalid_input(self): class TestCreateCondaEnv(unittest.TestCase): """Test _create_conda_env method.""" - @patch('sagemaker.serve.builder.requirements_manager.RequirementsManager') + @patch("sagemaker.serve.builder.requirements_manager.RequirementsManager") def test_create_conda_env_success(self, mock_req_manager): """Test successful conda env creation.""" utils = _ModelBuilderUtils() mock_manager = Mock() mock_req_manager.return_value = mock_manager - + utils._create_conda_env() - + # Should not raise diff --git a/sagemaker-serve/tests/unit/validations/test_check_integrity.py b/sagemaker-serve/tests/unit/validations/test_check_integrity.py index 11e66eb716..cc05c460bb 100644 --- a/sagemaker-serve/tests/unit/validations/test_check_integrity.py +++ b/sagemaker-serve/tests/unit/validations/test_check_integrity.py @@ -1,39 +1,27 @@ import unittest -import tempfile from pathlib import Path from unittest.mock import patch, mock_open -from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, - compute_hash, - perform_integrity_check -) +from sagemaker.serve.validations.check_integrity import compute_hash, perform_integrity_check class TestCheckIntegrity(unittest.TestCase): - def test_generate_secret_key(self): - key = generate_secret_key() - self.assertIsInstance(key, str) - self.assertEqual(len(key), 64) - - def test_generate_secret_key_custom_bytes(self): - key = generate_secret_key(nbytes=16) - self.assertEqual(len(key), 32) - def test_compute_hash(self): buffer = b"test data" - secret_key = "test_secret" - hash_value = compute_hash(buffer, secret_key) + hash_value = compute_hash(buffer) self.assertIsInstance(hash_value, str) self.assertEqual(len(hash_value), 64) def test_compute_hash_consistency(self): buffer = b"test data" - secret_key = "test_secret" - hash1 = compute_hash(buffer, secret_key) - hash2 = compute_hash(buffer, secret_key) + hash1 = compute_hash(buffer) + hash2 = compute_hash(buffer) self.assertEqual(hash1, hash2) - @patch.dict("os.environ", {"SAGEMAKER_SERVE_SECRET_KEY": "test_key"}) + def test_compute_hash_different_data(self): + hash1 = compute_hash(b"data1") + hash2 = compute_hash(b"data2") + self.assertNotEqual(hash1, hash2) + @patch("pathlib.Path.exists") @patch("builtins.open", new_callable=mock_open, read_data=b'{"sha256_hash": "test_hash"}') @patch("sagemaker.serve.validations.check_integrity._MetaData.from_json") @@ -41,10 +29,14 @@ def test_perform_integrity_check_failure(self, mock_metadata, mock_file, mock_ex mock_exists.return_value = True mock_meta = type("obj", (object,), {"sha256_hash": "wrong_hash"})() mock_metadata.return_value = mock_meta - + with self.assertRaises(ValueError): perform_integrity_check(b"test", Path("/tmp/metadata.json")) + def test_perform_integrity_check_missing_metadata(self): + with self.assertRaises(ValueError, msg="Path to metadata.json does not exist"): + perform_integrity_check(b"test", Path("/nonexistent/metadata.json")) + if __name__ == "__main__": unittest.main()