Skip to content

Commit d8e031f

Browse files
bug fix in inference_endpoint wait function for proper waiting on update (#2867)
* bug fix in inference_endpoint wait for proper waiting on update * Update src/huggingface_hub/_inference_endpoints.py improve code clarity and added logging based on review Co-authored-by: Célina <[email protected]> * changes in infernce_endpoint wait function for robust behaviour and addition of test case in test_inference_endpoint for testing changes in wait function * changes in test case test_wait_update --------- Co-authored-by: Célina <[email protected]>
1 parent c280a3c commit d8e031f

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

src/huggingface_hub/_inference_endpoints.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,21 @@ def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "Infere
207207

208208
start = time.time()
209209
while True:
210-
if self.url is not None:
211-
# Means the URL is provisioned => check if the endpoint is reachable
212-
response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token))
213-
if response.status_code == 200:
214-
logger.info("Inference Endpoint is ready to be used.")
215-
return self
216210
if self.status == InferenceEndpointStatus.FAILED:
217211
raise InferenceEndpointError(
218212
f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information."
219213
)
214+
if self.status == InferenceEndpointStatus.UPDATE_FAILED:
215+
raise InferenceEndpointError(
216+
f"Inference Endpoint {self.name} failed to update. Please check the logs for more information."
217+
)
218+
if self.status == InferenceEndpointStatus.RUNNING and self.url is not None:
219+
# Verify the endpoint is actually reachable
220+
response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token))
221+
if response.status_code == 200:
222+
logger.info("Inference Endpoint is ready to be used.")
223+
return self
224+
220225
if timeout is not None:
221226
if time.time() - start > timeout:
222227
raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.")

tests/test_inference_endpoints.py

+57-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime, timezone
2-
from unittest.mock import Mock, patch
2+
from itertools import chain, repeat
3+
from unittest.mock import MagicMock, Mock, patch
34

45
import pytest
56

@@ -109,6 +110,39 @@
109110
"targetReplica": 1,
110111
},
111112
}
113+
# added for test_wait_update function
114+
MOCK_UPDATE = {
115+
"name": "my-endpoint-name",
116+
"type": "protected",
117+
"accountId": None,
118+
"provider": {"vendor": "aws", "region": "us-east-1"},
119+
"compute": {
120+
"accelerator": "cpu",
121+
"instanceType": "intel-icl",
122+
"instanceSize": "x2",
123+
"scaling": {"minReplica": 0, "maxReplica": 1},
124+
},
125+
"model": {
126+
"repository": "gpt2",
127+
"revision": "11c5a3d5811f50298f278a704980280950aedb10",
128+
"task": "text-generation",
129+
"framework": "pytorch",
130+
"image": {"huggingface": {}},
131+
"secret": {"token": "my-token"},
132+
},
133+
"status": {
134+
"createdAt": "2023-10-26T12:41:53.263078506Z",
135+
"createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"},
136+
"updatedAt": "2023-10-26T12:41:53.263079138Z",
137+
"updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"},
138+
"private": None,
139+
"state": "updating",
140+
"url": "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud",
141+
"message": "Endpoint waiting for the update",
142+
"readyReplica": 0,
143+
"targetReplica": 1,
144+
},
145+
}
112146

113147

114148
def test_from_raw_initialization():
@@ -189,7 +223,7 @@ def test_fetch(mock_get: Mock):
189223
@patch("huggingface_hub._inference_endpoints.get_session")
190224
@patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint")
191225
def test_wait_until_running(mock_get: Mock, mock_session: Mock):
192-
"""Test waits waits until the endpoint is ready."""
226+
"""Test waits until the endpoint is ready."""
193227
endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo")
194228

195229
mock_get.side_effect = [
@@ -244,6 +278,27 @@ def test_wait_failed(mock_get: Mock):
244278
endpoint.wait(refresh_every=0.001)
245279

246280

281+
@patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint")
282+
@patch("huggingface_hub._inference_endpoints.get_session")
283+
def test_wait_update(mock_get_session, mock_get_inference_endpoint):
284+
"""Test that wait() returns when the endpoint transitions to running."""
285+
endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo")
286+
# Create an iterator that yields three MOCK_UPDATE responses,and then infinitely yields MOCK_RUNNING responses.
287+
responses = chain(
288+
[InferenceEndpoint.from_raw(MOCK_UPDATE, namespace="foo")] * 3,
289+
repeat(InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo")),
290+
)
291+
mock_get_inference_endpoint.side_effect = lambda *args, **kwargs: next(responses)
292+
293+
# Patch the get_session().get() call to always return a fake response with status_code 200.
294+
fake_response = MagicMock()
295+
fake_response.status_code = 200
296+
mock_get_session.return_value.get.return_value = fake_response
297+
298+
endpoint.wait(refresh_every=0.05)
299+
assert endpoint.status == "running"
300+
301+
247302
@patch("huggingface_hub.hf_api.HfApi.pause_inference_endpoint")
248303
def test_pause(mock: Mock):
249304
"""Test `pause` calls the correct alias."""

0 commit comments

Comments
 (0)