-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Feat/gated model support #4510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/gated model support #4510
Changes from 13 commits
f7677d8
347b599
2973f23
5bc742f
9210e49
4905cee
82d0d92
4cef235
cb81d11
f50de6b
2eff8fb
b50c557
1af132e
c9f79fd
29733c8
352f1ac
aa8e4cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
# language governing permissions and limitations under the License. | ||
"""This module accessors for the SageMaker JumpStart Public Hub.""" | ||
from __future__ import absolute_import | ||
from typing import Dict, Any | ||
from typing import Dict, Any, Optional | ||
from sagemaker import model_uris, script_uris | ||
from sagemaker.jumpstart.curated_hub.types import ( | ||
HubContentDependencyType, | ||
|
@@ -21,7 +21,10 @@ | |
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri | ||
from sagemaker.jumpstart.enums import JumpStartScriptScope | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket | ||
from sagemaker.jumpstart.utils import ( | ||
get_jumpstart_content_bucket, | ||
get_jumpstart_gated_content_bucket, | ||
) | ||
|
||
|
||
class PublicModelDataAccessor: | ||
|
@@ -34,7 +37,11 @@ def __init__( | |
studio_specs: Dict[str, Dict[str, Any]], | ||
): | ||
self._region = region | ||
self._bucket = get_jumpstart_content_bucket(region) | ||
self._bucket = ( | ||
get_jumpstart_gated_content_bucket(region) | ||
if model_specs.gated_bucket | ||
else get_jumpstart_content_bucket(region) | ||
) | ||
self.model_specs = model_specs | ||
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift | ||
|
||
|
@@ -52,6 +59,8 @@ def inference_artifact_s3_reference(self): | |
@property | ||
def training_artifact_s3_reference(self): | ||
"""Retrieves s3 reference for model training artifact""" | ||
if not self.model_specs.training_supported: | ||
return None | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) | ||
) | ||
|
@@ -66,13 +75,17 @@ def inference_script_s3_reference(self): | |
@property | ||
def training_script_s3_reference(self): | ||
"""Retrieves s3 reference for model training script""" | ||
if not self.model_specs.training_supported: | ||
return None | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) | ||
) | ||
|
||
@property | ||
def default_training_dataset_s3_reference(self): | ||
"""Retrieves s3 reference for s3 directory containing model training datasets""" | ||
if not self.model_specs.training_supported: | ||
return None | ||
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) | ||
|
||
@property | ||
|
@@ -95,22 +108,28 @@ def _get_bucket_name(self) -> str: | |
|
||
def __get_training_dataset_prefix(self) -> str: | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Retrieves training dataset location""" | ||
return self.studio_specs["defaultDataKey"] | ||
return self.studio_specs.get("defaultDataKey") | ||
|
||
def _jumpstart_script_s3_uri(self, model_scope: str) -> str: | ||
def _jumpstart_script_s3_uri(self, model_scope: str) -> Optional[str]: | ||
"""Retrieves JumpStart script s3 location""" | ||
return script_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
script_scope=model_scope, | ||
) | ||
try: | ||
return script_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
script_scope=model_scope, | ||
) | ||
except ValueError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. log something perhaps? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we can log something here, though I don't think we'll ever reach this since we're only calling this function if training is supported |
||
return None | ||
|
||
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: | ||
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> Optional[str]: | ||
"""Retrieves JumpStart artifact s3 location""" | ||
return model_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
model_scope=model_scope, | ||
) | ||
try: | ||
return model_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
model_scope=model_scope, | ||
) | ||
except ValueError: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we checking for gated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to / can't copy any files over from the gated bucket