Skip to content

Commit 83a30c9

Browse files
committed
Add dtype parameters instead of fix-string-like
The original parameter was causing a lot of acrobatics with regards to string dtypes between 2.x and 3.x. The new parameters simplify the internal logic and pass the responsibility and motivation of memory efficiency back to the users.
1 parent c3335e3 commit 83a30c9

File tree

3 files changed

+178
-94
lines changed

3 files changed

+178
-94
lines changed

doc/source/whatsnew/v0.24.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ Other Enhancements
411411
- :meth:`DataFrame.to_stata` and :class:`pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)
412412
- :meth:`DataFrame.between_time` and :meth:`DataFrame.at_time` have gained the ``axis`` parameter (:issue:`8839`)
413413
- The ``scatter_matrix``, ``andrews_curves``, ``parallel_coordinates``, ``lag_plot``, ``autocorrelation_plot``, ``bootstrap_plot``, and ``radviz`` plots from the ``pandas.plotting`` module are now accessible from calling :meth:`DataFrame.plot` (:issue:`11978`)
414-
- :meth:`DataFrame.to_records` now accepts a ``stringlike_as_fixed_length`` parameter to efficiently store string-likes as fixed-length string-like dtypes (e.g. ``S1``) instead of object dtype (``O``) (:issue:`18146`)
414+
- :meth:`DataFrame.to_records` now accepts ``index_dtypes`` and ``column_dtypes`` parameters to allow different data types in stored column and index records (:issue:`18146`)
415415
- :class:`IntervalIndex` has gained the :attr:`~IntervalIndex.is_overlapping` attribute to indicate if the ``IntervalIndex`` contains any overlapping intervals (:issue:`23309`)
416416
- :func:`pandas.DataFrame.to_sql` has gained the ``method`` argument to control SQL insertion clause. See the :ref:`insertion method <io.sql.method>` section in the documentation. (:issue:`8953`)
417417

pandas/core/frame.py

+57-58
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
OrderedDict, PY36, raise_with_traceback,
3636
string_and_binary_types)
3737
from pandas.compat.numpy import function as nv
38-
from pandas.api.types import infer_dtype
3938
from pandas.core.dtypes.cast import (
4039
maybe_upcast,
4140
cast_scalar_to_array,
@@ -1541,7 +1540,7 @@ def from_records(cls, data, index=None, exclude=None, columns=None,
15411540
return cls(mgr)
15421541

15431542
def to_records(self, index=True, convert_datetime64=None,
1544-
stringlike_as_fixed_length=False):
1543+
column_dtypes=None, index_dtypes=None):
15451544
"""
15461545
Convert DataFrame to a NumPy record array.
15471546
@@ -1558,11 +1557,20 @@ def to_records(self, index=True, convert_datetime64=None,
15581557
15591558
Whether to convert the index to datetime.datetime if it is a
15601559
DatetimeIndex.
1561-
stringlike_as_fixed_length : bool, default False
1562-
.. versionadded:: 0.24.0
1560+
column_dtypes : str, type, dict, default None
1561+
.. versionadded:: 0.24.0
1562+
1563+
If a string or type, the data type to store all columns. If
1564+
a dictionary, a mapping of column names and indices (zero-indexed)
1565+
to specific data types.
1566+
index_dtypes : str, type, dict, default None
1567+
.. versionadded:: 0.24.0
15631568
1564-
Store string-likes as fixed-length string-like dtypes
1565-
(e.g. ``S1`` dtype) instead of Python objects (``O`` dtype).
1569+
If a string or type, the data type to store all index levels. If
1570+
a dictionary, a mapping of index level names and indices
1571+
(zero-indexed) to specific data types.
1572+
1573+
This mapping is applied only if `index=True`.
15661574
15671575
Returns
15681576
-------
@@ -1605,26 +1613,22 @@ def to_records(self, index=True, convert_datetime64=None,
16051613
rec.array([(1, 0.5 ), (2, 0.75)],
16061614
dtype=[('A', '<i8'), ('B', '<f8')])
16071615
1608-
By default, strings are recorded as dtype 'O' for object:
1616+
Data types can be specified for the columns:
16091617
1610-
>>> df = pd.DataFrame({'A': [1, 2], 'B': ['abc', 'defg']},
1611-
... index=['a', 'b'])
1612-
>>> df.to_records()
1613-
rec.array([('a', 1, 'abc'), ('b', 2, 'defg')],
1614-
dtype=[('index', 'O'), ('A', '<i8'), ('B', 'O')])
1618+
>>> df.to_records(column_dtypes={"A": "int32"})
1619+
rec.array([('a', 1, 0.5 ), ('b', 2, 0.75)],
1620+
dtype=[('I', 'O'), ('A', '<i4'), ('B', '<f8')])
16151621
1616-
This can be inefficient (e.g. for short strings, or when storing with
1617-
`np.save()`). They can be recorded as fix-length string-like dtypes
1618-
such as 'S1' for zero-terminated bytes instead:
1622+
As well as for the index:
16191623
1620-
>>> df = pd.DataFrame({'A': [1, 2], 'B': ['abc', 'defg']},
1621-
... index=['a', 'b'])
1622-
>>> df.to_records(stringlike_as_fixed_length=True)
1623-
rec.array([('a', 1, 'abc'), ('b', 2, 'defg')],
1624-
dtype=[('index', '<U1'), ('A', '<i8'), ('B', '<U4')])
1624+
>>> df.to_records(index_dtypes="<S2")
1625+
rec.array([(b'a', 1, 0.5 ), (b'b', 2, 0.75)],
1626+
dtype=[('I', 'S2'), ('A', '<i8'), ('B', '<f8')])
16251627
1626-
Notice how the 'B' column is now stored as '<U4' for length-four
1627-
strings ('S4' for Python 2.x) instead of the 'O' object dtype.
1628+
>>> index_dtypes = "<S{}".format(df.index.str.len().max())
1629+
>>> df.to_records(index_dtypes=index_dtypes)
1630+
rec.array([(b'a', 1, 0.5 ), (b'b', 2, 0.75)],
1631+
dtype=[('I', 'S1'), ('A', '<i8'), ('B', '<f8')])
16281632
"""
16291633

16301634
if convert_datetime64 is not None:
@@ -1647,59 +1651,54 @@ def to_records(self, index=True, convert_datetime64=None,
16471651

16481652
count = 0
16491653
index_names = list(self.index.names)
1654+
16501655
if isinstance(self.index, MultiIndex):
16511656
for i, n in enumerate(index_names):
16521657
if n is None:
16531658
index_names[i] = 'level_%d' % count
16541659
count += 1
16551660
elif index_names[0] is None:
16561661
index_names = ['index']
1662+
16571663
names = (lmap(compat.text_type, index_names) +
16581664
lmap(compat.text_type, self.columns))
16591665
else:
16601666
arrays = [self[c].get_values() for c in self.columns]
16611667
names = lmap(compat.text_type, self.columns)
1668+
index_names = []
16621669

1670+
index_len = len(index_names)
16631671
formats = []
16641672

1665-
for v in arrays:
1666-
if not stringlike_as_fixed_length:
1667-
formats.append(v.dtype)
1673+
for i, v in enumerate(arrays):
1674+
index = i
1675+
1676+
if index < index_len:
1677+
dtype_mapping = index_dtypes
1678+
name = index_names[index]
16681679
else:
1669-
# gh-18146
1670-
#
1671-
# For string-like arrays, set dtype as zero-terminated bytes
1672-
# with max length equal to that of the longest string-like.
1673-
dtype = infer_dtype(v)
1674-
symbol = None
1675-
1676-
if dtype == "string":
1677-
# In Python 3.x, infer_dtype does not
1678-
# differentiate string from unicode
1679-
# like NumPy arrays do, so we
1680-
# specify unicode to be safe.
1681-
symbol = "S" if compat.PY2 else "U"
1682-
elif dtype == "unicode":
1683-
# In Python 3.x, infer_dtype does not
1684-
# differentiate string from unicode.
1685-
#
1686-
# Thus, we can only get this result
1687-
# in Python 2.x.
1688-
symbol = "U"
1689-
elif dtype == "bytes":
1690-
# In Python 2.x, infer_dtype does not
1691-
# differentiate string from bytes.
1692-
#
1693-
# Thus, we can only get this result
1694-
# in Python 3.x. However, NumPy does
1695-
# not have a fixed-length bytes dtype
1696-
# and just uses string instead.
1697-
symbol = "S"
1698-
1699-
if symbol is not None:
1700-
formats.append("{}{}".format(symbol, max(map(len, v))))
1680+
index -= index_len
1681+
dtype_mapping = column_dtypes
1682+
name = self.columns[index]
1683+
1684+
if isinstance(dtype_mapping, dict):
1685+
if name in dtype_mapping:
1686+
dtype_mapping = dtype_mapping[name]
1687+
elif index in dtype_mapping:
1688+
dtype_mapping = dtype_mapping[index]
17011689
else:
1702-
formats.append(v.dtype)
1690+
dtype_mapping = None
1691+
1692+
if dtype_mapping is None:
1693+
formats.append(v.dtype)
1694+
elif isinstance(dtype_mapping, (type, compat.string_types)):
1695+
formats.append(dtype_mapping)
1696+
else:
1697+
element = "row" if i < index_len else "column"
1698+
msg = ("Invalid dtype {dtype} specified for "
1699+
"{element} {name}").format(dtype=dtype_mapping,
1700+
element=element, name=name)
1701+
raise ValueError(msg)
17031702

17041703
return np.rec.fromarrays(
17051704
arrays,

pandas/tests/frame/test_convert_to.py

+120-35
Original file line numberDiff line numberDiff line change
@@ -191,42 +191,127 @@ def test_to_records_with_categorical(self):
191191
dtype=[('index', '=i8'), ('0', 'O')])
192192
tm.assert_almost_equal(result, expected)
193193

194-
@pytest.mark.parametrize("fixed_length", [True, False])
195-
@pytest.mark.parametrize("values,dtype_getter", [
196-
# Integer --> just take the dtype.
197-
([1, 2], lambda fixed, isPY2: "<i8"),
198-
199-
# Mixed --> cast to object.
200-
([1, "1"], lambda fixed, isPY2: "O"),
201-
202-
# String --> cast to string is PY2 else unicode in PY3.
203-
(["1", "2"], lambda fixed, isPY2: (
204-
("S" if isPY2 else "U") + "1") if fixed else "O"),
205-
206-
# String + max-length of longest string.
207-
(["12", "2"], lambda fixed, isPY2: (
208-
("S" if isPY2 else "U") + "2") if fixed else "O"),
209-
210-
# Unicode --> cast to unicode for both PY2 and PY3.
211-
([u"\u2120b", u"456"], lambda fixed, isPY2: "U3" if fixed else "O"),
212-
213-
# Bytes --> cast to string for both PY2 and PY3.
214-
([b"2", b"5"], lambda fixed, isPY2: "S1" if fixed else "O"),
215-
], ids=["int", "mixed", "str", "max-len", "unicode", "bytes"])
216-
def test_to_records_with_strings_as_fixed_length(self, fixed_length,
217-
values, dtype_getter):
218-
194+
@pytest.mark.parametrize("kwargs,expected", [
195+
# No dtypes --> default to array dtypes.
196+
(dict(),
197+
np.rec.array([(0, 1, 0.2, "a"), (1, 2, 1.5, "bc")],
198+
dtype=[("index", "<i8"), ("A", "<i8"),
199+
("B", "<f8"), ("C", "O")])),
200+
201+
# Should have no effect in this case.
202+
(dict(index=True),
203+
np.rec.array([(0, 1, 0.2, "a"), (1, 2, 1.5, "bc")],
204+
dtype=[("index", "<i8"), ("A", "<i8"),
205+
("B", "<f8"), ("C", "O")])),
206+
207+
# Column dtype applied across the board. Index unaffected.
208+
(dict(column_dtypes="<U4"),
209+
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
210+
dtype=[("index", "<i8"), ("A", "<U4"),
211+
("B", "<U4"), ("C", "<U4")])),
212+
213+
# Index dtype applied across the board. Columns unaffected.
214+
(dict(index_dtypes="<U1"),
215+
np.rec.array([("0", 1, 0.2, "a"), ("1", 2, 1.5, "bc")],
216+
dtype=[("index", "<U1"), ("A", "<i8"),
217+
("B", "<f8"), ("C", "O")])),
218+
219+
# Pass in a type instance.
220+
(dict(column_dtypes=np.unicode),
221+
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
222+
dtype=[("index", "<i8"), ("A", "<U"),
223+
("B", "<U"), ("C", "<U")])),
224+
225+
# Pass in a dictionary (name-only).
226+
(dict(column_dtypes={"A": np.int8, "B": np.float32, "C": "<U2"}),
227+
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
228+
dtype=[("index", "<i8"), ("A", "i1"),
229+
("B", "<f4"), ("C", "<U2")])),
230+
231+
# Pass in a dictionary (indices-only).
232+
(dict(index_dtypes={0: "int16"}),
233+
np.rec.array([(0, 1, 0.2, "a"), (1, 2, 1.5, "bc")],
234+
dtype=[("index", "i2"), ("A", "<i8"),
235+
("B", "<f8"), ("C", "O")])),
236+
237+
# Ignore index mappings if index is not True.
238+
(dict(index=False, index_dtypes="<U2"),
239+
np.rec.array([(1, 0.2, "a"), (2, 1.5, "bc")],
240+
dtype=[("A", "<i8"), ("B", "<f8"), ("C", "O")])),
241+
242+
# Non-existent names / indices in mapping should not error.
243+
(dict(index_dtypes={0: "int16", "not-there": "float32"}),
244+
np.rec.array([(0, 1, 0.2, "a"), (1, 2, 1.5, "bc")],
245+
dtype=[("index", "i2"), ("A", "<i8"),
246+
("B", "<f8"), ("C", "O")])),
247+
248+
# Names / indices not in mapping default to array dtype.
249+
(dict(column_dtypes={"A": np.int8, "B": np.float32}),
250+
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
251+
dtype=[("index", "<i8"), ("A", "i1"),
252+
("B", "<f4"), ("C", "O")])),
253+
254+
# Mixture of everything.
255+
(dict(column_dtypes={"A": np.int8, "B": np.float32},
256+
index_dtypes="<U2"),
257+
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
258+
dtype=[("index", "<U2"), ("A", "i1"),
259+
("B", "<f4"), ("C", "O")])),
260+
261+
# Invalid dype values.
262+
(dict(index=False, column_dtypes=list()),
263+
"Invalid dtype \\[\\] specified for column A"),
264+
265+
(dict(index=False, column_dtypes={"A": "int32", "B": 5}),
266+
"Invalid dtype 5 specified for column B"),
267+
])
268+
def test_to_records_dtype(self, kwargs, expected):
219269
# see gh-18146
220-
df = DataFrame({"values": values}, index=["a", "b"])
221-
result = df.to_records(stringlike_as_fixed_length=fixed_length)
222-
223-
ind_dtype = ((("S" if compat.PY2 else "U") + "1")
224-
if fixed_length else "O")
225-
val_dtype = dtype_getter(fixed_length, compat.PY2)
226-
227-
expected = np.rec.array([("a", values[0]), ("b", values[1])],
228-
dtype=[("index", ind_dtype),
229-
("values", val_dtype)])
270+
df = DataFrame({"A": [1, 2], "B": [0.2, 1.5], "C": ["a", "bc"]})
271+
272+
if isinstance(expected, str):
273+
with pytest.raises(ValueError, match=expected):
274+
df.to_records(**kwargs)
275+
else:
276+
result = df.to_records(**kwargs)
277+
tm.assert_almost_equal(result, expected)
278+
279+
@pytest.mark.parametrize("df,kwargs,expected", [
280+
# MultiIndex in the index.
281+
(DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
282+
columns=list("abc")).set_index(["a", "b"]),
283+
dict(column_dtypes="float64", index_dtypes={0: "int32", 1: "int8"}),
284+
np.rec.array([(1, 2, 3.), (4, 5, 6.), (7, 8, 9.)],
285+
dtype=[("a", "<i4"), ("b", "i1"), ("c", "<f8")])),
286+
287+
# MultiIndex in the columns.
288+
(DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
289+
columns=MultiIndex.from_tuples([("a", "d"), ("b", "e"),
290+
("c", "f")])),
291+
dict(column_dtypes={0: "<U1", 2: "float32"}, index_dtypes="float32"),
292+
np.rec.array([(0., u"1", 2, 3.), (1., u"4", 5, 6.),
293+
(2., u"7", 8, 9.)],
294+
dtype=[("index", "<f4"),
295+
("('a', 'd')", "<U1"),
296+
("('b', 'e')", "<i8"),
297+
("('c', 'f')", "<f4")])),
298+
299+
# MultiIndex in both the columns and index.
300+
(DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
301+
columns=MultiIndex.from_tuples([
302+
("a", "d"), ("b", "e"), ("c", "f")], names=list("ab")),
303+
index=MultiIndex.from_tuples([
304+
("d", -4), ("d", -5), ("f", -6)], names=list("cd"))),
305+
dict(column_dtypes="float64", index_dtypes={0: "<U2", 1: "int8"}),
306+
np.rec.array([("d", -4, 1., 2., 3.), ("d", -5, 4., 5., 6.),
307+
("f", -6, 7, 8, 9.)],
308+
dtype=[("c", "<U2"), ("d", "i1"),
309+
("('a', 'd')", "<f8"), ("('b', 'e')", "<f8"),
310+
("('c', 'f')", "<f8")]))
311+
])
312+
def test_to_records_dtype_mi(self, df, kwargs, expected):
313+
# see gh-18146
314+
result = df.to_records(**kwargs)
230315
tm.assert_almost_equal(result, expected)
231316

232317
@pytest.mark.parametrize('mapping', [

0 commit comments

Comments
 (0)