Skip to content

ENH: add StataReader context manager to ensure closing of the path #10613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,21 @@ def __init__(self, path_or_buf, convert_dates=True,

self._read_header()

def __enter__(self):
""" enter context manager """
return self

def __exit__(self, exc_type, exc_value, traceback):
""" exit context manager """
self.close()

def close(self):
""" close the handle if its open """
try:
self.path_or_buf.close()
except IOError:
pass

def _read_header(self):
first_char = self.path_or_buf.read(1)
if struct.unpack('c', first_char)[0] == b'<':
Expand Down
24 changes: 13 additions & 11 deletions pandas/io/tests/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,11 @@ def test_timestamp_and_label(self):
data_label = 'This is a data file.'
with tm.ensure_clean() as path:
original.to_stata(path, time_stamp=time_stamp, data_label=data_label)
reader = StataReader(path)
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
assert parsed_time_stamp == time_stamp
assert reader.data_label == data_label

with StataReader(path) as reader:
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
assert parsed_time_stamp == time_stamp
assert reader.data_label == data_label

def test_numeric_column_names(self):
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
Expand Down Expand Up @@ -599,13 +600,14 @@ def test_minimal_size_col(self):
original = DataFrame(s)
with tm.ensure_clean() as path:
original.to_stata(path, write_index=False)
sr = StataReader(path)
typlist = sr.typlist
variables = sr.varlist
formats = sr.fmtlist
for variable, fmt, typ in zip(variables, formats, typlist):
self.assertTrue(int(variable[1:]) == int(fmt[1:-1]))
self.assertTrue(int(variable[1:]) == typ)

with StataReader(path) as sr:
typlist = sr.typlist
variables = sr.varlist
formats = sr.fmtlist
for variable, fmt, typ in zip(variables, formats, typlist):
self.assertTrue(int(variable[1:]) == int(fmt[1:-1]))
self.assertTrue(int(variable[1:]) == typ)

def test_excessively_long_string(self):
str_lens = (1, 244, 500)
Expand Down