Skip to content

feature: Inferentia Neuron support for HuggingFace #2976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e1e5fc4
feature: Inferentia support for huggingface
jeniyat Mar 8, 2022
6343907
inf/neuron for HF
jeniyat Mar 8, 2022
4f9ec06
update w/ seperate json files
jeniyat Mar 8, 2022
428917c
removed changes from huggingface.json
jeniyat Mar 8, 2022
a17a6c7
added version as arguments
jeniyat Mar 8, 2022
58247e1
updated json file and retrive function with image_versions
jeniyat Mar 9, 2022
61a7840
updated retrive function to incorportate inference tool extraction
jeniyat Mar 9, 2022
8138849
removedd un-used comma from json
jeniyat Mar 9, 2022
abc643c
updated if-else condition
jeniyat Mar 9, 2022
4291fb2
updated if-else condition
jeniyat Mar 9, 2022
18dee73
remove un-used if
jeniyat Mar 9, 2022
c3e2e7f
updated with black
jeniyat Mar 9, 2022
d6dd56e
compacted the if condition
jeniyat Mar 9, 2022
355bf9c
compacted the if condition
jeniyat Mar 9, 2022
06e215c
handling of partial version name
jeniyat Mar 9, 2022
056f7db
handling of partial version name
jeniyat Mar 9, 2022
5335ad6
removed image_version
jeniyat Mar 10, 2022
63209b0
adde HF in neo compilable list
jeniyat Mar 11, 2022
9c1107d
adde HF in neo compilable list
jeniyat Mar 11, 2022
e7051a3
adde HF in neo compilable list
jeniyat Mar 11, 2022
d40f048
exception handling for None image_uri
jeniyat Mar 11, 2022
fff80f2
added unit test
jeniyat Mar 11, 2022
b13cf78
updated unit test
jeniyat Mar 11, 2022
e9e43f5
addressed nit comments
jeniyat Mar 11, 2022
2702611
black test
jeniyat Mar 11, 2022
39cb429
black test
jeniyat Mar 11, 2022
49594ae
added exception check
jeniyat Mar 11, 2022
e590bf5
udpate named
jeniyat Mar 11, 2022
fe4bdd3
Merge branch 'aws:master' into jeniyat/hf-inf-neuron
jeniyat Mar 13, 2022
393fa75
add session value for HF model
jeniyat Mar 14, 2022
d993da0
update w/ fstring
jeniyat Mar 14, 2022
02445ba
update w/ fstring
jeniyat Mar 14, 2022
7210d2f
added to do
jeniyat Mar 14, 2022
33adabc
added to do
jeniyat Mar 14, 2022
db0607f
update comment
jeniyat Mar 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.session import Session

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -169,9 +170,121 @@ def __init__(
super(HuggingFaceModel, self).__init__(
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
)
self.sagemaker_session = self.sagemaker_session or Session()

self.model_server_workers = model_server_workers

def deploy(
self,
initial_instance_count=None,
instance_type=None,
serializer=None,
deserializer=None,
accelerator_type=None,
endpoint_name=None,
tags=None,
kms_key=None,
wait=True,
data_capture_config=None,
async_inference_config=None,
serverless_inference_config=None,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.

Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an
``Endpoint`` from this ``Model``. If ``self.predictor_cls`` is not None,
this method returns a the result of invoking ``self.predictor_cls`` on
the created endpoint name.

The name of the created model is accessible in the ``name`` field of
this ``Model`` after deploy returns

The name of the created endpoint is accessible in the
``endpoint_name`` field of this ``Model`` after deploy returns.

Args:
initial_instance_count (int): The initial number of instances to run
in the ``Endpoint`` created from this ``Model``. If not using
serverless inference, then it need to be a number larger or equals
to 1 (default: None)
instance_type (str): The EC2 instance type to deploy this Model to.
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
serverless inference, then it is required to deploy a model.
(default: None)
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
serializer object, used to encode data for an inference endpoint
(default: None). If ``serializer`` is not None, then
``serializer`` will override the default serializer. The
default serializer is set by the ``predictor_cls``.
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
deserializer object, used to decode data from an inference
endpoint (default: None). If ``deserializer`` is not None, then
``deserializer`` will override the default deserializer. The
default deserializer is set by the ``predictor_cls``.
accelerator_type (str): Type of Elastic Inference accelerator to
deploy this model for model loading and inference, for example,
'ml.eia1.medium'. If not specified, no Elastic Inference
accelerator will be attached to the endpoint. For more
information:
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
endpoint_name (str): The name of the endpoint to create (default:
None). If not specified, a unique endpoint name will be created.
tags (List[dict[str, str]]): The list of tags to attach to this
specific endpoint.
kms_key (str): The ARN of the KMS key that is used to encrypt the
data on the storage volume attached to the instance hosting the
endpoint.
wait (bool): Whether the call should wait until the deployment of
this model completes (default: True).
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
configuration related to Endpoint data capture for use with
Amazon SageMaker Model Monitoring. Default: None.
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
configuration related to async endpoint. Use this configuration when trying
to create async endpoint and make async inference. If empty config object
passed through, will use default config to deploy async endpoint. Deploy a
real-time endpoint if it's None. (default: None)
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Use this configuration
when trying to create serverless endpoint and make serverless inference. If
empty object passed through, will use pre-defined values in
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
instance based endpoint if it's None. (default: None)
Raises:
ValueError: If arguments combination check failed in these circumstances:
- If no role is specified or
- If serverless inference config is not specified and instance type and instance
count are also not specified or
- If a wrong type of object is provided as serverless inference config or async
inference config
Returns:
callable[string, sagemaker.session.Session] or None: Invocation of
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
is not None. Otherwise, return None.
"""

if not self.image_uri and instance_type.startswith("ml.inf"):
self.image_uri = self.serving_image_uri(
region_name=self.sagemaker_session.boto_session.region_name,
instance_type=instance_type,
)

return super(HuggingFaceModel, self).deploy(
initial_instance_count,
instance_type,
serializer,
deserializer,
accelerator_type,
endpoint_name,
tags,
kms_key,
wait,
data_capture_config,
async_inference_config,
serverless_inference_config,
)

def register(
self,
content_types,
Expand Down
44 changes: 44 additions & 0 deletions src/sagemaker/image_uri_config/huggingface-neuron.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"inference": {
"processors": ["inf"],
"version_aliases": {"4.12": "4.12.3"},
"versions": {
"4.12.3": {
"version_aliases": {"pytorch1.9": "pytorch1.9.1"},
"pytorch1.9.1": {
"py_versions": ["py37"],
"repository": "huggingface-pytorch-inference-neuron",
"registries": {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"container_version": {"inf": "ubuntu18.04"},
"sdk_versions": ["sdk1.17.1"]
}
}
}
}
}
67 changes: 53 additions & 14 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from sagemaker.spark import defaults
from sagemaker.jumpstart import artifacts


logger = logging.getLogger(__name__)

ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
Expand All @@ -47,6 +46,8 @@ def retrieve(
model_version=None,
tolerate_vulnerable_model=False,
tolerate_deprecated_model=False,
sdk_version=None,
inference_tool=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -88,6 +89,11 @@ def retrieve(
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
should be tolerated without an exception raised. If False, raises an exception
if the version of the model is deprecated. (Default: False).
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
(default: None).
inference_tool (str): the tool that will be used to aid in the inference.
Valid values: "neuron, None"
(default: None).

Returns:
str: The ECR URI for the corresponding SageMaker Docker image.
Expand All @@ -100,7 +106,6 @@ def retrieve(
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if is_jumpstart_model_input(model_id, model_version):

return artifacts._retrieve_image_uri(
model_id,
model_version,
Expand All @@ -118,9 +123,13 @@ def retrieve(
tolerate_vulnerable_model,
tolerate_deprecated_model,
)

if training_compiler_config is None:
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
_framework = framework
if framework == HUGGING_FACE_FRAMEWORK:
inference_tool = _get_inference_tool(inference_tool, instance_type)
if inference_tool == "neuron":
_framework = f"{framework}-{inference_tool}"
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
elif framework == HUGGING_FACE_FRAMEWORK:
config = _config_for_framework_and_scope(
framework + "-training-compiler", image_scope, accelerator_type
Expand All @@ -129,6 +138,7 @@ def retrieve(
raise ValueError(
"Unsupported Configuration: Training Compiler is only supported with HuggingFace"
)

original_version = version
version = _validate_version_and_set_if_needed(version, config, framework)
version_config = config["versions"][_version_for_config(version, config)]
Expand All @@ -138,7 +148,6 @@ def retrieve(
full_base_framework_version = version_config["version_aliases"].get(
base_framework_version, base_framework_version
)

_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
version_config = version_config.get(full_base_framework_version)

Expand All @@ -161,25 +170,37 @@ def retrieve(
pt_or_tf_version = (
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
)

_version = original_version

if repo in [
"huggingface-pytorch-trcomp-training",
"huggingface-tensorflow-trcomp-training",
]:
_version = version
if repo in ["huggingface-pytorch-inference-neuron"]:
if not sdk_version:
sdk_version = _get_latest_versions(version_config["sdk_versions"])
container_version = sdk_version + "-" + container_version
if config.get("version_aliases").get(original_version):
_version = config.get("version_aliases")[original_version]
if (
config.get("versions", {})
.get(_version, {})
.get("version_aliases", {})
.get(base_framework_version, {})
):
_base_framework_version = config.get("versions")[_version]["version_aliases"][
base_framework_version
]
pt_or_tf_version = (
re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
)

tag_prefix = f"{pt_or_tf_version}-transformers{_version}"

else:
tag_prefix = version_config.get("tag_prefix", version)

tag = _format_tag(
tag_prefix,
processor,
py_version,
container_version,
)
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)

if _should_auto_select_container_version(instance_type, distribution):
container_versions = {
Expand Down Expand Up @@ -248,6 +269,20 @@ def config_for_framework(framework):
return json.load(f)


def _get_inference_tool(inference_tool, instance_type):
"""Extract the inference tool name from instance type."""
if not inference_tool and instance_type:
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match and match[1].startswith("inf"):
return "neuron"
return inference_tool


def _get_latest_versions(list_of_versions):
"""Extract the latest version from the input list of available versions."""
return sorted(list_of_versions, reverse=True)[0]


def _validate_accelerator_type(accelerator_type):
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook":
Expand Down Expand Up @@ -310,6 +345,8 @@ def _processor(instance_type, available_processors):

if instance_type.startswith("local"):
processor = "cpu" if instance_type == "local" else "gpu"
elif instance_type.startswith("neuron"):
processor = "neuron"
else:
# looks for either "ml.<family>.<size>" or "ml_<family>"
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
Expand Down Expand Up @@ -387,8 +424,10 @@ def _validate_arg(arg, available_options, arg_name):
)


def _format_tag(tag_prefix, processor, py_version, container_version):
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
"""Creates a tag for the image URI."""
if inference_tool:
return "-".join(x for x in (tag_prefix, inference_tool, py_version, container_version) if x)
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)


Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,21 @@ def huggingface_tensorflow_latest_training_py_version():
return "py37"


@pytest.fixture(scope="module")
def huggingface_neuron_latest_inference_pytorch_version():
return "1.9"


@pytest.fixture(scope="module")
def huggingface_neuron_latest_inference_transformer_version():
return "4.12"


@pytest.fixture(scope="module")
def huggingface_neuron_latest_inference_py_version():
return "py37"


@pytest.fixture(scope="module")
def pytorch_eia_py_version():
return "py3"
Expand Down
22 changes: 21 additions & 1 deletion tests/unit/sagemaker/huggingface/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
from mock import MagicMock, Mock, patch

from sagemaker.huggingface import HuggingFace
from sagemaker.huggingface import HuggingFace, HuggingFaceModel

from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION

Expand Down Expand Up @@ -252,6 +252,26 @@ def test_huggingface(
assert actual_train_args == expected_train_args


def test_huggingface_neuron(
sagemaker_session,
huggingface_neuron_latest_inference_pytorch_version,
huggingface_neuron_latest_inference_transformer_version,
huggingface_neuron_latest_inference_py_version,
):

inputs = "s3://mybucket/train"
huggingface_model = HuggingFaceModel(
model_data=inputs,
transformers_version=huggingface_neuron_latest_inference_transformer_version,
role=ROLE,
sagemaker_session=sagemaker_session,
pytorch_version=huggingface_neuron_latest_inference_pytorch_version,
py_version=huggingface_neuron_latest_inference_py_version,
)
container = huggingface_model.prepare_container_def("ml.inf.xlarge")
assert container["Image"]


def test_attach(
sagemaker_session,
huggingface_training_version,
Expand Down