Skip to content

Commit 68bd19b

Browse files
Merge branch 'master' into bump-airflow
2 parents f633b15 + 71a4c58 commit 68bd19b

File tree

4 files changed

+132
-13
lines changed

4 files changed

+132
-13
lines changed

src/sagemaker/huggingface/llm_utils.py

+46
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@
1515

1616
from typing import Optional
1717

18+
import urllib.request
19+
from urllib.error import HTTPError, URLError
20+
import json
21+
from json import JSONDecodeError
22+
import logging
1823
from sagemaker import image_uris
1924
from sagemaker.session import Session
2025

26+
logger = logging.getLogger(__name__)
27+
2128

2229
def get_huggingface_llm_image_uri(
2330
backend: str,
@@ -54,3 +61,42 @@ def get_huggingface_llm_image_uri(
5461
version = version or "0.24.0"
5562
return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
5663
raise ValueError("Unsupported backend: %s" % backend)
64+
65+
66+
def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = None) -> dict:
67+
"""Retrieves the json metadata of the HuggingFace Model via HuggingFace API.
68+
69+
Args:
70+
model_id (str): The HuggingFace Model ID
71+
hf_hub_token (str): The HuggingFace Hub Token needed for Private/Gated HuggingFace Models
72+
73+
Returns:
74+
dict: The model metadata retrieved with the HuggingFace API
75+
"""
76+
77+
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
78+
hf_model_metadata_json = None
79+
try:
80+
if hf_hub_token:
81+
hf_model_metadata_url = urllib.request.Request(
82+
hf_model_metadata_url, None, {"Authorization": "Bearer " + hf_hub_token}
83+
)
84+
with urllib.request.urlopen(hf_model_metadata_url) as response:
85+
hf_model_metadata_json = json.load(response)
86+
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
87+
if "HTTP Error 401: Unauthorized" in str(e):
88+
raise ValueError(
89+
"Trying to access a gated/private HuggingFace model without valid credentials. "
90+
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
91+
)
92+
logger.warning(
93+
"Exception encountered while trying to retrieve HuggingFace model metadata %s. "
94+
"Details: %s",
95+
hf_model_metadata_url,
96+
e,
97+
)
98+
if not hf_model_metadata_json:
99+
raise ValueError(
100+
"Did not find model metadata for the following HuggingFace Model ID %s" % model_id
101+
)
102+
return hf_model_metadata_json

tests/integ/sagemaker/workflow/test_notebook_job_step.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tarfile
1818
import logging
1919
import nbformat as nbf
20+
import pytest
2021

2122
from sagemaker import get_execution_role
2223
from sagemaker.s3 import S3Downloader
@@ -125,6 +126,9 @@ def verify_notebook_for_happy_case(cells):
125126
logging.error(error)
126127

127128

129+
@pytest.mark.skip(
130+
reason="This test is skipped temporarily due to failures. Need to re-enable later after fix."
131+
)
128132
def test_notebook_job_with_more_configuration(sagemaker_session):
129133
"""This test case is for more complex job configuration.
130134
1. a parent notebook file with %run magic to execute 'subfolder/sub.ipynb' and the

tests/integ/test_inference_component_based_endpoint.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import os
1616
import sagemaker.predictor
1717
import sagemaker.utils
18-
import tests.integ
1918
import pytest
2019

2120
from sagemaker import image_uris
@@ -114,10 +113,8 @@ def xgboost_model(sagemaker_session, resources, model_update_to_name):
114113
return xgb_model
115114

116115

117-
@pytest.mark.release
118-
@pytest.mark.skipif(
119-
tests.integ.test_region() not in tests.integ.INFERENCE_COMPONENT_SUPPORTED_REGIONS,
120-
reason="inference component based endpoint is not supported in certain regions",
116+
@pytest.mark.skip(
117+
reason="This test is skipped temporarily due to failures. Need to re-enable later after fix."
121118
)
122119
def test_deploy_single_model_with_endpoint_name(tfs_model, resources):
123120
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
@@ -145,10 +142,8 @@ def test_deploy_single_model_with_endpoint_name(tfs_model, resources):
145142
predictor.delete_endpoint()
146143

147144

148-
@pytest.mark.release
149-
@pytest.mark.skipif(
150-
tests.integ.test_region() not in tests.integ.INFERENCE_COMPONENT_SUPPORTED_REGIONS,
151-
reason="inference component based endpoint is not supported in certain regions",
145+
@pytest.mark.skip(
146+
reason="This test is skipped temporarily due to failures. Need to re-enable later after fix."
152147
)
153148
def test_deploy_update_predictor_with_other_model(
154149
tfs_model,
@@ -206,10 +201,8 @@ def test_deploy_update_predictor_with_other_model(
206201
predictor_to_update.delete_endpoint()
207202

208203

209-
@pytest.mark.release
210-
@pytest.mark.skipif(
211-
tests.integ.test_region() not in tests.integ.INFERENCE_COMPONENT_SUPPORTED_REGIONS,
212-
reason="inference component based endpoint is not supported in certain regions",
204+
@pytest.mark.skip(
205+
reason="This test is skipped temporarily due to failures. Need to re-enable later after fix."
213206
)
214207
def test_deploy_multi_models_without_endpoint_name(tfs_model, resources):
215208
input_data = {"instances": [1.0, 2.0, 5.0]}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from unittest import TestCase
16+
from urllib.error import HTTPError
17+
from unittest.mock import Mock, patch
18+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
19+
20+
MOCK_HF_ID = "mock_hf_id"
21+
MOCK_HF_HUB_TOKEN = "mock_hf_hub_token"
22+
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
23+
24+
25+
class LlmUtilsTests(TestCase):
26+
@patch("sagemaker.huggingface.llm_utils.urllib")
27+
@patch("sagemaker.huggingface.llm_utils.json")
28+
def test_huggingface_model_metadata_success(self, mock_json, mock_urllib):
29+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
30+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID)
31+
32+
mock_urllib.request.urlopen.assert_called_once_with(
33+
f"https://huggingface.co/api/models/{MOCK_HF_ID}"
34+
)
35+
self.assertEqual(ret_json["mock_key"], "mock_value")
36+
37+
@patch("sagemaker.huggingface.llm_utils.urllib")
38+
@patch("sagemaker.huggingface.llm_utils.json")
39+
def test_huggingface_model_metadata_gated_success(self, mock_json, mock_urllib):
40+
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
41+
mock_hf_model_metadata_url = Mock()
42+
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
43+
44+
ret_json = get_huggingface_model_metadata(MOCK_HF_ID, MOCK_HF_HUB_TOKEN)
45+
46+
mock_urllib.request.Request.assert_called_once_with(
47+
f"https://huggingface.co/api/models/{MOCK_HF_ID}",
48+
None,
49+
{"Authorization": "Bearer " + MOCK_HF_HUB_TOKEN},
50+
)
51+
self.assertEqual(ret_json["mock_key"], "mock_value")
52+
53+
@patch("sagemaker.huggingface.llm_utils.urllib")
54+
def test_huggingface_model_metadata_unauthorized_exception(self, mock_urllib):
55+
mock_urllib.request.urlopen.side_effect = HTTPError(
56+
code=401, msg="Unauthorized", url=None, hdrs=None, fp=None
57+
)
58+
with self.assertRaises(ValueError) as context:
59+
get_huggingface_model_metadata(MOCK_HF_ID)
60+
61+
expected_error_msg = (
62+
"Trying to access a gated/private HuggingFace model without valid credentials. "
63+
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
64+
)
65+
self.assertEquals(expected_error_msg, str(context.exception))
66+
67+
@patch("sagemaker.huggingface.llm_utils.urllib")
68+
def test_huggingface_model_metadata_general_exception(self, mock_urllib):
69+
mock_urllib.request.urlopen.side_effect = TimeoutError("timed out")
70+
with self.assertRaises(ValueError) as context:
71+
get_huggingface_model_metadata(MOCK_HF_ID)
72+
73+
expected_error_msg = (
74+
f"Did not find model metadata for the following HuggingFace Model ID {MOCK_HF_ID}"
75+
)
76+
self.assertEquals(expected_error_msg, str(context.exception))

0 commit comments

Comments
 (0)