From 78cb375fe992ef604215ba82cf0ddd064b092c58 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 4 Oct 2022 07:53:25 +0300 Subject: [PATCH] TST: use `with` where possible instead of manual `close` Coincidentally fixes some StataReaders being left open in tests. --- pandas/tests/io/parser/test_encoding.py | 10 +- .../tests/io/pytables/test_file_handling.py | 23 ++-- pandas/tests/io/pytables/test_read.py | 18 +-- pandas/tests/io/pytables/test_select.py | 28 ++--- pandas/tests/io/pytables/test_store.py | 5 +- pandas/tests/io/test_common.py | 10 +- pandas/tests/io/test_compression.py | 11 +- pandas/tests/io/test_fsspec.py | 24 ++-- pandas/tests/io/test_sql.py | 40 +++---- pandas/tests/io/test_stata.py | 109 +++++++++--------- 10 files changed, 127 insertions(+), 151 deletions(-) diff --git a/pandas/tests/io/parser/test_encoding.py b/pandas/tests/io/parser/test_encoding.py index c06ac9e76bd7f..0dff14dad21c4 100644 --- a/pandas/tests/io/parser/test_encoding.py +++ b/pandas/tests/io/parser/test_encoding.py @@ -66,13 +66,9 @@ def test_utf16_bom_skiprows(all_parsers, sep, encoding): with open(path, "wb") as f: f.write(bytes_data) - bytes_buffer = BytesIO(data.encode(utf8)) - bytes_buffer = TextIOWrapper(bytes_buffer, encoding=utf8) - - result = parser.read_csv(path, encoding=encoding, **kwargs) - expected = parser.read_csv(bytes_buffer, encoding=utf8, **kwargs) - - bytes_buffer.close() + with TextIOWrapper(BytesIO(data.encode(utf8)), encoding=utf8) as bytes_buffer: + result = parser.read_csv(path, encoding=encoding, **kwargs) + expected = parser.read_csv(bytes_buffer, encoding=utf8, **kwargs) tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/io/pytables/test_file_handling.py b/pandas/tests/io/pytables/test_file_handling.py index 9b20820e355a6..28c6d13d58043 100644 --- a/pandas/tests/io/pytables/test_file_handling.py +++ b/pandas/tests/io/pytables/test_file_handling.py @@ -41,9 +41,8 @@ def test_mode(setup_path, tmp_path, mode): HDFStore(path, mode=mode) else: - store = HDFStore(path, mode=mode) - assert store._handle.mode == mode - store.close() + with HDFStore(path, mode=mode) as store: + assert store._handle.mode == mode path = tmp_path / setup_path @@ -253,16 +252,14 @@ def test_complibs(tmp_path, setup_path): result = read_hdf(tmpfile, gname) tm.assert_frame_equal(result, df) - # Open file and check metadata - # for correct amount of compression - h5table = tables.open_file(tmpfile, mode="r") - for node in h5table.walk_nodes(where="/" + gname, classname="Leaf"): - assert node.filters.complevel == lvl - if lvl == 0: - assert node.filters.complib is None - else: - assert node.filters.complib == lib - h5table.close() + # Open file and check metadata for correct amount of compression + with tables.open_file(tmpfile, mode="r") as h5table: + for node in h5table.walk_nodes(where="/" + gname, classname="Leaf"): + assert node.filters.complevel == lvl + if lvl == 0: + assert node.filters.complib is None + else: + assert node.filters.complib == lib @pytest.mark.skipif( diff --git a/pandas/tests/io/pytables/test_read.py b/pandas/tests/io/pytables/test_read.py index 1163b9e11a367..6d92c15f1ea10 100644 --- a/pandas/tests/io/pytables/test_read.py +++ b/pandas/tests/io/pytables/test_read.py @@ -1,3 +1,4 @@ +from contextlib import closing from pathlib import Path import re @@ -207,11 +208,10 @@ def test_read_hdf_open_store(tmp_path, setup_path): path = tmp_path / setup_path df.to_hdf(path, "df", mode="w") direct = read_hdf(path, "df") - store = HDFStore(path, mode="r") - indirect = read_hdf(store, "df") - tm.assert_frame_equal(direct, indirect) - assert store.is_open - store.close() + with HDFStore(path, mode="r") as store: + indirect = read_hdf(store, "df") + tm.assert_frame_equal(direct, indirect) + assert store.is_open def test_read_hdf_iterator(tmp_path, setup_path): @@ -223,10 +223,10 @@ def test_read_hdf_iterator(tmp_path, setup_path): df.to_hdf(path, "df", mode="w", format="t") direct = read_hdf(path, "df") iterator = read_hdf(path, "df", iterator=True) - assert isinstance(iterator, TableIterator) - indirect = next(iterator.__iter__()) - tm.assert_frame_equal(direct, indirect) - iterator.store.close() + with closing(iterator.store): + assert isinstance(iterator, TableIterator) + indirect = next(iterator.__iter__()) + tm.assert_frame_equal(direct, indirect) def test_read_nokey(tmp_path, setup_path): diff --git a/pandas/tests/io/pytables/test_select.py b/pandas/tests/io/pytables/test_select.py index e28c70d83baa7..76d5cc5672e40 100644 --- a/pandas/tests/io/pytables/test_select.py +++ b/pandas/tests/io/pytables/test_select.py @@ -682,10 +682,9 @@ def test_frame_select_complex2(tmp_path): # scope with list like l0 = selection.index.tolist() # noqa:F841 - store = HDFStore(hh) - result = store.select("df", where="l1=l0") - tm.assert_frame_equal(result, expected) - store.close() + with HDFStore(hh) as store: + result = store.select("df", where="l1=l0") + tm.assert_frame_equal(result, expected) result = read_hdf(hh, "df", where="l1=l0") tm.assert_frame_equal(result, expected) @@ -705,21 +704,18 @@ def test_frame_select_complex2(tmp_path): tm.assert_frame_equal(result, expected) # scope with index - store = HDFStore(hh) - - result = store.select("df", where="l1=index") - tm.assert_frame_equal(result, expected) - - result = store.select("df", where="l1=selection.index") - tm.assert_frame_equal(result, expected) + with HDFStore(hh) as store: + result = store.select("df", where="l1=index") + tm.assert_frame_equal(result, expected) - result = store.select("df", where="l1=selection.index.tolist()") - tm.assert_frame_equal(result, expected) + result = store.select("df", where="l1=selection.index") + tm.assert_frame_equal(result, expected) - result = store.select("df", where="l1=list(selection.index)") - tm.assert_frame_equal(result, expected) + result = store.select("df", where="l1=selection.index.tolist()") + tm.assert_frame_equal(result, expected) - store.close() + result = store.select("df", where="l1=list(selection.index)") + tm.assert_frame_equal(result, expected) def test_invalid_filtering(setup_path): diff --git a/pandas/tests/io/pytables/test_store.py b/pandas/tests/io/pytables/test_store.py index ccea10f5b2612..08b1ee3f0ddbe 100644 --- a/pandas/tests/io/pytables/test_store.py +++ b/pandas/tests/io/pytables/test_store.py @@ -917,9 +917,8 @@ def do_copy(f, new_f=None, keys=None, propindexes=True, **kwargs): df = tm.makeDataFrame() with tm.ensure_clean() as path: - st = HDFStore(path) - st.append("df", df, data_columns=["A"]) - st.close() + with HDFStore(path) as st: + st.append("df", df, data_columns=["A"]) do_copy(f=path) do_copy(f=path, propindexes=False) diff --git a/pandas/tests/io/test_common.py b/pandas/tests/io/test_common.py index 3f95dd616a09c..4a6ec7cfd2ae3 100644 --- a/pandas/tests/io/test_common.py +++ b/pandas/tests/io/test_common.py @@ -118,11 +118,11 @@ def test_get_handle_with_path(self, path_type): assert os.path.expanduser(filename) == handles.handle.name def test_get_handle_with_buffer(self): - input_buffer = StringIO() - with icom.get_handle(input_buffer, "r") as handles: - assert handles.handle == input_buffer - assert not input_buffer.closed - input_buffer.close() + with StringIO() as input_buffer: + with icom.get_handle(input_buffer, "r") as handles: + assert handles.handle == input_buffer + assert not input_buffer.closed + assert input_buffer.closed # Test that BytesIOWrapper(get_handle) returns correct amount of bytes every time def test_bytesiowrapper_returns_correct_bytes(self): diff --git a/pandas/tests/io/test_compression.py b/pandas/tests/io/test_compression.py index 89ecde58735a7..782753177f245 100644 --- a/pandas/tests/io/test_compression.py +++ b/pandas/tests/io/test_compression.py @@ -282,18 +282,17 @@ def test_bzip_compression_level(obj, method): ) def test_empty_archive_zip(suffix, archive): with tm.ensure_clean(filename=suffix) as path: - file = archive(path, "w") - file.close() + with archive(path, "w"): + pass with pytest.raises(ValueError, match="Zero files found"): pd.read_csv(path) def test_ambiguous_archive_zip(): with tm.ensure_clean(filename=".zip") as path: - file = zipfile.ZipFile(path, "w") - file.writestr("a.csv", "foo,bar") - file.writestr("b.csv", "foo,bar") - file.close() + with zipfile.ZipFile(path, "w") as file: + file.writestr("a.csv", "foo,bar") + file.writestr("b.csv", "foo,bar") with pytest.raises(ValueError, match="Multiple files found in ZIP file"): pd.read_csv(path) diff --git a/pandas/tests/io/test_fsspec.py b/pandas/tests/io/test_fsspec.py index 4f033fd63f978..82f5bdda2a4c5 100644 --- a/pandas/tests/io/test_fsspec.py +++ b/pandas/tests/io/test_fsspec.py @@ -95,22 +95,18 @@ def test_to_csv_fsspec_object(cleared_fs, binary_mode, df1): path = "memory://test/test.csv" mode = "wb" if binary_mode else "w" - fsspec_object = fsspec.open(path, mode=mode).open() - - df1.to_csv(fsspec_object, index=True) - assert not fsspec_object.closed - fsspec_object.close() + with fsspec.open(path, mode=mode).open() as fsspec_object: + df1.to_csv(fsspec_object, index=True) + assert not fsspec_object.closed mode = mode.replace("w", "r") - fsspec_object = fsspec.open(path, mode=mode).open() - - df2 = read_csv( - fsspec_object, - parse_dates=["dt"], - index_col=0, - ) - assert not fsspec_object.closed - fsspec_object.close() + with fsspec.open(path, mode=mode) as fsspec_object: + df2 = read_csv( + fsspec_object, + parse_dates=["dt"], + index_col=0, + ) + assert not fsspec_object.closed tm.assert_frame_equal(df1, df2) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 375f66d545ed4..0594afda252d4 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -18,6 +18,7 @@ """ from __future__ import annotations +from contextlib import closing import csv from datetime import ( date, @@ -455,9 +456,8 @@ def sqlite_iris_conn(sqlite_iris_engine): @pytest.fixture def sqlite_buildin(): - conn = sqlite3.connect(":memory:") - yield conn - conn.close() + with sqlite3.connect(":memory:") as conn: + yield conn @pytest.fixture @@ -1532,13 +1532,14 @@ def test_sql_open_close(self, test_frame3): with tm.ensure_clean() as name: - conn = self.connect(name) - assert sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False) == 4 - conn.close() + with closing(self.connect(name)) as conn: + assert ( + sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False) + == 4 + ) - conn = self.connect(name) - result = sql.read_sql_query("SELECT * FROM test_frame3_legacy;", conn) - conn.close() + with closing(self.connect(name)) as conn: + result = sql.read_sql_query("SELECT * FROM test_frame3_legacy;", conn) tm.assert_frame_equal(test_frame3, result) @@ -2371,18 +2372,15 @@ class Test(BaseModel): BaseModel.metadata.create_all(self.conn) Session = sessionmaker(bind=self.conn) - session = Session() - - df = DataFrame({"id": [0, 1], "foo": ["hello", "world"]}) - assert ( - df.to_sql("test_frame", con=self.conn, index=False, if_exists="replace") - == 2 - ) - - session.commit() - foo = session.query(Test.id, Test.foo) - df = DataFrame(foo) - session.close() + with Session() as session: + df = DataFrame({"id": [0, 1], "foo": ["hello", "world"]}) + assert ( + df.to_sql("test_frame", con=self.conn, index=False, if_exists="replace") + == 2 + ) + session.commit() + foo = session.query(Test.id, Test.foo) + df = DataFrame(foo) assert list(df.columns) == ["id", "foo"] diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index 745d0691e8d86..a4e4751d75347 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -600,9 +600,8 @@ def test_value_labels_old_format(self, datapath): # Test that value_labels() returns an empty dict if the file format # predates supporting value labels. dpath = datapath("io", "data", "stata", "S4_EDUC1.dta") - reader = StataReader(dpath) - assert reader.value_labels() == {} - reader.close() + with StataReader(dpath) as reader: + assert reader.value_labels() == {} def test_date_export_formats(self): columns = ["tc", "td", "tw", "tm", "tq", "th", "ty"] @@ -1108,29 +1107,26 @@ def test_read_chunks_117( convert_categoricals=convert_categoricals, convert_dates=convert_dates, ) - itr = read_stata( + with read_stata( fname, iterator=True, convert_categoricals=convert_categoricals, convert_dates=convert_dates, - ) - - pos = 0 - for j in range(5): - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - try: - chunk = itr.read(chunksize) - except StopIteration: - break - from_frame = parsed.iloc[pos : pos + chunksize, :].copy() - from_frame = self._convert_categorical(from_frame) - tm.assert_frame_equal( - from_frame, chunk, check_dtype=False, check_datetimelike_compat=True - ) - - pos += chunksize - itr.close() + ) as itr: + pos = 0 + for j in range(5): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + try: + chunk = itr.read(chunksize) + except StopIteration: + break + from_frame = parsed.iloc[pos : pos + chunksize, :].copy() + from_frame = self._convert_categorical(from_frame) + tm.assert_frame_equal( + from_frame, chunk, check_dtype=False, check_datetimelike_compat=True + ) + pos += chunksize @staticmethod def _convert_categorical(from_frame: DataFrame) -> DataFrame: @@ -1206,28 +1202,26 @@ def test_read_chunks_115( ) # Compare to what we get when reading by chunk - itr = read_stata( + with read_stata( fname, iterator=True, convert_dates=convert_dates, convert_categoricals=convert_categoricals, - ) - pos = 0 - for j in range(5): - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - try: - chunk = itr.read(chunksize) - except StopIteration: - break - from_frame = parsed.iloc[pos : pos + chunksize, :].copy() - from_frame = self._convert_categorical(from_frame) - tm.assert_frame_equal( - from_frame, chunk, check_dtype=False, check_datetimelike_compat=True - ) - - pos += chunksize - itr.close() + ) as itr: + pos = 0 + for j in range(5): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + try: + chunk = itr.read(chunksize) + except StopIteration: + break + from_frame = parsed.iloc[pos : pos + chunksize, :].copy() + from_frame = self._convert_categorical(from_frame) + tm.assert_frame_equal( + from_frame, chunk, check_dtype=False, check_datetimelike_compat=True + ) + pos += chunksize def test_read_chunks_columns(self, datapath): fname = datapath("io", "data", "stata", "stata3_117.dta") @@ -1820,9 +1814,9 @@ def test_utf8_writer(self, version): data["β"].replace(value_labels["β"]).astype("category").cat.as_ordered() ) tm.assert_frame_equal(data, reread_encoded) - reader = StataReader(path) - assert reader.data_label == data_label - assert reader.variable_labels() == variable_labels + with StataReader(path) as reader: + assert reader.data_label == data_label + assert reader.variable_labels() == variable_labels data.to_stata(path, version=version, write_index=False) reread_to_stata = read_stata(path) @@ -1922,11 +1916,11 @@ def test_chunked_categorical(version): df.index.name = "index" with tm.ensure_clean() as path: df.to_stata(path, version=version) - reader = StataReader(path, chunksize=2, order_categoricals=False) - for i, block in enumerate(reader): - block = block.set_index("index") - assert "cats" in block - tm.assert_series_equal(block.cats, df.cats.iloc[2 * i : 2 * (i + 1)]) + with StataReader(path, chunksize=2, order_categoricals=False) as reader: + for i, block in enumerate(reader): + block = block.set_index("index") + assert "cats" in block + tm.assert_series_equal(block.cats, df.cats.iloc[2 * i : 2 * (i + 1)]) def test_chunked_categorical_partial(datapath): @@ -1952,7 +1946,8 @@ def test_chunked_categorical_partial(datapath): def test_iterator_errors(datapath, chunksize): dta_file = datapath("io", "data", "stata", "stata-dta-partially-labeled.dta") with pytest.raises(ValueError, match="chunksize must be a positive"): - StataReader(dta_file, chunksize=chunksize) + with StataReader(dta_file, chunksize=chunksize): + pass def test_iterator_value_labels(): @@ -2051,9 +2046,9 @@ def test_non_categorical_value_labels(): writer = StataWriter(path, data, value_labels=value_labels) writer.write_file() - reader = StataReader(path) - reader_value_labels = reader.value_labels() - assert reader_value_labels == expected + with StataReader(path) as reader: + reader_value_labels = reader.value_labels() + assert reader_value_labels == expected msg = "Can't create value labels for notY, it wasn't found in the dataset." with pytest.raises(KeyError, match=msg): @@ -2101,9 +2096,9 @@ def test_non_categorical_value_label_name_conversion(): with tm.assert_produces_warning(InvalidColumnName): data.to_stata(path, value_labels=value_labels) - reader = StataReader(path) - reader_value_labels = reader.value_labels() - assert reader_value_labels == expected + with StataReader(path) as reader: + reader_value_labels = reader.value_labels() + assert reader_value_labels == expected def test_non_categorical_value_label_convert_categoricals_error(): @@ -2122,8 +2117,8 @@ def test_non_categorical_value_label_convert_categoricals_error(): with tm.ensure_clean() as path: data.to_stata(path, value_labels=value_labels) - reader = StataReader(path, convert_categoricals=False) - reader_value_labels = reader.value_labels() + with StataReader(path, convert_categoricals=False) as reader: + reader_value_labels = reader.value_labels() assert reader_value_labels == value_labels col = "repeated_labels"