Skip to content

Commit 427dec6

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
feat: Logic to detect hardware GPU count and aggregate GPU memory size in MiB (#4389)
* Add logic to detect hardware GPU count and aggregate GPU memory size in MiB * Fix all formatting * Addressed PR review comments * Addressed PR Review messages * Addressed PR Review Messages * Addressed PR Review comments * Addressed PR Review Comments * Add integration tests * Add config * Fix integration tests * Include Instance Types GPU infor Config files * Addressed PR review comments * Fix unit tests * Fix unit test: 'Mock' object is not subscriptable --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent fc11ace commit 427dec6

File tree

6 files changed

+1107
-0
lines changed

6 files changed

+1107
-0
lines changed

src/sagemaker/image_uri_config/instance_gpu_info.json

+782
Large diffs are not rendered by default.
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve instance types GPU info."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import os
18+
from typing import Dict
19+
20+
21+
def retrieve(region: str) -> Dict[str, Dict[str, int]]:
22+
"""Retrieves instance types GPU info of the given region.
23+
24+
Args:
25+
region (str): The AWS region.
26+
27+
Returns:
28+
dict[str, dict[str, int]]: A dictionary that contains instance types as keys
29+
and GPU info as values or empty dictionary if the
30+
config for the given region is not found.
31+
32+
Raises:
33+
ValueError: If no config found.
34+
"""
35+
config_path = os.path.join(
36+
os.path.dirname(__file__), "image_uri_config", "instance_gpu_info.json"
37+
)
38+
try:
39+
with open(config_path) as f:
40+
instance_types_gpu_info_config = json.load(f)
41+
return instance_types_gpu_info_config.get(region, {})
42+
except FileNotFoundError:
43+
raise ValueError("Could not find instance types gpu info.")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Utilities for detecting available GPUs and Aggregate GPU Memory size of an instance"""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
from typing import Tuple
18+
19+
from botocore.exceptions import ClientError
20+
21+
from sagemaker import Session
22+
from sagemaker import instance_types_gpu_info
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]:
28+
"""Get GPU info for the provided instance
29+
30+
Args:
31+
instance_type (str)
32+
session: The session to use.
33+
34+
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
35+
and aggregate memory size in MiB at index 1.
36+
37+
Raises:
38+
ValueError: If The given instance type does not exist or GPU is not enabled.
39+
"""
40+
ec2_client = session.boto_session.client("ec2")
41+
ec2_instance = _format_instance_type(instance_type)
42+
43+
try:
44+
instance_info = ec2_client.describe_instance_types(InstanceTypes=[ec2_instance]).get(
45+
"InstanceTypes"
46+
)[0]
47+
except ClientError:
48+
raise ValueError(f"Provided instance_type is not GPU enabled: [#{ec2_instance}]")
49+
50+
if instance_info is not None:
51+
gpus_info = instance_info.get("GpuInfo")
52+
if gpus_info is not None:
53+
gpus = gpus_info.get("Gpus")
54+
if gpus is not None and len(gpus) > 0:
55+
count = gpus[0].get("Count")
56+
total_gpu_memory_in_mib = gpus_info.get("TotalGpuMemoryInMiB")
57+
if count and total_gpu_memory_in_mib:
58+
instance_gpu_info = (
59+
count,
60+
total_gpu_memory_in_mib,
61+
)
62+
logger.info("GPU Info [%s]: %s", ec2_instance, instance_gpu_info)
63+
return instance_gpu_info
64+
65+
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")
66+
67+
68+
def _get_gpu_info_fallback(instance_type: str, region: str) -> Tuple[int, int]:
69+
"""Get GPU info for the provided from the config
70+
71+
Args:
72+
instance_type (str):
73+
region: The AWS region.
74+
75+
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
76+
and aggregate memory size in MiB at index 1.
77+
78+
Raises:
79+
ValueError: If The given instance type does not exist.
80+
"""
81+
instance_types_gpu_info_config = instance_types_gpu_info.retrieve(region)
82+
fallback_instance_gpu_info = instance_types_gpu_info_config.get(instance_type)
83+
84+
ec2_instance = _format_instance_type(instance_type)
85+
if fallback_instance_gpu_info is None:
86+
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")
87+
88+
fallback_instance_gpu_info = (
89+
fallback_instance_gpu_info.get("Count"),
90+
fallback_instance_gpu_info.get("TotalGpuMemoryInMiB"),
91+
)
92+
logger.info("GPU Info [%s]: %s", ec2_instance, fallback_instance_gpu_info)
93+
return fallback_instance_gpu_info
94+
95+
96+
def _format_instance_type(instance_type: str) -> str:
97+
"""Formats provided instance type name
98+
99+
Args:
100+
instance_type (str):
101+
102+
Returns: formatted instance type.
103+
"""
104+
split_instance = instance_type.split(".")
105+
106+
if len(split_instance) > 2:
107+
split_instance.pop(0)
108+
109+
ec2_instance = ".".join(split_instance)
110+
return ec2_instance
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
17+
from sagemaker.serve.utils import hardware_detector
18+
19+
REGION = "us-west-2"
20+
VALID_INSTANCE_TYPE = "ml.g5.48xlarge"
21+
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge"
22+
EXPECTED_INSTANCE_GPU_INFO = (8, 196608)
23+
24+
25+
def test_get_gpu_info_success(sagemaker_session):
26+
gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session)
27+
28+
assert gpu_info == EXPECTED_INSTANCE_GPU_INFO
29+
30+
31+
def test_get_gpu_info_throws(sagemaker_session):
32+
with pytest.raises(ValueError):
33+
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)
34+
35+
36+
def test_get_gpu_info_fallback_success():
37+
gpu_info = hardware_detector._get_gpu_info_fallback(VALID_INSTANCE_TYPE, REGION)
38+
39+
assert gpu_info == EXPECTED_INSTANCE_GPU_INFO
40+
41+
42+
def test_get_gpu_info_fallback_throws():
43+
with pytest.raises(ValueError):
44+
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from botocore.exceptions import ClientError
16+
import pytest
17+
18+
from sagemaker.serve.utils import hardware_detector
19+
20+
REGION = "us-west-2"
21+
VALID_INSTANCE_TYPE = "ml.g5.48xlarge"
22+
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge"
23+
EXPECTED_INSTANCE_GPU_INFO = (8, 196608)
24+
25+
26+
def test_get_gpu_info_success(sagemaker_session, boto_session):
27+
boto_session.client("ec2").describe_instance_types.return_value = {
28+
"InstanceTypes": [
29+
{
30+
"GpuInfo": {
31+
"Gpus": [
32+
{
33+
"Name": "A10G",
34+
"Manufacturer": "NVIDIA",
35+
"Count": 8,
36+
"MemoryInfo": {"SizeInMiB": 24576},
37+
}
38+
],
39+
"TotalGpuMemoryInMiB": 196608,
40+
},
41+
}
42+
]
43+
}
44+
45+
instance_gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session)
46+
47+
boto_session.client("ec2").describe_instance_types.assert_called_once_with(
48+
InstanceTypes=["g5.48xlarge"]
49+
)
50+
assert instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO
51+
52+
53+
def test_get_gpu_info_throws(sagemaker_session, boto_session):
54+
boto_session.client("ec2").describe_instance_types.return_value = {"InstanceTypes": [{}]}
55+
56+
with pytest.raises(ValueError):
57+
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)
58+
59+
60+
def test_get_gpu_info_describe_instance_types_throws(sagemaker_session, boto_session):
61+
boto_session.client("ec2").describe_instance_types.side_effect = ClientError(
62+
{
63+
"Error": {
64+
"Code": "InvalidInstanceType",
65+
"Message": f"An error occurred (InvalidInstanceType) when calling the DescribeInstanceTypes "
66+
f"operation: The following supplied instance types do not exist: [{INVALID_INSTANCE_TYPE}]",
67+
}
68+
},
69+
"DescribeInstanceTypes",
70+
)
71+
72+
with pytest.raises(ValueError):
73+
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)
74+
75+
76+
def test_get_gpu_info_fallback_success():
77+
fallback_instance_gpu_info = hardware_detector._get_gpu_info_fallback(
78+
VALID_INSTANCE_TYPE, REGION
79+
)
80+
81+
assert fallback_instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO
82+
83+
84+
def test_get_gpu_info_fallback_throws():
85+
with pytest.raises(ValueError):
86+
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION)
87+
88+
89+
def test_format_instance_type_success():
90+
formatted_instance_type = hardware_detector._format_instance_type(VALID_INSTANCE_TYPE)
91+
92+
assert formatted_instance_type == "g5.48xlarge"
93+
94+
95+
def test_format_instance_type_without_ml_success():
96+
formatted_instance_type = hardware_detector._format_instance_type("g5.48xlarge")
97+
98+
assert formatted_instance_type == "g5.48xlarge"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import instance_types_gpu_info
16+
17+
REGION = "us-west-2"
18+
INVALID_REGION = "invalid-region"
19+
20+
21+
def test_retrieve_success():
22+
data = instance_types_gpu_info.retrieve(REGION)
23+
24+
assert len(data) > 0
25+
26+
27+
def test_retrieve_throws():
28+
data = instance_types_gpu_info.retrieve(INVALID_REGION)
29+
30+
assert len(data) == 0

0 commit comments

Comments
 (0)