Skip to content

Commit d6d8e08

Browse files
fix: kms key does not propapate in register model step
1 parent dd76ad7 commit d6d8e08

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,7 @@ def register(
10071007
if compile_model_family is not None:
10081008
model = self._compiled_models[compile_model_family]
10091009
else:
1010+
kwargs["model_kms_key"] = self.output_kms_key
10101011
model = self.create_model(image_uri=image_uri, **kwargs)
10111012
model.name = model_name
10121013
return model.register(

src/sagemaker/workflow/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
source_dir: str = None,
6161
dependencies: List = None,
6262
depends_on: List[str] = None,
63+
**kwargs,
6364
):
6465
"""Constructs a TrainingStep, given an `EstimatorBase` instance.
6566
@@ -98,6 +99,7 @@ def __init__(
9899
"inference_script": self._entry_point_basename,
99100
"model_archive": self._model_archive,
100101
},
102+
**kwargs,
101103
)
102104
repacker.disable_profiler = True
103105
inputs = TrainingInput(self._model_prefix)

src/sagemaker/workflow/step_collections.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ def __init__(
105105
repack_model = False
106106
if "entry_point" in kwargs:
107107
repack_model = True
108-
entry_point = kwargs["entry_point"]
108+
entry_point = kwargs.pop("entry_point", None)
109109
source_dir = kwargs.get("source_dir")
110110
dependencies = kwargs.get("dependencies")
111+
kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None))
112+
111113
repack_model_step = _RepackModelStep(
112114
name=f"{name}RepackModel",
113115
depends_on=depends_on,
@@ -116,6 +118,7 @@ def __init__(
116118
entry_point=entry_point,
117119
source_dir=source_dir,
118120
dependencies=dependencies,
121+
**kwargs,
119122
)
120123
steps.append(repack_model_step)
121124
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
@@ -124,6 +127,7 @@ def __init__(
124127
kwargs.pop("entry_point", None)
125128
kwargs.pop("source_dir", None)
126129
kwargs.pop("dependencies", None)
130+
kwargs.pop("output_kms_key", None)
127131

128132
register_model_step = _RegisterModelStep(
129133
name=name,

tests/integ/test_workflow.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from sagemaker.workflow.pipeline import Pipeline
6666
from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum
6767
from tests.integ import DATA_DIR
68+
from tests.integ.kms_utils import get_or_create_kms_key
6869

6970

7071
def ordered(obj):
@@ -849,6 +850,7 @@ def test_model_registration_with_model_repack(
849850
pipeline_name,
850851
region_name,
851852
):
853+
kms_key = get_or_create_kms_key(sagemaker_session, role)
852854
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
853855
entry_point = os.path.join(base_dir, "mnist.py")
854856
input_path = sagemaker_session.upload_data(
@@ -869,6 +871,7 @@ def test_model_registration_with_model_repack(
869871
instance_count=instance_count,
870872
instance_type=instance_type,
871873
sagemaker_session=sagemaker_session,
874+
output_kms_key=kms_key,
872875
)
873876
step_train = TrainingStep(
874877
name="pytorch-train",
@@ -880,12 +883,13 @@ def test_model_registration_with_model_repack(
880883
name="pytorch-register-model",
881884
estimator=pytorch_estimator,
882885
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
883-
content_types=["*"],
884-
response_types=["*"],
885-
inference_instances=["*"],
886-
transform_instances=["*"],
886+
content_types=["text/csv"],
887+
response_types=["text/csv"],
888+
inference_instances=["ml.t2.medium", "ml.m5.large"],
889+
transform_instances=["ml.m5.large"],
887890
description="test-description",
888891
entry_point=entry_point,
892+
model_kms_key=kms_key,
889893
)
890894

891895
model = Model(

0 commit comments

Comments
 (0)