Skip to content

Commit 55f2365

Browse files
authored
feature: Caching Improvements for SM Pipeline Workflows (#3441)
1 parent 6d974cf commit 55f2365

19 files changed

+1013
-103
lines changed

src/sagemaker/algorithm.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sagemaker.predictor import Predictor
2828
from sagemaker.session import Session
2929
from sagemaker.workflow.entities import PipelineVariable
30+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3031

3132
from sagemaker.workflow import is_pipeline_variable
3233

@@ -429,6 +430,7 @@ def _prepare_for_training(self, job_name=None):
429430

430431
super(AlgorithmEstimator, self)._prepare_for_training(job_name)
431432

433+
@runnable_by_pipeline
432434
def fit(
433435
self,
434436
inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None,

src/sagemaker/estimator.py

+36-13
Original file line numberDiff line numberDiff line change
@@ -778,13 +778,11 @@ def _stage_user_code_in_s3(self) -> str:
778778
if is_pipeline_variable(self.output_path):
779779
if self.code_location is None:
780780
code_bucket = self.sagemaker_session.default_bucket()
781-
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
781+
code_s3_prefix = self._assign_s3_prefix()
782782
kms_key = None
783783
else:
784784
code_bucket, key_prefix = parse_s3_url(self.code_location)
785-
code_s3_prefix = "/".join(
786-
filter(None, [key_prefix, self._current_job_name, "source"])
787-
)
785+
code_s3_prefix = self._assign_s3_prefix(key_prefix)
788786

789787
output_bucket = self.sagemaker_session.default_bucket()
790788
kms_key = self.output_kms_key if code_bucket == output_bucket else None
@@ -793,24 +791,20 @@ def _stage_user_code_in_s3(self) -> str:
793791
if local_mode:
794792
if self.code_location is None:
795793
code_bucket = self.sagemaker_session.default_bucket()
796-
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
794+
code_s3_prefix = self._assign_s3_prefix()
797795
kms_key = None
798796
else:
799797
code_bucket, key_prefix = parse_s3_url(self.code_location)
800-
code_s3_prefix = "/".join(
801-
filter(None, [key_prefix, self._current_job_name, "source"])
802-
)
798+
code_s3_prefix = self._assign_s3_prefix(key_prefix)
803799
kms_key = None
804800
else:
805801
if self.code_location is None:
806802
code_bucket, _ = parse_s3_url(self.output_path)
807-
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
803+
code_s3_prefix = self._assign_s3_prefix()
808804
kms_key = self.output_kms_key
809805
else:
810806
code_bucket, key_prefix = parse_s3_url(self.code_location)
811-
code_s3_prefix = "/".join(
812-
filter(None, [key_prefix, self._current_job_name, "source"])
813-
)
807+
code_s3_prefix = self._assign_s3_prefix(key_prefix)
814808

815809
output_bucket, _ = parse_s3_url(self.output_path)
816810
kms_key = self.output_kms_key if code_bucket == output_bucket else None
@@ -827,6 +821,36 @@ def _stage_user_code_in_s3(self) -> str:
827821
settings=self.sagemaker_session.settings,
828822
)
829823

824+
def _assign_s3_prefix(self, key_prefix=""):
825+
"""Include pipeline name+step name instead of job name in s3 path
826+
827+
Assign new s3 path structure if within a pipeline workflow that has
828+
set the _pipeline_config and respective name/hash variables
829+
830+
Args:
831+
key_prefix (str): Prefix for the S3 key, often netloc of url:
832+
https://docs.python.org/3.9/library/urllib.parse.html#urllib.parse.netloc
833+
834+
Returns:
835+
str: S3 path prefix that occurs before filename
836+
"""
837+
from sagemaker.workflow.utilities import _pipeline_config
838+
839+
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"]))
840+
if _pipeline_config and _pipeline_config.code_hash:
841+
code_s3_prefix = "/".join(
842+
filter(
843+
None,
844+
[
845+
key_prefix,
846+
_pipeline_config.pipeline_name,
847+
"code",
848+
_pipeline_config.code_hash,
849+
],
850+
)
851+
)
852+
return code_s3_prefix
853+
830854
def _prepare_rules(self):
831855
"""Rules list includes both debugger and profiler rules.
832856
@@ -1539,7 +1563,6 @@ def model_data(self):
15391563
model_uri = os.path.join(
15401564
self.output_path, self._current_job_name, "output", "model.tar.gz"
15411565
)
1542-
15431566
return model_uri
15441567

15451568
@abstractmethod

src/sagemaker/processing.py

+74-21
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
3636
from sagemaker.session import Session
3737
from sagemaker.workflow import is_pipeline_variable
38+
from sagemaker.workflow.functions import Join
3839
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
40+
from sagemaker.workflow.execution_variables import ExecutionVariables
3941
from sagemaker.workflow.entities import PipelineVariable
4042
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
4143
from sagemaker.apiutils._base_types import ApiObject
@@ -314,6 +316,8 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
314316
Raises:
315317
TypeError: if the inputs are not ``ProcessingInput`` objects.
316318
"""
319+
from sagemaker.workflow.utilities import _pipeline_config
320+
317321
# Initialize a list of normalized ProcessingInput objects.
318322
normalized_inputs = []
319323
if inputs is not None:
@@ -335,13 +339,23 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
335339
# and save the S3 uri in the ProcessingInput source.
336340
parse_result = urlparse(file_input.s3_input.s3_uri)
337341
if parse_result.scheme != "s3":
338-
desired_s3_uri = s3.s3_path_join(
339-
"s3://",
340-
self.sagemaker_session.default_bucket(),
341-
self._current_job_name,
342-
"input",
343-
file_input.input_name,
344-
)
342+
if _pipeline_config:
343+
desired_s3_uri = s3.s3_path_join(
344+
"s3://",
345+
self.sagemaker_session.default_bucket(),
346+
_pipeline_config.pipeline_name,
347+
_pipeline_config.step_name,
348+
"input",
349+
file_input.input_name,
350+
)
351+
else:
352+
desired_s3_uri = s3.s3_path_join(
353+
"s3://",
354+
self.sagemaker_session.default_bucket(),
355+
self._current_job_name,
356+
"input",
357+
file_input.input_name,
358+
)
345359
s3_uri = s3.S3Uploader.upload(
346360
local_path=file_input.s3_input.s3_uri,
347361
desired_s3_uri=desired_s3_uri,
@@ -369,6 +383,8 @@ def _normalize_outputs(self, outputs=None):
369383
TypeError: if the outputs are not ``ProcessingOutput`` objects.
370384
"""
371385
# Initialize a list of normalized ProcessingOutput objects.
386+
from sagemaker.workflow.utilities import _pipeline_config
387+
372388
normalized_outputs = []
373389
if outputs is not None:
374390
# Iterate through the provided list of outputs.
@@ -384,13 +400,27 @@ def _normalize_outputs(self, outputs=None):
384400
# If the output's destination is not an s3_uri, create one.
385401
parse_result = urlparse(output.destination)
386402
if parse_result.scheme != "s3":
387-
s3_uri = s3.s3_path_join(
388-
"s3://",
389-
self.sagemaker_session.default_bucket(),
390-
self._current_job_name,
391-
"output",
392-
output.output_name,
393-
)
403+
if _pipeline_config:
404+
s3_uri = Join(
405+
on="/",
406+
values=[
407+
"s3:/",
408+
self.sagemaker_session.default_bucket(),
409+
_pipeline_config.pipeline_name,
410+
ExecutionVariables.PIPELINE_EXECUTION_ID,
411+
_pipeline_config.step_name,
412+
"output",
413+
output.output_name,
414+
],
415+
)
416+
else:
417+
s3_uri = s3.s3_path_join(
418+
"s3://",
419+
self.sagemaker_session.default_bucket(),
420+
self._current_job_name,
421+
"output",
422+
output.output_name,
423+
)
394424
output.destination = s3_uri
395425
normalized_outputs.append(output)
396426
return normalized_outputs
@@ -507,6 +537,11 @@ def get_run_args(
507537
arguments (list[str]): A list of string arguments to be passed to a
508538
processing job (default: None).
509539
"""
540+
logger.warning(
541+
"This function has been deprecated and could break pipeline step caching. "
542+
"We recommend using the run() function directly with pipeline sessions"
543+
"to access step arguments."
544+
)
510545
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)
511546

512547
@runnable_by_pipeline
@@ -679,13 +714,24 @@ def _upload_code(self, code, kms_key=None):
679714
str: The S3 URI of the uploaded file or directory.
680715
681716
"""
682-
desired_s3_uri = s3.s3_path_join(
683-
"s3://",
684-
self.sagemaker_session.default_bucket(),
685-
self._current_job_name,
686-
"input",
687-
self._CODE_CONTAINER_INPUT_NAME,
688-
)
717+
from sagemaker.workflow.utilities import _pipeline_config
718+
719+
if _pipeline_config and _pipeline_config.code_hash:
720+
desired_s3_uri = s3.s3_path_join(
721+
"s3://",
722+
self.sagemaker_session.default_bucket(),
723+
_pipeline_config.pipeline_name,
724+
self._CODE_CONTAINER_INPUT_NAME,
725+
_pipeline_config.code_hash,
726+
)
727+
else:
728+
desired_s3_uri = s3.s3_path_join(
729+
"s3://",
730+
self.sagemaker_session.default_bucket(),
731+
self._current_job_name,
732+
"input",
733+
self._CODE_CONTAINER_INPUT_NAME,
734+
)
689735
return s3.S3Uploader.upload(
690736
local_path=code,
691737
desired_s3_uri=desired_s3_uri,
@@ -1499,6 +1545,12 @@ def get_run_args(
14991545
job_name (str): Processing job name. If not specified, the processor generates
15001546
a default job name, based on the base job name and current timestamp.
15011547
"""
1548+
logger.warning(
1549+
"This function has been deprecated and could break pipeline step caching. "
1550+
"We recommend using the run() function directly with pipeline sessions"
1551+
"to access step arguments."
1552+
)
1553+
15021554
# When job_name is None, the job_name to upload code (+payload) will
15031555
# differ from job_name used by run().
15041556
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
@@ -1512,6 +1564,7 @@ def get_run_args(
15121564
arguments=arguments,
15131565
)
15141566

1567+
@runnable_by_pipeline
15151568
def run( # type: ignore[override]
15161569
self,
15171570
code: str,

src/sagemaker/spark/processing.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from sagemaker.spark import defaults
4343

4444
from sagemaker.workflow import is_pipeline_variable
45+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
4546
from sagemaker.workflow.entities import PipelineVariable
4647
from sagemaker.workflow.functions import Join
4748

@@ -211,6 +212,7 @@ def get_run_args(
211212
arguments=arguments,
212213
)
213214

215+
@runnable_by_pipeline
214216
def run(
215217
self,
216218
submit_app,
@@ -399,12 +401,22 @@ def _stage_configuration(self, configuration):
399401
Args:
400402
configuration (Dict): the configuration dict for the EMR application configuration.
401403
"""
404+
from sagemaker.workflow.utilities import _pipeline_config
402405

403406
serialized_configuration = BytesIO(json.dumps(configuration).encode("utf-8"))
404-
s3_uri = (
405-
f"s3://{self.sagemaker_session.default_bucket()}/{self._current_job_name}/"
406-
f"input/{self._conf_container_input_name}/{self._conf_file_name}"
407-
)
407+
408+
if _pipeline_config and _pipeline_config.config_hash:
409+
s3_uri = (
410+
f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/"
411+
f"{_pipeline_config.step_name}/input/"
412+
f"{self._conf_container_input_name}/{_pipeline_config.config_hash}/"
413+
f"{self._conf_file_name}"
414+
)
415+
else:
416+
s3_uri = (
417+
f"s3://{self.sagemaker_session.default_bucket()}/{self._current_job_name}/"
418+
f"input/{self._conf_container_input_name}/{self._conf_file_name}"
419+
)
408420

409421
S3Uploader.upload_string_as_file_body(
410422
body=serialized_configuration,
@@ -443,11 +455,6 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
443455
if not input_channel_name:
444456
raise ValueError("input_channel_name value may not be empty.")
445457

446-
input_channel_s3_uri = (
447-
f"s3://{self.sagemaker_session.default_bucket()}"
448-
f"/{self._current_job_name}/input/{input_channel_name}"
449-
)
450-
451458
use_input_channel = False
452459
spark_opt_s3_uris = []
453460
spark_opt_s3_uris_has_pipeline_var = False
@@ -481,6 +488,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
481488

482489
# If any local files were found and copied, upload the temp directory to S3
483490
if os.listdir(tmpdir):
491+
from sagemaker.workflow.utilities import _pipeline_config
492+
493+
if _pipeline_config and _pipeline_config.code_hash:
494+
input_channel_s3_uri = (
495+
f"s3://{self.sagemaker_session.default_bucket()}"
496+
f"/{_pipeline_config.pipeline_name}/code/{_pipeline_config.code_hash}"
497+
f"/{input_channel_name}"
498+
)
499+
else:
500+
input_channel_s3_uri = (
501+
f"s3://{self.sagemaker_session.default_bucket()}"
502+
f"/{self._current_job_name}/input/{input_channel_name}"
503+
)
484504
logger.info(
485505
"Uploading dependencies from tmpdir %s to S3 %s", tmpdir, input_channel_s3_uri
486506
)
@@ -824,6 +844,7 @@ def get_run_args(
824844
arguments=arguments,
825845
)
826846

847+
@runnable_by_pipeline
827848
def run(
828849
self,
829850
submit_app: str,
@@ -1083,6 +1104,7 @@ def get_run_args(
10831104
arguments=arguments,
10841105
)
10851106

1107+
@runnable_by_pipeline
10861108
def run(
10871109
self,
10881110
submit_app: str,

0 commit comments

Comments
 (0)