20
20
import uuid
21
21
from abc import ABCMeta , abstractmethod
22
22
from typing import Any , Dict , Union , Optional , List
23
+ from packaging .specifiers import SpecifierSet
24
+ from packaging .version import Version
23
25
24
26
from six import string_types , with_metaclass
25
27
from six .moves .urllib .parse import urlparse
83
85
)
84
86
from sagemaker .workflow import is_pipeline_variable
85
87
from sagemaker .workflow .entities import PipelineVariable
86
- from sagemaker .workflow .pipeline_context import (
87
- PipelineSession ,
88
- runnable_by_pipeline ,
89
- )
88
+ from sagemaker .workflow .pipeline_context import PipelineSession , runnable_by_pipeline
90
89
91
90
logger = logging .getLogger (__name__ )
92
91
@@ -106,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
106
105
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
107
106
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
108
107
LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
108
+ LAUNCH_MWMS_ENV_NAME = "sagemaker_multi_worker_mirrored_strategy_enabled"
109
109
INSTANCE_TYPE = "sagemaker_instance_type"
110
110
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
111
111
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
@@ -557,9 +557,7 @@ def __init__(
557
557
self .dependencies = dependencies or []
558
558
self .uploaded_code = None
559
559
self .tags = add_jumpstart_tags (
560
- tags = tags ,
561
- training_model_uri = self .model_uri ,
562
- training_script_uri = self .source_dir ,
560
+ tags = tags , training_model_uri = self .model_uri , training_script_uri = self .source_dir
563
561
)
564
562
if self .instance_type in ("local" , "local_gpu" ):
565
563
if self .instance_type == "local_gpu" and self .instance_count > 1 :
@@ -680,8 +678,7 @@ def _ensure_base_job_name(self):
680
678
self .base_job_name
681
679
or get_jumpstart_base_name_if_jumpstart_model (self .source_dir , self .model_uri )
682
680
or base_name_from_image (
683
- self .training_image_uri (),
684
- default_base_name = EstimatorBase .JOB_CLASS_NAME ,
681
+ self .training_image_uri (), default_base_name = EstimatorBase .JOB_CLASS_NAME
685
682
)
686
683
)
687
684
@@ -744,7 +741,6 @@ def _prepare_for_training(self, job_name=None):
744
741
self .dependencies = updated_paths ["dependencies" ]
745
742
746
743
if self .source_dir or self .entry_point or self .dependencies :
747
-
748
744
# validate source dir will raise a ValueError if there is something wrong with
749
745
# the source directory. We are intentionally not handling it because this is a
750
746
# critical error.
@@ -1023,10 +1019,7 @@ def _set_source_s3_uri(self, rule):
1023
1019
parse_result = urlparse (rule .rule_parameters ["source_s3_uri" ])
1024
1020
if parse_result .scheme != "s3" :
1025
1021
desired_s3_uri = os .path .join (
1026
- "s3://" ,
1027
- self .sagemaker_session .default_bucket (),
1028
- rule .name ,
1029
- str (uuid .uuid4 ()),
1022
+ "s3://" , self .sagemaker_session .default_bucket (), rule .name , str (uuid .uuid4 ())
1030
1023
)
1031
1024
s3_uri = S3Uploader .upload (
1032
1025
local_path = rule .rule_parameters ["source_s3_uri" ],
@@ -1439,10 +1432,7 @@ def deploy(
1439
1432
self ._ensure_base_job_name ()
1440
1433
1441
1434
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model (
1442
- kwargs .get ("source_dir" ),
1443
- self .source_dir ,
1444
- kwargs .get ("model_data" ),
1445
- self .model_uri ,
1435
+ kwargs .get ("source_dir" ), self .source_dir , kwargs .get ("model_data" ), self .model_uri
1446
1436
)
1447
1437
default_name = (
1448
1438
name_from_base (jumpstart_base_name )
@@ -2240,11 +2230,7 @@ def _is_local_channel(cls, input_uri):
2240
2230
2241
2231
@classmethod
2242
2232
def update (
2243
- cls ,
2244
- estimator ,
2245
- profiler_rule_configs = None ,
2246
- profiler_config = None ,
2247
- resource_config = None ,
2233
+ cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
2248
2234
):
2249
2235
"""Update a running Amazon SageMaker training job.
2250
2236
@@ -3165,6 +3151,34 @@ def _validate_and_set_debugger_configs(self):
3165
3151
)
3166
3152
self .debugger_hook_config = False
3167
3153
3154
+ def _validate_mwms_config (self , distribution ):
3155
+ """Validate Multi Worker Mirrored Strategy configuration."""
3156
+ minimum_supported_framework_version = {"tensorflow" : {"framework_version" : "2.9" }}
3157
+ if self ._framework_name in minimum_supported_framework_version :
3158
+ for version_argument in minimum_supported_framework_version [self ._framework_name ]:
3159
+ current = getattr (self , version_argument )
3160
+ threshold = minimum_supported_framework_version [self ._framework_name ][
3161
+ version_argument
3162
+ ]
3163
+ if Version (current ) in SpecifierSet (f"< { threshold } " ):
3164
+ raise ValueError (
3165
+ "Multi Worker Mirrored Strategy is only supported "
3166
+ "from {} {} but received {}" .format (version_argument , threshold , current )
3167
+ )
3168
+ else :
3169
+ raise ValueError (
3170
+ "Multi Worker Mirrored Strategy is currently only supported "
3171
+ "with {} frameworks but received {}" .format (
3172
+ minimum_supported_framework_version .keys (), self ._framework_name
3173
+ )
3174
+ )
3175
+ unsupported_distributions = ["smdistributed" , "parameter_server" ]
3176
+ if any (i in distribution for i in unsupported_distributions ):
3177
+ raise ValueError (
3178
+ "Multi Worker Mirrored Strategy is currently not supported with the"
3179
+ " following distribution strategies: {}" .format (unsupported_distributions )
3180
+ )
3181
+
3168
3182
def _model_source_dir (self ):
3169
3183
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
3170
3184
@@ -3528,6 +3542,12 @@ def _distribution_configuration(self, distribution):
3528
3542
"dataparallel"
3529
3543
].get ("custom_mpi_options" , "" )
3530
3544
3545
+ if "multi_worker_mirrored_strategy" in distribution :
3546
+ mwms_enabled = distribution .get ("multi_worker_mirrored_strategy" ).get ("enabled" , False )
3547
+ if mwms_enabled :
3548
+ self ._validate_mwms_config (distribution )
3549
+ distribution_config [self .LAUNCH_MWMS_ENV_NAME ] = mwms_enabled
3550
+
3531
3551
if not (mpi_enabled or smdataparallel_enabled ) and distribution_config .get (
3532
3552
"sagemaker_distribution_instance_groups"
3533
3553
) not in [None , []]:
0 commit comments