Skip to content

Commit c50a4d9

Browse files
authored
fix: Support transformer data parameterization (aws#3145)
1 parent 4acbdb0 commit c50a4d9

File tree

8 files changed

+24
-15
lines changed

8 files changed

+24
-15
lines changed

src/sagemaker/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,13 +501,14 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
501501
if is_pipeline_variable(self.model_data):
502502
# model is not yet there, defer repacking to later during pipeline execution
503503
if not isinstance(self.sagemaker_session, PipelineSession):
504-
# TODO: link the doc in the warning once ready
505504
logging.warning(
506505
"The model_data is a Pipeline variable of type %s, "
507506
"which should be used under `PipelineSession` and "
508507
"leverage `ModelStep` to create or register model. "
509508
"Otherwise some functionalities e.g. "
510-
"runtime repack may be missing",
509+
"runtime repack may be missing. For more, see: "
510+
"https://sagemaker.readthedocs.io/en/stable/"
511+
"amazon_sagemaker_model_building_pipeline.html#model-step",
511512
type(self.model_data),
512513
)
513514
return

src/sagemaker/pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def register(
278278
description: Optional[str] = None,
279279
drift_check_baselines: Optional[DriftCheckBaselines] = None,
280280
customer_metadata_properties: Optional[Dict[str, str]] = None,
281+
domain: Optional[str] = None,
281282
):
282283
"""Creates a model package for creating SageMaker models or listing on Marketplace.
283284
@@ -305,6 +306,8 @@ def register(
305306
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
306307
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
307308
metadata properties (default: None).
309+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
310+
"MACHINE_LEARNING" (default: None).
308311
309312
Returns:
310313
A `sagemaker.model.ModelPackage` instance.
@@ -335,6 +338,7 @@ def register(
335338
container_def_list=container_def,
336339
drift_check_baselines=drift_check_baselines,
337340
customer_metadata_properties=customer_metadata_properties,
341+
domain=domain,
338342
)
339343

340344
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)

src/sagemaker/tensorflow/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,13 +362,14 @@ def prepare_container_def(
362362
if isinstance(self.sagemaker_session, PipelineSession):
363363
self.sagemaker_session.context.need_runtime_repack.add(id(self))
364364
else:
365-
# TODO: link the doc in the warning once ready
366365
logging.warning(
367366
"The model_data is a Pipeline variable of type %s, "
368367
"which should be used under `PipelineSession` and "
369368
"leverage `ModelStep` to create or register model. "
370369
"Otherwise some functionalities e.g. "
371-
"runtime repack may be missing",
370+
"runtime repack may be missing. For more, see: "
371+
"https://sagemaker.readthedocs.io/en/stable/"
372+
"amazon_sagemaker_model_building_pipeline.html#model-step",
372373
type(self.model_data),
373374
)
374375
model_data = self.model_data

src/sagemaker/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def transform(
191191
Only meaningful when wait is ``True`` (default: ``True``).
192192
"""
193193
local_mode = self.sagemaker_session.local_mode
194-
if not local_mode and not data.startswith("s3://"):
194+
if not local_mode and not is_pipeline_variable(data) and not data.startswith("s3://"):
195195
raise ValueError("Invalid S3 URI: {}".format(data))
196196

197197
if job_name is not None:

src/sagemaker/workflow/model_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def __init__(
111111
The display name provides better UI readability. (default: None).
112112
description (str): The description of the `ModelStep` (default: None).
113113
"""
114-
# TODO: add a doc link in error message once ready
115114
from sagemaker.workflow.utilities import validate_step_args_input
116115

117116
validate_step_args_input(
@@ -121,7 +120,8 @@ def __init__(
121120
Session.create_model_package_from_containers.__name__,
122121
},
123122
error_message="The step_args of ModelStep must be obtained from model.create() "
124-
"or model.register().",
123+
"or model.register(). For more, see: https://sagemaker.readthedocs.io/en/stable/"
124+
"amazon_sagemaker_model_building_pipeline.html#model-step",
125125
)
126126
if not (step_args.create_model_request is None) ^ (
127127
step_args.create_model_package_request is None

src/sagemaker/workflow/step_collections.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,13 @@ def __init__(
253253
steps.append(register_model_step)
254254
self.steps = steps
255255

256-
# TODO: add public document link here once ready
257256
warnings.warn(
258257
(
259258
"We are deprecating the use of RegisterModel. "
260259
"Instead, please use the ModelStep, which simply takes in the step arguments "
261-
"generated by model.register()."
260+
"generated by model.register(). For more, see: "
261+
"https://sagemaker.readthedocs.io/en/stable/"
262+
"amazon_sagemaker_model_building_pipeline.html#model-step"
262263
),
263264
DeprecationWarning,
264265
)

src/sagemaker/workflow/steps.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,13 @@ def __init__(
444444

445445
self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelOutput")
446446

447-
# TODO: add public document link here once ready
448447
warnings.warn(
449448
(
450449
"We are deprecating the use of CreateModelStep. "
451450
"Instead, please use the ModelStep, which simply takes in the step arguments "
452-
"generated by model.create()."
451+
"generated by model.create(). For more, see: "
452+
"https://sagemaker.readthedocs.io/en/stable/"
453+
"amazon_sagemaker_model_building_pipeline.html#model-step"
453454
),
454455
DeprecationWarning,
455456
)

tests/unit/sagemaker/workflow/test_transform_step.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def test_transform_step_with_transformer(pipeline_session):
8989
output_path=f"s3://{pipeline_session.default_bucket()}/Transform",
9090
sagemaker_session=pipeline_session,
9191
)
92-
93-
transform_inputs = TransformInput(
94-
data=f"s3://{pipeline_session.default_bucket()}/batch-data",
92+
data = ParameterString(
93+
name="Data", default_value=f"s3://{pipeline_session.default_bucket()}/batch-data"
9594
)
95+
transform_inputs = TransformInput(data=data)
9696

9797
with warnings.catch_warnings(record=True) as w:
9898
step_args = transformer.transform(
@@ -120,10 +120,11 @@ def test_transform_step_with_transformer(pipeline_session):
120120
pipeline = Pipeline(
121121
name="MyPipeline",
122122
steps=[step],
123-
parameters=[model_name],
123+
parameters=[model_name, data],
124124
sagemaker_session=pipeline_session,
125125
)
126126
step_args.args["ModelName"] = model_name.expr
127+
step_args.args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = data.expr
127128
assert json.loads(pipeline.definition())["Steps"][0] == {
128129
"Name": "MyTransformStep",
129130
"Type": "Transform",

0 commit comments

Comments
 (0)