|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""This module stores notebook utils related to SageMaker JumpStart.""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +from functools import cmp_to_key |
| 17 | +from typing import Any, Collection, List, Tuple, Union, Set, Dict |
| 18 | +from packaging.version import Version |
| 19 | +from sagemaker.jumpstart import accessors |
| 20 | +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME |
| 21 | +from sagemaker.jumpstart.enums import JumpStartScriptScope |
| 22 | +from sagemaker.jumpstart.utils import get_sagemaker_version |
| 23 | + |
| 24 | + |
| 25 | +def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: |
| 26 | + """Parse the input model id, return a tuple framework, task, rest-of-id. |
| 27 | +
|
| 28 | + Args: |
| 29 | + model_id (str): The model id for which to extract the framework/task/model. |
| 30 | + """ |
| 31 | + _id_parts = model_id.split("-") |
| 32 | + |
| 33 | + if len(_id_parts) < 3: |
| 34 | + raise ValueError(f"incorrect model id: {model_id}.") |
| 35 | + |
| 36 | + framework = _id_parts[0] |
| 37 | + task = _id_parts[1] |
| 38 | + name = "-".join(_id_parts[2:]) |
| 39 | + |
| 40 | + return framework, task, name |
| 41 | + |
| 42 | + |
| 43 | +def _compare_model_version_tuples( # pylint: disable=too-many-return-statements |
| 44 | + model_version_1: Tuple[str, str] = None, model_version_2: Tuple[str, str] = None |
| 45 | +) -> int: |
| 46 | + """Performs comparison of sdk specs paths, in order to sort them in manifest. |
| 47 | +
|
| 48 | + Args: |
| 49 | + model_version_1 (Tuple[str, str]): The first model id and version tuple to compare. |
| 50 | + model_version_2 (Tuple[str, str]): The second model id and version tuple to compare. |
| 51 | + """ |
| 52 | + if model_version_1 is None or model_version_2 is None: |
| 53 | + if model_version_2 is not None: |
| 54 | + return -1 |
| 55 | + if model_version_1 is not None: |
| 56 | + return 1 |
| 57 | + return 0 |
| 58 | + |
| 59 | + version_1 = model_version_1[1] |
| 60 | + model_id_1 = model_version_1[0] |
| 61 | + |
| 62 | + version_2 = model_version_2[1] |
| 63 | + model_id_2 = model_version_2[0] |
| 64 | + |
| 65 | + if model_id_1 < model_id_2: |
| 66 | + return -1 |
| 67 | + |
| 68 | + if model_id_2 < model_id_1: |
| 69 | + return 1 |
| 70 | + |
| 71 | + if Version(version_1) < Version(version_2): |
| 72 | + return 1 |
| 73 | + |
| 74 | + if Version(version_2) < Version(version_1): |
| 75 | + return -1 |
| 76 | + |
| 77 | + return 0 |
| 78 | + |
| 79 | + |
| 80 | +def list_jumpstart_frameworks( |
| 81 | + **kwargs: Dict[str, Any], |
| 82 | +) -> List[str]: |
| 83 | + """List frameworks actively in use by JumpStart. |
| 84 | +
|
| 85 | + Args: |
| 86 | + kwargs (Dict[str, Any]): kwarg arguments to supply to |
| 87 | + ``list_jumpstart_models``. |
| 88 | + """ |
| 89 | + models_list = list_jumpstart_models(**kwargs) |
| 90 | + frameworks = set() |
| 91 | + for model_id, _ in models_list: |
| 92 | + framework, _, _ = extract_framework_task_model(model_id) |
| 93 | + frameworks.add(framework) |
| 94 | + return sorted(list(frameworks)) |
| 95 | + |
| 96 | + |
| 97 | +def list_jumpstart_tasks( |
| 98 | + **kwargs: Dict[str, Any], |
| 99 | +) -> List[str]: |
| 100 | + """List tasks actively in use by JumpStart. |
| 101 | +
|
| 102 | + Args: |
| 103 | + kwargs (Dict[str, Any]): kwarg arguments to supply to |
| 104 | + ``list_jumpstart_models``. |
| 105 | + """ |
| 106 | + models_list = list_jumpstart_models(**kwargs) |
| 107 | + tasks = set() |
| 108 | + for model_id, _ in models_list: |
| 109 | + _, task, _ = extract_framework_task_model(model_id) |
| 110 | + tasks.add(task) |
| 111 | + return sorted(list(tasks)) |
| 112 | + |
| 113 | + |
| 114 | +def list_jumpstart_scripts( |
| 115 | + **kwargs: Dict[str, Any], |
| 116 | +) -> List[str]: |
| 117 | + """List scripts actively in use by JumpStart. |
| 118 | +
|
| 119 | + Note: Using this function will result in slow execution speed, as it requires |
| 120 | + making many http calls and parsing metadata files. To-Do: store script |
| 121 | + information for all models in a single file. |
| 122 | +
|
| 123 | + Check ``sagemaker.jumpstart.enums.JumpStartScriptScope`` for possible types |
| 124 | + of JumpStart scripts. |
| 125 | +
|
| 126 | + Args: |
| 127 | + kwargs (Dict[str, Any]): kwarg arguments to supply to |
| 128 | + ``list_jumpstart_models``. |
| 129 | + """ |
| 130 | + models_list = list_jumpstart_models(**kwargs) |
| 131 | + scripts = set() |
| 132 | + for model_id, version in models_list: |
| 133 | + scripts.add(JumpStartScriptScope.INFERENCE.value) |
| 134 | + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( |
| 135 | + region=kwargs.get("region", JUMPSTART_DEFAULT_REGION_NAME), |
| 136 | + model_id=model_id, |
| 137 | + version=version, |
| 138 | + ) |
| 139 | + if model_specs.training_supported: |
| 140 | + scripts.add(JumpStartScriptScope.TRAINING.value) |
| 141 | + |
| 142 | + if scripts == {e.value for e in JumpStartScriptScope}: |
| 143 | + break |
| 144 | + return sorted(list(scripts)) |
| 145 | + |
| 146 | + |
| 147 | +def list_jumpstart_models( |
| 148 | + script_allowlist: Union[str, Collection[str]] = None, |
| 149 | + task_allowlist: Union[str, Collection[str]] = None, |
| 150 | + framework_allowlist: Union[str, Collection[str]] = None, |
| 151 | + model_id_allowlist: Union[str, Collection[str]] = None, |
| 152 | + script_denylist: Union[str, Collection[str]] = None, |
| 153 | + task_denylist: Union[str, Collection[str]] = None, |
| 154 | + framework_denylist: Union[str, Collection[str]] = None, |
| 155 | + model_id_denylist: Union[str, Collection[str]] = None, |
| 156 | + region: str = JUMPSTART_DEFAULT_REGION_NAME, |
| 157 | + accept_unsupported_models: bool = False, |
| 158 | + accept_old_models: bool = False, |
| 159 | + accept_vulnerable_models: bool = True, |
| 160 | + accept_deprecated_models: bool = True, |
| 161 | +) -> List[str]: |
| 162 | + """List models in JumpStart, and optionally apply filters to result. |
| 163 | +
|
| 164 | + Args: |
| 165 | + script_allowlist (Union[str, Collection[str]]): Optional. String or |
| 166 | + ``Collection`` storing scripts. All models returned by this function |
| 167 | + must use a script which is specified in this argument. Note: Using this |
| 168 | + filter will result in slow execution speed, as it requires making more |
| 169 | + http calls and parsing many metadata files. To-Do: store script |
| 170 | + information for all models in a single file. |
| 171 | + (Default: None). |
| 172 | + task_allowlist (Union[str, Collection[str]]): Optional. String or |
| 173 | + ``Collection`` storing tasks. All models returned by this function |
| 174 | + must use a task which is specified in this argument. |
| 175 | + (Default: None). |
| 176 | + framework_allowlist (Union[str, Collection[str]]): Optional. String or |
| 177 | + ``Collection`` storing frameworks. All models returned by this function |
| 178 | + must use a frameworks which is specified in this argument. |
| 179 | + (Default: None). |
| 180 | + model_id_allowlist (Union[str, Collection[str]]): Optional. String or |
| 181 | + ``Collection`` storing model ids. All models returned by this function |
| 182 | + must use a model id which is specified in this argument. |
| 183 | + (Default: None). |
| 184 | + script_denylist (Union[str, Collection[str]]): Optional. String or |
| 185 | + ``Collection`` storing scripts. All models returned by this function |
| 186 | + must not use a script which is specified in this argument. Note: Using |
| 187 | + this filter will result in slow execution speed, as it requires making |
| 188 | + more http calls and parsing many metadata files. To-Do: store script |
| 189 | + information for all models in a single file. |
| 190 | + (Default: None). |
| 191 | + task_denylist (Union[str, Collection[str]]): Optional. String or |
| 192 | + ``Collection`` storing tasks. All models returned by this function |
| 193 | + must not use a task which is specified in this argument. |
| 194 | + (Default: None). |
| 195 | + framework_denylist (Union[str, Collection[str]]): Optional. String or |
| 196 | + ``Collection`` storing frameworks. All models returned by this function |
| 197 | + must not use a frameworks which is specified in this argument. |
| 198 | + (Default: None). |
| 199 | + model_id_denylist (Union[str, Collection[str]]): Optional. String or |
| 200 | + ``Collection`` storing scripts. All models returned by this function |
| 201 | + must not use a model id which is specified in this argument. |
| 202 | + (Default: None). |
| 203 | + region (str): Optional. Region to use when fetching JumpStart metadata. |
| 204 | + (Default: ``JUMPSTART_DEFAULT_REGION_NAME``). |
| 205 | + accept_unsupported_models (bool): Optional. Set to True to accept models that |
| 206 | + are not supported with the current SageMaker library version |
| 207 | + (Default: False). |
| 208 | + accept_old_models (bool): Optional. Set to True to accept model and version |
| 209 | + tuples for which a model with the same name and a newer version exists. |
| 210 | + (Default: False). |
| 211 | + accept_vulnerable_models (bool): Optional. Set to False to reject models that |
| 212 | + have a vulnerable inference or training script dependency. Note: accessing |
| 213 | + vulnerability information requires making many http calls and parsing many |
| 214 | + metadata files. To-Do: store vulnerability information for all models in a |
| 215 | + single file, and change default value to False. (Default: True). |
| 216 | + accept_deprecated_models (bool): Optional. Set to False to reject models that |
| 217 | + have been flagged as deprecated. Note: accessing deprecation information |
| 218 | + requires making many http calls and parsing many metadata files. To-Do: |
| 219 | + store deprecation information for all models in a single file, and change |
| 220 | + default value to False. (Default: True). |
| 221 | + """ |
| 222 | + bad_script_filter = script_allowlist is not None and script_denylist is not None |
| 223 | + bad_task_filter = task_allowlist is not None and task_denylist is not None |
| 224 | + bad_framework_filter = framework_allowlist is not None and framework_denylist is not None |
| 225 | + bad_model_id_filter = model_id_allowlist is not None and model_id_denylist is not None |
| 226 | + |
| 227 | + if bad_script_filter or bad_task_filter or bad_framework_filter or bad_model_id_filter: |
| 228 | + raise ValueError( |
| 229 | + ( |
| 230 | + "Cannot use an allowlist and denylist at the same time " |
| 231 | + "for a filter (script, task, framework, model id)" |
| 232 | + ) |
| 233 | + ) |
| 234 | + |
| 235 | + if isinstance(script_allowlist, str): |
| 236 | + script_allowlist = set([script_allowlist]) |
| 237 | + |
| 238 | + if isinstance(task_allowlist, str): |
| 239 | + task_allowlist = set([task_allowlist]) |
| 240 | + |
| 241 | + if isinstance(framework_allowlist, str): |
| 242 | + framework_allowlist = set([framework_allowlist]) |
| 243 | + |
| 244 | + if isinstance(model_id_allowlist, str): |
| 245 | + model_id_allowlist = set([model_id_allowlist]) |
| 246 | + |
| 247 | + if isinstance(script_denylist, str): |
| 248 | + script_denylist = set([script_denylist]) |
| 249 | + |
| 250 | + if isinstance(task_denylist, str): |
| 251 | + task_denylist = set([task_denylist]) |
| 252 | + |
| 253 | + if isinstance(framework_denylist, str): |
| 254 | + framework_denylist = set([framework_denylist]) |
| 255 | + |
| 256 | + if isinstance(model_id_denylist, str): |
| 257 | + model_id_denylist = set([model_id_denylist]) |
| 258 | + |
| 259 | + models_manifest_list = accessors.JumpStartModelsAccessor.get_manifest(region=region) |
| 260 | + model_id_version_dict: Dict[str, Set[str]] = dict() |
| 261 | + for model_manifest in models_manifest_list: |
| 262 | + model_id = model_manifest.model_id |
| 263 | + model_version = model_manifest.version |
| 264 | + |
| 265 | + if not accept_unsupported_models and Version(get_sagemaker_version()) < Version( |
| 266 | + model_manifest.min_version |
| 267 | + ): |
| 268 | + continue |
| 269 | + |
| 270 | + if model_id_allowlist is not None: |
| 271 | + model_id_allowlist = set(model_id_allowlist) |
| 272 | + if model_id not in model_id_allowlist: |
| 273 | + continue |
| 274 | + if model_id_denylist is not None: |
| 275 | + model_id_denylist = set(model_id_denylist) |
| 276 | + if model_id in model_id_denylist: |
| 277 | + continue |
| 278 | + |
| 279 | + framework, task, _ = extract_framework_task_model(model_id) |
| 280 | + supported_scripts = set([JumpStartScriptScope.INFERENCE.value]) |
| 281 | + |
| 282 | + if task_allowlist is not None: |
| 283 | + task_allowlist = set(task_allowlist) |
| 284 | + if task not in task_allowlist: |
| 285 | + continue |
| 286 | + if task_denylist is not None: |
| 287 | + task_denylist = set(task_denylist) |
| 288 | + if task in task_denylist: |
| 289 | + continue |
| 290 | + |
| 291 | + if framework_allowlist is not None: |
| 292 | + framework_allowlist = set(framework_allowlist) |
| 293 | + if framework not in framework_allowlist: |
| 294 | + continue |
| 295 | + if framework_denylist is not None: |
| 296 | + framework_denylist = set(framework_denylist) |
| 297 | + if framework in framework_denylist: |
| 298 | + continue |
| 299 | + |
| 300 | + if script_denylist is not None or script_allowlist is not None: |
| 301 | + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( |
| 302 | + region=region, model_id=model_id, version=model_version |
| 303 | + ) |
| 304 | + if model_specs.training_supported: |
| 305 | + supported_scripts.add(JumpStartScriptScope.TRAINING.value) |
| 306 | + |
| 307 | + if script_allowlist is not None: |
| 308 | + script_allowlist = set(script_allowlist) |
| 309 | + if len(supported_scripts.intersection(script_allowlist)) == 0: |
| 310 | + continue |
| 311 | + if script_denylist is not None: |
| 312 | + script_denylist = set(script_denylist) |
| 313 | + if len(supported_scripts.intersection(script_denylist)) > 0: |
| 314 | + continue |
| 315 | + if not accept_vulnerable_models: |
| 316 | + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( |
| 317 | + region=region, model_id=model_id, version=model_version |
| 318 | + ) |
| 319 | + if model_specs.inference_vulnerable or model_specs.training_vulnerable: |
| 320 | + continue |
| 321 | + if not accept_deprecated_models: |
| 322 | + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( |
| 323 | + region=region, model_id=model_id, version=model_version |
| 324 | + ) |
| 325 | + if model_specs.deprecated: |
| 326 | + continue |
| 327 | + |
| 328 | + if model_id not in model_id_version_dict: |
| 329 | + model_id_version_dict[model_id] = set() |
| 330 | + |
| 331 | + model_id_version_dict[model_id].add(Version(model_version)) |
| 332 | + |
| 333 | + if not accept_old_models: |
| 334 | + model_id_version_dict = { |
| 335 | + model: set([max(versions)]) for model, versions in model_id_version_dict.items() |
| 336 | + } |
| 337 | + |
| 338 | + model_id_set = set() |
| 339 | + for model_id in model_id_version_dict: |
| 340 | + for version in model_id_version_dict[model_id]: |
| 341 | + model_id_set.add((model_id, str(version))) |
| 342 | + |
| 343 | + return sorted(list(model_id_set), key=cmp_to_key(_compare_model_version_tuples)) |
0 commit comments