Skip to content

Commit 36eb70d

Browse files
committed
Feature: Cluster setup for MultiWorkerMirroredStrategy
1 parent 38db16c commit 36eb70d

File tree

5 files changed

+147
-13
lines changed

5 files changed

+147
-13
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@
2727
logger = logging.getLogger(__name__)
2828

2929
SAGEMAKER_PARAMETER_SERVER_ENABLED = "sagemaker_parameter_server_enabled"
30+
SAGEMAKER_MULTI_WORKER_MIRRORED_ENABLED = "sagemaker_multi_worker_mirrored_enabled"
3031
MODEL_DIR = "/opt/ml/model"
3132

3233

3334
def _is_host_master(hosts, current_host):
3435
return current_host == hosts[0]
3536

3637

37-
def _build_tf_config(hosts, current_host, ps_task=False):
38+
def _build_tf_config_for_ps(hosts, current_host, ps_task=False):
3839
"""Builds a dictionary containing cluster information based on number of hosts and number of
3940
parameter servers.
4041
@@ -84,6 +85,31 @@ def host_addresses(hosts, port=2222):
8485
return tf_config
8586

8687

88+
def _build_tf_config_for_mwm(hosts, current_host):
89+
"""Builds a dictionary containing cluster information based on number of workers
90+
for Multi Worker Mirrored distribution strategy.
91+
92+
Args:
93+
hosts (list[str]): List of host names in the cluster
94+
current_host (str): Current host name
95+
96+
Returns:
97+
dict[str: dict]: A dictionary describing the cluster setup for distributed training.
98+
For more information regarding TF_CONFIG:
99+
https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details
100+
"""
101+
workers = hosts
102+
103+
def host_addresses(hosts, port=8890):
104+
return ["{}:{}".format(host, port) for host in hosts]
105+
106+
tf_config = {"cluster": {}, "environment": "cloud"}
107+
tf_config["cluster"]["worker"] = host_addresses(workers)
108+
tf_config["task"] = {"index": workers.index(current_host), "type": "worker"}
109+
110+
return tf_config
111+
112+
87113
def _run_ps(env, cluster):
88114
logger.info("Running distributed training job with parameter servers")
89115

@@ -135,12 +161,26 @@ def train(env, cmd_args):
135161
"""
136162
parameter_server_enabled = env.additional_framework_parameters.get(
137163
SAGEMAKER_PARAMETER_SERVER_ENABLED, False
164+
) and len(env.hosts) > 1
165+
multi_worker_mirrored_enabled = env.additional_framework_parameters.get(
166+
SAGEMAKER_MULTI_WORKER_MIRRORED_ENABLED, False
138167
)
139-
if len(env.hosts) > 1 and parameter_server_enabled:
168+
169+
# Setup
170+
if parameter_server_enabled:
171+
172+
tf_config = _build_tf_config_for_ps(hosts=env.hosts, current_host=env.current_host)
173+
logger.info("Running distributed training job with parameter servers")
174+
175+
elif multi_worker_mirrored_enabled:
176+
177+
tf_config = _build_tf_config_for_mwm(hosts=env.hosts, current_host=env.current_host)
178+
logger.info("Running distributed training job with multi_worker_mirrored setup")
140179

141-
tf_config = _build_tf_config(hosts=env.hosts, current_host=env.current_host)
142180

143-
logger.info("Running distributed training job with parameter servers")
181+
# Run
182+
if parameter_server_enabled:
183+
144184
logger.info("Launching parameter server process")
145185
_run_ps(env, tf_config["cluster"])
146186
logger.info("Launching worker process")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2017-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
import boto3
18+
import pytest
19+
from sagemaker.tensorflow import TensorFlow
20+
from sagemaker.utils import unique_name_from_base
21+
from six.moves.urllib.parse import urlparse
22+
23+
from timeout import timeout
24+
25+
26+
27+
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources")
28+
29+
30+
31+
def test_multi_node(sagemaker_session, instance_type, image_uri, tmpdir, framework_version):
32+
estimator = TensorFlow(
33+
entry_point=os.path.join(RESOURCE_PATH, "multi_worker_mirrored", "train_sample.py"),
34+
role="SageMakerRole",
35+
instance_type=instance_type,
36+
instance_count=2,
37+
image_name=image_uri,
38+
framework_version=framework_version,
39+
py_version="py3",
40+
sagemaker_session=sagemaker_session,
41+
)
42+
estimator.fit(job_name=unique_name_from_base("test-tf-mwms"))
43+
raise NotImplementedError('Yet to add assertions')
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import tensorflow as tf
2+
3+
4+
5+
strategy = tf.distribute.MultiWorkerMirroredStrategy()
6+
7+
with strategy.scope():
8+
model = tf.keras.Sequential([
9+
tf.keras.layers.Dense(2, input_shape=(5,)),
10+
])
11+
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
12+
13+
def dataset_fn(ctx):
14+
x = np.random.random((2, 5)).astype(np.float32)
15+
y = np.random.randint(2, size=(2, 1))
16+
dataset = tf.data.Dataset.from_tensor_slices((x, y))
17+
return dataset.repeat().batch(1, drop_remainder=True)
18+
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
19+
20+
model.compile()
21+
model.fit(dist_dataset)

test/unit/test_training.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@
3131
CURRENT_HOST = HOST1
3232
CMD_ARGS = {"some_key": "some_value"}
3333
CLUSTER_WITH_PS = {
34-
"master": ["{}:2222".format(HOST1)],
35-
"worker": ["{}:2222".format(HOST2)],
34+
"master": ["{}:8890".format(HOST1)],
35+
"worker": ["{}:8890".format(HOST2)],
3636
"ps": ["{}:2223".format(HOST1), "{}:2223".format(HOST2)],
3737
}
38+
CLUSTER_WITH_MWMS = {
39+
"worker": ["{}:8890".format(HOST) for HOST IN (HOST1, HOST2)],
40+
}
41+
3842
MASTER_TASK = {"index": 0, "type": "master"}
3943
WORKER_TASK = {"index": 0, "type": "worker"}
4044
PS_TASK_1 = {"index": 0, "type": "ps"}
@@ -205,32 +209,45 @@ def test_train_distributed_no_ps(run, distributed_training_env):
205209
)
206210

207211

208-
def test_build_tf_config():
209-
assert training._build_tf_config(HOST_LIST, HOST1) == {
212+
def test_build_tf_config_for_mwms():
213+
assert training._build_tf_config_for_mwms(HOST_LIST, HOST1) == {
214+
"cluster": CLUSTER_WITH_MWMS,
215+
"environment": "cloud",
216+
"task": {"index": HOST_LIST.index(HOST1), "type": "worker"},
217+
}
218+
assert training._build_tf_config_for_mwms(HOST_LIST, HOST2) == {
219+
"cluster": CLUSTER_WITH_MWMS,
220+
"environment": "cloud",
221+
"task": {"index": HOST_LIST.index(HOST2), "type": "worker"},
222+
}
223+
224+
225+
def test_build_tf_config_for_ps():
226+
assert training._build_tf_config_for_ps(HOST_LIST, HOST1) == {
210227
"cluster": CLUSTER_WITH_PS,
211228
"environment": "cloud",
212229
"task": MASTER_TASK,
213230
}
214-
assert training._build_tf_config(HOST_LIST, HOST1, ps_task=True) == {
231+
assert training._build_tf_config_for_ps(HOST_LIST, HOST1, ps_task=True) == {
215232
"cluster": CLUSTER_WITH_PS,
216233
"environment": "cloud",
217234
"task": PS_TASK_1,
218235
}
219-
assert training._build_tf_config(HOST_LIST, HOST2) == {
236+
assert training._build_tf_config_for_ps(HOST_LIST, HOST2) == {
220237
"cluster": CLUSTER_WITH_PS,
221238
"environment": "cloud",
222239
"task": WORKER_TASK,
223240
}
224-
assert training._build_tf_config(HOST_LIST, HOST2, ps_task=True) == {
241+
assert training._build_tf_config_for_ps(HOST_LIST, HOST2, ps_task=True) == {
225242
"cluster": CLUSTER_WITH_PS,
226243
"environment": "cloud",
227244
"task": PS_TASK_2,
228245
}
229246

230247

231-
def test_build_tf_config_error():
248+
def test_build_tf_config_for_ps_error():
232249
with pytest.raises(ValueError) as error:
233-
training._build_tf_config([HOST1], HOST1, ps_task=True)
250+
training._build_tf_config_for_ps([HOST1], HOST1, ps_task=True)
234251
assert "Cannot have a ps task if there are no parameter servers in the cluster" in str(
235252
error.value
236253
)

0 commit comments

Comments
 (0)