diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 34bf6e2b5f..42d3e593f9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,11 @@ CHANGELOG ========= +1.15.1.dev +========== + +* feature: Add APIs to export Airflow transform and deploy config + 1.15.0 ====== diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 920ef814ba..d6cf2c195b 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -295,6 +295,7 @@ def model_data(self): logging.warning('No finished training job found associated with this estimator. Please make sure' 'this estimator is only used for building workflow config') model_uri = os.path.join(self.output_path, self._current_job_name, 'output', 'model.tar.gz') + return model_uri @abstractmethod @@ -390,9 +391,13 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML compute instance (default: None). """ - self._ensure_latest_training_job() + if self.latest_training_job is not None: + model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role) + else: + logging.warning('No finished training job found associated with this estimator. Please make sure' + 'this estimator is only used for building workflow config') + model_name = self._current_job_name - model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role) tags = tags or self.tags return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, @@ -879,19 +884,23 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML compute instance (default: None). """ - self._ensure_latest_training_job() role = role or self.role - model = self.create_model(role=role, model_server_workers=model_server_workers) - - container_def = model.prepare_container_def(instance_type) - model_name = model.name or name_from_image(container_def['Image']) - vpc_config = model.vpc_config - self.sagemaker_session.create_model(model_name, role, container_def, vpc_config) - - transform_env = model.env.copy() - if env is not None: - transform_env.update(env) + if self.latest_training_job is not None: + model = self.create_model(role=role, model_server_workers=model_server_workers) + + container_def = model.prepare_container_def(instance_type) + model_name = model.name or name_from_image(container_def['Image']) + vpc_config = model.vpc_config + self.sagemaker_session.create_model(model_name, role, container_def, vpc_config) + transform_env = model.env.copy() + if env is not None: + transform_env.update(env) + else: + logging.warning('No finished training job found associated with this estimator. Please make sure' + 'this estimator is only used for building workflow config') + model_name = self._current_job_name + transform_env = env or {} tags = tags or self.tags return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index eda684609d..c6016e9a1a 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -90,7 +90,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t an input for the transform job. content_type (str): MIME type of the input data (default: None). - compression (str): Compression type of the input data, if compressed (default: None). + compression_type (str): Compression type of the input data, if compressed (default: None). Valid values: 'Gzip', None. split_type (str): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', and 'RecordIO'. diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 3e366b253c..e5078c4234 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -376,8 +376,8 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None, role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model image (str): An container image to use for deploying the model model_server_workers (int): The number of worker processes used by the inference server. - If None, server will use one worker per vCPU. Only effective when estimator is - SageMaker framework. + If None, server will use one worker per vCPU. Only effective when estimator is a + SageMaker framework. vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model. Default: use subnets and security groups from this Estimator. * 'Subnets' (list[str]): List of subnet ids. @@ -394,5 +394,223 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None, elif isinstance(estimator, sagemaker.estimator.Framework): model = estimator.create_model(model_server_workers=model_server_workers, role=role, vpc_config_override=vpc_config_override) + else: + raise TypeError('Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework' + ' or sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.') return model_config(instance_type, model, role, image) + + +def transform_config(transformer, data, data_type='S3Prefix', content_type=None, compression_type=None, + split_type=None, job_name=None): + """Export Airflow transform config from a SageMaker transformer + + Args: + transformer (sagemaker.transformer.Transformer): The SageMaker transformer to export Airflow + config from. + data (str): Input data location in S3. + data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values: + + * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as + inputs for the transform job. + * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as + an input for the transform job. + + content_type (str): MIME type of the input data (default: None). + compression_type (str): Compression type of the input data, if compressed (default: None). + Valid values: 'Gzip', None. + split_type (str): The record delimiter for the input object (default: 'None'). + Valid values: 'None', 'Line', and 'RecordIO'. + job_name (str): job name (default: None). If not specified, one will be generated. + + Returns: + dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow. + """ + if job_name is not None: + transformer._current_job_name = job_name + else: + base_name = transformer.base_transform_job_name + transformer._current_job_name = utils.airflow_name_from_base(base_name) \ + if base_name is not None else transformer.model_name + + if transformer.output_path is None: + transformer.output_path = 's3://{}/{}'.format( + transformer.sagemaker_session.default_bucket(), transformer._current_job_name) + + job_config = sagemaker.transformer._TransformJob._load_config( + data, data_type, content_type, compression_type, split_type, transformer) + + config = { + 'TransformJobName': transformer._current_job_name, + 'ModelName': transformer.model_name, + 'TransformInput': job_config['input_config'], + 'TransformOutput': job_config['output_config'], + 'TransformResources': job_config['resource_config'], + } + + if transformer.strategy is not None: + config['BatchStrategy'] = transformer.strategy + + if transformer.max_concurrent_transforms is not None: + config['MaxConcurrentTransforms'] = transformer.max_concurrent_transforms + + if transformer.max_payload is not None: + config['MaxPayloadInMB'] = transformer.max_payload + + if transformer.env is not None: + config['Environment'] = transformer.env + + if transformer.tags is not None: + config['Tags'] = transformer.tags + + return config + + +def transform_config_from_estimator(estimator, instance_count, instance_type, data, data_type='S3Prefix', + content_type=None, compression_type=None, split_type=None, + job_name=None, strategy=None, assemble_with=None, output_path=None, + output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, + max_payload=None, tags=None, role=None, volume_kms_key=None, + model_server_workers=None, image=None, vpc_config_override=None): + """Export Airflow transform config from a SageMaker estimator + + Args: + estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from. + It has to be an estimator associated with a training job. + instance_count (int): Number of EC2 instances to use. + instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'. + data (str): Input data location in S3. + data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values: + + * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as + inputs for the transform job. + * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as + an input for the transform job. + + content_type (str): MIME type of the input data (default: None). + compression_type (str): Compression type of the input data, if compressed (default: None). + Valid values: 'Gzip', None. + split_type (str): The record delimiter for the input object (default: 'None'). + Valid values: 'None', 'Line', and 'RecordIO'. + job_name (str): job name (default: None). If not specified, one will be generated. + strategy (str): The strategy used to decide how to batch records in a single request (default: None). + Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'. + assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'. + output_path (str): S3 location for saving the transform result. If not specified, results are stored to + a default bucket. + output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None). + accept (str): The content type accepted by the endpoint deployed during the transform job. + env (dict): Environment variables to be set for use during the transform job (default: None). + max_concurrent_transforms (int): The maximum number of HTTP requests to be made to + each individual transform container at one time. + max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. + tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for + the training job are used for the transform job. + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during + transform jobs. If not specified, the role from the Estimator will be used. + volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML + compute instance (default: None). + model_server_workers (int): Optional. The number of worker processes used by the inference server. + If None, server will use one worker per vCPU. + image (str): An container image to use for deploying the model + vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model. + Default: use subnets and security groups from this Estimator. + * 'Subnets' (list[str]): List of subnet ids. + * 'SecurityGroupIds' (list[str]): List of security group ids. + + Returns: + dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow. + """ + model_base_config = model_config_from_estimator(instance_type=instance_type, estimator=estimator, role=role, + image=image, model_server_workers=model_server_workers, + vpc_config_override=vpc_config_override) + + if isinstance(estimator, sagemaker.estimator.Framework): + transformer = estimator.transformer(instance_count, instance_type, strategy, assemble_with, output_path, + output_kms_key, accept, env, max_concurrent_transforms, + max_payload, tags, role, model_server_workers, volume_kms_key) + else: + transformer = estimator.transformer(instance_count, instance_type, strategy, assemble_with, output_path, + output_kms_key, accept, env, max_concurrent_transforms, + max_payload, tags, role, volume_kms_key) + + transform_base_config = transform_config(transformer, data, data_type, content_type, compression_type, + split_type, job_name) + + config = { + 'Model': model_base_config, + 'Transform': transform_base_config + } + + return config + + +def deploy_config(model, initial_instance_count, instance_type, endpoint_name=None, tags=None): + """Export Airflow deploy config from a SageMaker model + + Args: + model (sagemaker.model.Model): The SageMaker model to export the Airflow config from. + instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. + initial_instance_count (int): The initial number of instances to run in the + ``Endpoint`` created from this ``Model``. + endpoint_name (str): The name of the endpoint to create (default: None). + If not specified, a unique endpoint name will be created. + tags (list[dict]): List of tags for labeling a training job. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + + Returns: + dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow. + + """ + model_base_config = model_config(instance_type, model) + + production_variant = sagemaker.production_variant(model.name, instance_type, initial_instance_count) + name = model.name + config_options = {'EndpointConfigName': name, 'ProductionVariants': [production_variant]} + if tags is not None: + config_options['Tags'] = tags + + endpoint_name = endpoint_name or name + endpoint_base_config = { + 'EndpointName': endpoint_name, + 'EndpointConfigName': name + } + + config = { + 'Model': model_base_config, + 'EndpointConfig': config_options, + 'Endpoint': endpoint_base_config + } + + # if there is s3 operations needed for model, move it to root level of config + s3_operations = model_base_config.pop('S3Operations', None) + if s3_operations is not None: + config['S3Operations'] = s3_operations + + return config + + +def deploy_config_from_estimator(estimator, initial_instance_count, instance_type, endpoint_name=None, + tags=None, **kwargs): + """Export Airflow deploy config from a SageMaker estimator + + Args: + estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from. + It has to be an estimator associated with a training job. + initial_instance_count (int): Minimum number of EC2 instances to deploy to an endpoint for prediction. + instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction, + for example, 'ml.c4.xlarge'. + endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of + the training job is used. + tags (list[dict]): List of tags for labeling a training job. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + **kwargs: Passed to invocation of ``create_model()``. Implementations may customize + ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. + For more, see the implementation docs. + + Returns: + dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow. + """ + model = estimator.create_model(**kwargs) + config = deploy_config(model, initial_instance_count, instance_type, endpoint_name, tags) + return config diff --git a/tests/unit/test_airflow.py b/tests/unit/test_airflow.py index 21345f5a5c..0dcb2d0b0a 100644 --- a/tests/unit/test_airflow.py +++ b/tests/unit/test_airflow.py @@ -16,7 +16,7 @@ import pytest import mock -from sagemaker import chainer, estimator, model, mxnet, tensorflow, tuner +from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner from sagemaker.workflow import airflow from sagemaker.amazon import amazon_estimator from sagemaker.amazon import knn, ntm, pca @@ -757,3 +757,374 @@ def test_model_config_from_amazon_alg_estimator(sagemaker_session): } assert config == expected_config + + +def test_transformer_config(sagemaker_session): + tf_transformer = transformer.Transformer( + model_name="tensorflow-model", + instance_count="{{ instance_count }}", + instance_type="ml.p2.xlarge", + strategy="SingleRecord", + assemble_with='Line', + output_path="{{ output_path }}", + output_kms_key="{{ kms_key }}", + accept="{{ accept }}", + max_concurrent_transforms="{{ max_parallel_job }}", + max_payload="{{ max_payload }}", + tags=[{"{{ key }}": "{{ value }}"}], + env={"{{ key }}": "{{ value }}"}, + base_transform_job_name="tensorflow-transform", + sagemaker_session=sagemaker_session, + volume_kms_key="{{ kms_key }}") + + data = "{{ transform_data }}" + + config = airflow.transform_config(tf_transformer, data, data_type='S3Prefix', content_type="{{ content_type }}", + compression_type="{{ compression_type }}", split_type="{{ split_type }}") + expected_config = { + 'TransformJobName': "tensorflow-transform-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'ModelName': 'tensorflow-model', + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': '{{ transform_data }}' + } + }, + 'ContentType': '{{ content_type }}', + 'CompressionType': '{{ compression_type }}', + 'SplitType': '{{ split_type }}'}, + 'TransformOutput': { + 'S3OutputPath': '{{ output_path }}', + 'KmsKeyId': '{{ kms_key }}', + 'AssembleWith': 'Line', + 'Accept': '{{ accept }}' + }, + 'TransformResources': { + 'InstanceCount': '{{ instance_count }}', + 'InstanceType': 'ml.p2.xlarge', + 'VolumeKmsKeyId': '{{ kms_key }}' + }, + 'BatchStrategy': 'SingleRecord', + 'MaxConcurrentTransforms': '{{ max_parallel_job }}', + 'MaxPayloadInMB': '{{ max_payload }}', + 'Environment': {'{{ key }}': '{{ value }}'}, + 'Tags': [{'{{ key }}': '{{ value }}'}] + } + + assert config == expected_config + + +def test_transform_config_from_framework_estimator(sagemaker_session): + mxnet_estimator = mxnet.MXNet( + entry_point="{{ entry_point }}", + source_dir="{{ source_dir }}", + py_version='py3', + framework_version='1.3.0', + role="{{ role }}", + train_instance_count=1, + train_instance_type='ml.m4.xlarge', + sagemaker_session=sagemaker_session, + base_job_name="{{ base_job_name }}", + hyperparameters={'batch_size': 100}) + + train_data = "{{ train_data }}" + transform_data = "{{ transform_data }}" + + # simulate training + airflow.training_config(mxnet_estimator, train_data) + + config = airflow.transform_config_from_estimator( + estimator=mxnet_estimator, + instance_count="{{ instance_count }}", + instance_type="ml.p2.xlarge", + data=transform_data) + expected_config = { + 'Model': { + 'ModelName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'PrimaryContainer': { + 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-gpu-py3', + 'Environment': {'SAGEMAKER_PROGRAM': '{{ entry_point }}', + 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/{{ base_job_name }}-" + "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}/" + "source/sourcedir.tar.gz", + 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', + 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', + 'SAGEMAKER_REGION': 'us-west-2' + }, + 'ModelDataUrl': "s3://output/{{ base_job_name }}-" + "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}/output/model.tar.gz"}, + 'ExecutionRoleArn': '{{ role }}' + }, + 'Transform': { + 'TransformJobName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'ModelName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': '{{ transform_data }}' + } + } + }, + 'TransformOutput': { + 'S3OutputPath': "s3://output/{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + }, + 'TransformResources': { + 'InstanceCount': '{{ instance_count }}', + 'InstanceType': 'ml.p2.xlarge' + }, + 'Environment': {} + } + } + + assert config == expected_config + + +def test_transform_config_from_amazon_alg_estimator(sagemaker_session): + knn_estimator = knn.KNN( + role="{{ role }}", + train_instance_count="{{ instance_count }}", + train_instance_type='ml.m4.xlarge', + k=16, + sample_size=128, + predictor_type='regressor', + sagemaker_session=sagemaker_session) + + record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') + transform_data = "{{ transform_data }}" + + # simulate training + airflow.training_config(knn_estimator, record, mini_batch_size=256) + + config = airflow.transform_config_from_estimator( + estimator=knn_estimator, + instance_count="{{ instance_count }}", + instance_type="ml.p2.xlarge", + data=transform_data) + expected_config = { + 'Model': { + 'ModelName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'PrimaryContainer': { + 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1', + 'Environment': {}, + 'ModelDataUrl': "s3://output/knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + "/output/model.tar.gz"}, + 'ExecutionRoleArn': '{{ role }}'}, + 'Transform': {'TransformJobName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'ModelName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': '{{ transform_data }}'} + } + }, + 'TransformOutput': { + 'S3OutputPath': "s3://output/knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + }, + 'TransformResources': { + 'InstanceCount': '{{ instance_count }}', + 'InstanceType': 'ml.p2.xlarge'} + } + } + + assert config == expected_config + + +def test_deploy_framework_model_config(sagemaker_session): + chainer_model = chainer.ChainerModel( + model_data="{{ model_data }}", + role="{{ role }}", + entry_point="{{ entry_point }}", + source_dir="{{ source_dir }}", + image=None, + py_version='py3', + framework_version='5.0.0', + model_server_workers="{{ model_server_worker }}", + sagemaker_session=sagemaker_session) + + config = airflow.deploy_config(chainer_model, + initial_instance_count="{{ instance_count }}", + instance_type="ml.m4.xlarge") + expected_config = { + 'Model': { + 'ModelName': "sagemaker-chainer-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'PrimaryContainer': { + 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3', + 'Environment': { + 'SAGEMAKER_PROGRAM': '{{ entry_point }}', + 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/sagemaker-chainer-" + "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + "/source/sourcedir.tar.gz", + 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', + 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', + 'SAGEMAKER_REGION': 'us-west-2', + 'SAGEMAKER_MODEL_SERVER_WORKERS': '{{ model_server_worker }}' + }, + 'ModelDataUrl': '{{ model_data }}'}, + 'ExecutionRoleArn': '{{ role }}' + }, + 'EndpointConfig': { + 'EndpointConfigName': "sagemaker-chainer-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'ProductionVariants': [{ + 'InstanceType': 'ml.m4.xlarge', + 'InitialInstanceCount': '{{ instance_count }}', + 'ModelName': "sagemaker-chainer-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'VariantName': 'AllTraffic', + 'InitialVariantWeight': 1 + }] + }, + 'Endpoint': { + 'EndpointName': "sagemaker-chainer-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'EndpointConfigName': "sagemaker-chainer-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + }, + 'S3Operations': { + 'S3Upload': [{ + 'Path': '{{ source_dir }}', + 'Bucket': 'output', + 'Key': "sagemaker-chainer-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}/source/sourcedir.tar.gz", + 'Tar': True + }] + } + } + + assert config == expected_config + + +def test_deploy_amazon_alg_model_config(sagemaker_session): + pca_model = pca.PCAModel( + model_data="{{ model_data }}", + role="{{ role }}", + sagemaker_session=sagemaker_session) + + config = airflow.deploy_config(pca_model, + initial_instance_count="{{ instance_count }}", + instance_type='ml.c4.xlarge') + expected_config = { + 'Model': { + 'ModelName': "pca-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'PrimaryContainer': { + 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/pca:1', + 'Environment': {}, + 'ModelDataUrl': '{{ model_data }}'}, + 'ExecutionRoleArn': '{{ role }}'}, + 'EndpointConfig': { + 'EndpointConfigName': "pca-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'ProductionVariants': [{ + 'InstanceType': 'ml.c4.xlarge', + 'InitialInstanceCount': '{{ instance_count }}', + 'ModelName': "pca-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'VariantName': 'AllTraffic', + 'InitialVariantWeight': 1 + }] + }, + 'Endpoint': { + 'EndpointName': "pca-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'EndpointConfigName': "pca-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + } + } + + assert config == expected_config + + +def test_deploy_config_from_framework_estimator(sagemaker_session): + mxnet_estimator = mxnet.MXNet( + entry_point="{{ entry_point }}", + source_dir="{{ source_dir }}", + py_version='py3', + framework_version='1.3.0', + role="{{ role }}", + train_instance_count=1, + train_instance_type='ml.m4.xlarge', + sagemaker_session=sagemaker_session, + base_job_name="{{ base_job_name }}", + hyperparameters={'batch_size': 100}) + + train_data = "{{ train_data }}" + + # simulate training + airflow.training_config(mxnet_estimator, train_data) + + config = airflow.deploy_config_from_estimator(estimator=mxnet_estimator, + initial_instance_count="{{ instance_count}}", + instance_type="ml.c4.large", + endpoint_name="mxnet-endpoint") + expected_config = { + 'Model': { + 'ModelName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'PrimaryContainer': { + 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3', + 'Environment': { + 'SAGEMAKER_PROGRAM': '{{ entry_point }}', + 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/{{ base_job_name }}-" + "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}/" + "source/sourcedir.tar.gz", + 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', + 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', + 'SAGEMAKER_REGION': 'us-west-2'}, + 'ModelDataUrl': "s3://output/{{ base_job_name }}-" + "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}/output/model.tar.gz"}, + 'ExecutionRoleArn': '{{ role }}' + }, + 'EndpointConfig': { + 'EndpointConfigName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'ProductionVariants': [{ + 'InstanceType': 'ml.c4.large', + 'InitialInstanceCount': '{{ instance_count}}', + 'ModelName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'VariantName': 'AllTraffic', + 'InitialVariantWeight': 1 + }] + }, + 'Endpoint': { + 'EndpointName': 'mxnet-endpoint', + 'EndpointConfigName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + } + } + + assert config == expected_config + + +def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): + knn_estimator = knn.KNN( + role="{{ role }}", + train_instance_count="{{ instance_count }}", + train_instance_type='ml.m4.xlarge', + k=16, + sample_size=128, + predictor_type='regressor', + sagemaker_session=sagemaker_session) + + record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') + + # simulate training + airflow.training_config(knn_estimator, record, mini_batch_size=256) + + config = airflow.deploy_config_from_estimator(estimator=knn_estimator, + initial_instance_count="{{ instance_count }}", + instance_type="ml.p2.xlarge") + expected_config = { + 'Model': { + 'ModelName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'PrimaryContainer': { + 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1', + 'Environment': {}, + 'ModelDataUrl': "s3://output/knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + "/output/model.tar.gz"}, 'ExecutionRoleArn': '{{ role }}'}, + 'EndpointConfig': { + 'EndpointConfigName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'ProductionVariants': [{ + 'InstanceType': 'ml.p2.xlarge', + 'InitialInstanceCount': '{{ instance_count }}', + 'ModelName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'VariantName': 'AllTraffic', 'InitialVariantWeight': 1 + }] + }, + 'Endpoint': { + 'EndpointName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", + 'EndpointConfigName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" + } + } + + assert config == expected_config