@@ -43,14 +43,14 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session):
43
43
with timeout (minutes = 20 ):
44
44
estimator = PyTorch .attach (pytorch_training_job , sagemaker_session = sagemaker_session )
45
45
predictor = estimator .deploy (1 , 'ml.c4.xlarge' , endpoint_name = endpoint_name )
46
- data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
46
+ data = numpy .zeros (shape = (1 , 1 , 28 , 28 ), dtype = numpy . float32 )
47
47
predictor .predict (data )
48
48
49
49
batch_size = 100
50
- data = numpy .random .rand (batch_size , 1 , 28 , 28 )
50
+ data = numpy .random .rand (batch_size , 1 , 28 , 28 ). astype ( numpy . float32 )
51
51
output = predictor .predict (data )
52
52
53
- assert numpy . asarray ( output ) .shape == (batch_size , 10 )
53
+ assert output .shape == (batch_size , 10 )
54
54
55
55
56
56
def test_deploy_model (pytorch_training_job , sagemaker_session ):
@@ -63,10 +63,10 @@ def test_deploy_model(pytorch_training_job, sagemaker_session):
63
63
predictor = model .deploy (1 , 'ml.m4.xlarge' , endpoint_name = endpoint_name )
64
64
65
65
batch_size = 100
66
- data = numpy .random .rand (batch_size , 1 , 28 , 28 )
66
+ data = numpy .random .rand (batch_size , 1 , 28 , 28 ). astype ( numpy . float32 )
67
67
output = predictor .predict (data )
68
68
69
- assert numpy . asarray ( output ) .shape == (batch_size , 10 )
69
+ assert output .shape == (batch_size , 10 )
70
70
71
71
72
72
def test_async_fit_deploy (sagemaker_session , pytorch_full_version ):
@@ -92,10 +92,10 @@ def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
92
92
predictor = estimator .deploy (1 , instance_type , endpoint_name = endpoint_name )
93
93
94
94
batch_size = 100
95
- data = numpy .random .rand (batch_size , 1 , 28 , 28 )
95
+ data = numpy .random .rand (batch_size , 1 , 28 , 28 ). astype ( numpy . float32 )
96
96
output = predictor .predict (data )
97
97
98
- assert numpy . asarray ( output ) .shape == (batch_size , 10 )
98
+ assert output .shape == (batch_size , 10 )
99
99
100
100
101
101
# TODO(nadiaya): Run against local mode when errors will be propagated
0 commit comments