@@ -112,8 +112,8 @@ def __init__(
112
112
if "entry_point" in kwargs :
113
113
repack_model = True
114
114
entry_point = kwargs .pop ("entry_point" , None )
115
- source_dir = kwargs .get ("source_dir" )
116
- dependencies = kwargs .get ("dependencies" )
115
+ source_dir = kwargs .pop ("source_dir" , None )
116
+ dependencies = kwargs .pop ("dependencies" , None )
117
117
kwargs = dict (** kwargs , output_kms_key = kwargs .pop ("model_kms_key" , None ))
118
118
119
119
repack_model_step = _RepackModelStep (
@@ -130,13 +130,10 @@ def __init__(
130
130
steps .append (repack_model_step )
131
131
model_data = repack_model_step .properties .ModelArtifacts .S3ModelArtifacts
132
132
133
- # remove kwargs consumed by model repacking step
134
- kwargs .pop ("entry_point" , None )
135
- kwargs .pop ("source_dir" , None )
136
- kwargs .pop ("dependencies" , None )
137
- kwargs .pop ("output_kms_key" , None )
133
+ # remove kwargs consumed by model repacking step
134
+ kwargs .pop ("output_kms_key" , None )
138
135
139
- if model is not None :
136
+ elif model is not None :
140
137
if isinstance (model , PipelineModel ):
141
138
self .model_list = model .models
142
139
self .container_def_list = model .pipeline_container_def (inference_instances [0 ])
@@ -156,7 +153,9 @@ def __init__(
156
153
entry_point = model_entity .entry_point
157
154
source_dir = model_entity .source_dir
158
155
dependencies = model_entity .dependencies
156
+ kwargs = dict (** kwargs , output_kms_key = model_entity .model_kms_key )
159
157
name = model_entity .name or model_entity ._framework_name
158
+
160
159
repack_model_step = _RepackModelStep (
161
160
name = f"{ name } RepackModel" ,
162
161
depends_on = depends_on ,
@@ -166,12 +165,16 @@ def __init__(
166
165
entry_point = entry_point ,
167
166
source_dir = source_dir ,
168
167
dependencies = dependencies ,
168
+ ** kwargs ,
169
169
)
170
170
steps .append (repack_model_step )
171
171
model_entity .model_data = (
172
172
repack_model_step .properties .ModelArtifacts .S3ModelArtifacts
173
173
)
174
174
175
+ # remove kwargs consumed by model repacking step
176
+ kwargs .pop ("output_kms_key" , None )
177
+
175
178
register_model_step = _RegisterModelStep (
176
179
name = name ,
177
180
estimator = estimator ,
0 commit comments