10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- """Placeholder docstring """
13
+ """Classes for using TensorFlow on Amazon SageMaker for inference. """
14
14
from __future__ import absolute_import
15
15
16
16
import logging
22
22
from sagemaker .tensorflow .defaults import TF_VERSION
23
23
24
24
25
- class Predictor (sagemaker .RealTimePredictor ):
25
+ class TensorFlowPredictor (sagemaker .RealTimePredictor ):
26
26
"""A ``RealTimePredictor`` implementation for inference against TensorFlow
27
27
Serving endpoints.
28
28
"""
@@ -37,7 +37,7 @@ def __init__(
37
37
model_name = None ,
38
38
model_version = None ,
39
39
):
40
- """Initialize a ``TFSPredictor ``. See `` sagemaker.RealTimePredictor` `
40
+ """Initialize a ``TensorFlowPredictor ``. See :class:`~ sagemaker.predictor. RealTimePredictor`
41
41
for more info about parameters.
42
42
43
43
Args:
@@ -61,7 +61,7 @@ def __init__(
61
61
that should handle the request. If not specified, the latest
62
62
version of the model will be used.
63
63
"""
64
- super (Predictor , self ).__init__ (
64
+ super (TensorFlowPredictor , self ).__init__ (
65
65
endpoint_name , sagemaker_session , serializer , deserializer , content_type
66
66
)
67
67
@@ -115,13 +115,13 @@ def predict(self, data, initial_args=None):
115
115
else :
116
116
args ["CustomAttributes" ] = self ._model_attributes
117
117
118
- return super (Predictor , self ).predict (data , args )
118
+ return super (TensorFlowPredictor , self ).predict (data , args )
119
119
120
120
121
- class Model (sagemaker .model .FrameworkModel ):
122
- """Placeholder docstring """
121
+ class TensorFlowModel (sagemaker .model .FrameworkModel ):
122
+ """A ``FrameworkModel`` implementation for inference with TensorFlow Serving. """
123
123
124
- FRAMEWORK_NAME = "tensorflow-serving"
124
+ __framework_name__ = "tensorflow-serving"
125
125
LOG_LEVEL_PARAM_NAME = "SAGEMAKER_TFS_NGINX_LOGLEVEL"
126
126
LOG_LEVEL_MAP = {
127
127
logging .DEBUG : "debug" ,
@@ -140,7 +140,7 @@ def __init__(
140
140
image = None ,
141
141
framework_version = TF_VERSION ,
142
142
container_log_level = None ,
143
- predictor_cls = Predictor ,
143
+ predictor_cls = TensorFlowPredictor ,
144
144
** kwargs
145
145
):
146
146
"""Initialize a Model.
@@ -171,15 +171,15 @@ def __init__(
171
171
:class:`~sagemaker.model.FrameworkModel` and
172
172
:class:`~sagemaker.model.Model`.
173
173
"""
174
- super (Model , self ).__init__ (
174
+ super (TensorFlowModel , self ).__init__ (
175
175
model_data = model_data ,
176
176
role = role ,
177
177
image = image ,
178
178
predictor_cls = predictor_cls ,
179
179
entry_point = entry_point ,
180
180
** kwargs
181
181
)
182
- self ._framework_version = framework_version
182
+ self .framework_version = framework_version
183
183
self ._container_log_level = container_log_level
184
184
185
185
def deploy (
@@ -196,10 +196,10 @@ def deploy(
196
196
):
197
197
198
198
if accelerator_type and not self ._eia_supported ():
199
- msg = "The TensorFlow version %s doesn't support EIA." % self ._framework_version
200
-
199
+ msg = "The TensorFlow version %s doesn't support EIA." % self .framework_version
201
200
raise AttributeError (msg )
202
- return super (Model , self ).deploy (
201
+
202
+ return super (TensorFlowModel , self ).deploy (
203
203
initial_instance_count = initial_instance_count ,
204
204
instance_type = instance_type ,
205
205
accelerator_type = accelerator_type ,
@@ -213,7 +213,7 @@ def deploy(
213
213
214
214
def _eia_supported (self ):
215
215
"""Return true if TF version is EIA enabled"""
216
- return [int (s ) for s in self ._framework_version .split ("." )][:2 ] <= self .LATEST_EIA_VERSION
216
+ return [int (s ) for s in self .framework_version .split ("." )][:2 ] <= self .LATEST_EIA_VERSION
217
217
218
218
def prepare_container_def (self , instance_type , accelerator_type = None ):
219
219
"""
@@ -249,12 +249,12 @@ def _get_container_env(self):
249
249
if not self ._container_log_level :
250
250
return self .env
251
251
252
- if self ._container_log_level not in Model .LOG_LEVEL_MAP :
252
+ if self ._container_log_level not in self .LOG_LEVEL_MAP :
253
253
logging .warning ("ignoring invalid container log level: %s" , self ._container_log_level )
254
254
return self .env
255
255
256
256
env = dict (self .env )
257
- env [Model .LOG_LEVEL_PARAM_NAME ] = Model .LOG_LEVEL_MAP [self ._container_log_level ]
257
+ env [self .LOG_LEVEL_PARAM_NAME ] = self .LOG_LEVEL_MAP [self ._container_log_level ]
258
258
return env
259
259
260
260
def _get_image_uri (self , instance_type , accelerator_type = None ):
@@ -269,9 +269,9 @@ def _get_image_uri(self, instance_type, accelerator_type=None):
269
269
region_name = self .sagemaker_session .boto_region_name
270
270
return create_image_uri (
271
271
region_name ,
272
- Model . FRAMEWORK_NAME ,
272
+ self . __framework_name__ ,
273
273
instance_type ,
274
- self ._framework_version ,
274
+ self .framework_version ,
275
275
accelerator_type = accelerator_type ,
276
276
)
277
277
0 commit comments