28
28
)
29
29
from sagemaker .pytorch import defaults
30
30
from sagemaker .pytorch .model import PyTorchModel
31
+ from sagemaker .pytorch .training_compiler .config import TrainingCompilerConfig
31
32
from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
32
33
from sagemaker .workflow .entities import PipelineVariable
33
34
@@ -51,7 +52,8 @@ def __init__(
51
52
hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
52
53
image_uri : Optional [Union [str , PipelineVariable ]] = None ,
53
54
distribution : Optional [Dict ] = None ,
54
- ** kwargs
55
+ compiler_config : Optional [TrainingCompilerConfig ] = None ,
56
+ ** kwargs ,
55
57
):
56
58
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
57
59
@@ -208,6 +210,31 @@ def __init__(
208
210
To learn more, see `Training with parameter servers
209
211
<https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-parameter-servers>`_.
210
212
213
+ **To enable distributed training with
214
+ `SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
215
+ for PyTorch:**
216
+
217
+ .. code:: python
218
+
219
+ {
220
+ "pytorchxla": {
221
+ "enabled": True
222
+ }
223
+ }
224
+
225
+ To learn more, see `SageMaker Training Compiler
226
+ <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
227
+ in the *Amazon SageMaker Developer Guide*.
228
+
229
+ .. note::
230
+
231
+ When you use this PyTorch XLA option for distributed training strategy,
232
+ you must add the ``compiler_config`` parameter and activate SageMaker
233
+ Training Compiler.
234
+
235
+ compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`):
236
+ Configures SageMaker Training Compiler to accelerate training.
237
+
211
238
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
212
239
constructor.
213
240
@@ -250,6 +277,25 @@ def __init__(
250
277
251
278
self .distribution = distribution or {}
252
279
280
+ if compiler_config is not None :
281
+ if not isinstance (compiler_config , TrainingCompilerConfig ):
282
+ error_string = (
283
+ f"Expected instance of type { TrainingCompilerConfig } "
284
+ f"for argument compiler_config. "
285
+ f"Instead got { type (compiler_config )} "
286
+ )
287
+ raise ValueError (error_string )
288
+ if compiler_config :
289
+ compiler_config .validate (self )
290
+ elif distribution is not None and "pytorchxla" in distribution :
291
+ raise ValueError (
292
+ "Distributed training through PyTorch XLA is currently only supported "
293
+ "when SageMaker Training Compiler is enabled. To learn more, "
294
+ "see Enable SageMaker Training Compiler at "
295
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html."
296
+ )
297
+ self .compiler_config = compiler_config
298
+
253
299
def _pytorch_distribution_configuration (self , distribution ):
254
300
"""Returns a dict of distribution config for PyTorch training
255
301
@@ -289,6 +335,12 @@ def hyperparameters(self):
289
335
hyperparameters .update (
290
336
EstimatorBase ._json_encode_hyperparameters (additional_hyperparameters )
291
337
)
338
+ if self .compiler_config :
339
+ training_compiler_hyperparameters = self .compiler_config ._to_hyperparameter_dict ()
340
+ hyperparameters .update (
341
+ EstimatorBase ._json_encode_hyperparameters (training_compiler_hyperparameters )
342
+ )
343
+
292
344
return hyperparameters
293
345
294
346
def create_model (
@@ -299,7 +351,7 @@ def create_model(
299
351
entry_point = None ,
300
352
source_dir = None ,
301
353
dependencies = None ,
302
- ** kwargs
354
+ ** kwargs ,
303
355
):
304
356
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
305
357
@@ -350,7 +402,7 @@ def create_model(
350
402
sagemaker_session = self .sagemaker_session ,
351
403
vpc_config = self .get_vpc_config (vpc_config_override ),
352
404
dependencies = (dependencies or self .dependencies ),
353
- ** kwargs
405
+ ** kwargs ,
354
406
)
355
407
356
408
@classmethod
@@ -371,6 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
371
423
)
372
424
image_uri = init_params .pop ("image_uri" )
373
425
framework , py_version , tag , _ = framework_name_from_image (image_uri )
426
+ if framework :
427
+ framework = framework .split ("-" )[0 ]
374
428
375
429
if tag is None :
376
430
framework_version = None
0 commit comments