Skip to content

Commit 02697a0

Browse files
committed
Adding tests for the TF trcomp BYOC path
1 parent 26729be commit 02697a0

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py

+58
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,64 @@ def test_default(
286286
actual_train_args == expected_train_args
287287
), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}"
288288

289+
def test_byoc(
290+
self,
291+
time,
292+
name_from_base,
293+
sagemaker_session,
294+
tensorflow_training_version,
295+
tensorflow_training_py_version,
296+
instance_class,
297+
):
298+
compiler_config = TrainingCompilerConfig()
299+
instance_type = f"ml.{instance_class}.2xlarge"
300+
301+
tf = TensorFlow(
302+
py_version=tensorflow_training_py_version,
303+
entry_point=SCRIPT_PATH,
304+
role=ROLE,
305+
sagemaker_session=sagemaker_session,
306+
instance_count=INSTANCE_COUNT,
307+
instance_type=instance_type,
308+
image_uri=_get_full_gpu_image_uri(
309+
tensorflow_training_version,
310+
instance_type,
311+
compiler_config,
312+
tensorflow_training_py_version,
313+
),
314+
enable_sagemaker_metrics=False,
315+
compiler_config=compiler_config,
316+
)
317+
318+
inputs = "s3://mybucket/train"
319+
320+
tf.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG)
321+
322+
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
323+
assert sagemaker_call_names == ["train", "logs_for_job"]
324+
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
325+
assert boto_call_names == ["resource"]
326+
327+
expected_train_args = _create_train_job(
328+
tensorflow_training_version,
329+
instance_type,
330+
compiler_config,
331+
tensorflow_training_py_version,
332+
)
333+
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
334+
expected_train_args["enable_sagemaker_metrics"] = False
335+
expected_train_args["hyperparameters"][
336+
TrainingCompilerConfig.HP_ENABLE_COMPILER
337+
] = json.dumps(True)
338+
expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
339+
False
340+
)
341+
342+
actual_train_args = sagemaker_session.method_calls[0][2]
343+
assert (
344+
actual_train_args == expected_train_args
345+
), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}"
346+
289347
def test_debug_compiler_config(
290348
self,
291349
time,

0 commit comments

Comments
 (0)