|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""This module contains functions for obtainining JumpStart artifacts.""" |
| 14 | +from __future__ import absolute_import |
| 15 | +from typing import Optional |
| 16 | +from sagemaker import image_uris |
| 17 | +from sagemaker.jumpstart.constants import ( |
| 18 | + JUMPSTART_DEFAULT_REGION_NAME, |
| 19 | + INFERENCE, |
| 20 | + TRAINING, |
| 21 | + SUPPORTED_JUMPSTART_SCOPES, |
| 22 | +) |
| 23 | +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket |
| 24 | +from sagemaker.jumpstart import accessors as jumpstart_accessors |
| 25 | + |
| 26 | + |
| 27 | +def _retrieve_image_uri( |
| 28 | + model_id: str, |
| 29 | + model_version: str, |
| 30 | + framework: Optional[str], |
| 31 | + region: Optional[str], |
| 32 | + version: Optional[str], |
| 33 | + py_version: Optional[str], |
| 34 | + instance_type: Optional[str], |
| 35 | + accelerator_type: Optional[str], |
| 36 | + image_scope: Optional[str], |
| 37 | + container_version: Optional[str], |
| 38 | + distribution: Optional[str], |
| 39 | + base_framework_version: Optional[str], |
| 40 | + training_compiler_config: Optional[str], |
| 41 | +): |
| 42 | + """Retrieves the container image URI for JumpStart models. |
| 43 | +
|
| 44 | + Only `model_id` and `model_version` are required to be non-None; |
| 45 | + the rest of the fields are auto-populated. |
| 46 | +
|
| 47 | +
|
| 48 | + Args: |
| 49 | + model_id (str): JumpStart model id for which to retrieve image URI. |
| 50 | + model_version (str): JumpStart model version for which to retrieve image URI. |
| 51 | + framework (str): The name of the framework or algorithm. |
| 52 | + region (str): The AWS region. |
| 53 | + version (str): The framework or algorithm version. This is required if there is |
| 54 | + more than one supported version for the given framework or algorithm. |
| 55 | + py_version (str): The Python version. This is required if there is |
| 56 | + more than one supported Python version for the given framework version. |
| 57 | + instance_type (str): The SageMaker instance type. For supported types, see |
| 58 | + https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if |
| 59 | + there are different images for different processor types. |
| 60 | + accelerator_type (str): Elastic Inference accelerator type. For more, see |
| 61 | + https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html. |
| 62 | + image_scope (str): The image type, i.e. what it is used for. |
| 63 | + Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, |
| 64 | + ``image_scope`` is ignored. |
| 65 | + container_version (str): the version of docker image. |
| 66 | + Ideally the value of parameter should be created inside the framework. |
| 67 | + For custom use, see the list of supported container versions: |
| 68 | + https://github.com/aws/deep-learning-containers/blob/master/available_images.md |
| 69 | + (default: None). |
| 70 | + distribution (dict): A dictionary with information on how to run distributed training |
| 71 | + training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): |
| 72 | + A configuration class for the SageMaker Training Compiler |
| 73 | + (default: None). |
| 74 | +
|
| 75 | + Returns: |
| 76 | + str: the ECR URI for the corresponding SageMaker Docker image. |
| 77 | +
|
| 78 | + Raises: |
| 79 | + ValueError: If the combination of arguments specified is not supported. |
| 80 | + """ |
| 81 | + if region is None: |
| 82 | + region = JUMPSTART_DEFAULT_REGION_NAME |
| 83 | + |
| 84 | + assert region is not None |
| 85 | + |
| 86 | + if image_scope is None: |
| 87 | + raise ValueError( |
| 88 | + "Must specify `image_scope` argument to retrieve image uri for JumpStart models." |
| 89 | + ) |
| 90 | + if image_scope not in SUPPORTED_JUMPSTART_SCOPES: |
| 91 | + raise ValueError("JumpStart models only support inference and training.") |
| 92 | + |
| 93 | + model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs( |
| 94 | + region, model_id, model_version |
| 95 | + ) |
| 96 | + |
| 97 | + if image_scope == INFERENCE: |
| 98 | + ecr_specs = model_specs.hosting_ecr_specs |
| 99 | + elif image_scope == TRAINING: |
| 100 | + if not model_specs.training_supported: |
| 101 | + raise ValueError(f"JumpStart model id '{model_id}' does not support training.") |
| 102 | + assert model_specs.training_ecr_specs is not None |
| 103 | + ecr_specs = model_specs.training_ecr_specs |
| 104 | + |
| 105 | + if framework is not None and framework != ecr_specs.framework: |
| 106 | + raise ValueError(f"Bad value for container framework for JumpStart model: '{framework}'.") |
| 107 | + |
| 108 | + if version is not None and version != ecr_specs.framework_version: |
| 109 | + raise ValueError( |
| 110 | + f"Bad value for container framework version for JumpStart model: '{version}'." |
| 111 | + ) |
| 112 | + |
| 113 | + if py_version is not None and py_version != ecr_specs.py_version: |
| 114 | + raise ValueError( |
| 115 | + f"Bad value for container python version for JumpStart model: '{py_version}'." |
| 116 | + ) |
| 117 | + |
| 118 | + if framework == "huggingface": |
| 119 | + base_framework_version = ecr_specs.framework_version |
| 120 | + |
| 121 | + return image_uris.retrieve( |
| 122 | + framework=ecr_specs.framework, |
| 123 | + region=region, |
| 124 | + version=ecr_specs.framework_version, |
| 125 | + py_version=ecr_specs.py_version, |
| 126 | + instance_type=instance_type, |
| 127 | + accelerator_type=accelerator_type, |
| 128 | + image_scope=image_scope, |
| 129 | + container_version=container_version, |
| 130 | + distribution=distribution, |
| 131 | + base_framework_version=base_framework_version, |
| 132 | + training_compiler_config=training_compiler_config, |
| 133 | + ) |
| 134 | + |
| 135 | + |
| 136 | +def _retrieve_model_uri( |
| 137 | + model_id: str, |
| 138 | + model_version: str, |
| 139 | + model_scope: Optional[str], |
| 140 | + region: Optional[str], |
| 141 | +): |
| 142 | + """Retrieves the model artifact S3 URI for the model matching the given arguments. |
| 143 | +
|
| 144 | + Args: |
| 145 | + model_id (str): JumpStart model id for which to retrieve model S3 URI. |
| 146 | + model_version (str): JumpStart model version for which to retrieve model S3 URI. |
| 147 | + model_scope (str): The model type, i.e. what it is used for. |
| 148 | + Valid values: "training" and "inference". |
| 149 | + region (str): Region for which to retrieve model S3 URI. |
| 150 | + Returns: |
| 151 | + str: the model artifact S3 URI for the corresponding model. |
| 152 | +
|
| 153 | + Raises: |
| 154 | + ValueError: If the combination of arguments specified is not supported. |
| 155 | + """ |
| 156 | + if region is None: |
| 157 | + region = JUMPSTART_DEFAULT_REGION_NAME |
| 158 | + |
| 159 | + assert region is not None |
| 160 | + |
| 161 | + if model_scope is None: |
| 162 | + raise ValueError( |
| 163 | + "Must specify `model_scope` argument to retrieve model " |
| 164 | + "artifact uri for JumpStart models." |
| 165 | + ) |
| 166 | + |
| 167 | + if model_scope not in SUPPORTED_JUMPSTART_SCOPES: |
| 168 | + raise ValueError("JumpStart models only support inference and training.") |
| 169 | + |
| 170 | + model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs( |
| 171 | + region, model_id, model_version |
| 172 | + ) |
| 173 | + if model_scope == INFERENCE: |
| 174 | + model_artifact_key = model_specs.hosting_artifact_key |
| 175 | + elif model_scope == TRAINING: |
| 176 | + if not model_specs.training_supported: |
| 177 | + raise ValueError(f"JumpStart model id '{model_id}' does not support training.") |
| 178 | + assert model_specs.training_artifact_key is not None |
| 179 | + model_artifact_key = model_specs.training_artifact_key |
| 180 | + |
| 181 | + bucket = get_jumpstart_content_bucket(region) |
| 182 | + |
| 183 | + model_s3_uri = f"s3://{bucket}/{model_artifact_key}" |
| 184 | + |
| 185 | + return model_s3_uri |
| 186 | + |
| 187 | + |
| 188 | +def _retrieve_script_uri( |
| 189 | + model_id: str, |
| 190 | + model_version: str, |
| 191 | + script_scope: Optional[str], |
| 192 | + region: Optional[str], |
| 193 | +): |
| 194 | + """Retrieves the model script s3 URI for the model matching the given arguments. |
| 195 | +
|
| 196 | + Args: |
| 197 | + model_id (str): JumpStart model id for which to retrieve model script S3 URI. |
| 198 | + model_version (str): JumpStart model version for which to retrieve model script S3 URI. |
| 199 | + script_scope (str): The script type, i.e. what it is used for. |
| 200 | + Valid values: "training" and "inference". |
| 201 | + region (str): Region for which to retrieve model script S3 URI. |
| 202 | + Returns: |
| 203 | + str: the model script URI for the corresponding model. |
| 204 | +
|
| 205 | + Raises: |
| 206 | + ValueError: If the combination of arguments specified is not supported. |
| 207 | + """ |
| 208 | + if region is None: |
| 209 | + region = JUMPSTART_DEFAULT_REGION_NAME |
| 210 | + |
| 211 | + assert region is not None |
| 212 | + |
| 213 | + if script_scope is None: |
| 214 | + raise ValueError( |
| 215 | + "Must specify `script_scope` argument to retrieve model script uri for " |
| 216 | + "JumpStart models." |
| 217 | + ) |
| 218 | + |
| 219 | + if script_scope not in SUPPORTED_JUMPSTART_SCOPES: |
| 220 | + raise ValueError("JumpStart models only support inference and training.") |
| 221 | + |
| 222 | + model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs( |
| 223 | + region, model_id, model_version |
| 224 | + ) |
| 225 | + if script_scope == INFERENCE: |
| 226 | + model_script_key = model_specs.hosting_script_key |
| 227 | + elif script_scope == TRAINING: |
| 228 | + if not model_specs.training_supported: |
| 229 | + raise ValueError(f"JumpStart model id '{model_id}' does not support training.") |
| 230 | + assert model_specs.training_script_key is not None |
| 231 | + model_script_key = model_specs.training_script_key |
| 232 | + |
| 233 | + bucket = get_jumpstart_content_bucket(region) |
| 234 | + |
| 235 | + script_s3_uri = f"s3://{bucket}/{model_script_key}" |
| 236 | + |
| 237 | + return script_s3_uri |
0 commit comments