@@ -179,7 +179,9 @@ def test_additional_hyperparameters(sagemaker_session, chainer_version, chainer_
179
179
)
180
180
181
181
182
- def test_attach_with_additional_hyperparameters (sagemaker_session , chainer_version , chainer_py_version ):
182
+ def test_attach_with_additional_hyperparameters (
183
+ sagemaker_session , chainer_version , chainer_py_version
184
+ ):
183
185
training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-cpu-{}" .format (
184
186
chainer_version , chainer_py_version
185
187
)
@@ -388,7 +390,9 @@ def test_model(sagemaker_session, chainer_version, chainer_py_version):
388
390
389
391
390
392
@patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
391
- def test_model_prepare_container_def_accelerator_error (sagemaker_session , chainer_version , chainer_py_version ):
393
+ def test_model_prepare_container_def_accelerator_error (
394
+ sagemaker_session , chainer_version , chainer_py_version
395
+ ):
392
396
model = ChainerModel (
393
397
MODEL_DATA ,
394
398
role = ROLE ,
@@ -433,29 +437,44 @@ def test_train_image_default(sagemaker_session, chainer_version, chainer_py_vers
433
437
434
438
def test_train_image_cpu_instances (sagemaker_session , chainer_version , chainer_py_version ):
435
439
chainer = _chainer_estimator (
436
- sagemaker_session , framework_version = chainer_version , py_version = chainer_py_version , train_instance_type = "ml.c2.2xlarge"
440
+ sagemaker_session ,
441
+ framework_version = chainer_version ,
442
+ py_version = chainer_py_version ,
443
+ train_instance_type = "ml.c2.2xlarge" ,
437
444
)
438
445
assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version , chainer_py_version )
439
446
440
447
chainer = _chainer_estimator (
441
- sagemaker_session , framework_version = chainer_version , py_version = chainer_py_version , train_instance_type = "ml.c4.2xlarge"
448
+ sagemaker_session ,
449
+ framework_version = chainer_version ,
450
+ py_version = chainer_py_version ,
451
+ train_instance_type = "ml.c4.2xlarge" ,
442
452
)
443
453
assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version , chainer_py_version )
444
454
445
455
chainer = _chainer_estimator (
446
- sagemaker_session , framework_version = chainer_version , py_version = chainer_py_version , train_instance_type = "ml.m16"
456
+ sagemaker_session ,
457
+ framework_version = chainer_version ,
458
+ py_version = chainer_py_version ,
459
+ train_instance_type = "ml.m16" ,
447
460
)
448
461
assert chainer .train_image () == _get_full_cpu_image_uri (chainer_version , chainer_py_version )
449
462
450
463
451
464
def test_train_image_gpu_instances (sagemaker_session , chainer_version , chainer_py_version ):
452
465
chainer = _chainer_estimator (
453
- sagemaker_session , framework_version = chainer_version , py_version = chainer_py_version , train_instance_type = "ml.g2.2xlarge"
466
+ sagemaker_session ,
467
+ framework_version = chainer_version ,
468
+ py_version = chainer_py_version ,
469
+ train_instance_type = "ml.g2.2xlarge" ,
454
470
)
455
471
assert chainer .train_image () == _get_full_gpu_image_uri (chainer_version , chainer_py_version )
456
472
457
473
chainer = _chainer_estimator (
458
- sagemaker_session , framework_version = chainer_version , py_version = chainer_py_version , train_instance_type = "ml.p2.2xlarge"
474
+ sagemaker_session ,
475
+ framework_version = chainer_version ,
476
+ py_version = chainer_py_version ,
477
+ train_instance_type = "ml.p2.2xlarge" ,
459
478
)
460
479
assert chainer .train_image () == _get_full_gpu_image_uri (chainer_version , chainer_py_version )
461
480
0 commit comments