|
15 | 15 | import os
|
16 | 16 |
|
17 | 17 | import sagemaker
|
18 |
| -from sagemaker import job, model, utils |
| 18 | +from sagemaker import fw_utils, job, utils, session, vpc_utils |
19 | 19 | from sagemaker.amazon import amazon_estimator
|
20 | 20 |
|
21 | 21 |
|
22 | 22 | def prepare_framework(estimator, s3_operations):
|
23 |
| - """Prepare S3 operations (specify where to upload source_dir) and environment variables |
| 23 | + """Prepare S3 operations (specify where to upload `source_dir`) and environment variables |
24 | 24 | related to framework.
|
25 | 25 |
|
26 | 26 | Args:
|
27 | 27 | estimator (sagemaker.estimator.Estimator): The framework estimator to get information from and update.
|
28 |
| - s3_operations (dict): The dict to specify s3 operations (upload source_dir). |
| 28 | + s3_operations (dict): The dict to specify s3 operations (upload `source_dir`). |
29 | 29 | """
|
30 | 30 | bucket = estimator.code_location if estimator.code_location else estimator.sagemaker_session._default_bucket
|
31 | 31 | key = '{}/source/sourcedir.tar.gz'.format(estimator._current_job_name)
|
32 | 32 | script = os.path.basename(estimator.entry_point)
|
33 | 33 | if estimator.source_dir and estimator.source_dir.lower().startswith('s3://'):
|
34 | 34 | code_dir = estimator.source_dir
|
| 35 | + estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) |
35 | 36 | else:
|
36 | 37 | code_dir = 's3://{}/{}'.format(bucket, key)
|
| 38 | + estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) |
37 | 39 | s3_operations['S3Upload'] = [{
|
38 | 40 | 'Path': estimator.source_dir or script,
|
39 | 41 | 'Bucket': bucket,
|
40 | 42 | 'Key': key,
|
41 | 43 | 'Tar': True
|
42 | 44 | }]
|
43 |
| - estimator._hyperparameters[model.DIR_PARAM_NAME] = code_dir |
44 |
| - estimator._hyperparameters[model.SCRIPT_PARAM_NAME] = script |
45 |
| - estimator._hyperparameters[model.CLOUDWATCH_METRICS_PARAM_NAME] = estimator.enable_cloudwatch_metrics |
46 |
| - estimator._hyperparameters[model.CONTAINER_LOG_LEVEL_PARAM_NAME] = estimator.container_log_level |
47 |
| - estimator._hyperparameters[model.JOB_NAME_PARAM_NAME] = estimator._current_job_name |
48 |
| - estimator._hyperparameters[model.SAGEMAKER_REGION_PARAM_NAME] = estimator.sagemaker_session.boto_region_name |
| 45 | + estimator._hyperparameters[sagemaker.model.DIR_PARAM_NAME] = code_dir |
| 46 | + estimator._hyperparameters[sagemaker.model.SCRIPT_PARAM_NAME] = script |
| 47 | + estimator._hyperparameters[sagemaker.model.CLOUDWATCH_METRICS_PARAM_NAME] = \ |
| 48 | + estimator.enable_cloudwatch_metrics |
| 49 | + estimator._hyperparameters[sagemaker.model.CONTAINER_LOG_LEVEL_PARAM_NAME] = estimator.container_log_level |
| 50 | + estimator._hyperparameters[sagemaker.model.JOB_NAME_PARAM_NAME] = estimator._current_job_name |
| 51 | + estimator._hyperparameters[sagemaker.model.SAGEMAKER_REGION_PARAM_NAME] = \ |
| 52 | + estimator.sagemaker_session.boto_region_name |
49 | 53 |
|
50 | 54 |
|
51 | 55 | def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None):
|
@@ -102,8 +106,8 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
|
102 | 106 | mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
|
103 | 107 | Amazon algorithm. For other estimators, batch size should be specified in the estimator.
|
104 | 108 |
|
105 |
| - Returns (dict): |
106 |
| - Training config that can be directly used by SageMakerTrainingOperator in Airflow. |
| 109 | + Returns: |
| 110 | + dict: Training config that can be directly used by SageMakerTrainingOperator in Airflow. |
107 | 111 | """
|
108 | 112 | default_bucket = estimator.sagemaker_session.default_bucket()
|
109 | 113 | s3_operations = {}
|
@@ -181,8 +185,8 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
|
181 | 185 | mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
|
182 | 186 | Amazon algorithm. For other estimators, batch size should be specified in the estimator.
|
183 | 187 |
|
184 |
| - Returns (dict): |
185 |
| - Training config that can be directly used by SageMakerTrainingOperator in Airflow. |
| 188 | + Returns: |
| 189 | + dict: Training config that can be directly used by SageMakerTrainingOperator in Airflow. |
186 | 190 | """
|
187 | 191 |
|
188 | 192 | train_config = training_base_config(estimator, inputs, job_name, mini_batch_size)
|
@@ -219,8 +223,8 @@ def tuning_config(tuner, inputs, job_name=None):
|
219 | 223 |
|
220 | 224 | job_name (str): Specify a tuning job name if needed.
|
221 | 225 |
|
222 |
| - Returns (dict): |
223 |
| - Tuning config that can be directly used by SageMakerTuningOperator in Airflow. |
| 226 | + Returns: |
| 227 | + dict: Tuning config that can be directly used by SageMakerTuningOperator in Airflow. |
224 | 228 | """
|
225 | 229 | train_config = training_base_config(tuner.estimator, inputs)
|
226 | 230 | hyperparameters = train_config.pop('HyperParameters', None)
|
@@ -269,3 +273,126 @@ def tuning_config(tuner, inputs, job_name=None):
|
269 | 273 | tune_config['S3Operations'] = s3_operations
|
270 | 274 |
|
271 | 275 | return tune_config
|
| 276 | + |
| 277 | + |
| 278 | +def prepare_framework_container_def(model, instance_type, s3_operations): |
| 279 | + """Prepare the framework model container information. Specify related S3 operations for Airflow to perform. |
| 280 | + (Upload `source_dir`) |
| 281 | +
|
| 282 | + Args: |
| 283 | + model (sagemaker.model.FrameworkModel): The framework model |
| 284 | + instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. |
| 285 | + s3_operations (dict): The dict to specify S3 operations (upload `source_dir`). |
| 286 | +
|
| 287 | + Returns: |
| 288 | + dict: The container information of this framework model. |
| 289 | + """ |
| 290 | + deploy_image = model.image |
| 291 | + if not deploy_image: |
| 292 | + region_name = model.sagemaker_session.boto_session.region_name |
| 293 | + deploy_image = fw_utils.create_image_uri( |
| 294 | + region_name, model.__framework_name__, instance_type, model.framework_version, model.py_version) |
| 295 | + |
| 296 | + base_name = utils.base_name_from_image(deploy_image) |
| 297 | + model.name = model.name or utils.airflow_name_from_base(base_name) |
| 298 | + |
| 299 | + bucket = model.bucket or model.sagemaker_session._default_bucket |
| 300 | + script = os.path.basename(model.entry_point) |
| 301 | + key = '{}/source/sourcedir.tar.gz'.format(model.name) |
| 302 | + |
| 303 | + if model.source_dir and model.source_dir.lower().startswith('s3://'): |
| 304 | + model.uploaded_code = fw_utils.UploadedCode(s3_prefix=model.source_dir, script_name=script) |
| 305 | + else: |
| 306 | + code_dir = 's3://{}/{}'.format(bucket, key) |
| 307 | + model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) |
| 308 | + s3_operations['S3Upload'] = [{ |
| 309 | + 'Path': model.source_dir or script, |
| 310 | + 'Bucket': bucket, |
| 311 | + 'Key': key, |
| 312 | + 'Tar': True |
| 313 | + }] |
| 314 | + |
| 315 | + deploy_env = dict(model.env) |
| 316 | + deploy_env.update(model._framework_env_vars()) |
| 317 | + |
| 318 | + try: |
| 319 | + if model.model_server_workers: |
| 320 | + deploy_env[sagemaker.model.MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(model.model_server_workers) |
| 321 | + except AttributeError: |
| 322 | + # This applies to a FrameworkModel which is not SageMaker Deep Learning Framework Model |
| 323 | + pass |
| 324 | + |
| 325 | + return sagemaker.container_def(deploy_image, model.model_data, deploy_env) |
| 326 | + |
| 327 | + |
| 328 | +def model_config(instance_type, model, role=None, image=None): |
| 329 | + """Export Airflow model config from a SageMaker model |
| 330 | +
|
| 331 | + Args: |
| 332 | + instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge' |
| 333 | + model (sagemaker.model.FrameworkModel): The SageMaker model to export Airflow config from |
| 334 | + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model |
| 335 | + image (str): An container image to use for deploying the model |
| 336 | +
|
| 337 | + Returns: |
| 338 | + dict: Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part |
| 339 | + of the config used by SageMakerEndpointOperator and SageMakerTransformOperator in Airflow. |
| 340 | + """ |
| 341 | + s3_operations = {} |
| 342 | + model.image = image or model.image |
| 343 | + |
| 344 | + if isinstance(model, sagemaker.model.FrameworkModel): |
| 345 | + container_def = prepare_framework_container_def(model, instance_type, s3_operations) |
| 346 | + else: |
| 347 | + container_def = model.prepare_container_def(instance_type) |
| 348 | + base_name = utils.base_name_from_image(container_def['Image']) |
| 349 | + model.name = model.name or utils.airflow_name_from_base(base_name) |
| 350 | + |
| 351 | + primary_container = session._expand_container_def(container_def) |
| 352 | + |
| 353 | + config = { |
| 354 | + 'ModelName': model.name, |
| 355 | + 'PrimaryContainer': primary_container, |
| 356 | + 'ExecutionRoleArn': role or model.role |
| 357 | + } |
| 358 | + |
| 359 | + if model.vpc_config: |
| 360 | + config['VpcConfig'] = model.vpc_config |
| 361 | + |
| 362 | + if s3_operations: |
| 363 | + config['S3Operations'] = s3_operations |
| 364 | + |
| 365 | + return config |
| 366 | + |
| 367 | + |
| 368 | +def model_config_from_estimator(instance_type, estimator, role=None, image=None, model_server_workers=None, |
| 369 | + vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT): |
| 370 | + """Export Airflow model config from a SageMaker estimator |
| 371 | +
|
| 372 | + Args: |
| 373 | + instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge' |
| 374 | + estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from. |
| 375 | + It has to be an estimator associated with a training job. |
| 376 | + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model |
| 377 | + image (str): An container image to use for deploying the model |
| 378 | + model_server_workers (int): The number of worker processes used by the inference server. |
| 379 | + If None, server will use one worker per vCPU. Only effective when estimator is |
| 380 | + SageMaker framework. |
| 381 | + vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model. |
| 382 | + Default: use subnets and security groups from this Estimator. |
| 383 | + * 'Subnets' (list[str]): List of subnet ids. |
| 384 | + * 'SecurityGroupIds' (list[str]): List of security group ids. |
| 385 | +
|
| 386 | + Returns: |
| 387 | + dict: Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part |
| 388 | + of the config used by SageMakerEndpointOperator and SageMakerTransformOperator in Airflow. |
| 389 | + """ |
| 390 | + if isinstance(estimator, sagemaker.estimator.Estimator): |
| 391 | + model = estimator.create_model(role=role, image=image, vpc_config_override=vpc_config_override) |
| 392 | + elif isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase): |
| 393 | + model = estimator.create_model(vpc_config_override=vpc_config_override) |
| 394 | + elif isinstance(estimator, sagemaker.estimator.Framework): |
| 395 | + model = estimator.create_model(model_server_workers=model_server_workers, role=role, |
| 396 | + vpc_config_override=vpc_config_override) |
| 397 | + |
| 398 | + return model_config(instance_type, model, role, image) |
0 commit comments