Skip to content

add input parameterization tests for workflow job steps #3150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 170 additions & 115 deletions tests/unit/sagemaker/workflow/test_processing_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from sagemaker.tuner import HyperparameterTuner
from sagemaker.workflow.pipeline_context import PipelineSession

from sagemaker.processing import Processor, ScriptProcessor, FrameworkProcessor
from sagemaker.processing import (
Processor,
ScriptProcessor,
FrameworkProcessor,
ProcessingOutput,
ProcessingInput,
)
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.pytorch.processing import PyTorchProcessor
from sagemaker.tensorflow.processing import TensorFlowProcessor
Expand All @@ -34,11 +40,12 @@
from sagemaker.wrangler.processing import DataWranglerProcessor
from sagemaker.spark.processing import SparkJarProcessor, PySparkProcessor

from sagemaker.processing import ProcessingInput

from sagemaker.workflow.steps import CacheConfig, ProcessingStep
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.properties import PropertyFile
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.functions import Join

from sagemaker.network import NetworkConfig
from sagemaker.pytorch.estimator import PyTorch
Expand All @@ -62,6 +69,144 @@
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
INSTANCE_TYPE = "ml.m4.xlarge"

FRAMEWORK_PROCESSOR = [
(
FrameworkProcessor(
framework_version="1.8",
instance_type=INSTANCE_TYPE,
instance_count=1,
role=ROLE,
estimator_cls=PyTorch,
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
SKLearnProcessor(
framework_version="0.23-1",
instance_type=INSTANCE_TYPE,
instance_count=1,
role=ROLE,
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
PyTorchProcessor(
role=ROLE,
instance_type=INSTANCE_TYPE,
instance_count=1,
framework_version="1.8.0",
py_version="py3",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
TensorFlowProcessor(
role=ROLE,
instance_type=INSTANCE_TYPE,
instance_count=1,
framework_version="2.0",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
HuggingFaceProcessor(
transformers_version="4.6",
pytorch_version="1.7",
role=ROLE,
instance_count=1,
instance_type="ml.p3.2xlarge",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
XGBoostProcessor(
framework_version="1.3-1",
py_version="py3",
role=ROLE,
instance_count=1,
instance_type=INSTANCE_TYPE,
base_job_name="test-xgboost",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
MXNetProcessor(
framework_version="1.4.1",
py_version="py3",
role=ROLE,
instance_count=1,
instance_type=INSTANCE_TYPE,
base_job_name="test-mxnet",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
DataWranglerProcessor(
role=ROLE,
data_wrangler_flow_source="s3://my-bucket/dw.flow",
instance_count=1,
instance_type=INSTANCE_TYPE,
),
{},
),
(
SparkJarProcessor(
role=ROLE,
framework_version="2.4",
instance_count=1,
instance_type=INSTANCE_TYPE,
),
{
"submit_app": "s3://my-jar",
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
},
),
(
PySparkProcessor(
role=ROLE,
framework_version="2.4",
instance_count=1,
instance_type=INSTANCE_TYPE,
),
{
"submit_app": "s3://my-jar",
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
},
),
]

PROCESSING_INPUT = [
ProcessingInput(source="s3://my-bucket/processing_manifest", destination="processing_manifest"),
ProcessingInput(
source=ParameterString(name="my-processing-input"),
destination="processing-input",
),
ProcessingInput(
source=ParameterString(
name="my-processing-input", default_value="s3://my-bucket/my-processing"
),
destination="processing-input",
),
ProcessingInput(
source=Join(on="/", values=["s3://my-bucket", "my-input"]),
destination="processing-input",
),
]

PROCESSING_OUTPUT = [
ProcessingOutput(source="/opt/ml/output", destination="s3://my-bucket/my-output"),
ProcessingOutput(source="/opt/ml/output", destination=ParameterString(name="my-output")),
ProcessingOutput(
source="/opt/ml/output",
destination=ParameterString(name="my-output", default_value="s3://my-bucket/my-output"),
),
ProcessingOutput(
source="/opt/ml/output",
destination=Join(on="/", values=["s3://my-bucket", "my-output"]),
),
]


@pytest.fixture
def client():
Expand Down Expand Up @@ -253,117 +398,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
}


@pytest.mark.parametrize(
"framework_processor",
[
(
FrameworkProcessor(
framework_version="1.8",
instance_type=INSTANCE_TYPE,
instance_count=1,
role=ROLE,
estimator_cls=PyTorch,
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
SKLearnProcessor(
framework_version="0.23-1",
instance_type=INSTANCE_TYPE,
instance_count=1,
role=ROLE,
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
PyTorchProcessor(
role=ROLE,
instance_type=INSTANCE_TYPE,
instance_count=1,
framework_version="1.8.0",
py_version="py3",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
TensorFlowProcessor(
role=ROLE,
instance_type=INSTANCE_TYPE,
instance_count=1,
framework_version="2.0",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
HuggingFaceProcessor(
transformers_version="4.6",
pytorch_version="1.7",
role=ROLE,
instance_count=1,
instance_type="ml.p3.2xlarge",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
XGBoostProcessor(
framework_version="1.3-1",
py_version="py3",
role=ROLE,
instance_count=1,
instance_type=INSTANCE_TYPE,
base_job_name="test-xgboost",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
MXNetProcessor(
framework_version="1.4.1",
py_version="py3",
role=ROLE,
instance_count=1,
instance_type=INSTANCE_TYPE,
base_job_name="test-mxnet",
),
{"code": DUMMY_S3_SCRIPT_PATH},
),
(
DataWranglerProcessor(
role=ROLE,
data_wrangler_flow_source=f"s3://{BUCKET}/dw.flow",
instance_count=1,
instance_type=INSTANCE_TYPE,
),
{},
),
(
SparkJarProcessor(
role=ROLE,
framework_version="2.4",
instance_count=1,
instance_type=INSTANCE_TYPE,
),
{
"submit_app": "s3://my-jar",
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
},
),
(
PySparkProcessor(
role=ROLE,
framework_version="2.4",
instance_count=1,
instance_type=INSTANCE_TYPE,
),
{
"submit_app": "s3://my-jar",
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
},
),
],
)
@pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR)
@pytest.mark.parametrize("processing_input", PROCESSING_INPUT)
@pytest.mark.parametrize("processing_output", PROCESSING_OUTPUT)
def test_processing_step_with_framework_processor(
framework_processor, pipeline_session, processing_input, network_config
framework_processor, pipeline_session, processing_input, processing_output, network_config
):

processor, run_inputs = framework_processor
Expand All @@ -373,7 +412,8 @@ def test_processing_step_with_framework_processor(
processor.volume_kms_key = "volume-kms-key"
processor.network_config = network_config

run_inputs["inputs"] = processing_input
run_inputs["inputs"] = [processing_input]
run_inputs["outputs"] = [processing_output]

step_args = processor.run(**run_inputs)

Expand All @@ -387,10 +427,25 @@ def test_processing_step_with_framework_processor(
sagemaker_session=pipeline_session,
)

assert json.loads(pipeline.definition())["Steps"][0] == {
step_args = step_args.args
step_def = json.loads(pipeline.definition())["Steps"][0]

assert step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"] == processing_input.source
assert (
step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
== processing_output.destination
)

del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"]
del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"]

del step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
del step_def["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]

assert step_def == {
"Name": "MyProcessingStep",
"Type": "Processing",
"Arguments": step_args.args,
"Arguments": step_args,
}


Expand Down
Loading