Skip to content

Commit 63a1e5c

Browse files
StephenKappeljreback
authored andcommitted
ENH: astype() allows col label -> dtype mapping as arg
closes #7271 closes #13375
1 parent 210fea9 commit 63a1e5c

File tree

5 files changed

+130
-6
lines changed

5 files changed

+130
-6
lines changed

doc/source/whatsnew/v0.19.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ API changes
331331
- Passing ``Period`` with multiple frequencies to normal ``Index`` now returns ``Index`` with ``object`` dtype (:issue:`13664`)
332332
- ``PeriodIndex.fillna`` with ``Period`` has different freq now coerces to ``object`` dtype (:issue:`13664`)
333333
- More informative exceptions are passed through the csv parser. The exception type would now be the original exception type instead of ``CParserError``. (:issue:`13652`)
334+
- ``astype()`` will now accept a dict of column name to data types mapping as the ``dtype`` argument. (:issue:`12086`)
335+
334336

335337
.. _whatsnew_0190.api.tolist:
336338

pandas/core/generic.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# pylint: disable=W0231,E1101
2+
import collections
23
import warnings
34
import operator
45
import weakref
@@ -161,7 +162,7 @@ def _init_mgr(self, mgr, axes=None, dtype=None, copy=False):
161162

162163
@property
163164
def _constructor(self):
164-
"""Used when a manipulation result has the same dimesions as the
165+
"""Used when a manipulation result has the same dimensions as the
165166
original.
166167
"""
167168
raise AbstractMethodError(self)
@@ -3000,18 +3001,48 @@ def astype(self, dtype, copy=True, raise_on_error=True, **kwargs):
30003001
30013002
Parameters
30023003
----------
3003-
dtype : numpy.dtype or Python type
3004+
dtype : data type, or dict of column name -> data type
3005+
Use a numpy.dtype or Python type to cast entire pandas object to
3006+
the same type. Alternatively, use {col: dtype, ...}, where col is a
3007+
column label and dtype is a numpy.dtype or Python type to cast one
3008+
or more of the DataFrame's columns to column-specific types.
30043009
raise_on_error : raise on invalid input
30053010
kwargs : keyword arguments to pass on to the constructor
30063011
30073012
Returns
30083013
-------
30093014
casted : type of caller
30103015
"""
3016+
if isinstance(dtype, collections.Mapping):
3017+
if self.ndim == 1: # i.e. Series
3018+
if len(dtype) > 1 or list(dtype.keys())[0] != self.name:
3019+
raise KeyError('Only the Series name can be used for '
3020+
'the key in Series dtype mappings.')
3021+
new_type = list(dtype.values())[0]
3022+
return self.astype(new_type, copy, raise_on_error, **kwargs)
3023+
elif self.ndim > 2:
3024+
raise NotImplementedError(
3025+
'astype() only accepts a dtype arg of type dict when '
3026+
'invoked on Series and DataFrames. A single dtype must be '
3027+
'specified when invoked on a Panel.'
3028+
)
3029+
for col_name in dtype.keys():
3030+
if col_name not in self:
3031+
raise KeyError('Only a column name can be used for the '
3032+
'key in a dtype mappings argument.')
3033+
from pandas import concat
3034+
results = []
3035+
for col_name, col in self.iteritems():
3036+
if col_name in dtype:
3037+
results.append(col.astype(dtype[col_name], copy=copy))
3038+
else:
3039+
results.append(results.append(col.copy() if copy else col))
3040+
return concat(results, axis=1, copy=False)
30113041

3012-
mgr = self._data.astype(dtype=dtype, copy=copy,
3013-
raise_on_error=raise_on_error, **kwargs)
3014-
return self._constructor(mgr).__finalize__(self)
3042+
# else, only a single dtype is given
3043+
new_data = self._data.astype(dtype=dtype, copy=copy,
3044+
raise_on_error=raise_on_error, **kwargs)
3045+
return self._constructor(new_data).__finalize__(self)
30153046

30163047
def copy(self, deep=True):
30173048
"""

pandas/tests/frame/test_dtypes.py

+64-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
from pandas import (DataFrame, Series, date_range, Timedelta, Timestamp,
8-
compat, option_context)
8+
compat, concat, option_context)
99
from pandas.compat import u
1010
from pandas.types.dtypes import DatetimeTZDtype
1111
from pandas.tests.frame.common import TestData
@@ -396,6 +396,69 @@ def test_astype_str(self):
396396
expected = DataFrame(['1.12345678901'])
397397
assert_frame_equal(result, expected)
398398

399+
def test_astype_dict(self):
400+
# GH7271
401+
a = Series(date_range('2010-01-04', periods=5))
402+
b = Series(range(5))
403+
c = Series([0.0, 0.2, 0.4, 0.6, 0.8])
404+
d = Series(['1.0', '2', '3.14', '4', '5.4'])
405+
df = DataFrame({'a': a, 'b': b, 'c': c, 'd': d})
406+
original = df.copy(deep=True)
407+
408+
# change type of a subset of columns
409+
result = df.astype({'b': 'str', 'd': 'float32'})
410+
expected = DataFrame({
411+
'a': a,
412+
'b': Series(['0', '1', '2', '3', '4']),
413+
'c': c,
414+
'd': Series([1.0, 2.0, 3.14, 4.0, 5.4], dtype='float32')})
415+
assert_frame_equal(result, expected)
416+
assert_frame_equal(df, original)
417+
418+
result = df.astype({'b': np.float32, 'c': 'float32', 'd': np.float64})
419+
expected = DataFrame({
420+
'a': a,
421+
'b': Series([0.0, 1.0, 2.0, 3.0, 4.0], dtype='float32'),
422+
'c': Series([0.0, 0.2, 0.4, 0.6, 0.8], dtype='float32'),
423+
'd': Series([1.0, 2.0, 3.14, 4.0, 5.4], dtype='float64')})
424+
assert_frame_equal(result, expected)
425+
assert_frame_equal(df, original)
426+
427+
# change all columns
428+
assert_frame_equal(df.astype({'a': str, 'b': str, 'c': str, 'd': str}),
429+
df.astype(str))
430+
assert_frame_equal(df, original)
431+
432+
# error should be raised when using something other than column labels
433+
# in the keys of the dtype dict
434+
self.assertRaises(KeyError, df.astype, {'b': str, 2: str})
435+
self.assertRaises(KeyError, df.astype, {'e': str})
436+
assert_frame_equal(df, original)
437+
438+
# if the dtypes provided are the same as the original dtypes, the
439+
# resulting DataFrame should be the same as the original DataFrame
440+
equiv = df.astype({col: df[col].dtype for col in df.columns})
441+
assert_frame_equal(df, equiv)
442+
assert_frame_equal(df, original)
443+
444+
def test_astype_duplicate_col(self):
445+
a1 = Series([1, 2, 3, 4, 5], name='a')
446+
b = Series([0.1, 0.2, 0.4, 0.6, 0.8], name='b')
447+
a2 = Series([0, 1, 2, 3, 4], name='a')
448+
df = concat([a1, b, a2], axis=1)
449+
450+
result = df.astype(str)
451+
a1_str = Series(['1', '2', '3', '4', '5'], dtype='str', name='a')
452+
b_str = Series(['0.1', '0.2', '0.4', '0.6', '0.8'], dtype=str,
453+
name='b')
454+
a2_str = Series(['0', '1', '2', '3', '4'], dtype='str', name='a')
455+
expected = concat([a1_str, b_str, a2_str], axis=1)
456+
assert_frame_equal(result, expected)
457+
458+
result = df.astype({'a': 'str'})
459+
expected = concat([a1_str, b, a2_str], axis=1)
460+
assert_frame_equal(result, expected)
461+
399462
def test_timedeltas(self):
400463
df = DataFrame(dict(A=Series(date_range('2012-1-1', periods=3,
401464
freq='D')),

pandas/tests/series/test_dtypes.py

+16
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,22 @@ def test_astype_unicode(self):
133133
reload(sys) # noqa
134134
sys.setdefaultencoding(former_encoding)
135135

136+
def test_astype_dict(self):
137+
# GH7271
138+
s = Series(range(0, 10, 2), name='abc')
139+
140+
result = s.astype({'abc': str})
141+
expected = Series(['0', '2', '4', '6', '8'], name='abc')
142+
assert_series_equal(result, expected)
143+
144+
result = s.astype({'abc': 'float64'})
145+
expected = Series([0.0, 2.0, 4.0, 6.0, 8.0], dtype='float64',
146+
name='abc')
147+
assert_series_equal(result, expected)
148+
149+
self.assertRaises(KeyError, s.astype, {'abc': str, 'def': str})
150+
self.assertRaises(KeyError, s.astype, {0: str})
151+
136152
def test_complexx(self):
137153
# GH4819
138154
# complex access for ndarray compat

pandas/tests/test_panel.py

+12
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,18 @@ def test_dtypes(self):
12311231
expected = Series(np.dtype('float64'), index=self.panel.items)
12321232
assert_series_equal(result, expected)
12331233

1234+
def test_astype(self):
1235+
# GH7271
1236+
data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
1237+
panel = Panel(data, ['a', 'b'], ['c', 'd'], ['e', 'f'])
1238+
1239+
str_data = np.array([[['1', '2'], ['3', '4']],
1240+
[['5', '6'], ['7', '8']]])
1241+
expected = Panel(str_data, ['a', 'b'], ['c', 'd'], ['e', 'f'])
1242+
assert_panel_equal(panel.astype(str), expected)
1243+
1244+
self.assertRaises(NotImplementedError, panel.astype, {0: str})
1245+
12341246
def test_apply(self):
12351247
# GH1148
12361248

0 commit comments

Comments
 (0)