23
23
from sagemaker .estimator import Framework
24
24
import sagemaker .fw_utils as fw
25
25
from sagemaker .tensorflow import defaults
26
- from sagemaker .tensorflow .model import TensorFlowModel
27
26
from sagemaker .tensorflow .serving import Model
28
27
from sagemaker .transformer import Transformer
29
28
from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
@@ -252,10 +251,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
252
251
253
252
def create_model (
254
253
self ,
255
- model_server_workers = None ,
256
254
role = None ,
257
255
vpc_config_override = VPC_CONFIG_DEFAULT ,
258
- endpoint_type = None ,
259
256
entry_point = None ,
260
257
source_dir = None ,
261
258
dependencies = None ,
@@ -266,43 +263,25 @@ def create_model(
266
263
267
264
Args:
268
265
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
269
- used during transform jobs. If not specified, the role from the Estimator will be
270
- used.
271
- model_server_workers (int): Optional. The number of worker processes used by the
272
- inference server. If None, server will use one worker per vCPU.
266
+ used during transform jobs. If not specified, the role from the Estimator is used.
273
267
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
274
- model.
275
- Default: use subnets and security groups from this Estimator.
268
+ model. Default: use subnets and security groups from this Estimator.
269
+
276
270
* 'Subnets' (list[str]): List of subnet ids.
277
271
* 'SecurityGroupIds' (list[str]): List of security group ids.
278
- endpoint_type (str): Optional. Selects the software stack used by the inference server.
279
- If not specified, the model will be configured to use the default
280
- SageMaker model server. If 'tensorflow-serving', the model will be configured to
281
- use the SageMaker Tensorflow Serving container.
272
+
282
273
entry_point (str): Path (absolute or relative) to the local Python source file which
283
- should be executed as the entry point to training. If not specified and
284
- ``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
285
- ``endpoint_type`` is also ``None``, then the training entry point is used.
274
+ should be executed as the entry point to training (default: None).
286
275
source_dir (str): Path (absolute or relative) to a directory with any other serving
287
- source code dependencies aside from the entry point file. If not specified and
288
- ``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If
289
- ``endpoint_type`` is also ``None``, then the model source directory from training
290
- is used.
276
+ source code dependencies aside from the entry point file (default: None).
291
277
dependencies (list[str]): A list of paths to directories (absolute or relative) with
292
- any additional libraries that will be exported to the container.
293
- If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
294
- set to ``None``.
295
- If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
296
- **kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`
297
- and :class:`~sagemaker.tensorflow.model.TensorFlowModel` constructors.
278
+ any additional libraries that will be exported to the container (default: None).
279
+ **kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`.
298
280
299
281
Returns:
300
- sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
301
- ``Model`` object. See :class:`~sagemaker.tensorflow.serving.Model` or
302
- :class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
282
+ sagemaker.tensorflow.serving.Model: A ``Model`` object.
283
+ See :class:`~sagemaker.tensorflow.serving.Model` for full details.
303
284
"""
304
- role = role or self .role
305
-
306
285
if "image" not in kwargs :
307
286
kwargs ["image" ] = self .image_name
308
287
@@ -312,41 +291,11 @@ def create_model(
312
291
if "enable_network_isolation" not in kwargs :
313
292
kwargs ["enable_network_isolation" ] = self .enable_network_isolation ()
314
293
315
- if endpoint_type == "tensorflow-serving" or self ._script_mode_enabled :
316
- return self ._create_tfs_model (
317
- role = role ,
318
- vpc_config_override = vpc_config_override ,
319
- entry_point = entry_point ,
320
- source_dir = source_dir ,
321
- dependencies = dependencies ,
322
- ** kwargs
323
- )
324
-
325
- return self ._create_default_model (
326
- model_server_workers = model_server_workers ,
327
- role = role ,
328
- vpc_config_override = vpc_config_override ,
329
- entry_point = entry_point ,
330
- source_dir = source_dir ,
331
- dependencies = dependencies ,
332
- ** kwargs
333
- )
334
-
335
- def _create_tfs_model (
336
- self ,
337
- role = None ,
338
- vpc_config_override = VPC_CONFIG_DEFAULT ,
339
- entry_point = None ,
340
- source_dir = None ,
341
- dependencies = None ,
342
- ** kwargs
343
- ):
344
- """Placeholder docstring"""
345
294
return Model (
346
295
model_data = self .model_data ,
347
- role = role ,
296
+ role = role or self . role ,
348
297
container_log_level = self .container_log_level ,
349
- framework_version = utils . get_short_version ( self .framework_version ) ,
298
+ framework_version = self .framework_version ,
350
299
sagemaker_session = self .sagemaker_session ,
351
300
vpc_config = self .get_vpc_config (vpc_config_override ),
352
301
entry_point = entry_point ,
@@ -355,34 +304,6 @@ def _create_tfs_model(
355
304
** kwargs
356
305
)
357
306
358
- def _create_default_model (
359
- self ,
360
- model_server_workers ,
361
- role ,
362
- vpc_config_override ,
363
- entry_point = None ,
364
- source_dir = None ,
365
- dependencies = None ,
366
- ** kwargs
367
- ):
368
- """Placeholder docstring"""
369
- return TensorFlowModel (
370
- self .model_data ,
371
- role ,
372
- entry_point or self .entry_point ,
373
- source_dir = source_dir or self ._model_source_dir (),
374
- enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
375
- container_log_level = self .container_log_level ,
376
- code_location = self .code_location ,
377
- py_version = self .py_version ,
378
- framework_version = self .framework_version ,
379
- model_server_workers = model_server_workers ,
380
- sagemaker_session = self .sagemaker_session ,
381
- vpc_config = self .get_vpc_config (vpc_config_override ),
382
- dependencies = dependencies or self .dependencies ,
383
- ** kwargs
384
- )
385
-
386
307
def hyperparameters (self ):
387
308
"""Return hyperparameters used by your custom TensorFlow code during model training."""
388
309
hyperparameters = super (TensorFlow , self ).hyperparameters ()
@@ -479,9 +400,7 @@ def transformer(
479
400
max_payload = None ,
480
401
tags = None ,
481
402
role = None ,
482
- model_server_workers = None ,
483
403
volume_kms_key = None ,
484
- endpoint_type = None ,
485
404
entry_point = None ,
486
405
vpc_config_override = VPC_CONFIG_DEFAULT ,
487
406
enable_network_isolation = None ,
@@ -515,15 +434,8 @@ def transformer(
515
434
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
516
435
used during transform jobs. If not specified, the role from the Estimator will be
517
436
used.
518
- model_server_workers (int): Optional. The number of worker processes used by the
519
- inference server. If None, server will use one worker per vCPU.
520
437
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
521
438
compute instance (default: None).
522
- endpoint_type (str): Optional. Selects the software stack used by the inference server.
523
- If not specified, the model will be configured to use the default
524
- SageMaker model server.
525
- If 'tensorflow-serving', the model will be configured to
526
- use the SageMaker Tensorflow Serving container.
527
439
entry_point (str): Path (absolute or relative) to the local Python source file which
528
440
should be executed as the entry point to training. If not specified and
529
441
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
@@ -575,10 +487,8 @@ def transformer(
575
487
enable_network_isolation = self .enable_network_isolation ()
576
488
577
489
model = self .create_model (
578
- model_server_workers = model_server_workers ,
579
490
role = role ,
580
491
vpc_config_override = vpc_config_override ,
581
- endpoint_type = endpoint_type ,
582
492
entry_point = entry_point ,
583
493
enable_network_isolation = enable_network_isolation ,
584
494
name = model_name ,
0 commit comments