|
46 | 46 | SageMakerJobStepRetryPolicy,
|
47 | 47 | )
|
48 | 48 | from sagemaker.xgboost import XGBoostModel
|
| 49 | +from sagemaker.lambda_helper import Lambda |
| 50 | +from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum |
49 | 51 | from tests.unit import DATA_DIR
|
50 | 52 | from tests.unit.sagemaker.workflow.helpers import CustomStep
|
51 | 53 |
|
@@ -844,3 +846,44 @@ def _verify_register_model_container_definition(
|
844 | 846 | if submit_dir and not submit_dir.startswith("s3://"):
|
845 | 847 | # exclude the s3 path assertion as it contains timestamp
|
846 | 848 | assert submit_dir == expected_submit_dir
|
| 849 | + |
| 850 | + |
| 851 | +def test_model_step_with_lambda_property_reference(pipeline_session): |
| 852 | + lambda_step = LambdaStep( |
| 853 | + name="MyLambda", |
| 854 | + lambda_func=Lambda( |
| 855 | + function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda" |
| 856 | + ), |
| 857 | + outputs=[ |
| 858 | + LambdaOutput(output_name="model_image", output_type=LambdaOutputTypeEnum.String), |
| 859 | + LambdaOutput(output_name="model_artifact", output_type=LambdaOutputTypeEnum.String), |
| 860 | + ], |
| 861 | + ) |
| 862 | + |
| 863 | + model = PyTorchModel( |
| 864 | + name="MyModel", |
| 865 | + framework_version="1.8.0", |
| 866 | + py_version="py3", |
| 867 | + image_uri=lambda_step.properties.Outputs["model_image"], |
| 868 | + model_data=lambda_step.properties.Outputs["model_artifact"], |
| 869 | + sagemaker_session=pipeline_session, |
| 870 | + entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", |
| 871 | + role=_ROLE, |
| 872 | + ) |
| 873 | + |
| 874 | + step_create_model = ModelStep(name="mymodelstep", step_args=model.create()) |
| 875 | + |
| 876 | + pipeline = Pipeline( |
| 877 | + name="MyPipeline", |
| 878 | + steps=[lambda_step, step_create_model], |
| 879 | + sagemaker_session=pipeline_session, |
| 880 | + ) |
| 881 | + steps = json.loads(pipeline.definition())["Steps"] |
| 882 | + repack_step = steps[1] |
| 883 | + assert repack_step["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"][ |
| 884 | + "S3Uri" |
| 885 | + ] == {"Get": "Steps.MyLambda.OutputParameters['model_artifact']"} |
| 886 | + register_step = steps[2] |
| 887 | + assert register_step["Arguments"]["PrimaryContainer"]["Image"] == { |
| 888 | + "Get": "Steps.MyLambda.OutputParameters['model_image']" |
| 889 | + } |
0 commit comments