24
24
25
25
import tests .integ .lock as lock
26
26
from tests .integ import DATA_DIR
27
+ from mock import Mock , ANY
27
28
28
29
from sagemaker import image_uris
29
30
@@ -221,6 +222,13 @@ def test_mxnet_local_data_local_script(
221
222
):
222
223
data_path = os .path .join (DATA_DIR , "mxnet_mnist" )
223
224
script_path = os .path .join (data_path , "mnist.py" )
225
+ local_no_s3_session = LocalNoS3Session ()
226
+ local_no_s3_session .boto_session .resource = Mock (
227
+ side_effect = local_no_s3_session .boto_session .resource
228
+ )
229
+ local_no_s3_session .boto_session .client = Mock (
230
+ side_effect = local_no_s3_session .boto_session .client
231
+ )
224
232
225
233
mx = MXNet (
226
234
entry_point = script_path ,
@@ -229,7 +237,7 @@ def test_mxnet_local_data_local_script(
229
237
instance_type = "local" ,
230
238
framework_version = mxnet_training_latest_version ,
231
239
py_version = mxnet_training_latest_py_version ,
232
- sagemaker_session = LocalNoS3Session () ,
240
+ sagemaker_session = local_no_s3_session ,
233
241
)
234
242
235
243
train_input = "file://" + os .path .join (data_path , "train" )
@@ -243,6 +251,11 @@ def test_mxnet_local_data_local_script(
243
251
predictor = mx .deploy (1 , "local" , endpoint_name = endpoint_name )
244
252
data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
245
253
predictor .predict (data )
254
+ # check if no boto_session s3 calls were made
255
+ with pytest .raises (AssertionError ):
256
+ local_no_s3_session .boto_session .resource .assert_called_with ("s3" , region_name = ANY )
257
+ with pytest .raises (AssertionError ):
258
+ local_no_s3_session .boto_session .client .assert_called_with ("s3" , region_name = ANY )
246
259
finally :
247
260
predictor .delete_endpoint ()
248
261
0 commit comments