Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
917 changes: 461 additions & 456 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -119,11 +118,8 @@ def prepare_for_mms(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _start_serving(
env = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
if env_vars:
Expand All @@ -47,7 +46,7 @@ def _start_serving(
image,
"serve",
# network_mode="host",
ports={'8080/tcp': 8080},
ports={"8080/tcp": 8080},
detach=True,
auto_remove=True,
volumes={
Expand Down Expand Up @@ -131,7 +130,6 @@ def _upload_server_artifacts(
env_vars = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"LOCAL_PYTHON": platform.python_version(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -64,11 +63,8 @@ def prepare_for_smd(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def _upload_smd_artifacts(
"SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_INFERENCE_CODE": "inference.handler",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def _start_tei_serving(
secret_key: Secret key to use for authentication
env_vars: Environment variables to set
"""
if env_vars and secret_key:
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

self.container = client.containers.run(
image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData
Expand Down Expand Up @@ -56,12 +55,9 @@ def prepare_for_tf_serving(
if not mlflow_saved_model_dir:
raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.")
_move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir)

secret_key = generate_secret_key()

with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _start_tensorflow_serving(
detach=True,
auto_remove=False, # Temporarily disabled to see crash logs
# network_mode="host",
ports={'8501/tcp': 8501},
ports={"8501/tcp": 8501},
volumes={
Path(model_path): {
"bind": "/opt/ml/model",
Expand All @@ -47,7 +47,6 @@ def _start_tensorflow_serving(
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
Expand Down Expand Up @@ -124,7 +123,6 @@ def _upload_tensorflow_serving_artifacts(
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
Expand Down Expand Up @@ -56,7 +55,9 @@ def prepare_for_torchserve(
# https://github.com/aws/sagemaker-python-sdk/issues/4288
if is_1p_image_uri(image_uri=image_uri) and "xgboost" in image_uri:
shutil.copy2(Path(__file__).parent.joinpath("xgboost_inference.py"), code_dir)
os.rename(str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py")))
os.rename(
str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py"))
)
else:
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)

Expand All @@ -67,11 +68,8 @@ def prepare_for_torchserve(

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
hash_value = compute_hash(buffer=buffer)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _start_torch_serve(
detach=True,
auto_remove=True,
# network_mode="host",
ports={'8080/tcp': 8080},
ports={"8080/tcp": 8080},
volumes={
Path(model_path): {
"bind": "/opt/ml/model",
Expand All @@ -39,7 +39,6 @@ def _start_torch_serve(
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
Expand Down Expand Up @@ -103,7 +102,6 @@ def _upload_torchserve_artifacts(
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
10 changes: 7 additions & 3 deletions sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ def auto_complete_config(auto_complete_model_config):
def initialize(self, args: dict) -> None:
"""Placeholder docstring"""
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
with open(str(serve_path), mode="rb") as f:
inference_spec, schema_builder = cloudpickle.load(f)
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")

# TODO: HMAC signing for integrity check
# Integrity check BEFORE deserialization to prevent RCE via malicious pickle
with open(str(serve_path), "rb") as f:
buffer = f.read()
perform_integrity_check(buffer=buffer, metadata_path=metadata_path)

inference_spec, schema_builder = cloudpickle.loads(buffer)

self.inference_spec = inference_spec
self.schema_builder = schema_builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _start_triton_server(
env_vars.update(
{
"TRITON_MODEL_DIR": "/models/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
)
Expand Down Expand Up @@ -133,7 +132,6 @@ def _upload_triton_artifacts(
env_vars = {
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
"TRITON_MODEL_DIR": "/opt/ml/model/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars
20 changes: 6 additions & 14 deletions sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
"""Validates the integrity of pickled file with HMAC signing."""
"""Validates the integrity of pickled file with SHA-256 hash."""

from __future__ import absolute_import
import secrets
import hmac
import hashlib
import os
from pathlib import Path

from sagemaker.core.remote_function.core.serialization import _MetaData


def generate_secret_key(nbytes: int = 32) -> str:
"""Generates secret key"""
return secrets.token_hex(nbytes)


def compute_hash(buffer: bytes, secret_key: str) -> str:
"""Compute hash value using HMAC"""
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
def compute_hash(buffer: bytes) -> str:
"""Compute SHA-256 hash of the given buffer."""
return hashlib.sha256(buffer).hexdigest()


def perform_integrity_check(buffer: bytes, metadata_path: Path):
"""Validates the integrity of bytes by comparing the hash value"""
secret_key = os.environ.get("SAGEMAKER_SERVE_SECRET_KEY")
actual_hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
"""Validates the integrity of bytes by comparing the hash value."""
actual_hash_value = compute_hash(buffer=buffer)

if not Path.exists(metadata_path):
raise ValueError("Path to metadata.json does not exist")
Expand Down
2 changes: 1 addition & 1 deletion sagemaker-serve/tests/unit/model_server/test_djl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
_get_default_batch_size,
_tokens_from_chars,
_tokens_from_words,
_set_tokens_to_tokens_threshold
_set_tokens_to_tokens_threshold,
)


Expand Down
Loading
Loading