Skip to content

Commit 5011a37

Browse files
authored
ENH: context-manager for chunksize/iterator-reader (#38225)
1 parent b5233c4 commit 5011a37

19 files changed

+265
-189
lines changed

doc/source/user_guide/io.rst

+21-19
Original file line numberDiff line numberDiff line change
@@ -1577,19 +1577,21 @@ value will be an iterable object of type ``TextFileReader``:
15771577

15781578
.. ipython:: python
15791579
1580-
reader = pd.read_csv("tmp.sv", sep="|", chunksize=4)
1581-
reader
1580+
with pd.read_csv("tmp.sv", sep="|", chunksize=4) as reader:
1581+
reader
1582+
for chunk in reader:
1583+
print(chunk)
15821584
1583-
for chunk in reader:
1584-
print(chunk)
1585+
.. versionchanged:: 1.2
15851586

1587+
``read_csv/json/sas`` return a context-manager when iterating through a file.
15861588

15871589
Specifying ``iterator=True`` will also return the ``TextFileReader`` object:
15881590

15891591
.. ipython:: python
15901592
1591-
reader = pd.read_csv("tmp.sv", sep="|", iterator=True)
1592-
reader.get_chunk(5)
1593+
with pd.read_csv("tmp.sv", sep="|", iterator=True) as reader:
1594+
reader.get_chunk(5)
15931595
15941596
.. ipython:: python
15951597
:suppress:
@@ -2238,10 +2240,10 @@ For line-delimited json files, pandas can also return an iterator which reads in
22382240
df.to_json(orient="records", lines=True)
22392241
22402242
# reader is an iterator that returns ``chunksize`` lines each iteration
2241-
reader = pd.read_json(StringIO(jsonl), lines=True, chunksize=1)
2242-
reader
2243-
for chunk in reader:
2244-
print(chunk)
2243+
with pd.read_json(StringIO(jsonl), lines=True, chunksize=1) as reader:
2244+
reader
2245+
for chunk in reader:
2246+
print(chunk)
22452247
22462248
.. _io.table_schema:
22472249

@@ -5471,19 +5473,19 @@ object can be used as an iterator.
54715473

54725474
.. ipython:: python
54735475
5474-
reader = pd.read_stata("stata.dta", chunksize=3)
5475-
for df in reader:
5476-
print(df.shape)
5476+
with pd.read_stata("stata.dta", chunksize=3) as reader:
5477+
for df in reader:
5478+
print(df.shape)
54775479
54785480
For more fine-grained control, use ``iterator=True`` and specify
54795481
``chunksize`` with each call to
54805482
:func:`~pandas.io.stata.StataReader.read`.
54815483

54825484
.. ipython:: python
54835485
5484-
reader = pd.read_stata("stata.dta", iterator=True)
5485-
chunk1 = reader.read(5)
5486-
chunk2 = reader.read(5)
5486+
with pd.read_stata("stata.dta", iterator=True) as reader:
5487+
chunk1 = reader.read(5)
5488+
chunk2 = reader.read(5)
54875489
54885490
Currently the ``index`` is retrieved as a column.
54895491

@@ -5595,9 +5597,9 @@ Obtain an iterator and read an XPORT file 100,000 lines at a time:
55955597
pass
55965598
55975599
5598-
rdr = pd.read_sas("sas_xport.xpt", chunk=100000)
5599-
for chunk in rdr:
5600-
do_something(chunk)
5600+
with pd.read_sas("sas_xport.xpt", chunk=100000) as rdr:
5601+
for chunk in rdr:
5602+
do_something(chunk)
56015603
56025604
The specification_ for the xport file format is available from the SAS
56035605
web site.

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ Other enhancements
291291
- Improve numerical stability for :meth:`.Rolling.skew`, :meth:`.Rolling.kurt`, :meth:`Expanding.skew` and :meth:`Expanding.kurt` through implementation of Kahan summation (:issue:`6929`)
292292
- Improved error reporting for subsetting columns of a :class:`.DataFrameGroupBy` with ``axis=1`` (:issue:`37725`)
293293
- Implement method ``cross`` for :meth:`DataFrame.merge` and :meth:`DataFrame.join` (:issue:`5401`)
294+
- When :func:`read_csv/sas/json` are called with ``chuncksize``/``iterator`` they can be used in a ``with`` statement as they return context-managers (:issue:`38225`)
294295

295296
.. ---------------------------------------------------------------------------
296297

pandas/io/html.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -794,9 +794,8 @@ def _data_to_frame(**kwargs):
794794

795795
# fill out elements of body that are "ragged"
796796
_expand_elements(body)
797-
tp = TextParser(body, header=header, **kwargs)
798-
df = tp.read()
799-
return df
797+
with TextParser(body, header=header, **kwargs) as tp:
798+
return tp.read()
800799

801800

802801
_valid_parsers = {

pandas/io/json/_json.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,10 @@ def read_json(
437437
This can only be passed if `lines=True`.
438438
If this is None, the file will be read into memory all at once.
439439
440+
.. versionchanged:: 1.2
441+
442+
``JsonReader`` is a context manager.
443+
440444
compression : {{'infer', 'gzip', 'bz2', 'zip', 'xz', None}}, default 'infer'
441445
For on-the-fly decompression of on-disk data. If 'infer', then use
442446
gzip, bz2, zip or xz if path_or_buf is a string ending in
@@ -555,7 +559,8 @@ def read_json(
555559
if chunksize:
556560
return json_reader
557561

558-
return json_reader.read()
562+
with json_reader:
563+
return json_reader.read()
559564

560565

561566
class JsonReader(abc.Iterator):
@@ -747,6 +752,12 @@ def __next__(self):
747752
self.close()
748753
raise StopIteration
749754

755+
def __enter__(self):
756+
return self
757+
758+
def __exit__(self, exc_type, exc_value, traceback):
759+
self.close()
760+
750761

751762
class Parser:
752763
_split_keys: Tuple[str, ...]

pandas/io/parsers.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,19 @@
276276
iterator : bool, default False
277277
Return TextFileReader object for iteration or getting chunks with
278278
``get_chunk()``.
279+
280+
.. versionchanged:: 1.2
281+
282+
``TextFileReader`` is a context manager.
279283
chunksize : int, optional
280284
Return TextFileReader object for iteration.
281285
See the `IO Tools docs
282286
<https://pandas.pydata.org/pandas-docs/stable/io.html#io-chunking>`_
283287
for more information on ``iterator`` and ``chunksize``.
288+
289+
.. versionchanged:: 1.2
290+
291+
``TextFileReader`` is a context manager.
284292
compression : {{'infer', 'gzip', 'bz2', 'zip', 'xz', None}}, default 'infer'
285293
For on-the-fly decompression of on-disk data. If 'infer' and
286294
`filepath_or_buffer` is path-like, then detect compression from the
@@ -451,12 +459,8 @@ def _read(filepath_or_buffer: FilePathOrBuffer, kwds):
451459
if chunksize or iterator:
452460
return parser
453461

454-
try:
455-
data = parser.read(nrows)
456-
finally:
457-
parser.close()
458-
459-
return data
462+
with parser:
463+
return parser.read(nrows)
460464

461465

462466
_parser_defaults = {
@@ -1074,6 +1078,12 @@ def get_chunk(self, size=None):
10741078
size = min(size, self.nrows - self._currow)
10751079
return self.read(nrows=size)
10761080

1081+
def __enter__(self):
1082+
return self
1083+
1084+
def __exit__(self, exc_type, exc_value, traceback):
1085+
self.close()
1086+
10771087

10781088
def _is_index_col(col):
10791089
return col is not None and col is not False

pandas/io/sas/sasreader.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pandas._typing import FilePathOrBuffer, Label
88

9-
from pandas.io.common import IOHandles, stringify_path
9+
from pandas.io.common import stringify_path
1010

1111
if TYPE_CHECKING:
1212
from pandas import DataFrame
@@ -18,8 +18,6 @@ class ReaderBase(metaclass=ABCMeta):
1818
Protocol for XportReader and SAS7BDATReader classes.
1919
"""
2020

21-
handles: IOHandles
22-
2321
@abstractmethod
2422
def read(self, nrows=None):
2523
pass
@@ -28,6 +26,12 @@ def read(self, nrows=None):
2826
def close(self):
2927
pass
3028

29+
def __enter__(self):
30+
return self
31+
32+
def __exit__(self, exc_type, exc_value, traceback):
33+
self.close()
34+
3135

3236
@overload
3337
def read_sas(
@@ -87,9 +91,17 @@ def read_sas(
8791
Encoding for text data. If None, text data are stored as raw bytes.
8892
chunksize : int
8993
Read file `chunksize` lines at a time, returns iterator.
94+
95+
.. versionchanged:: 1.2
96+
97+
``TextFileReader`` is a context manager.
9098
iterator : bool, defaults to False
9199
If True, returns an iterator for reading the file incrementally.
92100
101+
.. versionchanged:: 1.2
102+
103+
``TextFileReader`` is a context manager.
104+
93105
Returns
94106
-------
95107
DataFrame if iterator=False and chunksize=None, else SAS7BDATReader
@@ -136,5 +148,5 @@ def read_sas(
136148
if iterator or chunksize:
137149
return reader
138150

139-
with reader.handles:
151+
with reader:
140152
return reader.read()

pandas/tests/io/json/test_compression.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ def test_chunksize_with_compression(compression):
6565
df = pd.read_json('{"a": ["foo", "bar", "baz"], "b": [4, 5, 6]}')
6666
df.to_json(path, orient="records", lines=True, compression=compression)
6767

68-
res = pd.read_json(path, lines=True, chunksize=1, compression=compression)
69-
roundtripped_df = pd.concat(res)
68+
with pd.read_json(
69+
path, lines=True, chunksize=1, compression=compression
70+
) as res:
71+
roundtripped_df = pd.concat(res)
7072
tm.assert_frame_equal(df, roundtripped_df)
7173

7274

pandas/tests/io/json/test_readlines.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,17 @@ def test_readjson_chunks(lines_json_df, chunksize):
7777
# GH17048: memory usage when lines=True
7878

7979
unchunked = read_json(StringIO(lines_json_df), lines=True)
80-
reader = read_json(StringIO(lines_json_df), lines=True, chunksize=chunksize)
81-
chunked = pd.concat(reader)
80+
with read_json(StringIO(lines_json_df), lines=True, chunksize=chunksize) as reader:
81+
chunked = pd.concat(reader)
8282

8383
tm.assert_frame_equal(chunked, unchunked)
8484

8585

8686
def test_readjson_chunksize_requires_lines(lines_json_df):
8787
msg = "chunksize can only be passed if lines=True"
8888
with pytest.raises(ValueError, match=msg):
89-
pd.read_json(StringIO(lines_json_df), lines=False, chunksize=2)
89+
with pd.read_json(StringIO(lines_json_df), lines=False, chunksize=2) as _:
90+
pass
9091

9192

9293
def test_readjson_chunks_series():
@@ -97,15 +98,17 @@ def test_readjson_chunks_series():
9798
unchunked = pd.read_json(strio, lines=True, typ="Series")
9899

99100
strio = StringIO(s.to_json(lines=True, orient="records"))
100-
chunked = pd.concat(pd.read_json(strio, lines=True, typ="Series", chunksize=1))
101+
with pd.read_json(strio, lines=True, typ="Series", chunksize=1) as reader:
102+
chunked = pd.concat(reader)
101103

102104
tm.assert_series_equal(chunked, unchunked)
103105

104106

105107
def test_readjson_each_chunk(lines_json_df):
106108
# Other tests check that the final result of read_json(chunksize=True)
107109
# is correct. This checks the intermediate chunks.
108-
chunks = list(pd.read_json(StringIO(lines_json_df), lines=True, chunksize=2))
110+
with pd.read_json(StringIO(lines_json_df), lines=True, chunksize=2) as reader:
111+
chunks = list(reader)
109112
assert chunks[0].shape == (2, 2)
110113
assert chunks[1].shape == (1, 2)
111114

@@ -114,7 +117,8 @@ def test_readjson_chunks_from_file():
114117
with tm.ensure_clean("test.json") as path:
115118
df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
116119
df.to_json(path, lines=True, orient="records")
117-
chunked = pd.concat(pd.read_json(path, lines=True, chunksize=1))
120+
with pd.read_json(path, lines=True, chunksize=1) as reader:
121+
chunked = pd.concat(reader)
118122
unchunked = pd.read_json(path, lines=True)
119123
tm.assert_frame_equal(unchunked, chunked)
120124

@@ -141,7 +145,8 @@ def test_readjson_chunks_closes(chunksize):
141145
compression=None,
142146
nrows=None,
143147
)
144-
reader.read()
148+
with reader:
149+
reader.read()
145150
assert (
146151
reader.handles.handle.closed
147152
), f"didn't close stream with chunksize = {chunksize}"
@@ -152,7 +157,10 @@ def test_readjson_invalid_chunksize(lines_json_df, chunksize):
152157
msg = r"'chunksize' must be an integer >=1"
153158

154159
with pytest.raises(ValueError, match=msg):
155-
pd.read_json(StringIO(lines_json_df), lines=True, chunksize=chunksize)
160+
with pd.read_json(
161+
StringIO(lines_json_df), lines=True, chunksize=chunksize
162+
) as _:
163+
pass
156164

157165

158166
@pytest.mark.parametrize("chunksize", [None, 1, 2])
@@ -176,7 +184,8 @@ def test_readjson_chunks_multiple_empty_lines(chunksize):
176184
orig = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
177185
test = pd.read_json(j, lines=True, chunksize=chunksize)
178186
if chunksize is not None:
179-
test = pd.concat(test)
187+
with test:
188+
test = pd.concat(test)
180189
tm.assert_frame_equal(orig, test, obj=f"chunksize: {chunksize}")
181190

182191

@@ -212,8 +221,8 @@ def test_readjson_nrows_chunks(nrows, chunksize):
212221
{"a": 3, "b": 4}
213222
{"a": 5, "b": 6}
214223
{"a": 7, "b": 8}"""
215-
reader = read_json(jsonl, lines=True, nrows=nrows, chunksize=chunksize)
216-
chunked = pd.concat(reader)
224+
with read_json(jsonl, lines=True, nrows=nrows, chunksize=chunksize) as reader:
225+
chunked = pd.concat(reader)
217226
expected = DataFrame({"a": [1, 3, 5, 7], "b": [2, 4, 6, 8]}).iloc[:nrows]
218227
tm.assert_frame_equal(chunked, expected)
219228

@@ -240,6 +249,6 @@ def test_readjson_lines_chunks_fileurl(datapath):
240249
]
241250
os_path = datapath("io", "json", "data", "line_delimited.json")
242251
file_url = Path(os_path).as_uri()
243-
url_reader = pd.read_json(file_url, lines=True, chunksize=1)
244-
for index, chuck in enumerate(url_reader):
245-
tm.assert_frame_equal(chuck, df_list_expected[index])
252+
with pd.read_json(file_url, lines=True, chunksize=1) as url_reader:
253+
for index, chuck in enumerate(url_reader):
254+
tm.assert_frame_equal(chuck, df_list_expected[index])

pandas/tests/io/parser/test_c_parser_only.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -376,25 +376,25 @@ def test_parse_trim_buffers(c_parser_only):
376376
)
377377

378378
# Iterate over the CSV file in chunks of `chunksize` lines
379-
chunks_ = parser.read_csv(
379+
with parser.read_csv(
380380
StringIO(csv_data), header=None, dtype=object, chunksize=chunksize
381-
)
382-
result = concat(chunks_, axis=0, ignore_index=True)
381+
) as chunks_:
382+
result = concat(chunks_, axis=0, ignore_index=True)
383383

384384
# Check for data corruption if there was no segfault
385385
tm.assert_frame_equal(result, expected)
386386

387387
# This extra test was added to replicate the fault in gh-5291.
388388
# Force 'utf-8' encoding, so that `_string_convert` would take
389389
# a different execution branch.
390-
chunks_ = parser.read_csv(
390+
with parser.read_csv(
391391
StringIO(csv_data),
392392
header=None,
393393
dtype=object,
394394
chunksize=chunksize,
395395
encoding="utf_8",
396-
)
397-
result = concat(chunks_, axis=0, ignore_index=True)
396+
) as chunks_:
397+
result = concat(chunks_, axis=0, ignore_index=True)
398398
tm.assert_frame_equal(result, expected)
399399

400400

0 commit comments

Comments
 (0)