Skip to content

Commit 5cce7a0

Browse files
committed
Fix failing tests
1 parent 6c9325e commit 5cce7a0

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

src/sagemaker/pytorch/estimator.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class PyTorch(Framework):
3737

3838
_framework_name = "pytorch"
3939
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
40-
INSTANCE_TYPE = "sagemaker_instance_type"
40+
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4141

4242
def __init__(
4343
self,
@@ -207,13 +207,15 @@ def __init__(
207207
)
208208
self.framework_version = framework_version
209209
self.py_version = py_version
210-
self.instance_type = instance_type
211210

212211
if "enable_sagemaker_metrics" not in kwargs:
213212
# enable sagemaker metrics for PT v1.3 or greater:
214213
if self.framework_version and Version(self.framework_version) >= Version("1.3"):
215214
kwargs["enable_sagemaker_metrics"] = True
216215

216+
if "instance_type" in kwargs:
217+
self.instance_type = kwargs["instance_type"]
218+
217219
super(PyTorch, self).__init__(
218220
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
219221
)
@@ -231,17 +233,19 @@ def __init__(
231233
self.distribution = distribution or {}
232234

233235
def _pytorch_distribution_configuration(self, distribution):
234-
"""Returns a dict of distribution config
236+
"""Returns a dict of distribution config for PyTorch training
237+
235238
Args:
236-
None
239+
distribution (dict): A dictionary with information on how to run distributed training.
237240
Returns:
238-
dict containing torch ddp config
241+
dict containing Pytorch DDP config
239242
"""
240243
distribution_config = {}
241244
if "pytorchddp" in distribution:
242245
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
243246
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
244-
distribution_config[self.INSTANCE_TYPE] = self.instance_type
247+
if self.instance_type is not None:
248+
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
245249
else:
246250
distribution_config = self._distribution_configuration(distribution=distribution)
247251
return distribution_config

tests/data/pytorch_ddp/mnist_pt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
2525
import smdistributed.dataparallel.torch.distributed as dist
2626

27-
dist.init_process_group()
27+
dist.init_process_group(backend="nccl")
2828

2929

3030
class Net(nn.Module):

tests/integ/test_pytorchddp.py

-3
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@
1313
from __future__ import absolute_import
1414

1515
import os
16-
1716
import pytest
18-
1917
import sagemaker.utils
2018
import tests.integ as integ
21-
2219
from sagemaker.pytorch import PyTorch
2320
from tests.integ import timeout
2421
from tests.integ.test_pytorch import _upload_training_data

0 commit comments

Comments
 (0)