Skip to content

Commit 190fcb5

Browse files
Merge branch 'master' into docstring-fix
2 parents 350e164 + 617bfab commit 190fcb5

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

src/sagemaker/fw_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import os
1818
import re
19+
import time
1920
import shutil
2021
import tempfile
2122
from collections import namedtuple
@@ -24,6 +25,7 @@
2425
import sagemaker.image_uris
2526
from sagemaker.session_settings import SessionSettings
2627
import sagemaker.utils
28+
from sagemaker.workflow import is_pipeline_variable
2729

2830
from sagemaker.deprecations import renamed_warning
2931

@@ -395,8 +397,10 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
395397
Returns:
396398
str: the key prefix to be used in uploading code
397399
"""
398-
training_job_name = sagemaker.utils.name_from_image(image)
399-
return "/".join(filter(None, [code_location_key_prefix, model_name or training_job_name]))
400+
name_from_image = f"/model_code/{int(time.time())}"
401+
if not is_pipeline_variable(image):
402+
name_from_image = sagemaker.utils.name_from_image(image)
403+
return "/".join(filter(None, [code_location_key_prefix, model_name or name_from_image]))
400404

401405

402406
def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution):

tests/unit/sagemaker/workflow/test_model_step.py

+43
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
SageMakerJobStepRetryPolicy,
4747
)
4848
from sagemaker.xgboost import XGBoostModel
49+
from sagemaker.lambda_helper import Lambda
50+
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
4951
from tests.unit import DATA_DIR
5052
from tests.unit.sagemaker.workflow.helpers import CustomStep
5153

@@ -844,3 +846,44 @@ def _verify_register_model_container_definition(
844846
if submit_dir and not submit_dir.startswith("s3://"):
845847
# exclude the s3 path assertion as it contains timestamp
846848
assert submit_dir == expected_submit_dir
849+
850+
851+
def test_model_step_with_lambda_property_reference(pipeline_session):
852+
lambda_step = LambdaStep(
853+
name="MyLambda",
854+
lambda_func=Lambda(
855+
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda"
856+
),
857+
outputs=[
858+
LambdaOutput(output_name="model_image", output_type=LambdaOutputTypeEnum.String),
859+
LambdaOutput(output_name="model_artifact", output_type=LambdaOutputTypeEnum.String),
860+
],
861+
)
862+
863+
model = PyTorchModel(
864+
name="MyModel",
865+
framework_version="1.8.0",
866+
py_version="py3",
867+
image_uri=lambda_step.properties.Outputs["model_image"],
868+
model_data=lambda_step.properties.Outputs["model_artifact"],
869+
sagemaker_session=pipeline_session,
870+
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
871+
role=_ROLE,
872+
)
873+
874+
step_create_model = ModelStep(name="mymodelstep", step_args=model.create())
875+
876+
pipeline = Pipeline(
877+
name="MyPipeline",
878+
steps=[lambda_step, step_create_model],
879+
sagemaker_session=pipeline_session,
880+
)
881+
steps = json.loads(pipeline.definition())["Steps"]
882+
repack_step = steps[1]
883+
assert repack_step["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"][
884+
"S3Uri"
885+
] == {"Get": "Steps.MyLambda.OutputParameters['model_artifact']"}
886+
register_step = steps[2]
887+
assert register_step["Arguments"]["PrimaryContainer"]["Image"] == {
888+
"Get": "Steps.MyLambda.OutputParameters['model_image']"
889+
}

0 commit comments

Comments
 (0)