|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 |
| -import argparse |
| 14 | + |
15 | 15 | import logging
|
16 | 16 | import os
|
17 | 17 | import sys
|
18 | 18 |
|
19 |
| -import sagemaker_containers |
20 | 19 | import torch
|
21 |
| -import torch.distributed as dist |
22 | 20 | import torch.nn as nn
|
23 | 21 | import torch.nn.functional as F
|
24 |
| -import torch.optim as optim |
25 | 22 | import torch.utils.data
|
26 | 23 | import torch.utils.data.distributed
|
27 |
| -from torchvision import datasets, transforms |
28 | 24 |
|
29 | 25 | logger = logging.getLogger(__name__)
|
30 | 26 | logger.setLevel(logging.DEBUG)
|
@@ -53,171 +49,9 @@ def forward(self, x):
|
53 | 49 | return F.log_softmax(x, dim=1)
|
54 | 50 |
|
55 | 51 |
|
56 |
| -def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs): |
57 |
| - logger.info("Get train data loader") |
58 |
| - dataset = datasets.MNIST(training_dir, train=True, transform=transforms.Compose([ |
59 |
| - transforms.ToTensor(), |
60 |
| - transforms.Normalize((0.1307,), (0.3081,)) |
61 |
| - ])) |
62 |
| - train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None |
63 |
| - return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None, |
64 |
| - sampler=train_sampler, **kwargs) |
65 |
| - |
66 |
| - |
67 |
| -def _get_test_data_loader(test_batch_size, training_dir, **kwargs): |
68 |
| - logger.info("Get test data loader") |
69 |
| - return torch.utils.data.DataLoader( |
70 |
| - datasets.MNIST(training_dir, train=False, transform=transforms.Compose([ |
71 |
| - transforms.ToTensor(), |
72 |
| - transforms.Normalize((0.1307,), (0.3081,)) |
73 |
| - ])), |
74 |
| - batch_size=test_batch_size, shuffle=True, **kwargs) |
75 |
| - |
76 |
| - |
77 |
| -def _average_gradients(model): |
78 |
| - # Gradient averaging. |
79 |
| - size = float(dist.get_world_size()) |
80 |
| - for param in model.parameters(): |
81 |
| - dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) |
82 |
| - param.grad.data /= size |
83 |
| - |
84 |
| - |
85 |
| -def train(args): |
86 |
| - is_distributed = len(args.hosts) > 1 and args.backend is not None |
87 |
| - logger.debug("Distributed training - {}".format(is_distributed)) |
88 |
| - use_cuda = (args.processor == 'gpu') or (args.num_gpus > 0) |
89 |
| - logger.debug("Number of gpus available - {}".format(args.num_gpus)) |
90 |
| - kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} |
91 |
| - device = torch.device("cuda" if use_cuda else "cpu") |
92 |
| - |
93 |
| - if is_distributed: |
94 |
| - # Initialize the distributed environment. |
95 |
| - world_size = len(args.hosts) |
96 |
| - os.environ['WORLD_SIZE'] = str(world_size) |
97 |
| - host_rank = args.hosts.index(args.current_host) |
98 |
| - os.environ['RANK'] = str(host_rank) |
99 |
| - dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size) |
100 |
| - logger.info('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format( |
101 |
| - args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format( |
102 |
| - dist.get_rank(), args.num_gpus)) |
103 |
| - |
104 |
| - # set the seed for generating random numbers |
105 |
| - torch.manual_seed(args.seed) |
106 |
| - if use_cuda: |
107 |
| - torch.cuda.manual_seed(args.seed) |
108 |
| - |
109 |
| - train_loader = _get_train_data_loader(args.batch_size, args.data_dir, is_distributed, **kwargs) |
110 |
| - test_loader = _get_test_data_loader(args.test_batch_size, args.data_dir, **kwargs) |
111 |
| - |
112 |
| - # TODO: assert the logs when we move to the SDK local mode |
113 |
| - logger.debug("Processes {}/{} ({:.0f}%) of train data".format( |
114 |
| - len(train_loader.sampler), len(train_loader.dataset), |
115 |
| - 100. * len(train_loader.sampler) / len(train_loader.dataset) |
116 |
| - )) |
117 |
| - |
118 |
| - logger.debug("Processes {}/{} ({:.0f}%) of test data".format( |
119 |
| - len(test_loader.sampler), len(test_loader.dataset), |
120 |
| - 100. * len(test_loader.sampler) / len(test_loader.dataset) |
121 |
| - )) |
122 |
| - |
123 |
| - model = Net().to(device) |
124 |
| - if is_distributed and use_cuda: |
125 |
| - # multi-machine multi-gpu case |
126 |
| - logger.debug("Multi-machine multi-gpu: using DistributedDataParallel.") |
127 |
| - model = torch.nn.parallel.DistributedDataParallel(model) |
128 |
| - elif use_cuda: |
129 |
| - # single-machine multi-gpu case |
130 |
| - logger.debug("Single-machine multi-gpu: using DataParallel().cuda().") |
131 |
| - model = torch.nn.DataParallel(model).to(device) |
132 |
| - else: |
133 |
| - # single-machine or multi-machine cpu case |
134 |
| - logger.debug("Single-machine/multi-machine cpu: using DataParallel.") |
135 |
| - model = torch.nn.DataParallel(model) |
136 |
| - |
137 |
| - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) |
138 |
| - |
139 |
| - for epoch in range(1, args.epochs + 1): |
140 |
| - model.train() |
141 |
| - for batch_idx, (data, target) in enumerate(train_loader, 1): |
142 |
| - data, target = data.to(device), target.to(device) |
143 |
| - optimizer.zero_grad() |
144 |
| - output = model(data) |
145 |
| - loss = F.nll_loss(output, target) |
146 |
| - loss.backward() |
147 |
| - if is_distributed and not use_cuda: |
148 |
| - # average gradients manually for multi-machine cpu case only |
149 |
| - _average_gradients(model) |
150 |
| - optimizer.step() |
151 |
| - if batch_idx % args.log_interval == 0: |
152 |
| - logger.debug('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
153 |
| - epoch, batch_idx * len(data), len(train_loader.sampler), |
154 |
| - 100. * batch_idx / len(train_loader), loss.item())) |
155 |
| - test(model, test_loader, device) |
156 |
| - save_model(model, args.model_dir) |
157 |
| - |
158 |
| - |
159 |
| -def test(model, test_loader, device): |
160 |
| - model.eval() |
161 |
| - test_loss = 0 |
162 |
| - correct = 0 |
163 |
| - with torch.no_grad(): |
164 |
| - for data, target in test_loader: |
165 |
| - data, target = data.to(device), target.to(device) |
166 |
| - output = model(data) |
167 |
| - test_loss += F.nll_loss(output, target, size_average=None).item() # sum up batch loss |
168 |
| - pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability |
169 |
| - correct += pred.eq(target.view_as(pred)).sum().item() |
170 |
| - |
171 |
| - test_loss /= len(test_loader.dataset) |
172 |
| - logger.debug('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
173 |
| - test_loss, correct, len(test_loader.dataset), |
174 |
| - 100. * correct / len(test_loader.dataset))) |
175 |
| - |
176 |
| - |
177 | 52 | def model_fn(model_dir):
|
178 | 53 | logger.info('model_fn')
|
179 | 54 | model = torch.nn.DataParallel(Net())
|
180 | 55 | with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
|
181 | 56 | model.load_state_dict(torch.load(f))
|
182 | 57 | return model
|
183 |
| - |
184 |
| - |
185 |
| -def save_model(model, model_dir): |
186 |
| - logger.info("Saving the model.") |
187 |
| - path = os.path.join(model_dir, 'model.pth') |
188 |
| - # recommended way from http://pytorch.org/docs/master/notes/serialization.html |
189 |
| - torch.save(model.state_dict(), path) |
190 |
| - |
191 |
| - |
192 |
| -if __name__ == '__main__': |
193 |
| - parser = argparse.ArgumentParser() |
194 |
| - |
195 |
| - # Data and model checkpoints directories |
196 |
| - parser.add_argument('--batch-size', type=int, default=64, metavar='N', |
197 |
| - help='input batch size for training (default: 64)') |
198 |
| - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', |
199 |
| - help='input batch size for testing (default: 1000)') |
200 |
| - parser.add_argument('--epochs', type=int, default=1, metavar='N', |
201 |
| - help='number of epochs to train (default: 10)') |
202 |
| - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', |
203 |
| - help='learning rate (default: 0.01)') |
204 |
| - parser.add_argument('--momentum', type=float, default=0.5, metavar='M', |
205 |
| - help='SGD momentum (default: 0.5)') |
206 |
| - parser.add_argument('--seed', type=int, default=1, metavar='S', |
207 |
| - help='random seed (default: 1)') |
208 |
| - parser.add_argument('--log-interval', type=int, default=100, metavar='N', |
209 |
| - help='how many batches to wait before logging training status') |
210 |
| - parser.add_argument('--backend', type=str, default=None, |
211 |
| - help='backend for distributed training') |
212 |
| - parser.add_argument('--processor', type=str, default='cpu', |
213 |
| - help='backend for distributed training') |
214 |
| - |
215 |
| - # Container environment |
216 |
| - env = sagemaker_containers.training_env() |
217 |
| - parser.add_argument('--hosts', type=list, default=env.hosts) |
218 |
| - parser.add_argument('--current-host', type=str, default=env.current_host) |
219 |
| - parser.add_argument('--model-dir', type=str, default=env.model_dir) |
220 |
| - parser.add_argument('--data-dir', type=str, default=env.channel_input_dirs['training']) |
221 |
| - parser.add_argument('--num-gpus', type=int, default=env.num_gpus) |
222 |
| - |
223 |
| - train(parser.parse_args()) |
0 commit comments