29
29
_REGISTER_MODEL_NAME_BASE = "RegisterModel"
30
30
_CREATE_MODEL_NAME_BASE = "CreateModel"
31
31
_REPACK_MODEL_NAME_BASE = "RepackModel"
32
+ _IGNORED_REPACK_PARAM_LIST = ["entry_point" , "source_dir" , "hyperparameters" , "dependencies" ]
33
+
34
+ logger = logging .getLogger (__name__ )
32
35
33
36
34
37
class ModelStep (StepCollection ):
@@ -42,6 +45,7 @@ def __init__(
42
45
retry_policies : Optional [Union [List [RetryPolicy ], Dict [str , List [RetryPolicy ]]]] = None ,
43
46
display_name : Optional [str ] = None ,
44
47
description : Optional [str ] = None ,
48
+ repack_model_step_settings : Optional [Dict [str , any ]] = None ,
45
49
):
46
50
"""Constructs a `ModelStep`.
47
51
@@ -115,6 +119,15 @@ def __init__(
115
119
display_name (str): The display name of the `ModelStep`.
116
120
The display name provides better UI readability. (default: None).
117
121
description (str): The description of the `ModelStep` (default: None).
122
+ repack_model_step_settings (Dict[str, any]): The kwargs passed to the _RepackModelStep
123
+ to customize the configuration of the underlying repack model job (default: None).
124
+ Notes:
125
+ 1. If the _RepackModelStep is unnecessary, the settings will be ignored.
126
+ 2. If the _RepackModelStep is added, the repack_model_step_settings
127
+ is honored if set.
128
+ 3. In repack_model_step_settings, the arguments with misspelled keys will be
129
+ ignored. Please refer to the expected parameters of repack model job in
130
+ :class:`~sagemaker.sklearn.estimator.SKLearn` and its base classes.
118
131
"""
119
132
from sagemaker .workflow .utilities import validate_step_args_input
120
133
@@ -148,6 +161,9 @@ def __init__(
148
161
self .display_name = display_name
149
162
self .description = description
150
163
self .steps : List [Step ] = []
164
+ self ._repack_model_step_settings = (
165
+ dict (repack_model_step_settings ) if repack_model_step_settings else {}
166
+ )
151
167
self ._model = step_args .model
152
168
self ._create_model_args = self .step_args .create_model_request
153
169
self ._register_model_args = self .step_args .create_model_package_request
@@ -157,6 +173,12 @@ def __init__(
157
173
158
174
if self ._need_runtime_repack :
159
175
self ._append_repack_model_step ()
176
+ elif self ._repack_model_step_settings :
177
+ logger .warning (
178
+ "Non-empty repack_model_step_settings is supplied but no repack model "
179
+ "step is needed. Ignoring the repack_model_step_settings."
180
+ )
181
+
160
182
if self ._register_model_args :
161
183
self ._append_register_model_step ()
162
184
else :
@@ -235,14 +257,12 @@ def _append_repack_model_step(self):
235
257
elif isinstance (self ._model , Model ):
236
258
model_list = [self ._model ]
237
259
else :
238
- logging .warning ("No models to repack" )
260
+ logger .warning ("No models to repack" )
239
261
return
240
262
241
- security_group_ids = None
242
- subnets = None
243
- if self ._model .vpc_config :
244
- security_group_ids = self ._model .vpc_config .get ("SecurityGroupIds" , None )
245
- subnets = self ._model .vpc_config .get ("Subnets" , None )
263
+ self ._pop_out_non_configurable_repack_model_step_args ()
264
+
265
+ security_group_ids , subnets = self ._resolve_repack_model_step_vpc_configs ()
246
266
247
267
for i , model in enumerate (model_list ):
248
268
runtime_repack_flg = (
@@ -252,8 +272,16 @@ def _append_repack_model_step(self):
252
272
name_base = model .name or i
253
273
repack_model_step = _RepackModelStep (
254
274
name = "{}-{}-{}" .format (self .name , _REPACK_MODEL_NAME_BASE , name_base ),
255
- sagemaker_session = self ._model .sagemaker_session or model .sagemaker_session ,
256
- role = self ._model .role or model .role ,
275
+ sagemaker_session = (
276
+ self ._repack_model_step_settings .pop ("sagemaker_session" , None )
277
+ or self ._model .sagemaker_session
278
+ or model .sagemaker_session
279
+ ),
280
+ role = (
281
+ self ._repack_model_step_settings .pop ("role" , None )
282
+ or self ._model .role
283
+ or model .role
284
+ ),
257
285
model_data = model .model_data ,
258
286
entry_point = model .entry_point ,
259
287
source_dir = model .source_dir ,
@@ -266,8 +294,15 @@ def _append_repack_model_step(self):
266
294
),
267
295
depends_on = self .depends_on ,
268
296
retry_policies = self ._repack_model_retry_policies ,
269
- output_path = self ._runtime_repack_output_prefix ,
270
- output_kms_key = model .model_kms_key ,
297
+ output_path = (
298
+ self ._repack_model_step_settings .pop ("output_path" , None )
299
+ or self ._runtime_repack_output_prefix
300
+ ),
301
+ output_kms_key = (
302
+ self ._repack_model_step_settings .pop ("output_kms_key" , None )
303
+ or model .model_kms_key
304
+ ),
305
+ ** self ._repack_model_step_settings
271
306
)
272
307
self .steps .append (repack_model_step )
273
308
@@ -282,3 +317,32 @@ def _append_repack_model_step(self):
282
317
"InferenceSpecification"
283
318
]["Containers" ][i ]
284
319
container ["ModelDataUrl" ] = repacked_model_data
320
+
321
+ def _pop_out_non_configurable_repack_model_step_args (self ):
322
+ """Pop out non-configurable args from _repack_model_step_settings"""
323
+ if not self ._repack_model_step_settings :
324
+ return
325
+ for ignored_param in _IGNORED_REPACK_PARAM_LIST :
326
+ if self ._repack_model_step_settings .pop (ignored_param , None ):
327
+ logger .warning (
328
+ "The repack model step parameter - %s is not configurable. Ignoring it." ,
329
+ ignored_param ,
330
+ )
331
+
332
+ def _resolve_repack_model_step_vpc_configs (self ):
333
+ """Resolve vpc configs for repack model step"""
334
+ # Note: the EstimatorBase constructor ensures that:
335
+ # "When setting up custom VPC, both subnets and security_group_ids must be set"
336
+ if self ._repack_model_step_settings .get (
337
+ "security_group_ids" , None
338
+ ) or self ._repack_model_step_settings .get ("subnets" , None ):
339
+ security_group_ids = self ._repack_model_step_settings .pop ("security_group_ids" , None )
340
+ subnets = self ._repack_model_step_settings .pop ("subnets" , None )
341
+ return security_group_ids , subnets
342
+
343
+ if self ._model .vpc_config :
344
+ security_group_ids = self ._model .vpc_config .get ("SecurityGroupIds" , None )
345
+ subnets = self ._model .vpc_config .get ("Subnets" , None )
346
+ return security_group_ids , subnets
347
+
348
+ return None , None
0 commit comments