14
14
from __future__ import absolute_import
15
15
16
16
import logging
17
+ from typing import Union , Optional , List , Dict
17
18
18
19
import packaging .version
19
20
20
21
import sagemaker
21
- from sagemaker import image_uris
22
+ from sagemaker import image_uris , ModelMetrics
22
23
from sagemaker .deserializers import JSONDeserializer
24
+ from sagemaker .drift_check_baselines import DriftCheckBaselines
23
25
from sagemaker .fw_utils import (
24
26
model_code_key_prefix ,
25
27
python_deprecation_warning ,
26
28
validate_version_or_image_args ,
27
29
)
30
+ from sagemaker .metadata_properties import MetadataProperties
28
31
from sagemaker .model import FrameworkModel , MODEL_SERVER_WORKERS_PARAM_NAME
29
32
from sagemaker .mxnet import defaults
30
33
from sagemaker .predictor import Predictor
31
34
from sagemaker .serializers import JSONSerializer
35
+ from sagemaker .utils import to_string
36
+ from sagemaker .workflow .entities import PipelineVariable
32
37
33
38
logger = logging .getLogger ("sagemaker" )
34
39
@@ -77,14 +82,14 @@ class MXNetModel(FrameworkModel):
77
82
78
83
def __init__ (
79
84
self ,
80
- model_data ,
81
- role ,
82
- entry_point ,
83
- framework_version = None ,
84
- py_version = None ,
85
- image_uri = None ,
86
- predictor_cls = MXNetPredictor ,
87
- model_server_workers = None ,
85
+ model_data : Union [ str , PipelineVariable ] ,
86
+ role : str ,
87
+ entry_point : str ,
88
+ framework_version : str = _LOWEST_MMS_VERSION ,
89
+ py_version : Optional [ str ] = None ,
90
+ image_uri : Optional [ Union [ str , PipelineVariable ]] = None ,
91
+ predictor_cls : callable = MXNetPredictor ,
92
+ model_server_workers : Optional [ Union [ int , PipelineVariable ]] = None ,
88
93
** kwargs
89
94
):
90
95
"""Initialize an MXNetModel.
@@ -102,7 +107,7 @@ def __init__(
102
107
hosting. If ``source_dir`` is specified, then ``entry_point``
103
108
must point to a file located at the root of ``source_dir``.
104
109
framework_version (str): MXNet version you want to use for executing
105
- your model training code. Defaults to ``None ``. Required unless
110
+ your model training code. Defaults to ``1.4.0 ``. Required unless
106
111
``image_uri`` is provided.
107
112
py_version (str): Python version you want to use for executing your
108
113
model training code. Defaults to ``None``. Required unless
@@ -144,21 +149,21 @@ def __init__(
144
149
145
150
def register (
146
151
self ,
147
- content_types ,
148
- response_types ,
149
- inference_instances = None ,
150
- transform_instances = None ,
151
- model_package_name = None ,
152
- model_package_group_name = None ,
153
- image_uri = None ,
154
- model_metrics = None ,
155
- metadata_properties = None ,
156
- marketplace_cert = False ,
157
- approval_status = None ,
158
- description = None ,
159
- drift_check_baselines = None ,
160
- customer_metadata_properties = None ,
161
- domain = None ,
152
+ content_types : List [ Union [ str , PipelineVariable ]] ,
153
+ response_types : List [ Union [ str , PipelineVariable ]] ,
154
+ inference_instances : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
155
+ transform_instances : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
156
+ model_package_name : Optional [ Union [ str , PipelineVariable ]] = None ,
157
+ model_package_group_name : Optional [ Union [ str , PipelineVariable ]] = None ,
158
+ image_uri : Optional [ Union [ str , PipelineVariable ]] = None ,
159
+ model_metrics : Optional [ ModelMetrics ] = None ,
160
+ metadata_properties : Optional [ MetadataProperties ] = None ,
161
+ marketplace_cert : bool = False ,
162
+ approval_status : Optional [ Union [ str , PipelineVariable ]] = None ,
163
+ description : Optional [ str ] = None ,
164
+ drift_check_baselines : Optional [ DriftCheckBaselines ] = None ,
165
+ customer_metadata_properties : Optional [ Dict [ str , Union [ str , PipelineVariable ]]] = None ,
166
+ domain : Optional [ Union [ str , PipelineVariable ]] = None ,
162
167
):
163
168
"""Creates a model package for creating SageMaker models or listing on Marketplace.
164
169
@@ -262,7 +267,9 @@ def prepare_container_def(
262
267
deploy_env .update (self ._script_mode_env_vars ())
263
268
264
269
if self .model_server_workers :
265
- deploy_env [MODEL_SERVER_WORKERS_PARAM_NAME .upper ()] = str (self .model_server_workers )
270
+ deploy_env [MODEL_SERVER_WORKERS_PARAM_NAME .upper ()] = to_string (
271
+ self .model_server_workers
272
+ )
266
273
return sagemaker .container_def (
267
274
deploy_image , self .repacked_model_data or self .model_data , deploy_env
268
275
)
0 commit comments