Skip to content

Commit 5fe82e3

Browse files
committed
ENH: astype() allows col label -> dtype mapping as arg; GH7271
1 parent 043879f commit 5fe82e3

File tree

5 files changed

+129
-6
lines changed

5 files changed

+129
-6
lines changed

doc/source/whatsnew/v0.19.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ API changes
267267
- ``.filter()`` enforces mutual exclusion of the keyword arguments. (:issue:`12399`)
268268
- ``PeridIndex`` can now accept ``list`` and ``array`` which contains ``pd.NaT`` (:issue:`13430`)
269269
- ``__setitem__`` will no longer apply a callable rhs as a function instead of storing it. Call ``where`` directly to get the previous behavior. (:issue:`13299`)
270+
- ``astype()`` will now accept a dict of column name to data types mapping as the ``dtype`` argument. (:issue:`12086`)
270271

271272

272273
.. _whatsnew_0190.api.tolist:

pandas/core/generic.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# pylint: disable=W0231,E1101
2+
import collections
23
import warnings
34
import operator
45
import weakref
@@ -161,7 +162,7 @@ def _init_mgr(self, mgr, axes=None, dtype=None, copy=False):
161162

162163
@property
163164
def _constructor(self):
164-
"""Used when a manipulation result has the same dimesions as the
165+
"""Used when a manipulation result has the same dimensions as the
165166
original.
166167
"""
167168
raise AbstractMethodError(self)
@@ -3001,18 +3002,48 @@ def astype(self, dtype, copy=True, raise_on_error=True, **kwargs):
30013002
30023003
Parameters
30033004
----------
3004-
dtype : numpy.dtype or Python type
3005+
dtype : data type, or dict of column name -> data type
3006+
Use a numpy.dtype or Python type to cast entire pandas object to
3007+
the same type. Alternatively, use {col: dtype, ...}, where col is a
3008+
column label and dtype is a numpy.dtype or Python type to cast one
3009+
or more of the DataFrame's columns to column-specific types.
30053010
raise_on_error : raise on invalid input
30063011
kwargs : keyword arguments to pass on to the constructor
30073012
30083013
Returns
30093014
-------
30103015
casted : type of caller
30113016
"""
3017+
if isinstance(dtype, collections.Mapping):
3018+
if self.ndim == 1: # i.e. Series
3019+
if len(dtype) > 1 or list(dtype.keys())[0] != self.name:
3020+
raise KeyError('Only the Series name can be used for '
3021+
'the key in Series dtype mappings.')
3022+
new_type = list(dtype.values())[0]
3023+
return self.astype(new_type, copy, raise_on_error, **kwargs)
3024+
elif self.ndim > 2:
3025+
raise NotImplementedError(
3026+
'astype() only accepts a dtype arg of type dict when '
3027+
'invoked on Series and DataFrames. A single dtype must be '
3028+
'specified when invoked on a Panel.'
3029+
)
3030+
for col_name in dtype.keys():
3031+
if col_name not in self:
3032+
raise KeyError('Only a column name can be used for the '
3033+
'key in a dtype mappings argument.')
3034+
from pandas import concat
3035+
results = []
3036+
for col_name, col in self.iteritems():
3037+
if col_name in dtype:
3038+
results.append(col.astype(dtype[col_name], copy=copy))
3039+
else:
3040+
results.append(results.append(col.copy() if copy else col))
3041+
return concat(results, axis=1, copy=False)
30123042

3013-
mgr = self._data.astype(dtype=dtype, copy=copy,
3014-
raise_on_error=raise_on_error, **kwargs)
3015-
return self._constructor(mgr).__finalize__(self)
3043+
# else, only a single dtype is given
3044+
new_data = self._data.astype(dtype=dtype, copy=copy,
3045+
raise_on_error=raise_on_error, **kwargs)
3046+
return self._constructor(new_data).__finalize__(self)
30163047

30173048
def copy(self, deep=True):
30183049
"""

pandas/tests/frame/test_dtypes.py

+64-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
from pandas import (DataFrame, Series, date_range, Timedelta, Timestamp,
8-
compat, option_context)
8+
compat, concat, option_context)
99
from pandas.compat import u
1010
from pandas.types.dtypes import DatetimeTZDtype
1111
from pandas.tests.frame.common import TestData
@@ -396,6 +396,69 @@ def test_astype_str(self):
396396
expected = DataFrame(['1.12345678901'])
397397
assert_frame_equal(result, expected)
398398

399+
def test_astype_dict(self):
400+
# GH7271
401+
a = Series(date_range('2010-01-04', periods=5))
402+
b = Series(range(5))
403+
c = Series([0.0, 0.2, 0.4, 0.6, 0.8])
404+
d = Series(['1.0', '2', '3.14', '4', '5.4'])
405+
df = DataFrame({'a': a, 'b': b, 'c': c, 'd': d})
406+
original = df.copy(deep=True)
407+
408+
# change type of a subset of columns
409+
result = df.astype({'b': 'str', 'd': 'float32'})
410+
expected = DataFrame({
411+
'a': a,
412+
'b': Series(['0', '1', '2', '3', '4']),
413+
'c': c,
414+
'd': Series([1.0, 2.0, 3.14, 4.0, 5.4], dtype='float32')})
415+
assert_frame_equal(result, expected)
416+
assert_frame_equal(df, original)
417+
418+
result = df.astype({'b': np.float32, 'c': 'float32', 'd': np.float64})
419+
expected = DataFrame({
420+
'a': a,
421+
'b': Series([0.0, 1.0, 2.0, 3.0, 4.0], dtype='float32'),
422+
'c': Series([0.0, 0.2, 0.4, 0.6, 0.8], dtype='float32'),
423+
'd': Series([1.0, 2.0, 3.14, 4.0, 5.4], dtype='float64')})
424+
assert_frame_equal(result, expected)
425+
assert_frame_equal(df, original)
426+
427+
# change all columns
428+
assert_frame_equal(df.astype({'a': str, 'b': str, 'c': str, 'd': str}),
429+
df.astype(str))
430+
assert_frame_equal(df, original)
431+
432+
# error should be raised when using something other than column labels
433+
# in the keys of the dtype dict
434+
self.assertRaises(KeyError, df.astype, {'b': str, 2: str})
435+
self.assertRaises(KeyError, df.astype, {'e': str})
436+
assert_frame_equal(df, original)
437+
438+
# if the dtypes provided are the same as the original dtypes, the
439+
# resulting DataFrame should be the same as the original DataFrame
440+
equiv = df.astype({col: df[col].dtype for col in df.columns})
441+
assert_frame_equal(df, equiv)
442+
assert_frame_equal(df, original)
443+
444+
def test_astype_duplicate_col(self):
445+
a1 = Series([1, 2, 3, 4, 5], name='a')
446+
b = Series([0.1, 0.2, 0.4, 0.6, 0.8], name='b')
447+
a2 = Series([0, 1, 2, 3, 4], name='a')
448+
df = concat([a1, b, a2], axis=1)
449+
450+
result = df.astype(str)
451+
a1_str = Series(['1', '2', '3', '4', '5'], dtype='str', name='a')
452+
b_str = Series(['0.1', '0.2', '0.4', '0.6', '0.8'], dtype=str,
453+
name='b')
454+
a2_str = Series(['0', '1', '2', '3', '4'], dtype='str', name='a')
455+
expected = concat([a1_str, b_str, a2_str], axis=1)
456+
assert_frame_equal(result, expected)
457+
458+
result = df.astype({'a': 'str'})
459+
expected = concat([a1_str, b, a2_str], axis=1)
460+
assert_frame_equal(result, expected)
461+
399462
def test_timedeltas(self):
400463
df = DataFrame(dict(A=Series(date_range('2012-1-1', periods=3,
401464
freq='D')),

pandas/tests/series/test_dtypes.py

+16
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,22 @@ def test_astype_unicode(self):
133133
reload(sys) # noqa
134134
sys.setdefaultencoding(former_encoding)
135135

136+
def test_astype_dict(self):
137+
# GH7271
138+
s = Series(range(0, 10, 2), name='abc')
139+
140+
result = s.astype({'abc': str})
141+
expected = Series(['0', '2', '4', '6', '8'], name='abc')
142+
assert_series_equal(result, expected)
143+
144+
result = s.astype({'abc': 'float64'})
145+
expected = Series([0.0, 2.0, 4.0, 6.0, 8.0], dtype='float64',
146+
name='abc')
147+
assert_series_equal(result, expected)
148+
149+
self.assertRaises(KeyError, s.astype, {'abc': str, 'def': str})
150+
self.assertRaises(KeyError, s.astype, {0: str})
151+
136152
def test_complexx(self):
137153
# GH4819
138154
# complex access for ndarray compat

pandas/tests/test_panel.py

+12
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,18 @@ def test_dtypes(self):
12311231
expected = Series(np.dtype('float64'), index=self.panel.items)
12321232
assert_series_equal(result, expected)
12331233

1234+
def test_astype(self):
1235+
# GH7271
1236+
data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
1237+
panel = Panel(data, ['a', 'b'], ['c', 'd'], ['e', 'f'])
1238+
1239+
str_data = np.array([[['1', '2'], ['3', '4']],
1240+
[['5', '6'], ['7', '8']]])
1241+
expected = Panel(str_data, ['a', 'b'], ['c', 'd'], ['e', 'f'])
1242+
assert_panel_equal(panel.astype(str), expected)
1243+
1244+
self.assertRaises(NotImplementedError, panel.astype, {0: str})
1245+
12341246
def test_apply(self):
12351247
# GH1148
12361248

0 commit comments

Comments
 (0)