Skip to content

feature: CompilationStep support for Sagemaker Pipelines #2740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/sagemaker/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,62 @@ class CreateModelInput(object):
accelerator_type: str = attr.ib(default=None)


@attr.s
class CompilationInput(object):
"""Create a class containing all the parameters.

It can be used when calling ``sagemaker.model.Model.compile_model()``

Parameters:
target_instance_type(str): Identifies the device that you want to
run your model after compilation, for example: ml_c5. For allowed
strings see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
input_shape(str): Specifies the name and shape of the expected
inputs for your trained model in json dictionary form, for
example: {'data': [1,3,1024,1024]}, or {'var1': [1,1,28,28],
'var2': [1,1,28,28]}
output_path(str): Specifies where to store the compiled model
framework (str): The framework that is used to train the original
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
'onnx', 'xgboost'
framework_version (str): The version of the framework
compile_max_run (int): Timeout in seconds for compilation (default:
15 * 60). After this amount of time Amazon SageMaker Neo
terminates the compilation job regardless of its current status.
tags (list[dict]): List of tags for labeling a compilation job. For
more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
For allowed strings see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
It can be used instead of target_instance_family.
target_platform_arch (str): Target Platform Architecture, for example: 'X86_64'.
For allowed strings see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
It can be used instead of target_instance_family.
target_platform_accelerator (str, optional): Target Platform Accelerator,
for example: 'NVIDIA'. For allowed strings see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
It can be used instead of target_instance_family.
compiler_options (dict, optional): Additional parameters for compiler.
Compiler Options are TargetPlatform / target_instance_family specific. See
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
"""

target_instance_type: str = attr.ib(default=None)
input_shape: dict = attr.ib(factory=dict)
output_path: str = attr.ib(default=None)
framework: str = attr.ib(default=None)
framework_version: str = attr.ib(default=None)
compile_max_run: int = attr.ib(default=15 * 60)
tags: list = attr.ib(factory=list)
target_platform_os: str = attr.ib(default=None)
target_platform_arch: str = attr.ib(default=None)
target_platform_accelerator: str = attr.ib(default=None)
compiler_options: dict = attr.ib(factory=dict)


@attr.s
class TransformInput(object):
"""Create a class containing all the parameters.
Expand Down
53 changes: 53 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
utils,
git_utils,
)
from sagemaker.inputs import CompilationInput
from sagemaker.deprecations import removed_kwargs
from sagemaker.predictor import PredictorBase
from sagemaker.transformer import Transformer
Expand Down Expand Up @@ -410,6 +411,58 @@ def _compilation_job_config(
"job_name": job_name,
}

def _get_compilation_args(self, estimator, inputs):
"""Constructs a dict of arguments for an Amazon SageMaker compilation job from estimator.

Args:
estimator (sagemaker.estimator.EstimatorBase): Estimator object
created by the user.
inputs (CompilationInput): class containing all the parameters that
can be used when calling ``sagemaker.model.Model.compile_model()``
"""

if not isinstance(inputs, CompilationInput):
raise TypeError("Your inputs must be provided as ProcessingInput objects.")
target_instance_family = inputs.target_instance_type
input_shape = inputs.input_shape
output_path = inputs.output_path
role = estimator.role
compile_max_run = inputs.compile_max_run
job_name = estimator._compilation_job_name()
framework = inputs.framework or self._framework()
if framework is None:
raise ValueError(
"You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
)
if framework not in NEO_ALLOWED_FRAMEWORKS:
raise ValueError(
"You must provide valid framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
)
if self.model_data is None:
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
tags = inputs.tags
target_platform_os = inputs.target_platform_os
target_platform_arch = inputs.target_platform_arch
target_platform_accelerator = inputs.target_platform_accelerator
compiler_options = inputs.compiler_options
framework_version = inputs.framework_version or self._get_framework_version()

return self._compilation_job_config(
target_instance_family,
input_shape,
output_path,
role,
compile_max_run,
job_name,
framework,
tags,
target_platform_os,
target_platform_arch,
target_platform_accelerator,
compiler_options,
framework_version,
)

def _compilation_image_uri(self, region, target_instance_type, framework, framework_version):
"""Retrieve the Neo or Inferentia image URI.

Expand Down
51 changes: 51 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,57 @@ def compile_model(
LOGGER.info("Creating compilation-job with name: %s", job_name)
self.sagemaker_client.create_compilation_job(**compilation_job_request)

def _get_compilation_request(
self,
job_name,
input_model_config,
output_model_config,
role,
stop_condition,
tags=None,
vpc_config=None,
):
"""Construct CreateCompilationJob request

Args:
input_model_config (dict): the trained model and the Amazon S3 location where it is
stored.
output_model_config (dict): Identifies the Amazon S3 location where you want Amazon
SageMaker Neo to save the results of compilation job
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Neo
compilation jobs use this role to access model artifacts. You must grant
sufficient permissions to this role.
job_name (str): Name of the compilation job being created.
stop_condition (dict): Defines when compilation job shall finish. Contains entries
that can be understood by the service like ``MaxRuntimeInSeconds``.
tags (list[dict]): List of tags for labeling a compile model job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
vpc_config (dict): Contains values for VpcConfig:
* subnets (list[str]): List of subnet ids.
The key in vpc_config is 'Subnets'.
* security_group_ids (list[str]): List of security group ids.
The key in vpc_config is 'SecurityGroupIds'.
Returns:
dict: A dictionary for CreateCompilationJob request
"""

compilation_request = {
"InputConfig": input_model_config,
"OutputConfig": output_model_config,
"RoleArn": role,
"StoppingCondition": stop_condition,
"CompilationJobName": job_name,
}

tags = _append_project_tags(tags)
if tags is not None:
compilation_request["Tags"] = tags

if vpc_config is not None:
compilation_request["VpcConfig"] = vpc_config

return compilation_request

def package_model_for_edge(
self,
output_model_config,
Expand Down
99 changes: 88 additions & 11 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
from __future__ import absolute_import

import abc

from enum import Enum
from typing import Dict, List, Union

import attr

from sagemaker.estimator import EstimatorBase, _TrainingJob
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput
from sagemaker.inputs import (
CompilationInput,
CreateModelInput,
FileSystemInput,
TrainingInput,
TransformInput,
)
from sagemaker.model import Model
from sagemaker.processing import (
ProcessingInput,
Expand All @@ -31,16 +36,9 @@
)
from sagemaker.transformer import Transformer, _TransformJob
from sagemaker.tuner import HyperparameterTuner, _TuningJob
from sagemaker.workflow.entities import (
DefaultEnumMeta,
Entity,
RequestType,
)
from sagemaker.workflow.properties import (
PropertyFile,
Properties,
)
from sagemaker.workflow.entities import DefaultEnumMeta, Entity, RequestType
from sagemaker.workflow.functions import Join
from sagemaker.workflow.properties import Properties, PropertyFile
from sagemaker.workflow.retry import RetryPolicy


Expand All @@ -55,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
TRANSFORM = "Transform"
CALLBACK = "Callback"
TUNING = "Tuning"
COMPILATION = "Compilation"
LAMBDA = "Lambda"


Expand Down Expand Up @@ -681,3 +680,81 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
"output/model.tar.gz",
],
)


class CompilationStep(Step):
"""Compilation step for workflow."""

def __init__(
self,
name: str,
estimator: EstimatorBase,
model: Model,
inputs: CompilationInput = None,
job_arguments: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
retry_policies: List[RetryPolicy] = None,
display_name: str = None,
description: str = None,
cache_config: CacheConfig = None,
):
"""Construct a CompilationStep.

Given an `EstimatorBase` and a `sagemaker.model.Model` instance construct a CompilationStep.

In addition to the estimator and Model instances, the other arguments are those that are
supplied to the `compile_model` method of the `sagemaker.model.Model.compile_model`.

Args:
name (str): The name of the compilation step.
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
model (Model): A `sagemaker.model.Model` instance.
inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
Defaults to `None`.
job_arguments (List[str]): A list of strings to be passed into the processing job.
Defaults to `None`.
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.CompilationStep` depends on
retry_policies (List[RetryPolicy]): A list of retry policy
display_name (str): The display name of the compilation step.
description (str): The description of the compilation step.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
"""
super(CompilationStep, self).__init__(
name, StepTypeEnum.COMPILATION, display_name, description, depends_on, retry_policies
)
self.estimator = estimator
self.model = model
self.inputs = inputs
self.job_arguments = job_arguments
self._properties = Properties(
path=f"Steps.{name}", shape_name="DescribeCompilationJobResponse"
)
self.cache_config = cache_config

@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call `create_compilation_job`.

NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
The TrainingJobName and ExperimentConfig attributes cannot be included.
"""

compilation_args = self.model._get_compilation_args(self.estimator, self.inputs)
request_dict = self.model.sagemaker_session._get_compilation_request(**compilation_args)
request_dict.pop("CompilationJobName")

return request_dict

@property
def properties(self):
"""A Properties object representing the DescribeTrainingJobResponse data model."""
return self._properties

def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)

return request_dict