Skip to content

Commit 0a267cf

Browse files
committed
BUG: preserve merge keys dtypes when possible
closes pandas-dev#8596 xref to pandas-dev#13169 as assignment of Index of bools not retaining dtype
1 parent 4173dbf commit 0a267cf

File tree

14 files changed

+433
-127
lines changed

14 files changed

+433
-127
lines changed

doc/source/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,7 @@ Modifying and Computations
13331333
Index.max
13341334
Index.reindex
13351335
Index.repeat
1336+
Index.where
13361337
Index.take
13371338
Index.putmask
13381339
Index.set_names

doc/source/whatsnew/v0.18.2.txt

+57-1
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,20 @@ Other enhancements
7777
- The ``pd.read_csv()`` with ``engine='python'`` has gained support for the ``decimal`` option (:issue:`12933`)
7878

7979
- ``Index.astype()`` now accepts an optional boolean argument ``copy``, which allows optional copying if the requirements on dtype are satisfied (:issue:`13209`)
80+
- ``Index`` now supports the ``.where()`` function for same shape indexing (:issue:`13170`)
81+
82+
.. ipython:: python
83+
84+
idx = pd.Index(['a', 'b', 'c'])
85+
idx.where([True, False, True])
86+
8087
- ``Categorical.astype()`` now accepts an optional boolean argument ``copy``, effective when dtype is categorical (:issue:`13209`)
8188
- Consistent with the Python API, ``pd.read_csv()`` will now interpret ``+inf`` as positive infinity (:issue:`13274`)
8289

8390
- ``pd.read_html()`` has gained support for the ``decimal`` option (:issue:`12907`)
8491

92+
93+
8594
.. _whatsnew_0182.api:
8695

8796
API changes
@@ -119,7 +128,6 @@ New Behavior:
119128

120129
type(s.tolist()[0])
121130

122-
123131
.. _whatsnew_0182.api.promote:
124132

125133
``Series`` type promotion on assignment
@@ -171,6 +179,54 @@ This will now convert integers/floats with the default unit of ``ns``.
171179

172180
pd.to_datetime([1, 'foo'], errors='coerce')
173181

182+
.. _whatsnew_0182.api.merging:
183+
184+
Merging changes
185+
^^^^^^^^^^^^^^^
186+
187+
Merging will now preserve the dtype of the join keys (:issue:`8596`)
188+
189+
.. ipython:: python
190+
191+
df1 = pd.DataFrame({'key': [1], 'v1': [10]})
192+
df1
193+
df2 = pd.DataFrame({'key': [1, 2], 'v1': [20, 30]})
194+
df2
195+
196+
Previous Behavior:
197+
198+
.. code-block:: ipython
199+
200+
In [5]: pd.merge(df1, df2, how='outer')
201+
Out[5]:
202+
key v1
203+
0 1.0 10.0
204+
1 1.0 20.0
205+
2 2.0 30.0
206+
207+
In [6]: pd.merge(df1, df2, how='outer').dtypes
208+
Out[6]:
209+
key float64
210+
v1 float64
211+
dtype: object
212+
213+
New Behavior:
214+
215+
We are able to preserve the join keys
216+
217+
.. ipython:: python
218+
219+
pd.merge(df1, df2, how='outer')
220+
pd.merge(df1, df2, how='outer').dtypes
221+
222+
Of course if you have missing values that are introduced, then the
223+
resulting dtype will be upcast (unchanged from previous).
224+
225+
.. ipython:: python
226+
227+
pd.merge(df1, df2, how='outer', on='key')
228+
pd.merge(df1, df2, how='outer', on='key').dtypes
229+
174230
.. _whatsnew_0182.api.other:
175231

176232
Other API changes

pandas/indexes/base.py

+18
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,24 @@ def repeat(self, n, *args, **kwargs):
465465
nv.validate_repeat(args, kwargs)
466466
return self._shallow_copy(self._values.repeat(n))
467467

468+
def where(self, cond, other=None):
469+
"""
470+
.. versionadded:: 0.18.2
471+
472+
Return an Index of same shape as self and whose corresponding
473+
entries are from self where cond is True and otherwise are from
474+
other.
475+
476+
Parameters
477+
----------
478+
cond : boolean same length as self
479+
other : scalar, or array-like
480+
"""
481+
if other is None:
482+
other = self._na_value
483+
values = np.where(cond, self.values, other)
484+
return self._shallow_copy_with_infer(values, dtype=self.dtype)
485+
468486
def ravel(self, order='C'):
469487
"""
470488
return an ndarray of the flattened values of the underlying data

pandas/indexes/category.py

+23
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,29 @@ def _can_reindex(self, indexer):
307307
""" always allow reindexing """
308308
pass
309309

310+
def where(self, cond, other=None):
311+
"""
312+
.. versionadded:: 0.18.2
313+
314+
Return an Index of same shape as self and whose corresponding
315+
entries are from self where cond is True and otherwise are from
316+
other.
317+
318+
Parameters
319+
----------
320+
cond : boolean same length as self
321+
other : scalar, or array-like
322+
"""
323+
if other is None:
324+
other = self._na_value
325+
values = np.where(cond, self.values, other)
326+
327+
from pandas.core.categorical import Categorical
328+
cat = Categorical(values,
329+
categories=self.categories,
330+
ordered=self.ordered)
331+
return self._shallow_copy(cat, **self._get_attributes_dict())
332+
310333
def reindex(self, target, method=None, level=None, limit=None,
311334
tolerance=None):
312335
"""

pandas/indexes/multi.py

+4
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,10 @@ def repeat(self, n, *args, **kwargs):
10841084
for label in self.labels], names=self.names,
10851085
sortorder=self.sortorder, verify_integrity=False)
10861086

1087+
def where(self, cond, other=None):
1088+
raise NotImplementedError(".where is not supported for "
1089+
"MultiIndex operations")
1090+
10871091
def drop(self, labels, level=None, errors='raise'):
10881092
"""
10891093
Make new MultiIndex with passed list of labels deleted

pandas/tests/indexes/common.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pandas import (Series, Index, Float64Index, Int64Index, RangeIndex,
99
MultiIndex, CategoricalIndex, DatetimeIndex,
10-
TimedeltaIndex, PeriodIndex)
10+
TimedeltaIndex, PeriodIndex, notnull)
1111
from pandas.util.testing import assertRaisesRegexp
1212

1313
import pandas.util.testing as tm
@@ -363,6 +363,18 @@ def test_numpy_repeat(self):
363363
tm.assertRaisesRegexp(ValueError, msg, np.repeat,
364364
i, rep, axis=0)
365365

366+
def test_where(self):
367+
i = self.create_index()
368+
result = i.where(notnull(i))
369+
expected = i
370+
tm.assert_index_equal(result, expected)
371+
372+
i2 = i.copy()
373+
i2 = pd.Index([np.nan, np.nan] + i[2:].tolist())
374+
result = i.where(notnull(i2))
375+
expected = i2
376+
tm.assert_index_equal(result, expected)
377+
366378
def test_setops_errorcases(self):
367379
for name, idx in compat.iteritems(self.indices):
368380
# # non-iterable input

pandas/tests/indexes/test_category.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313

14-
from pandas import Categorical, compat
14+
from pandas import Categorical, compat, notnull
1515
from pandas.util.testing import assert_almost_equal
1616
import pandas.core.config as cf
1717
import pandas as pd
@@ -230,6 +230,19 @@ def f(x):
230230
ordered=False)
231231
tm.assert_categorical_equal(result, exp)
232232

233+
def test_where(self):
234+
i = self.create_index()
235+
result = i.where(notnull(i))
236+
expected = i
237+
tm.assert_index_equal(result, expected)
238+
239+
i2 = i.copy()
240+
i2 = pd.CategoricalIndex([np.nan, np.nan] + i[2:].tolist(),
241+
categories=i.categories)
242+
result = i.where(notnull(i2))
243+
expected = i2
244+
tm.assert_index_equal(result, expected)
245+
233246
def test_append(self):
234247

235248
ci = self.create_index()

pandas/tests/indexes/test_datetimelike.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pandas import (DatetimeIndex, Float64Index, Index, Int64Index,
88
NaT, Period, PeriodIndex, Series, Timedelta,
99
TimedeltaIndex, date_range, period_range,
10-
timedelta_range)
10+
timedelta_range, notnull)
1111

1212
import pandas.util.testing as tm
1313

@@ -449,6 +449,38 @@ def test_astype_raises(self):
449449
self.assertRaises(ValueError, idx.astype, 'datetime64')
450450
self.assertRaises(ValueError, idx.astype, 'datetime64[D]')
451451

452+
def test_where_other(self):
453+
454+
# other is ndarray or Index
455+
i = pd.date_range('20130101', periods=3, tz='US/Eastern')
456+
457+
for arr in [np.nan, pd.NaT]:
458+
result = i.where(notnull(i), other=np.nan)
459+
expected = i
460+
tm.assert_index_equal(result, expected)
461+
462+
i2 = i.copy()
463+
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
464+
result = i.where(notnull(i2), i2)
465+
tm.assert_index_equal(result, i2)
466+
467+
i2 = i.copy()
468+
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
469+
result = i.where(notnull(i2), i2.values)
470+
tm.assert_index_equal(result, i2)
471+
472+
def test_where_tz(self):
473+
i = pd.date_range('20130101', periods=3, tz='US/Eastern')
474+
result = i.where(notnull(i))
475+
expected = i
476+
tm.assert_index_equal(result, expected)
477+
478+
i2 = i.copy()
479+
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
480+
result = i.where(notnull(i2))
481+
expected = i2
482+
tm.assert_index_equal(result, expected)
483+
452484
def test_get_loc(self):
453485
idx = pd.date_range('2000-01-01', periods=3)
454486

@@ -776,6 +808,39 @@ def test_get_loc(self):
776808
with tm.assertRaises(KeyError):
777809
idx.get_loc('2000-01-10', method='nearest', tolerance='1 day')
778810

811+
def test_where(self):
812+
i = self.create_index()
813+
result = i.where(notnull(i))
814+
expected = i
815+
tm.assert_index_equal(result, expected)
816+
817+
i2 = i.copy()
818+
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
819+
freq='D')
820+
result = i.where(notnull(i2))
821+
expected = i2
822+
tm.assert_index_equal(result, expected)
823+
824+
def test_where_other(self):
825+
826+
i = self.create_index()
827+
for arr in [np.nan, pd.NaT]:
828+
result = i.where(notnull(i), other=np.nan)
829+
expected = i
830+
tm.assert_index_equal(result, expected)
831+
832+
i2 = i.copy()
833+
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
834+
freq='D')
835+
result = i.where(notnull(i2), i2)
836+
tm.assert_index_equal(result, i2)
837+
838+
i2 = i.copy()
839+
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
840+
freq='D')
841+
result = i.where(notnull(i2), i2.values)
842+
tm.assert_index_equal(result, i2)
843+
779844
def test_get_indexer(self):
780845
idx = pd.period_range('2000-01-01', periods=3).asfreq('H', how='start')
781846
tm.assert_numpy_array_equal(idx.get_indexer(idx), [0, 1, 2])

pandas/tests/indexes/test_multi.py

+8
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ def test_labels_dtypes(self):
7878
self.assertTrue((i.labels[0] >= 0).all())
7979
self.assertTrue((i.labels[1] >= 0).all())
8080

81+
def test_where(self):
82+
i = MultiIndex.from_tuples([('A', 1), ('A', 2)])
83+
84+
def f():
85+
i.where(True)
86+
87+
self.assertRaises(NotImplementedError, f)
88+
8189
def test_repeat(self):
8290
reps = 2
8391
numbers = [1, 2, 3]

pandas/tests/types/test_types.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
import nose
3+
import numpy as np
4+
5+
from pandas import NaT
6+
from pandas.types.api import (DatetimeTZDtype, CategoricalDtype,
7+
na_value_for_dtype, pandas_dtype)
8+
9+
10+
def test_pandas_dtype():
11+
12+
assert pandas_dtype('datetime64[ns, US/Eastern]') == DatetimeTZDtype(
13+
'datetime64[ns, US/Eastern]')
14+
assert pandas_dtype('category') == CategoricalDtype()
15+
for dtype in ['M8[ns]', 'm8[ns]', 'object', 'float64', 'int64']:
16+
assert pandas_dtype(dtype) == np.dtype(dtype)
17+
18+
19+
def test_na_value_for_dtype():
20+
for dtype in [np.dtype('M8[ns]'), np.dtype('m8[ns]'),
21+
DatetimeTZDtype('datetime64[ns, US/Eastern]')]:
22+
assert na_value_for_dtype(dtype) is NaT
23+
24+
for dtype in ['u1', 'u2', 'u4', 'u8',
25+
'i1', 'i2', 'i4', 'i8']:
26+
assert na_value_for_dtype(np.dtype(dtype)) == 0
27+
28+
for dtype in ['bool']:
29+
assert na_value_for_dtype(np.dtype(dtype)) is False
30+
31+
for dtype in ['f2', 'f4', 'f8']:
32+
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))
33+
34+
for dtype in ['O']:
35+
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))
36+
37+
38+
if __name__ == '__main__':
39+
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
40+
exit=False)

0 commit comments

Comments
 (0)