Skip to content

Commit ff0e615

Browse files
committed
Black formatted files
1 parent c24c5f5 commit ff0e615

File tree

4 files changed

+94
-56
lines changed

4 files changed

+94
-56
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,22 @@ def train(env, cmd_args):
165165
multi_worker_mirrored_enabled = env.additional_framework_parameters.get(
166166
SAGEMAKER_MULTI_WORKER_MIRRORED_ENABLED, False
167167
)
168-
168+
169169
# Setup
170170
if parameter_server_enabled:
171-
171+
172172
tf_config = _build_tf_config_for_ps(hosts=env.hosts, current_host=env.current_host)
173173
logger.info("Running distributed training job with parameter servers")
174-
174+
175175
elif multi_worker_mirrored_enabled:
176-
176+
177177
tf_config = _build_tf_config_for_mwm(hosts=env.hosts, current_host=env.current_host)
178178
logger.info("Running distributed training job with multi_worker_mirrored setup")
179179

180180

181181
# Run
182182
if parameter_server_enabled:
183-
183+
184184
logger.info("Launching parameter server process")
185185
_run_ps(env, tf_config["cluster"])
186186
logger.info("Launching worker process")

test/integration/sagemaker/test_multi_worker_mirrored.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,30 @@
1414

1515
import os
1616

17-
import boto3
18-
import pytest
1917
from sagemaker.tensorflow import TensorFlow
2018
from sagemaker.utils import unique_name_from_base
21-
from six.moves.urllib.parse import urlparse
22-
23-
from timeout import timeout
24-
2519

2620

2721
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources")
2822

2923

30-
31-
def test_multi_node(sagemaker_session, instance_type, image_uri, tmpdir, framework_version):
24+
def test_multi_node(
25+
sagemaker_session, instance_type, image_uri, tmpdir, framework_version
26+
):
3227
estimator = TensorFlow(
33-
entry_point=os.path.join(RESOURCE_PATH, "multi_worker_mirrored", "train_sample.py"),
34-
role="SageMakerRole",
35-
instance_type=instance_type,
36-
instance_count=2,
37-
image_name=image_uri,
38-
framework_version=framework_version,
39-
py_version="py3",
40-
hyperparameters={
41-
'sagemaker_multi_worker_mirrored_enabled': True,
42-
},
43-
sagemaker_session=sagemaker_session,
44-
)
28+
entry_point=os.path.join(
29+
RESOURCE_PATH, "multi_worker_mirrored", "train_sample.py"
30+
),
31+
role="SageMakerRole",
32+
instance_type=instance_type,
33+
instance_count=2,
34+
image_name=image_uri,
35+
framework_version=framework_version,
36+
py_version="py3",
37+
hyperparameters={
38+
"sagemaker_multi_worker_mirrored_enabled": True,
39+
},
40+
sagemaker_session=sagemaker_session,
41+
)
4542
estimator.fit(job_name=unique_name_from_base("test-tf-mwms"))
46-
raise NotImplementedError('Yet to add assertion')
43+
raise NotImplementedError("Yet to add assertion")

test/resources/multi_worker_mirrored/train_sample.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
import tensorflow as tf
2-
2+
import numpy as np
33

44

55
strategy = tf.distribute.MultiWorkerMirroredStrategy()
66

77
with strategy.scope():
8-
model = tf.keras.Sequential([
9-
tf.keras.layers.Dense(2, input_shape=(5,)),
10-
])
11-
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
8+
model = tf.keras.Sequential(
9+
[
10+
tf.keras.layers.Dense(2, input_shape=(5,)),
11+
]
12+
)
13+
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
14+
1215

1316
def dataset_fn(ctx):
14-
x = np.random.random((2, 5)).astype(np.float32)
15-
y = np.random.randint(2, size=(2, 1))
16-
dataset = tf.data.Dataset.from_tensor_slices((x, y))
17-
return dataset.repeat().batch(1, drop_remainder=True)
17+
x = np.random.random((2, 5)).astype(np.float32)
18+
y = np.random.randint(2, size=(2, 1))
19+
dataset = tf.data.Dataset.from_tensor_slices((x, y))
20+
return dataset.repeat().batch(1, drop_remainder=True)
21+
22+
1823
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
1924

2025
model.compile()

test/unit/test_training.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@
3535
"worker": ["{}:8890".format(HOST2)],
3636
"ps": ["{}:2223".format(HOST1), "{}:2223".format(HOST2)],
3737
}
38-
CLUSTER_WITH_MWMS = {
39-
"worker": ["{}:8890".format(HOST) for HOST IN (HOST1, HOST2)],
40-
}
38+
CLUSTER_WITH_MWMS = {"worker": ["{}:8890".format(HOST) for HOST in HOST_LIST]}
4139

4240
MASTER_TASK = {"index": 0, "type": "master"}
4341
WORKER_TASK = {"index": 0, "type": "worker"}
@@ -54,7 +52,9 @@ def distributed_training_env():
5452
env = simple_training_env()
5553

5654
env.hosts = HOST_LIST
57-
env.additional_framework_parameters = {training.SAGEMAKER_PARAMETER_SERVER_ENABLED: True}
55+
env.additional_framework_parameters = {
56+
training.SAGEMAKER_PARAMETER_SERVER_ENABLED: True
57+
}
5858
return env
5959

6060

@@ -98,7 +98,9 @@ def test_single_machine(run_module, single_machine_training_env):
9898

9999
@patch("sagemaker_training.entry_point.run")
100100
def test_train_horovod(run_module, single_machine_training_env):
101-
single_machine_training_env.additional_framework_parameters["sagemaker_mpi_enabled"] = True
101+
single_machine_training_env.additional_framework_parameters[
102+
"sagemaker_mpi_enabled"
103+
] = True
102104

103105
training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
104106
run_module.assert_called_with(
@@ -113,22 +115,32 @@ def test_train_horovod(run_module, single_machine_training_env):
113115

114116
@pytest.mark.skip_on_pipeline
115117
@pytest.mark.skipif(
116-
sys.version_info.major != 3, reason="Skip this for python 2 because of dict key order mismatch"
118+
sys.version_info.major != 3,
119+
reason="Skip this for python 2 because of dict key order mismatch",
117120
)
118121
@patch("tensorflow.train.ClusterSpec")
119122
@patch("tensorflow.train.Server")
120123
@patch("sagemaker_training.entry_point.run")
121124
@patch("multiprocessing.Process", lambda target: target())
122125
@patch("time.sleep", MagicMock())
123-
def test_train_distributed_master(run, tf_server, cluster_spec, distributed_training_env):
126+
def test_train_distributed_master(
127+
run, tf_server, cluster_spec, distributed_training_env
128+
):
124129
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
125130

126131
cluster_spec.assert_called_with(
127-
{"worker": ["host2:2222"], "master": ["host1:2222"], "ps": ["host1:2223", "host2:2223"]}
132+
{
133+
"worker": ["host2:2222"],
134+
"master": ["host1:2222"],
135+
"ps": ["host1:2223", "host2:2223"],
136+
}
128137
)
129138

130139
tf_server.assert_called_with(
131-
cluster_spec(), job_name="ps", task_index=0, config=tf.ConfigProto(device_count={"GPU": 0})
140+
cluster_spec(),
141+
job_name="ps",
142+
task_index=0,
143+
config=tf.ConfigProto(device_count={"GPU": 0}),
132144
)
133145
tf_server().join.assert_called_with()
134146

@@ -152,24 +164,34 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
152164

153165
@pytest.mark.skip_on_pipeline
154166
@pytest.mark.skipif(
155-
sys.version_info.major != 3, reason="Skip this for python 2 because of dict key order mismatch"
167+
sys.version_info.major != 3,
168+
reason="Skip this for python 2 because of dict key order mismatch",
156169
)
157170
@patch("tensorflow.train.ClusterSpec")
158171
@patch("tensorflow.train.Server")
159172
@patch("sagemaker_training.entry_point.run")
160173
@patch("multiprocessing.Process", lambda target: target())
161174
@patch("time.sleep", MagicMock())
162-
def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_training_env):
175+
def test_train_distributed_worker(
176+
run, tf_server, cluster_spec, distributed_training_env
177+
):
163178
distributed_training_env.current_host = HOST2
164179

165180
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
166181

167182
cluster_spec.assert_called_with(
168-
{"worker": ["host2:2222"], "master": ["host1:2222"], "ps": ["host1:2223", "host2:2223"]}
183+
{
184+
"worker": ["host2:2222"],
185+
"master": ["host1:2222"],
186+
"ps": ["host1:2223", "host2:2223"],
187+
}
169188
)
170189

171190
tf_server.assert_called_with(
172-
cluster_spec(), job_name="ps", task_index=1, config=tf.ConfigProto(device_count={"GPU": 0})
191+
cluster_spec(),
192+
job_name="ps",
193+
task_index=1,
194+
config=tf.ConfigProto(device_count={"GPU": 0}),
173195
)
174196
tf_server().join.assert_called_with()
175197

@@ -248,8 +270,9 @@ def test_build_tf_config_for_ps():
248270
def test_build_tf_config_for_ps_error():
249271
with pytest.raises(ValueError) as error:
250272
training._build_tf_config_for_ps([HOST1], HOST1, ps_task=True)
251-
assert "Cannot have a ps task if there are no parameter servers in the cluster" in str(
252-
error.value
273+
assert (
274+
"Cannot have a ps task if there are no parameter servers in the cluster"
275+
in str(error.value)
253276
)
254277

255278

@@ -271,7 +294,9 @@ def test_log_model_missing_warning_no_model(logger):
271294

272295
@patch("sagemaker_tensorflow_container.training.logger")
273296
def test_log_model_missing_warning_wrong_format(logger):
274-
training._log_model_missing_warning(os.path.join(RESOURCE_PATH, "test_dir_wrong_model"))
297+
training._log_model_missing_warning(
298+
os.path.join(RESOURCE_PATH, "test_dir_wrong_model")
299+
)
275300
logger.warn.assert_called_with(
276301
"Your model will NOT be servable with SageMaker TensorFlow Serving container. "
277302
"The model artifact was not saved in the TensorFlow "
@@ -282,16 +307,22 @@ def test_log_model_missing_warning_wrong_format(logger):
282307

283308
@patch("sagemaker_tensorflow_container.training.logger")
284309
def test_log_model_missing_warning_wrong_parent_dir(logger):
285-
training._log_model_missing_warning(os.path.join(RESOURCE_PATH, "test_dir_wrong_parent_dir"))
310+
training._log_model_missing_warning(
311+
os.path.join(RESOURCE_PATH, "test_dir_wrong_parent_dir")
312+
)
286313
logger.warn.assert_called_with(
287314
"Your model will NOT be servable with SageMaker TensorFlow Serving containers. "
288-
'The SavedModel bundle is under directory "{}", not a numeric name.'.format("not-digit")
315+
'The SavedModel bundle is under directory "{}", not a numeric name.'.format(
316+
"not-digit"
317+
)
289318
)
290319

291320

292321
@patch("sagemaker_tensorflow_container.training.logger")
293322
def test_log_model_missing_warning_correct(logger):
294-
training._log_model_missing_warning(os.path.join(RESOURCE_PATH, "test_dir_correct_model"))
323+
training._log_model_missing_warning(
324+
os.path.join(RESOURCE_PATH, "test_dir_correct_model")
325+
)
295326
logger.warn.assert_not_called()
296327

297328

@@ -323,7 +354,10 @@ def test_main(
323354
@patch("sagemaker_tensorflow_container.training.train")
324355
@patch("logging.Logger.setLevel")
325356
@patch("sagemaker_training.environment.Environment")
326-
@patch("sagemaker_training.environment.read_hyperparameters", return_value={"model_dir": MODEL_DIR})
357+
@patch(
358+
"sagemaker_training.environment.read_hyperparameters",
359+
return_value={"model_dir": MODEL_DIR},
360+
)
327361
@patch("sagemaker_tensorflow_container.s3_utils.configure")
328362
def test_main_simple_training_model_dir(
329363
configure_s3_env,
@@ -361,7 +395,9 @@ def test_main_tuning_model_dir(
361395
training_env.return_value = single_machine_training_env
362396
os.environ["SAGEMAKER_REGION"] = REGION
363397
training.main()
364-
expected_model_dir = "{}/{}/model".format(MODEL_DIR, single_machine_training_env.job_name)
398+
expected_model_dir = "{}/{}/model".format(
399+
MODEL_DIR, single_machine_training_env.job_name
400+
)
365401
configure_s3_env.assert_called_once_with(expected_model_dir, REGION)
366402

367403

0 commit comments

Comments
 (0)