@@ -37,7 +37,7 @@ class PyTorch(Framework):
37
37
38
38
_framework_name = "pytorch"
39
39
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
40
- INSTANCE_TYPE = "sagemaker_instance_type"
40
+ INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
41
41
42
42
def __init__ (
43
43
self ,
@@ -207,13 +207,15 @@ def __init__(
207
207
)
208
208
self .framework_version = framework_version
209
209
self .py_version = py_version
210
- self .instance_type = instance_type
211
210
212
211
if "enable_sagemaker_metrics" not in kwargs :
213
212
# enable sagemaker metrics for PT v1.3 or greater:
214
213
if self .framework_version and Version (self .framework_version ) >= Version ("1.3" ):
215
214
kwargs ["enable_sagemaker_metrics" ] = True
216
215
216
+ if "instance_type" in kwargs :
217
+ self .instance_type = kwargs ["instance_type" ]
218
+
217
219
super (PyTorch , self ).__init__ (
218
220
entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs
219
221
)
@@ -231,17 +233,19 @@ def __init__(
231
233
self .distribution = distribution or {}
232
234
233
235
def _pytorch_distribution_configuration (self , distribution ):
234
- """Returns a dict of distribution config
236
+ """Returns a dict of distribution config for PyTorch training
237
+
235
238
Args:
236
- None
239
+ distribution (dict): A dictionary with information on how to run distributed training.
237
240
Returns:
238
- dict containing torch ddp config
241
+ dict containing Pytorch DDP config
239
242
"""
240
243
distribution_config = {}
241
244
if "pytorchddp" in distribution :
242
245
pytorch_ddp_enabled = distribution .get ("pytorchddp" ).get ("enabled" , False )
243
246
distribution_config [self .LAUNCH_PYTORCH_DDP_ENV_NAME ] = pytorch_ddp_enabled
244
- distribution_config [self .INSTANCE_TYPE ] = self .instance_type
247
+ if self .instance_type is not None :
248
+ distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
245
249
else :
246
250
distribution_config = self ._distribution_configuration (distribution = distribution )
247
251
return distribution_config
0 commit comments