Skip to content

Feature: Cluster setup for MultiWorkerMirroredStrategy #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jun 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b45d840
Feature: Cluster setup for MultiWorkerMirroredStrategy
Lokiiiiii May 19, 2022
4564712
Configuring tests to use the new hyperparameter for MWMS
Lokiiiiii May 19, 2022
9e528c7
Black formatted files
Lokiiiiii May 19, 2022
ca61c3f
fixing failing tests
Lokiiiiii May 20, 2022
718e5c7
Removing references to py versions older than py37
Lokiiiiii May 20, 2022
86701b4
Converting py36 tests to py37
Lokiiiiii May 24, 2022
df94fc4
fix: linting and changed variable name to sagemaker_multi_worker_mirr…
Lokiiiiii Jun 3, 2022
c3e6819
fix: feezing protobuf version
Lokiiiiii Jun 3, 2022
0f7ee2f
fix: renaming MWMS variable name
Lokiiiiii Jun 3, 2022
56337da
fix: rename functions for _mwm to _mwms
Lokiiiiii Jun 3, 2022
4e65975
Revert "fix: feezing protobuf version"
Lokiiiiii Jun 3, 2022
24242ea
Revert "Converting py36 tests to py37"
Lokiiiiii Jun 3, 2022
e6fbbcc
Revert "Removing references to py versions older than py37"
Lokiiiiii Jun 3, 2022
f2773a7
fix: variable name changes for MWMS
Lokiiiiii Jun 3, 2022
0676291
fix: renaming training script to train_dummy.py
Lokiiiiii Jun 3, 2022
b19894d
fix: freezing latest sagemaker toolkit version
Lokiiiiii Jun 3, 2022
9ed20cf
trigger ci
nish21 Jun 3, 2022
36d81eb
Merge branch 'tf-2' of github.com:aws/sagemaker-tensorflow-training-t…
nish21 Jun 3, 2022
dd06073
fix: adding epochs and steps to failing MWMS test
Lokiiiiii Jun 4, 2022
f5cf636
fix: changing MWMS testcase
Lokiiiiii Jun 4, 2022
f8ce3f0
fix: logic error in MWMS
Lokiiiiii Jun 4, 2022
4e1f0af
fix: logic error in MWMS
Lokiiiiii Jun 4, 2022
edfc844
fix: Updating MWMS tests to check for log lines
Lokiiiiii Jun 4, 2022
385462e
fix: linting
Lokiiiiii Jun 4, 2022
0c420ce
trigger ci
nish21 Jun 4, 2022
36edae2
trigger ci
nish21 Jun 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def read_version():
"Programming Language :: Python :: 3.9",
],
install_requires=[
"sagemaker-training>=4.1.0",
"sagemaker-training>=4.1.3",
"numpy",
"scipy",
"sklearn",
Expand Down
61 changes: 53 additions & 8 deletions src/sagemaker_tensorflow_container/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@

SAGEMAKER_PARAMETER_SERVER_ENABLED = "sagemaker_parameter_server_enabled"
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED = "sagemaker_distributed_dataparallel_enabled"
SAGEMAKER_MULTI_WORKER_MIRRORED_STRATEGY_ENABLED = (
"sagemaker_multi_worker_mirrored_strategy_enabled"
)
MODEL_DIR = "/opt/ml/model"


def _is_host_master(hosts, current_host):
return current_host == hosts[0]


def _build_tf_config(hosts, current_host, ps_task=False):
def _build_tf_config_for_ps(hosts, current_host, ps_task=False):
"""Builds a dictionary containing cluster information based on number of hosts and number of
parameter servers.

Expand Down Expand Up @@ -85,6 +88,31 @@ def host_addresses(hosts, port=2222):
return tf_config


def _build_tf_config_for_mwms(hosts, current_host):
"""Builds a dictionary containing cluster information based on number of workers
for Multi Worker Mirrored distribution strategy.

Args:
hosts (list[str]): List of host names in the cluster
current_host (str): Current host name

Returns:
dict[str: dict]: A dictionary describing the cluster setup for distributed training.
For more information regarding TF_CONFIG:
https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details
"""
workers = hosts

def host_addresses(hosts, port=8890):
return ["{}:{}".format(host, port) for host in hosts]

tf_config = {"cluster": {}, "environment": "cloud"}
tf_config["cluster"]["worker"] = host_addresses(workers)
tf_config["task"] = {"index": workers.index(current_host), "type": "worker"}

return tf_config


def _run_ps(env, cluster):
logger.info("Running distributed training job with parameter servers")

Expand Down Expand Up @@ -134,17 +162,35 @@ def train(env, cmd_args):
Args:
env (sagemaker_training.environment.Environment): Instance of Environment class
"""
parameter_server_enabled = env.additional_framework_parameters.get(
SAGEMAKER_PARAMETER_SERVER_ENABLED, False
parameter_server_enabled = (
env.additional_framework_parameters.get(SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
and len(env.hosts) > 1
)
multi_worker_mirrored_strategy_enabled = env.additional_framework_parameters.get(
SAGEMAKER_MULTI_WORKER_MIRRORED_STRATEGY_ENABLED, False
)
sagemaker_distributed_dataparallel_enabled = env.additional_framework_parameters.get(
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED, False
)
if len(env.hosts) > 1 and parameter_server_enabled:

tf_config = _build_tf_config(hosts=env.hosts, current_host=env.current_host)
env_vars = env.to_env_vars()

# Setup
if parameter_server_enabled:

tf_config = _build_tf_config_for_ps(hosts=env.hosts, current_host=env.current_host)
logger.info("Running distributed training job with parameter servers")

elif multi_worker_mirrored_strategy_enabled:

env_vars["TF_CONFIG"] = json.dumps(
_build_tf_config_for_mwms(hosts=env.hosts, current_host=env.current_host)
)
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")

# Run
if parameter_server_enabled:

logger.info("Launching parameter server process")
_run_ps(env, tf_config["cluster"])
logger.info("Launching worker process")
Expand All @@ -168,7 +214,7 @@ def train(env, cmd_args):
uri=env.module_dir,
user_entry_point=env.user_entry_point,
args=cmd_args,
env_vars=env.to_env_vars(),
env_vars=env_vars,
capture_error=True,
runner_type=runner_type,
)
Expand Down Expand Up @@ -217,8 +263,7 @@ def _model_dir_with_training_job(model_dir, job_name):


def main():
"""Training entry point
"""
"""Training entry point"""
hyperparameters = environment.read_hyperparameters()
env = environment.Environment(hyperparameters=hyperparameters)

Expand Down
42 changes: 42 additions & 0 deletions test/integration/sagemaker/test_multi_worker_mirrored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2017-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

from sagemaker.tensorflow import TensorFlow
from sagemaker.utils import unique_name_from_base


RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources")


def test_multi_node(sagemaker_session, instance_type, image_uri, tmpdir, framework_version, capsys):
estimator = TensorFlow(
entry_point=os.path.join(RESOURCE_PATH, "multi_worker_mirrored", "train_dummy.py"),
role="SageMakerRole",
instance_type=instance_type,
instance_count=2,
image_name=image_uri,
framework_version=framework_version,
py_version="py3",
hyperparameters={
"sagemaker_multi_worker_mirrored_strategy_enabled": True,
},
sagemaker_session=sagemaker_session,
)
estimator.fit(job_name=unique_name_from_base("test-tf-mwms"))
captured = capsys.readouterr()
logs = captured.out + captured.err
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
assert "TF_CONFIG=" in logs
13 changes: 13 additions & 0 deletions test/resources/multi_worker_mirrored/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import absolute_import
48 changes: 48 additions & 0 deletions test/resources/multi_worker_mirrored/train_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Please refer to https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_keras.ipynb

import tensorflow as tf
import numpy as np
import os
import json


def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the [0, 255] range.
# You need to convert them to float32 with values in the [0, 1] range.
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset

def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model


per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_and_compile_cnn_model()

multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
56 changes: 43 additions & 13 deletions test/unit/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
"worker": ["{}:2222".format(HOST2)],
"ps": ["{}:2223".format(HOST1), "{}:2223".format(HOST2)],
}
CLUSTER_WITH_MWMS = {"worker": ["{}:8890".format(HOST) for HOST in HOST_LIST]}

MASTER_TASK = {"index": 0, "type": "master"}
WORKER_TASK = {"index": 0, "type": "worker"}
PS_TASK_1 = {"index": 0, "type": "ps"}
Expand Down Expand Up @@ -109,7 +111,9 @@ def test_train_horovod(run_module, single_machine_training_env):

@patch("sagemaker_training.entry_point.run")
def test_train_smdataparallel(run_module, single_machine_training_env):
single_machine_training_env.additional_framework_parameters["sagemaker_distributed_dataparallel_enabled"] = True
single_machine_training_env.additional_framework_parameters[
"sagemaker_distributed_dataparallel_enabled"
] = True

training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
run_module.assert_called_with(
Expand All @@ -124,7 +128,8 @@ def test_train_smdataparallel(run_module, single_machine_training_env):

@pytest.mark.skip_on_pipeline
@pytest.mark.skipif(
sys.version_info.major != 3, reason="Skip this for python 2 because of dict key order mismatch"
sys.version_info.major != 3,
reason="Skip this for python 2 because of dict key order mismatch",
)
@patch("tensorflow.train.ClusterSpec")
@patch("tensorflow.distribute.Server")
Expand All @@ -135,7 +140,11 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)

cluster_spec.assert_called_with(
{"worker": ["host2:2222"], "master": ["host1:2222"], "ps": ["host1:2223", "host2:2223"]}
{
"worker": ["host2:2222"],
"master": ["host1:2222"],
"ps": ["host1:2223", "host2:2223"],
}
)

tf_server.assert_called_with(
Expand Down Expand Up @@ -166,7 +175,8 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai

@pytest.mark.skip_on_pipeline
@pytest.mark.skipif(
sys.version_info.major != 3, reason="Skip this for python 2 because of dict key order mismatch"
sys.version_info.major != 3,
reason="Skip this for python 2 because of dict key order mismatch",
)
@patch("tensorflow.train.ClusterSpec")
@patch("tensorflow.distribute.Server")
Expand All @@ -179,7 +189,11 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)

cluster_spec.assert_called_with(
{"worker": ["host2:2222"], "master": ["host1:2222"], "ps": ["host1:2223", "host2:2223"]}
{
"worker": ["host2:2222"],
"master": ["host1:2222"],
"ps": ["host1:2223", "host2:2223"],
}
)

tf_server.assert_called_with(
Expand Down Expand Up @@ -226,32 +240,45 @@ def test_train_distributed_no_ps(run, distributed_training_env):
)


def test_build_tf_config():
assert training._build_tf_config(HOST_LIST, HOST1) == {
def test_build_tf_config_for_mwms():
assert training._build_tf_config_for_mwms(HOST_LIST, HOST1) == {
"cluster": CLUSTER_WITH_MWMS,
"environment": "cloud",
"task": {"index": HOST_LIST.index(HOST1), "type": "worker"},
}
assert training._build_tf_config_for_mwms(HOST_LIST, HOST2) == {
"cluster": CLUSTER_WITH_MWMS,
"environment": "cloud",
"task": {"index": HOST_LIST.index(HOST2), "type": "worker"},
}


def test_build_tf_config_for_ps():
assert training._build_tf_config_for_ps(HOST_LIST, HOST1) == {
"cluster": CLUSTER_WITH_PS,
"environment": "cloud",
"task": MASTER_TASK,
}
assert training._build_tf_config(HOST_LIST, HOST1, ps_task=True) == {
assert training._build_tf_config_for_ps(HOST_LIST, HOST1, ps_task=True) == {
"cluster": CLUSTER_WITH_PS,
"environment": "cloud",
"task": PS_TASK_1,
}
assert training._build_tf_config(HOST_LIST, HOST2) == {
assert training._build_tf_config_for_ps(HOST_LIST, HOST2) == {
"cluster": CLUSTER_WITH_PS,
"environment": "cloud",
"task": WORKER_TASK,
}
assert training._build_tf_config(HOST_LIST, HOST2, ps_task=True) == {
assert training._build_tf_config_for_ps(HOST_LIST, HOST2, ps_task=True) == {
"cluster": CLUSTER_WITH_PS,
"environment": "cloud",
"task": PS_TASK_2,
}


def test_build_tf_config_error():
def test_build_tf_config_for_ps_error():
with pytest.raises(ValueError) as error:
training._build_tf_config([HOST1], HOST1, ps_task=True)
training._build_tf_config_for_ps([HOST1], HOST1, ps_task=True)
assert "Cannot have a ps task if there are no parameter servers in the cluster" in str(
error.value
)
Expand Down Expand Up @@ -327,7 +354,10 @@ def test_main(
@patch("sagemaker_tensorflow_container.training.train")
@patch("logging.Logger.setLevel")
@patch("sagemaker_training.environment.Environment")
@patch("sagemaker_training.environment.read_hyperparameters", return_value={"model_dir": MODEL_DIR})
@patch(
"sagemaker_training.environment.read_hyperparameters",
return_value={"model_dir": MODEL_DIR},
)
@patch("sagemaker_tensorflow_container.s3_utils.configure")
def test_main_simple_training_model_dir(
configure_s3_env,
Expand Down