Skip to content

Commit aa7592d

Browse files
feature: Add output path parameter for _RepackModelStep
1 parent 17fe93e commit aa7592d

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
display_name: str = None,
6060
description: str = None,
6161
source_dir: str = None,
62+
repack_output_path=None,
6263
dependencies: List = None,
6364
depends_on: Union[List[str], List[Step]] = None,
6465
retry_policies: List[RetryPolicy] = None,
@@ -101,6 +102,9 @@ def __init__(
101102
or model hosting source code dependencies aside from the entry point
102103
file in the Git repo (default: None). Structure within this
103104
directory are preserved when training on Amazon SageMaker.
105+
repack_output_path (str): The S3 prefix URI where the repacked model will be
106+
uploaded (default: None) - don't include a trailing slash.
107+
If not specified, the default location is s3://default-bucket/job-name.
104108
dependencies (list[str]): A list of paths to directories (absolute
105109
or relative) with any additional libraries that will be exported
106110
to the container (default: []). The library folders will be
@@ -170,6 +174,8 @@ def __init__(
170174
},
171175
subnets=subnets,
172176
security_group_ids=security_group_ids,
177+
output_path=repack_output_path,
178+
code_location=repack_output_path,
173179
**kwargs,
174180
)
175181
repacker.disable_profiler = True

src/sagemaker/workflow/step_collections.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
estimator: EstimatorBase = None,
6464
model_data=None,
6565
depends_on: Union[List[str], List[Step]] = None,
66+
repack_output_path=None,
6667
repack_model_step_retry_policies: List[RetryPolicy] = None,
6768
register_model_step_retry_policies: List[RetryPolicy] = None,
6869
model_package_group_name=None,
@@ -91,6 +92,9 @@ def __init__(
9192
job can be run or on which an endpoint can be deployed (default: None).
9293
depends_on (List[str] or List[Step]): The list of step names or step instances
9394
the first step in the collection depends on
95+
repack_output_path (str): The S3 prefix URI where the repacked model will be
96+
uploaded (default: None) - don't include a trailing slash.
97+
If not specified, the default location is s3://default-bucket/job-name.
9498
repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
9599
for the repack model step
96100
register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
@@ -151,6 +155,7 @@ def __init__(
151155
security_group_ids=security_group_ids,
152156
description=description,
153157
display_name=display_name,
158+
repack_output_path=repack_output_path,
154159
**kwargs,
155160
)
156161
steps.append(repack_model_step)
@@ -195,6 +200,7 @@ def __init__(
195200
security_group_ids=security_group_ids,
196201
description=description,
197202
display_name=display_name,
203+
repack_output_path=repack_output_path,
198204
**kwargs,
199205
)
200206
steps.append(repack_model_step)
@@ -256,6 +262,7 @@ def __init__(
256262
image_uri=None,
257263
predictor_cls=None,
258264
env=None,
265+
repack_output_path=None,
259266
# transformer arguments
260267
strategy=None,
261268
assemble_with=None,
@@ -307,6 +314,9 @@ def __init__(
307314
it will be the format of the batch transform output.
308315
env (dict): The Environment variables to be set for use during the
309316
transform job (default: None).
317+
repack_output_path (str): The S3 prefix URI where the repacked model will be
318+
uploaded (default: None) - don't include a trailing slash.
319+
If not specified, the default location is s3://default-bucket/job-name.
310320
depends_on (List[str] or List[Step]): The list of step names or step instances
311321
the first step in the collection depends on
312322
repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
@@ -336,6 +346,7 @@ def __init__(
336346
security_group_ids=estimator.security_group_ids,
337347
description=description,
338348
display_name=display_name,
349+
repack_output_path=repack_output_path,
339350
)
340351
steps.append(repack_model_step)
341352
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts

0 commit comments

Comments
 (0)