|
46 | 46 | SageMakerJobStepRetryPolicy,
|
47 | 47 | )
|
48 | 48 | from sagemaker.xgboost import XGBoostModel
|
| 49 | +from sagemaker.lambda_helper import Lambda |
| 50 | +from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum |
49 | 51 | from tests.unit import DATA_DIR
|
50 | 52 |
|
51 | 53 | _IMAGE_URI = "fakeimage"
|
@@ -839,3 +841,36 @@ def _verify_register_model_container_definition(
|
839 | 841 | if submit_dir and not submit_dir.startswith("s3://"):
|
840 | 842 | # exclude the s3 path assertion as it contains timestamp
|
841 | 843 | assert submit_dir == expected_submit_dir
|
| 844 | + |
| 845 | + |
| 846 | +def test_model_step_with_lambda_property_reference(pipeline_session): |
| 847 | + lambda_step = LambdaStep( |
| 848 | + name="MyLambda", |
| 849 | + lambda_func=Lambda( |
| 850 | + function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda" |
| 851 | + ), |
| 852 | + outputs=[ |
| 853 | + LambdaOutput(output_name="model_image", output_type=LambdaOutputTypeEnum.String), |
| 854 | + LambdaOutput(output_name="model_artifact", output_type=LambdaOutputTypeEnum.String), |
| 855 | + ], |
| 856 | + ) |
| 857 | + |
| 858 | + model = PyTorchModel( |
| 859 | + name="MyModel", |
| 860 | + framework_version="1.8.0", |
| 861 | + py_version="py3", |
| 862 | + image_uri=lambda_step.properties.Outputs["model_image"], |
| 863 | + model_data=lambda_step.properties.Outputs["model_artifact"], |
| 864 | + sagemaker_session=pipeline_session, |
| 865 | + entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", |
| 866 | + role=_ROLE, |
| 867 | + ) |
| 868 | + |
| 869 | + step_create_model = ModelStep(name="mymodelstep", step_args=model.create()) |
| 870 | + |
| 871 | + pipeline = Pipeline( |
| 872 | + name="MyPipeline", |
| 873 | + steps=[lambda_step, step_create_model], |
| 874 | + sagemaker_session=pipeline_session, |
| 875 | + ) |
| 876 | + assert pipeline.definition() is not None |
0 commit comments