Skip to content

Commit a356d7d

Browse files
Dewen Qiqidewenwhen
Dewen Qi
authored andcommitted
change: Add PipelineVariable annotation in framework models
1 parent 5cf83df commit a356d7d

File tree

14 files changed

+274
-189
lines changed

14 files changed

+274
-189
lines changed

src/sagemaker/chainer/model.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
from typing import Optional, Union
1718

1819
import sagemaker
1920
from sagemaker import image_uris
@@ -27,6 +28,8 @@
2728
from sagemaker.deserializers import NumpyDeserializer
2829
from sagemaker.predictor import Predictor
2930
from sagemaker.serializers import NumpySerializer
31+
from sagemaker.utils import to_string
32+
from sagemaker.workflow.entities import PipelineVariable
3033

3134
logger = logging.getLogger("sagemaker")
3235

@@ -75,14 +78,14 @@ class ChainerModel(FrameworkModel):
7578

7679
def __init__(
7780
self,
78-
model_data,
79-
role,
80-
entry_point,
81-
image_uri=None,
82-
framework_version=None,
83-
py_version=None,
84-
predictor_cls=ChainerPredictor,
85-
model_server_workers=None,
81+
model_data: Union[str, PipelineVariable],
82+
role: str,
83+
entry_point: str,
84+
image_uri: Optional[Union[str, PipelineVariable]] = None,
85+
framework_version: Optional[str] = None,
86+
py_version: Optional[str] = None,
87+
predictor_cls: callable = ChainerPredictor,
88+
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
8689
**kwargs
8790
):
8891
"""Initialize an ChainerModel.
@@ -180,7 +183,9 @@ def prepare_container_def(
180183
deploy_env.update(self._script_mode_env_vars())
181184

182185
if self.model_server_workers:
183-
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
186+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
187+
self.model_server_workers
188+
)
184189
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
185190

186191
def serving_image_uri(

src/sagemaker/estimator.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
build_dict,
7474
get_config_value,
7575
name_from_base,
76+
to_string,
7677
)
7778
from sagemaker.workflow import is_pipeline_variable
7879
from sagemaker.workflow.pipeline_context import (
@@ -1848,10 +1849,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
18481849

18491850
current_hyperparameters = estimator.hyperparameters()
18501851
if current_hyperparameters is not None:
1851-
hyperparameters = {
1852-
str(k): (v.to_string() if is_pipeline_variable(v) else str(v))
1853-
for (k, v) in current_hyperparameters.items()
1854-
}
1852+
hyperparameters = {str(k): to_string(v) for (k, v) in current_hyperparameters.items()}
18551853

18561854
train_args = config.copy()
18571855
train_args["input_mode"] = estimator.input_mode

src/sagemaker/huggingface/model.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,23 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
from typing import Optional, Union, List, Dict
1718

1819
import sagemaker
19-
from sagemaker import image_uris
20+
from sagemaker import image_uris, ModelMetrics
2021
from sagemaker.deserializers import JSONDeserializer
22+
from sagemaker.drift_check_baselines import DriftCheckBaselines
2123
from sagemaker.fw_utils import (
2224
model_code_key_prefix,
2325
validate_version_or_image_args,
2426
)
27+
from sagemaker.metadata_properties import MetadataProperties
2528
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2629
from sagemaker.predictor import Predictor
2730
from sagemaker.serializers import JSONSerializer
2831
from sagemaker.session import Session
32+
from sagemaker.utils import to_string
33+
from sagemaker.workflow.entities import PipelineVariable
2934

3035
logger = logging.getLogger("sagemaker")
3136

@@ -92,16 +97,16 @@ class HuggingFaceModel(FrameworkModel):
9297

9398
def __init__(
9499
self,
95-
role,
96-
model_data=None,
97-
entry_point=None,
98-
transformers_version=None,
99-
tensorflow_version=None,
100-
pytorch_version=None,
101-
py_version=None,
102-
image_uri=None,
103-
predictor_cls=HuggingFacePredictor,
104-
model_server_workers=None,
100+
role: str,
101+
model_data: Optional[Union[str, PipelineVariable]] = None,
102+
entry_point: Optional[str] = None,
103+
transformers_version: Optional[str] = None,
104+
tensorflow_version: Optional[str] = None,
105+
pytorch_version: Optional[str] = None,
106+
py_version: Optional[str] = None,
107+
image_uri: Optional[Union[str, PipelineVariable]] = None,
108+
predictor_cls: callable = HuggingFacePredictor,
109+
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
105110
**kwargs,
106111
):
107112
"""Initialize a HuggingFaceModel.
@@ -291,21 +296,21 @@ def deploy(
291296

292297
def register(
293298
self,
294-
content_types,
295-
response_types,
296-
inference_instances=None,
297-
transform_instances=None,
298-
model_package_name=None,
299-
model_package_group_name=None,
300-
image_uri=None,
301-
model_metrics=None,
302-
metadata_properties=None,
303-
marketplace_cert=False,
304-
approval_status=None,
305-
description=None,
306-
drift_check_baselines=None,
307-
customer_metadata_properties=None,
308-
domain=None,
299+
content_types: List[Union[str, PipelineVariable]],
300+
response_types: List[Union[str, PipelineVariable]],
301+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
302+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
303+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
304+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
305+
image_uri: Optional[Union[str, PipelineVariable]] = None,
306+
model_metrics: Optional[ModelMetrics] = None,
307+
metadata_properties: Optional[MetadataProperties] = None,
308+
marketplace_cert: bool = False,
309+
approval_status: Optional[Union[str, PipelineVariable]] = None,
310+
description: Optional[str] = None,
311+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
312+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
313+
domain: Optional[Union[str, PipelineVariable]] = None,
309314
):
310315
"""Creates a model package for creating SageMaker models or listing on Marketplace.
311316
@@ -409,7 +414,9 @@ def prepare_container_def(
409414
deploy_env.update(self._script_mode_env_vars())
410415

411416
if self.model_server_workers:
412-
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
417+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
418+
self.model_server_workers
419+
)
413420
return sagemaker.container_def(
414421
deploy_image, self.repacked_model_data or self.model_data, deploy_env
415422
)

src/sagemaker/multidatamodel.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
from typing import Union, Optional
1718

1819
from six.moves.urllib.parse import urlparse
1920

@@ -22,6 +23,8 @@
2223
from sagemaker.deprecations import removed_kwargs
2324
from sagemaker.model import Model
2425
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
27+
from sagemaker.workflow.entities import PipelineVariable
2528

2629
MULTI_MODEL_CONTAINER_MODE = "MultiModel"
2730

@@ -34,12 +37,12 @@ class MultiDataModel(Model):
3437

3538
def __init__(
3639
self,
37-
name,
38-
model_data_prefix,
39-
model=None,
40-
image_uri=None,
41-
role=None,
42-
sagemaker_session=None,
40+
name: str,
41+
model_data_prefix: str,
42+
model: Optional[Model] = None,
43+
image_uri: Optional[Union[str, PipelineVariable]] = None,
44+
role: Optional[str] = None,
45+
sagemaker_session: Optional[Session] = None,
4346
**kwargs,
4447
):
4548
"""Initialize a ``MultiDataModel``.
@@ -106,6 +109,7 @@ def __init__(
106109

107110
# Set the ``Model`` parameters if the model parameter is not specified
108111
if not self.model:
112+
pop_out_unused_kwarg("model_data", kwargs, self.model_data_prefix)
109113
super(MultiDataModel, self).__init__(
110114
image_uri,
111115
self.model_data_prefix,
@@ -115,7 +119,9 @@ def __init__(
115119
**kwargs,
116120
)
117121

118-
def prepare_container_def(self, instance_type=None, accelerator_type=None):
122+
def prepare_container_def(
123+
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
124+
):
119125
"""Return a container definition set.
120126
121127
Definition set includes MultiModel mode, model data and other parameters

src/sagemaker/mxnet/model.py

+33-26
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,26 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
from typing import Union, Optional, List, Dict
1718

1819
import packaging.version
1920

2021
import sagemaker
21-
from sagemaker import image_uris
22+
from sagemaker import image_uris, ModelMetrics
2223
from sagemaker.deserializers import JSONDeserializer
24+
from sagemaker.drift_check_baselines import DriftCheckBaselines
2325
from sagemaker.fw_utils import (
2426
model_code_key_prefix,
2527
python_deprecation_warning,
2628
validate_version_or_image_args,
2729
)
30+
from sagemaker.metadata_properties import MetadataProperties
2831
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2932
from sagemaker.mxnet import defaults
3033
from sagemaker.predictor import Predictor
3134
from sagemaker.serializers import JSONSerializer
35+
from sagemaker.utils import to_string
36+
from sagemaker.workflow.entities import PipelineVariable
3237

3338
logger = logging.getLogger("sagemaker")
3439

@@ -77,14 +82,14 @@ class MXNetModel(FrameworkModel):
7782

7883
def __init__(
7984
self,
80-
model_data,
81-
role,
82-
entry_point,
83-
framework_version=None,
84-
py_version=None,
85-
image_uri=None,
86-
predictor_cls=MXNetPredictor,
87-
model_server_workers=None,
85+
model_data: Union[str, PipelineVariable],
86+
role: str,
87+
entry_point: str,
88+
framework_version: str = _LOWEST_MMS_VERSION,
89+
py_version: Optional[str] = None,
90+
image_uri: Optional[Union[str, PipelineVariable]] = None,
91+
predictor_cls: callable = MXNetPredictor,
92+
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
8893
**kwargs
8994
):
9095
"""Initialize an MXNetModel.
@@ -102,7 +107,7 @@ def __init__(
102107
hosting. If ``source_dir`` is specified, then ``entry_point``
103108
must point to a file located at the root of ``source_dir``.
104109
framework_version (str): MXNet version you want to use for executing
105-
your model training code. Defaults to ``None``. Required unless
110+
your model training code. Defaults to ``1.4.0``. Required unless
106111
``image_uri`` is provided.
107112
py_version (str): Python version you want to use for executing your
108113
model training code. Defaults to ``None``. Required unless
@@ -144,21 +149,21 @@ def __init__(
144149

145150
def register(
146151
self,
147-
content_types,
148-
response_types,
149-
inference_instances=None,
150-
transform_instances=None,
151-
model_package_name=None,
152-
model_package_group_name=None,
153-
image_uri=None,
154-
model_metrics=None,
155-
metadata_properties=None,
156-
marketplace_cert=False,
157-
approval_status=None,
158-
description=None,
159-
drift_check_baselines=None,
160-
customer_metadata_properties=None,
161-
domain=None,
152+
content_types: List[Union[str, PipelineVariable]],
153+
response_types: List[Union[str, PipelineVariable]],
154+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
155+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
156+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
157+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
158+
image_uri: Optional[Union[str, PipelineVariable]] = None,
159+
model_metrics: Optional[ModelMetrics] = None,
160+
metadata_properties: Optional[MetadataProperties] = None,
161+
marketplace_cert: bool = False,
162+
approval_status: Optional[Union[str, PipelineVariable]] = None,
163+
description: Optional[str] = None,
164+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
165+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
166+
domain: Optional[Union[str, PipelineVariable]] = None,
162167
):
163168
"""Creates a model package for creating SageMaker models or listing on Marketplace.
164169
@@ -262,7 +267,9 @@ def prepare_container_def(
262267
deploy_env.update(self._script_mode_env_vars())
263268

264269
if self.model_server_workers:
265-
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
270+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
271+
self.model_server_workers
272+
)
266273
return sagemaker.container_def(
267274
deploy_image, self.repacked_model_data or self.model_data, deploy_env
268275
)

src/sagemaker/parameter.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import json
1717

18-
from sagemaker.workflow import is_pipeline_variable
18+
from sagemaker.utils import to_string
1919

2020

2121
class ParameterRange(object):
@@ -71,12 +71,8 @@ def as_tuning_range(self, name):
7171
"""
7272
return {
7373
"Name": name,
74-
"MinValue": str(self.min_value)
75-
if not is_pipeline_variable(self.min_value)
76-
else self.min_value.to_string(),
77-
"MaxValue": str(self.max_value)
78-
if not is_pipeline_variable(self.max_value)
79-
else self.max_value.to_string(),
74+
"MinValue": to_string(self.min_value),
75+
"MaxValue": to_string(self.max_value),
8076
"ScalingType": self.scaling_type,
8177
}
8278

@@ -110,7 +106,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
110106
This input will be converted into a list of strings.
111107
"""
112108
values = values if isinstance(values, list) else [values]
113-
self.values = [str(v) if not is_pipeline_variable(v) else v.to_string() for v in values]
109+
self.values = [to_string(v) for v in values]
114110

115111
def as_tuning_range(self, name):
116112
"""Represent the parameter range as a dictionary.

0 commit comments

Comments
 (0)