|
1 | 1 | from datetime import datetime, timezone
|
2 |
| -from unittest.mock import Mock, patch |
| 2 | +from itertools import chain, repeat |
| 3 | +from unittest.mock import MagicMock, Mock, patch |
3 | 4 |
|
4 | 5 | import pytest
|
5 | 6 |
|
|
109 | 110 | "targetReplica": 1,
|
110 | 111 | },
|
111 | 112 | }
|
| 113 | +# added for test_wait_update function |
| 114 | +MOCK_UPDATE = { |
| 115 | + "name": "my-endpoint-name", |
| 116 | + "type": "protected", |
| 117 | + "accountId": None, |
| 118 | + "provider": {"vendor": "aws", "region": "us-east-1"}, |
| 119 | + "compute": { |
| 120 | + "accelerator": "cpu", |
| 121 | + "instanceType": "intel-icl", |
| 122 | + "instanceSize": "x2", |
| 123 | + "scaling": {"minReplica": 0, "maxReplica": 1}, |
| 124 | + }, |
| 125 | + "model": { |
| 126 | + "repository": "gpt2", |
| 127 | + "revision": "11c5a3d5811f50298f278a704980280950aedb10", |
| 128 | + "task": "text-generation", |
| 129 | + "framework": "pytorch", |
| 130 | + "image": {"huggingface": {}}, |
| 131 | + "secret": {"token": "my-token"}, |
| 132 | + }, |
| 133 | + "status": { |
| 134 | + "createdAt": "2023-10-26T12:41:53.263078506Z", |
| 135 | + "createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, |
| 136 | + "updatedAt": "2023-10-26T12:41:53.263079138Z", |
| 137 | + "updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"}, |
| 138 | + "private": None, |
| 139 | + "state": "updating", |
| 140 | + "url": "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud", |
| 141 | + "message": "Endpoint waiting for the update", |
| 142 | + "readyReplica": 0, |
| 143 | + "targetReplica": 1, |
| 144 | + }, |
| 145 | +} |
112 | 146 |
|
113 | 147 |
|
114 | 148 | def test_from_raw_initialization():
|
@@ -189,7 +223,7 @@ def test_fetch(mock_get: Mock):
|
189 | 223 | @patch("huggingface_hub._inference_endpoints.get_session")
|
190 | 224 | @patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint")
|
191 | 225 | def test_wait_until_running(mock_get: Mock, mock_session: Mock):
|
192 |
| - """Test waits waits until the endpoint is ready.""" |
| 226 | + """Test waits until the endpoint is ready.""" |
193 | 227 | endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo")
|
194 | 228 |
|
195 | 229 | mock_get.side_effect = [
|
@@ -244,6 +278,27 @@ def test_wait_failed(mock_get: Mock):
|
244 | 278 | endpoint.wait(refresh_every=0.001)
|
245 | 279 |
|
246 | 280 |
|
| 281 | +@patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint") |
| 282 | +@patch("huggingface_hub._inference_endpoints.get_session") |
| 283 | +def test_wait_update(mock_get_session, mock_get_inference_endpoint): |
| 284 | + """Test that wait() returns when the endpoint transitions to running.""" |
| 285 | + endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo") |
| 286 | + # Create an iterator that yields three MOCK_UPDATE responses,and then infinitely yields MOCK_RUNNING responses. |
| 287 | + responses = chain( |
| 288 | + [InferenceEndpoint.from_raw(MOCK_UPDATE, namespace="foo")] * 3, |
| 289 | + repeat(InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo")), |
| 290 | + ) |
| 291 | + mock_get_inference_endpoint.side_effect = lambda *args, **kwargs: next(responses) |
| 292 | + |
| 293 | + # Patch the get_session().get() call to always return a fake response with status_code 200. |
| 294 | + fake_response = MagicMock() |
| 295 | + fake_response.status_code = 200 |
| 296 | + mock_get_session.return_value.get.return_value = fake_response |
| 297 | + |
| 298 | + endpoint.wait(refresh_every=0.05) |
| 299 | + assert endpoint.status == "running" |
| 300 | + |
| 301 | + |
247 | 302 | @patch("huggingface_hub.hf_api.HfApi.pause_inference_endpoint")
|
248 | 303 | def test_pause(mock: Mock):
|
249 | 304 | """Test `pause` calls the correct alias."""
|
|
0 commit comments