Skip to content

Commit a717819

Browse files
author
Alexandre Duverger
committed
Log if handler service is using default or custom functions implementations
1 parent 80634b3 commit a717819

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

src/sagemaker_huggingface_inference_toolkit/handler_service.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434

3535
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
3636
PYTHON_PATH_ENV = "PYTHONPATH"
37+
MODEL_FN = "model_fn"
38+
INPUT_FN = "input_fn"
39+
PREDICT_FN = "predict_fn"
40+
OUTPUT_FN = "output_fn"
41+
TRANSFORM_FN = "transform_fn"
3742

3843
logger = logging.getLogger(__name__)
3944

@@ -272,35 +277,58 @@ def validate_and_initialize_user_module(self):
272277
"""
273278
user_module_name = self.environment.module_name
274279
if importlib.util.find_spec(user_module_name) is not None:
280+
logger.info("Inference script implementation found at `{}`.".format(user_module_name))
275281
user_module = importlib.import_module(user_module_name)
276282

277-
load_fn = getattr(user_module, "model_fn", None)
278-
preprocess_fn = getattr(user_module, "input_fn", None)
279-
predict_fn = getattr(user_module, "predict_fn", None)
280-
postprocess_fn = getattr(user_module, "output_fn", None)
281-
transform_fn = getattr(user_module, "transform_fn", None)
283+
load_fn = getattr(user_module, MODEL_FN, None)
284+
preprocess_fn = getattr(user_module, INPUT_FN, None)
285+
predict_fn = getattr(user_module, PREDICT_FN, None)
286+
postprocess_fn = getattr(user_module, OUTPUT_FN, None)
287+
transform_fn = getattr(user_module, TRANSFORM_FN, None)
282288

283289
if transform_fn and (preprocess_fn or predict_fn or postprocess_fn):
284290
raise ValueError(
285-
"Cannot use transform_fn implementation in conjunction with "
286-
"input_fn, predict_fn, and/or output_fn implementation"
291+
"Cannot use {} implementation in conjunction with {}, {}, and/or {} implementation".format(
292+
TRANSFORM_FN, INPUT_FN, PREDICT_FN, OUTPUT_FN
293+
)
287294
)
288-
295+
self.log_func_implementation_found_or_not(load_fn, MODEL_FN)
289296
if load_fn is not None:
290297
self.load_extra_arg = self.function_extra_arg(self.load, load_fn)
291298
self.load = load_fn
299+
self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN)
292300
if preprocess_fn is not None:
293301
self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn)
294302
self.preprocess = preprocess_fn
303+
self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN)
295304
if predict_fn is not None:
296305
self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn)
297306
self.predict = predict_fn
307+
self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN)
298308
if postprocess_fn is not None:
299309
self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn)
300310
self.postprocess = postprocess_fn
311+
self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN)
301312
if transform_fn is not None:
302313
self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn)
303314
self.transform_fn = transform_fn
315+
else:
316+
logger.info(
317+
"No inference script implementation was found at `{}`. Default implementation of all functions will be used.".format(
318+
user_module_name
319+
)
320+
)
321+
322+
@staticmethod
323+
def log_func_implementation_found_or_not(func, func_name):
324+
if func is not None:
325+
logger.info("`{}` implementation found. It will be used in place of the default one.".format(func_name))
326+
else:
327+
logger.info(
328+
"No `{}` implementation was found. The default one from the handler service will be used.".format(
329+
func_name
330+
)
331+
)
304332

305333
def function_extra_arg(self, default_func, func):
306334
"""Helper to call the handler function which covers 2 cases:

0 commit comments

Comments
 (0)