Skip to content

Implement subclass compatibility for workflow pipeline job steps #3040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sagemaker.estimator import EstimatorBase, _TrainingJob
from sagemaker.inputs import FileSystemInput, TrainingInput
from sagemaker.utils import sagemaker_timestamp
from sagemaker.workflow.pipeline_context import runnable_by_pipeline

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -192,6 +193,7 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
self.feature_dim = feature_dim
self.mini_batch_size = mini_batch_size

@runnable_by_pipeline
def fit(
self,
records,
Expand Down
11 changes: 6 additions & 5 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,8 @@ def _run(
output_name="analysis_result",
s3_upload_mode="EndOfJob",
)
super().run(

return super().run(
inputs=[data_input, config_input],
outputs=[result_output],
wait=wait,
Expand Down Expand Up @@ -871,7 +872,7 @@ def run_pre_training_bias(
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
self._run(
return self._run(
data_config,
analysis_config,
wait,
Expand Down Expand Up @@ -957,7 +958,7 @@ def run_post_training_bias(
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
self._run(
return self._run(
data_config,
analysis_config,
wait,
Expand Down Expand Up @@ -1060,7 +1061,7 @@ def run_bias(
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Bias")
self._run(
return self._run(
data_config,
analysis_config,
wait,
Expand Down Expand Up @@ -1167,7 +1168,7 @@ def run_explainability(
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Explainability")
self._run(
return self._run(
data_config,
analysis_config,
wait,
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
name_from_base,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.pipeline_context import (
PipelineSession,
runnable_by_pipeline,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -896,6 +900,7 @@ def latest_job_profiler_artifacts_path(self):
)
return None

@runnable_by_pipeline
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
"""Train a model using the input training dataset.

Expand Down Expand Up @@ -1341,7 +1346,9 @@ def register(
@property
def model_data(self):
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
if self.latest_training_job is not None:
if self.latest_training_job is not None and not isinstance(
self.sagemaker_session, PipelineSession
):
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=self.latest_training_job.name
)["ModelArtifacts"]["S3ModelArtifacts"]
Expand Down Expand Up @@ -1767,6 +1774,7 @@ def start_new(cls, estimator, inputs, experiment_config):
all information about the started training job.
"""
train_args = cls._get_train_args(estimator, inputs, experiment_config)

estimator.sagemaker_session.train(**train_args)

return cls(estimator.sagemaker_session, estimator._current_job_name)
Expand Down
15 changes: 7 additions & 8 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,13 @@

from six.moves.urllib.parse import urlparse
from six.moves.urllib.request import url2pathname

from sagemaker import s3
from sagemaker.job import _Job
from sagemaker.local import LocalSession
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
from sagemaker.session import Session
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.entities import Expression
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
from sagemaker.apiutils._base_types import ApiObject
from sagemaker.s3 import S3Uploader
Expand Down Expand Up @@ -133,6 +130,7 @@ def __init__(

self.sagemaker_session = sagemaker_session or Session()

@runnable_by_pipeline
def run(
self,
inputs=None,
Expand Down Expand Up @@ -314,10 +312,10 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
if file_input.input_name is None:
file_input.input_name = "input-{}".format(count)

if isinstance(file_input.source, Properties) or file_input.dataset_definition:
if is_pipeline_variable(file_input.source) or file_input.dataset_definition:
normalized_inputs.append(file_input)
continue
if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)):
if is_pipeline_variable(file_input.s3_input.s3_uri):
normalized_inputs.append(file_input)
continue
# If the source is a local path, upload it to S3
Expand Down Expand Up @@ -367,7 +365,7 @@ def _normalize_outputs(self, outputs=None):
# Generate a name for the ProcessingOutput if it doesn't have one.
if output.output_name is None:
output.output_name = "output-{}".format(count)
if isinstance(output.destination, (Parameter, Expression, Properties)):
if is_pipeline_variable(output.destination):
normalized_outputs.append(output)
continue
# If the output's destination is not an s3_uri, create one.
Expand Down Expand Up @@ -497,6 +495,7 @@ def get_run_args(
"""
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)

@runnable_by_pipeline
def run(
self,
code,
Expand Down Expand Up @@ -1600,7 +1599,7 @@ def run( # type: ignore[override]
)

# Submit a processing job.
super().run(
return super().run(
code=s3_runproc_sh,
inputs=inputs,
outputs=outputs,
Expand Down
52 changes: 39 additions & 13 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re
import sys
import time
import typing
import warnings
from typing import List, Dict, Any, Sequence

Expand Down Expand Up @@ -551,7 +552,6 @@ def train( # noqa: C901
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
* max_retry_attsmpts (int): Number of times a job should be retried.
The key in RetryStrategy is 'MaxRetryAttempts'.

Returns:
str: ARN of the training job, if it is created.
"""
Expand Down Expand Up @@ -585,9 +585,13 @@ def train( # noqa: C901
environment=environment,
retry_strategy=retry_strategy,
)
LOGGER.info("Creating training-job with name: %s", job_name)
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
self.sagemaker_client.create_training_job(**train_request)

def submit(request):
LOGGER.info("Creating training-job with name: %s", job_name)
LOGGER.debug("train request: %s", json.dumps(request, indent=4))
self.sagemaker_client.create_training_job(**request)

self._intercept_create_request(train_request, submit)

def _get_train_request( # noqa: C901
self,
Expand Down Expand Up @@ -912,9 +916,13 @@ def process(
tags=tags,
experiment_config=experiment_config,
)
LOGGER.info("Creating processing-job with name %s", job_name)
LOGGER.debug("process request: %s", json.dumps(process_request, indent=4))
self.sagemaker_client.create_processing_job(**process_request)

def submit(request):
LOGGER.info("Creating processing-job with name %s", job_name)
LOGGER.debug("process request: %s", json.dumps(request, indent=4))
self.sagemaker_client.create_processing_job(**request)

self._intercept_create_request(process_request, submit)

def _get_process_request(
self,
Expand Down Expand Up @@ -2086,9 +2094,12 @@ def create_tuning_job(
tags=tags,
)

LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
def submit(request):
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
LOGGER.debug("tune request: %s", json.dumps(request, indent=4))
self.sagemaker_client.create_hyper_parameter_tuning_job(**request)

self._intercept_create_request(tune_request, submit)

def _get_tuning_request(
self,
Expand Down Expand Up @@ -2553,9 +2564,12 @@ def transform(
model_client_config=model_client_config,
)

LOGGER.info("Creating transform job with name: %s", job_name)
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
self.sagemaker_client.create_transform_job(**transform_request)
def submit(request):
LOGGER.info("Creating transform job with name: %s", job_name)
LOGGER.debug("Transform request: %s", json.dumps(request, indent=4))
self.sagemaker_client.create_transform_job(**request)

self._intercept_create_request(transform_request, submit)

def _create_model_request(
self,
Expand Down Expand Up @@ -4161,6 +4175,18 @@ def account_id(self) -> str:
)
return sts_client.get_caller_identity()["Account"]

def _intercept_create_request(self, request: typing.Dict, create):
"""This function intercepts the create job request.

PipelineSession inherits this Session class and will override
this function to intercept the create request.

Args:
request (dict): the create job request
create (functor): a functor calls the sagemaker client create method
"""
create(request)


def get_model_package_args(
content_types,
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/spark/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def run(
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

super().run(
return super().run(
submit_app,
inputs,
outputs,
Expand Down Expand Up @@ -868,7 +868,7 @@ def run(
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
)

super().run(
return super().run(
submit_app=submit_app,
inputs=extended_inputs,
outputs=extended_outputs,
Expand Down Expand Up @@ -1125,7 +1125,7 @@ def run(
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
)

super().run(
return super().run(
submit_app=submit_app,
inputs=extended_inputs,
outputs=extended_outputs,
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from sagemaker.job import _Job
from sagemaker.session import Session
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
from sagemaker.workflow import is_pipeline_variable
from sagemaker.utils import base_name_from_image, name_from_base


Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(

self.sagemaker_session = sagemaker_session or Session()

@runnable_by_pipeline
def transform(
self,
data,
Expand Down Expand Up @@ -197,7 +200,11 @@ def transform(
base_name = self.base_transform_job_name

if base_name is None:
base_name = self._retrieve_base_name()
base_name = (
"transform-job"
if is_pipeline_variable(self.model_name)
else self._retrieve_base_name()
)

self._current_job_name = name_from_base(base_name)

Expand Down Expand Up @@ -370,6 +377,7 @@ def start_new(
experiment_config,
model_client_config,
)

transformer.sagemaker_session.transform(**transform_args)

return cls(transformer.sagemaker_session, transformer._current_job_name)
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
IntegerParameter,
ParameterRange,
)
from sagemaker.workflow.pipeline_context import runnable_by_pipeline

from sagemaker.session import Session
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
Expand Down Expand Up @@ -380,6 +381,7 @@ def _prepare_static_hyperparameters(

return static_hyperparameters

@runnable_by_pipeline
def fit(
self,
inputs=None,
Expand Down Expand Up @@ -466,7 +468,9 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim
estimator_names = sorted(self.estimator_dict.keys())
self._validate_dict_argument(name="inputs", value=inputs, allowed_keys=estimator_names)
self._validate_dict_argument(
name="include_cls_metadata", value=include_cls_metadata, allowed_keys=estimator_names
name="include_cls_metadata",
value=include_cls_metadata,
allowed_keys=estimator_names,
)
self._validate_dict_argument(
name="estimator_kwargs", value=estimator_kwargs, allowed_keys=estimator_names
Expand Down Expand Up @@ -1468,6 +1472,7 @@ def start_new(cls, tuner, inputs):
information about the started job.
"""
tuner_args = cls._get_tuner_args(tuner, inputs)

tuner.sagemaker_session.create_tuning_job(**tuner_args)

return cls(tuner.sagemaker_session, tuner._current_job_name)
Expand Down
Loading