23
23
from sagemaker .estimator import Framework
24
24
import sagemaker .fw_utils as fw
25
25
from sagemaker .tensorflow import defaults
26
- from sagemaker .tensorflow .model import TensorFlowModel
27
26
from sagemaker .tensorflow .serving import Model
28
- from sagemaker .transformer import Transformer
29
27
from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
30
28
31
29
logger = logging .getLogger ("sagemaker" )
@@ -252,10 +250,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
252
250
253
251
def create_model (
254
252
self ,
255
- model_server_workers = None ,
256
253
role = None ,
257
254
vpc_config_override = VPC_CONFIG_DEFAULT ,
258
- endpoint_type = None ,
259
255
entry_point = None ,
260
256
source_dir = None ,
261
257
dependencies = None ,
@@ -266,43 +262,25 @@ def create_model(
266
262
267
263
Args:
268
264
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
269
- used during transform jobs. If not specified, the role from the Estimator will be
270
- used.
271
- model_server_workers (int): Optional. The number of worker processes used by the
272
- inference server. If None, server will use one worker per vCPU.
265
+ used during transform jobs. If not specified, the role from the Estimator is used.
273
266
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
274
- model.
275
- Default: use subnets and security groups from this Estimator.
267
+ model. Default: use subnets and security groups from this Estimator.
268
+
276
269
* 'Subnets' (list[str]): List of subnet ids.
277
270
* 'SecurityGroupIds' (list[str]): List of security group ids.
278
- endpoint_type (str): Optional. Selects the software stack used by the inference server.
279
- If not specified, the model will be configured to use the default
280
- SageMaker model server. If 'tensorflow-serving', the model will be configured to
281
- use the SageMaker Tensorflow Serving container.
271
+
282
272
entry_point (str): Path (absolute or relative) to the local Python source file which
283
- should be executed as the entry point to training. If not specified and
284
- ``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
285
- ``endpoint_type`` is also ``None``, then the training entry point is used.
273
+ should be executed as the entry point to training (default: None).
286
274
source_dir (str): Path (absolute or relative) to a directory with any other serving
287
- source code dependencies aside from the entry point file. If not specified and
288
- ``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If
289
- ``endpoint_type`` is also ``None``, then the model source directory from training
290
- is used.
275
+ source code dependencies aside from the entry point file (default: None).
291
276
dependencies (list[str]): A list of paths to directories (absolute or relative) with
292
- any additional libraries that will be exported to the container.
293
- If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
294
- set to ``None``.
295
- If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
296
- **kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`
297
- and :class:`~sagemaker.tensorflow.model.TensorFlowModel` constructors.
277
+ any additional libraries that will be exported to the container (default: None).
278
+ **kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`.
298
279
299
280
Returns:
300
- sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
301
- ``Model`` object. See :class:`~sagemaker.tensorflow.serving.Model` or
302
- :class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
281
+ sagemaker.tensorflow.serving.Model: A ``Model`` object.
282
+ See :class:`~sagemaker.tensorflow.serving.Model` for full details.
303
283
"""
304
- role = role or self .role
305
-
306
284
if "image" not in kwargs :
307
285
kwargs ["image" ] = self .image_name
308
286
@@ -312,41 +290,11 @@ def create_model(
312
290
if "enable_network_isolation" not in kwargs :
313
291
kwargs ["enable_network_isolation" ] = self .enable_network_isolation ()
314
292
315
- if endpoint_type == "tensorflow-serving" or self ._script_mode_enabled :
316
- return self ._create_tfs_model (
317
- role = role ,
318
- vpc_config_override = vpc_config_override ,
319
- entry_point = entry_point ,
320
- source_dir = source_dir ,
321
- dependencies = dependencies ,
322
- ** kwargs
323
- )
324
-
325
- return self ._create_default_model (
326
- model_server_workers = model_server_workers ,
327
- role = role ,
328
- vpc_config_override = vpc_config_override ,
329
- entry_point = entry_point ,
330
- source_dir = source_dir ,
331
- dependencies = dependencies ,
332
- ** kwargs
333
- )
334
-
335
- def _create_tfs_model (
336
- self ,
337
- role = None ,
338
- vpc_config_override = VPC_CONFIG_DEFAULT ,
339
- entry_point = None ,
340
- source_dir = None ,
341
- dependencies = None ,
342
- ** kwargs
343
- ):
344
- """Placeholder docstring"""
345
293
return Model (
346
294
model_data = self .model_data ,
347
- role = role ,
295
+ role = role or self . role ,
348
296
container_log_level = self .container_log_level ,
349
- framework_version = utils . get_short_version ( self .framework_version ) ,
297
+ framework_version = self .framework_version ,
350
298
sagemaker_session = self .sagemaker_session ,
351
299
vpc_config = self .get_vpc_config (vpc_config_override ),
352
300
entry_point = entry_point ,
@@ -355,34 +303,6 @@ def _create_tfs_model(
355
303
** kwargs
356
304
)
357
305
358
- def _create_default_model (
359
- self ,
360
- model_server_workers ,
361
- role ,
362
- vpc_config_override ,
363
- entry_point = None ,
364
- source_dir = None ,
365
- dependencies = None ,
366
- ** kwargs
367
- ):
368
- """Placeholder docstring"""
369
- return TensorFlowModel (
370
- self .model_data ,
371
- role ,
372
- entry_point or self .entry_point ,
373
- source_dir = source_dir or self ._model_source_dir (),
374
- enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
375
- container_log_level = self .container_log_level ,
376
- code_location = self .code_location ,
377
- py_version = self .py_version ,
378
- framework_version = self .framework_version ,
379
- model_server_workers = model_server_workers ,
380
- sagemaker_session = self .sagemaker_session ,
381
- vpc_config = self .get_vpc_config (vpc_config_override ),
382
- dependencies = dependencies or self .dependencies ,
383
- ** kwargs
384
- )
385
-
386
306
def hyperparameters (self ):
387
307
"""Return hyperparameters used by your custom TensorFlow code during model training."""
388
308
hyperparameters = super (TensorFlow , self ).hyperparameters ()
@@ -464,137 +384,3 @@ def train_image(self):
464
384
)
465
385
466
386
return super (TensorFlow , self ).train_image ()
467
-
468
- def transformer (
469
- self ,
470
- instance_count ,
471
- instance_type ,
472
- strategy = None ,
473
- assemble_with = None ,
474
- output_path = None ,
475
- output_kms_key = None ,
476
- accept = None ,
477
- env = None ,
478
- max_concurrent_transforms = None ,
479
- max_payload = None ,
480
- tags = None ,
481
- role = None ,
482
- model_server_workers = None ,
483
- volume_kms_key = None ,
484
- endpoint_type = None ,
485
- entry_point = None ,
486
- vpc_config_override = VPC_CONFIG_DEFAULT ,
487
- enable_network_isolation = None ,
488
- model_name = None ,
489
- ):
490
- """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It
491
- reuses the SageMaker Session and base job name used by the Estimator.
492
-
493
- Args:
494
- instance_count (int): Number of EC2 instances to use.
495
- instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
496
- strategy (str): The strategy used to decide how to batch records in a single request
497
- (default: None). Valid values: 'MultiRecord' and 'SingleRecord'.
498
- assemble_with (str): How the output is assembled (default: None). Valid values: 'Line'
499
- or 'None'.
500
- output_path (str): S3 location for saving the transform result. If not specified,
501
- results are stored to a default bucket.
502
- output_kms_key (str): Optional. KMS key ID for encrypting the transform output
503
- (default: None).
504
- accept (str): The accept header passed by the client to
505
- the inference endpoint. If it is supported by the endpoint,
506
- it will be the format of the batch transform output.
507
- env (dict): Environment variables to be set for use during the transform job
508
- (default: None).
509
- max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
510
- each individual transform container at one time.
511
- max_payload (int): Maximum size of the payload in a single HTTP request to the
512
- container in MB.
513
- tags (list[dict]): List of tags for labeling a transform job. If none specified, then
514
- the tags used for the training job are used for the transform job.
515
- role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
516
- used during transform jobs. If not specified, the role from the Estimator will be
517
- used.
518
- model_server_workers (int): Optional. The number of worker processes used by the
519
- inference server. If None, server will use one worker per vCPU.
520
- volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
521
- compute instance (default: None).
522
- endpoint_type (str): Optional. Selects the software stack used by the inference server.
523
- If not specified, the model will be configured to use the default
524
- SageMaker model server.
525
- If 'tensorflow-serving', the model will be configured to
526
- use the SageMaker Tensorflow Serving container.
527
- entry_point (str): Path (absolute or relative) to the local Python source file which
528
- should be executed as the entry point to training. If not specified and
529
- ``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
530
- ``endpoint_type`` is also ``None``, then the training entry point is used.
531
- vpc_config_override (dict[str, list[str]]): Optional override for
532
- the VpcConfig set on the model.
533
- Default: use subnets and security groups from this Estimator.
534
-
535
- * 'Subnets' (list[str]): List of subnet ids.
536
- * 'SecurityGroupIds' (list[str]): List of security group ids.
537
-
538
- enable_network_isolation (bool): Specifies whether container will
539
- run in network isolation mode. Network isolation mode restricts
540
- the container access to outside networks (such as the internet).
541
- The container does not make any inbound or outbound network
542
- calls. If True, a channel named "code" will be created for any
543
- user entry script for inference. Also known as Internet-free mode.
544
- If not specified, this setting is taken from the estimator's
545
- current configuration.
546
- model_name (str): Name to use for creating an Amazon SageMaker
547
- model. If not specified, the name of the training job is used.
548
- """
549
- role = role or self .role
550
-
551
- if self .latest_training_job is None :
552
- logging .warning (
553
- "No finished training job found associated with this estimator. Please make sure "
554
- "this estimator is only used for building workflow config"
555
- )
556
- return Transformer (
557
- model_name or self ._current_job_name ,
558
- instance_count ,
559
- instance_type ,
560
- strategy = strategy ,
561
- assemble_with = assemble_with ,
562
- output_path = output_path ,
563
- output_kms_key = output_kms_key ,
564
- accept = accept ,
565
- max_concurrent_transforms = max_concurrent_transforms ,
566
- max_payload = max_payload ,
567
- env = env or {},
568
- tags = tags ,
569
- base_transform_job_name = self .base_job_name ,
570
- volume_kms_key = volume_kms_key ,
571
- sagemaker_session = self .sagemaker_session ,
572
- )
573
-
574
- if enable_network_isolation is None :
575
- enable_network_isolation = self .enable_network_isolation ()
576
-
577
- model = self .create_model (
578
- model_server_workers = model_server_workers ,
579
- role = role ,
580
- vpc_config_override = vpc_config_override ,
581
- endpoint_type = endpoint_type ,
582
- entry_point = entry_point ,
583
- enable_network_isolation = enable_network_isolation ,
584
- name = model_name ,
585
- )
586
-
587
- return model .transformer (
588
- instance_count ,
589
- instance_type ,
590
- strategy = strategy ,
591
- assemble_with = assemble_with ,
592
- output_path = output_path ,
593
- output_kms_key = output_kms_key ,
594
- accept = accept ,
595
- env = env ,
596
- max_concurrent_transforms = max_concurrent_transforms ,
597
- max_payload = max_payload ,
598
- tags = tags ,
599
- volume_kms_key = volume_kms_key ,
600
- )
0 commit comments