Skip to content

Commit 5a9a9da

Browse files
committed
Merge pull request pandas-dev#10613 from jreback/stata
ENH: add StataReader context manager to ensure closing of the path
2 parents 4bb45b1 + 59dd18b commit 5a9a9da

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

pandas/io/stata.py

+15
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,21 @@ def __init__(self, path_or_buf, convert_dates=True,
949949

950950
self._read_header()
951951

952+
def __enter__(self):
953+
""" enter context manager """
954+
return self
955+
956+
def __exit__(self, exc_type, exc_value, traceback):
957+
""" exit context manager """
958+
self.close()
959+
960+
def close(self):
961+
""" close the handle if its open """
962+
try:
963+
self.path_or_buf.close()
964+
except IOError:
965+
pass
966+
952967
def _read_header(self):
953968
first_char = self.path_or_buf.read(1)
954969
if struct.unpack('c', first_char)[0] == b'<':

pandas/io/tests/test_stata.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,11 @@ def test_timestamp_and_label(self):
430430
data_label = 'This is a data file.'
431431
with tm.ensure_clean() as path:
432432
original.to_stata(path, time_stamp=time_stamp, data_label=data_label)
433-
reader = StataReader(path)
434-
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
435-
assert parsed_time_stamp == time_stamp
436-
assert reader.data_label == data_label
433+
434+
with StataReader(path) as reader:
435+
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
436+
assert parsed_time_stamp == time_stamp
437+
assert reader.data_label == data_label
437438

438439
def test_numeric_column_names(self):
439440
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
@@ -599,13 +600,14 @@ def test_minimal_size_col(self):
599600
original = DataFrame(s)
600601
with tm.ensure_clean() as path:
601602
original.to_stata(path, write_index=False)
602-
sr = StataReader(path)
603-
typlist = sr.typlist
604-
variables = sr.varlist
605-
formats = sr.fmtlist
606-
for variable, fmt, typ in zip(variables, formats, typlist):
607-
self.assertTrue(int(variable[1:]) == int(fmt[1:-1]))
608-
self.assertTrue(int(variable[1:]) == typ)
603+
604+
with StataReader(path) as sr:
605+
typlist = sr.typlist
606+
variables = sr.varlist
607+
formats = sr.fmtlist
608+
for variable, fmt, typ in zip(variables, formats, typlist):
609+
self.assertTrue(int(variable[1:]) == int(fmt[1:-1]))
610+
self.assertTrue(int(variable[1:]) == typ)
609611

610612
def test_excessively_long_string(self):
611613
str_lens = (1, 244, 500)

0 commit comments

Comments
 (0)