34
34
@pytest .fixture ()
35
35
def config_file_as_yaml (get_data_dir ):
36
36
config_file_path = os .path .join (get_data_dir , "config.yaml" )
37
- return open (config_file_path , "r" ).read ()
37
+ with open (config_file_path , "r" ) as f :
38
+ content = f .read ()
39
+ return content
38
40
39
41
40
42
@pytest .fixture ()
41
43
def expected_merged_config (get_data_dir ):
42
44
expected_merged_config_file_path = os .path .join (
43
45
get_data_dir , "expected_output_config_after_merge.yaml"
44
46
)
45
- return yaml .safe_load (open (expected_merged_config_file_path , "r" ).read ())
47
+ with open (expected_merged_config_file_path , "r" ) as f :
48
+ content = yaml .safe_load (f .read ())
49
+ return content
50
+
51
+
52
+ def _raise_valueerror (* args ):
53
+ raise ValueError (args )
46
54
47
55
48
56
def test_config_when_default_config_file_and_user_config_file_is_not_found ():
@@ -60,7 +68,8 @@ def test_config_when_overriden_default_config_file_is_not_found(get_data_dir):
60
68
def test_invalid_config_file_which_has_python_code (get_data_dir ):
61
69
invalid_config_file_path = os .path .join (get_data_dir , "config_file_with_code.yaml" )
62
70
# no exceptions will be thrown with yaml.unsafe_load
63
- yaml .unsafe_load (open (invalid_config_file_path , "r" ))
71
+ with open (invalid_config_file_path , "r" ) as f :
72
+ yaml .unsafe_load (f )
64
73
# PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using
65
74
# yaml.safe_load internally
66
75
with pytest .raises (ConstructorError ) as exception_info :
@@ -228,7 +237,8 @@ def test_merge_of_s3_default_config_file_and_regular_config_file(
228
237
get_data_dir , expected_merged_config , s3_resource_mock
229
238
):
230
239
config_file_content_path = os .path .join (get_data_dir , "sample_config_for_merge.yaml" )
231
- config_file_as_yaml = open (config_file_content_path , "r" ).read ()
240
+ with open (config_file_content_path , "r" ) as f :
241
+ config_file_as_yaml = f .read ()
232
242
config_file_bucket = "config-file-bucket"
233
243
config_file_s3_prefix = "config/config.yaml"
234
244
config_file_s3_uri = "s3://{}/{}" .format (config_file_bucket , config_file_s3_prefix )
@@ -440,8 +450,11 @@ def test_load_local_mode_config(mock_load_config):
440
450
mock_load_config .assert_called_with (_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH )
441
451
442
452
443
- def test_load_local_mode_config_when_config_file_is_not_found ():
453
+ @patch ("sagemaker.config.config._load_config_from_file" , side_effect = _raise_valueerror )
454
+ def test_load_local_mode_config_when_config_file_is_not_found (mock_load_config ):
455
+ # Patch is needed because one might actually have a local config file
444
456
assert load_local_mode_config () is None
457
+ mock_load_config .assert_called_with (_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH )
445
458
446
459
447
460
@pytest .mark .parametrize (
0 commit comments