Skip to content

Commit 2b43f7c

Browse files
authored
Merge branch 'master' into dw-tlv-3x-updates
2 parents 8cdc741 + dd4be4c commit 2b43f7c

File tree

7 files changed

+187
-2
lines changed

7 files changed

+187
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
{
2+
"inference": {
3+
"processors": ["gpu"],
4+
"version_aliases": {
5+
"0.1": "0.1.0"
6+
},
7+
"versions": {
8+
"0.1.0": {
9+
"py_versions": ["py310"],
10+
"registries": {
11+
"af-south-1": "626614931356",
12+
"il-central-1": "780543022126",
13+
"ap-east-1": "871362719292",
14+
"ap-northeast-1": "763104351884",
15+
"ap-northeast-2": "763104351884",
16+
"ap-northeast-3": "364406365360",
17+
"ap-south-1": "763104351884",
18+
"ap-south-2": "772153158452",
19+
"ap-southeast-1": "763104351884",
20+
"ap-southeast-2": "763104351884",
21+
"ap-southeast-3": "907027046896",
22+
"ap-southeast-4": "457447274322",
23+
"ca-central-1": "763104351884",
24+
"eu-central-1": "763104351884",
25+
"eu-central-2": "380420809688",
26+
"eu-north-1": "763104351884",
27+
"eu-west-1": "763104351884",
28+
"eu-west-2": "763104351884",
29+
"eu-west-3": "763104351884",
30+
"eu-south-1": "692866216735",
31+
"eu-south-2": "503227376785",
32+
"me-south-1": "217643126080",
33+
"me-central-1": "914824155844",
34+
"sa-east-1": "763104351884",
35+
"us-east-1": "763104351884",
36+
"us-east-2": "763104351884",
37+
"us-west-1": "763104351884",
38+
"us-west-2": "763104351884"
39+
},
40+
"tag_prefix": "2.0.1-sgm0.1.0",
41+
"repository": "stabilityai-pytorch-inference",
42+
"container_version": {
43+
"gpu": "cu118-ubuntu20.04-sagemaker"
44+
}
45+
}
46+
}
47+
}
48+
}

src/sagemaker/image_uris.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
4040
INFERENCE_GRAVITON = "inference_graviton"
4141
DATA_WRANGLER_FRAMEWORK = "data-wrangler"
42+
STABILITYAI_FRAMEWORK = "stabilityai"
4243

4344

4445
@override_pipeline_parameter_var
@@ -476,7 +477,11 @@ def _validate_version_and_set_if_needed(version, config, framework):
476477

477478
return available_versions[0]
478479

479-
if version is None and framework in [DATA_WRANGLER_FRAMEWORK, HUGGING_FACE_LLM_FRAMEWORK]:
480+
if version is None and framework in [
481+
DATA_WRANGLER_FRAMEWORK,
482+
HUGGING_FACE_LLM_FRAMEWORK,
483+
STABILITYAI_FRAMEWORK,
484+
]:
480485
version = _get_latest_versions(available_versions)
481486

482487
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))

src/sagemaker/stabilityai/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""StabilityAI module."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.stabilityai.stability_utils import get_stabilityai_image_uri # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Utility functions."""
14+
15+
from __future__ import absolute_import
16+
17+
from typing import Optional
18+
19+
from sagemaker import image_uris
20+
from sagemaker.session import Session
21+
22+
23+
def get_stabilityai_image_uri(
24+
session: Optional[Session] = None,
25+
region: Optional[str] = None,
26+
version: Optional[str] = None,
27+
image_scope: Optional[str] = "inference",
28+
) -> str:
29+
"""Very basic utility function to fetch image URI of StabilityAI images.
30+
31+
Args:
32+
session (Session): SageMaker session.
33+
region (str): AWS region of image URI.
34+
version (str): Framework version. Latest version used if not specified.
35+
image_scope (str): Image type. e.g. inference, training
36+
Returns:
37+
Image URI string.
38+
"""
39+
40+
if region is None:
41+
if session is None:
42+
region = Session().boto_session.region_name
43+
else:
44+
region = session.boto_session.region_name
45+
return image_uris.retrieve(
46+
framework="stabilityai",
47+
region=region,
48+
version=version,
49+
image_scope=image_scope,
50+
)

tests/scripts/run-notebook-test.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,14 @@ echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN"
123123
--platformIdentifier notebook-al2-v2 \
124124
--consider-skips-failures \
125125
./amazon-sagemaker-examples/sagemaker_processing/spark_distributed_data_processing/sagemaker-spark-processing.ipynb \
126-
./amazon-sagemaker-examples/advanced_functionality/kmeans_bring_your_own_model/kmeans_bring_your_own_model.ipynb \
127126
./amazon-sagemaker-examples/advanced_functionality/tensorflow_iris_byom/tensorflow_BYOM_iris.ipynb \
128127
./amazon-sagemaker-examples/sagemaker-python-sdk/1P_kmeans_highlevel/kmeans_mnist.ipynb \
129128
./amazon-sagemaker-examples/sagemaker-python-sdk/scikit_learn_randomforest/Sklearn_on_SageMaker_end2end.ipynb \
130129
./amazon-sagemaker-examples/sagemaker-pipelines/tabular/abalone_build_train_deploy/sagemaker-pipelines-preprocess-train-evaluate-batch-transform.ipynb \
131130
131+
# Skipping test until fix in example notebook to move to new conda environment
132+
#./amazon-sagemaker-examples/advanced_functionality/kmeans_bring_your_own_model/kmeans_bring_your_own_model.ipynb \
133+
132134
# Skipping test until fix in example notebook to install docker-compose is complete
133135
#./amazon-sagemaker-examples/sagemaker-python-sdk/tensorflow_moving_from_framework_mode_to_script_mode/tensorflow_moving_from_framework_mode_to_script_mode.ipynb \
134136

tests/unit/sagemaker/image_uris/expected_uris.py

+5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ def huggingface_llm_framework_uri(
9191
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
9292

9393

94+
def stabilityai_framework_uri(repo, account, tag, region=REGION):
95+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
96+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
97+
98+
9499
def base_python_uri(repo, account, region=REGION):
95100
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
96101
tag = "1.0"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
17+
from sagemaker.stabilityai import get_stabilityai_image_uri
18+
from tests.unit.sagemaker.image_uris import expected_uris
19+
20+
ACCOUNTS = {
21+
"af-south-1": "626614931356",
22+
"il-central-1": "780543022126",
23+
"ap-east-1": "871362719292",
24+
"ap-northeast-1": "763104351884",
25+
"ap-northeast-2": "763104351884",
26+
"ap-northeast-3": "364406365360",
27+
"ap-south-1": "763104351884",
28+
"ap-southeast-1": "763104351884",
29+
"ap-southeast-2": "763104351884",
30+
"ap-southeast-3": "907027046896",
31+
"ca-central-1": "763104351884",
32+
"eu-central-1": "763104351884",
33+
"eu-north-1": "763104351884",
34+
"eu-west-1": "763104351884",
35+
"eu-west-2": "763104351884",
36+
"eu-west-3": "763104351884",
37+
"eu-south-1": "692866216735",
38+
"me-south-1": "217643126080",
39+
"sa-east-1": "763104351884",
40+
"us-east-1": "763104351884",
41+
"us-east-2": "763104351884",
42+
"us-west-1": "763104351884",
43+
"us-west-2": "763104351884",
44+
}
45+
SAI_VERSIONS = ["0.1.0"]
46+
SAI_VERSIONS_MAPPING = {"0.1.0": "2.0.1-sgm0.1.0-gpu-py310-cu118-ubuntu20.04-sagemaker"}
47+
48+
49+
@pytest.mark.parametrize("version", SAI_VERSIONS)
50+
def test_stabilityai_image_uris(version):
51+
for region in ACCOUNTS.keys():
52+
result = get_stabilityai_image_uri(region=region, version=version)
53+
expected = expected_uris.stabilityai_framework_uri(
54+
"stabilityai-pytorch-inference",
55+
ACCOUNTS[region],
56+
SAI_VERSIONS_MAPPING[version],
57+
region=region,
58+
)
59+
assert expected == result

0 commit comments

Comments
 (0)