Skip to content

Commit 52adeb8

Browse files
author
EC2 Default User
committed
query hf api for model md
1 parent ef7c5a0 commit 52adeb8

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

src/sagemaker/huggingface/llm_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@
1515

1616
from typing import Optional
1717

18+
import urllib.request
19+
from urllib.error import HTTPError, URLError
20+
import json
21+
from json import JSONDecodeError
22+
import logging
1823
from sagemaker import image_uris
1924
from sagemaker.session import Session
2025

26+
logger = logging.getLogger(__name__)
27+
2128

2229
def get_huggingface_llm_image_uri(
2330
backend: str,
@@ -54,3 +61,42 @@ def get_huggingface_llm_image_uri(
5461
version = version or "0.24.0"
5562
return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
5663
raise ValueError("Unsupported backend: %s" % backend)
64+
65+
66+
def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = None) -> dict:
67+
"""Retrieves the json metadata of the HuggingFace Model via HuggingFace API.
68+
69+
Args:
70+
model_id (str): The HuggingFace Model ID
71+
hf_hub_token (str): The HuggingFace Hub Token needed for Private/Gated HuggingFace Models
72+
73+
Returns:
74+
dict: The model metadata retrieved with the HuggingFace API
75+
"""
76+
77+
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
78+
hf_model_metadata_json = None
79+
try:
80+
if hf_hub_token:
81+
hf_model_metadata_url = urllib.request.Request(
82+
hf_model_metadata_url, None, {"Authorization": "Bearer " + hf_hub_token}
83+
)
84+
with urllib.request.urlopen(hf_model_metadata_url) as response:
85+
hf_model_metadata_json = json.load(response)
86+
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
87+
if "HTTP Error 401: Unauthorized" in str(e):
88+
raise ValueError(
89+
"Trying to access a gated/private HuggingFace model without valid credentials. "
90+
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
91+
)
92+
logger.warning(
93+
"Exception encountered while trying to retrieve HuggingFace model metadata %s. "
94+
"Details: %s",
95+
hf_model_metadata_url,
96+
e,
97+
)
98+
if not hf_model_metadata_json:
99+
raise ValueError(
100+
"Did not find model metadata for the following HuggingFace Model ID %s" % model_id
101+
)
102+
return hf_model_metadata_json
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from unittest import TestCase
16+
from urllib.error import HTTPError
17+
from unittest.mock import Mock, patch
18+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
19+
20+
MOCK_HF_ID = "mock_hf_id"
21+
MOCK_HF_HUB_TOKEN = "mock_hf_hub_token"
22+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
23+
24+
25+
class LlmUtilsTests(TestCase):
26+
@patch("sagemaker.huggingface.llm_utils.urllib")
27+
@patch("sagemaker.huggingface.llm_utils.json")
28+
def test_huggingface_model_metadata_success(self, mock_json, mock_urllib):
29+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
30+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID)
31+
32+
mock_urllib.request.urlopen.assert_called_once_with(
33+
f"https://huggingface.co/api/models/{MOCK_HF_ID}"
34+
)
35+
self.assertEqual(ret_json["mock_key"], "mock_value")
36+
37+
@patch("sagemaker.huggingface.llm_utils.urllib")
38+
@patch("sagemaker.huggingface.llm_utils.json")
39+
def test_huggingface_model_metadata_gated_success(self, mock_json, mock_urllib):
40+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
41+
mock_hf_model_metadata_url = Mock()
42+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
43+
44+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
45+
46+
mock_urllib.request.Request.assert_called_once_with(
47+
f"https://huggingface.co/api/models/{MOCK_HF_ID}",
48+
None,
49+
{"Authorization": "Bearer " + MOCK_HF_HUB_TOKEN},
50+
)
51+
self.assertEqual(ret_json["mock_key"], "mock_value")
52+
53+
@patch("sagemaker.huggingface.llm_utils.urllib")
54+
def test_huggingface_model_metadata_unauthorized_exception(self, mock_urllib):
55+
mock_urllib.request.urlopen.side_effect = HTTPError(
56+
code=401, msg="Unauthorized", url=None, hdrs=None, fp=None
57+
)
58+
with self.assertRaises(ValueError) as context:
59+
get_huggingface_model_metadata(MOCK_HF_ID)
60+
61+
expected_error_msg = (
62+
"Trying to access a gated/private HuggingFace model without valid credentials. "
63+
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
64+
)
65+
self.assertEquals(expected_error_msg, str(context.exception))
66+
67+
@patch("sagemaker.huggingface.llm_utils.urllib")
68+
def test_huggingface_model_metadata_general_exception(self, mock_urllib):
69+
mock_urllib.request.urlopen.side_effect = TimeoutError("timed out")
70+
with self.assertRaises(ValueError) as context:
71+
get_huggingface_model_metadata(MOCK_HF_ID)
72+
73+
expected_error_msg = (
74+
f"Did not find model metadata for the following HuggingFace Model ID {MOCK_HF_ID}"
75+
)
76+
self.assertEquals(expected_error_msg, str(context.exception))

0 commit comments

Comments
 (0)