@@ -286,6 +286,64 @@ def test_default(
286
286
actual_train_args == expected_train_args
287
287
), f"{ json .dumps (actual_train_args , indent = 2 )} != { json .dumps (expected_train_args , indent = 2 )} "
288
288
289
+ def test_byoc (
290
+ self ,
291
+ time ,
292
+ name_from_base ,
293
+ sagemaker_session ,
294
+ tensorflow_training_version ,
295
+ tensorflow_training_py_version ,
296
+ instance_class ,
297
+ ):
298
+ compiler_config = TrainingCompilerConfig ()
299
+ instance_type = f"ml.{ instance_class } .2xlarge"
300
+
301
+ tf = TensorFlow (
302
+ py_version = tensorflow_training_py_version ,
303
+ entry_point = SCRIPT_PATH ,
304
+ role = ROLE ,
305
+ sagemaker_session = sagemaker_session ,
306
+ instance_count = INSTANCE_COUNT ,
307
+ instance_type = instance_type ,
308
+ image_uri = _get_full_gpu_image_uri (
309
+ tensorflow_training_version ,
310
+ instance_type ,
311
+ compiler_config ,
312
+ tensorflow_training_py_version ,
313
+ ),
314
+ enable_sagemaker_metrics = False ,
315
+ compiler_config = compiler_config ,
316
+ )
317
+
318
+ inputs = "s3://mybucket/train"
319
+
320
+ tf .fit (inputs = inputs , experiment_config = EXPERIMENT_CONFIG )
321
+
322
+ sagemaker_call_names = [c [0 ] for c in sagemaker_session .method_calls ]
323
+ assert sagemaker_call_names == ["train" , "logs_for_job" ]
324
+ boto_call_names = [c [0 ] for c in sagemaker_session .boto_session .method_calls ]
325
+ assert boto_call_names == ["resource" ]
326
+
327
+ expected_train_args = _create_train_job (
328
+ tensorflow_training_version ,
329
+ instance_type ,
330
+ compiler_config ,
331
+ tensorflow_training_py_version ,
332
+ )
333
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
334
+ expected_train_args ["enable_sagemaker_metrics" ] = False
335
+ expected_train_args ["hyperparameters" ][
336
+ TrainingCompilerConfig .HP_ENABLE_COMPILER
337
+ ] = json .dumps (True )
338
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig .HP_ENABLE_DEBUG ] = json .dumps (
339
+ False
340
+ )
341
+
342
+ actual_train_args = sagemaker_session .method_calls [0 ][2 ]
343
+ assert (
344
+ actual_train_args == expected_train_args
345
+ ), f"{ json .dumps (actual_train_args , indent = 2 )} != { json .dumps (expected_train_args , indent = 2 )} "
346
+
289
347
def test_debug_compiler_config (
290
348
self ,
291
349
time ,
0 commit comments