Skip to content

Commit ed23eb8

Browse files
authored
CI: avoid file leaks in sas_xport tests (#35693)
1 parent 542b20a commit ed23eb8

File tree

3 files changed

+42
-14
lines changed

3 files changed

+42
-14
lines changed

pandas/io/sas/sasreader.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pandas._typing import FilePathOrBuffer, Label
88

9-
from pandas.io.common import stringify_path
9+
from pandas.io.common import get_filepath_or_buffer, stringify_path
1010

1111
if TYPE_CHECKING:
1212
from pandas import DataFrame # noqa: F401
@@ -109,6 +109,10 @@ def read_sas(
109109
else:
110110
raise ValueError("unable to infer format of SAS file")
111111

112+
filepath_or_buffer, _, _, should_close = get_filepath_or_buffer(
113+
filepath_or_buffer, encoding
114+
)
115+
112116
reader: ReaderBase
113117
if format.lower() == "xport":
114118
from pandas.io.sas.sas_xport import XportReader
@@ -129,5 +133,7 @@ def read_sas(
129133
return reader
130134

131135
data = reader.read()
132-
reader.close()
136+
137+
if should_close:
138+
reader.close()
133139
return data

pandas/tests/io/sas/test_xport.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import numpy as np
44
import pytest
55

6+
import pandas.util._test_decorators as td
7+
68
import pandas as pd
79
import pandas._testing as tm
810

@@ -26,10 +28,12 @@ def setup_method(self, datapath):
2628
self.dirpath = datapath("io", "sas", "data")
2729
self.file01 = os.path.join(self.dirpath, "DEMO_G.xpt")
2830
self.file02 = os.path.join(self.dirpath, "SSHSV1_A.xpt")
29-
self.file02b = open(os.path.join(self.dirpath, "SSHSV1_A.xpt"), "rb")
3031
self.file03 = os.path.join(self.dirpath, "DRXFCD_G.xpt")
3132
self.file04 = os.path.join(self.dirpath, "paxraw_d_short.xpt")
3233

34+
with td.file_leak_context():
35+
yield
36+
3337
def test1_basic(self):
3438
# Tests with DEMO_G.xpt (all numeric file)
3539

@@ -127,7 +131,12 @@ def test2_binary(self):
127131
data_csv = pd.read_csv(self.file02.replace(".xpt", ".csv"))
128132
numeric_as_float(data_csv)
129133

130-
data = read_sas(self.file02b, format="xport")
134+
with open(self.file02, "rb") as fd:
135+
with td.file_leak_context():
136+
# GH#35693 ensure that if we pass an open file, we
137+
# dont incorrectly close it in read_sas
138+
data = read_sas(fd, format="xport")
139+
131140
tm.assert_frame_equal(data, data_csv)
132141

133142
def test_multiple_types(self):

pandas/util/_test_decorators.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def test_foo():
2323
2424
For more information, refer to the ``pytest`` documentation on ``skipif``.
2525
"""
26+
from contextlib import contextmanager
2627
from distutils.version import LooseVersion
27-
from functools import wraps
2828
import locale
2929
from typing import Callable, Optional
3030

@@ -237,23 +237,36 @@ def documented_fixture(fixture):
237237

238238
def check_file_leaks(func) -> Callable:
239239
"""
240-
Decorate a test function tot check that we are not leaking file descriptors.
240+
Decorate a test function to check that we are not leaking file descriptors.
241241
"""
242-
psutil = safe_import("psutil")
243-
if not psutil:
242+
with file_leak_context():
244243
return func
245244

246-
@wraps(func)
247-
def new_func(*args, **kwargs):
245+
246+
@contextmanager
247+
def file_leak_context():
248+
"""
249+
ContextManager analogue to check_file_leaks.
250+
"""
251+
psutil = safe_import("psutil")
252+
if not psutil:
253+
yield
254+
else:
248255
proc = psutil.Process()
249256
flist = proc.open_files()
257+
conns = proc.connections()
250258

251-
func(*args, **kwargs)
259+
yield
252260

253261
flist2 = proc.open_files()
254-
assert flist2 == flist
255-
256-
return new_func
262+
# on some builds open_files includes file position, which we _dont_
263+
# expect to remain unchanged, so we need to compare excluding that
264+
flist_ex = [(x.path, x.fd) for x in flist]
265+
flist2_ex = [(x.path, x.fd) for x in flist2]
266+
assert flist2_ex == flist_ex, (flist2, flist)
267+
268+
conns2 = proc.connections()
269+
assert conns2 == conns, (conns2, conns)
257270

258271

259272
def async_mark():

0 commit comments

Comments
 (0)