13
13
# limitations under the License.
14
14
import json
15
15
import os
16
+ import sys
16
17
import tempfile
17
18
18
19
import pytest
@@ -117,7 +118,7 @@ def test_predict(inference_handler):
117
118
model_dir = tmpdirname ,
118
119
)
119
120
inference_handler .model = get_pipeline (task = TASK , device = - 1 , model_dir = storage_folder )
120
- prediction = inference_handler .predict (INPUT )
121
+ prediction = inference_handler .predict (INPUT , inference_handler . model )
121
122
assert "label" in prediction [0 ]
122
123
assert "score" in prediction [0 ]
123
124
@@ -128,9 +129,8 @@ def test_postprocess(inference_handler):
128
129
129
130
130
131
def test_validate_and_initialize_user_module (inference_handler ):
131
- model_dir = os .path .join (os .getcwd (), "tests/resources" )
132
+ model_dir = os .path .join (os .getcwd (), "tests/resources/model_input_predict_output_fn " )
132
133
CONTEXT = Context ("" , model_dir , {}, 1 , - 1 , "1.1.4" )
133
- inference_handler .environment
134
134
135
135
inference_handler .initialize (CONTEXT )
136
136
CONTEXT .request_processor = [RequestProcessor ({"Content-Type" : "application/json" })]
@@ -141,5 +141,19 @@ def test_validate_and_initialize_user_module(inference_handler):
141
141
142
142
assert inference_handler .load ({}) == "model"
143
143
assert inference_handler .preprocess ({}, "" ) == "data"
144
- assert inference_handler .predict ({}) == "output"
144
+ assert inference_handler .predict ({}, "model" ) == "output"
145
145
assert inference_handler .postprocess ("output" , "" ) == "output"
146
+
147
+
148
+ def test_validate_and_initialize_user_module_transform_fn ():
149
+ os .environ ["SAGEMAKER_PROGRAM" ] = "inference_tranform_fn.py"
150
+ inference_handler = handler_service .HuggingFaceHandlerService ()
151
+ model_dir = os .path .join (os .getcwd (), "tests/resources/model_transform_fn" )
152
+ CONTEXT = Context ("dummy" , model_dir , {}, 1 , - 1 , "1.1.4" )
153
+
154
+ inference_handler .initialize (CONTEXT )
155
+ CONTEXT .request_processor = [RequestProcessor ({"Content-Type" : "application/json" })]
156
+ CONTEXT .metrics = MetricsStore (1 , MODEL )
157
+ assert "output" in inference_handler .handle ([{"body" : b"dummy" }], CONTEXT )[0 ]
158
+ assert inference_handler .load ({}) == "Loading inference_tranform_fn.py"
159
+ assert inference_handler .transform_fn ("model" , "dummy" , "application/json" , "application/json" ) == "output dummy"
0 commit comments