Skip to content

Commit deeb7c1

Browse files
committed
Implement subclass compatibility for workflow pipeline job steps
1 parent a3eea2b commit deeb7c1

20 files changed

+1439
-161
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sagemaker.estimator import EstimatorBase, _TrainingJob
2828
from sagemaker.inputs import FileSystemInput, TrainingInput
2929
from sagemaker.utils import sagemaker_timestamp
30+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3031

3132
logger = logging.getLogger(__name__)
3233

@@ -192,6 +193,7 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
192193
self.feature_dim = feature_dim
193194
self.mini_batch_size = mini_batch_size
194195

196+
@runnable_by_pipeline
195197
def fit(
196198
self,
197199
records,

src/sagemaker/clarify.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,8 @@ def _run(
803803
output_name="analysis_result",
804804
s3_upload_mode="EndOfJob",
805805
)
806-
super().run(
806+
807+
return super().run(
807808
inputs=[data_input, config_input],
808809
outputs=[result_output],
809810
wait=wait,
@@ -871,7 +872,7 @@ def run_pre_training_bias(
871872
job_name = utils.name_from_base(self.job_name_prefix)
872873
else:
873874
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
874-
self._run(
875+
return self._run(
875876
data_config,
876877
analysis_config,
877878
wait,
@@ -957,7 +958,7 @@ def run_post_training_bias(
957958
job_name = utils.name_from_base(self.job_name_prefix)
958959
else:
959960
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
960-
self._run(
961+
return self._run(
961962
data_config,
962963
analysis_config,
963964
wait,
@@ -1060,7 +1061,7 @@ def run_bias(
10601061
job_name = utils.name_from_base(self.job_name_prefix)
10611062
else:
10621063
job_name = utils.name_from_base("Clarify-Bias")
1063-
self._run(
1064+
return self._run(
10641065
data_config,
10651066
analysis_config,
10661067
wait,
@@ -1167,7 +1168,7 @@ def run_explainability(
11671168
job_name = utils.name_from_base(self.job_name_prefix)
11681169
else:
11691170
job_name = utils.name_from_base("Clarify-Explainability")
1170-
self._run(
1171+
return self._run(
11711172
data_config,
11721173
analysis_config,
11731174
wait,

src/sagemaker/estimator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,12 @@
7474
get_config_value,
7575
name_from_base,
7676
)
77+
7778
from sagemaker.workflow.entities import PipelineVariable
79+
from sagemaker.workflow.pipeline_context import (
80+
PipelineSession,
81+
runnable_by_pipeline,
82+
)
7883

7984
logger = logging.getLogger(__name__)
8085

@@ -896,6 +901,7 @@ def latest_job_profiler_artifacts_path(self):
896901
)
897902
return None
898903

904+
@runnable_by_pipeline
899905
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
900906
"""Train a model using the input training dataset.
901907
@@ -1341,7 +1347,9 @@ def register(
13411347
@property
13421348
def model_data(self):
13431349
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
1344-
if self.latest_training_job is not None:
1350+
if self.latest_training_job is not None and not isinstance(
1351+
self.sagemaker_session, PipelineSession
1352+
):
13451353
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
13461354
TrainingJobName=self.latest_training_job.name
13471355
)["ModelArtifacts"]["S3ModelArtifacts"]
@@ -1767,6 +1775,7 @@ def start_new(cls, estimator, inputs, experiment_config):
17671775
all information about the started training job.
17681776
"""
17691777
train_args = cls._get_train_args(estimator, inputs, experiment_config)
1778+
17701779
estimator.sagemaker_session.train(**train_args)
17711780

17721781
return cls(estimator.sagemaker_session, estimator._current_job_name)

src/sagemaker/jumpstart/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,3 @@
126126
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
127127

128128
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"
129-

src/sagemaker/processing.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,19 @@
2828

2929
from six.moves.urllib.parse import urlparse
3030
from six.moves.urllib.request import url2pathname
31-
3231
from sagemaker import s3
3332
from sagemaker.job import _Job
3433
from sagemaker.local import LocalSession
3534
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
3635
from sagemaker.session import Session
37-
from sagemaker.workflow.properties import Properties
38-
from sagemaker.workflow.parameters import Parameter
39-
from sagemaker.workflow.entities import Expression
36+
from sagemaker.workflow.pipeline_context import (
37+
runnable_by_pipeline,
38+
is_pipeline_entities,
39+
)
4040
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
4141
from sagemaker.apiutils._base_types import ApiObject
4242
from sagemaker.s3 import S3Uploader
4343

44-
4544
logger = logging.getLogger(__name__)
4645

4746

@@ -133,6 +132,7 @@ def __init__(
133132

134133
self.sagemaker_session = sagemaker_session or Session()
135134

135+
@runnable_by_pipeline
136136
def run(
137137
self,
138138
inputs=None,
@@ -308,10 +308,10 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
308308
if file_input.input_name is None:
309309
file_input.input_name = "input-{}".format(count)
310310

311-
if isinstance(file_input.source, Properties) or file_input.dataset_definition:
311+
if is_pipeline_entities(file_input.source) or file_input.dataset_definition:
312312
normalized_inputs.append(file_input)
313313
continue
314-
if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)):
314+
if is_pipeline_entities(file_input.s3_input.s3_uri):
315315
normalized_inputs.append(file_input)
316316
continue
317317
# If the source is a local path, upload it to S3
@@ -361,7 +361,7 @@ def _normalize_outputs(self, outputs=None):
361361
# Generate a name for the ProcessingOutput if it doesn't have one.
362362
if output.output_name is None:
363363
output.output_name = "output-{}".format(count)
364-
if isinstance(output.destination, (Parameter, Expression, Properties)):
364+
if is_pipeline_entities(output.destination):
365365
normalized_outputs.append(output)
366366
continue
367367
# If the output's destination is not an s3_uri, create one.
@@ -491,6 +491,7 @@ def get_run_args(
491491
"""
492492
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)
493493

494+
@runnable_by_pipeline
494495
def run(
495496
self,
496497
code,
@@ -1594,7 +1595,7 @@ def run( # type: ignore[override]
15941595
)
15951596

15961597
# Submit a processing job.
1597-
super().run(
1598+
return super().run(
15981599
code=s3_runproc_sh,
15991600
inputs=inputs,
16001601
outputs=outputs,

src/sagemaker/session.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import re
2020
import sys
2121
import time
22+
import typing
2223
import warnings
2324
from typing import List, Dict, Any, Sequence
2425

@@ -551,7 +552,6 @@ def train( # noqa: C901
551552
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
552553
* max_retry_attsmpts (int): Number of times a job should be retried.
553554
The key in RetryStrategy is 'MaxRetryAttempts'.
554-
555555
Returns:
556556
str: ARN of the training job, if it is created.
557557
"""
@@ -585,9 +585,13 @@ def train( # noqa: C901
585585
environment=environment,
586586
retry_strategy=retry_strategy,
587587
)
588-
LOGGER.info("Creating training-job with name: %s", job_name)
589-
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
590-
self.sagemaker_client.create_training_job(**train_request)
588+
589+
def submit(request):
590+
LOGGER.info("Creating training-job with name: %s", job_name)
591+
LOGGER.debug("train request: %s", json.dumps(request, indent=4))
592+
self.sagemaker_client.create_training_job(**request)
593+
594+
self._intercept_create_request(train_request, submit)
591595

592596
def _get_train_request( # noqa: C901
593597
self,
@@ -910,9 +914,13 @@ def process(
910914
tags=tags,
911915
experiment_config=experiment_config,
912916
)
913-
LOGGER.info("Creating processing-job with name %s", job_name)
914-
LOGGER.debug("process request: %s", json.dumps(process_request, indent=4))
915-
self.sagemaker_client.create_processing_job(**process_request)
917+
918+
def submit(request):
919+
LOGGER.info("Creating processing-job with name %s", job_name)
920+
LOGGER.debug("process request: %s", json.dumps(request, indent=4))
921+
self.sagemaker_client.create_processing_job(**request)
922+
923+
self._intercept_create_request(process_request, submit)
916924

917925
def _get_process_request(
918926
self,
@@ -2084,9 +2092,12 @@ def create_tuning_job(
20842092
tags=tags,
20852093
)
20862094

2087-
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
2088-
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
2089-
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
2095+
def submit(request):
2096+
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
2097+
LOGGER.debug("tune request: %s", json.dumps(request, indent=4))
2098+
self.sagemaker_client.create_hyper_parameter_tuning_job(**request)
2099+
2100+
self._intercept_create_request(tune_request, submit)
20902101

20912102
def _get_tuning_request(
20922103
self,
@@ -2547,9 +2558,12 @@ def transform(
25472558
model_client_config=model_client_config,
25482559
)
25492560

2550-
LOGGER.info("Creating transform job with name: %s", job_name)
2551-
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
2552-
self.sagemaker_client.create_transform_job(**transform_request)
2561+
def submit(request):
2562+
LOGGER.info("Creating transform job with name: %s", job_name)
2563+
LOGGER.debug("Transform request: %s", json.dumps(request, indent=4))
2564+
self.sagemaker_client.create_transform_job(**request)
2565+
2566+
self._intercept_create_request(transform_request, submit)
25532567

25542568
def _create_model_request(
25552569
self,
@@ -4155,6 +4169,16 @@ def account_id(self) -> str:
41554169
)
41564170
return sts_client.get_caller_identity()["Account"]
41574171

4172+
def _intercept_create_request(self, request: typing.Dict, create):
4173+
"""
4174+
This function intercepts the create job request
4175+
4176+
Args:
4177+
request (dict): the create job request
4178+
create (functor): a functor calls the sagemaker client create method
4179+
"""
4180+
create(**request)
4181+
41584182

41594183
def get_model_package_args(
41604184
content_types,

src/sagemaker/spark/processing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def run(
249249
"""
250250
self._current_job_name = self._generate_current_job_name(job_name=job_name)
251251

252-
super().run(
252+
return super().run(
253253
submit_app,
254254
inputs,
255255
outputs,
@@ -868,7 +868,7 @@ def run(
868868
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
869869
)
870870

871-
super().run(
871+
return super().run(
872872
submit_app=submit_app,
873873
inputs=extended_inputs,
874874
outputs=extended_outputs,
@@ -1125,7 +1125,7 @@ def run(
11251125
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
11261126
)
11271127

1128-
super().run(
1128+
return super().run(
11291129
submit_app=submit_app,
11301130
inputs=extended_inputs,
11311131
outputs=extended_outputs,

src/sagemaker/transformer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
from sagemaker.job import _Job
1919
from sagemaker.session import Session
20+
from sagemaker.workflow.pipeline_context import (
21+
runnable_by_pipeline,
22+
is_pipeline_entities,
23+
)
2024
from sagemaker.utils import base_name_from_image, name_from_base
2125

2226

@@ -106,6 +110,7 @@ def __init__(
106110

107111
self.sagemaker_session = sagemaker_session or Session()
108112

113+
@runnable_by_pipeline
109114
def transform(
110115
self,
111116
data,
@@ -197,7 +202,11 @@ def transform(
197202
base_name = self.base_transform_job_name
198203

199204
if base_name is None:
200-
base_name = self._retrieve_base_name()
205+
base_name = (
206+
"transform-job"
207+
if is_pipeline_entities(self.model_name)
208+
else self._retrieve_base_name()
209+
)
201210

202211
self._current_job_name = name_from_base(base_name)
203212

@@ -370,6 +379,7 @@ def start_new(
370379
experiment_config,
371380
model_client_config,
372381
)
382+
373383
transformer.sagemaker_session.transform(**transform_args)
374384

375385
return cls(transformer.sagemaker_session, transformer._current_job_name)

0 commit comments

Comments
 (0)