Skip to content

Commit 6e1e1a9

Browse files
committed
Revert "Removed unnecessary mnist scripts."
This reverts commit 7d84091.
1 parent 7d84091 commit 6e1e1a9

File tree

7 files changed

+159
-5
lines changed

7 files changed

+159
-5
lines changed

test-toolkit/integration/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
eia_sub_dir = 'model_eia'
2525

2626
model_cpu_dir = os.path.join(mnist_path, cpu_sub_dir)
27+
mnist_cpu_script = os.path.join(model_cpu_dir, 'mnist.py')
2728
model_cpu_1d_dir = os.path.join(model_cpu_dir, '1d')
2829
mnist_1d_script = os.path.join(model_cpu_1d_dir, 'mnist_1d.py')
2930
model_gpu_dir = os.path.join(mnist_path, gpu_sub_dir)
31+
mnist_gpu_script = os.path.join(model_gpu_dir, 'mnist.py')
32+
model_gpu_1d_dir = os.path.join(model_gpu_dir, '1d')
3033
model_eia_dir = os.path.join(mnist_path, eia_sub_dir)
3134
mnist_eia_script = os.path.join(model_eia_dir, 'mnist.py')
3235
call_model_fn_once_script = os.path.join(resources_path, 'call_model_fn_once.py')

test-toolkit/integration/sagemaker/test_mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,22 @@
1919
import sagemaker
2020
from sagemaker.pytorch import PyTorchModel
2121

22-
from integration import model_cpu_dir, mnist_script, model_eia_dir, mnist_eia_script
22+
from integration import model_cpu_dir, mnist_cpu_script, mnist_gpu_script, model_eia_dir, mnist_eia_script
2323
from integration.sagemaker.timeout import timeout_and_delete_endpoint
2424

2525

2626
@pytest.mark.cpu_test
2727
def test_mnist_cpu(sagemaker_session, image_uri, instance_type):
2828
instance_type = instance_type or 'ml.c4.xlarge'
2929
model_dir = os.path.join(model_cpu_dir, 'model_mnist.tar.gz')
30-
_test_mnist_distributed(sagemaker_session, image_uri, instance_type, model_dir, mnist_script)
30+
_test_mnist_distributed(sagemaker_session, image_uri, instance_type, model_dir, mnist_cpu_script)
3131

3232

3333
@pytest.mark.gpu_test
3434
def test_mnist_gpu(sagemaker_session, image_uri, instance_type):
3535
instance_type = instance_type or 'ml.p2.xlarge'
3636
model_dir = os.path.join(model_cpu_dir, 'model_mnist.tar.gz')
37-
_test_mnist_distributed(sagemaker_session, image_uri, instance_type, model_dir, mnist_script)
37+
_test_mnist_distributed(sagemaker_session, image_uri, instance_type, model_dir, mnist_gpu_script)
3838

3939

4040
@pytest.mark.eia_test
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
18+
def model_fn(model_dir):
19+
lock_file = os.path.join(model_dir, 'model_fn.lock.{}'.format(os.getpid()))
20+
if os.path.exists(lock_file):
21+
raise RuntimeError('model_fn called more than once (lock: {})'.format(lock_file))
22+
23+
open(lock_file, 'a').close()
24+
25+
return 'model'
26+
27+
28+
def input_fn(data, content_type):
29+
return data
30+
31+
32+
def predict_fn(data, model):
33+
return b'output'
34+
35+
36+
def output_fn(prediction, accept):
37+
return prediction
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import logging
16+
import os
17+
import sys
18+
19+
import torch
20+
import torch.nn as nn
21+
import torch.nn.functional as F
22+
import torch.utils.data
23+
import torch.utils.data.distributed
24+
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.DEBUG)
27+
logger.addHandler(logging.StreamHandler(sys.stdout))
28+
29+
30+
# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
31+
class Net(nn.Module):
32+
def __init__(self):
33+
logger.info("Create neural network module")
34+
35+
super(Net, self).__init__()
36+
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
37+
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
38+
self.conv2_drop = nn.Dropout2d()
39+
self.fc1 = nn.Linear(320, 50)
40+
self.fc2 = nn.Linear(50, 10)
41+
42+
def forward(self, x):
43+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
44+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
45+
x = x.view(-1, 320)
46+
x = F.relu(self.fc1(x))
47+
x = F.dropout(x, training=self.training)
48+
x = self.fc2(x)
49+
return F.log_softmax(x, dim=1)
50+
51+
52+
def model_fn(model_dir):
53+
logger.info('model_fn')
54+
model = torch.nn.DataParallel(Net())
55+
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
56+
model.load_state_dict(torch.load(f))
57+
return model

test-toolkit/resources/mnist/model_eia/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import logging
16+
import os
17+
import sys
18+
19+
import torch
20+
import torch.nn as nn
21+
import torch.nn.functional as F
22+
import torch.utils.data
23+
import torch.utils.data.distributed
24+
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.DEBUG)
27+
logger.addHandler(logging.StreamHandler(sys.stdout))
28+
29+
30+
# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
31+
class Net(nn.Module):
32+
def __init__(self):
33+
logger.info("Create neural network module")
34+
35+
super(Net, self).__init__()
36+
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
37+
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
38+
self.conv2_drop = nn.Dropout2d()
39+
self.fc1 = nn.Linear(320, 50)
40+
self.fc2 = nn.Linear(50, 10)
41+
42+
def forward(self, x):
43+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
44+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
45+
x = x.view(-1, 320)
46+
x = F.relu(self.fc1(x))
47+
x = F.dropout(x, training=self.training)
48+
x = self.fc2(x)
49+
return F.log_softmax(x, dim=1)
50+
51+
52+
def model_fn(model_dir):
53+
logger.info('model_fn')
54+
model = torch.nn.DataParallel(Net())
55+
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
56+
model.load_state_dict(torch.load(f))
57+
return model

test-toolkit/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of

0 commit comments

Comments
 (0)