Skip to content

Commit 0cacd24

Browse files
committed
Merge pull request #10638 from sinhrks/array_finalize
CLN: Make ufunc works for Index
2 parents 76520d9 + e244bdd commit 0cacd24

File tree

11 files changed

+315
-46
lines changed

11 files changed

+315
-46
lines changed

doc/source/whatsnew/v0.17.0.txt

+5
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ Other enhancements
309309

310310
- ``DataFrame.apply`` will return a Series of dicts if the passed function returns a dict and ``reduce=True`` (:issue:`8735`).
311311

312+
- ``PeriodIndex`` now supports arithmetic with ``np.ndarray`` (:issue:`10638`)
313+
312314
- ``concat`` will now use existing Series names if provided (:issue:`10698`).
313315

314316
.. ipython:: python
@@ -333,6 +335,7 @@ Other enhancements
333335

334336
pd.concat([foo, bar, baz], 1)
335337

338+
336339
.. _whatsnew_0170.api:
337340

338341
.. _whatsnew_0170.api_breaking:
@@ -1005,3 +1008,5 @@ Bug Fixes
10051008
- Bug when constructing ``DataFrame`` where passing a dictionary with only scalar values and specifying columns did not raise an error (:issue:`10856`)
10061009
- Bug in ``.var()`` causing roundoff errors for highly similar values (:issue:`10242`)
10071010
- Bug in ``DataFrame.plot(subplots=True)`` with duplicated columns outputs incorrect result (:issue:`10962`)
1011+
- Bug in ``Index`` arithmetic may result in incorrect class (:issue:`10638`)
1012+

pandas/core/index.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,12 @@ def __array_wrap__(self, result, context=None):
273273
"""
274274
Gets called after a ufunc
275275
"""
276-
return self._shallow_copy(result)
276+
if is_bool_dtype(result):
277+
return result
278+
279+
attrs = self._get_attributes_dict()
280+
attrs = self._maybe_update_attributes(attrs)
281+
return Index(result, **attrs)
277282

278283
@cache_readonly
279284
def dtype(self):
@@ -2809,6 +2814,10 @@ def invalid_op(self, other=None):
28092814
cls.__abs__ = _make_invalid_op('__abs__')
28102815
cls.__inv__ = _make_invalid_op('__inv__')
28112816

2817+
def _maybe_update_attributes(self, attrs):
2818+
""" Update Index attributes (e.g. freq) depending on op """
2819+
return attrs
2820+
28122821
@classmethod
28132822
def _add_numeric_methods(cls):
28142823
""" add in numeric methods """
@@ -2849,7 +2858,9 @@ def _evaluate_numeric_binop(self, other):
28492858
if reversed:
28502859
values, other = other, values
28512860

2852-
return self._shallow_copy(op(values, other))
2861+
attrs = self._get_attributes_dict()
2862+
attrs = self._maybe_update_attributes(attrs)
2863+
return Index(op(values, other), **attrs)
28532864

28542865
return _evaluate_numeric_binop
28552866

@@ -2861,8 +2872,9 @@ def _evaluate_numeric_unary(self):
28612872
if not self._is_numeric_dtype:
28622873
raise TypeError("cannot evaluate a numeric op {opstr} for type: {typ}".format(opstr=opstr,
28632874
typ=type(self)))
2864-
2865-
return self._shallow_copy(op(self.values))
2875+
attrs = self._get_attributes_dict()
2876+
attrs = self._maybe_update_attributes(attrs)
2877+
return Index(op(self.values), **attrs)
28662878

28672879
return _evaluate_numeric_unary
28682880

pandas/core/ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,8 @@ def wrapper(left, right, name=name, na_op=na_op):
613613
else:
614614
# scalars
615615
if hasattr(lvalues, 'values') and not isinstance(lvalues, pd.DatetimeIndex):
616-
lvalues = lvalues.values
616+
lvalues = lvalues.values
617+
617618
return left._constructor(wrap_results(na_op(lvalues, rvalues)),
618619
index=left.index, name=left.name,
619620
dtype=dtype)

pandas/tests/test_index.py

+128-2
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,56 @@ def test_equals_op(self):
509509
tm.assert_numpy_array_equal(index_a == item, expected3)
510510
tm.assert_numpy_array_equal(series_a == item, expected3)
511511

512+
def test_numpy_ufuncs(self):
513+
# test ufuncs of numpy 1.9.2. see:
514+
# http://docs.scipy.org/doc/numpy/reference/ufuncs.html
515+
516+
# some functions are skipped because it may return different result
517+
# for unicode input depending on numpy version
518+
519+
for name, idx in compat.iteritems(self.indices):
520+
for func in [np.exp, np.exp2, np.expm1, np.log, np.log2, np.log10,
521+
np.log1p, np.sqrt, np.sin, np.cos,
522+
np.tan, np.arcsin, np.arccos, np.arctan,
523+
np.sinh, np.cosh, np.tanh, np.arcsinh, np.arccosh,
524+
np.arctanh, np.deg2rad, np.rad2deg]:
525+
if isinstance(idx, pd.tseries.base.DatetimeIndexOpsMixin):
526+
# raise TypeError or ValueError (PeriodIndex)
527+
# PeriodIndex behavior should be changed in future version
528+
with tm.assertRaises(Exception):
529+
func(idx)
530+
elif isinstance(idx, (Float64Index, Int64Index)):
531+
# coerces to float (e.g. np.sin)
532+
result = func(idx)
533+
exp = Index(func(idx.values), name=idx.name)
534+
self.assert_index_equal(result, exp)
535+
self.assertIsInstance(result, pd.Float64Index)
536+
else:
537+
# raise AttributeError or TypeError
538+
if len(idx) == 0:
539+
continue
540+
else:
541+
with tm.assertRaises(Exception):
542+
func(idx)
543+
544+
for func in [np.isfinite, np.isinf, np.isnan, np.signbit]:
545+
if isinstance(idx, pd.tseries.base.DatetimeIndexOpsMixin):
546+
# raise TypeError or ValueError (PeriodIndex)
547+
with tm.assertRaises(Exception):
548+
func(idx)
549+
elif isinstance(idx, (Float64Index, Int64Index)):
550+
# results in bool array
551+
result = func(idx)
552+
exp = func(idx.values)
553+
self.assertIsInstance(result, np.ndarray)
554+
tm.assertNotIsInstance(result, Index)
555+
else:
556+
if len(idx) == 0:
557+
continue
558+
else:
559+
with tm.assertRaises(Exception):
560+
func(idx)
561+
512562

513563
class TestIndex(Base, tm.TestCase):
514564
_holder = Index
@@ -2848,6 +2898,41 @@ def test_slice_keep_name(self):
28482898
idx = Int64Index([1, 2], name='asdf')
28492899
self.assertEqual(idx.name, idx[1:].name)
28502900

2901+
def test_ufunc_coercions(self):
2902+
idx = pd.Int64Index([1, 2, 3, 4, 5], name='x')
2903+
2904+
result = np.sqrt(idx)
2905+
tm.assertIsInstance(result, Float64Index)
2906+
exp = pd.Float64Index(np.sqrt(np.array([1, 2, 3, 4, 5])), name='x')
2907+
tm.assert_index_equal(result, exp)
2908+
2909+
result = np.divide(idx, 2.)
2910+
tm.assertIsInstance(result, Float64Index)
2911+
exp = pd.Float64Index([0.5, 1., 1.5, 2., 2.5], name='x')
2912+
tm.assert_index_equal(result, exp)
2913+
2914+
# _evaluate_numeric_binop
2915+
result = idx + 2.
2916+
tm.assertIsInstance(result, Float64Index)
2917+
exp = pd.Float64Index([3., 4., 5., 6., 7.], name='x')
2918+
tm.assert_index_equal(result, exp)
2919+
2920+
result = idx - 2.
2921+
tm.assertIsInstance(result, Float64Index)
2922+
exp = pd.Float64Index([-1., 0., 1., 2., 3.], name='x')
2923+
tm.assert_index_equal(result, exp)
2924+
2925+
result = idx * 1.
2926+
tm.assertIsInstance(result, Float64Index)
2927+
exp = pd.Float64Index([1., 2., 3., 4., 5.], name='x')
2928+
tm.assert_index_equal(result, exp)
2929+
2930+
result = idx / 2.
2931+
tm.assertIsInstance(result, Float64Index)
2932+
exp = pd.Float64Index([0.5, 1., 1.5, 2., 2.5], name='x')
2933+
tm.assert_index_equal(result, exp)
2934+
2935+
28512936
class DatetimeLike(Base):
28522937

28532938
def test_str(self):
@@ -3101,7 +3186,9 @@ def test_get_loc(self):
31013186
tolerance=timedelta(1)), 1)
31023187
with tm.assertRaisesRegexp(ValueError, 'must be convertible'):
31033188
idx.get_loc('2000-01-10', method='nearest', tolerance='foo')
3104-
with tm.assertRaisesRegexp(ValueError, 'different freq'):
3189+
3190+
msg = 'Input has different freq from PeriodIndex\\(freq=D\\)'
3191+
with tm.assertRaisesRegexp(ValueError, msg):
31053192
idx.get_loc('2000-01-10', method='nearest', tolerance='1 hour')
31063193
with tm.assertRaises(KeyError):
31073194
idx.get_loc('2000-01-10', method='nearest', tolerance='1 day')
@@ -3119,7 +3206,8 @@ def test_get_indexer(self):
31193206
idx.get_indexer(target, 'nearest', tolerance='1 hour'),
31203207
[0, -1, 1])
31213208

3122-
with self.assertRaisesRegexp(ValueError, 'different freq'):
3209+
msg = 'Input has different freq from PeriodIndex\\(freq=H\\)'
3210+
with self.assertRaisesRegexp(ValueError, msg):
31233211
idx.get_indexer(target, 'nearest', tolerance='1 minute')
31243212

31253213
tm.assert_numpy_array_equal(
@@ -3215,6 +3303,44 @@ def test_numeric_compat(self):
32153303
def test_pickle_compat_construction(self):
32163304
pass
32173305

3306+
def test_ufunc_coercions(self):
3307+
# normal ops are also tested in tseries/test_timedeltas.py
3308+
idx = TimedeltaIndex(['2H', '4H', '6H', '8H', '10H'],
3309+
freq='2H', name='x')
3310+
3311+
for result in [idx * 2, np.multiply(idx, 2)]:
3312+
tm.assertIsInstance(result, TimedeltaIndex)
3313+
exp = TimedeltaIndex(['4H', '8H', '12H', '16H', '20H'],
3314+
freq='4H', name='x')
3315+
tm.assert_index_equal(result, exp)
3316+
self.assertEqual(result.freq, '4H')
3317+
3318+
for result in [idx / 2, np.divide(idx, 2)]:
3319+
tm.assertIsInstance(result, TimedeltaIndex)
3320+
exp = TimedeltaIndex(['1H', '2H', '3H', '4H', '5H'],
3321+
freq='H', name='x')
3322+
tm.assert_index_equal(result, exp)
3323+
self.assertEqual(result.freq, 'H')
3324+
3325+
idx = TimedeltaIndex(['2H', '4H', '6H', '8H', '10H'],
3326+
freq='2H', name='x')
3327+
for result in [ - idx, np.negative(idx)]:
3328+
tm.assertIsInstance(result, TimedeltaIndex)
3329+
exp = TimedeltaIndex(['-2H', '-4H', '-6H', '-8H', '-10H'],
3330+
freq='-2H', name='x')
3331+
tm.assert_index_equal(result, exp)
3332+
self.assertEqual(result.freq, None)
3333+
3334+
idx = TimedeltaIndex(['-2H', '-1H', '0H', '1H', '2H'],
3335+
freq='H', name='x')
3336+
for result in [ abs(idx), np.absolute(idx)]:
3337+
tm.assertIsInstance(result, TimedeltaIndex)
3338+
exp = TimedeltaIndex(['2H', '1H', '0H', '1H', '2H'],
3339+
freq=None, name='x')
3340+
tm.assert_index_equal(result, exp)
3341+
self.assertEqual(result.freq, None)
3342+
3343+
32183344
class TestMultiIndex(Base, tm.TestCase):
32193345
_holder = MultiIndex
32203346
_multiprocess_can_split_ = True

pandas/tseries/index.py

-9
Original file line numberDiff line numberDiff line change
@@ -1077,15 +1077,6 @@ def _fast_union(self, other):
10771077
end=max(left_end, right_end),
10781078
freq=left.offset)
10791079

1080-
def __array_finalize__(self, obj):
1081-
if self.ndim == 0: # pragma: no cover
1082-
return self.item()
1083-
1084-
self.offset = getattr(obj, 'offset', None)
1085-
self.tz = getattr(obj, 'tz', None)
1086-
self.name = getattr(obj, 'name', None)
1087-
self._reset_identity()
1088-
10891080
def __iter__(self):
10901081
"""
10911082
Return an iterator over the boxed values

pandas/tseries/period.py

+38-10
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import pandas.core.common as com
2020
from pandas.core.common import (isnull, _INT64_DTYPE, _maybe_box,
2121
_values_from_object, ABCSeries,
22-
is_integer, is_float, is_object_dtype)
22+
is_integer, is_float, is_object_dtype,
23+
is_float_dtype)
2324
from pandas import compat
2425
from pandas.util.decorators import cache_readonly
2526

@@ -307,6 +308,30 @@ def __contains__(self, key):
307308
return False
308309
return key.ordinal in self._engine
309310

311+
def __array_wrap__(self, result, context=None):
312+
"""
313+
Gets called after a ufunc. Needs additional handling as
314+
PeriodIndex stores internal data as int dtype
315+
316+
Replace this to __numpy_ufunc__ in future version
317+
"""
318+
if isinstance(context, tuple) and len(context) > 0:
319+
func = context[0]
320+
if (func is np.add):
321+
return self._add_delta(context[1][1])
322+
elif (func is np.subtract):
323+
return self._add_delta(-context[1][1])
324+
elif isinstance(func, np.ufunc):
325+
if 'M->M' not in func.types:
326+
msg = "ufunc '{0}' not supported for the PeriodIndex"
327+
# This should be TypeError, but TypeError cannot be raised
328+
# from here because numpy catches.
329+
raise ValueError(msg.format(func.__name__))
330+
331+
if com.is_bool_dtype(result):
332+
return result
333+
return PeriodIndex(result, freq=self.freq, name=self.name)
334+
310335
@property
311336
def _box_func(self):
312337
return lambda x: Period._from_ordinal(ordinal=x, freq=self.freq)
@@ -522,7 +547,18 @@ def _maybe_convert_timedelta(self, other):
522547
base = frequencies.get_base_alias(freqstr)
523548
if base == self.freq.rule_code:
524549
return other.n
525-
raise ValueError("Input has different freq from PeriodIndex(freq={0})".format(self.freq))
550+
elif isinstance(other, np.ndarray):
551+
if com.is_integer_dtype(other):
552+
return other
553+
elif com.is_timedelta64_dtype(other):
554+
offset = frequencies.to_offset(self.freq)
555+
if isinstance(offset, offsets.Tick):
556+
nanos = tslib._delta_to_nanoseconds(other)
557+
offset_nanos = tslib._delta_to_nanoseconds(offset)
558+
if (nanos % offset_nanos).all() == 0:
559+
return nanos // offset_nanos
560+
msg = "Input has different freq from PeriodIndex(freq={0})"
561+
raise ValueError(msg.format(self.freqstr))
526562

527563
def _add_delta(self, other):
528564
ordinal_delta = self._maybe_convert_timedelta(other)
@@ -775,14 +811,6 @@ def _format_native_types(self, na_rep=u('NaT'), date_format=None, **kwargs):
775811
values[imask] = np.array([formatter(dt) for dt in values[imask]])
776812
return values
777813

778-
def __array_finalize__(self, obj):
779-
if not self.ndim: # pragma: no cover
780-
return self.item()
781-
782-
self.freq = getattr(obj, 'freq', None)
783-
self.name = getattr(obj, 'name', None)
784-
self._reset_identity()
785-
786814
def take(self, indices, axis=0):
787815
"""
788816
Analogous to ndarray.take

pandas/tseries/tdi.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,14 @@ def __setstate__(self, state):
278278
raise Exception("invalid pickle state")
279279
_unpickle_compat = __setstate__
280280

281+
def _maybe_update_attributes(self, attrs):
282+
""" Update Index attributes (e.g. freq) depending on op """
283+
freq = attrs.get('freq', None)
284+
if freq is not None:
285+
# no need to infer if freq is None
286+
attrs['freq'] = 'infer'
287+
return attrs
288+
281289
def _add_delta(self, delta):
282290
if isinstance(delta, (Tick, timedelta, np.timedelta64)):
283291
new_values = self._add_delta_td(delta)
@@ -560,14 +568,6 @@ def _fast_union(self, other):
560568
else:
561569
return left
562570

563-
def __array_finalize__(self, obj):
564-
if self.ndim == 0: # pragma: no cover
565-
return self.item()
566-
567-
self.name = getattr(obj, 'name', None)
568-
self.freq = getattr(obj, 'freq', None)
569-
self._reset_identity()
570-
571571
def _wrap_union_result(self, other, result):
572572
name = self.name if self.name == other.name else None
573573
return self._simple_new(result, name=name, freq=None)

0 commit comments

Comments
 (0)