Skip to content

Commit 02619cb

Browse files
committed
Implement subclass compatibility for workflow pipeline job steps
1 parent 0e4fd55 commit 02619cb

20 files changed

+1440
-175
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sagemaker.estimator import EstimatorBase, _TrainingJob
2828
from sagemaker.inputs import FileSystemInput, TrainingInput
2929
from sagemaker.utils import sagemaker_timestamp
30+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3031

3132
logger = logging.getLogger(__name__)
3233

@@ -192,6 +193,7 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
192193
self.feature_dim = feature_dim
193194
self.mini_batch_size = mini_batch_size
194195

196+
@runnable_by_pipeline
195197
def fit(
196198
self,
197199
records,

src/sagemaker/clarify.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,8 @@ def _run(
803803
output_name="analysis_result",
804804
s3_upload_mode="EndOfJob",
805805
)
806-
super().run(
806+
807+
return super().run(
807808
inputs=[data_input, config_input],
808809
outputs=[result_output],
809810
wait=wait,
@@ -871,7 +872,7 @@ def run_pre_training_bias(
871872
job_name = utils.name_from_base(self.job_name_prefix)
872873
else:
873874
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
874-
self._run(
875+
return self._run(
875876
data_config,
876877
analysis_config,
877878
wait,
@@ -957,7 +958,7 @@ def run_post_training_bias(
957958
job_name = utils.name_from_base(self.job_name_prefix)
958959
else:
959960
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
960-
self._run(
961+
return self._run(
961962
data_config,
962963
analysis_config,
963964
wait,
@@ -1060,7 +1061,7 @@ def run_bias(
10601061
job_name = utils.name_from_base(self.job_name_prefix)
10611062
else:
10621063
job_name = utils.name_from_base("Clarify-Bias")
1063-
self._run(
1064+
return self._run(
10641065
data_config,
10651066
analysis_config,
10661067
wait,
@@ -1167,7 +1168,7 @@ def run_explainability(
11671168
job_name = utils.name_from_base(self.job_name_prefix)
11681169
else:
11691170
job_name = utils.name_from_base("Clarify-Explainability")
1170-
self._run(
1171+
return self._run(
11711172
data_config,
11721173
analysis_config,
11731174
wait,

src/sagemaker/estimator.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,12 @@
7373
get_config_value,
7474
name_from_base,
7575
)
76-
from sagemaker.workflow.entities import Expression
77-
from sagemaker.workflow.parameters import Parameter
78-
from sagemaker.workflow.properties import Properties
76+
77+
from sagemaker.workflow.pipeline_context import (
78+
PipelineSession,
79+
runnable_by_pipeline,
80+
is_pipeline_entities
81+
)
7982

8083
logger = logging.getLogger(__name__)
8184

@@ -598,7 +601,7 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A
598601
current_hyperparameters = hyperparameters
599602
if current_hyperparameters is not None:
600603
hyperparameters = {
601-
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v))
604+
str(k): (v if is_pipeline_entities(v) else json.dumps(v))
602605
for (k, v) in current_hyperparameters.items()
603606
}
604607
return hyperparameters
@@ -894,6 +897,7 @@ def latest_job_profiler_artifacts_path(self):
894897
)
895898
return None
896899

900+
@runnable_by_pipeline
897901
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
898902
"""Train a model using the input training dataset.
899903
@@ -1331,7 +1335,7 @@ def register(
13311335
@property
13321336
def model_data(self):
13331337
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
1334-
if self.latest_training_job is not None:
1338+
if self.latest_training_job is not None and type(self.sagemaker_session) is not PipelineSession:
13351339
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
13361340
TrainingJobName=self.latest_training_job.name
13371341
)["ModelArtifacts"]["S3ModelArtifacts"]
@@ -1757,6 +1761,9 @@ def start_new(cls, estimator, inputs, experiment_config):
17571761
all information about the started training job.
17581762
"""
17591763
train_args = cls._get_train_args(estimator, inputs, experiment_config)
1764+
if type(estimator.sagemaker_session) is PipelineSession:
1765+
train_args['pipeline_session'] = estimator.sagemaker_session
1766+
17601767
estimator.sagemaker_session.train(**train_args)
17611768

17621769
return cls(estimator.sagemaker_session, estimator._current_job_name)
@@ -1801,7 +1808,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
18011808
current_hyperparameters = estimator.hyperparameters()
18021809
if current_hyperparameters is not None:
18031810
hyperparameters = {
1804-
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
1811+
str(k): (v if is_pipeline_entities(v) else str(v))
18051812
for (k, v) in current_hyperparameters.items()
18061813
}
18071814

src/sagemaker/processing.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,21 @@
2828

2929
from six.moves.urllib.parse import urlparse
3030
from six.moves.urllib.request import url2pathname
31-
31+
import sagemaker
3232
from sagemaker import s3
3333
from sagemaker.job import _Job
3434
from sagemaker.local import LocalSession
3535
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
3636
from sagemaker.session import Session
37-
from sagemaker.workflow.properties import Properties
38-
from sagemaker.workflow.parameters import Parameter
39-
from sagemaker.workflow.entities import Expression
37+
from sagemaker.workflow.pipeline_context import (
38+
PipelineSession,
39+
runnable_by_pipeline,
40+
is_pipeline_entities,
41+
)
4042
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
4143
from sagemaker.apiutils._base_types import ApiObject
4244
from sagemaker.s3 import S3Uploader
4345

44-
4546
logger = logging.getLogger(__name__)
4647

4748

@@ -133,6 +134,7 @@ def __init__(
133134

134135
self.sagemaker_session = sagemaker_session or Session()
135136

137+
@runnable_by_pipeline
136138
def run(
137139
self,
138140
inputs=None,
@@ -308,10 +310,10 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
308310
if file_input.input_name is None:
309311
file_input.input_name = "input-{}".format(count)
310312

311-
if isinstance(file_input.source, Properties) or file_input.dataset_definition:
313+
if is_pipeline_entities(file_input.source) or file_input.dataset_definition:
312314
normalized_inputs.append(file_input)
313315
continue
314-
if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)):
316+
if is_pipeline_entities(file_input.s3_input.s3_uri):
315317
normalized_inputs.append(file_input)
316318
continue
317319
# If the source is a local path, upload it to S3
@@ -361,7 +363,7 @@ def _normalize_outputs(self, outputs=None):
361363
# Generate a name for the ProcessingOutput if it doesn't have one.
362364
if output.output_name is None:
363365
output.output_name = "output-{}".format(count)
364-
if isinstance(output.destination, (Parameter, Expression, Properties)):
366+
if is_pipeline_entities(output.destination):
365367
normalized_outputs.append(output)
366368
continue
367369
# If the output's destination is not an s3_uri, create one.
@@ -491,6 +493,7 @@ def get_run_args(
491493
"""
492494
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)
493495

496+
@runnable_by_pipeline
494497
def run(
495498
self,
496499
code,
@@ -765,6 +768,9 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
765768
print("Inputs: ", process_args["inputs"])
766769
print("Outputs: ", process_args["output_config"]["Outputs"])
767770

771+
if type(processor.sagemaker_session) is PipelineSession:
772+
process_args["pipeline_session"] = processor.sagemaker_session
773+
768774
# Call sagemaker_session.process using the arguments dictionary.
769775
processor.sagemaker_session.process(**process_args)
770776

@@ -1594,7 +1600,7 @@ def run( # type: ignore[override]
15941600
)
15951601

15961602
# Submit a processing job.
1597-
super().run(
1603+
return super().run(
15981604
code=s3_runproc_sh,
15991605
inputs=inputs,
16001606
outputs=outputs,
@@ -1753,3 +1759,4 @@ def _set_entrypoint(self, command, user_script_name):
17531759
)
17541760
)
17551761
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
1762+

src/sagemaker/session.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def train( # noqa: C901
466466
profiler_config=None,
467467
environment=None,
468468
retry_strategy=None,
469+
pipeline_session=None,
469470
):
470471
"""Create an Amazon SageMaker training job.
471472
@@ -551,7 +552,7 @@ def train( # noqa: C901
551552
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
552553
* max_retry_attsmpts (int): Number of times a job should be retried.
553554
The key in RetryStrategy is 'MaxRetryAttempts'.
554-
555+
pipeline_session (`sagemaker.workflow.pipeline_session.PipelineSession`) A pipeline session
555556
Returns:
556557
str: ARN of the training job, if it is created.
557558
"""
@@ -585,6 +586,10 @@ def train( # noqa: C901
585586
environment=environment,
586587
retry_strategy=retry_strategy,
587588
)
589+
if pipeline_session:
590+
pipeline_session.context = train_request
591+
return
592+
588593
LOGGER.info("Creating training-job with name: %s", job_name)
589594
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
590595
self.sagemaker_client.create_training_job(**train_request)
@@ -856,6 +861,7 @@ def process(
856861
role_arn,
857862
tags,
858863
experiment_config=None,
864+
pipeline_session=None,
859865
):
860866
"""Create an Amazon SageMaker processing job.
861867
@@ -895,6 +901,7 @@ def process(
895901
* If both `ExperimentName` and `TrialName` are not supplied the trial component
896902
will be unassociated.
897903
* `TrialComponentDisplayName` is used for display in Studio.
904+
pipeline_session (`sagemaker.workflow.pipeline_session.PipelineSession`) A pipeline session
898905
"""
899906
tags = _append_project_tags(tags)
900907
process_request = self._get_process_request(
@@ -910,6 +917,9 @@ def process(
910917
tags=tags,
911918
experiment_config=experiment_config,
912919
)
920+
if pipeline_session:
921+
pipeline_session.context = process_request
922+
return
913923
LOGGER.info("Creating processing-job with name %s", job_name)
914924
LOGGER.debug("process request: %s", json.dumps(process_request, indent=4))
915925
self.sagemaker_client.create_processing_job(**process_request)
@@ -2045,6 +2055,7 @@ def create_tuning_job(
20452055
training_config_list=None,
20462056
warm_start_config=None,
20472057
tags=None,
2058+
pipeline_session=None,
20482059
):
20492060
"""Create an Amazon SageMaker hyperparameter tuning job.
20502061
@@ -2081,7 +2092,9 @@ def create_tuning_job(
20812092
warm_start_config=warm_start_config,
20822093
tags=tags,
20832094
)
2084-
2095+
if pipeline_session:
2096+
pipeline_session.context = tune_request
2097+
return
20852098
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
20862099
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
20872100
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
@@ -2492,6 +2505,7 @@ def transform(
24922505
tags,
24932506
data_processing,
24942507
model_client_config=None,
2508+
pipeline_session=None,
24952509
):
24962510
"""Create an Amazon SageMaker transform job.
24972511
@@ -2544,6 +2558,9 @@ def transform(
25442558
data_processing=data_processing,
25452559
model_client_config=model_client_config,
25462560
)
2561+
if pipeline_session:
2562+
pipeline_session.context = transform_request
2563+
return
25472564

25482565
LOGGER.info("Creating transform job with name: %s", job_name)
25492566
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))

src/sagemaker/spark/processing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def run(
249249
"""
250250
self._current_job_name = self._generate_current_job_name(job_name=job_name)
251251

252-
super().run(
252+
return super().run(
253253
submit_app,
254254
inputs,
255255
outputs,
@@ -868,7 +868,7 @@ def run(
868868
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
869869
)
870870

871-
super().run(
871+
return super().run(
872872
submit_app=submit_app,
873873
inputs=extended_inputs,
874874
outputs=extended_outputs,
@@ -1125,7 +1125,7 @@ def run(
11251125
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
11261126
)
11271127

1128-
super().run(
1128+
return super().run(
11291129
submit_app=submit_app,
11301130
inputs=extended_inputs,
11311131
outputs=extended_outputs,

src/sagemaker/transformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from sagemaker.job import _Job
1919
from sagemaker.session import Session
20+
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline, is_pipeline_entities
2021
from sagemaker.utils import base_name_from_image, name_from_base
2122

2223

@@ -106,6 +107,7 @@ def __init__(
106107

107108
self.sagemaker_session = sagemaker_session or Session()
108109

110+
@runnable_by_pipeline
109111
def transform(
110112
self,
111113
data,
@@ -196,8 +198,10 @@ def transform(
196198
else:
197199
base_name = self.base_transform_job_name
198200

199-
if base_name is None:
201+
if base_name is None and not is_pipeline_entities(self.model_name):
200202
base_name = self._retrieve_base_name()
203+
else:
204+
base_name = f"transform-job"
201205

202206
self._current_job_name = name_from_base(base_name)
203207

@@ -370,6 +374,8 @@ def start_new(
370374
experiment_config,
371375
model_client_config,
372376
)
377+
if type(transformer.sagemaker_session) is PipelineSession:
378+
transform_args['pipeline_session'] = transformer.sagemaker_session
373379
transformer.sagemaker_session.transform(**transform_args)
374380

375381
return cls(transformer.sagemaker_session, transformer._current_job_name)

0 commit comments

Comments
 (0)