Skip to content

Commit 1fb2fcd

Browse files
committed
fix: properly close sagemaker config file after loading config
Closes aws#4456
1 parent c799d1a commit 1fb2fcd

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

src/sagemaker/config/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def _load_config_from_file(file_path: str) -> dict:
181181
f"Provide a valid file path"
182182
)
183183
logger.debug("Fetching defaults config from location: %s", file_path)
184-
return yaml.safe_load(open(inferred_file_path, "r"))
184+
with open(inferred_file_path, "r") as f:
185+
content = yaml.safe_load(f)
186+
return content
185187

186188

187189
def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:

tests/unit/sagemaker/config/test_config.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,23 @@
3434
@pytest.fixture()
3535
def config_file_as_yaml(get_data_dir):
3636
config_file_path = os.path.join(get_data_dir, "config.yaml")
37-
return open(config_file_path, "r").read()
37+
with open(config_file_path, "r") as f:
38+
content = f.read()
39+
return content
3840

3941

4042
@pytest.fixture()
4143
def expected_merged_config(get_data_dir):
4244
expected_merged_config_file_path = os.path.join(
4345
get_data_dir, "expected_output_config_after_merge.yaml"
4446
)
45-
return yaml.safe_load(open(expected_merged_config_file_path, "r").read())
47+
with open(expected_merged_config_file_path, "r") as f:
48+
content = yaml.safe_load(f.read())
49+
return content
50+
51+
52+
def _raise_valueerror(*args):
53+
raise ValueError(args)
4654

4755

4856
def test_config_when_default_config_file_and_user_config_file_is_not_found():
@@ -60,7 +68,8 @@ def test_config_when_overriden_default_config_file_is_not_found(get_data_dir):
6068
def test_invalid_config_file_which_has_python_code(get_data_dir):
6169
invalid_config_file_path = os.path.join(get_data_dir, "config_file_with_code.yaml")
6270
# no exceptions will be thrown with yaml.unsafe_load
63-
yaml.unsafe_load(open(invalid_config_file_path, "r"))
71+
with open(invalid_config_file_path, "r") as f:
72+
yaml.unsafe_load(f)
6473
# PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using
6574
# yaml.safe_load internally
6675
with pytest.raises(ConstructorError) as exception_info:
@@ -228,7 +237,8 @@ def test_merge_of_s3_default_config_file_and_regular_config_file(
228237
get_data_dir, expected_merged_config, s3_resource_mock
229238
):
230239
config_file_content_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml")
231-
config_file_as_yaml = open(config_file_content_path, "r").read()
240+
with open(config_file_content_path, "r") as f:
241+
config_file_as_yaml = f.read()
232242
config_file_bucket = "config-file-bucket"
233243
config_file_s3_prefix = "config/config.yaml"
234244
config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix)
@@ -440,8 +450,11 @@ def test_load_local_mode_config(mock_load_config):
440450
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
441451

442452

443-
def test_load_local_mode_config_when_config_file_is_not_found():
453+
@patch("sagemaker.config.config._load_config_from_file", side_effect=_raise_valueerror)
454+
def test_load_local_mode_config_when_config_file_is_not_found(mock_load_config):
455+
# Patch is needed because one might actually have a local config file
444456
assert load_local_mode_config() is None
457+
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
445458

446459

447460
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)