Skip to content

Commit f814430

Browse files
fix: black formatting
1 parent 20434df commit f814430

File tree

5 files changed

+74
-53
lines changed

5 files changed

+74
-53
lines changed

src/sagemaker/fw_utils.py

+33-22
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,9 @@
134134
"1.12.0",
135135
]
136136

137-
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = [
138-
"1.11",
139-
"1.11.0"
140-
]
137+
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]
141138

142-
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = [
143-
"torch_distributed"
144-
]
139+
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
145140

146141
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
147142

@@ -710,7 +705,14 @@ def _validate_smdataparallel_args(
710705

711706

712707
def validate_distribution(
713-
distribution, instance_groups, framework_name, framework_version, py_version, image_uri, entry_point, kwargs
708+
distribution,
709+
instance_groups,
710+
framework_name,
711+
framework_version,
712+
py_version,
713+
image_uri,
714+
entry_point,
715+
kwargs,
714716
):
715717
"""Check if distribution strategy is correctly invoked by the user.
716718
@@ -850,9 +852,8 @@ def validate_distribution(
850852
)
851853
return distribution
852854

853-
def validate_distribution_for_instance_type(
854-
instance_type, distribution
855-
):
855+
856+
def validate_distribution_for_instance_type(instance_type, distribution):
856857
"""Check if the provided distribution strategy is supported for the instance_type
857858
858859
Args:
@@ -869,11 +870,11 @@ def validate_distribution_for_instance_type(
869870
distribution_strategy = keys[0]
870871
if distribution_strategy != "torch_distributed":
871872
err_msg += (
872-
f"Provided distribution strategy {distribution_strategy} is not supported for"
873-
" Trainium instances.\n"
874-
"Please specify one of the following supported distribution strategies:"
875-
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
876-
)
873+
f"Provided distribution strategy {distribution_strategy} is not supported for"
874+
" Trainium instances.\n"
875+
"Please specify one of the following supported distribution strategies:"
876+
f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n"
877+
)
877878
elif len(keys) > 1:
878879
err_msg += (
879880
f"Multiple distribution strategies are not supported for Trainium instances.\n"
@@ -884,6 +885,7 @@ def validate_distribution_for_instance_type(
884885
if err_msg:
885886
raise ValueError(err_msg)
886887

888+
887889
def validate_pytorch_distribution(
888890
distribution, framework_name, framework_version, py_version, image_uri
889891
):
@@ -940,8 +942,15 @@ def validate_pytorch_distribution(
940942
if err_msg:
941943
raise ValueError(err_msg)
942944

945+
943946
def validate_torch_distributed_distribution(
944-
instance_type, distribution, framework_name, framework_version, py_version, image_uri, entry_point,
947+
instance_type,
948+
distribution,
949+
framework_name,
950+
framework_version,
951+
py_version,
952+
image_uri,
953+
entry_point,
945954
):
946955
"""Check if torch_distributed distribution strategy is correctly invoked by the user.
947956
@@ -1003,20 +1012,22 @@ def validate_torch_distributed_distribution(
10031012
return
10041013
else:
10051014
err_msg += (
1006-
f"torch_distributed is currently supported only for trainium instances."
1007-
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \
1015+
f"torch_distributed is currently supported only for trainium instances."
1016+
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \
10081017
for information regarding distributed training on non-trainium instances"
10091018
)
10101019

10111020
# Check entry point type
10121021
if not entry_point.endswith(".py"):
1013-
err_msg += ("Unsupported entry point type for torch_distributed.\n"
1014-
"Only python programs (*.py) are supported."
1022+
err_msg += (
1023+
"Unsupported entry point type for torch_distributed.\n"
1024+
"Only python programs (*.py) are supported."
10151025
)
1016-
1026+
10171027
if err_msg:
10181028
raise ValueError(err_msg)
10191029

1030+
10201031
def python_deprecation_warning(framework, latest_supported_version):
10211032
"""Placeholder docstring"""
10221033
return PYTHON_2_DEPRECATION_WARNING.format(

src/sagemaker/pytorch/estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __init__(
168168
169169
To learn more, see `Distributed PyTorch Training
170170
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
171-
171+
172172
**To enable Torch Distributed (Trainium Instances):**
173173
174174
.. code:: python
@@ -177,7 +177,7 @@ def __init__(
177177
"enabled": True
178178
}
179179
}
180-
To learn more, see `Distributed PyTorch Training on Trainium
180+
To learn more, see `Distributed PyTorch Training on Trainium
181181
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.
182182
183183
**To enable MPI:**

tests/conftest.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,11 @@ def torch_distributed_py_version():
452452
return "py3"
453453

454454

455-
@pytest.fixture(
456-
scope="module", params=["1.11.0"]
457-
)
455+
@pytest.fixture(scope="module", params=["1.11.0"])
458456
def torch_distributed_framework_version(request):
459457
return request.param
460458

459+
461460
@pytest.fixture(scope="session")
462461
def cpu_instance_type(sagemaker_session, request):
463462
region = sagemaker_session.boto_session.region_name

tests/data/torch_distributed/mnist_mlp_trainium.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@
1616

1717
# Initialize XLA process group for torchrun
1818
import torch_xla.distributed.xla_backend
19-
torch.distributed.init_process_group('xla')
19+
20+
torch.distributed.init_process_group("xla")
2021

2122
# Global constants
2223
EPOCHS = 4
2324
WARMUP_STEPS = 2
2425
BATCH_SIZE = 32
2526

2627
# Load MNIST train dataset
27-
train_dataset = mnist.MNIST(root=os.path.join('./MNIST_DATA_train', str(xm.get_ordinal())),
28-
train=True, download=True, transform=ToTensor())
28+
train_dataset = mnist.MNIST(
29+
root=os.path.join("./MNIST_DATA_train", str(xm.get_ordinal())),
30+
train=True,
31+
download=True,
32+
transform=ToTensor(),
33+
)
34+
2935

3036
def main():
3137
# XLA MP: get world size
@@ -34,7 +40,7 @@ def main():
3440
torch.manual_seed(0)
3541

3642
# Move model to device and declare optimizer and loss function
37-
device = 'xla'
43+
device = "xla"
3844
model = MLP().to(device)
3945
# For multiprocessing, scale up learning rate
4046
optimizer = torch.optim.SGD(model.parameters(), lr=0.01 * world_size)
@@ -43,45 +49,47 @@ def main():
4349
# Prepare data loader
4450
train_sampler = None
4551
if world_size > 1:
46-
train_sampler = DistributedSampler(train_dataset,
47-
num_replicas=world_size,
48-
rank=xm.get_ordinal(),
49-
shuffle=True)
50-
train_loader = DataLoader(train_dataset,
51-
batch_size=BATCH_SIZE,
52-
sampler=train_sampler,
53-
shuffle=False if train_sampler else True)
52+
train_sampler = DistributedSampler(
53+
train_dataset, num_replicas=world_size, rank=xm.get_ordinal(), shuffle=True
54+
)
55+
train_loader = DataLoader(
56+
train_dataset,
57+
batch_size=BATCH_SIZE,
58+
sampler=train_sampler,
59+
shuffle=False if train_sampler else True,
60+
)
5461
# XLA MP: use MpDeviceLoader from torch_xla.distributed
5562
train_device_loader = pl.MpDeviceLoader(train_loader, device)
5663

5764
# Run the training loop
58-
print('----------Training ---------------')
65+
print("----------Training ---------------")
5966
model.train()
6067
for epoch in range(EPOCHS):
6168
start = time.time()
62-
print(f'Epoch: {epoch}')
69+
print(f"Epoch: {epoch}")
6370
for idx, (train_x, train_label) in enumerate(train_device_loader):
6471
optimizer.zero_grad()
6572
train_x = train_x.view(train_x.size(0), -1)
6673
output = model(train_x)
6774
loss = loss_fn(output, train_label)
6875
loss.backward()
69-
xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step
70-
if idx < WARMUP_STEPS: # skip warmup iterations
76+
xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step
77+
if idx < WARMUP_STEPS: # skip warmup iterations
7178
start = time.time()
7279

7380
# Compute statistics for the last epoch
74-
interval = idx - WARMUP_STEPS # skip warmup iterations
81+
interval = idx - WARMUP_STEPS # skip warmup iterations
7582
throughput = interval / (time.time() - start)
7683
print("Train throughput (iter/sec): {}".format(throughput))
77-
print("Final loss is {:0.4f}".format(loss.detach().to('cpu')))
84+
print("Final loss is {:0.4f}".format(loss.detach().to("cpu")))
7885

7986
# Save checkpoint for evaluation (xm.save ensures only one process save)
8087
os.makedirs("checkpoints", exist_ok=True)
81-
checkpoint = {'state_dict': model.state_dict()}
82-
xm.save(checkpoint,'checkpoints/checkpoint.pt')
88+
checkpoint = {"state_dict": model.state_dict()}
89+
xm.save(checkpoint, "checkpoints/checkpoint.pt")
90+
91+
print("----------End Training ---------------")
8392

84-
print('----------End Training ---------------')
8593

86-
if __name__ == '__main__':
87-
main()
94+
if __name__ == "__main__":
95+
main()

tests/unit/test_fw_utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@ def test_validate_pytorchddp_raises():
947947
image_uri=None,
948948
)
949949

950+
950951
def test_validate_torch_distributed_not_raises():
951952
# Case 1: Framework is not PyTorch
952953
fw_utils.validate_torch_distributed_distribution(
@@ -979,6 +980,7 @@ def test_validate_torch_distributed_not_raises():
979980
image_uri="custom-container",
980981
)
981982

983+
982984
def test_validate_torch_distributed_raises():
983985
torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
984986
# Case 1: Unsupported framework version
@@ -1001,6 +1003,7 @@ def test_validate_torch_distributed_raises():
10011003
image_uri=None,
10021004
)
10031005

1006+
10041007
def test_validate_unsupported_distributions_trainium_raises():
10051008
with pytest.raises(ValueError):
10061009
mpi_enabled = {"mpi": {"enabled": True}}
@@ -1015,17 +1018,17 @@ def test_validate_unsupported_distributions_trainium_raises():
10151018
distribution=mpi_enabled,
10161019
instance_type="ml.trn1.32xlarge",
10171020
)
1018-
1021+
10191022
with pytest.raises(ValueError):
10201023
pytorch_ddp_enabled = {"pytorch_ddp": {"enabled": True}}
10211024
fw_utils.validate_distribution_for_instance_type(
10221025
distribution=pytorch_ddp_enabled,
10231026
instance_type="ml.trn1.32xlarge",
10241027
)
1025-
1028+
10261029
with pytest.raises(ValueError):
10271030
smdataparallel_enabled = {"smdataparallel": {"enabled": True}}
10281031
fw_utils.validate_distribution_for_instance_type(
10291032
distribution=smdataparallel_enabled,
10301033
instance_type="ml.trn1.32xlarge",
1031-
)
1034+
)

0 commit comments

Comments
 (0)