-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Adding Jumpstart retrieval functions #2789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
shreyapandit
merged 16 commits into
aws:master-jumpstart
from
evakravi:feat/jumpstart-retrieve-functions
Jan 12, 2022
Merged
Changes from 10 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
fa79a5d
feat: jumpstart retrieve functions (wip)
evakravi 84bf597
change: use packaging library for jumpstart versions
evakravi 86cf577
fix: linting, mypy, logical issues for jumpstart models
evakravi fea1020
fix: tox.ini
evakravi f85b7f0
change: retrieve script uri argument name, comments
evakravi ed7b772
change: improve jumpstart retrieve fx impl, cleanup tests, comments, …
evakravi a427d4b
change: improve jumpstart retrieve uri unit tests, fix logic for imag…
evakravi f6ade25
feat: integration tests for jumpstart sdk retrieve functions
evakravi 5b98f42
change: cleanup code, remove redundant tests
evakravi 64b51d6
change: minor updates to jumpstart retrieve functions
evakravi 4e4590c
change: create helper function for setting kwargs and region for Jump…
evakravi 3c10d48
change: update name of jumpstart models accessor, fix small issues
evakravi 42743aa
fix: image_uri_region argument formatting
evakravi 8d95d30
fix: jumpstart works without region in aws config
evakravi c3d158c
Merge branch 'master-jumpstart' into feat/jumpstart-retrieve-functions
evakravi 9a396de
fix: reduce MaxRuntimeInSeconds for jumpstart integ test training jobs
evakravi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,4 +27,5 @@ venv/ | |
*.swp | ||
.docker/ | ||
env/ | ||
.vscode/ | ||
.vscode/ | ||
**/tmp |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains accessors related to SageMaker JumpStart.""" | ||
from __future__ import absolute_import | ||
from typing import Any, Dict, Optional | ||
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs | ||
from sagemaker.jumpstart import cache | ||
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME | ||
|
||
|
||
class SageMakerSettings(object): | ||
"""Static class for storing the SageMaker settings.""" | ||
|
||
_parsed_sagemaker_version = "" | ||
|
||
@staticmethod | ||
def set_sagemaker_version(version: str) -> None: | ||
"""Set SageMaker version.""" | ||
SageMakerSettings._parsed_sagemaker_version = version | ||
|
||
@staticmethod | ||
def get_sagemaker_version() -> str: | ||
"""Return SageMaker version.""" | ||
return SageMakerSettings._parsed_sagemaker_version | ||
|
||
|
||
class JumpStartModelsCache(object): | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Static class for storing the JumpStart models cache.""" | ||
|
||
_cache: Optional[cache.JumpStartModelsCache] = None | ||
_curr_region = JUMPSTART_DEFAULT_REGION_NAME | ||
|
||
_cache_kwargs: Dict[str, Any] = {} | ||
shreyapandit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@staticmethod | ||
def _validate_and_mutate_region_cache_kwargs( | ||
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None | ||
) -> Dict[str, Any]: | ||
"""Returns cache_kwargs with region argument removed if present. | ||
|
||
Raises: | ||
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument. | ||
|
||
Args: | ||
cache_kwargs (Optional[Dict[str, Any]]): cache kwargs to validate. | ||
region (str): The region to validate along with the kwargs. | ||
""" | ||
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs | ||
assert isinstance(cache_kwargs_dict, dict) | ||
shreyapandit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if region is not None and "region" in cache_kwargs_dict: | ||
if region != cache_kwargs_dict["region"]: | ||
raise ValueError( | ||
f"Inconsistent region definitions: {region}, {cache_kwargs_dict['region']}" | ||
) | ||
del cache_kwargs_dict["region"] | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
shreyapandit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return cache_kwargs_dict | ||
|
||
@staticmethod | ||
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: | ||
"""Returns model header from JumpStart models cache. | ||
|
||
Args: | ||
region (str): region for which to retrieve header. | ||
model_id (str): model id to retrieve. | ||
version (str): semantic version to retrieve for the model id. | ||
""" | ||
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs( | ||
JumpStartModelsCache._cache_kwargs, region | ||
) | ||
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region: | ||
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs) | ||
JumpStartModelsCache._curr_region = region | ||
assert JumpStartModelsCache._cache is not None | ||
mufaddal-rohawala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return JumpStartModelsCache._cache.get_header(model_id, version) | ||
|
||
@staticmethod | ||
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs: | ||
"""Returns model specs from JumpStart models cache. | ||
|
||
Args: | ||
region (str): region for which to retrieve header. | ||
model_id (str): model id to retrieve. | ||
version (str): semantic version to retrieve for the model id. | ||
""" | ||
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs( | ||
JumpStartModelsCache._cache_kwargs, region | ||
) | ||
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region: | ||
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs) | ||
JumpStartModelsCache._curr_region = region | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: @evakravi This section is same as lines 80-82 above, can we refactor this to avoid duplicacy? |
||
assert JumpStartModelsCache._cache is not None | ||
return JumpStartModelsCache._cache.get_specs(model_id, version) | ||
|
||
@staticmethod | ||
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None: | ||
"""Sets cache kwargs, clears the cache. | ||
|
||
Raises: | ||
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument. | ||
|
||
Args: | ||
cache_kwargs (str): cache kwargs to validate. | ||
region (str): Optional. The region to validate along with the kwargs. | ||
""" | ||
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs( | ||
cache_kwargs, region | ||
) | ||
JumpStartModelsCache._cache_kwargs = cache_kwargs | ||
if region is None: | ||
JumpStartModelsCache._cache = cache.JumpStartModelsCache( | ||
**JumpStartModelsCache._cache_kwargs | ||
) | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
JumpStartModelsCache._curr_region = region | ||
JumpStartModelsCache._cache = cache.JumpStartModelsCache( | ||
region=region, **JumpStartModelsCache._cache_kwargs | ||
) | ||
|
||
@staticmethod | ||
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None) -> None: | ||
"""Resets cache, optionally allowing cache kwargs to be passed to the new cache. | ||
|
||
Raises: | ||
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument. | ||
|
||
Args: | ||
cache_kwargs (str): cache kwargs to validate. | ||
region (str): The region to validate along with the kwargs. | ||
""" | ||
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs | ||
JumpStartModelsCache.set_cache_kwargs(cache_kwargs_dict, region) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.