15
15
import csv
16
16
import json
17
17
18
+ import mock
18
19
import numpy as np
19
20
import pytest
20
21
import torch
@@ -40,7 +41,7 @@ def __call__(self, tensor):
40
41
return 3 * tensor
41
42
42
43
43
- @pytest .fixture (scope = ' session' , name = ' tensor' )
44
+ @pytest .fixture (scope = " session" , name = " tensor" )
44
45
def fixture_tensor ():
45
46
tensor = torch .rand (5 , 10 , 7 , 9 )
46
47
return tensor .to (device )
@@ -51,9 +52,14 @@ def inference_handler():
51
52
return default_inference_handler .DefaultPytorchInferenceHandler ()
52
53
53
54
55
+ @pytest .fixture ()
56
+ def eia_inference_handler ():
57
+ return default_inference_handler .DefaultPytorchInferenceHandler ()
58
+
59
+
54
60
def test_default_model_fn (inference_handler ):
55
61
with pytest .raises (NotImplementedError ):
56
- inference_handler .default_model_fn (' model_dir' )
62
+ inference_handler .default_model_fn (" model_dir" )
57
63
58
64
59
65
def test_default_input_fn_json (inference_handler , tensor ):
@@ -67,7 +73,7 @@ def test_default_input_fn_json(inference_handler, tensor):
67
73
def test_default_input_fn_csv (inference_handler ):
68
74
array = [[1 , 2 , 3 ], [4 , 5 , 6 ]]
69
75
str_io = StringIO ()
70
- csv .writer (str_io , delimiter = ',' ).writerows (array )
76
+ csv .writer (str_io , delimiter = "," ).writerows (array )
71
77
72
78
deserialized_np_array = inference_handler .default_input_fn (str_io .getvalue (), content_types .CSV )
73
79
@@ -78,7 +84,7 @@ def test_default_input_fn_csv(inference_handler):
78
84
79
85
def test_default_input_fn_csv_bad_columns (inference_handler ):
80
86
str_io = StringIO ()
81
- csv_writer = csv .writer (str_io , delimiter = ',' )
87
+ csv_writer = csv .writer (str_io , delimiter = "," )
82
88
csv_writer .writerow ([1 , 2 , 3 ])
83
89
csv_writer .writerow ([1 , 2 , 3 , 4 ])
84
90
@@ -97,7 +103,7 @@ def test_default_input_fn_npy(inference_handler, tensor):
97
103
98
104
def test_default_input_fn_bad_content_type (inference_handler ):
99
105
with pytest .raises (errors .UnsupportedFormatError ):
100
- inference_handler .default_input_fn ('' , ' application/not_supported' )
106
+ inference_handler .default_input_fn ("" , " application/not_supported" )
101
107
102
108
103
109
def test_default_predict_fn (inference_handler , tensor ):
@@ -162,7 +168,7 @@ def test_default_output_fn_csv_float(inference_handler):
162
168
163
169
def test_default_output_fn_bad_accept (inference_handler ):
164
170
with pytest .raises (errors .UnsupportedFormatError ):
165
- inference_handler .default_output_fn ('' , ' application/not_supported' )
171
+ inference_handler .default_output_fn ("" , " application/not_supported" )
166
172
167
173
168
174
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "cuda is not available" )
@@ -171,4 +177,34 @@ def test_default_output_fn_gpu(inference_handler):
171
177
172
178
output = inference_handler .default_output_fn (tensor_gpu , content_types .CSV )
173
179
174
- assert '1,2,3\n 4,5,6\n ' .encode ("utf-8" ) == output
180
+ assert "1,2,3\n 4,5,6\n " .encode ("utf-8" ) == output
181
+
182
+
183
+ def test_eia_default_model_fn (eia_inference_handler ):
184
+ with mock .patch ("sagemaker_pytorch_serving_container.default_inference_handler.os" ) as mock_os :
185
+ mock_os .getenv .return_value = "true"
186
+ mock_os .path .join .return_value = "model_dir"
187
+ mock_os .path .exists .return_value = True
188
+ with mock .patch ("torch.jit.load" ) as mock_torch :
189
+ mock_torch .return_value = DummyModel ()
190
+ model = eia_inference_handler .default_model_fn ("model_dir" )
191
+ assert model is not None
192
+
193
+
194
+ def test_eia_default_model_fn_error (eia_inference_handler ):
195
+ with mock .patch ("sagemaker_pytorch_serving_container.default_inference_handler.os" ) as mock_os :
196
+ mock_os .getenv .return_value = "true"
197
+ mock_os .path .join .return_value = "model_dir"
198
+ mock_os .path .exists .return_value = False
199
+ with pytest .raises (FileNotFoundError ):
200
+ eia_inference_handler .default_model_fn ("model_dir" )
201
+
202
+
203
+ def test_eia_default_predict_fn (eia_inference_handler , tensor ):
204
+ model = DummyModel ()
205
+ with mock .patch ("sagemaker_pytorch_serving_container.default_inference_handler.os" ) as mock_os :
206
+ mock_os .getenv .return_value = "true"
207
+ with mock .patch ("torch.jit.optimized_execution" ) as mock_torch :
208
+ mock_torch .__enter__ .return_value = "dummy"
209
+ eia_inference_handler .default_predict_fn (tensor , model )
210
+ mock_torch .assert_called_once ()
0 commit comments