-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add APIs to export transform and deploy config #497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
43a2e53
7daf88d
18eec18
f022323
ab57396
ca0357c
901f0bf
6f90472
30f3030
0e4e314
d1a005e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -295,6 +295,7 @@ def model_data(self): | |
logging.warning('No finished training job found associated with this estimator. Please make sure' | ||
'this estimator is only used for building workflow config') | ||
model_uri = os.path.join(self.output_path, self._current_job_name, 'output', 'model.tar.gz') | ||
|
||
return model_uri | ||
|
||
@abstractmethod | ||
|
@@ -390,9 +391,13 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit | |
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML | ||
compute instance (default: None). | ||
""" | ||
self._ensure_latest_training_job() | ||
if self.latest_training_job is not None: | ||
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role) | ||
else: | ||
logging.warning('No finished training job found associated with this estimator. Please make sure' | ||
'this estimator is only used for building workflow config') | ||
model_name = self._current_job_name | ||
|
||
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role) | ||
tags = tags or self.tags | ||
|
||
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, | ||
|
@@ -879,19 +884,23 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit | |
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML | ||
compute instance (default: None). | ||
""" | ||
self._ensure_latest_training_job() | ||
role = role or self.role | ||
|
||
model = self.create_model(role=role, model_server_workers=model_server_workers) | ||
|
||
container_def = model.prepare_container_def(instance_type) | ||
model_name = model.name or name_from_image(container_def['Image']) | ||
vpc_config = model.vpc_config | ||
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config) | ||
|
||
transform_env = model.env.copy() | ||
if env is not None: | ||
transform_env.update(env) | ||
if self.latest_training_job is not None: | ||
model = self.create_model(role=role, model_server_workers=model_server_workers) | ||
|
||
container_def = model.prepare_container_def(instance_type) | ||
model_name = model.name or name_from_image(container_def['Image']) | ||
vpc_config = model.vpc_config | ||
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config) | ||
transform_env = model.env.copy() | ||
if env is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Answered above. |
||
transform_env.update(env) | ||
else: | ||
logging.warning('No finished training job found associated with this estimator. Please make sure' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are there unit tests for this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep the transformer from estimator tests will cover this. |
||
'this estimator is only used for building workflow config') | ||
model_name = self._current_job_name | ||
transform_env = env or {} | ||
|
||
tags = tags or self.tags | ||
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -376,8 +376,8 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None, | |
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model | ||
image (str): An container image to use for deploying the model | ||
model_server_workers (int): The number of worker processes used by the inference server. | ||
If None, server will use one worker per vCPU. Only effective when estimator is | ||
SageMaker framework. | ||
If None, server will use one worker per vCPU. Only effective when estimator is a | ||
SageMaker framework. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/is/is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model. | ||
Default: use subnets and security groups from this Estimator. | ||
* 'Subnets' (list[str]): List of subnet ids. | ||
|
@@ -394,5 +394,223 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None, | |
elif isinstance(estimator, sagemaker.estimator.Framework): | ||
model = estimator.create_model(model_server_workers=model_server_workers, role=role, | ||
vpc_config_override=vpc_config_override) | ||
else: | ||
raise TypeError('Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework' | ||
' or sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.') | ||
|
||
return model_config(instance_type, model, role, image) | ||
|
||
|
||
def transform_config(transformer, data, data_type='S3Prefix', content_type=None, compression_type=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these kinds of methods potentially reusable for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep I do want it to be reused in session. Now I cannot because I need to bypass all the validations/describe calls/etc to use Jinja templating. If in session, we can construct the config first, and then apply all manipulations afterwards. That would be awesome and these two parts can be coupled. (which is really good since for now I am worried if session part got updated, for example, new entries in boto config introduced, airflow part cannot catch it) |
||
split_type=None, job_name=None): | ||
"""Export Airflow transform config from a SageMaker transformer | ||
|
||
Args: | ||
transformer (sagemaker.transformer.Transformer): The SageMaker transformer to export Airflow | ||
config from. | ||
data (str): Input data location in S3. | ||
data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values: | ||
|
||
* 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as | ||
inputs for the transform job. | ||
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as | ||
an input for the transform job. | ||
|
||
content_type (str): MIME type of the input data (default: None). | ||
compression_type (str): Compression type of the input data, if compressed (default: None). | ||
Valid values: 'Gzip', None. | ||
split_type (str): The record delimiter for the input object (default: 'None'). | ||
Valid values: 'None', 'Line', and 'RecordIO'. | ||
job_name (str): job name (default: None). If not specified, one will be generated. | ||
|
||
Returns: | ||
dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow. | ||
""" | ||
if job_name is not None: | ||
transformer._current_job_name = job_name | ||
else: | ||
base_name = transformer.base_transform_job_name | ||
transformer._current_job_name = utils.airflow_name_from_base(base_name) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. transformer._current_job_name = utils.airflow_name_from_base(base_name) if base_name else transformer.model_name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Answered above. |
||
if base_name is not None else transformer.model_name | ||
|
||
if transformer.output_path is None: | ||
transformer.output_path = 's3://{}/{}'.format( | ||
transformer.sagemaker_session.default_bucket(), transformer._current_job_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this logic that's also in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep it's similar except the naming method is different. One is airflow_name_from_base and the other is name_from_base. So if I wrap them in a method, that method needs to take a method as input arg which I don't want to do for just these small amounts of codes for now. If we apply the different methods to get name first, then there's just one line left which I guess no need to wrap. I am still targeting at what I mentioned in other comments that we could reformat the codes in session/estimator/transformer/etc in a way that we could wrap a lot codes in methods and couple airflow with them. |
||
|
||
job_config = sagemaker.transformer._TransformJob._load_config( | ||
data, data_type, content_type, compression_type, split_type, transformer) | ||
|
||
config = { | ||
'TransformJobName': transformer._current_job_name, | ||
'ModelName': transformer.model_name, | ||
'TransformInput': job_config['input_config'], | ||
'TransformOutput': job_config['output_config'], | ||
'TransformResources': job_config['resource_config'], | ||
} | ||
|
||
if transformer.strategy is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to my other comments, do all these as if transformer.strategy: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Answered above. |
||
config['BatchStrategy'] = transformer.strategy | ||
|
||
if transformer.max_concurrent_transforms is not None: | ||
config['MaxConcurrentTransforms'] = transformer.max_concurrent_transforms | ||
|
||
if transformer.max_payload is not None: | ||
config['MaxPayloadInMB'] = transformer.max_payload | ||
|
||
if transformer.env is not None: | ||
config['Environment'] = transformer.env | ||
|
||
if transformer.tags is not None: | ||
config['Tags'] = transformer.tags | ||
|
||
return config | ||
|
||
|
||
def transform_config_from_estimator(estimator, instance_count, instance_type, data, data_type='S3Prefix', | ||
content_type=None, compression_type=None, split_type=None, | ||
job_name=None, strategy=None, assemble_with=None, output_path=None, | ||
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, | ||
max_payload=None, tags=None, role=None, volume_kms_key=None, | ||
model_server_workers=None, image=None, vpc_config_override=None): | ||
"""Export Airflow transform config from a SageMaker estimator | ||
|
||
Args: | ||
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from. | ||
It has to be an estimator associated with a training job. | ||
instance_count (int): Number of EC2 instances to use. | ||
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'. | ||
data (str): Input data location in S3. | ||
data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values: | ||
|
||
* 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as | ||
inputs for the transform job. | ||
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as | ||
an input for the transform job. | ||
|
||
content_type (str): MIME type of the input data (default: None). | ||
compression_type (str): Compression type of the input data, if compressed (default: None). | ||
Valid values: 'Gzip', None. | ||
split_type (str): The record delimiter for the input object (default: 'None'). | ||
Valid values: 'None', 'Line', and 'RecordIO'. | ||
job_name (str): job name (default: None). If not specified, one will be generated. | ||
strategy (str): The strategy used to decide how to batch records in a single request (default: None). | ||
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'. | ||
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'. | ||
output_path (str): S3 location for saving the transform result. If not specified, results are stored to | ||
a default bucket. | ||
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None). | ||
accept (str): The content type accepted by the endpoint deployed during the transform job. | ||
env (dict): Environment variables to be set for use during the transform job (default: None). | ||
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to | ||
each individual transform container at one time. | ||
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. | ||
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for | ||
the training job are used for the transform job. | ||
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during | ||
transform jobs. If not specified, the role from the Estimator will be used. | ||
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML | ||
compute instance (default: None). | ||
model_server_workers (int): Optional. The number of worker processes used by the inference server. | ||
If None, server will use one worker per vCPU. | ||
image (str): An container image to use for deploying the model | ||
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model. | ||
Default: use subnets and security groups from this Estimator. | ||
* 'Subnets' (list[str]): List of subnet ids. | ||
* 'SecurityGroupIds' (list[str]): List of security group ids. | ||
|
||
Returns: | ||
dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow. | ||
""" | ||
model_base_config = model_config_from_estimator(instance_type=instance_type, estimator=estimator, role=role, | ||
image=image, model_server_workers=model_server_workers, | ||
vpc_config_override=vpc_config_override) | ||
|
||
if isinstance(estimator, sagemaker.estimator.Framework): | ||
transformer = estimator.transformer(instance_count, instance_type, strategy, assemble_with, output_path, | ||
output_kms_key, accept, env, max_concurrent_transforms, | ||
max_payload, tags, role, model_server_workers, volume_kms_key) | ||
else: | ||
transformer = estimator.transformer(instance_count, instance_type, strategy, assemble_with, output_path, | ||
output_kms_key, accept, env, max_concurrent_transforms, | ||
max_payload, tags, role, volume_kms_key) | ||
|
||
transform_base_config = transform_config(transformer, data, data_type, content_type, compression_type, | ||
split_type, job_name) | ||
|
||
config = { | ||
'Model': model_base_config, | ||
'Transform': transform_base_config | ||
} | ||
|
||
return config | ||
|
||
|
||
def deploy_config(model, initial_instance_count, instance_type, endpoint_name=None, tags=None): | ||
"""Export Airflow deploy config from a SageMaker model | ||
|
||
Args: | ||
model (sagemaker.model.Model): The SageMaker model to export the Airflow config from. | ||
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. | ||
initial_instance_count (int): The initial number of instances to run in the | ||
``Endpoint`` created from this ``Model``. | ||
endpoint_name (str): The name of the endpoint to create (default: None). | ||
If not specified, a unique endpoint name will be created. | ||
tags (list[dict]): List of tags for labeling a training job. For more, see | ||
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. | ||
|
||
Returns: | ||
dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow. | ||
|
||
""" | ||
model_base_config = model_config(instance_type, model) | ||
|
||
production_variant = sagemaker.production_variant(model.name, instance_type, initial_instance_count) | ||
name = model.name | ||
config_options = {'EndpointConfigName': name, 'ProductionVariants': [production_variant]} | ||
if tags is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if tags: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Answered above. |
||
config_options['Tags'] = tags | ||
|
||
endpoint_name = endpoint_name or name | ||
endpoint_base_config = { | ||
'EndpointName': endpoint_name, | ||
'EndpointConfigName': name | ||
} | ||
|
||
config = { | ||
'Model': model_base_config, | ||
'EndpointConfig': config_options, | ||
'Endpoint': endpoint_base_config | ||
} | ||
|
||
# if there is s3 operations needed for model, move it to root level of config | ||
s3_operations = model_base_config.pop('S3Operations', None) | ||
if s3_operations is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ^^ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Answered above. |
||
config['S3Operations'] = s3_operations | ||
|
||
return config | ||
|
||
|
||
def deploy_config_from_estimator(estimator, initial_instance_count, instance_type, endpoint_name=None, | ||
tags=None, **kwargs): | ||
"""Export Airflow deploy config from a SageMaker estimator | ||
|
||
Args: | ||
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from. | ||
It has to be an estimator associated with a training job. | ||
initial_instance_count (int): Minimum number of EC2 instances to deploy to an endpoint for prediction. | ||
instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction, | ||
for example, 'ml.c4.xlarge'. | ||
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of | ||
the training job is used. | ||
tags (list[dict]): List of tags for labeling a training job. For more, see | ||
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. | ||
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize | ||
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. | ||
For more, see the implementation docs. | ||
|
||
Returns: | ||
dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow. | ||
""" | ||
model = estimator.create_model(**kwargs) | ||
config = deploy_config(model, initial_instance_count, instance_type, endpoint_name, tags) | ||
return config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can just do
if self.latest_training_job:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep that looks nice. But the problem is we won't get into
if var
if var is things like [], or {}. Not just None. Hence I kind of want to explicitly say the only thing I don't want is None. Especially sometimes I do have some var = {} and need to do var.update() after. If there's an if in between, I do need to deal with the {} case.
Yep we probably should do
if var
most of times and
if var is not None
only when needed. But using the second one for now is kind of consistent to what we have in the existed classes (like estimator, model, etc). I prefer not to change for now. If later when we introduce style guide and can force it everywhere, we probably could do the change for all.