Skip to content

Commit 50f49a2

Browse files
ravishankar-sivaramannavinnsmufaddal-rohawalajeniyat
authored andcommitted
feature: CompilationStep support for Sagemaker Pipelines (aws#2740)
Co-authored-by: Navin Soni <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: Jeniya Tabassum <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]>
1 parent 960fc54 commit 50f49a2

File tree

5 files changed

+302
-12
lines changed

5 files changed

+302
-12
lines changed

src/sagemaker/inputs.py

+61
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,67 @@ class CreateModelInput(object):
136136
accelerator_type: str = attr.ib(default=None)
137137

138138

139+
@attr.s
140+
class CompilationInput(object):
141+
"""Create a class containing all the parameters.
142+
143+
It can be used when calling ``sagemaker.model.Model.compile_model()``
144+
145+
Parameters:
146+
target_instance_type(str): Identifies the device that you want to
147+
run your model after compilation, for example: ml_c5. For allowed
148+
strings see
149+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
150+
input_shape(str): Specifies the name and shape of the expected
151+
inputs for your trained model in json dictionary form, for
152+
example: {'data': [1,3,1024,1024]}, or {'var1': [1,1,28,28],
153+
'var2': [1,1,28,28]}
154+
output_path(str): Specifies where to store the compiled model
155+
framework (str, optional): The framework that is used to train the original
156+
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
157+
'onnx', 'xgboost' (default: None)
158+
framework_version (str, optional): The version of the framework (default: None)
159+
compile_max_run (int, optional): Timeout in seconds for compilation (default:
160+
15 * 60). After this amount of time Amazon SageMaker Neo
161+
terminates the compilation job regardless of its current status.
162+
tags (list[dict], optional): List of tags for labeling a compilation job.
163+
For more, see
164+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
165+
job_name (str, optional): The name of the compilation job (default: None)
166+
target_platform_os (str, optional): Target Platform OS, for example: 'LINUX'.
167+
(default: None)
168+
For allowed strings see
169+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
170+
It can be used instead of target_instance_family.
171+
target_platform_arch (str, optional): Target Platform Architecture, for example: 'X86_64'.
172+
(default: None)
173+
For allowed strings see
174+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
175+
It can be used instead of target_instance_family.
176+
target_platform_accelerator (str, optional): Target Platform Accelerator,
177+
for example: 'NVIDIA'. (default: None)
178+
For allowed strings see
179+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
180+
It can be used instead of target_instance_family.
181+
compiler_options (dict, optional): Additional parameters for compiler. (default: None)
182+
Compiler Options are TargetPlatform / target_instance_family specific. See
183+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
184+
"""
185+
186+
target_instance_type: str = attr.ib(default=None)
187+
input_shape: dict = attr.ib(factory=dict)
188+
output_path: str = attr.ib(default=None)
189+
framework: str = attr.ib(default=None)
190+
framework_version: str = attr.ib(default=None)
191+
compile_max_run: int = attr.ib(default=15 * 60)
192+
tags: list = attr.ib(factory=list)
193+
job_name: str = attr.ib(default=None)
194+
target_platform_os: str = attr.ib(default=None)
195+
target_platform_arch: str = attr.ib(default=None)
196+
target_platform_accelerator: str = attr.ib(default=None)
197+
compiler_options: dict = attr.ib(default=None)
198+
199+
139200
@attr.s
140201
class TransformInput(object):
141202
"""Create a class containing all the parameters.

src/sagemaker/model.py

+53
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
utils,
3030
git_utils,
3131
)
32+
from sagemaker.inputs import CompilationInput
3233
from sagemaker.deprecations import removed_kwargs
3334
from sagemaker.predictor import PredictorBase
3435
from sagemaker.transformer import Transformer
@@ -410,6 +411,58 @@ def _compilation_job_config(
410411
"job_name": job_name,
411412
}
412413

414+
def _get_compilation_args(self, estimator, inputs):
415+
"""Constructs a dict of arguments for an Amazon SageMaker compilation job from estimator.
416+
417+
Args:
418+
estimator (sagemaker.estimator.EstimatorBase): Estimator object
419+
created by the user.
420+
inputs (CompilationInput): class containing all the parameters that
421+
can be used when calling ``sagemaker.model.Model.compile_model()``
422+
"""
423+
424+
if not isinstance(inputs, CompilationInput):
425+
raise TypeError("Your inputs must be provided as CompilationInput objects.")
426+
target_instance_family = inputs.target_instance_type
427+
input_shape = inputs.input_shape
428+
output_path = inputs.output_path
429+
role = estimator.role
430+
compile_max_run = inputs.compile_max_run
431+
job_name = estimator._compilation_job_name()
432+
framework = inputs.framework or self._framework()
433+
if framework is None:
434+
raise ValueError(
435+
"You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
436+
)
437+
if framework not in NEO_ALLOWED_FRAMEWORKS:
438+
raise ValueError(
439+
"You must provide valid framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
440+
)
441+
if self.model_data is None:
442+
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
443+
tags = inputs.tags
444+
target_platform_os = inputs.target_platform_os
445+
target_platform_arch = inputs.target_platform_arch
446+
target_platform_accelerator = inputs.target_platform_accelerator
447+
compiler_options = inputs.compiler_options
448+
framework_version = inputs.framework_version or self._get_framework_version()
449+
450+
return self._compilation_job_config(
451+
target_instance_family,
452+
input_shape,
453+
output_path,
454+
role,
455+
compile_max_run,
456+
job_name,
457+
framework,
458+
tags,
459+
target_platform_os,
460+
target_platform_arch,
461+
target_platform_accelerator,
462+
compiler_options,
463+
framework_version,
464+
)
465+
413466
def _compilation_image_uri(self, region, target_instance_type, framework, framework_version):
414467
"""Retrieve the Neo or Inferentia image URI.
415468

src/sagemaker/session.py

+51
Original file line numberDiff line numberDiff line change
@@ -1840,6 +1840,57 @@ def compile_model(
18401840
LOGGER.info("Creating compilation-job with name: %s", job_name)
18411841
self.sagemaker_client.create_compilation_job(**compilation_job_request)
18421842

1843+
def _get_compilation_request(
1844+
self,
1845+
job_name,
1846+
input_model_config,
1847+
output_model_config,
1848+
role,
1849+
stop_condition,
1850+
tags=None,
1851+
vpc_config=None,
1852+
):
1853+
"""Construct CreateCompilationJob request
1854+
1855+
Args:
1856+
input_model_config (dict): the trained model and the Amazon S3 location where it is
1857+
stored.
1858+
output_model_config (dict): Identifies the Amazon S3 location where you want Amazon
1859+
SageMaker Neo to save the results of compilation job
1860+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Neo
1861+
compilation jobs use this role to access model artifacts. You must grant
1862+
sufficient permissions to this role.
1863+
job_name (str): Name of the compilation job being created.
1864+
stop_condition (dict): Defines when compilation job shall finish. Contains entries
1865+
that can be understood by the service like ``MaxRuntimeInSeconds``.
1866+
tags (list[dict]): List of tags for labeling a compile model job. For more, see
1867+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
1868+
vpc_config (dict): Contains values for VpcConfig:
1869+
* subnets (list[str]): List of subnet ids.
1870+
The key in vpc_config is 'Subnets'.
1871+
* security_group_ids (list[str]): List of security group ids.
1872+
The key in vpc_config is 'SecurityGroupIds'.
1873+
Returns:
1874+
dict: A dictionary for CreateCompilationJob request
1875+
"""
1876+
1877+
compilation_request = {
1878+
"InputConfig": input_model_config,
1879+
"OutputConfig": output_model_config,
1880+
"RoleArn": role,
1881+
"StoppingCondition": stop_condition,
1882+
"CompilationJobName": job_name,
1883+
}
1884+
1885+
tags = _append_project_tags(tags)
1886+
if tags is not None:
1887+
compilation_request["Tags"] = tags
1888+
1889+
if vpc_config is not None:
1890+
compilation_request["VpcConfig"] = vpc_config
1891+
1892+
return compilation_request
1893+
18431894
def package_model_for_edge(
18441895
self,
18451896
output_model_config,

src/sagemaker/workflow/steps.py

+88-11
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17-
1817
from enum import Enum
1918
from typing import Dict, List, Union
2019

2120
import attr
2221

2322
from sagemaker.estimator import EstimatorBase, _TrainingJob
24-
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput
23+
from sagemaker.inputs import (
24+
CompilationInput,
25+
CreateModelInput,
26+
FileSystemInput,
27+
TrainingInput,
28+
TransformInput,
29+
)
2530
from sagemaker.model import Model
2631
from sagemaker.processing import (
2732
ProcessingInput,
@@ -31,16 +36,9 @@
3136
)
3237
from sagemaker.transformer import Transformer, _TransformJob
3338
from sagemaker.tuner import HyperparameterTuner, _TuningJob
34-
from sagemaker.workflow.entities import (
35-
DefaultEnumMeta,
36-
Entity,
37-
RequestType,
38-
)
39-
from sagemaker.workflow.properties import (
40-
PropertyFile,
41-
Properties,
42-
)
39+
from sagemaker.workflow.entities import DefaultEnumMeta, Entity, RequestType
4340
from sagemaker.workflow.functions import Join
41+
from sagemaker.workflow.properties import Properties, PropertyFile
4442
from sagemaker.workflow.retry import RetryPolicy
4543

4644

@@ -55,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5553
TRANSFORM = "Transform"
5654
CALLBACK = "Callback"
5755
TUNING = "Tuning"
56+
COMPILATION = "Compilation"
5857
LAMBDA = "Lambda"
5958
EMR = "EMR"
6059

@@ -682,3 +681,81 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
682681
"output/model.tar.gz",
683682
],
684683
)
684+
685+
686+
class CompilationStep(ConfigurableRetryStep):
687+
"""Compilation step for workflow."""
688+
689+
def __init__(
690+
self,
691+
name: str,
692+
estimator: EstimatorBase,
693+
model: Model,
694+
inputs: CompilationInput = None,
695+
job_arguments: List[str] = None,
696+
depends_on: Union[List[str], List[Step]] = None,
697+
retry_policies: List[RetryPolicy] = None,
698+
display_name: str = None,
699+
description: str = None,
700+
cache_config: CacheConfig = None,
701+
):
702+
"""Construct a CompilationStep.
703+
704+
Given an `EstimatorBase` and a `sagemaker.model.Model` instance construct a CompilationStep.
705+
706+
In addition to the estimator and Model instances, the other arguments are those that are
707+
supplied to the `compile_model` method of the `sagemaker.model.Model.compile_model`.
708+
709+
Args:
710+
name (str): The name of the compilation step.
711+
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
712+
model (Model): A `sagemaker.model.Model` instance.
713+
inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
714+
Defaults to `None`.
715+
job_arguments (List[str]): A list of strings to be passed into the processing job.
716+
Defaults to `None`.
717+
depends_on (List[str] or List[Step]): A list of step names or step instances
718+
this `sagemaker.workflow.steps.CompilationStep` depends on
719+
retry_policies (List[RetryPolicy]): A list of retry policy
720+
display_name (str): The display name of the compilation step.
721+
description (str): The description of the compilation step.
722+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
723+
"""
724+
super(CompilationStep, self).__init__(
725+
name, StepTypeEnum.COMPILATION, display_name, description, depends_on, retry_policies
726+
)
727+
self.estimator = estimator
728+
self.model = model
729+
self.inputs = inputs
730+
self.job_arguments = job_arguments
731+
self._properties = Properties(
732+
path=f"Steps.{name}", shape_name="DescribeCompilationJobResponse"
733+
)
734+
self.cache_config = cache_config
735+
736+
@property
737+
def arguments(self) -> RequestType:
738+
"""The arguments dict that is used to call `create_compilation_job`.
739+
740+
NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
741+
The TrainingJobName and ExperimentConfig attributes cannot be included.
742+
"""
743+
744+
compilation_args = self.model._get_compilation_args(self.estimator, self.inputs)
745+
request_dict = self.model.sagemaker_session._get_compilation_request(**compilation_args)
746+
request_dict.pop("CompilationJobName")
747+
748+
return request_dict
749+
750+
@property
751+
def properties(self):
752+
"""A Properties object representing the DescribeTrainingJobResponse data model."""
753+
return self._properties
754+
755+
def to_request(self) -> RequestType:
756+
"""Updates the dictionary with cache configuration."""
757+
request_dict = super().to_request()
758+
if self.cache_config:
759+
request_dict.update(self.cache_config.config)
760+
761+
return request_dict

tests/unit/sagemaker/workflow/test_steps.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker.debugger import DEBUGGER_FLAG, ProfilerConfig
2727
from sagemaker.estimator import Estimator
2828
from sagemaker.tensorflow import TensorFlow
29-
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput
29+
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput, CompilationInput
3030
from sagemaker.model import Model
3131
from sagemaker.processing import (
3232
Processor,
@@ -52,6 +52,7 @@
5252
)
5353
from sagemaker.workflow.steps import (
5454
ProcessingStep,
55+
CompilationStep,
5556
ConfigurableRetryStep,
5657
StepTypeEnum,
5758
TrainingStep,
@@ -1054,3 +1055,50 @@ def test_multi_algo_tuning_step(sagemaker_session):
10541055
],
10551056
},
10561057
}
1058+
1059+
1060+
def test_compilation_step(sagemaker_session):
1061+
estimator = Estimator(
1062+
image_uri=IMAGE_URI,
1063+
role=ROLE,
1064+
instance_count=1,
1065+
instance_type="ml.c5.4xlarge",
1066+
profiler_config=ProfilerConfig(system_monitor_interval_millis=500),
1067+
rules=[],
1068+
sagemaker_session=sagemaker_session,
1069+
)
1070+
1071+
model = Model(
1072+
image_uri=IMAGE_URI,
1073+
model_data="s3://output/tensorflow.tar.gz",
1074+
sagemaker_session=sagemaker_session,
1075+
)
1076+
1077+
compilation_input = CompilationInput(
1078+
target_instance_type="ml_inf",
1079+
input_shape={"data": [1, 3, 1024, 1024]},
1080+
output_path="s3://output",
1081+
compile_max_run=100,
1082+
framework="tensorflow",
1083+
job_name="compile-model",
1084+
compiler_options=None,
1085+
)
1086+
compilation_step = CompilationStep(
1087+
name="MyCompilationStep", estimator=estimator, model=model, inputs=compilation_input
1088+
)
1089+
1090+
assert compilation_step.to_request() == {
1091+
"Name": "MyCompilationStep",
1092+
"Type": "Compilation",
1093+
"Arguments": {
1094+
"InputConfig": {
1095+
"DataInputConfig": '{"data": [1, 3, 1024, 1024]}',
1096+
"Framework": "TENSORFLOW",
1097+
"S3Uri": "s3://output/tensorflow.tar.gz",
1098+
},
1099+
"OutputConfig": {"S3OutputLocation": "s3://output", "TargetDevice": "ml_inf"},
1100+
"RoleArn": ROLE,
1101+
"StoppingCondition": {"MaxRuntimeInSeconds": 100},
1102+
"Tags": [],
1103+
},
1104+
}

0 commit comments

Comments
 (0)