Skip to content

Commit 086c946

Browse files
authored
feat: Support custom repack model settings (aws#4328)
1 parent 668e65d commit 086c946

File tree

3 files changed

+353
-132
lines changed

3 files changed

+353
-132
lines changed

src/sagemaker/workflow/_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ def __init__(
172172

173173
# the real estimator and inputs
174174
repacker = SKLearn(
175-
framework_version=FRAMEWORK_VERSION,
176-
instance_type=INSTANCE_TYPE,
175+
framework_version=kwargs.pop("framework_version", None) or FRAMEWORK_VERSION,
176+
instance_type=kwargs.pop("instance_type", None) or INSTANCE_TYPE,
177177
entry_point=REPACK_SCRIPT_LAUNCHER,
178178
source_dir=self._source_dir,
179179
dependencies=self._dependencies,

src/sagemaker/workflow/model_step.py

+74-10
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
_REGISTER_MODEL_NAME_BASE = "RegisterModel"
3030
_CREATE_MODEL_NAME_BASE = "CreateModel"
3131
_REPACK_MODEL_NAME_BASE = "RepackModel"
32+
_IGNORED_REPACK_PARAM_LIST = ["entry_point", "source_dir", "hyperparameters", "dependencies"]
33+
34+
logger = logging.getLogger(__name__)
3235

3336

3437
class ModelStep(StepCollection):
@@ -42,6 +45,7 @@ def __init__(
4245
retry_policies: Optional[Union[List[RetryPolicy], Dict[str, List[RetryPolicy]]]] = None,
4346
display_name: Optional[str] = None,
4447
description: Optional[str] = None,
48+
repack_model_step_settings: Optional[Dict[str, any]] = None,
4549
):
4650
"""Constructs a `ModelStep`.
4751
@@ -115,6 +119,15 @@ def __init__(
115119
display_name (str): The display name of the `ModelStep`.
116120
The display name provides better UI readability. (default: None).
117121
description (str): The description of the `ModelStep` (default: None).
122+
repack_model_step_settings (Dict[str, any]): The kwargs passed to the _RepackModelStep
123+
to customize the configuration of the underlying repack model job (default: None).
124+
Notes:
125+
1. If the _RepackModelStep is unnecessary, the settings will be ignored.
126+
2. If the _RepackModelStep is added, the repack_model_step_settings
127+
is honored if set.
128+
3. In repack_model_step_settings, the arguments with misspelled keys will be
129+
ignored. Please refer to the expected parameters of repack model job in
130+
:class:`~sagemaker.sklearn.estimator.SKLearn` and its base classes.
118131
"""
119132
from sagemaker.workflow.utilities import validate_step_args_input
120133

@@ -148,6 +161,9 @@ def __init__(
148161
self.display_name = display_name
149162
self.description = description
150163
self.steps: List[Step] = []
164+
self._repack_model_step_settings = (
165+
dict(repack_model_step_settings) if repack_model_step_settings else {}
166+
)
151167
self._model = step_args.model
152168
self._create_model_args = self.step_args.create_model_request
153169
self._register_model_args = self.step_args.create_model_package_request
@@ -157,6 +173,12 @@ def __init__(
157173

158174
if self._need_runtime_repack:
159175
self._append_repack_model_step()
176+
elif self._repack_model_step_settings:
177+
logger.warning(
178+
"Non-empty repack_model_step_settings is supplied but no repack model "
179+
"step is needed. Ignoring the repack_model_step_settings."
180+
)
181+
160182
if self._register_model_args:
161183
self._append_register_model_step()
162184
else:
@@ -235,14 +257,12 @@ def _append_repack_model_step(self):
235257
elif isinstance(self._model, Model):
236258
model_list = [self._model]
237259
else:
238-
logging.warning("No models to repack")
260+
logger.warning("No models to repack")
239261
return
240262

241-
security_group_ids = None
242-
subnets = None
243-
if self._model.vpc_config:
244-
security_group_ids = self._model.vpc_config.get("SecurityGroupIds", None)
245-
subnets = self._model.vpc_config.get("Subnets", None)
263+
self._pop_out_non_configurable_repack_model_step_args()
264+
265+
security_group_ids, subnets = self._resolve_repack_model_step_vpc_configs()
246266

247267
for i, model in enumerate(model_list):
248268
runtime_repack_flg = (
@@ -252,8 +272,16 @@ def _append_repack_model_step(self):
252272
name_base = model.name or i
253273
repack_model_step = _RepackModelStep(
254274
name="{}-{}-{}".format(self.name, _REPACK_MODEL_NAME_BASE, name_base),
255-
sagemaker_session=self._model.sagemaker_session or model.sagemaker_session,
256-
role=self._model.role or model.role,
275+
sagemaker_session=(
276+
self._repack_model_step_settings.pop("sagemaker_session", None)
277+
or self._model.sagemaker_session
278+
or model.sagemaker_session
279+
),
280+
role=(
281+
self._repack_model_step_settings.pop("role", None)
282+
or self._model.role
283+
or model.role
284+
),
257285
model_data=model.model_data,
258286
entry_point=model.entry_point,
259287
source_dir=model.source_dir,
@@ -266,8 +294,15 @@ def _append_repack_model_step(self):
266294
),
267295
depends_on=self.depends_on,
268296
retry_policies=self._repack_model_retry_policies,
269-
output_path=self._runtime_repack_output_prefix,
270-
output_kms_key=model.model_kms_key,
297+
output_path=(
298+
self._repack_model_step_settings.pop("output_path", None)
299+
or self._runtime_repack_output_prefix
300+
),
301+
output_kms_key=(
302+
self._repack_model_step_settings.pop("output_kms_key", None)
303+
or model.model_kms_key
304+
),
305+
**self._repack_model_step_settings
271306
)
272307
self.steps.append(repack_model_step)
273308

@@ -282,3 +317,32 @@ def _append_repack_model_step(self):
282317
"InferenceSpecification"
283318
]["Containers"][i]
284319
container["ModelDataUrl"] = repacked_model_data
320+
321+
def _pop_out_non_configurable_repack_model_step_args(self):
322+
"""Pop out non-configurable args from _repack_model_step_settings"""
323+
if not self._repack_model_step_settings:
324+
return
325+
for ignored_param in _IGNORED_REPACK_PARAM_LIST:
326+
if self._repack_model_step_settings.pop(ignored_param, None):
327+
logger.warning(
328+
"The repack model step parameter - %s is not configurable. Ignoring it.",
329+
ignored_param,
330+
)
331+
332+
def _resolve_repack_model_step_vpc_configs(self):
333+
"""Resolve vpc configs for repack model step"""
334+
# Note: the EstimatorBase constructor ensures that:
335+
# "When setting up custom VPC, both subnets and security_group_ids must be set"
336+
if self._repack_model_step_settings.get(
337+
"security_group_ids", None
338+
) or self._repack_model_step_settings.get("subnets", None):
339+
security_group_ids = self._repack_model_step_settings.pop("security_group_ids", None)
340+
subnets = self._repack_model_step_settings.pop("subnets", None)
341+
return security_group_ids, subnets
342+
343+
if self._model.vpc_config:
344+
security_group_ids = self._model.vpc_config.get("SecurityGroupIds", None)
345+
subnets = self._model.vpc_config.get("Subnets", None)
346+
return security_group_ids, subnets
347+
348+
return None, None

0 commit comments

Comments
 (0)