Skip to content

Commit 53a43f6

Browse files
authored
Add APIs to export transform and deploy config (aws#497)
1 parent e37ac12 commit 53a43f6

File tree

5 files changed

+620
-17
lines changed

5 files changed

+620
-17
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.15.1.dev
6+
==========
7+
8+
* feature: Add APIs to export Airflow transform and deploy config
9+
510
1.15.0
611
======
712

src/sagemaker/estimator.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def model_data(self):
295295
logging.warning('No finished training job found associated with this estimator. Please make sure'
296296
'this estimator is only used for building workflow config')
297297
model_uri = os.path.join(self.output_path, self._current_job_name, 'output', 'model.tar.gz')
298+
298299
return model_uri
299300

300301
@abstractmethod
@@ -390,9 +391,13 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
390391
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
391392
compute instance (default: None).
392393
"""
393-
self._ensure_latest_training_job()
394+
if self.latest_training_job is not None:
395+
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role)
396+
else:
397+
logging.warning('No finished training job found associated with this estimator. Please make sure'
398+
'this estimator is only used for building workflow config')
399+
model_name = self._current_job_name
394400

395-
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role)
396401
tags = tags or self.tags
397402

398403
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
@@ -879,19 +884,23 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
879884
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
880885
compute instance (default: None).
881886
"""
882-
self._ensure_latest_training_job()
883887
role = role or self.role
884888

885-
model = self.create_model(role=role, model_server_workers=model_server_workers)
886-
887-
container_def = model.prepare_container_def(instance_type)
888-
model_name = model.name or name_from_image(container_def['Image'])
889-
vpc_config = model.vpc_config
890-
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config)
891-
892-
transform_env = model.env.copy()
893-
if env is not None:
894-
transform_env.update(env)
889+
if self.latest_training_job is not None:
890+
model = self.create_model(role=role, model_server_workers=model_server_workers)
891+
892+
container_def = model.prepare_container_def(instance_type)
893+
model_name = model.name or name_from_image(container_def['Image'])
894+
vpc_config = model.vpc_config
895+
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config)
896+
transform_env = model.env.copy()
897+
if env is not None:
898+
transform_env.update(env)
899+
else:
900+
logging.warning('No finished training job found associated with this estimator. Please make sure'
901+
'this estimator is only used for building workflow config')
902+
model_name = self._current_job_name
903+
transform_env = env or {}
895904

896905
tags = tags or self.tags
897906
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,

src/sagemaker/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
9090
an input for the transform job.
9191
9292
content_type (str): MIME type of the input data (default: None).
93-
compression (str): Compression type of the input data, if compressed (default: None).
93+
compression_type (str): Compression type of the input data, if compressed (default: None).
9494
Valid values: 'Gzip', None.
9595
split_type (str): The record delimiter for the input object (default: 'None').
9696
Valid values: 'None', 'Line', and 'RecordIO'.

src/sagemaker/workflow/airflow.py

Lines changed: 220 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,8 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
376376
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
377377
image (str): An container image to use for deploying the model
378378
model_server_workers (int): The number of worker processes used by the inference server.
379-
If None, server will use one worker per vCPU. Only effective when estimator is
380-
SageMaker framework.
379+
If None, server will use one worker per vCPU. Only effective when estimator is a
380+
SageMaker framework.
381381
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
382382
Default: use subnets and security groups from this Estimator.
383383
* 'Subnets' (list[str]): List of subnet ids.
@@ -394,5 +394,223 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
394394
elif isinstance(estimator, sagemaker.estimator.Framework):
395395
model = estimator.create_model(model_server_workers=model_server_workers, role=role,
396396
vpc_config_override=vpc_config_override)
397+
else:
398+
raise TypeError('Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework'
399+
' or sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.')
397400

398401
return model_config(instance_type, model, role, image)
402+
403+
404+
def transform_config(transformer, data, data_type='S3Prefix', content_type=None, compression_type=None,
405+
split_type=None, job_name=None):
406+
"""Export Airflow transform config from a SageMaker transformer
407+
408+
Args:
409+
transformer (sagemaker.transformer.Transformer): The SageMaker transformer to export Airflow
410+
config from.
411+
data (str): Input data location in S3.
412+
data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values:
413+
414+
* 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as
415+
inputs for the transform job.
416+
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as
417+
an input for the transform job.
418+
419+
content_type (str): MIME type of the input data (default: None).
420+
compression_type (str): Compression type of the input data, if compressed (default: None).
421+
Valid values: 'Gzip', None.
422+
split_type (str): The record delimiter for the input object (default: 'None').
423+
Valid values: 'None', 'Line', and 'RecordIO'.
424+
job_name (str): job name (default: None). If not specified, one will be generated.
425+
426+
Returns:
427+
dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow.
428+
"""
429+
if job_name is not None:
430+
transformer._current_job_name = job_name
431+
else:
432+
base_name = transformer.base_transform_job_name
433+
transformer._current_job_name = utils.airflow_name_from_base(base_name) \
434+
if base_name is not None else transformer.model_name
435+
436+
if transformer.output_path is None:
437+
transformer.output_path = 's3://{}/{}'.format(
438+
transformer.sagemaker_session.default_bucket(), transformer._current_job_name)
439+
440+
job_config = sagemaker.transformer._TransformJob._load_config(
441+
data, data_type, content_type, compression_type, split_type, transformer)
442+
443+
config = {
444+
'TransformJobName': transformer._current_job_name,
445+
'ModelName': transformer.model_name,
446+
'TransformInput': job_config['input_config'],
447+
'TransformOutput': job_config['output_config'],
448+
'TransformResources': job_config['resource_config'],
449+
}
450+
451+
if transformer.strategy is not None:
452+
config['BatchStrategy'] = transformer.strategy
453+
454+
if transformer.max_concurrent_transforms is not None:
455+
config['MaxConcurrentTransforms'] = transformer.max_concurrent_transforms
456+
457+
if transformer.max_payload is not None:
458+
config['MaxPayloadInMB'] = transformer.max_payload
459+
460+
if transformer.env is not None:
461+
config['Environment'] = transformer.env
462+
463+
if transformer.tags is not None:
464+
config['Tags'] = transformer.tags
465+
466+
return config
467+
468+
469+
def transform_config_from_estimator(estimator, instance_count, instance_type, data, data_type='S3Prefix',
470+
content_type=None, compression_type=None, split_type=None,
471+
job_name=None, strategy=None, assemble_with=None, output_path=None,
472+
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
473+
max_payload=None, tags=None, role=None, volume_kms_key=None,
474+
model_server_workers=None, image=None, vpc_config_override=None):
475+
"""Export Airflow transform config from a SageMaker estimator
476+
477+
Args:
478+
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
479+
It has to be an estimator associated with a training job.
480+
instance_count (int): Number of EC2 instances to use.
481+
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
482+
data (str): Input data location in S3.
483+
data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values:
484+
485+
* 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as
486+
inputs for the transform job.
487+
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as
488+
an input for the transform job.
489+
490+
content_type (str): MIME type of the input data (default: None).
491+
compression_type (str): Compression type of the input data, if compressed (default: None).
492+
Valid values: 'Gzip', None.
493+
split_type (str): The record delimiter for the input object (default: 'None').
494+
Valid values: 'None', 'Line', and 'RecordIO'.
495+
job_name (str): job name (default: None). If not specified, one will be generated.
496+
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
497+
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
498+
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
499+
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
500+
a default bucket.
501+
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
502+
accept (str): The content type accepted by the endpoint deployed during the transform job.
503+
env (dict): Environment variables to be set for use during the transform job (default: None).
504+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
505+
each individual transform container at one time.
506+
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
507+
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
508+
the training job are used for the transform job.
509+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
510+
transform jobs. If not specified, the role from the Estimator will be used.
511+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
512+
compute instance (default: None).
513+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
514+
If None, server will use one worker per vCPU.
515+
image (str): An container image to use for deploying the model
516+
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
517+
Default: use subnets and security groups from this Estimator.
518+
* 'Subnets' (list[str]): List of subnet ids.
519+
* 'SecurityGroupIds' (list[str]): List of security group ids.
520+
521+
Returns:
522+
dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow.
523+
"""
524+
model_base_config = model_config_from_estimator(instance_type=instance_type, estimator=estimator, role=role,
525+
image=image, model_server_workers=model_server_workers,
526+
vpc_config_override=vpc_config_override)
527+
528+
if isinstance(estimator, sagemaker.estimator.Framework):
529+
transformer = estimator.transformer(instance_count, instance_type, strategy, assemble_with, output_path,
530+
output_kms_key, accept, env, max_concurrent_transforms,
531+
max_payload, tags, role, model_server_workers, volume_kms_key)
532+
else:
533+
transformer = estimator.transformer(instance_count, instance_type, strategy, assemble_with, output_path,
534+
output_kms_key, accept, env, max_concurrent_transforms,
535+
max_payload, tags, role, volume_kms_key)
536+
537+
transform_base_config = transform_config(transformer, data, data_type, content_type, compression_type,
538+
split_type, job_name)
539+
540+
config = {
541+
'Model': model_base_config,
542+
'Transform': transform_base_config
543+
}
544+
545+
return config
546+
547+
548+
def deploy_config(model, initial_instance_count, instance_type, endpoint_name=None, tags=None):
549+
"""Export Airflow deploy config from a SageMaker model
550+
551+
Args:
552+
model (sagemaker.model.Model): The SageMaker model to export the Airflow config from.
553+
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
554+
initial_instance_count (int): The initial number of instances to run in the
555+
``Endpoint`` created from this ``Model``.
556+
endpoint_name (str): The name of the endpoint to create (default: None).
557+
If not specified, a unique endpoint name will be created.
558+
tags (list[dict]): List of tags for labeling a training job. For more, see
559+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
560+
561+
Returns:
562+
dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow.
563+
564+
"""
565+
model_base_config = model_config(instance_type, model)
566+
567+
production_variant = sagemaker.production_variant(model.name, instance_type, initial_instance_count)
568+
name = model.name
569+
config_options = {'EndpointConfigName': name, 'ProductionVariants': [production_variant]}
570+
if tags is not None:
571+
config_options['Tags'] = tags
572+
573+
endpoint_name = endpoint_name or name
574+
endpoint_base_config = {
575+
'EndpointName': endpoint_name,
576+
'EndpointConfigName': name
577+
}
578+
579+
config = {
580+
'Model': model_base_config,
581+
'EndpointConfig': config_options,
582+
'Endpoint': endpoint_base_config
583+
}
584+
585+
# if there is s3 operations needed for model, move it to root level of config
586+
s3_operations = model_base_config.pop('S3Operations', None)
587+
if s3_operations is not None:
588+
config['S3Operations'] = s3_operations
589+
590+
return config
591+
592+
593+
def deploy_config_from_estimator(estimator, initial_instance_count, instance_type, endpoint_name=None,
594+
tags=None, **kwargs):
595+
"""Export Airflow deploy config from a SageMaker estimator
596+
597+
Args:
598+
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
599+
It has to be an estimator associated with a training job.
600+
initial_instance_count (int): Minimum number of EC2 instances to deploy to an endpoint for prediction.
601+
instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction,
602+
for example, 'ml.c4.xlarge'.
603+
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of
604+
the training job is used.
605+
tags (list[dict]): List of tags for labeling a training job. For more, see
606+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
607+
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
608+
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
609+
For more, see the implementation docs.
610+
611+
Returns:
612+
dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow.
613+
"""
614+
model = estimator.create_model(**kwargs)
615+
config = deploy_config(model, initial_instance_count, instance_type, endpoint_name, tags)
616+
return config

0 commit comments

Comments
 (0)