Skip to content

Commit d9f4dd1

Browse files
authored
feature: inferentia support (#1373)
* feature: inferentia support
1 parent 9e354e4 commit d9f4dd1

File tree

6 files changed

+222
-0
lines changed

6 files changed

+222
-0
lines changed

src/sagemaker/fw_utils.py

+64
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@
104104
"pytorch-serving-eia": [1, 3, 1],
105105
}
106106

107+
INFERENTIA_VERSION_RANGES = {
108+
"neo-mxnet": [[1, 5, 1], [1, 5, 1]],
109+
"neo-tensorflow": [[1, 15, 0], [1, 15, 0]],
110+
}
111+
112+
INFERENTIA_SUPPORTED_REGIONS = ["us-east-1", "us-west-2"]
113+
107114
DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1", "us-iso-east-1"]
108115

109116

@@ -124,6 +131,23 @@ def is_version_equal_or_higher(lowest_version, framework_version):
124131
return version_list >= lowest_version[0 : len(version_list)]
125132

126133

134+
def is_version_equal_or_lower(highest_version, framework_version):
135+
"""Determine whether the ``framework_version`` is equal to or lower than
136+
``highest_version``
137+
138+
Args:
139+
highest_version (List[int]): highest version represented in an integer
140+
list
141+
framework_version (str): framework version string
142+
143+
Returns:
144+
bool: Whether or not ``framework_version`` is equal to or lower than
145+
``highest_version``
146+
"""
147+
version_list = [int(s) for s in framework_version.split(".")]
148+
return version_list <= highest_version[0 : len(version_list)]
149+
150+
127151
def _is_dlc_version(framework, framework_version, py_version):
128152
"""Return if the framework's version uses the corresponding DLC image.
129153
@@ -144,6 +168,23 @@ def _is_dlc_version(framework, framework_version, py_version):
144168
return False
145169

146170

171+
def _is_inferentia_supported(framework, framework_version):
172+
"""Return if Inferentia supports the framework and its version.
173+
174+
Args:
175+
framework (str): The framework name, e.g. "tensorflow"
176+
framework_version (str): The framework version
177+
178+
Returns:
179+
bool: Whether or not Inferentia supports the framework and its version.
180+
"""
181+
lowest_version_list = INFERENTIA_VERSION_RANGES.get(framework)[0]
182+
highest_version_list = INFERENTIA_VERSION_RANGES.get(framework)[1]
183+
return is_version_equal_or_higher(
184+
lowest_version_list, framework_version
185+
) and is_version_equal_or_lower(highest_version_list, framework_version)
186+
187+
147188
def _registry_id(region, framework, py_version, account, framework_version):
148189
"""Return the Amazon ECR registry number (or AWS account ID) for
149190
the given framework, framework version, Python version, and region.
@@ -240,11 +281,34 @@ def create_image_uri(
240281
# 'cpu' or 'gpu'.
241282
if family in optimized_families:
242283
device_type = family
284+
elif family.startswith("inf"):
285+
device_type = "inf"
243286
elif family[0] in ["g", "p"]:
244287
device_type = "gpu"
245288
else:
246289
device_type = "cpu"
247290

291+
if device_type == "inf":
292+
if region not in INFERENTIA_SUPPORTED_REGIONS:
293+
raise ValueError(
294+
"Inferentia is not supported in region {}. Supported regions are {}".format(
295+
region, ", ".join(INFERENTIA_SUPPORTED_REGIONS)
296+
)
297+
)
298+
if framework not in INFERENTIA_VERSION_RANGES:
299+
raise ValueError(
300+
"Inferentia does not support {}. Currently it supports "
301+
"MXNet and TensorFlow with more frameworks coming soon.".format(
302+
framework.split("-")[-1]
303+
)
304+
)
305+
if not _is_inferentia_supported(framework, framework_version):
306+
raise ValueError(
307+
"Inferentia is not supported with {} version {}.".format(
308+
framework.split("-")[-1], framework_version
309+
)
310+
)
311+
248312
use_dlc_image = _is_dlc_version(framework, framework_version, py_version)
249313

250314
if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia"):

src/sagemaker/model.py

+32
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
"us-gov-west-1": "263933020539",
5151
}
5252

53+
INFERENTIA_INSTANCE_PREFIX = "ml_inf"
54+
5355

5456
class Model(object):
5557
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
@@ -286,6 +288,23 @@ def _neo_image(self, region, target_instance_type, framework, framework_version)
286288
account=self._neo_image_account(region),
287289
)
288290

291+
def _inferentia_image(self, region, target_instance_type, framework, framework_version):
292+
"""
293+
Args:
294+
region:
295+
target_instance_type:
296+
framework:
297+
framework_version:
298+
"""
299+
return fw_utils.create_image_uri(
300+
region,
301+
"neo-" + framework.lower(),
302+
target_instance_type.replace("_", "."),
303+
framework_version,
304+
py_version="py3",
305+
account=self._neo_image_account(region),
306+
)
307+
289308
def compile(
290309
self,
291310
target_instance_family,
@@ -364,6 +383,14 @@ def compile(
364383
framework_version,
365384
)
366385
self._is_compiled_model = True
386+
elif target_instance_family.startswith(INFERENTIA_INSTANCE_PREFIX):
387+
self.image = self._inferentia_image(
388+
self.sagemaker_session.boto_region_name,
389+
target_instance_family,
390+
framework,
391+
framework_version,
392+
)
393+
self._is_compiled_model = True
367394
else:
368395
LOGGER.warning(
369396
"The instance type %s is not supported to deploy via SageMaker,"
@@ -437,6 +464,11 @@ def deploy(
437464
if self.role is None:
438465
raise ValueError("Role can not be null for deploying a model")
439466

467+
if instance_type.startswith("ml.inf") and not self._is_compiled_model:
468+
LOGGER.warning(
469+
"Your model is not compiled. Please compile your model before using Inferentia."
470+
)
471+
440472
compiled_model_suffix = "-".join(instance_type.split(".")[:-1])
441473
if self._is_compiled_model:
442474
name_prefix = self.name or utils.name_from_image(self.image)

tests/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,11 @@ def cpu_instance_type(sagemaker_session, request):
269269
return "ml.m4.xlarge"
270270

271271

272+
@pytest.fixture(scope="session")
273+
def inf_instance_type(sagemaker_session, request):
274+
return "ml.inf1.xlarge"
275+
276+
272277
@pytest.fixture(scope="session")
273278
def ec2_instance_type(cpu_instance_type):
274279
return cpu_instance_type[3:]
@@ -289,6 +294,11 @@ def cpu_instance_family(cpu_instance_type):
289294
return "_".join(cpu_instance_type.split(".")[0:2])
290295

291296

297+
@pytest.fixture(scope="session")
298+
def inf_instance_family(inf_instance_type):
299+
return "_".join(inf_instance_type.split(".")[0:2])
300+
301+
292302
def pytest_generate_tests(metafunc):
293303
if "instance_type" in metafunc.fixturenames:
294304
boto_config = metafunc.config.getoption("--boto-config")

tests/integ/test_neo_mxnet.py

+36
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2525

2626
NEO_MXNET_VERSION = "1.4.1" # Neo doesn't support MXNet 1.6 yet.
27+
INF_MXNET_VERSION = "1.5.1"
2728

2829

2930
@pytest.fixture(scope="module")
@@ -110,3 +111,38 @@ def test_deploy_model(
110111
predictor.content_type = "application/vnd+python.numpy+binary"
111112
data = numpy.zeros(shape=(1, 1, 28, 28))
112113
predictor.predict(data)
114+
115+
116+
@pytest.mark.skip(reason="Inferentia is not supported yet.")
117+
def test_inferentia_deploy_model(
118+
mxnet_training_job, sagemaker_session, inf_instance_type, inf_instance_family
119+
):
120+
endpoint_name = unique_name_from_base("test-neo-deploy-model")
121+
122+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
123+
desc = sagemaker_session.sagemaker_client.describe_training_job(
124+
TrainingJobName=mxnet_training_job
125+
)
126+
model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
127+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_neo.py")
128+
role = "SageMakerRole"
129+
model = MXNetModel(
130+
model_data,
131+
role,
132+
entry_point=script_path,
133+
framework_version=INF_MXNET_VERSION,
134+
sagemaker_session=sagemaker_session,
135+
)
136+
137+
model.compile(
138+
target_instance_family=inf_instance_family,
139+
input_shape={"data": [1, 1, 28, 28]},
140+
role=role,
141+
job_name=unique_name_from_base("test-deploy-model-compilation-job"),
142+
output_path="/".join(model_data.split("/")[:-1]),
143+
)
144+
predictor = model.deploy(1, inf_instance_type, endpoint_name=endpoint_name)
145+
146+
predictor.content_type = "application/vnd+python.numpy+binary"
147+
data = numpy.zeros(shape=(1, 1, 28, 28))
148+
predictor.predict(data)

tests/unit/test_fw_utils.py

+56
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,62 @@ def test_invalid_instance_type():
721721
fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, "p3.2xlarge", "1.0.0", "py3")
722722

723723

724+
def test_valid_inferentia_image():
725+
image_uri = fw_utils.create_image_uri(
726+
REGION,
727+
"neo-tensorflow",
728+
"ml.inf1.2xlarge",
729+
"1.15.0",
730+
py_version="py3",
731+
account=MOCK_ACCOUNT,
732+
)
733+
assert (
734+
image_uri
735+
== "{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format(
736+
MOCK_ACCOUNT, REGION
737+
)
738+
)
739+
740+
741+
def test_invalid_inferentia_region():
742+
with pytest.raises(ValueError) as e:
743+
fw_utils.create_image_uri(
744+
"ap-south-1",
745+
"neo-tensorflow",
746+
"ml.inf1.2xlarge",
747+
"1.15.0",
748+
py_version="py3",
749+
account=MOCK_ACCOUNT,
750+
)
751+
assert "Inferentia is not supported in region ap-south-1." in str(e)
752+
753+
754+
def test_inferentia_invalid_framework():
755+
with pytest.raises(ValueError) as e:
756+
fw_utils.create_image_uri(
757+
REGION,
758+
"neo-pytorch",
759+
"ml.inf1.2xlarge",
760+
"1.4.0",
761+
py_version="py3",
762+
account=MOCK_ACCOUNT,
763+
)
764+
assert "Inferentia does not support pytorch." in str(e)
765+
766+
767+
def test_invalid_inferentia_framework_version():
768+
with pytest.raises(ValueError) as e:
769+
fw_utils.create_image_uri(
770+
REGION,
771+
"neo-tensorflow",
772+
"ml.inf1.2xlarge",
773+
"1.15.2",
774+
py_version="py3",
775+
account=MOCK_ACCOUNT,
776+
)
777+
assert "Inferentia is not supported with tensorflow version 1.15.2." in str(e)
778+
779+
724780
@patch(
725781
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
726782
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),

tests/unit/test_model.py

+24
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ACCELERATOR_TYPE = "ml.eia.medium"
4040
IMAGE_NAME = "fakeimage"
4141
REGION = "us-west-2"
42+
NEO_REGION_ACCOUNT = "301217895009"
4243
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
4344
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
4445
BRANCH = "test-branch-git-config"
@@ -546,6 +547,29 @@ def test_delete_non_deployed_model(sagemaker_session):
546547
model.delete_model()
547548

548549

550+
def test_compile_model_for_inferentia(sagemaker_session, tmpdir):
551+
sagemaker_session.wait_for_compilation_job = Mock(
552+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
553+
)
554+
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
555+
model.compile(
556+
target_instance_family="ml_inf",
557+
input_shape={"data": [1, 3, 1024, 1024]},
558+
output_path="s3://output",
559+
role="role",
560+
framework="tensorflow",
561+
framework_version="1.15.0",
562+
job_name="compile-model",
563+
)
564+
assert (
565+
"{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format(
566+
NEO_REGION_ACCOUNT, REGION
567+
)
568+
== model.image
569+
)
570+
assert model._is_compiled_model is True
571+
572+
549573
def test_compile_model_for_edge_device(sagemaker_session, tmpdir):
550574
sagemaker_session.wait_for_compilation_job = Mock(
551575
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE

0 commit comments

Comments
 (0)