Skip to content

Commit fc2c946

Browse files
qidewenwhenDewen Qi
authored andcommitted
change: Add PipelineVariable annotation in framework models (aws#3188)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 851fd7f commit fc2c946

File tree

14 files changed

+372
-281
lines changed

14 files changed

+372
-281
lines changed

src/sagemaker/chainer/model.py

+42-32
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,25 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
from typing import Optional, Union, List, Dict
1718

1819
import sagemaker
19-
from sagemaker import image_uris
20+
from sagemaker import image_uris, ModelMetrics
21+
from sagemaker.drift_check_baselines import DriftCheckBaselines
2022
from sagemaker.fw_utils import (
2123
model_code_key_prefix,
2224
python_deprecation_warning,
2325
validate_version_or_image_args,
2426
)
27+
from sagemaker.metadata_properties import MetadataProperties
2528
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2629
from sagemaker.chainer import defaults
2730
from sagemaker.deserializers import NumpyDeserializer
2831
from sagemaker.predictor import Predictor
2932
from sagemaker.serializers import NumpySerializer
33+
from sagemaker.utils import to_string
34+
from sagemaker.workflow import is_pipeline_variable
35+
from sagemaker.workflow.entities import PipelineVariable
3036

3137
logger = logging.getLogger("sagemaker")
3238

@@ -75,14 +81,14 @@ class ChainerModel(FrameworkModel):
7581

7682
def __init__(
7783
self,
78-
model_data,
79-
role,
80-
entry_point,
81-
image_uri=None,
82-
framework_version=None,
83-
py_version=None,
84-
predictor_cls=ChainerPredictor,
85-
model_server_workers=None,
84+
model_data: Union[str, PipelineVariable],
85+
role: str,
86+
entry_point: str,
87+
image_uri: Optional[Union[str, PipelineVariable]] = None,
88+
framework_version: Optional[str] = None,
89+
py_version: Optional[str] = None,
90+
predictor_cls: callable = ChainerPredictor,
91+
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
8692
**kwargs
8793
):
8894
"""Initialize an ChainerModel.
@@ -142,27 +148,27 @@ def __init__(
142148

143149
def register(
144150
self,
145-
content_types,
146-
response_types,
147-
inference_instances,
148-
transform_instances,
149-
model_package_name=None,
150-
model_package_group_name=None,
151-
image_uri=None,
152-
model_metrics=None,
153-
metadata_properties=None,
154-
marketplace_cert=False,
155-
approval_status=None,
156-
description=None,
157-
drift_check_baselines=None,
158-
customer_metadata_properties=None,
159-
domain=None,
160-
sample_payload_url=None,
161-
task=None,
162-
framework=None,
163-
framework_version=None,
164-
nearest_model_name=None,
165-
data_input_configuration=None,
151+
content_types: List[Union[str, PipelineVariable]],
152+
response_types: List[Union[str, PipelineVariable]],
153+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
154+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
155+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
156+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
157+
image_uri: Optional[Union[str, PipelineVariable]] = None,
158+
model_metrics: Optional[ModelMetrics] = None,
159+
metadata_properties: Optional[MetadataProperties] = None,
160+
marketplace_cert: bool = False,
161+
approval_status: Optional[Union[str, PipelineVariable]] = None,
162+
description: Optional[str] = None,
163+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
164+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
165+
domain: Optional[Union[str, PipelineVariable]] = None,
166+
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
167+
task: Optional[Union[str, PipelineVariable]] = None,
168+
framework: Optional[Union[str, PipelineVariable]] = None,
169+
framework_version: Optional[Union[str, PipelineVariable]] = None,
170+
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
171+
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
166172
):
167173
"""Creates a model package for creating SageMaker models or listing on Marketplace.
168174
@@ -218,6 +224,8 @@ def register(
218224
region_name=self.sagemaker_session.boto_session.region_name,
219225
instance_type=instance_type,
220226
)
227+
if not is_pipeline_variable(framework):
228+
framework = (framework or self._framework_name).upper()
221229
return super(ChainerModel, self).register(
222230
content_types,
223231
response_types,
@@ -236,7 +244,7 @@ def register(
236244
domain=domain,
237245
sample_payload_url=sample_payload_url,
238246
task=task,
239-
framework=(framework or self._framework_name).upper(),
247+
framework=framework,
240248
framework_version=framework_version or self.framework_version,
241249
nearest_model_name=nearest_model_name,
242250
data_input_configuration=data_input_configuration,
@@ -282,7 +290,9 @@ def prepare_container_def(
282290
deploy_env.update(self._script_mode_env_vars())
283291

284292
if self.model_server_workers:
285-
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
293+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
294+
self.model_server_workers
295+
)
286296
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
287297

288298
def serving_image_uri(

src/sagemaker/estimator.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
build_dict,
7777
get_config_value,
7878
name_from_base,
79+
to_string,
7980
)
8081
from sagemaker.workflow import is_pipeline_variable
8182
from sagemaker.workflow.entities import PipelineVariable
@@ -1947,10 +1948,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
19471948

19481949
current_hyperparameters = estimator.hyperparameters()
19491950
if current_hyperparameters is not None:
1950-
hyperparameters = {
1951-
str(k): (v.to_string() if is_pipeline_variable(v) else str(v))
1952-
for (k, v) in current_hyperparameters.items()
1953-
}
1951+
hyperparameters = {str(k): to_string(v) for (k, v) in current_hyperparameters.items()}
19541952

19551953
train_args = config.copy()
19561954
train_args["input_mode"] = estimator.input_mode

src/sagemaker/huggingface/model.py

+49-39
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,24 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
from typing import Optional, Union, List, Dict
1718

1819
import sagemaker
19-
from sagemaker import image_uris
20+
from sagemaker import image_uris, ModelMetrics
2021
from sagemaker.deserializers import JSONDeserializer
22+
from sagemaker.drift_check_baselines import DriftCheckBaselines
2123
from sagemaker.fw_utils import (
2224
model_code_key_prefix,
2325
validate_version_or_image_args,
2426
)
27+
from sagemaker.metadata_properties import MetadataProperties
2528
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2629
from sagemaker.predictor import Predictor
2730
from sagemaker.serializers import JSONSerializer
2831
from sagemaker.session import Session
32+
from sagemaker.utils import to_string
33+
from sagemaker.workflow import is_pipeline_variable
34+
from sagemaker.workflow.entities import PipelineVariable
2935

3036
logger = logging.getLogger("sagemaker")
3137

@@ -100,16 +106,16 @@ class HuggingFaceModel(FrameworkModel):
100106

101107
def __init__(
102108
self,
103-
role,
104-
model_data=None,
105-
entry_point=None,
106-
transformers_version=None,
107-
tensorflow_version=None,
108-
pytorch_version=None,
109-
py_version=None,
110-
image_uri=None,
111-
predictor_cls=HuggingFacePredictor,
112-
model_server_workers=None,
109+
role: str,
110+
model_data: Optional[Union[str, PipelineVariable]] = None,
111+
entry_point: Optional[str] = None,
112+
transformers_version: Optional[str] = None,
113+
tensorflow_version: Optional[str] = None,
114+
pytorch_version: Optional[str] = None,
115+
py_version: Optional[str] = None,
116+
image_uri: Optional[Union[str, PipelineVariable]] = None,
117+
predictor_cls: callable = HuggingFacePredictor,
118+
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
113119
**kwargs,
114120
):
115121
"""Initialize a HuggingFaceModel.
@@ -299,27 +305,27 @@ def deploy(
299305

300306
def register(
301307
self,
302-
content_types,
303-
response_types,
304-
inference_instances=None,
305-
transform_instances=None,
306-
model_package_name=None,
307-
model_package_group_name=None,
308-
image_uri=None,
309-
model_metrics=None,
310-
metadata_properties=None,
311-
marketplace_cert=False,
312-
approval_status=None,
313-
description=None,
314-
drift_check_baselines=None,
315-
customer_metadata_properties=None,
316-
domain=None,
317-
sample_payload_url=None,
318-
task=None,
319-
framework=None,
320-
framework_version=None,
321-
nearest_model_name=None,
322-
data_input_configuration=None,
308+
content_types: List[Union[str, PipelineVariable]],
309+
response_types: List[Union[str, PipelineVariable]],
310+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
311+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
312+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
313+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
314+
image_uri: Optional[Union[str, PipelineVariable]] = None,
315+
model_metrics: Optional[ModelMetrics] = None,
316+
metadata_properties: Optional[MetadataProperties] = None,
317+
marketplace_cert: bool = False,
318+
approval_status: Optional[Union[str, PipelineVariable]] = None,
319+
description: Optional[str] = None,
320+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
321+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
322+
domain: Optional[Union[str, PipelineVariable]] = None,
323+
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
324+
task: Optional[Union[str, PipelineVariable]] = None,
325+
framework: Optional[Union[str, PipelineVariable]] = None,
326+
framework_version: Optional[Union[str, PipelineVariable]] = None,
327+
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
328+
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
323329
):
324330
"""Creates a model package for creating SageMaker models or listing on Marketplace.
325331
@@ -377,6 +383,13 @@ def register(
377383
region_name=self.sagemaker_session.boto_session.region_name,
378384
instance_type=instance_type,
379385
)
386+
if not is_pipeline_variable(framework):
387+
framework = (
388+
framework
389+
or fetch_framework_and_framework_version(
390+
self.tensorflow_version, self.pytorch_version
391+
)[0]
392+
).upper()
380393
return super(HuggingFaceModel, self).register(
381394
content_types,
382395
response_types,
@@ -395,12 +408,7 @@ def register(
395408
domain=domain,
396409
sample_payload_url=sample_payload_url,
397410
task=task,
398-
framework=(
399-
framework
400-
or fetch_framework_and_framework_version(
401-
self.tensorflow_version, self.pytorch_version
402-
)[0]
403-
).upper(),
411+
framework=framework,
404412
framework_version=framework_version
405413
or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[
406414
1
@@ -449,7 +457,9 @@ def prepare_container_def(
449457
deploy_env.update(self._script_mode_env_vars())
450458

451459
if self.model_server_workers:
452-
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
460+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
461+
self.model_server_workers
462+
)
453463
return sagemaker.container_def(
454464
deploy_image, self.repacked_model_data or self.model_data, deploy_env
455465
)

src/sagemaker/multidatamodel.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
from typing import Union, Optional
1718

1819
from six.moves.urllib.parse import urlparse
1920

@@ -22,6 +23,8 @@
2223
from sagemaker.deprecations import removed_kwargs
2324
from sagemaker.model import Model
2425
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
27+
from sagemaker.workflow.entities import PipelineVariable
2528

2629
MULTI_MODEL_CONTAINER_MODE = "MultiModel"
2730

@@ -34,12 +37,12 @@ class MultiDataModel(Model):
3437

3538
def __init__(
3639
self,
37-
name,
38-
model_data_prefix,
39-
model=None,
40-
image_uri=None,
41-
role=None,
42-
sagemaker_session=None,
40+
name: str,
41+
model_data_prefix: str,
42+
model: Optional[Model] = None,
43+
image_uri: Optional[Union[str, PipelineVariable]] = None,
44+
role: Optional[str] = None,
45+
sagemaker_session: Optional[Session] = None,
4346
**kwargs,
4447
):
4548
"""Initialize a ``MultiDataModel``.
@@ -106,6 +109,7 @@ def __init__(
106109

107110
# Set the ``Model`` parameters if the model parameter is not specified
108111
if not self.model:
112+
pop_out_unused_kwarg("model_data", kwargs, self.model_data_prefix)
109113
super(MultiDataModel, self).__init__(
110114
image_uri,
111115
self.model_data_prefix,
@@ -115,7 +119,9 @@ def __init__(
115119
**kwargs,
116120
)
117121

118-
def prepare_container_def(self, instance_type=None, accelerator_type=None):
122+
def prepare_container_def(
123+
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
124+
):
119125
"""Return a container definition set.
120126
121127
Definition set includes MultiModel mode, model data and other parameters

0 commit comments

Comments
 (0)