Skip to content

Commit 70e59d6

Browse files
staubhpPayton Staub
authored andcommitted
feature: Support model pipelines in CreateModelStep (#2845)
Co-authored-by: Payton Staub <[email protected]>
1 parent 0a65986 commit 70e59d6

File tree

2 files changed

+99
-12
lines changed

2 files changed

+99
-12
lines changed

src/sagemaker/workflow/steps.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
TransformInput,
3131
)
3232
from sagemaker.model import Model
33+
from sagemaker.pipeline import PipelineModel
3334
from sagemaker.processing import (
3435
ProcessingInput,
3536
ProcessingJob,
@@ -319,7 +320,7 @@ class CreateModelStep(ConfigurableRetryStep):
319320
def __init__(
320321
self,
321322
name: str,
322-
model: Model,
323+
model: Union[Model, PipelineModel],
323324
inputs: CreateModelInput,
324325
depends_on: Union[List[str], List[Step]] = None,
325326
retry_policies: List[RetryPolicy] = None,
@@ -333,7 +334,8 @@ def __init__(
333334
334335
Args:
335336
name (str): The name of the CreateModel step.
336-
model (Model): A `sagemaker.model.Model` instance.
337+
model (Model or PipelineModel): A `sagemaker.model.Model`
338+
or `sagemaker.pipeline.PipelineModel` instance.
337339
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
338340
Defaults to `None`.
339341
depends_on (List[str] or List[Step]): A list of step names or step instances
@@ -358,16 +360,25 @@ def arguments(self) -> RequestType:
358360
ModelName cannot be included in the arguments.
359361
"""
360362

361-
request_dict = self.model.sagemaker_session._create_model_request(
362-
name="",
363-
role=self.model.role,
364-
container_defs=self.model.prepare_container_def(
365-
instance_type=self.inputs.instance_type,
366-
accelerator_type=self.inputs.accelerator_type,
367-
),
368-
vpc_config=self.model.vpc_config,
369-
enable_network_isolation=self.model.enable_network_isolation(),
370-
)
363+
if isinstance(self.model, PipelineModel):
364+
request_dict = self.model.sagemaker_session._create_model_request(
365+
name="",
366+
role=self.model.role,
367+
container_defs=self.model.pipeline_container_def(self.inputs.instance_type),
368+
vpc_config=self.model.vpc_config,
369+
enable_network_isolation=self.model.enable_network_isolation,
370+
)
371+
else:
372+
request_dict = self.model.sagemaker_session._create_model_request(
373+
name="",
374+
role=self.model.role,
375+
container_defs=self.model.prepare_container_def(
376+
instance_type=self.inputs.instance_type,
377+
accelerator_type=self.inputs.accelerator_type,
378+
),
379+
vpc_config=self.model.vpc_config,
380+
enable_network_isolation=self.model.enable_network_isolation(),
381+
)
371382
request_dict.pop("ModelName")
372383

373384
return request_dict

tests/unit/sagemaker/workflow/test_steps.py

+76
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
CreateModelStep,
6363
CacheConfig,
6464
)
65+
from sagemaker.pipeline import PipelineModel
66+
from sagemaker.sparkml import SparkMLModel
67+
from sagemaker.predictor import Predictor
68+
from sagemaker.model import FrameworkModel
6569
from tests.unit import DATA_DIR
6670

6771
DUMMY_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
@@ -89,6 +93,21 @@ def properties(self):
8993
return self._properties
9094

9195

96+
class DummyFrameworkModel(FrameworkModel):
97+
def __init__(self, sagemaker_session, **kwargs):
98+
super(DummyFrameworkModel, self).__init__(
99+
"s3://bucket/model_1.tar.gz",
100+
"mi-1",
101+
ROLE,
102+
os.path.join(DATA_DIR, "dummy_script.py"),
103+
sagemaker_session=sagemaker_session,
104+
**kwargs,
105+
)
106+
107+
def create_predictor(self, endpoint_name):
108+
return Predictor(endpoint_name, self.sagemaker_session)
109+
110+
92111
@pytest.fixture
93112
def boto_session():
94113
role_mock = Mock()
@@ -704,6 +723,63 @@ def test_create_model_step(sagemaker_session):
704723
assert step.properties.ModelName.expr == {"Get": "Steps.MyCreateModelStep.ModelName"}
705724

706725

726+
@patch("tarfile.open")
727+
@patch("time.strftime", return_value="2017-10-10-14-14-15")
728+
def test_create_model_step_with_model_pipeline(tfo, time, sagemaker_session):
729+
framework_model = DummyFrameworkModel(sagemaker_session)
730+
sparkml_model = SparkMLModel(
731+
model_data="s3://bucket/model_2.tar.gz",
732+
role=ROLE,
733+
sagemaker_session=sagemaker_session,
734+
env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
735+
)
736+
model = PipelineModel(
737+
models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session
738+
)
739+
inputs = CreateModelInput(
740+
instance_type="c4.4xlarge",
741+
accelerator_type="ml.eia1.medium",
742+
)
743+
step = CreateModelStep(
744+
name="MyCreateModelStep",
745+
depends_on=["TestStep"],
746+
display_name="MyCreateModelStep",
747+
description="TestDescription",
748+
model=model,
749+
inputs=inputs,
750+
)
751+
step.add_depends_on(["SecondTestStep"])
752+
753+
assert step.to_request() == {
754+
"Name": "MyCreateModelStep",
755+
"Type": "Model",
756+
"Description": "TestDescription",
757+
"DisplayName": "MyCreateModelStep",
758+
"DependsOn": ["TestStep", "SecondTestStep"],
759+
"Arguments": {
760+
"Containers": [
761+
{
762+
"Environment": {
763+
"SAGEMAKER_PROGRAM": "dummy_script.py",
764+
"SAGEMAKER_SUBMIT_DIRECTORY": "s3://my-bucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz",
765+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
766+
"SAGEMAKER_REGION": "us-west-2",
767+
},
768+
"Image": "mi-1",
769+
"ModelDataUrl": "s3://bucket/model_1.tar.gz",
770+
},
771+
{
772+
"Environment": {"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
773+
"Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4",
774+
"ModelDataUrl": "s3://bucket/model_2.tar.gz",
775+
},
776+
],
777+
"ExecutionRoleArn": "DummyRole",
778+
},
779+
}
780+
assert step.properties.ModelName.expr == {"Get": "Steps.MyCreateModelStep.ModelName"}
781+
782+
707783
def test_transform_step(sagemaker_session):
708784
transformer = Transformer(
709785
model_name=MODEL_NAME,

0 commit comments

Comments
 (0)