diff --git a/tests/unit/sagemaker/workflow/helpers.py b/tests/unit/sagemaker/workflow/helpers.py index 5fdc30f8ba..67405e9372 100644 --- a/tests/unit/sagemaker/workflow/helpers.py +++ b/tests/unit/sagemaker/workflow/helpers.py @@ -41,7 +41,10 @@ def __init__(self, name, display_name=None, description=None, depends_on=None): super(CustomStep, self).__init__( name, display_name, description, StepTypeEnum.TRAINING, depends_on ) - self._properties = Properties(path=f"Steps.{name}") + # for testing property reference, we just use DescribeTrainingJobResponse shape here. + self._properties = Properties( + path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse" + ) @property def arguments(self): diff --git a/tests/unit/sagemaker/workflow/test_transform_step.py b/tests/unit/sagemaker/workflow/test_transform_step.py index ee5c27f757..2a75349d56 100644 --- a/tests/unit/sagemaker/workflow/test_transform_step.py +++ b/tests/unit/sagemaker/workflow/test_transform_step.py @@ -23,10 +23,13 @@ from sagemaker.parameter import IntegerParameter from sagemaker.tuner import HyperparameterTuner from sagemaker.workflow.pipeline_context import PipelineSession +from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.steps import TransformStep, TransformInput from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.functions import Join +from sagemaker.workflow import is_pipeline_variable from sagemaker.transformer import Transformer @@ -53,6 +56,7 @@ def client(): client_mock._client_config.user_agent = ( "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" ) + client_mock.describe_model.return_value = {"PrimaryContainer": {}, "Containers": {}} return client_mock @@ -80,18 +84,44 @@ def pipeline_session(boto_session, client): ) -def test_transform_step_with_transformer(pipeline_session): - model_name = ParameterString("ModelName") +@pytest.mark.parametrize( + "model_name", + [ + "my-model", + ParameterString("ModelName"), + ParameterString("ModelName", default_value="my-model"), + Join(on="-", values=["my", "model"]), + CustomStep(name="custom-step").properties.RoleArn, + ], +) +@pytest.mark.parametrize( + "data", + [ + "s3://my-bucket/my-data", + ParameterString("MyTransformInput"), + ParameterString("MyTransformInput", default_value="s3://my-model"), + Join(on="/", values=["s3://my-bucket", "my-transform-data", "input"]), + CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath, + ], +) +@pytest.mark.parametrize( + "output_path", + [ + "s3://my-bucket/my-output-path", + ParameterString("MyOutputPath"), + ParameterString("MyOutputPath", default_value="s3://my-output"), + Join(on="/", values=["s3://my-bucket", "my-transform-data", "output"]), + CustomStep(name="custom-step").properties.OutputDataConfig.S3OutputPath, + ], +) +def test_transform_step_with_transformer(model_name, data, output_path, pipeline_session): transformer = Transformer( model_name=model_name, instance_type="ml.m5.xlarge", instance_count=1, - output_path=f"s3://{pipeline_session.default_bucket()}/Transform", + output_path=output_path, sagemaker_session=pipeline_session, ) - data = ParameterString( - name="Data", default_value=f"s3://{pipeline_session.default_bucket()}/batch-data" - ) transform_inputs = TransformInput(data=data) with warnings.catch_warnings(record=True) as w: @@ -123,13 +153,27 @@ def test_transform_step_with_transformer(pipeline_session): parameters=[model_name, data], sagemaker_session=pipeline_session, ) - step_args.args["ModelName"] = model_name.expr - step_args.args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = data.expr - assert json.loads(pipeline.definition())["Steps"][0] == { - "Name": "MyTransformStep", - "Type": "Transform", - "Arguments": step_args.args, - } + step_args = step_args.args + step_def = json.loads(pipeline.definition())["Steps"][0] + step_args["ModelName"] = model_name.expr if is_pipeline_variable(model_name) else model_name + step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = ( + data.expr if is_pipeline_variable(data) else data + ) + step_args["TransformOutput"]["S3OutputPath"] = ( + output_path.expr if is_pipeline_variable(output_path) else output_path + ) + + del ( + step_args["ModelName"], + step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"], + step_args["TransformOutput"]["S3OutputPath"], + ) + del ( + step_def["Arguments"]["ModelName"], + step_def["Arguments"]["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"], + step_def["Arguments"]["TransformOutput"]["S3OutputPath"], + ) + assert step_def == {"Name": "MyTransformStep", "Type": "Transform", "Arguments": step_args} @pytest.mark.parametrize(