Skip to content

Commit 530d21b

Browse files
fix: kms key does not propapate in register model step (#2481)
* fix: kms key does not propapate in register model step * fix: cr comments Co-authored-by: icywang86rui <[email protected]>
1 parent a05b10b commit 530d21b

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,8 @@ def register(
10071007
if compile_model_family is not None:
10081008
model = self._compiled_models[compile_model_family]
10091009
else:
1010+
if "model_kms_key" not in kwargs:
1011+
kwargs["model_kms_key"] = self.output_kms_key
10101012
model = self.create_model(image_uri=image_uri, **kwargs)
10111013
model.name = model_name
10121014
return model.register(

src/sagemaker/workflow/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
source_dir: str = None,
6161
dependencies: List = None,
6262
depends_on: List[str] = None,
63+
**kwargs,
6364
):
6465
"""Constructs a TrainingStep, given an `EstimatorBase` instance.
6566
@@ -98,6 +99,7 @@ def __init__(
9899
"inference_script": self._entry_point_basename,
99100
"model_archive": self._model_archive,
100101
},
102+
**kwargs,
101103
)
102104
repacker.disable_profiler = True
103105
inputs = TrainingInput(self._model_prefix)

src/sagemaker/workflow/step_collections.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ def __init__(
105105
repack_model = False
106106
if "entry_point" in kwargs:
107107
repack_model = True
108-
entry_point = kwargs["entry_point"]
108+
entry_point = kwargs.pop("entry_point", None)
109109
source_dir = kwargs.get("source_dir")
110110
dependencies = kwargs.get("dependencies")
111+
kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None))
112+
111113
repack_model_step = _RepackModelStep(
112114
name=f"{name}RepackModel",
113115
depends_on=depends_on,
@@ -116,6 +118,7 @@ def __init__(
116118
entry_point=entry_point,
117119
source_dir=source_dir,
118120
dependencies=dependencies,
121+
**kwargs,
119122
)
120123
steps.append(repack_model_step)
121124
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
@@ -124,6 +127,7 @@ def __init__(
124127
kwargs.pop("entry_point", None)
125128
kwargs.pop("source_dir", None)
126129
kwargs.pop("dependencies", None)
130+
kwargs.pop("output_kms_key", None)
127131

128132
register_model_step = _RegisterModelStep(
129133
name=name,

tests/integ/test_workflow.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from sagemaker.workflow.pipeline import Pipeline
6666
from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum
6767
from tests.integ import DATA_DIR
68+
from tests.integ.kms_utils import get_or_create_kms_key
6869

6970

7071
def ordered(obj):
@@ -890,6 +891,7 @@ def test_model_registration_with_model_repack(
890891
pipeline_name,
891892
region_name,
892893
):
894+
kms_key = get_or_create_kms_key(sagemaker_session, role)
893895
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
894896
entry_point = os.path.join(base_dir, "mnist.py")
895897
input_path = sagemaker_session.upload_data(
@@ -910,6 +912,7 @@ def test_model_registration_with_model_repack(
910912
instance_count=instance_count,
911913
instance_type=instance_type,
912914
sagemaker_session=sagemaker_session,
915+
output_kms_key=kms_key,
913916
)
914917
step_train = TrainingStep(
915918
name="pytorch-train",
@@ -921,12 +924,13 @@ def test_model_registration_with_model_repack(
921924
name="pytorch-register-model",
922925
estimator=pytorch_estimator,
923926
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
924-
content_types=["*"],
925-
response_types=["*"],
926-
inference_instances=["*"],
927-
transform_instances=["*"],
927+
content_types=["text/csv"],
928+
response_types=["text/csv"],
929+
inference_instances=["ml.t2.medium", "ml.m5.large"],
930+
transform_instances=["ml.m5.large"],
928931
description="test-description",
929932
entry_point=entry_point,
933+
model_kms_key=kms_key,
930934
)
931935

932936
model = Model(

0 commit comments

Comments
 (0)