From 297bce138deaa64b8bad25640db4fda466918a6c Mon Sep 17 00:00:00 2001 From: Btara Truhandarien Date: Fri, 31 May 2024 16:06:41 -0400 Subject: [PATCH] feat(initialize): default to first GPU when gpu_id not provided --- .../handler_service.py | 5 ++++- tests/unit/test_handler_service_with_context.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/sagemaker_huggingface_inference_toolkit/handler_service.py b/src/sagemaker_huggingface_inference_toolkit/handler_service.py index e4ef6c6..00b9b34 100644 --- a/src/sagemaker_huggingface_inference_toolkit/handler_service.py +++ b/src/sagemaker_huggingface_inference_toolkit/handler_service.py @@ -99,7 +99,10 @@ def get_device(self): The get device function will return the device for the DL Framework. """ if _is_gpu_available(): - return int(self.context.system_properties.get("gpu_id")) + # there may be cases when gpu_id isn't provided, in which case + # then we default to the first GPU + gpu_id = self.context.system_properties.get('gpu_id') or 0 + return int(gpu_id) else: return -1 diff --git a/tests/unit/test_handler_service_with_context.py b/tests/unit/test_handler_service_with_context.py index a8b5b71..3d448b6 100644 --- a/tests/unit/test_handler_service_with_context.py +++ b/tests/unit/test_handler_service_with_context.py @@ -23,7 +23,7 @@ from mms.metrics.metrics_store import MetricsStore from mock import Mock from sagemaker_huggingface_inference_toolkit import handler_service -from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline +from sagemaker_huggingface_inference_toolkit.transformers_utils import _is_gpu_available, _load_model_from_hub, get_pipeline TASK = "text-classification" @@ -63,6 +63,20 @@ def test_test_initialize(inference_handler): inference_handler.initialize(CONTEXT) assert inference_handler.initialized is True +@require_torch +@pytest.mark.skipif(not _is_gpu_available(), reason="No GPU available") +@slow +def test_initialize_without_gpu_id_fallback_to_first_gpu(inference_handler): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_model_from_hub( + model_id=MODEL, + model_dir=tmpdirname, + ) + CONTEXT = Context(MODEL, storage_folder, {}, 1, None, "1.1.4") + + inference_handler.initialize(CONTEXT) + assert inference_handler.initialized is True + assert inference_handler.device == 0 @require_torch def test_handle(inference_handler):