Skip to content

Commit 28e07cf

Browse files
rsareddy0329Roja Reddy Sareddy
and
Roja Reddy Sareddy
authored
Added handler for pipeline variable while creating process job (#5122)
* change: Allow telemetry only in supported regions * change: Allow telemetry only in supported regions * change: Allow telemetry only in supported regions * change: Allow telemetry only in supported regions * change: Allow telemetry only in supported regions * documentation: Removed a line about python version requirements of training script which can misguide users.Training script can be of latest version based on the support provided by framework_version of the container * feature: Enabled update_endpoint through model_builder * fix: fix unit test, black-check, pylint errors * fix: fix black-check, pylint errors * fix:Added handler for pipeline variable while creating process job * fix: Added handler for pipeline variable while creating process job --------- Co-authored-by: Roja Reddy Sareddy <[email protected]>
1 parent fb22b91 commit 28e07cf

File tree

3 files changed

+272
-5
lines changed

3 files changed

+272
-5
lines changed

src/sagemaker/processing.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
and interpretation on Amazon SageMaker.
1818
"""
1919
from __future__ import absolute_import
20-
20+
import json
2121
import logging
2222
import os
2323
import pathlib
@@ -314,6 +314,15 @@ def _normalize_args(
314314
"code argument has to be a valid S3 URI or local file path "
315315
+ "rather than a pipeline variable"
316316
)
317+
if arguments is not None:
318+
processed_arguments = []
319+
for arg in arguments:
320+
if isinstance(arg, PipelineVariable):
321+
processed_value = json.dumps(arg.expr)
322+
processed_arguments.append(processed_value)
323+
else:
324+
processed_arguments.append(str(arg))
325+
arguments = processed_arguments
317326

318327
self._current_job_name = self._generate_current_job_name(job_name=job_name)
319328

tests/unit/sagemaker/workflow/test_processing_step.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,12 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
824824
processor, run_inputs = spark_processor
825825
processor.sagemaker_session = pipeline_session
826826
processor.role = ROLE
827-
827+
arguments_output = [
828+
"--input",
829+
"input-data-uri",
830+
"--output",
831+
'{"Get": "Parameters.MyArgOutput"}',
832+
]
828833
run_inputs["inputs"] = processing_input
829834

830835
step_args = processor.run(**run_inputs)
@@ -835,7 +840,7 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
835840

836841
step_args = get_step_args_helper(step_args, "Processing")
837842

838-
assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"]
843+
assert step_args["AppSpecification"]["ContainerArguments"] == arguments_output
839844

840845
entry_points = step_args["AppSpecification"]["ContainerEntrypoint"]
841846
entry_points_expr = []
@@ -1019,6 +1024,12 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_
10191024
processor, run_inputs = spark_processor
10201025
processor.sagemaker_session = pipeline_session
10211026
processor.role = ROLE
1027+
arguments_output = [
1028+
"--input",
1029+
"input-data-uri",
1030+
"--output",
1031+
'{"Get": "Parameters.MyArgOutput"}',
1032+
]
10221033

10231034
run_inputs["inputs"] = processing_input
10241035

@@ -1030,7 +1041,7 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_
10301041

10311042
step_args = get_step_args_helper(step_args, "Processing")
10321043

1033-
assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"]
1044+
assert step_args["AppSpecification"]["ContainerArguments"] == arguments_output
10341045

10351046
entry_points = step_args["AppSpecification"]["ContainerEntrypoint"]
10361047
entry_points_expr = []

tests/unit/test_processing.py

+248-1
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646
from sagemaker.fw_utils import UploadedCode
4747
from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig
4848
from sagemaker.workflow.functions import Join
49-
from sagemaker.workflow.execution_variables import ExecutionVariables
49+
from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables
5050
from tests.unit import SAGEMAKER_CONFIG_PROCESSING_JOB
51+
from sagemaker.workflow.parameters import ParameterString
5152

5253
BUCKET_NAME = "mybucket"
5354
REGION = "us-west-2"
@@ -1717,3 +1718,249 @@ def _get_describe_response_inputs_and_ouputs():
17171718
"ProcessingInputs": _get_expected_args_all_parameters(None)["inputs"],
17181719
"ProcessingOutputConfig": _get_expected_args_all_parameters(None)["output_config"],
17191720
}
1721+
1722+
1723+
# Parameters
1724+
def _get_data_inputs_with_parameters():
1725+
return [
1726+
ProcessingInput(
1727+
source=ParameterString(name="input_data", default_value="s3://dummy-bucket/input"),
1728+
destination="/opt/ml/processing/input",
1729+
input_name="input-1",
1730+
)
1731+
]
1732+
1733+
1734+
def _get_data_outputs_with_parameters():
1735+
return [
1736+
ProcessingOutput(
1737+
source="/opt/ml/processing/output",
1738+
destination=ParameterString(
1739+
name="output_data", default_value="s3://dummy-bucket/output"
1740+
),
1741+
output_name="output-1",
1742+
)
1743+
]
1744+
1745+
1746+
def _get_expected_args_with_parameters(job_name):
1747+
return {
1748+
"inputs": [
1749+
{
1750+
"InputName": "input-1",
1751+
"S3Input": {
1752+
"S3Uri": "s3://dummy-bucket/input",
1753+
"LocalPath": "/opt/ml/processing/input",
1754+
"S3DataType": "S3Prefix",
1755+
"S3InputMode": "File",
1756+
"S3DataDistributionType": "FullyReplicated",
1757+
"S3CompressionType": "None",
1758+
},
1759+
}
1760+
],
1761+
"output_config": {
1762+
"Outputs": [
1763+
{
1764+
"OutputName": "output-1",
1765+
"S3Output": {
1766+
"S3Uri": "s3://dummy-bucket/output",
1767+
"LocalPath": "/opt/ml/processing/output",
1768+
"S3UploadMode": "EndOfJob",
1769+
},
1770+
}
1771+
]
1772+
},
1773+
"job_name": job_name,
1774+
"resources": {
1775+
"ClusterConfig": {
1776+
"InstanceType": "ml.m4.xlarge",
1777+
"InstanceCount": 1,
1778+
"VolumeSizeInGB": 100,
1779+
"VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
1780+
}
1781+
},
1782+
"stopping_condition": {"MaxRuntimeInSeconds": 3600},
1783+
"app_specification": {
1784+
"ImageUri": "custom-image-uri",
1785+
"ContainerArguments": [
1786+
"--input-data",
1787+
"s3://dummy-bucket/input-param",
1788+
"--output-path",
1789+
"s3://dummy-bucket/output-param",
1790+
],
1791+
"ContainerEntrypoint": ["python3"],
1792+
},
1793+
"environment": {"my_env_variable": "my_env_variable_value"},
1794+
"network_config": {
1795+
"EnableNetworkIsolation": True,
1796+
"EnableInterContainerTrafficEncryption": True,
1797+
"VpcConfig": {
1798+
"Subnets": ["my_subnet_id"],
1799+
"SecurityGroupIds": ["my_security_group_id"],
1800+
},
1801+
},
1802+
"role_arn": "dummy/role",
1803+
"tags": [{"Key": "my-tag", "Value": "my-tag-value"}],
1804+
"experiment_config": {"ExperimentName": "AnExperiment"},
1805+
}
1806+
1807+
1808+
@patch("os.path.exists", return_value=True)
1809+
@patch("os.path.isfile", return_value=True)
1810+
@patch("sagemaker.utils.repack_model")
1811+
@patch("sagemaker.utils.create_tar_file")
1812+
@patch("sagemaker.session.Session.upload_data")
1813+
def test_script_processor_with_parameter_string(
1814+
upload_data_mock,
1815+
create_tar_file_mock,
1816+
repack_model_mock,
1817+
exists_mock,
1818+
isfile_mock,
1819+
sagemaker_session,
1820+
):
1821+
"""Test ScriptProcessor with ParameterString arguments"""
1822+
upload_data_mock.return_value = "s3://mocked_s3_uri_from_upload_data"
1823+
1824+
# Setup processor
1825+
processor = ScriptProcessor(
1826+
role="arn:aws:iam::012345678901:role/SageMakerRole", # Updated role ARN
1827+
image_uri="custom-image-uri",
1828+
command=["python3"],
1829+
instance_type="ml.m4.xlarge",
1830+
instance_count=1,
1831+
volume_size_in_gb=100,
1832+
volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
1833+
output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
1834+
max_runtime_in_seconds=3600,
1835+
base_job_name="test_processor",
1836+
env={"my_env_variable": "my_env_variable_value"},
1837+
tags=[{"Key": "my-tag", "Value": "my-tag-value"}],
1838+
network_config=NetworkConfig(
1839+
subnets=["my_subnet_id"],
1840+
security_group_ids=["my_security_group_id"],
1841+
enable_network_isolation=True,
1842+
encrypt_inter_container_traffic=True,
1843+
),
1844+
sagemaker_session=sagemaker_session,
1845+
)
1846+
1847+
input_param = ParameterString(name="input_param", default_value="s3://dummy-bucket/input-param")
1848+
output_param = ParameterString(
1849+
name="output_param", default_value="s3://dummy-bucket/output-param"
1850+
)
1851+
exec_var = ExecutionVariable(name="ExecutionTest")
1852+
join_var = Join(on="/", values=["s3://bucket", "prefix", "file.txt"])
1853+
dummy_str_var = "test-variable"
1854+
1855+
# Define expected arguments
1856+
expected_args = {
1857+
"inputs": [
1858+
{
1859+
"InputName": "input-1",
1860+
"AppManaged": False,
1861+
"S3Input": {
1862+
"S3Uri": ParameterString(
1863+
name="input_data", default_value="s3://dummy-bucket/input"
1864+
),
1865+
"LocalPath": "/opt/ml/processing/input",
1866+
"S3DataType": "S3Prefix",
1867+
"S3InputMode": "File",
1868+
"S3DataDistributionType": "FullyReplicated",
1869+
"S3CompressionType": "None",
1870+
},
1871+
},
1872+
{
1873+
"InputName": "code",
1874+
"AppManaged": False,
1875+
"S3Input": {
1876+
"S3Uri": "s3://mocked_s3_uri_from_upload_data",
1877+
"LocalPath": "/opt/ml/processing/input/code",
1878+
"S3DataType": "S3Prefix",
1879+
"S3InputMode": "File",
1880+
"S3DataDistributionType": "FullyReplicated",
1881+
"S3CompressionType": "None",
1882+
},
1883+
},
1884+
],
1885+
"output_config": {
1886+
"Outputs": [
1887+
{
1888+
"OutputName": "output-1",
1889+
"AppManaged": False,
1890+
"S3Output": {
1891+
"S3Uri": ParameterString(
1892+
name="output_data", default_value="s3://dummy-bucket/output"
1893+
),
1894+
"LocalPath": "/opt/ml/processing/output",
1895+
"S3UploadMode": "EndOfJob",
1896+
},
1897+
}
1898+
],
1899+
"KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
1900+
},
1901+
"job_name": "test_job",
1902+
"resources": {
1903+
"ClusterConfig": {
1904+
"InstanceType": "ml.m4.xlarge",
1905+
"InstanceCount": 1,
1906+
"VolumeSizeInGB": 100,
1907+
"VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
1908+
}
1909+
},
1910+
"stopping_condition": {"MaxRuntimeInSeconds": 3600},
1911+
"app_specification": {
1912+
"ImageUri": "custom-image-uri",
1913+
"ContainerArguments": [
1914+
"--input-data",
1915+
'{"Get": "Parameters.input_param"}',
1916+
"--output-path",
1917+
'{"Get": "Parameters.output_param"}',
1918+
"--exec-arg",
1919+
'{"Get": "Execution.ExecutionTest"}',
1920+
"--join-arg",
1921+
'{"Std:Join": {"On": "/", "Values": ["s3://bucket", "prefix", "file.txt"]}}',
1922+
"--string-param",
1923+
"test-variable",
1924+
],
1925+
"ContainerEntrypoint": ["python3", "/opt/ml/processing/input/code/processing_code.py"],
1926+
},
1927+
"environment": {"my_env_variable": "my_env_variable_value"},
1928+
"network_config": {
1929+
"EnableNetworkIsolation": True,
1930+
"EnableInterContainerTrafficEncryption": True,
1931+
"VpcConfig": {
1932+
"SecurityGroupIds": ["my_security_group_id"],
1933+
"Subnets": ["my_subnet_id"],
1934+
},
1935+
},
1936+
"role_arn": "arn:aws:iam::012345678901:role/SageMakerRole",
1937+
"tags": [{"Key": "my-tag", "Value": "my-tag-value"}],
1938+
"experiment_config": {"ExperimentName": "AnExperiment"},
1939+
}
1940+
1941+
# Run processor
1942+
processor.run(
1943+
code="/local/path/to/processing_code.py",
1944+
inputs=_get_data_inputs_with_parameters(),
1945+
outputs=_get_data_outputs_with_parameters(),
1946+
arguments=[
1947+
"--input-data",
1948+
input_param,
1949+
"--output-path",
1950+
output_param,
1951+
"--exec-arg",
1952+
exec_var,
1953+
"--join-arg",
1954+
join_var,
1955+
"--string-param",
1956+
dummy_str_var,
1957+
],
1958+
wait=True,
1959+
logs=False,
1960+
job_name="test_job",
1961+
experiment_config={"ExperimentName": "AnExperiment"},
1962+
)
1963+
1964+
# Assert
1965+
sagemaker_session.process.assert_called_with(**expected_args)
1966+
assert "test_job" in processor._current_job_name

0 commit comments

Comments
 (0)