diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 3f4156e219..aae66bc8ba 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1007,6 +1007,8 @@ def register( if compile_model_family is not None: model = self._compiled_models[compile_model_family] else: + if "model_kms_key" not in kwargs: + kwargs["model_kms_key"] = self.output_kms_key model = self.create_model(image_uri=image_uri, **kwargs) model.name = model_name return model.register( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index a2ab24e3da..bf2e87ed29 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -60,6 +60,7 @@ def __init__( source_dir: str = None, dependencies: List = None, depends_on: List[str] = None, + **kwargs, ): """Constructs a TrainingStep, given an `EstimatorBase` instance. @@ -98,6 +99,7 @@ def __init__( "inference_script": self._entry_point_basename, "model_archive": self._model_archive, }, + **kwargs, ) repacker.disable_profiler = True inputs = TrainingInput(self._model_prefix) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 6ee048c0b2..143c59395c 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -105,9 +105,11 @@ def __init__( repack_model = False if "entry_point" in kwargs: repack_model = True - entry_point = kwargs["entry_point"] + entry_point = kwargs.pop("entry_point", None) source_dir = kwargs.get("source_dir") dependencies = kwargs.get("dependencies") + kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None)) + repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, @@ -116,6 +118,7 @@ def __init__( entry_point=entry_point, source_dir=source_dir, dependencies=dependencies, + **kwargs, ) steps.append(repack_model_step) model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts @@ -124,6 +127,7 @@ def __init__( kwargs.pop("entry_point", None) kwargs.pop("source_dir", None) kwargs.pop("dependencies", None) + kwargs.pop("output_kms_key", None) register_model_step = _RegisterModelStep( name=name, diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 75803b77e5..c85e2f3aa2 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -65,6 +65,7 @@ from sagemaker.workflow.pipeline import Pipeline from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum from tests.integ import DATA_DIR +from tests.integ.kms_utils import get_or_create_kms_key def ordered(obj): @@ -890,6 +891,7 @@ def test_model_registration_with_model_repack( pipeline_name, region_name, ): + kms_key = get_or_create_kms_key(sagemaker_session, role) base_dir = os.path.join(DATA_DIR, "pytorch_mnist") entry_point = os.path.join(base_dir, "mnist.py") input_path = sagemaker_session.upload_data( @@ -910,6 +912,7 @@ def test_model_registration_with_model_repack( instance_count=instance_count, instance_type=instance_type, sagemaker_session=sagemaker_session, + output_kms_key=kms_key, ) step_train = TrainingStep( name="pytorch-train", @@ -921,12 +924,13 @@ def test_model_registration_with_model_repack( name="pytorch-register-model", estimator=pytorch_estimator, model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, - content_types=["*"], - response_types=["*"], - inference_instances=["*"], - transform_instances=["*"], + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.large"], + transform_instances=["ml.m5.large"], description="test-description", entry_point=entry_point, + model_kms_key=kms_key, ) model = Model(