Skip to content

Feat: function to call hf api for model md #4346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/sagemaker/huggingface/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@

from typing import Optional

import urllib.request
from urllib.error import HTTPError, URLError
import json
from json import JSONDecodeError
import logging
from sagemaker import image_uris
from sagemaker.session import Session

logger = logging.getLogger(__name__)


def get_huggingface_llm_image_uri(
backend: str,
Expand Down Expand Up @@ -54,3 +61,42 @@ def get_huggingface_llm_image_uri(
version = version or "0.24.0"
return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
raise ValueError("Unsupported backend: %s" % backend)


def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = None) -> dict:
"""Retrieves the json metadata of the HuggingFace Model via HuggingFace API.

Args:
model_id (str): The HuggingFace Model ID
hf_hub_token (str): The HuggingFace Hub Token needed for Private/Gated HuggingFace Models

Returns:
dict: The model metadata retrieved with the HuggingFace API
"""

hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
hf_model_metadata_json = None
try:
if hf_hub_token:
hf_model_metadata_url = urllib.request.Request(
hf_model_metadata_url, None, {"Authorization": "Bearer " + hf_hub_token}
)
with urllib.request.urlopen(hf_model_metadata_url) as response:
hf_model_metadata_json = json.load(response)
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
if "HTTP Error 401: Unauthorized" in str(e):
raise ValueError(
"Trying to access a gated/private HuggingFace model without valid credentials. "
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
)
logger.warning(
"Exception encountered while trying to retrieve HuggingFace model metadata %s. "
"Details: %s",
hf_model_metadata_url,
e,
)
if not hf_model_metadata_json:
raise ValueError(
"Did not find model metadata for the following HuggingFace Model ID %s" % model_id
)
return hf_model_metadata_json
76 changes: 76 additions & 0 deletions tests/unit/sagemaker/huggingface/test_llm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from unittest import TestCase
from urllib.error import HTTPError
from unittest.mock import Mock, patch
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata

MOCK_HF_ID = "mock_hf_id"
MOCK_HF_HUB_TOKEN = "mock_hf_hub_token"
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}


class LlmUtilsTests(TestCase):
@patch("sagemaker.huggingface.llm_utils.urllib")
@patch("sagemaker.huggingface.llm_utils.json")
def test_huggingface_model_metadata_success(self, mock_json, mock_urllib):
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
ret_json = get_huggingface_model_metadata(MOCK_HF_ID)

mock_urllib.request.urlopen.assert_called_once_with(
f"https://huggingface.co/api/models/{MOCK_HF_ID}"
)
self.assertEqual(ret_json["mock_key"], "mock_value")

@patch("sagemaker.huggingface.llm_utils.urllib")
@patch("sagemaker.huggingface.llm_utils.json")
def test_huggingface_model_metadata_gated_success(self, mock_json, mock_urllib):
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
mock_hf_model_metadata_url = Mock()
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url

ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)

mock_urllib.request.Request.assert_called_once_with(
f"https://huggingface.co/api/models/{MOCK_HF_ID}",
None,
{"Authorization": "Bearer " + MOCK_HF_HUB_TOKEN},
)
self.assertEqual(ret_json["mock_key"], "mock_value")

@patch("sagemaker.huggingface.llm_utils.urllib")
def test_huggingface_model_metadata_unauthorized_exception(self, mock_urllib):
mock_urllib.request.urlopen.side_effect = HTTPError(
code=401, msg="Unauthorized", url=None, hdrs=None, fp=None
)
with self.assertRaises(ValueError) as context:
get_huggingface_model_metadata(MOCK_HF_ID)

expected_error_msg = (
"Trying to access a gated/private HuggingFace model without valid credentials. "
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
)
self.assertEquals(expected_error_msg, str(context.exception))

@patch("sagemaker.huggingface.llm_utils.urllib")
def test_huggingface_model_metadata_general_exception(self, mock_urllib):
mock_urllib.request.urlopen.side_effect = TimeoutError("timed out")
with self.assertRaises(ValueError) as context:
get_huggingface_model_metadata(MOCK_HF_ID)

expected_error_msg = (
f"Did not find model metadata for the following HuggingFace Model ID {MOCK_HF_ID}"
)
self.assertEquals(expected_error_msg, str(context.exception))