Skip to content

Commit 64b51d6

Browse files
committed
change: minor updates to jumpstart retrieve functions
1 parent 5b98f42 commit 64b51d6

File tree

6 files changed

+16
-14
lines changed

6 files changed

+16
-14
lines changed

src/sagemaker/image_uris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,13 @@ def retrieve(
9595
return artifacts._retrieve_image_uri(
9696
model_id,
9797
model_version,
98+
image_scope,
9899
framework,
99100
region,
100101
version,
101102
py_version,
102103
instance_type,
103104
accelerator_type,
104-
image_scope,
105105
container_version,
106106
distribution,
107107
base_framework_version,

src/sagemaker/jumpstart/artifacts.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
def _retrieve_image_uri(
2929
model_id: str,
3030
model_version: str,
31+
image_scope: str,
3132
framework: Optional[str],
3233
region: Optional[str],
3334
version: Optional[str],
3435
py_version: Optional[str],
3536
instance_type: Optional[str],
3637
accelerator_type: Optional[str],
37-
image_scope: Optional[str],
3838
container_version: Optional[str],
3939
distribution: Optional[str],
4040
base_framework_version: Optional[str],
@@ -50,6 +50,9 @@ def _retrieve_image_uri(
5050
model_id (str): JumpStart model ID for which to retrieve image URI.
5151
model_version (str): Version of the JumpStart model for which to retrieve
5252
the image URI (default: None).
53+
image_scope (str): The image type, i.e. what it is used for.
54+
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
55+
``image_scope`` is ignored.
5356
framework (str): The name of the framework or algorithm.
5457
region (str): The AWS region.
5558
version (str): The framework or algorithm version. This is required if there is
@@ -61,9 +64,6 @@ def _retrieve_image_uri(
6164
there are different images for different processor types.
6265
accelerator_type (str): Elastic Inference accelerator type. For more, see
6366
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
64-
image_scope (str): The image type, i.e. what it is used for.
65-
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
66-
``image_scope`` is ignored.
6767
container_version (str): the version of docker image.
6868
Ideally the value of parameter should be created inside the framework.
6969
For custom use, see the list of supported container versions:
@@ -112,7 +112,7 @@ def _retrieve_image_uri(
112112
if framework is not None and framework != ecr_specs.framework:
113113
raise ValueError(
114114
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "
115-
"and version {model_version}'."
115+
f"and version {model_version}'."
116116
)
117117

118118
if version is not None and version != ecr_specs.framework_version:
@@ -124,7 +124,7 @@ def _retrieve_image_uri(
124124
if py_version is not None and py_version != ecr_specs.py_version:
125125
raise ValueError(
126126
f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' "
127-
"and version {model_version}'."
127+
f"and version {model_version}'."
128128
)
129129

130130
base_framework_version_override: Optional[str] = None

src/sagemaker/script_uris.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Accessors to retrieve the script S3 URI to be run pretrained ML models
14-
in SageMaker containers.
15-
"""
13+
"""Accessors to retrieve the script S3 URI to run pretrained ML models."""
14+
1615
from __future__ import absolute_import
1716

1817
import logging

tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def __init__(
5151
region=JUMPSTART_DEFAULT_REGION_NAME,
5252
boto_config=Config(retries={"max_attempts": 10, "mode": "standard"}),
5353
base_name="jumpstart-inference-job",
54-
execution_role=Session().get_caller_identity_arn(),
54+
execution_role=None,
5555
) -> None:
5656

5757
self.suffix = suffix
5858
self.test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
5959
self.region = region
6060
self.config = boto_config
6161
self.base_name = base_name
62-
self.execution_role = execution_role
62+
self.execution_role = execution_role or Session().get_caller_identity_arn()
6363
self.account_id = boto3.client("sts").get_caller_identity()["Account"]
6464
self.image_uri = image_uri
6565
self.script_uri = script_uri

tests/integ/sagemaker/jumpstart/retrieve_uri/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
region=JUMPSTART_DEFAULT_REGION_NAME,
4545
boto_config=Config(retries={"max_attempts": 10, "mode": "standard"}),
4646
base_name="jumpstart-training-job",
47-
execution_role=Session().get_caller_identity_arn(),
47+
execution_role=None,
4848
) -> None:
4949

5050
self.account_id = boto3.client("sts").get_caller_identity()["Account"]
@@ -53,7 +53,7 @@ def __init__(
5353
self.region = region
5454
self.config = boto_config
5555
self.base_name = base_name
56-
self.execution_role = execution_role
56+
self.execution_role = execution_role or Session().get_caller_identity_arn()
5757
self.image_uri = image_uri
5858
self.script_uri = script_uri
5959
self.model_uri = model_uri

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def get_header_from_base_header(
4848
def get_prototype_model_spec(
4949
region: str = None, model_id: str = None, version: str = None
5050
) -> JumpStartModelSpecs:
51+
"""This function mocks cache accessor functions. For this mock,
52+
we only retrieve model specs based on the model id.
53+
"""
5154

5255
specs = JumpStartModelSpecs(PROTOTYPICAL_MODEL_SPECS_DICT[model_id])
5356
return specs

0 commit comments

Comments
 (0)