diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index aef5e5e585..65befe41b0 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -15,9 +15,16 @@ from typing import Optional +import urllib.request +from urllib.error import HTTPError, URLError +import json +from json import JSONDecodeError +import logging from sagemaker import image_uris from sagemaker.session import Session +logger = logging.getLogger(__name__) + def get_huggingface_llm_image_uri( backend: str, @@ -54,3 +61,42 @@ def get_huggingface_llm_image_uri( version = version or "0.24.0" return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version) raise ValueError("Unsupported backend: %s" % backend) + + +def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = None) -> dict: + """Retrieves the json metadata of the HuggingFace Model via HuggingFace API. + + Args: + model_id (str): The HuggingFace Model ID + hf_hub_token (str): The HuggingFace Hub Token needed for Private/Gated HuggingFace Models + + Returns: + dict: The model metadata retrieved with the HuggingFace API + """ + + hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}" + hf_model_metadata_json = None + try: + if hf_hub_token: + hf_model_metadata_url = urllib.request.Request( + hf_model_metadata_url, None, {"Authorization": "Bearer " + hf_hub_token} + ) + with urllib.request.urlopen(hf_model_metadata_url) as response: + hf_model_metadata_json = json.load(response) + except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e: + if "HTTP Error 401: Unauthorized" in str(e): + raise ValueError( + "Trying to access a gated/private HuggingFace model without valid credentials. " + "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars" + ) + logger.warning( + "Exception encountered while trying to retrieve HuggingFace model metadata %s. " + "Details: %s", + hf_model_metadata_url, + e, + ) + if not hf_model_metadata_json: + raise ValueError( + "Did not find model metadata for the following HuggingFace Model ID %s" % model_id + ) + return hf_model_metadata_json diff --git a/tests/unit/sagemaker/huggingface/test_llm_utils.py b/tests/unit/sagemaker/huggingface/test_llm_utils.py new file mode 100644 index 0000000000..3c4cdde3f6 --- /dev/null +++ b/tests/unit/sagemaker/huggingface/test_llm_utils.py @@ -0,0 +1,76 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from urllib.error import HTTPError +from unittest.mock import Mock, patch +from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata + +MOCK_HF_ID = "mock_hf_id" +MOCK_HF_HUB_TOKEN = "mock_hf_hub_token" +MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"} + + +class LlmUtilsTests(TestCase): + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + def test_huggingface_model_metadata_success(self, mock_json, mock_urllib): + mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON + ret_json = get_huggingface_model_metadata(MOCK_HF_ID) + + mock_urllib.request.urlopen.assert_called_once_with( + f"https://huggingface.co/api/models/{MOCK_HF_ID}" + ) + self.assertEqual(ret_json["mock_key"], "mock_value") + + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + def test_huggingface_model_metadata_gated_success(self, mock_json, mock_urllib): + mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON + mock_hf_model_metadata_url = Mock() + mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url + + ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN) + + mock_urllib.request.Request.assert_called_once_with( + f"https://huggingface.co/api/models/{MOCK_HF_ID}", + None, + {"Authorization": "Bearer " + MOCK_HF_HUB_TOKEN}, + ) + self.assertEqual(ret_json["mock_key"], "mock_value") + + @patch("sagemaker.huggingface.llm_utils.urllib") + def test_huggingface_model_metadata_unauthorized_exception(self, mock_urllib): + mock_urllib.request.urlopen.side_effect = HTTPError( + code=401, msg="Unauthorized", url=None, hdrs=None, fp=None + ) + with self.assertRaises(ValueError) as context: + get_huggingface_model_metadata(MOCK_HF_ID) + + expected_error_msg = ( + "Trying to access a gated/private HuggingFace model without valid credentials. " + "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars" + ) + self.assertEquals(expected_error_msg, str(context.exception)) + + @patch("sagemaker.huggingface.llm_utils.urllib") + def test_huggingface_model_metadata_general_exception(self, mock_urllib): + mock_urllib.request.urlopen.side_effect = TimeoutError("timed out") + with self.assertRaises(ValueError) as context: + get_huggingface_model_metadata(MOCK_HF_ID) + + expected_error_msg = ( + f"Did not find model metadata for the following HuggingFace Model ID {MOCK_HF_ID}" + ) + self.assertEquals(expected_error_msg, str(context.exception))