Skip to content

Commit fedc1c2

Browse files
committed
Add integ tests and remove dev changes
1 parent 6b33897 commit fedc1c2

File tree

5 files changed

+306
-26
lines changed

5 files changed

+306
-26
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.100.1.dev0
1+
2.99.1.dev0

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -207,26 +207,7 @@ def __init__(
207207
)
208208
self.framework_version = framework_version
209209
self.py_version = py_version
210-
211-
if distribution is not None:
212-
instance_type = renamed_kwargs(
213-
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
214-
)
215-
216-
validate_smdistributed(
217-
instance_type=instance_type,
218-
framework_name=self._framework_name,
219-
framework_version=framework_version,
220-
py_version=py_version,
221-
distribution=distribution,
222-
image_uri=image_uri,
223-
)
224-
225-
warn_if_parameter_server_with_multi_gpu(
226-
training_instance_type=instance_type, distribution=distribution
227-
)
228-
229-
self.instance_type = instance_type
210+
self.instance_type = instance_type
230211

231212
if "enable_sagemaker_metrics" not in kwargs:
232213
# enable sagemaker metrics for PT v1.3 or greater:
@@ -260,7 +241,6 @@ def _pytorch_distribution_configuration(self, distribution):
260241
if "pytorchddp" in distribution:
261242
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
262243
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
263-
LOGGER.info("viskaria Setting instance type to: %s", self.instance_type)
264244
distribution_config[self.INSTANCE_TYPE] = self.instance_type
265245
else:
266246
distribution_config = self._distribution_configuration(distribution=distribution)

tests/data/pytorch_ddp/mnist_pt.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import print_function
14+
15+
import argparse
16+
import os
17+
import time
18+
import torch
19+
import torch.nn as nn
20+
import torch.nn.functional as F
21+
import torch.optim as optim
22+
from torchvision import datasets, transforms
23+
from torch.optim.lr_scheduler import StepLR
24+
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
25+
import smdistributed.dataparallel.torch.distributed as dist
26+
27+
dist.init_process_group()
28+
29+
30+
class Net(nn.Module):
31+
def __init__(self):
32+
super(Net, self).__init__()
33+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
34+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
35+
self.dropout1 = nn.Dropout2d(0.25)
36+
self.dropout2 = nn.Dropout2d(0.5)
37+
self.fc1 = nn.Linear(9216, 128)
38+
self.fc2 = nn.Linear(128, 10)
39+
40+
def forward(self, x):
41+
x = self.conv1(x)
42+
x = F.relu(x)
43+
x = self.conv2(x)
44+
x = F.relu(x)
45+
x = F.max_pool2d(x, 2)
46+
x = self.dropout1(x)
47+
x = torch.flatten(x, 1)
48+
x = self.fc1(x)
49+
x = F.relu(x)
50+
x = self.dropout2(x)
51+
x = self.fc2(x)
52+
output = F.log_softmax(x, dim=1)
53+
return output
54+
55+
56+
def train(args, model, device, train_loader, optimizer, epoch):
57+
model.train()
58+
for batch_idx, (data, target) in enumerate(train_loader):
59+
data, target = data.to(device), target.to(device)
60+
optimizer.zero_grad()
61+
output = model(data)
62+
loss = F.nll_loss(output, target)
63+
loss.backward()
64+
optimizer.step()
65+
if batch_idx % args.log_interval == 0 and args.rank == 0:
66+
print(
67+
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
68+
epoch,
69+
batch_idx * len(data) * args.world_size,
70+
len(train_loader.dataset),
71+
100.0 * batch_idx / len(train_loader),
72+
loss.item(),
73+
)
74+
)
75+
if args.verbose:
76+
print("Batch", batch_idx, "from rank", args.rank)
77+
78+
79+
def test(model, device, test_loader):
80+
model.eval()
81+
test_loss = 0
82+
correct = 0
83+
with torch.no_grad():
84+
for data, target in test_loader:
85+
data, target = data.to(device), target.to(device)
86+
output = model(data)
87+
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
88+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
89+
correct += pred.eq(target.view_as(pred)).sum().item()
90+
91+
test_loss /= len(test_loader.dataset)
92+
93+
print(
94+
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
95+
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
96+
)
97+
)
98+
99+
100+
def main():
101+
# Training settings
102+
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
103+
parser.add_argument(
104+
"--batch-size",
105+
type=int,
106+
default=64,
107+
metavar="N",
108+
help="input batch size for training (default: 64)",
109+
)
110+
parser.add_argument(
111+
"--test-batch-size",
112+
type=int,
113+
default=1000,
114+
metavar="N",
115+
help="input batch size for testing (default: 1000)",
116+
)
117+
parser.add_argument(
118+
"--epochs",
119+
type=int,
120+
default=14,
121+
metavar="N",
122+
help="number of epochs to train (default: 14)",
123+
)
124+
parser.add_argument(
125+
"--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)"
126+
)
127+
parser.add_argument(
128+
"--gamma",
129+
type=float,
130+
default=0.7,
131+
metavar="M",
132+
help="Learning rate step gamma (default: 0.7)",
133+
)
134+
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
135+
parser.add_argument(
136+
"--log-interval",
137+
type=int,
138+
default=10,
139+
metavar="N",
140+
help="how many batches to wait before logging training status",
141+
)
142+
parser.add_argument(
143+
"--save-model", action="store_true", default=False, help="For Saving the current Model"
144+
)
145+
parser.add_argument(
146+
"--verbose",
147+
action="store_true",
148+
default=False,
149+
help="For displaying SM Distributed Data Parallel-specific logs",
150+
)
151+
parser.add_argument(
152+
"--data-path",
153+
type=str,
154+
default=os.environ["SM_CHANNEL_TRAINING"],
155+
help="Path for downloading the MNIST dataset",
156+
)
157+
158+
args = parser.parse_args()
159+
args.world_size = dist.get_world_size()
160+
args.rank = rank = dist.get_rank()
161+
args.local_rank = local_rank = dist.get_local_rank()
162+
args.lr = 1.0
163+
args.batch_size //= args.world_size // 8
164+
args.batch_size = max(args.batch_size, 1)
165+
data_path = args.data_path
166+
167+
if args.verbose:
168+
print(
169+
"Hello from rank",
170+
rank,
171+
"of local_rank",
172+
local_rank,
173+
"in world size of",
174+
args.world_size,
175+
)
176+
177+
if not torch.cuda.is_available():
178+
raise Exception(
179+
"Must run SM Distributed Data Parallel MNIST example on CUDA-capable devices."
180+
)
181+
182+
torch.manual_seed(args.seed)
183+
184+
device = torch.device("cuda")
185+
186+
if local_rank == 0:
187+
train_dataset = datasets.MNIST(
188+
data_path,
189+
train=True,
190+
download=False, # True sets a dependency on an external site for our tests.
191+
transform=transforms.Compose(
192+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
193+
),
194+
)
195+
else:
196+
time.sleep(8)
197+
train_dataset = datasets.MNIST(
198+
data_path,
199+
train=True,
200+
download=False,
201+
transform=transforms.Compose(
202+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
203+
),
204+
)
205+
206+
train_sampler = torch.utils.data.distributed.DistributedSampler(
207+
train_dataset, num_replicas=args.world_size, rank=rank
208+
)
209+
train_loader = torch.utils.data.DataLoader(
210+
train_dataset,
211+
batch_size=args.batch_size,
212+
shuffle=False,
213+
num_workers=0,
214+
pin_memory=True,
215+
sampler=train_sampler,
216+
)
217+
if rank == 0:
218+
test_loader = torch.utils.data.DataLoader(
219+
datasets.MNIST(
220+
data_path,
221+
train=False,
222+
transform=transforms.Compose(
223+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
224+
),
225+
),
226+
batch_size=args.test_batch_size,
227+
shuffle=True,
228+
)
229+
230+
model = DDP(Net().to(device))
231+
torch.cuda.set_device(local_rank)
232+
model.cuda(local_rank)
233+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
234+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
235+
for epoch in range(1, args.epochs + 1):
236+
train(args, model, device, train_loader, optimizer, epoch)
237+
if rank == 0:
238+
test(model, device, test_loader)
239+
scheduler.step()
240+
241+
if args.save_model:
242+
torch.save(model.state_dict(), "mnist_cnn.pt")
243+
244+
245+
if __name__ == "__main__":
246+
main()

tests/integ/test_pytorchddp.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
import pytest
18+
19+
import sagemaker.utils
20+
import tests.integ as integ
21+
22+
from sagemaker.pytorch import PyTorch
23+
from tests.integ import timeout
24+
from tests.integ.test_pytorch import _upload_training_data
25+
26+
pytorchddp_dir = os.path.join(os.path.dirname(__file__), "..", "data", "pytorch_ddp")
27+
28+
29+
@pytest.mark.skip(
30+
reason="This test is skipped for now due ML capacity error."
31+
"This test should be re-enabled later."
32+
)
33+
@pytest.mark.skipif(
34+
integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS,
35+
reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge",
36+
)
37+
def test_pytorchddp_pt_mnist(
38+
sagemaker_session,
39+
pytorch_training_latest_version,
40+
pytorch_training_latest_py_version,
41+
):
42+
job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp")
43+
estimator = PyTorch(
44+
entry_point="mnist_pt.py",
45+
role="SageMakerRole",
46+
source_dir=pytorchddp_dir,
47+
instance_count=2,
48+
instance_type="ml.p3.16xlarge",
49+
sagemaker_session=sagemaker_session,
50+
framework_version=pytorch_training_latest_version,
51+
py_version=pytorch_training_latest_py_version,
52+
distribution={"pytorchddp": {"enabled": True}},
53+
)
54+
55+
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
56+
estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name)

tests/unit/test_pytorch.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@
5656
"TrialComponentDisplayName": "tc",
5757
}
5858

59-
DISTRIBUTION_PYTORCH_DDP_ENABLED = {
60-
"pytorchddp": {"enabled": True, "nnodes": 2, "nproc_per_node": 8}
61-
}
59+
DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}}
6260

6361

6462
@pytest.fixture(name="sagemaker_session")
@@ -724,7 +722,7 @@ def test_pytorch_ddp_distribution_configuration(
724722
framework_version=pytorch_training_version,
725723
py_version=pytorch_training_py_version,
726724
distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED,
727-
instance_type = test_instance_type,
725+
instance_type=test_instance_type,
728726
)
729727
actual_pytorch_ddp = pytorch._pytorch_distribution_configuration(
730728
distribution=pytorch.distribution

0 commit comments

Comments
 (0)