@@ -301,23 +301,29 @@ def validate_and_initialize_user_module(self):
301
301
)
302
302
self .log_func_implementation_found_or_not (load_fn , MODEL_FN )
303
303
if load_fn is not None :
304
- self .load_extra_arg = self .function_extra_arg (self .load , load_fn )
304
+ self .load_extra_arg = self .function_extra_arg (HuggingFaceHandlerService .load , load_fn )
305
305
self .load = load_fn
306
306
self .log_func_implementation_found_or_not (preprocess_fn , INPUT_FN )
307
307
if preprocess_fn is not None :
308
- self .preprocess_extra_arg = self .function_extra_arg (self .preprocess , preprocess_fn )
308
+ self .preprocess_extra_arg = self .function_extra_arg (
309
+ HuggingFaceHandlerService .preprocess , preprocess_fn
310
+ )
309
311
self .preprocess = preprocess_fn
310
312
self .log_func_implementation_found_or_not (predict_fn , PREDICT_FN )
311
313
if predict_fn is not None :
312
- self .predict_extra_arg = self .function_extra_arg (self .predict , predict_fn )
314
+ self .predict_extra_arg = self .function_extra_arg (HuggingFaceHandlerService .predict , predict_fn )
313
315
self .predict = predict_fn
314
316
self .log_func_implementation_found_or_not (postprocess_fn , OUTPUT_FN )
315
317
if postprocess_fn is not None :
316
- self .postprocess_extra_arg = self .function_extra_arg (self .postprocess , postprocess_fn )
318
+ self .postprocess_extra_arg = self .function_extra_arg (
319
+ HuggingFaceHandlerService .postprocess , postprocess_fn
320
+ )
317
321
self .postprocess = postprocess_fn
318
322
self .log_func_implementation_found_or_not (transform_fn , TRANSFORM_FN )
319
323
if transform_fn is not None :
320
- self .transform_extra_arg = self .function_extra_arg (self .transform_fn , transform_fn )
324
+ self .transform_extra_arg = self .function_extra_arg (
325
+ HuggingFaceHandlerService .transform_fn , transform_fn
326
+ )
321
327
self .transform_fn = transform_fn
322
328
else :
323
329
logger .info (
@@ -342,8 +348,15 @@ def function_extra_arg(self, default_func, func):
342
348
1. the handle function takes context
343
349
2. the handle function does not take context
344
350
"""
345
- num_default_func_input = len (signature (default_func ).parameters )
346
- num_func_input = len (signature (func ).parameters )
351
+ default_params = signature (default_func ).parameters
352
+ func_params = signature (func ).parameters
353
+
354
+ if "self" in default_params :
355
+ num_default_func_input = len (default_params ) - 1
356
+ else :
357
+ num_default_func_input = len (default_params )
358
+
359
+ num_func_input = len (func_params )
347
360
if num_default_func_input == num_func_input :
348
361
# function takes context
349
362
extra_args = [self .context ]
0 commit comments