Skip to content

Commit 92b57dd

Browse files
Merge pull request #136 from benieric/main-fix-extra-args
fix: correctly handle method parameter counting in function_extra_arg
2 parents 5a7519d + a204e71 commit 92b57dd

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

src/sagemaker_huggingface_inference_toolkit/handler_service.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,23 +301,29 @@ def validate_and_initialize_user_module(self):
301301
)
302302
self.log_func_implementation_found_or_not(load_fn, MODEL_FN)
303303
if load_fn is not None:
304-
self.load_extra_arg = self.function_extra_arg(self.load, load_fn)
304+
self.load_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.load, load_fn)
305305
self.load = load_fn
306306
self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN)
307307
if preprocess_fn is not None:
308-
self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn)
308+
self.preprocess_extra_arg = self.function_extra_arg(
309+
HuggingFaceHandlerService.preprocess, preprocess_fn
310+
)
309311
self.preprocess = preprocess_fn
310312
self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN)
311313
if predict_fn is not None:
312-
self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn)
314+
self.predict_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.predict, predict_fn)
313315
self.predict = predict_fn
314316
self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN)
315317
if postprocess_fn is not None:
316-
self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn)
318+
self.postprocess_extra_arg = self.function_extra_arg(
319+
HuggingFaceHandlerService.postprocess, postprocess_fn
320+
)
317321
self.postprocess = postprocess_fn
318322
self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN)
319323
if transform_fn is not None:
320-
self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn)
324+
self.transform_extra_arg = self.function_extra_arg(
325+
HuggingFaceHandlerService.transform_fn, transform_fn
326+
)
321327
self.transform_fn = transform_fn
322328
else:
323329
logger.info(
@@ -342,8 +348,15 @@ def function_extra_arg(self, default_func, func):
342348
1. the handle function takes context
343349
2. the handle function does not take context
344350
"""
345-
num_default_func_input = len(signature(default_func).parameters)
346-
num_func_input = len(signature(func).parameters)
351+
default_params = signature(default_func).parameters
352+
func_params = signature(func).parameters
353+
354+
if "self" in default_params:
355+
num_default_func_input = len(default_params) - 1
356+
else:
357+
num_default_func_input = len(default_params)
358+
359+
num_func_input = len(func_params)
347360
if num_default_func_input == num_func_input:
348361
# function takes context
349362
extra_args = [self.context]

tests/unit/test_handler_service_with_context.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,23 @@ def test_validate_and_initialize_user_module_transform_fn():
166166
inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
167167
== "output dummy"
168168
)
169+
170+
171+
def test_validate_and_initialize_user_module_transform_fn_race_condition():
172+
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
173+
inference_handler = handler_service.HuggingFaceHandlerService()
174+
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context")
175+
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")
176+
177+
# Similuate 2 threads bypassing check in handle() - calling initialize twice
178+
inference_handler.initialize(CONTEXT)
179+
inference_handler.initialize(CONTEXT)
180+
181+
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
182+
CONTEXT.metrics = MetricsStore(1, MODEL)
183+
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
184+
assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py"
185+
assert (
186+
inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
187+
== "output dummy"
188+
)

tests/unit/test_handler_service_without_context.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,20 @@ def test_validate_and_initialize_user_module_transform_fn():
154154
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
155155
assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
156156
assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy"
157+
158+
159+
def test_validate_and_initialize_user_module_transform_fn_race_condition():
160+
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
161+
inference_handler = handler_service.HuggingFaceHandlerService()
162+
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context")
163+
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")
164+
165+
# Similuate 2 threads bypassing check in handle() - calling initialize twice
166+
inference_handler.initialize(CONTEXT)
167+
inference_handler.initialize(CONTEXT)
168+
169+
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
170+
CONTEXT.metrics = MetricsStore(1, MODEL)
171+
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
172+
assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
173+
assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy"

0 commit comments

Comments
 (0)