|
46 | 46 | from sagemaker.fw_utils import UploadedCode
|
47 | 47 | from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig
|
48 | 48 | from sagemaker.workflow.functions import Join
|
49 |
| -from sagemaker.workflow.execution_variables import ExecutionVariables |
| 49 | +from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables |
50 | 50 | from tests.unit import SAGEMAKER_CONFIG_PROCESSING_JOB
|
| 51 | +from sagemaker.workflow.parameters import ParameterString |
51 | 52 |
|
52 | 53 | BUCKET_NAME = "mybucket"
|
53 | 54 | REGION = "us-west-2"
|
@@ -1717,3 +1718,249 @@ def _get_describe_response_inputs_and_ouputs():
|
1717 | 1718 | "ProcessingInputs": _get_expected_args_all_parameters(None)["inputs"],
|
1718 | 1719 | "ProcessingOutputConfig": _get_expected_args_all_parameters(None)["output_config"],
|
1719 | 1720 | }
|
| 1721 | + |
| 1722 | + |
| 1723 | +# Parameters |
| 1724 | +def _get_data_inputs_with_parameters(): |
| 1725 | + return [ |
| 1726 | + ProcessingInput( |
| 1727 | + source=ParameterString(name="input_data", default_value="s3://dummy-bucket/input"), |
| 1728 | + destination="/opt/ml/processing/input", |
| 1729 | + input_name="input-1", |
| 1730 | + ) |
| 1731 | + ] |
| 1732 | + |
| 1733 | + |
| 1734 | +def _get_data_outputs_with_parameters(): |
| 1735 | + return [ |
| 1736 | + ProcessingOutput( |
| 1737 | + source="/opt/ml/processing/output", |
| 1738 | + destination=ParameterString( |
| 1739 | + name="output_data", default_value="s3://dummy-bucket/output" |
| 1740 | + ), |
| 1741 | + output_name="output-1", |
| 1742 | + ) |
| 1743 | + ] |
| 1744 | + |
| 1745 | + |
| 1746 | +def _get_expected_args_with_parameters(job_name): |
| 1747 | + return { |
| 1748 | + "inputs": [ |
| 1749 | + { |
| 1750 | + "InputName": "input-1", |
| 1751 | + "S3Input": { |
| 1752 | + "S3Uri": "s3://dummy-bucket/input", |
| 1753 | + "LocalPath": "/opt/ml/processing/input", |
| 1754 | + "S3DataType": "S3Prefix", |
| 1755 | + "S3InputMode": "File", |
| 1756 | + "S3DataDistributionType": "FullyReplicated", |
| 1757 | + "S3CompressionType": "None", |
| 1758 | + }, |
| 1759 | + } |
| 1760 | + ], |
| 1761 | + "output_config": { |
| 1762 | + "Outputs": [ |
| 1763 | + { |
| 1764 | + "OutputName": "output-1", |
| 1765 | + "S3Output": { |
| 1766 | + "S3Uri": "s3://dummy-bucket/output", |
| 1767 | + "LocalPath": "/opt/ml/processing/output", |
| 1768 | + "S3UploadMode": "EndOfJob", |
| 1769 | + }, |
| 1770 | + } |
| 1771 | + ] |
| 1772 | + }, |
| 1773 | + "job_name": job_name, |
| 1774 | + "resources": { |
| 1775 | + "ClusterConfig": { |
| 1776 | + "InstanceType": "ml.m4.xlarge", |
| 1777 | + "InstanceCount": 1, |
| 1778 | + "VolumeSizeInGB": 100, |
| 1779 | + "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", |
| 1780 | + } |
| 1781 | + }, |
| 1782 | + "stopping_condition": {"MaxRuntimeInSeconds": 3600}, |
| 1783 | + "app_specification": { |
| 1784 | + "ImageUri": "custom-image-uri", |
| 1785 | + "ContainerArguments": [ |
| 1786 | + "--input-data", |
| 1787 | + "s3://dummy-bucket/input-param", |
| 1788 | + "--output-path", |
| 1789 | + "s3://dummy-bucket/output-param", |
| 1790 | + ], |
| 1791 | + "ContainerEntrypoint": ["python3"], |
| 1792 | + }, |
| 1793 | + "environment": {"my_env_variable": "my_env_variable_value"}, |
| 1794 | + "network_config": { |
| 1795 | + "EnableNetworkIsolation": True, |
| 1796 | + "EnableInterContainerTrafficEncryption": True, |
| 1797 | + "VpcConfig": { |
| 1798 | + "Subnets": ["my_subnet_id"], |
| 1799 | + "SecurityGroupIds": ["my_security_group_id"], |
| 1800 | + }, |
| 1801 | + }, |
| 1802 | + "role_arn": "dummy/role", |
| 1803 | + "tags": [{"Key": "my-tag", "Value": "my-tag-value"}], |
| 1804 | + "experiment_config": {"ExperimentName": "AnExperiment"}, |
| 1805 | + } |
| 1806 | + |
| 1807 | + |
| 1808 | +@patch("os.path.exists", return_value=True) |
| 1809 | +@patch("os.path.isfile", return_value=True) |
| 1810 | +@patch("sagemaker.utils.repack_model") |
| 1811 | +@patch("sagemaker.utils.create_tar_file") |
| 1812 | +@patch("sagemaker.session.Session.upload_data") |
| 1813 | +def test_script_processor_with_parameter_string( |
| 1814 | + upload_data_mock, |
| 1815 | + create_tar_file_mock, |
| 1816 | + repack_model_mock, |
| 1817 | + exists_mock, |
| 1818 | + isfile_mock, |
| 1819 | + sagemaker_session, |
| 1820 | +): |
| 1821 | + """Test ScriptProcessor with ParameterString arguments""" |
| 1822 | + upload_data_mock.return_value = "s3://mocked_s3_uri_from_upload_data" |
| 1823 | + |
| 1824 | + # Setup processor |
| 1825 | + processor = ScriptProcessor( |
| 1826 | + role="arn:aws:iam::012345678901:role/SageMakerRole", # Updated role ARN |
| 1827 | + image_uri="custom-image-uri", |
| 1828 | + command=["python3"], |
| 1829 | + instance_type="ml.m4.xlarge", |
| 1830 | + instance_count=1, |
| 1831 | + volume_size_in_gb=100, |
| 1832 | + volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", |
| 1833 | + output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key", |
| 1834 | + max_runtime_in_seconds=3600, |
| 1835 | + base_job_name="test_processor", |
| 1836 | + env={"my_env_variable": "my_env_variable_value"}, |
| 1837 | + tags=[{"Key": "my-tag", "Value": "my-tag-value"}], |
| 1838 | + network_config=NetworkConfig( |
| 1839 | + subnets=["my_subnet_id"], |
| 1840 | + security_group_ids=["my_security_group_id"], |
| 1841 | + enable_network_isolation=True, |
| 1842 | + encrypt_inter_container_traffic=True, |
| 1843 | + ), |
| 1844 | + sagemaker_session=sagemaker_session, |
| 1845 | + ) |
| 1846 | + |
| 1847 | + input_param = ParameterString(name="input_param", default_value="s3://dummy-bucket/input-param") |
| 1848 | + output_param = ParameterString( |
| 1849 | + name="output_param", default_value="s3://dummy-bucket/output-param" |
| 1850 | + ) |
| 1851 | + exec_var = ExecutionVariable(name="ExecutionTest") |
| 1852 | + join_var = Join(on="/", values=["s3://bucket", "prefix", "file.txt"]) |
| 1853 | + dummy_str_var = "test-variable" |
| 1854 | + |
| 1855 | + # Define expected arguments |
| 1856 | + expected_args = { |
| 1857 | + "inputs": [ |
| 1858 | + { |
| 1859 | + "InputName": "input-1", |
| 1860 | + "AppManaged": False, |
| 1861 | + "S3Input": { |
| 1862 | + "S3Uri": ParameterString( |
| 1863 | + name="input_data", default_value="s3://dummy-bucket/input" |
| 1864 | + ), |
| 1865 | + "LocalPath": "/opt/ml/processing/input", |
| 1866 | + "S3DataType": "S3Prefix", |
| 1867 | + "S3InputMode": "File", |
| 1868 | + "S3DataDistributionType": "FullyReplicated", |
| 1869 | + "S3CompressionType": "None", |
| 1870 | + }, |
| 1871 | + }, |
| 1872 | + { |
| 1873 | + "InputName": "code", |
| 1874 | + "AppManaged": False, |
| 1875 | + "S3Input": { |
| 1876 | + "S3Uri": "s3://mocked_s3_uri_from_upload_data", |
| 1877 | + "LocalPath": "/opt/ml/processing/input/code", |
| 1878 | + "S3DataType": "S3Prefix", |
| 1879 | + "S3InputMode": "File", |
| 1880 | + "S3DataDistributionType": "FullyReplicated", |
| 1881 | + "S3CompressionType": "None", |
| 1882 | + }, |
| 1883 | + }, |
| 1884 | + ], |
| 1885 | + "output_config": { |
| 1886 | + "Outputs": [ |
| 1887 | + { |
| 1888 | + "OutputName": "output-1", |
| 1889 | + "AppManaged": False, |
| 1890 | + "S3Output": { |
| 1891 | + "S3Uri": ParameterString( |
| 1892 | + name="output_data", default_value="s3://dummy-bucket/output" |
| 1893 | + ), |
| 1894 | + "LocalPath": "/opt/ml/processing/output", |
| 1895 | + "S3UploadMode": "EndOfJob", |
| 1896 | + }, |
| 1897 | + } |
| 1898 | + ], |
| 1899 | + "KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key", |
| 1900 | + }, |
| 1901 | + "job_name": "test_job", |
| 1902 | + "resources": { |
| 1903 | + "ClusterConfig": { |
| 1904 | + "InstanceType": "ml.m4.xlarge", |
| 1905 | + "InstanceCount": 1, |
| 1906 | + "VolumeSizeInGB": 100, |
| 1907 | + "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", |
| 1908 | + } |
| 1909 | + }, |
| 1910 | + "stopping_condition": {"MaxRuntimeInSeconds": 3600}, |
| 1911 | + "app_specification": { |
| 1912 | + "ImageUri": "custom-image-uri", |
| 1913 | + "ContainerArguments": [ |
| 1914 | + "--input-data", |
| 1915 | + '{"Get": "Parameters.input_param"}', |
| 1916 | + "--output-path", |
| 1917 | + '{"Get": "Parameters.output_param"}', |
| 1918 | + "--exec-arg", |
| 1919 | + '{"Get": "Execution.ExecutionTest"}', |
| 1920 | + "--join-arg", |
| 1921 | + '{"Std:Join": {"On": "/", "Values": ["s3://bucket", "prefix", "file.txt"]}}', |
| 1922 | + "--string-param", |
| 1923 | + "test-variable", |
| 1924 | + ], |
| 1925 | + "ContainerEntrypoint": ["python3", "/opt/ml/processing/input/code/processing_code.py"], |
| 1926 | + }, |
| 1927 | + "environment": {"my_env_variable": "my_env_variable_value"}, |
| 1928 | + "network_config": { |
| 1929 | + "EnableNetworkIsolation": True, |
| 1930 | + "EnableInterContainerTrafficEncryption": True, |
| 1931 | + "VpcConfig": { |
| 1932 | + "SecurityGroupIds": ["my_security_group_id"], |
| 1933 | + "Subnets": ["my_subnet_id"], |
| 1934 | + }, |
| 1935 | + }, |
| 1936 | + "role_arn": "arn:aws:iam::012345678901:role/SageMakerRole", |
| 1937 | + "tags": [{"Key": "my-tag", "Value": "my-tag-value"}], |
| 1938 | + "experiment_config": {"ExperimentName": "AnExperiment"}, |
| 1939 | + } |
| 1940 | + |
| 1941 | + # Run processor |
| 1942 | + processor.run( |
| 1943 | + code="/local/path/to/processing_code.py", |
| 1944 | + inputs=_get_data_inputs_with_parameters(), |
| 1945 | + outputs=_get_data_outputs_with_parameters(), |
| 1946 | + arguments=[ |
| 1947 | + "--input-data", |
| 1948 | + input_param, |
| 1949 | + "--output-path", |
| 1950 | + output_param, |
| 1951 | + "--exec-arg", |
| 1952 | + exec_var, |
| 1953 | + "--join-arg", |
| 1954 | + join_var, |
| 1955 | + "--string-param", |
| 1956 | + dummy_str_var, |
| 1957 | + ], |
| 1958 | + wait=True, |
| 1959 | + logs=False, |
| 1960 | + job_name="test_job", |
| 1961 | + experiment_config={"ExperimentName": "AnExperiment"}, |
| 1962 | + ) |
| 1963 | + |
| 1964 | + # Assert |
| 1965 | + sagemaker_session.process.assert_called_with(**expected_args) |
| 1966 | + assert "test_job" in processor._current_job_name |
0 commit comments