Skip to content

Commit c796bd4

Browse files
committed
feat: Add code_location support to Processing
This change gives users the ability to control upload location of ScriptProcessor code with a code_location parameter, similarly to Framework estimators and FrameworkProcessor. Extends both unit and integration tests to verify correct upload locations.
1 parent 98079ef commit c796bd4

File tree

5 files changed

+264
-148
lines changed

5 files changed

+264
-148
lines changed

src/sagemaker/processing.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def __init__(
399399
env=None,
400400
tags=None,
401401
network_config=None,
402+
code_location=None,
402403
):
403404
"""Initializes a ``ScriptProcessor`` instance.
404405
@@ -443,10 +444,20 @@ def __init__(
443444
A :class:`~sagemaker.network.NetworkConfig`
444445
object that configures network isolation, encryption of
445446
inter-container traffic, security group IDs, and subnets.
447+
code_location (str): The S3 prefix URI where custom code will be uploaded
448+
(default: None) - don't include a trailing slash since a string prepended
449+
with a "/" is appended to ``code_location``. Code will be uploaded to
450+
folder 'code_location/job-name/input/code'. If not specified, the default
451+
``code location`` is 's3://{sagemaker-default-bucket}'.
446452
"""
447453
self._CODE_CONTAINER_BASE_PATH = "/opt/ml/processing/input/"
448454
self._CODE_CONTAINER_INPUT_NAME = "code"
449455
self.command = command
456+
self.code_location = (
457+
code_location[:-1]
458+
if (code_location and code_location.endswith("/"))
459+
else code_location
460+
)
450461

451462
super(ScriptProcessor, self).__init__(
452463
role=role,
@@ -653,9 +664,15 @@ def _upload_code(self, code, kms_key=None):
653664
str: The S3 URI of the uploaded file or directory.
654665
655666
"""
667+
if self.code_location:
668+
code_bucket, key_prefix = s3.parse_s3_url(self.code_location)
669+
else:
670+
code_bucket = self.sagemaker_session.default_bucket()
671+
key_prefix = ""
656672
desired_s3_uri = s3.s3_path_join(
657673
"s3://",
658-
self.sagemaker_session.default_bucket(),
674+
code_bucket,
675+
key_prefix,
659676
self._current_job_name,
660677
"input",
661678
self._CODE_CONTAINER_INPUT_NAME,
@@ -1373,15 +1390,12 @@ def __init__(
13731390
env=env,
13741391
tags=tags,
13751392
network_config=network_config,
1393+
code_location=code_location,
13761394
)
13771395

13781396
self._FRAMEWORK_ENTRYPOINT_CONTAINER_INPUT_NAME = "entrypoint"
13791397
self._FRAMEWORK_ENTRYPOINT_SCRIPT_NAME = "runproc.sh"
13801398

1381-
self.code_location = (
1382-
code_location[:-1] if (code_location and code_location.endswith("/")) else code_location
1383-
)
1384-
13851399
if image_uri is None or base_job_name is None:
13861400
# For these default configuration purposes, we don't need the optional args:
13871401
est = self._create_estimator()

src/sagemaker/sklearn/processing.py

+7
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
env=None,
4242
tags=None,
4343
network_config=None,
44+
code_location=None,
4445
):
4546
"""Initialize an ``SKLearnProcessor`` instance.
4647
@@ -80,6 +81,11 @@ def __init__(
8081
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
8182
object that configures network isolation, encryption of
8283
inter-container traffic, security group IDs, and subnets.
84+
code_location (str): The S3 prefix URI where custom code will be uploaded
85+
(default: None) - don't include a trailing slash since a string prepended
86+
with a "/" is appended to ``code_location``. Code will be uploaded to
87+
folder 'code_location/job-name/input/code'. If not specified, the default
88+
``code location`` is 's3://{sagemaker-default-bucket}'.
8389
"""
8490
if not command:
8591
command = ["python3"]
@@ -106,4 +112,5 @@ def __init__(
106112
env=env,
107113
tags=tags,
108114
network_config=network_config,
115+
code_location=code_location,
109116
)

tests/integ/test_processing.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,9 @@ def test_sklearn_with_network_config(sagemaker_session, sklearn_latest_version,
701701

702702

703703
def test_processing_job_inputs_and_output_config(
704-
sagemaker_session, image_uri, cpu_instance_type, output_kms_key
704+
sagemaker_session, image_uri, cpu_instance_type, output_kms_key, custom_bucket_name,
705705
):
706+
custom_code_location = f"s3://{custom_bucket_name}/customized-processing-code"
706707
script_processor = ScriptProcessor(
707708
role=ROLE,
708709
image_uri=image_uri,
@@ -716,6 +717,7 @@ def test_processing_job_inputs_and_output_config(
716717
base_job_name="test-script-processor",
717718
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
718719
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
720+
code_location=custom_code_location,
719721
sagemaker_session=sagemaker_session,
720722
)
721723

@@ -734,6 +736,10 @@ def test_processing_job_inputs_and_output_config(
734736
assert (
735737
job_description["ProcessingInputs"][:-1] == expected_inputs_and_outputs["ProcessingInputs"]
736738
)
739+
assert job_description["ProcessingInputs"][-1]["InputName"] == "code"
740+
assert job_description["ProcessingInputs"][-1]["S3Input"]["S3Uri"].startswith(
741+
custom_code_location
742+
)
737743
assert (
738744
job_description["ProcessingOutputConfig"]
739745
== expected_inputs_and_outputs["ProcessingOutputConfig"]

tests/integ/test_s3.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
import pytest
1919

20-
from sagemaker.s3 import S3Uploader
21-
from sagemaker.s3 import S3Downloader
20+
from sagemaker.s3 import S3Downloader, S3Uploader, s3_path_join
2221

2322
from tests.integ.kms_utils import get_or_create_kms_key
2423

@@ -238,3 +237,8 @@ def test_s3_uploader_and_downloader_downloads_files_when_given_directory_uris_wi
238237

239238
with open(os.path.join(TMP_BASE_PATH, my_inner_directory_uuid, file_2_name), "r") as f:
240239
assert file_2_body == f.read()
240+
241+
242+
def test_s3_path_join_ignores_empty_elements():
243+
# (At writing, ScriptProcessor code upload expects/requires this)
244+
assert s3_path_join("s3://", "mybucket", "", "a", "b", "", "c") == "s3://mybucket/a/b/c"

0 commit comments

Comments
 (0)