Skip to content

Commit 95b3877

Browse files
mxnet-sdk-team-mmsvdantu
authored andcommitted
Added tests to test the API changes
1 parent 28d7f3b commit 95b3877

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import os
2+
3+
4+
def model_fn(model_dir):
5+
return f"Loading {os.path.basename(__file__)}"
6+
7+
8+
def transform_fn(a, b, c, d):
9+
return f"output {b}"

tests/unit/test_handler_service.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import json
1515
import os
16+
import sys
1617
import tempfile
1718

1819
import pytest
@@ -117,7 +118,7 @@ def test_predict(inference_handler):
117118
model_dir=tmpdirname,
118119
)
119120
inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder)
120-
prediction = inference_handler.predict(INPUT)
121+
prediction = inference_handler.predict(INPUT, inference_handler.model)
121122
assert "label" in prediction[0]
122123
assert "score" in prediction[0]
123124

@@ -128,9 +129,8 @@ def test_postprocess(inference_handler):
128129

129130

130131
def test_validate_and_initialize_user_module(inference_handler):
131-
model_dir = os.path.join(os.getcwd(), "tests/resources")
132+
model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn")
132133
CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4")
133-
inference_handler.environment
134134

135135
inference_handler.initialize(CONTEXT)
136136
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
@@ -141,5 +141,19 @@ def test_validate_and_initialize_user_module(inference_handler):
141141

142142
assert inference_handler.load({}) == "model"
143143
assert inference_handler.preprocess({}, "") == "data"
144-
assert inference_handler.predict({}) == "output"
144+
assert inference_handler.predict({}, "model") == "output"
145145
assert inference_handler.postprocess("output", "") == "output"
146+
147+
148+
def test_validate_and_initialize_user_module_transform_fn():
149+
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
150+
inference_handler = handler_service.HuggingFaceHandlerService()
151+
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn")
152+
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")
153+
154+
inference_handler.initialize(CONTEXT)
155+
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
156+
CONTEXT.metrics = MetricsStore(1, MODEL)
157+
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
158+
assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
159+
assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy"

0 commit comments

Comments
 (0)