Skip to content

Commit ae70340

Browse files
Lokiiiiiimchoi8739
andauthored
feature: Adding support for Multi Worker Mirrored Strategy in TF estimator (#3192)
Co-authored-by: Miyoung <[email protected]>
1 parent 5b68431 commit ae70340

File tree

10 files changed

+364
-298
lines changed

10 files changed

+364
-298
lines changed

src/sagemaker/estimator.py

+43-23
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import uuid
2121
from abc import ABCMeta, abstractmethod
2222
from typing import Any, Dict, Union, Optional, List
23+
from packaging.specifiers import SpecifierSet
24+
from packaging.version import Version
2325

2426
from six import string_types, with_metaclass
2527
from six.moves.urllib.parse import urlparse
@@ -83,10 +85,7 @@
8385
)
8486
from sagemaker.workflow import is_pipeline_variable
8587
from sagemaker.workflow.entities import PipelineVariable
86-
from sagemaker.workflow.pipeline_context import (
87-
PipelineSession,
88-
runnable_by_pipeline,
89-
)
88+
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
9089

9190
logger = logging.getLogger(__name__)
9291

@@ -106,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
106105
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
107106
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
108107
LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
108+
LAUNCH_MWMS_ENV_NAME = "sagemaker_multi_worker_mirrored_strategy_enabled"
109109
INSTANCE_TYPE = "sagemaker_instance_type"
110110
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
111111
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
@@ -557,9 +557,7 @@ def __init__(
557557
self.dependencies = dependencies or []
558558
self.uploaded_code = None
559559
self.tags = add_jumpstart_tags(
560-
tags=tags,
561-
training_model_uri=self.model_uri,
562-
training_script_uri=self.source_dir,
560+
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
563561
)
564562
if self.instance_type in ("local", "local_gpu"):
565563
if self.instance_type == "local_gpu" and self.instance_count > 1:
@@ -680,8 +678,7 @@ def _ensure_base_job_name(self):
680678
self.base_job_name
681679
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
682680
or base_name_from_image(
683-
self.training_image_uri(),
684-
default_base_name=EstimatorBase.JOB_CLASS_NAME,
681+
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
685682
)
686683
)
687684

@@ -744,7 +741,6 @@ def _prepare_for_training(self, job_name=None):
744741
self.dependencies = updated_paths["dependencies"]
745742

746743
if self.source_dir or self.entry_point or self.dependencies:
747-
748744
# validate source dir will raise a ValueError if there is something wrong with
749745
# the source directory. We are intentionally not handling it because this is a
750746
# critical error.
@@ -1023,10 +1019,7 @@ def _set_source_s3_uri(self, rule):
10231019
parse_result = urlparse(rule.rule_parameters["source_s3_uri"])
10241020
if parse_result.scheme != "s3":
10251021
desired_s3_uri = os.path.join(
1026-
"s3://",
1027-
self.sagemaker_session.default_bucket(),
1028-
rule.name,
1029-
str(uuid.uuid4()),
1022+
"s3://", self.sagemaker_session.default_bucket(), rule.name, str(uuid.uuid4())
10301023
)
10311024
s3_uri = S3Uploader.upload(
10321025
local_path=rule.rule_parameters["source_s3_uri"],
@@ -1439,10 +1432,7 @@ def deploy(
14391432
self._ensure_base_job_name()
14401433

14411434
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
1442-
kwargs.get("source_dir"),
1443-
self.source_dir,
1444-
kwargs.get("model_data"),
1445-
self.model_uri,
1435+
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
14461436
)
14471437
default_name = (
14481438
name_from_base(jumpstart_base_name)
@@ -2240,11 +2230,7 @@ def _is_local_channel(cls, input_uri):
22402230

22412231
@classmethod
22422232
def update(
2243-
cls,
2244-
estimator,
2245-
profiler_rule_configs=None,
2246-
profiler_config=None,
2247-
resource_config=None,
2233+
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
22482234
):
22492235
"""Update a running Amazon SageMaker training job.
22502236
@@ -3165,6 +3151,34 @@ def _validate_and_set_debugger_configs(self):
31653151
)
31663152
self.debugger_hook_config = False
31673153

3154+
def _validate_mwms_config(self, distribution):
3155+
"""Validate Multi Worker Mirrored Strategy configuration."""
3156+
minimum_supported_framework_version = {"tensorflow": {"framework_version": "2.9"}}
3157+
if self._framework_name in minimum_supported_framework_version:
3158+
for version_argument in minimum_supported_framework_version[self._framework_name]:
3159+
current = getattr(self, version_argument)
3160+
threshold = minimum_supported_framework_version[self._framework_name][
3161+
version_argument
3162+
]
3163+
if Version(current) in SpecifierSet(f"< {threshold}"):
3164+
raise ValueError(
3165+
"Multi Worker Mirrored Strategy is only supported "
3166+
"from {} {} but received {}".format(version_argument, threshold, current)
3167+
)
3168+
else:
3169+
raise ValueError(
3170+
"Multi Worker Mirrored Strategy is currently only supported "
3171+
"with {} frameworks but received {}".format(
3172+
minimum_supported_framework_version.keys(), self._framework_name
3173+
)
3174+
)
3175+
unsupported_distributions = ["smdistributed", "parameter_server"]
3176+
if any(i in distribution for i in unsupported_distributions):
3177+
raise ValueError(
3178+
"Multi Worker Mirrored Strategy is currently not supported with the"
3179+
" following distribution strategies: {}".format(unsupported_distributions)
3180+
)
3181+
31683182
def _model_source_dir(self):
31693183
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
31703184
@@ -3528,6 +3542,12 @@ def _distribution_configuration(self, distribution):
35283542
"dataparallel"
35293543
].get("custom_mpi_options", "")
35303544

3545+
if "multi_worker_mirrored_strategy" in distribution:
3546+
mwms_enabled = distribution.get("multi_worker_mirrored_strategy").get("enabled", False)
3547+
if mwms_enabled:
3548+
self._validate_mwms_config(distribution)
3549+
distribution_config[self.LAUNCH_MWMS_ENV_NAME] = mwms_enabled
3550+
35313551
if not (mpi_enabled or smdataparallel_enabled) and distribution_config.get(
35323552
"sagemaker_distribution_instance_groups"
35333553
) not in [None, []]:

src/sagemaker/tensorflow/estimator.py

+17
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,23 @@ def __init__(
137137
To find a complete list of parameters for SageMaker model parallelism,
138138
see :ref:`sm-sdk-modelparallel-general`.
139139
140+
**To enable Multi Worker Mirrored Strategy:**
141+
142+
.. code:: python
143+
144+
{
145+
"multi_worker_mirrored_strategy": {
146+
"enabled": True
147+
}
148+
}
149+
150+
This distribution strategy option is available for TensorFlow 2.9 and later in
151+
the SageMaker Python SDK v2.xx.yy and later.
152+
To learn more about the mirrored strategy for TensorFlow,
153+
see `TensorFlow Distributed Training
154+
<https://www.tensorflow.org/guide/distributed_training>`_
155+
in the *TensorFlow documentation*.
156+
140157
**To enable MPI:**
141158
142159
.. code:: python

src/sagemaker/tensorflow/training_compiler/config.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def validate(cls, estimator):
7979
"""Checks if SageMaker Training Compiler is configured correctly.
8080
8181
Args:
82-
estimator (str): A estimator object
82+
estimator (:class:`sagemaker.tensorflow.estimator.TensorFlow`): A estimator object
8383
If SageMaker Training Compiler is enabled, it will validate whether
8484
the estimator is configured to be compatible with Training Compiler.
8585
@@ -102,3 +102,13 @@ def validate(cls, estimator):
102102
cls.MIN_SUPPORTED_VERSION, estimator.framework_version
103103
)
104104
raise ValueError(error_helper_string)
105+
106+
if estimator.distribution and "multi_worker_mirrored_strategy" in estimator.distribution:
107+
mwms_enabled = estimator.distribution.get("multi_worker_mirrored_strategy").get(
108+
"enabled", False
109+
)
110+
if mwms_enabled:
111+
raise ValueError(
112+
"Multi Worker Mirrored Strategy distributed training configuration "
113+
"is currently not compatible with SageMaker Training Compiler."
114+
)

tests/conftest.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,7 @@ def huggingface_training_compiler_pytorch_version(
281281
huggingface_training_compiler_version,
282282
):
283283
versions = _huggingface_base_fm_version(
284-
huggingface_training_compiler_version,
285-
"pytorch",
286-
"huggingface_training_compiler",
284+
huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler"
287285
)
288286
if not versions:
289287
pytest.skip(
@@ -298,9 +296,7 @@ def huggingface_training_compiler_tensorflow_version(
298296
huggingface_training_compiler_version,
299297
):
300298
versions = _huggingface_base_fm_version(
301-
huggingface_training_compiler_version,
302-
"tensorflow",
303-
"huggingface_training_compiler",
299+
huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler"
304300
)
305301
if not versions:
306302
pytest.skip(
@@ -526,8 +522,7 @@ def pytorch_ddp_py_version():
526522

527523

528524
@pytest.fixture(
529-
scope="module",
530-
params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"],
525+
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
531526
)
532527
def pytorch_ddp_framework_version(request):
533528
return request.param
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
2+
3+
import json
4+
import os
5+
import tensorflow as tf
6+
import numpy as np
7+
8+
9+
def mnist_dataset(batch_size):
10+
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
11+
# The `x` arrays are in uint8 and have values in the [0, 255] range.
12+
# You need to convert them to float32 with values in the [0, 1] range.
13+
x_train = x_train / np.float32(255)
14+
y_train = y_train.astype(np.int64)
15+
train_dataset = (
16+
tf.data.Dataset.from_tensor_slices((x_train, y_train))
17+
.shuffle(60000)
18+
.repeat()
19+
.batch(batch_size)
20+
)
21+
return train_dataset
22+
23+
24+
def build_and_compile_cnn_model():
25+
model = tf.keras.Sequential(
26+
[
27+
tf.keras.layers.InputLayer(input_shape=(28, 28)),
28+
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
29+
tf.keras.layers.Conv2D(32, 3, activation="relu"),
30+
tf.keras.layers.Flatten(),
31+
tf.keras.layers.Dense(128, activation="relu"),
32+
tf.keras.layers.Dense(10),
33+
]
34+
)
35+
model.compile(
36+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
37+
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
38+
metrics=["accuracy"],
39+
)
40+
return model
41+
42+
43+
per_worker_batch_size = 64
44+
tf_config = json.loads(os.environ["TF_CONFIG"])
45+
num_workers = len(tf_config["cluster"]["worker"])
46+
47+
strategy = tf.distribute.MultiWorkerMirroredStrategy()
48+
49+
global_batch_size = per_worker_batch_size * num_workers
50+
multi_worker_dataset = mnist_dataset(global_batch_size)
51+
52+
with strategy.scope():
53+
multi_worker_model = build_and_compile_cnn_model()
54+
55+
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
56+
57+
print(f"strategy.num_replicas_in_sync={strategy.num_replicas_in_sync}")

tests/integ/test_tf.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SCRIPT = "mnist.py"
3939
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
4040
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
41+
MWMS_DISTRIBUTION = {"multi_worker_mirrored_strategy": {"enabled": True}}
4142
TAGS = [{"Key": "some-key", "Value": "some-value"}]
4243
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
4344

@@ -68,12 +69,7 @@ def test_framework_processing_job_with_deps(
6869
sagemaker_session=sagemaker_session,
6970
base_job_name="test-tensorflow",
7071
)
71-
processor.run(
72-
code=entry_point,
73-
source_dir=code_path,
74-
inputs=[],
75-
wait=True,
76-
)
72+
processor.run(code=entry_point, source_dir=code_path, inputs=[], wait=True)
7773

7874

7975
def test_mnist_with_checkpoint_config(
@@ -110,9 +106,7 @@ def test_mnist_with_checkpoint_config(
110106
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
111107
estimator.fit(inputs=inputs, job_name=training_job_name)
112108
assert_s3_file_patterns_exist(
113-
sagemaker_session,
114-
estimator.model_dir,
115-
[r"model\.ckpt-\d+\.index", r"checkpoint"],
109+
sagemaker_session, estimator.model_dir, [r"model\.ckpt-\d+\.index", r"checkpoint"]
116110
)
117111
# remove dataframe assertion to unblock PR build
118112
# TODO: add independent integration test for `training_job_analytics`
@@ -130,9 +124,7 @@ def test_mnist_with_checkpoint_config(
130124
]
131125
)
132126

133-
expected_retry_strategy = {
134-
"MaximumRetryAttempts": 2,
135-
}
127+
expected_retry_strategy = {"MaximumRetryAttempts": 2}
136128
actual_retry_strategy = sagemaker_session.sagemaker_client.describe_training_job(
137129
TrainingJobName=training_job_name
138130
)["RetryStrategy"]
@@ -181,6 +173,48 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
181173
)
182174

183175

176+
@pytest.mark.release
177+
@pytest.mark.skipif(
178+
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
179+
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
180+
reason="no ml.p2 or ml.p3 instances in this region",
181+
)
182+
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
183+
def test_mwms_gpu(
184+
sagemaker_session,
185+
tensorflow_training_latest_version,
186+
tensorflow_training_latest_py_version,
187+
capsys,
188+
**kwargs,
189+
):
190+
instance_count = 2
191+
estimator = TensorFlow(
192+
source_dir=os.path.join(RESOURCE_PATH, "tensorflow_mnist"),
193+
entry_point="mnist_mwms.py",
194+
model_dir=False,
195+
instance_type=kwargs["instance_type"],
196+
instance_count=instance_count,
197+
framework_version=tensorflow_training_latest_version,
198+
py_version=tensorflow_training_latest_py_version,
199+
distribution=MWMS_DISTRIBUTION,
200+
environment={"NCCL_DEBUG": "INFO"},
201+
max_run=60 * 60 * 1, # 1 hour
202+
role=ROLE,
203+
volume_size=400,
204+
sagemaker_session=sagemaker_session,
205+
disable_profiler=True,
206+
)
207+
208+
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
209+
estimator.fit(job_name=unique_name_from_base("test-tf-mwms"))
210+
211+
captured = capsys.readouterr()
212+
logs = captured.out + captured.err
213+
print(logs)
214+
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
215+
assert f"strategy.num_replicas_in_sync={instance_count}" in logs
216+
217+
184218
@pytest.mark.release
185219
def test_mnist_distributed_cpu(
186220
sagemaker_session,
@@ -237,9 +271,7 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
237271
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
238272
estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed"))
239273
assert_s3_file_patterns_exist(
240-
sagemaker_session,
241-
estimator.model_dir,
242-
[r"model\.ckpt-\d+\.index", r"checkpoint"],
274+
sagemaker_session, estimator.model_dir, [r"model\.ckpt-\d+\.index", r"checkpoint"]
243275
)
244276

245277

@@ -346,8 +378,7 @@ def test_model_deploy_with_serverless_inference_config(
346378
sagemaker_session=sagemaker_session,
347379
)
348380
predictor = model.deploy(
349-
serverless_inference_config=ServerlessInferenceConfig(),
350-
endpoint_name=endpoint_name,
381+
serverless_inference_config=ServerlessInferenceConfig(), endpoint_name=endpoint_name
351382
)
352383

353384
input_data = {"instances": [1.0, 2.0, 5.0]}

0 commit comments

Comments
 (0)