Skip to content

Commit 172381b

Browse files
yangawslaurenyu
authored andcommitted
Add APIs to export Airflow model config (aws#492)
1 parent 537ba79 commit 172381b

File tree

5 files changed

+346
-20
lines changed

5 files changed

+346
-20
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ CHANGELOG
77

88
* bug-fix: Changes to use correct S3 bucket and time range for dataframes in TrainingJobAnalytics.
99
* bug-fix: Local Mode: correctly handle the case where the model output folder doesn't exist yet
10-
* feature: Add APIs to export Airflow training and tuning config
10+
* feature: Add APIs to export Airflow training, tuning and model config
1111
* doc-fix: Fix typos in tensorflow serving documentation
1212
* doc-fix: Add estimator base classes to API docs
1313
* feature: HyperparameterTuner: add support for Automatic Model Tuning's Warm Start Jobs

src/sagemaker/estimator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,14 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kw
288288
@property
289289
def model_data(self):
290290
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
291-
return self.sagemaker_session.sagemaker_client.describe_training_job(
292-
TrainingJobName=self.latest_training_job.name)['ModelArtifacts']['S3ModelArtifacts']
291+
if self.latest_training_job is not None:
292+
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
293+
TrainingJobName=self.latest_training_job.name)['ModelArtifacts']['S3ModelArtifacts']
294+
else:
295+
logging.warning('No finished training job found associated with this estimator. Please make sure'
296+
'this estimator is only used for building workflow config')
297+
model_uri = os.path.join(self.output_path, self._current_job_name, 'output', 'model.tar.gz')
298+
return model_uri
293299

294300
@abstractmethod
295301
def create_model(self, **kwargs):

src/sagemaker/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
166166
self.bucket, self.key_prefix = parse_s3_url(code_location)
167167
else:
168168
self.bucket, self.key_prefix = None, None
169+
self.uploaded_code = None
169170

170171
def prepare_container_def(self, instance_type): # pylint disable=unused-argument
171172
"""Return a container definition with framework configuration set in model environment variables.

src/sagemaker/workflow/airflow.py

Lines changed: 142 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,41 @@
1515
import os
1616

1717
import sagemaker
18-
from sagemaker import job, model, utils
18+
from sagemaker import fw_utils, job, utils, session, vpc_utils
1919
from sagemaker.amazon import amazon_estimator
2020

2121

2222
def prepare_framework(estimator, s3_operations):
23-
"""Prepare S3 operations (specify where to upload source_dir) and environment variables
23+
"""Prepare S3 operations (specify where to upload `source_dir`) and environment variables
2424
related to framework.
2525
2626
Args:
2727
estimator (sagemaker.estimator.Estimator): The framework estimator to get information from and update.
28-
s3_operations (dict): The dict to specify s3 operations (upload source_dir).
28+
s3_operations (dict): The dict to specify s3 operations (upload `source_dir`).
2929
"""
3030
bucket = estimator.code_location if estimator.code_location else estimator.sagemaker_session._default_bucket
3131
key = '{}/source/sourcedir.tar.gz'.format(estimator._current_job_name)
3232
script = os.path.basename(estimator.entry_point)
3333
if estimator.source_dir and estimator.source_dir.lower().startswith('s3://'):
3434
code_dir = estimator.source_dir
35+
estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
3536
else:
3637
code_dir = 's3://{}/{}'.format(bucket, key)
38+
estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
3739
s3_operations['S3Upload'] = [{
3840
'Path': estimator.source_dir or script,
3941
'Bucket': bucket,
4042
'Key': key,
4143
'Tar': True
4244
}]
43-
estimator._hyperparameters[model.DIR_PARAM_NAME] = code_dir
44-
estimator._hyperparameters[model.SCRIPT_PARAM_NAME] = script
45-
estimator._hyperparameters[model.CLOUDWATCH_METRICS_PARAM_NAME] = estimator.enable_cloudwatch_metrics
46-
estimator._hyperparameters[model.CONTAINER_LOG_LEVEL_PARAM_NAME] = estimator.container_log_level
47-
estimator._hyperparameters[model.JOB_NAME_PARAM_NAME] = estimator._current_job_name
48-
estimator._hyperparameters[model.SAGEMAKER_REGION_PARAM_NAME] = estimator.sagemaker_session.boto_region_name
45+
estimator._hyperparameters[sagemaker.model.DIR_PARAM_NAME] = code_dir
46+
estimator._hyperparameters[sagemaker.model.SCRIPT_PARAM_NAME] = script
47+
estimator._hyperparameters[sagemaker.model.CLOUDWATCH_METRICS_PARAM_NAME] = \
48+
estimator.enable_cloudwatch_metrics
49+
estimator._hyperparameters[sagemaker.model.CONTAINER_LOG_LEVEL_PARAM_NAME] = estimator.container_log_level
50+
estimator._hyperparameters[sagemaker.model.JOB_NAME_PARAM_NAME] = estimator._current_job_name
51+
estimator._hyperparameters[sagemaker.model.SAGEMAKER_REGION_PARAM_NAME] = \
52+
estimator.sagemaker_session.boto_region_name
4953

5054

5155
def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None):
@@ -102,8 +106,8 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
102106
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
103107
Amazon algorithm. For other estimators, batch size should be specified in the estimator.
104108
105-
Returns (dict):
106-
Training config that can be directly used by SageMakerTrainingOperator in Airflow.
109+
Returns:
110+
dict: Training config that can be directly used by SageMakerTrainingOperator in Airflow.
107111
"""
108112
default_bucket = estimator.sagemaker_session.default_bucket()
109113
s3_operations = {}
@@ -181,8 +185,8 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
181185
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
182186
Amazon algorithm. For other estimators, batch size should be specified in the estimator.
183187
184-
Returns (dict):
185-
Training config that can be directly used by SageMakerTrainingOperator in Airflow.
188+
Returns:
189+
dict: Training config that can be directly used by SageMakerTrainingOperator in Airflow.
186190
"""
187191

188192
train_config = training_base_config(estimator, inputs, job_name, mini_batch_size)
@@ -219,8 +223,8 @@ def tuning_config(tuner, inputs, job_name=None):
219223
220224
job_name (str): Specify a tuning job name if needed.
221225
222-
Returns (dict):
223-
Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
226+
Returns:
227+
dict: Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
224228
"""
225229
train_config = training_base_config(tuner.estimator, inputs)
226230
hyperparameters = train_config.pop('HyperParameters', None)
@@ -269,3 +273,126 @@ def tuning_config(tuner, inputs, job_name=None):
269273
tune_config['S3Operations'] = s3_operations
270274

271275
return tune_config
276+
277+
278+
def prepare_framework_container_def(model, instance_type, s3_operations):
279+
"""Prepare the framework model container information. Specify related S3 operations for Airflow to perform.
280+
(Upload `source_dir`)
281+
282+
Args:
283+
model (sagemaker.model.FrameworkModel): The framework model
284+
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
285+
s3_operations (dict): The dict to specify S3 operations (upload `source_dir`).
286+
287+
Returns:
288+
dict: The container information of this framework model.
289+
"""
290+
deploy_image = model.image
291+
if not deploy_image:
292+
region_name = model.sagemaker_session.boto_session.region_name
293+
deploy_image = fw_utils.create_image_uri(
294+
region_name, model.__framework_name__, instance_type, model.framework_version, model.py_version)
295+
296+
base_name = utils.base_name_from_image(deploy_image)
297+
model.name = model.name or utils.airflow_name_from_base(base_name)
298+
299+
bucket = model.bucket or model.sagemaker_session._default_bucket
300+
script = os.path.basename(model.entry_point)
301+
key = '{}/source/sourcedir.tar.gz'.format(model.name)
302+
303+
if model.source_dir and model.source_dir.lower().startswith('s3://'):
304+
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=model.source_dir, script_name=script)
305+
else:
306+
code_dir = 's3://{}/{}'.format(bucket, key)
307+
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
308+
s3_operations['S3Upload'] = [{
309+
'Path': model.source_dir or script,
310+
'Bucket': bucket,
311+
'Key': key,
312+
'Tar': True
313+
}]
314+
315+
deploy_env = dict(model.env)
316+
deploy_env.update(model._framework_env_vars())
317+
318+
try:
319+
if model.model_server_workers:
320+
deploy_env[sagemaker.model.MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(model.model_server_workers)
321+
except AttributeError:
322+
# This applies to a FrameworkModel which is not SageMaker Deep Learning Framework Model
323+
pass
324+
325+
return sagemaker.container_def(deploy_image, model.model_data, deploy_env)
326+
327+
328+
def model_config(instance_type, model, role=None, image=None):
329+
"""Export Airflow model config from a SageMaker model
330+
331+
Args:
332+
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'
333+
model (sagemaker.model.FrameworkModel): The SageMaker model to export Airflow config from
334+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
335+
image (str): An container image to use for deploying the model
336+
337+
Returns:
338+
dict: Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part
339+
of the config used by SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
340+
"""
341+
s3_operations = {}
342+
model.image = image or model.image
343+
344+
if isinstance(model, sagemaker.model.FrameworkModel):
345+
container_def = prepare_framework_container_def(model, instance_type, s3_operations)
346+
else:
347+
container_def = model.prepare_container_def(instance_type)
348+
base_name = utils.base_name_from_image(container_def['Image'])
349+
model.name = model.name or utils.airflow_name_from_base(base_name)
350+
351+
primary_container = session._expand_container_def(container_def)
352+
353+
config = {
354+
'ModelName': model.name,
355+
'PrimaryContainer': primary_container,
356+
'ExecutionRoleArn': role or model.role
357+
}
358+
359+
if model.vpc_config:
360+
config['VpcConfig'] = model.vpc_config
361+
362+
if s3_operations:
363+
config['S3Operations'] = s3_operations
364+
365+
return config
366+
367+
368+
def model_config_from_estimator(instance_type, estimator, role=None, image=None, model_server_workers=None,
369+
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT):
370+
"""Export Airflow model config from a SageMaker estimator
371+
372+
Args:
373+
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'
374+
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
375+
It has to be an estimator associated with a training job.
376+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
377+
image (str): An container image to use for deploying the model
378+
model_server_workers (int): The number of worker processes used by the inference server.
379+
If None, server will use one worker per vCPU. Only effective when estimator is
380+
SageMaker framework.
381+
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
382+
Default: use subnets and security groups from this Estimator.
383+
* 'Subnets' (list[str]): List of subnet ids.
384+
* 'SecurityGroupIds' (list[str]): List of security group ids.
385+
386+
Returns:
387+
dict: Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part
388+
of the config used by SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
389+
"""
390+
if isinstance(estimator, sagemaker.estimator.Estimator):
391+
model = estimator.create_model(role=role, image=image, vpc_config_override=vpc_config_override)
392+
elif isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
393+
model = estimator.create_model(vpc_config_override=vpc_config_override)
394+
elif isinstance(estimator, sagemaker.estimator.Framework):
395+
model = estimator.create_model(model_server_workers=model_server_workers, role=role,
396+
vpc_config_override=vpc_config_override)
397+
398+
return model_config(instance_type, model, role, image)

0 commit comments

Comments
 (0)