Skip to content

BUG: preserve join keys dtype #13170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,7 @@ Modifying and Computations
Index.max
Index.reindex
Index.repeat
Index.where
Index.take
Index.putmask
Index.set_names
Expand Down
58 changes: 57 additions & 1 deletion doc/source/whatsnew/v0.18.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,20 @@ Other enhancements
- The ``pd.read_csv()`` with ``engine='python'`` has gained support for the ``decimal`` option (:issue:`12933`)

- ``Index.astype()`` now accepts an optional boolean argument ``copy``, which allows optional copying if the requirements on dtype are satisfied (:issue:`13209`)
- ``Index`` now supports the ``.where()`` function for same shape indexing (:issue:`13170`)

.. ipython:: python

idx = pd.Index(['a', 'b', 'c'])
idx.where([True, False, True])

- ``Categorical.astype()`` now accepts an optional boolean argument ``copy``, effective when dtype is categorical (:issue:`13209`)
- Consistent with the Python API, ``pd.read_csv()`` will now interpret ``+inf`` as positive infinity (:issue:`13274`)

- ``pd.read_html()`` has gained support for the ``decimal`` option (:issue:`12907`)



.. _whatsnew_0182.api:

API changes
Expand Down Expand Up @@ -119,7 +128,6 @@ New Behavior:

type(s.tolist()[0])


.. _whatsnew_0182.api.promote:

``Series`` type promotion on assignment
Expand Down Expand Up @@ -171,6 +179,54 @@ This will now convert integers/floats with the default unit of ``ns``.

pd.to_datetime([1, 'foo'], errors='coerce')

.. _whatsnew_0182.api.merging:

Merging changes
^^^^^^^^^^^^^^^

Merging will now preserve the dtype of the join keys (:issue:`8596`)

.. ipython:: python

df1 = pd.DataFrame({'key': [1], 'v1': [10]})
df1
df2 = pd.DataFrame({'key': [1, 2], 'v1': [20, 30]})
df2

Previous Behavior:

.. code-block:: ipython

In [5]: pd.merge(df1, df2, how='outer')
Out[5]:
key v1
0 1.0 10.0
1 1.0 20.0
2 2.0 30.0

In [6]: pd.merge(df1, df2, how='outer').dtypes
Out[6]:
key float64
v1 float64
dtype: object

New Behavior:

We are able to preserve the join keys

.. ipython:: python

pd.merge(df1, df2, how='outer')
pd.merge(df1, df2, how='outer').dtypes

Of course if you have missing values that are introduced, then the
resulting dtype will be upcast (unchanged from previous).

.. ipython:: python

pd.merge(df1, df2, how='outer', on='key')
pd.merge(df1, df2, how='outer', on='key').dtypes

.. _whatsnew_0182.api.other:

Other API changes
Expand Down
18 changes: 18 additions & 0 deletions pandas/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,24 @@ def repeat(self, n, *args, **kwargs):
nv.validate_repeat(args, kwargs)
return self._shallow_copy(self._values.repeat(n))

def where(self, cond, other=None):
"""
.. versionadded:: 0.18.2

Return an Index of same shape as self and whose corresponding
entries are from self where cond is True and otherwise are from
other.

Parameters
----------
cond : boolean same length as self
other : scalar, or array-like
"""
if other is None:
other = self._na_value
values = np.where(cond, self.values, other)
return self._shallow_copy_with_infer(values, dtype=self.dtype)

def ravel(self, order='C'):
"""
return an ndarray of the flattened values of the underlying data
Expand Down
23 changes: 23 additions & 0 deletions pandas/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,29 @@ def _can_reindex(self, indexer):
""" always allow reindexing """
pass

def where(self, cond, other=None):
"""
.. versionadded:: 0.18.2

Return an Index of same shape as self and whose corresponding
entries are from self where cond is True and otherwise are from
other.

Parameters
----------
cond : boolean same length as self
other : scalar, or array-like
"""
if other is None:
other = self._na_value
values = np.where(cond, self.values, other)

from pandas.core.categorical import Categorical
cat = Categorical(values,
categories=self.categories,
ordered=self.ordered)
return self._shallow_copy(cat, **self._get_attributes_dict())

def reindex(self, target, method=None, level=None, limit=None,
tolerance=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions pandas/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,10 @@ def repeat(self, n, *args, **kwargs):
for label in self.labels], names=self.names,
sortorder=self.sortorder, verify_integrity=False)

def where(self, cond, other=None):
raise NotImplementedError(".where is not supported for "
"MultiIndex operations")

def drop(self, labels, level=None, errors='raise'):
"""
Make new MultiIndex with passed list of labels deleted
Expand Down
14 changes: 13 additions & 1 deletion pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pandas import (Series, Index, Float64Index, Int64Index, RangeIndex,
MultiIndex, CategoricalIndex, DatetimeIndex,
TimedeltaIndex, PeriodIndex)
TimedeltaIndex, PeriodIndex, notnull)
from pandas.util.testing import assertRaisesRegexp

import pandas.util.testing as tm
Expand Down Expand Up @@ -363,6 +363,18 @@ def test_numpy_repeat(self):
tm.assertRaisesRegexp(ValueError, msg, np.repeat,
i, rep, axis=0)

def test_where(self):
i = self.create_index()
result = i.where(notnull(i))
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = pd.Index([np.nan, np.nan] + i[2:].tolist())
result = i.where(notnull(i2))
expected = i2
tm.assert_index_equal(result, expected)

def test_setops_errorcases(self):
for name, idx in compat.iteritems(self.indices):
# # non-iterable input
Expand Down
15 changes: 14 additions & 1 deletion pandas/tests/indexes/test_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np

from pandas import Categorical, compat
from pandas import Categorical, compat, notnull
from pandas.util.testing import assert_almost_equal
import pandas.core.config as cf
import pandas as pd
Expand Down Expand Up @@ -230,6 +230,19 @@ def f(x):
ordered=False)
tm.assert_categorical_equal(result, exp)

def test_where(self):
i = self.create_index()
result = i.where(notnull(i))
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = pd.CategoricalIndex([np.nan, np.nan] + i[2:].tolist(),
categories=i.categories)
result = i.where(notnull(i2))
expected = i2
tm.assert_index_equal(result, expected)

def test_append(self):

ci = self.create_index()
Expand Down
67 changes: 66 additions & 1 deletion pandas/tests/indexes/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pandas import (DatetimeIndex, Float64Index, Index, Int64Index,
NaT, Period, PeriodIndex, Series, Timedelta,
TimedeltaIndex, date_range, period_range,
timedelta_range)
timedelta_range, notnull)

import pandas.util.testing as tm

Expand Down Expand Up @@ -449,6 +449,38 @@ def test_astype_raises(self):
self.assertRaises(ValueError, idx.astype, 'datetime64')
self.assertRaises(ValueError, idx.astype, 'datetime64[D]')

def test_where_other(self):

# other is ndarray or Index
i = pd.date_range('20130101', periods=3, tz='US/Eastern')

for arr in [np.nan, pd.NaT]:
result = i.where(notnull(i), other=np.nan)
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
result = i.where(notnull(i2), i2)
tm.assert_index_equal(result, i2)

i2 = i.copy()
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
result = i.where(notnull(i2), i2.values)
tm.assert_index_equal(result, i2)

def test_where_tz(self):
i = pd.date_range('20130101', periods=3, tz='US/Eastern')
result = i.where(notnull(i))
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
result = i.where(notnull(i2))
expected = i2
tm.assert_index_equal(result, expected)

def test_get_loc(self):
idx = pd.date_range('2000-01-01', periods=3)

Expand Down Expand Up @@ -776,6 +808,39 @@ def test_get_loc(self):
with tm.assertRaises(KeyError):
idx.get_loc('2000-01-10', method='nearest', tolerance='1 day')

def test_where(self):
i = self.create_index()
result = i.where(notnull(i))
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
freq='D')
result = i.where(notnull(i2))
expected = i2
tm.assert_index_equal(result, expected)

def test_where_other(self):

i = self.create_index()
for arr in [np.nan, pd.NaT]:
result = i.where(notnull(i), other=np.nan)
expected = i
tm.assert_index_equal(result, expected)

i2 = i.copy()
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
freq='D')
result = i.where(notnull(i2), i2)
tm.assert_index_equal(result, i2)

i2 = i.copy()
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
freq='D')
result = i.where(notnull(i2), i2.values)
tm.assert_index_equal(result, i2)

def test_get_indexer(self):
idx = pd.period_range('2000-01-01', periods=3).asfreq('H', how='start')
tm.assert_numpy_array_equal(idx.get_indexer(idx), [0, 1, 2])
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/indexes/test_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ def test_labels_dtypes(self):
self.assertTrue((i.labels[0] >= 0).all())
self.assertTrue((i.labels[1] >= 0).all())

def test_where(self):
i = MultiIndex.from_tuples([('A', 1), ('A', 2)])

def f():
i.where(True)

self.assertRaises(NotImplementedError, f)

def test_repeat(self):
reps = 2
numbers = [1, 2, 3]
Expand Down
40 changes: 40 additions & 0 deletions pandas/tests/types/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
import nose
import numpy as np

from pandas import NaT
from pandas.types.api import (DatetimeTZDtype, CategoricalDtype,
na_value_for_dtype, pandas_dtype)


def test_pandas_dtype():

assert pandas_dtype('datetime64[ns, US/Eastern]') == DatetimeTZDtype(
'datetime64[ns, US/Eastern]')
assert pandas_dtype('category') == CategoricalDtype()
for dtype in ['M8[ns]', 'm8[ns]', 'object', 'float64', 'int64']:
assert pandas_dtype(dtype) == np.dtype(dtype)


def test_na_value_for_dtype():
for dtype in [np.dtype('M8[ns]'), np.dtype('m8[ns]'),
DatetimeTZDtype('datetime64[ns, US/Eastern]')]:
assert na_value_for_dtype(dtype) is NaT

for dtype in ['u1', 'u2', 'u4', 'u8',
'i1', 'i2', 'i4', 'i8']:
assert na_value_for_dtype(np.dtype(dtype)) == 0

for dtype in ['bool']:
assert na_value_for_dtype(np.dtype(dtype)) is False

for dtype in ['f2', 'f4', 'f8']:
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))

for dtype in ['O']:
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))


if __name__ == '__main__':
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
exit=False)
Loading