@@ -376,8 +376,8 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
376
376
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
377
377
image (str): An container image to use for deploying the model
378
378
model_server_workers (int): The number of worker processes used by the inference server.
379
- If None, server will use one worker per vCPU. Only effective when estimator is
380
- SageMaker framework.
379
+ If None, server will use one worker per vCPU. Only effective when estimator is a
380
+ SageMaker framework.
381
381
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
382
382
Default: use subnets and security groups from this Estimator.
383
383
* 'Subnets' (list[str]): List of subnet ids.
@@ -394,5 +394,223 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
394
394
elif isinstance (estimator , sagemaker .estimator .Framework ):
395
395
model = estimator .create_model (model_server_workers = model_server_workers , role = role ,
396
396
vpc_config_override = vpc_config_override )
397
+ else :
398
+ raise TypeError ('Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework'
399
+ ' or sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.' )
397
400
398
401
return model_config (instance_type , model , role , image )
402
+
403
+
404
+ def transform_config (transformer , data , data_type = 'S3Prefix' , content_type = None , compression_type = None ,
405
+ split_type = None , job_name = None ):
406
+ """Export Airflow transform config from a SageMaker transformer
407
+
408
+ Args:
409
+ transformer (sagemaker.transformer.Transformer): The SageMaker transformer to export Airflow
410
+ config from.
411
+ data (str): Input data location in S3.
412
+ data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values:
413
+
414
+ * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as
415
+ inputs for the transform job.
416
+ * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as
417
+ an input for the transform job.
418
+
419
+ content_type (str): MIME type of the input data (default: None).
420
+ compression_type (str): Compression type of the input data, if compressed (default: None).
421
+ Valid values: 'Gzip', None.
422
+ split_type (str): The record delimiter for the input object (default: 'None').
423
+ Valid values: 'None', 'Line', and 'RecordIO'.
424
+ job_name (str): job name (default: None). If not specified, one will be generated.
425
+
426
+ Returns:
427
+ dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow.
428
+ """
429
+ if job_name is not None :
430
+ transformer ._current_job_name = job_name
431
+ else :
432
+ base_name = transformer .base_transform_job_name
433
+ transformer ._current_job_name = utils .airflow_name_from_base (base_name ) \
434
+ if base_name is not None else transformer .model_name
435
+
436
+ if transformer .output_path is None :
437
+ transformer .output_path = 's3://{}/{}' .format (
438
+ transformer .sagemaker_session .default_bucket (), transformer ._current_job_name )
439
+
440
+ job_config = sagemaker .transformer ._TransformJob ._load_config (
441
+ data , data_type , content_type , compression_type , split_type , transformer )
442
+
443
+ config = {
444
+ 'TransformJobName' : transformer ._current_job_name ,
445
+ 'ModelName' : transformer .model_name ,
446
+ 'TransformInput' : job_config ['input_config' ],
447
+ 'TransformOutput' : job_config ['output_config' ],
448
+ 'TransformResources' : job_config ['resource_config' ],
449
+ }
450
+
451
+ if transformer .strategy is not None :
452
+ config ['BatchStrategy' ] = transformer .strategy
453
+
454
+ if transformer .max_concurrent_transforms is not None :
455
+ config ['MaxConcurrentTransforms' ] = transformer .max_concurrent_transforms
456
+
457
+ if transformer .max_payload is not None :
458
+ config ['MaxPayloadInMB' ] = transformer .max_payload
459
+
460
+ if transformer .env is not None :
461
+ config ['Environment' ] = transformer .env
462
+
463
+ if transformer .tags is not None :
464
+ config ['Tags' ] = transformer .tags
465
+
466
+ return config
467
+
468
+
469
+ def transform_config_from_estimator (estimator , instance_count , instance_type , data , data_type = 'S3Prefix' ,
470
+ content_type = None , compression_type = None , split_type = None ,
471
+ job_name = None , strategy = None , assemble_with = None , output_path = None ,
472
+ output_kms_key = None , accept = None , env = None , max_concurrent_transforms = None ,
473
+ max_payload = None , tags = None , role = None , volume_kms_key = None ,
474
+ model_server_workers = None , image = None , vpc_config_override = None ):
475
+ """Export Airflow transform config from a SageMaker estimator
476
+
477
+ Args:
478
+ estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
479
+ It has to be an estimator associated with a training job.
480
+ instance_count (int): Number of EC2 instances to use.
481
+ instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
482
+ data (str): Input data location in S3.
483
+ data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values:
484
+
485
+ * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as
486
+ inputs for the transform job.
487
+ * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as
488
+ an input for the transform job.
489
+
490
+ content_type (str): MIME type of the input data (default: None).
491
+ compression_type (str): Compression type of the input data, if compressed (default: None).
492
+ Valid values: 'Gzip', None.
493
+ split_type (str): The record delimiter for the input object (default: 'None').
494
+ Valid values: 'None', 'Line', and 'RecordIO'.
495
+ job_name (str): job name (default: None). If not specified, one will be generated.
496
+ strategy (str): The strategy used to decide how to batch records in a single request (default: None).
497
+ Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
498
+ assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
499
+ output_path (str): S3 location for saving the transform result. If not specified, results are stored to
500
+ a default bucket.
501
+ output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
502
+ accept (str): The content type accepted by the endpoint deployed during the transform job.
503
+ env (dict): Environment variables to be set for use during the transform job (default: None).
504
+ max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
505
+ each individual transform container at one time.
506
+ max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
507
+ tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
508
+ the training job are used for the transform job.
509
+ role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
510
+ transform jobs. If not specified, the role from the Estimator will be used.
511
+ volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
512
+ compute instance (default: None).
513
+ model_server_workers (int): Optional. The number of worker processes used by the inference server.
514
+ If None, server will use one worker per vCPU.
515
+ image (str): An container image to use for deploying the model
516
+ vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
517
+ Default: use subnets and security groups from this Estimator.
518
+ * 'Subnets' (list[str]): List of subnet ids.
519
+ * 'SecurityGroupIds' (list[str]): List of security group ids.
520
+
521
+ Returns:
522
+ dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow.
523
+ """
524
+ model_base_config = model_config_from_estimator (instance_type = instance_type , estimator = estimator , role = role ,
525
+ image = image , model_server_workers = model_server_workers ,
526
+ vpc_config_override = vpc_config_override )
527
+
528
+ if isinstance (estimator , sagemaker .estimator .Framework ):
529
+ transformer = estimator .transformer (instance_count , instance_type , strategy , assemble_with , output_path ,
530
+ output_kms_key , accept , env , max_concurrent_transforms ,
531
+ max_payload , tags , role , model_server_workers , volume_kms_key )
532
+ else :
533
+ transformer = estimator .transformer (instance_count , instance_type , strategy , assemble_with , output_path ,
534
+ output_kms_key , accept , env , max_concurrent_transforms ,
535
+ max_payload , tags , role , volume_kms_key )
536
+
537
+ transform_base_config = transform_config (transformer , data , data_type , content_type , compression_type ,
538
+ split_type , job_name )
539
+
540
+ config = {
541
+ 'Model' : model_base_config ,
542
+ 'Transform' : transform_base_config
543
+ }
544
+
545
+ return config
546
+
547
+
548
+ def deploy_config (model , initial_instance_count , instance_type , endpoint_name = None , tags = None ):
549
+ """Export Airflow deploy config from a SageMaker model
550
+
551
+ Args:
552
+ model (sagemaker.model.Model): The SageMaker model to export the Airflow config from.
553
+ instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
554
+ initial_instance_count (int): The initial number of instances to run in the
555
+ ``Endpoint`` created from this ``Model``.
556
+ endpoint_name (str): The name of the endpoint to create (default: None).
557
+ If not specified, a unique endpoint name will be created.
558
+ tags (list[dict]): List of tags for labeling a training job. For more, see
559
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
560
+
561
+ Returns:
562
+ dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow.
563
+
564
+ """
565
+ model_base_config = model_config (instance_type , model )
566
+
567
+ production_variant = sagemaker .production_variant (model .name , instance_type , initial_instance_count )
568
+ name = model .name
569
+ config_options = {'EndpointConfigName' : name , 'ProductionVariants' : [production_variant ]}
570
+ if tags is not None :
571
+ config_options ['Tags' ] = tags
572
+
573
+ endpoint_name = endpoint_name or name
574
+ endpoint_base_config = {
575
+ 'EndpointName' : endpoint_name ,
576
+ 'EndpointConfigName' : name
577
+ }
578
+
579
+ config = {
580
+ 'Model' : model_base_config ,
581
+ 'EndpointConfig' : config_options ,
582
+ 'Endpoint' : endpoint_base_config
583
+ }
584
+
585
+ # if there is s3 operations needed for model, move it to root level of config
586
+ s3_operations = model_base_config .pop ('S3Operations' , None )
587
+ if s3_operations is not None :
588
+ config ['S3Operations' ] = s3_operations
589
+
590
+ return config
591
+
592
+
593
+ def deploy_config_from_estimator (estimator , initial_instance_count , instance_type , endpoint_name = None ,
594
+ tags = None , ** kwargs ):
595
+ """Export Airflow deploy config from a SageMaker estimator
596
+
597
+ Args:
598
+ estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
599
+ It has to be an estimator associated with a training job.
600
+ initial_instance_count (int): Minimum number of EC2 instances to deploy to an endpoint for prediction.
601
+ instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction,
602
+ for example, 'ml.c4.xlarge'.
603
+ endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of
604
+ the training job is used.
605
+ tags (list[dict]): List of tags for labeling a training job. For more, see
606
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
607
+ **kwargs: Passed to invocation of ``create_model()``. Implementations may customize
608
+ ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
609
+ For more, see the implementation docs.
610
+
611
+ Returns:
612
+ dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow.
613
+ """
614
+ model = estimator .create_model (** kwargs )
615
+ config = deploy_config (model , initial_instance_count , instance_type , endpoint_name , tags )
616
+ return config
0 commit comments