|
34 | 34 |
|
35 | 35 | ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
|
36 | 36 | PYTHON_PATH_ENV = "PYTHONPATH"
|
| 37 | +MODEL_FN = "model_fn" |
| 38 | +INPUT_FN = "input_fn" |
| 39 | +PREDICT_FN = "predict_fn" |
| 40 | +OUTPUT_FN = "output_fn" |
| 41 | +TRANSFORM_FN = "transform_fn" |
37 | 42 |
|
38 | 43 | logger = logging.getLogger(__name__)
|
39 | 44 |
|
@@ -272,35 +277,58 @@ def validate_and_initialize_user_module(self):
|
272 | 277 | """
|
273 | 278 | user_module_name = self.environment.module_name
|
274 | 279 | if importlib.util.find_spec(user_module_name) is not None:
|
| 280 | + logger.info("Inference script implementation found at `{}`.".format(user_module_name)) |
275 | 281 | user_module = importlib.import_module(user_module_name)
|
276 | 282 |
|
277 |
| - load_fn = getattr(user_module, "model_fn", None) |
278 |
| - preprocess_fn = getattr(user_module, "input_fn", None) |
279 |
| - predict_fn = getattr(user_module, "predict_fn", None) |
280 |
| - postprocess_fn = getattr(user_module, "output_fn", None) |
281 |
| - transform_fn = getattr(user_module, "transform_fn", None) |
| 283 | + load_fn = getattr(user_module, MODEL_FN, None) |
| 284 | + preprocess_fn = getattr(user_module, INPUT_FN, None) |
| 285 | + predict_fn = getattr(user_module, PREDICT_FN, None) |
| 286 | + postprocess_fn = getattr(user_module, OUTPUT_FN, None) |
| 287 | + transform_fn = getattr(user_module, TRANSFORM_FN, None) |
282 | 288 |
|
283 | 289 | if transform_fn and (preprocess_fn or predict_fn or postprocess_fn):
|
284 | 290 | raise ValueError(
|
285 |
| - "Cannot use transform_fn implementation in conjunction with " |
286 |
| - "input_fn, predict_fn, and/or output_fn implementation" |
| 291 | + "Cannot use {} implementation in conjunction with {}, {}, and/or {} implementation".format( |
| 292 | + TRANSFORM_FN, INPUT_FN, PREDICT_FN, OUTPUT_FN |
| 293 | + ) |
287 | 294 | )
|
288 |
| - |
| 295 | + self.log_func_implementation_found_or_not(load_fn, MODEL_FN) |
289 | 296 | if load_fn is not None:
|
290 | 297 | self.load_extra_arg = self.function_extra_arg(self.load, load_fn)
|
291 | 298 | self.load = load_fn
|
| 299 | + self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN) |
292 | 300 | if preprocess_fn is not None:
|
293 | 301 | self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn)
|
294 | 302 | self.preprocess = preprocess_fn
|
| 303 | + self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN) |
295 | 304 | if predict_fn is not None:
|
296 | 305 | self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn)
|
297 | 306 | self.predict = predict_fn
|
| 307 | + self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN) |
298 | 308 | if postprocess_fn is not None:
|
299 | 309 | self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn)
|
300 | 310 | self.postprocess = postprocess_fn
|
| 311 | + self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN) |
301 | 312 | if transform_fn is not None:
|
302 | 313 | self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn)
|
303 | 314 | self.transform_fn = transform_fn
|
| 315 | + else: |
| 316 | + logger.info( |
| 317 | + "No inference script implementation was found at `{}`. Default implementation of all functions will be used.".format( |
| 318 | + user_module_name |
| 319 | + ) |
| 320 | + ) |
| 321 | + |
| 322 | + @staticmethod |
| 323 | + def log_func_implementation_found_or_not(func, func_name): |
| 324 | + if func is not None: |
| 325 | + logger.info("`{}` implementation found. It will be used in place of the default one.".format(func_name)) |
| 326 | + else: |
| 327 | + logger.info( |
| 328 | + "No `{}` implementation was found. The default one from the handler service will be used.".format( |
| 329 | + func_name |
| 330 | + ) |
| 331 | + ) |
304 | 332 |
|
305 | 333 | def function_extra_arg(self, default_func, func):
|
306 | 334 | """Helper to call the handler function which covers 2 cases:
|
|
0 commit comments