Skip to content

Commit c437191

Browse files
staubhpPayton Staub
authored andcommitted
fix: Revert "feature: CompilationStep support for Sagemaker Pipelines (#2900)
Co-authored-by: Payton Staub <[email protected]>
1 parent dd5d9be commit c437191

File tree

6 files changed

+12
-304
lines changed

6 files changed

+12
-304
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,6 @@ Steps
140140

141141
.. autoclass:: sagemaker.workflow.lambda_step.LambdaStep
142142

143-
.. autoclass:: sagemaker.workflow.steps.CompilationStep
144-
145143
.. autoclass:: sagemaker.workflow.quality_check_step.QualityCheckConfig
146144

147145
.. autoclass:: sagemaker.workflow.quality_check_step.QualityCheckStep

src/sagemaker/inputs.py

-61
Original file line numberDiff line numberDiff line change
@@ -136,67 +136,6 @@ class CreateModelInput(object):
136136
accelerator_type: str = attr.ib(default=None)
137137

138138

139-
@attr.s
140-
class CompilationInput(object):
141-
"""Create a class containing all the parameters.
142-
143-
It can be used when calling ``sagemaker.model.Model.compile_model()``
144-
145-
Parameters:
146-
target_instance_type(str): Identifies the device that you want to
147-
run your model after compilation, for example: ml_c5. For allowed
148-
strings see
149-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
150-
input_shape(str): Specifies the name and shape of the expected
151-
inputs for your trained model in json dictionary form, for
152-
example: {'data': [1,3,1024,1024]}, or {'var1': [1,1,28,28],
153-
'var2': [1,1,28,28]}
154-
output_path(str): Specifies where to store the compiled model
155-
framework (str, optional): The framework that is used to train the original
156-
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
157-
'onnx', 'xgboost' (default: None)
158-
framework_version (str, optional): The version of the framework (default: None)
159-
compile_max_run (int, optional): Timeout in seconds for compilation (default:
160-
15 * 60). After this amount of time Amazon SageMaker Neo
161-
terminates the compilation job regardless of its current status.
162-
tags (list[dict], optional): List of tags for labeling a compilation job.
163-
For more, see
164-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
165-
job_name (str, optional): The name of the compilation job (default: None)
166-
target_platform_os (str, optional): Target Platform OS, for example: 'LINUX'.
167-
(default: None)
168-
For allowed strings see
169-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
170-
It can be used instead of target_instance_family.
171-
target_platform_arch (str, optional): Target Platform Architecture, for example: 'X86_64'.
172-
(default: None)
173-
For allowed strings see
174-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
175-
It can be used instead of target_instance_family.
176-
target_platform_accelerator (str, optional): Target Platform Accelerator,
177-
for example: 'NVIDIA'. (default: None)
178-
For allowed strings see
179-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
180-
It can be used instead of target_instance_family.
181-
compiler_options (dict, optional): Additional parameters for compiler. (default: None)
182-
Compiler Options are TargetPlatform / target_instance_family specific. See
183-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
184-
"""
185-
186-
target_instance_type: str = attr.ib(default=None)
187-
input_shape: dict = attr.ib(factory=dict)
188-
output_path: str = attr.ib(default=None)
189-
framework: str = attr.ib(default=None)
190-
framework_version: str = attr.ib(default=None)
191-
compile_max_run: int = attr.ib(default=15 * 60)
192-
tags: list = attr.ib(factory=list)
193-
job_name: str = attr.ib(default=None)
194-
target_platform_os: str = attr.ib(default=None)
195-
target_platform_arch: str = attr.ib(default=None)
196-
target_platform_accelerator: str = attr.ib(default=None)
197-
compiler_options: dict = attr.ib(default=None)
198-
199-
200139
@attr.s
201140
class TransformInput(object):
202141
"""Create a class containing all the parameters.

src/sagemaker/model.py

-53
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
utils,
3131
git_utils,
3232
)
33-
from sagemaker.inputs import CompilationInput
3433
from sagemaker.deprecations import removed_kwargs
3534
from sagemaker.predictor import PredictorBase
3635
from sagemaker.serverless import ServerlessInferenceConfig
@@ -658,58 +657,6 @@ def _compilation_job_config(
658657
"job_name": job_name,
659658
}
660659

661-
def _get_compilation_args(self, estimator, inputs):
662-
"""Constructs a dict of arguments for an Amazon SageMaker compilation job from estimator.
663-
664-
Args:
665-
estimator (sagemaker.estimator.EstimatorBase): Estimator object
666-
created by the user.
667-
inputs (CompilationInput): class containing all the parameters that
668-
can be used when calling ``sagemaker.model.Model.compile_model()``
669-
"""
670-
671-
if not isinstance(inputs, CompilationInput):
672-
raise TypeError("Your inputs must be provided as CompilationInput objects.")
673-
target_instance_family = inputs.target_instance_type
674-
input_shape = inputs.input_shape
675-
output_path = inputs.output_path
676-
role = estimator.role
677-
compile_max_run = inputs.compile_max_run
678-
job_name = estimator._compilation_job_name()
679-
framework = inputs.framework or self._framework()
680-
if framework is None:
681-
raise ValueError(
682-
"You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
683-
)
684-
if framework not in NEO_ALLOWED_FRAMEWORKS:
685-
raise ValueError(
686-
"You must provide valid framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
687-
)
688-
if self.model_data is None:
689-
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
690-
tags = inputs.tags
691-
target_platform_os = inputs.target_platform_os
692-
target_platform_arch = inputs.target_platform_arch
693-
target_platform_accelerator = inputs.target_platform_accelerator
694-
compiler_options = inputs.compiler_options
695-
framework_version = inputs.framework_version or self._get_framework_version()
696-
697-
return self._compilation_job_config(
698-
target_instance_family,
699-
input_shape,
700-
output_path,
701-
role,
702-
compile_max_run,
703-
job_name,
704-
framework,
705-
tags,
706-
target_platform_os,
707-
target_platform_arch,
708-
target_platform_accelerator,
709-
compiler_options,
710-
framework_version,
711-
)
712-
713660
def _compilation_image_uri(self, region, target_instance_type, framework, framework_version):
714661
"""Retrieve the Neo or Inferentia image URI.
715662

src/sagemaker/session.py

-51
Original file line numberDiff line numberDiff line change
@@ -1845,57 +1845,6 @@ def compile_model(
18451845
LOGGER.info("Creating compilation-job with name: %s", job_name)
18461846
self.sagemaker_client.create_compilation_job(**compilation_job_request)
18471847

1848-
def _get_compilation_request(
1849-
self,
1850-
job_name,
1851-
input_model_config,
1852-
output_model_config,
1853-
role,
1854-
stop_condition,
1855-
tags=None,
1856-
vpc_config=None,
1857-
):
1858-
"""Construct CreateCompilationJob request
1859-
1860-
Args:
1861-
input_model_config (dict): the trained model and the Amazon S3 location where it is
1862-
stored.
1863-
output_model_config (dict): Identifies the Amazon S3 location where you want Amazon
1864-
SageMaker Neo to save the results of compilation job
1865-
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Neo
1866-
compilation jobs use this role to access model artifacts. You must grant
1867-
sufficient permissions to this role.
1868-
job_name (str): Name of the compilation job being created.
1869-
stop_condition (dict): Defines when compilation job shall finish. Contains entries
1870-
that can be understood by the service like ``MaxRuntimeInSeconds``.
1871-
tags (list[dict]): List of tags for labeling a compile model job. For more, see
1872-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
1873-
vpc_config (dict): Contains values for VpcConfig:
1874-
* subnets (list[str]): List of subnet ids.
1875-
The key in vpc_config is 'Subnets'.
1876-
* security_group_ids (list[str]): List of security group ids.
1877-
The key in vpc_config is 'SecurityGroupIds'.
1878-
Returns:
1879-
dict: A dictionary for CreateCompilationJob request
1880-
"""
1881-
1882-
compilation_request = {
1883-
"InputConfig": input_model_config,
1884-
"OutputConfig": output_model_config,
1885-
"RoleArn": role,
1886-
"StoppingCondition": stop_condition,
1887-
"CompilationJobName": job_name,
1888-
}
1889-
1890-
tags = _append_project_tags(tags)
1891-
if tags is not None:
1892-
compilation_request["Tags"] = tags
1893-
1894-
if vpc_config is not None:
1895-
compilation_request["VpcConfig"] = vpc_config
1896-
1897-
return compilation_request
1898-
18991848
def package_model_for_edge(
19001849
self,
19011850
output_model_config,

src/sagemaker/workflow/steps.py

+11-88
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,15 @@
1515

1616
import abc
1717
import warnings
18+
1819
from enum import Enum
1920
from typing import Dict, List, Union
2021
from urllib.parse import urlparse
2122

2223
import attr
2324

2425
from sagemaker.estimator import EstimatorBase, _TrainingJob
25-
from sagemaker.inputs import (
26-
CompilationInput,
27-
CreateModelInput,
28-
FileSystemInput,
29-
TrainingInput,
30-
TransformInput,
31-
)
26+
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput
3227
from sagemaker.model import Model
3328
from sagemaker.pipeline import PipelineModel
3429
from sagemaker.processing import (
@@ -39,9 +34,16 @@
3934
)
4035
from sagemaker.transformer import Transformer, _TransformJob
4136
from sagemaker.tuner import HyperparameterTuner, _TuningJob
42-
from sagemaker.workflow.entities import DefaultEnumMeta, Entity, RequestType
37+
from sagemaker.workflow.entities import (
38+
DefaultEnumMeta,
39+
Entity,
40+
RequestType,
41+
)
42+
from sagemaker.workflow.properties import (
43+
PropertyFile,
44+
Properties,
45+
)
4346
from sagemaker.workflow.functions import Join
44-
from sagemaker.workflow.properties import Properties, PropertyFile
4547
from sagemaker.workflow.retry import RetryPolicy
4648

4749

@@ -56,7 +58,6 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5658
TRANSFORM = "Transform"
5759
CALLBACK = "Callback"
5860
TUNING = "Tuning"
59-
COMPILATION = "Compilation"
6061
LAMBDA = "Lambda"
6162
QUALITY_CHECK = "QualityCheck"
6263
CLARIFY_CHECK = "ClarifyCheck"
@@ -730,81 +731,3 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
730731
"output/model.tar.gz",
731732
],
732733
)
733-
734-
735-
class CompilationStep(ConfigurableRetryStep):
736-
"""Compilation step for workflow."""
737-
738-
def __init__(
739-
self,
740-
name: str,
741-
estimator: EstimatorBase,
742-
model: Model,
743-
inputs: CompilationInput = None,
744-
job_arguments: List[str] = None,
745-
depends_on: Union[List[str], List[Step]] = None,
746-
retry_policies: List[RetryPolicy] = None,
747-
display_name: str = None,
748-
description: str = None,
749-
cache_config: CacheConfig = None,
750-
):
751-
"""Construct a CompilationStep.
752-
753-
Given an `EstimatorBase` and a `sagemaker.model.Model` instance construct a CompilationStep.
754-
755-
In addition to the estimator and Model instances, the other arguments are those that are
756-
supplied to the `compile_model` method of the `sagemaker.model.Model.compile_model`.
757-
758-
Args:
759-
name (str): The name of the compilation step.
760-
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
761-
model (Model): A `sagemaker.model.Model` instance.
762-
inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
763-
Defaults to `None`.
764-
job_arguments (List[str]): A list of strings to be passed into the processing job.
765-
Defaults to `None`.
766-
depends_on (List[str] or List[Step]): A list of step names or step instances
767-
this `sagemaker.workflow.steps.CompilationStep` depends on
768-
retry_policies (List[RetryPolicy]): A list of retry policy
769-
display_name (str): The display name of the compilation step.
770-
description (str): The description of the compilation step.
771-
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
772-
"""
773-
super(CompilationStep, self).__init__(
774-
name, StepTypeEnum.COMPILATION, display_name, description, depends_on, retry_policies
775-
)
776-
self.estimator = estimator
777-
self.model = model
778-
self.inputs = inputs
779-
self.job_arguments = job_arguments
780-
self._properties = Properties(
781-
path=f"Steps.{name}", shape_name="DescribeCompilationJobResponse"
782-
)
783-
self.cache_config = cache_config
784-
785-
@property
786-
def arguments(self) -> RequestType:
787-
"""The arguments dict that is used to call `create_compilation_job`.
788-
789-
NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
790-
The TrainingJobName and ExperimentConfig attributes cannot be included.
791-
"""
792-
793-
compilation_args = self.model._get_compilation_args(self.estimator, self.inputs)
794-
request_dict = self.model.sagemaker_session._get_compilation_request(**compilation_args)
795-
request_dict.pop("CompilationJobName")
796-
797-
return request_dict
798-
799-
@property
800-
def properties(self):
801-
"""A Properties object representing the DescribeTrainingJobResponse data model."""
802-
return self._properties
803-
804-
def to_request(self) -> RequestType:
805-
"""Updates the dictionary with cache configuration."""
806-
request_dict = super().to_request()
807-
if self.cache_config:
808-
request_dict.update(self.cache_config.config)
809-
810-
return request_dict

0 commit comments

Comments
 (0)