diff --git a/doc/source/whatsnew/v0.19.0.txt b/doc/source/whatsnew/v0.19.0.txt index 0b9695125c0a9..10009f2ff8e43 100644 --- a/doc/source/whatsnew/v0.19.0.txt +++ b/doc/source/whatsnew/v0.19.0.txt @@ -267,6 +267,7 @@ API changes - ``.filter()`` enforces mutual exclusion of the keyword arguments. (:issue:`12399`) - ``PeridIndex`` can now accept ``list`` and ``array`` which contains ``pd.NaT`` (:issue:`13430`) - ``__setitem__`` will no longer apply a callable rhs as a function instead of storing it. Call ``where`` directly to get the previous behavior. (:issue:`13299`) +- ``astype()`` will now accept a dict of column name to data types mapping as the ``dtype`` argument. (:issue:`12086`) .. _whatsnew_0190.api.tolist: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index d6e6f571be53a..0c19ccbc40a9f 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -1,4 +1,5 @@ # pylint: disable=W0231,E1101 +import collections import warnings import operator import weakref @@ -161,7 +162,7 @@ def _init_mgr(self, mgr, axes=None, dtype=None, copy=False): @property def _constructor(self): - """Used when a manipulation result has the same dimesions as the + """Used when a manipulation result has the same dimensions as the original. """ raise AbstractMethodError(self) @@ -3001,7 +3002,11 @@ def astype(self, dtype, copy=True, raise_on_error=True, **kwargs): Parameters ---------- - dtype : numpy.dtype or Python type + dtype : data type, or dict of column name -> data type + Use a numpy.dtype or Python type to cast entire pandas object to + the same type. Alternatively, use {col: dtype, ...}, where col is a + column label and dtype is a numpy.dtype or Python type to cast one + or more of the DataFrame's columns to column-specific types. raise_on_error : raise on invalid input kwargs : keyword arguments to pass on to the constructor @@ -3009,10 +3014,36 @@ def astype(self, dtype, copy=True, raise_on_error=True, **kwargs): ------- casted : type of caller """ + if isinstance(dtype, collections.Mapping): + if self.ndim == 1: # i.e. Series + if len(dtype) > 1 or list(dtype.keys())[0] != self.name: + raise KeyError('Only the Series name can be used for ' + 'the key in Series dtype mappings.') + new_type = list(dtype.values())[0] + return self.astype(new_type, copy, raise_on_error, **kwargs) + elif self.ndim > 2: + raise NotImplementedError( + 'astype() only accepts a dtype arg of type dict when ' + 'invoked on Series and DataFrames. A single dtype must be ' + 'specified when invoked on a Panel.' + ) + for col_name in dtype.keys(): + if col_name not in self: + raise KeyError('Only a column name can be used for the ' + 'key in a dtype mappings argument.') + from pandas import concat + results = [] + for col_name, col in self.iteritems(): + if col_name in dtype: + results.append(col.astype(dtype[col_name], copy=copy)) + else: + results.append(results.append(col.copy() if copy else col)) + return concat(results, axis=1, copy=False) - mgr = self._data.astype(dtype=dtype, copy=copy, - raise_on_error=raise_on_error, **kwargs) - return self._constructor(mgr).__finalize__(self) + # else, only a single dtype is given + new_data = self._data.astype(dtype=dtype, copy=copy, + raise_on_error=raise_on_error, **kwargs) + return self._constructor(new_data).__finalize__(self) def copy(self, deep=True): """ diff --git a/pandas/tests/frame/test_dtypes.py b/pandas/tests/frame/test_dtypes.py index c650436eefaf3..817770b9da610 100644 --- a/pandas/tests/frame/test_dtypes.py +++ b/pandas/tests/frame/test_dtypes.py @@ -5,7 +5,7 @@ import numpy as np from pandas import (DataFrame, Series, date_range, Timedelta, Timestamp, - compat, option_context) + compat, concat, option_context) from pandas.compat import u from pandas.types.dtypes import DatetimeTZDtype from pandas.tests.frame.common import TestData @@ -396,6 +396,69 @@ def test_astype_str(self): expected = DataFrame(['1.12345678901']) assert_frame_equal(result, expected) + def test_astype_dict(self): + # GH7271 + a = Series(date_range('2010-01-04', periods=5)) + b = Series(range(5)) + c = Series([0.0, 0.2, 0.4, 0.6, 0.8]) + d = Series(['1.0', '2', '3.14', '4', '5.4']) + df = DataFrame({'a': a, 'b': b, 'c': c, 'd': d}) + original = df.copy(deep=True) + + # change type of a subset of columns + result = df.astype({'b': 'str', 'd': 'float32'}) + expected = DataFrame({ + 'a': a, + 'b': Series(['0', '1', '2', '3', '4']), + 'c': c, + 'd': Series([1.0, 2.0, 3.14, 4.0, 5.4], dtype='float32')}) + assert_frame_equal(result, expected) + assert_frame_equal(df, original) + + result = df.astype({'b': np.float32, 'c': 'float32', 'd': np.float64}) + expected = DataFrame({ + 'a': a, + 'b': Series([0.0, 1.0, 2.0, 3.0, 4.0], dtype='float32'), + 'c': Series([0.0, 0.2, 0.4, 0.6, 0.8], dtype='float32'), + 'd': Series([1.0, 2.0, 3.14, 4.0, 5.4], dtype='float64')}) + assert_frame_equal(result, expected) + assert_frame_equal(df, original) + + # change all columns + assert_frame_equal(df.astype({'a': str, 'b': str, 'c': str, 'd': str}), + df.astype(str)) + assert_frame_equal(df, original) + + # error should be raised when using something other than column labels + # in the keys of the dtype dict + self.assertRaises(KeyError, df.astype, {'b': str, 2: str}) + self.assertRaises(KeyError, df.astype, {'e': str}) + assert_frame_equal(df, original) + + # if the dtypes provided are the same as the original dtypes, the + # resulting DataFrame should be the same as the original DataFrame + equiv = df.astype({col: df[col].dtype for col in df.columns}) + assert_frame_equal(df, equiv) + assert_frame_equal(df, original) + + def test_astype_duplicate_col(self): + a1 = Series([1, 2, 3, 4, 5], name='a') + b = Series([0.1, 0.2, 0.4, 0.6, 0.8], name='b') + a2 = Series([0, 1, 2, 3, 4], name='a') + df = concat([a1, b, a2], axis=1) + + result = df.astype(str) + a1_str = Series(['1', '2', '3', '4', '5'], dtype='str', name='a') + b_str = Series(['0.1', '0.2', '0.4', '0.6', '0.8'], dtype=str, + name='b') + a2_str = Series(['0', '1', '2', '3', '4'], dtype='str', name='a') + expected = concat([a1_str, b_str, a2_str], axis=1) + assert_frame_equal(result, expected) + + result = df.astype({'a': 'str'}) + expected = concat([a1_str, b, a2_str], axis=1) + assert_frame_equal(result, expected) + def test_timedeltas(self): df = DataFrame(dict(A=Series(date_range('2012-1-1', periods=3, freq='D')), diff --git a/pandas/tests/series/test_dtypes.py b/pandas/tests/series/test_dtypes.py index 6864eac603ded..9a406dfa10c35 100644 --- a/pandas/tests/series/test_dtypes.py +++ b/pandas/tests/series/test_dtypes.py @@ -133,6 +133,22 @@ def test_astype_unicode(self): reload(sys) # noqa sys.setdefaultencoding(former_encoding) + def test_astype_dict(self): + # GH7271 + s = Series(range(0, 10, 2), name='abc') + + result = s.astype({'abc': str}) + expected = Series(['0', '2', '4', '6', '8'], name='abc') + assert_series_equal(result, expected) + + result = s.astype({'abc': 'float64'}) + expected = Series([0.0, 2.0, 4.0, 6.0, 8.0], dtype='float64', + name='abc') + assert_series_equal(result, expected) + + self.assertRaises(KeyError, s.astype, {'abc': str, 'def': str}) + self.assertRaises(KeyError, s.astype, {0: str}) + def test_complexx(self): # GH4819 # complex access for ndarray compat diff --git a/pandas/tests/test_panel.py b/pandas/tests/test_panel.py index f2e13867d3bf0..d9c7c1dc0dc62 100644 --- a/pandas/tests/test_panel.py +++ b/pandas/tests/test_panel.py @@ -1231,6 +1231,18 @@ def test_dtypes(self): expected = Series(np.dtype('float64'), index=self.panel.items) assert_series_equal(result, expected) + def test_astype(self): + # GH7271 + data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + panel = Panel(data, ['a', 'b'], ['c', 'd'], ['e', 'f']) + + str_data = np.array([[['1', '2'], ['3', '4']], + [['5', '6'], ['7', '8']]]) + expected = Panel(str_data, ['a', 'b'], ['c', 'd'], ['e', 'f']) + assert_panel_equal(panel.astype(str), expected) + + self.assertRaises(NotImplementedError, panel.astype, {0: str}) + def test_apply(self): # GH1148