Skip to content

Commit 49e49bc

Browse files
jrebackmax-sixty
authored andcommitted
provide proper index coercion with _shallow_copy for insert,delete,append operations
1 parent 6ec9624 commit 49e49bc

File tree

4 files changed

+117
-43
lines changed

4 files changed

+117
-43
lines changed

pandas/core/index.py

+68-24
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ class Index(IndexOpsMixin, PandasObject):
8989
_left_indexer = _algos.left_join_indexer_object
9090
_inner_indexer = _algos.inner_join_indexer_object
9191
_outer_indexer = _algos.outer_join_indexer_object
92-
9392
_box_scalars = False
9493

9594
_typ = 'index'
@@ -204,6 +203,17 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None, fastpath=False,
204203

205204
@classmethod
206205
def _simple_new(cls, values, name=None, **kwargs):
206+
"""
207+
we require the we have a dtype compat for the values
208+
if we are passed a non-dtype compat, then coerce using the constructor
209+
210+
Must be careful not to recurse.
211+
"""
212+
if not hasattr(values, 'dtype'):
213+
values = np.array(values,copy=False)
214+
if is_object_dtype(values):
215+
values = cls(values, name=name, **kwargs).values
216+
207217
result = object.__new__(cls)
208218
result._data = values
209219
result.name = name
@@ -341,15 +351,41 @@ def view(self, cls=None):
341351
result._id = self._id
342352
return result
343353

344-
def _shallow_copy(self, values=None, **kwargs):
345-
""" create a new Index, don't copy the data, use the same object attributes
346-
with passed in attributes taking precedence """
354+
def _shallow_copy(self, values=None, infer=False, **kwargs):
355+
"""
356+
create a new Index, don't copy the data, use the same object attributes
357+
with passed in attributes taking precedence
358+
359+
*this is an internal non-public method*
360+
361+
Parameters
362+
----------
363+
values : the values to create the new Index, optional
364+
infer : boolean, default False
365+
if True, infer the new type of the passed values
366+
kwargs : updates the default attributes for this Index
367+
"""
347368
if values is None:
348369
values = self.values
349370
attributes = self._get_attributes_dict()
350371
attributes.update(kwargs)
372+
373+
if infer:
374+
attributes['copy'] = False
375+
return Index(values, **attributes)
376+
351377
return self.__class__._simple_new(values,**attributes)
352378

379+
def _coerce_scalar_to_index(self, item):
380+
"""
381+
we need to coerce a scalar to a compat for our index type
382+
383+
Parameters
384+
----------
385+
item : scalar item to coerce
386+
"""
387+
return Index([item], dtype=self.dtype, **self._get_attributes_dict())
388+
353389
def copy(self, names=None, name=None, dtype=None, deep=False):
354390
"""
355391
Make a copy of this object. Name and dtype sets those attributes on
@@ -1132,7 +1168,9 @@ def append(self, other):
11321168
appended : Index
11331169
"""
11341170
to_concat, name = self._ensure_compat_append(other)
1135-
return Index(np.concatenate(to_concat), name=name)
1171+
attribs = self._get_attributes_dict()
1172+
attribs['name'] = name
1173+
return self._shallow_copy(np.concatenate(to_concat), infer=True, **attribs)
11361174

11371175
@staticmethod
11381176
def _ensure_compat_concat(indexes):
@@ -1548,16 +1586,12 @@ def sym_diff(self, other, result_name=None):
15481586
other, result_name_update = self._convert_can_do_setop(other)
15491587
if result_name is None:
15501588
result_name = result_name_update
1551-
the_diff_sorted = sorted(set((self.difference(other)).union(other.difference(self))))
1552-
if isinstance(self, MultiIndex):
1553-
# multiindexes don't currently work well with _shallow_copy & can't be supplied a name
1554-
return Index(the_diff_sorted, **self._get_attributes_dict())
1555-
else:
1556-
# convert list to Index and pull values out - required for subtypes such as PeriodIndex
1557-
diff_array = Index(the_diff_sorted).values
1558-
return self._shallow_copy(diff_array, name=result_name)
1559-
1560-
1589+
the_diff = sorted(set((self.difference(other)).union(other.difference(self))))
1590+
attribs = self._get_attributes_dict()
1591+
attribs['name'] = result_name
1592+
if 'freq' in attribs:
1593+
attribs['freq'] = None
1594+
return self._shallow_copy(the_diff, infer=True, **attribs)
15611595

15621596
def get_loc(self, key, method=None):
15631597
"""
@@ -2535,7 +2569,8 @@ def delete(self, loc):
25352569
-------
25362570
new_index : Index
25372571
"""
2538-
return self._shallow_copy(np.delete(self._data, loc))
2572+
attribs = self._get_attributes_dict()
2573+
return self._shallow_copy(np.delete(self._data, loc), **attribs)
25392574

25402575
def insert(self, loc, item):
25412576
"""
@@ -2551,12 +2586,13 @@ def insert(self, loc, item):
25512586
-------
25522587
new_index : Index
25532588
"""
2554-
indexes=[self[:loc],
2555-
Index([item]),
2556-
self[loc:]]
2557-
2558-
return indexes[0].append(indexes[1]).append(indexes[2])
2589+
_self = np.asarray(self)
2590+
item = self._coerce_scalar_to_index(item).values
25592591

2592+
idx = np.concatenate(
2593+
(_self[:loc], item, _self[loc:]))
2594+
attribs = self._get_attributes_dict()
2595+
return self._shallow_copy(idx, infer=True, **attribs)
25602596

25612597
def drop(self, labels, errors='raise'):
25622598
"""
@@ -3687,7 +3723,7 @@ class MultiIndex(Index):
36873723
Level of sortedness (must be lexicographically sorted by that
36883724
level)
36893725
names : optional sequence of objects
3690-
Names for each of the index levels.
3726+
Names for each of the index levels. (name is accepted for compat)
36913727
copy : boolean, default False
36923728
Copy the meta-data
36933729
verify_integrity : boolean, default True
@@ -3703,8 +3739,11 @@ class MultiIndex(Index):
37033739
rename = Index.set_names
37043740

37053741
def __new__(cls, levels=None, labels=None, sortorder=None, names=None,
3706-
copy=False, verify_integrity=True, _set_identity=True, **kwargs):
3742+
copy=False, verify_integrity=True, _set_identity=True, name=None, **kwargs):
37073743

3744+
# compat with Index
3745+
if name is not None:
3746+
names = name
37083747
if levels is None or labels is None:
37093748
raise TypeError("Must pass both levels and labels")
37103749
if len(levels) != len(labels):
@@ -4013,7 +4052,12 @@ def view(self, cls=None):
40134052
result._id = self._id
40144053
return result
40154054

4016-
_shallow_copy = view
4055+
def _shallow_copy(self, values=None, infer=False, **kwargs):
4056+
if values is not None:
4057+
if 'name' in kwargs:
4058+
kwargs['names'] = kwargs.pop('name',None)
4059+
return MultiIndex.from_tuples(values, **kwargs)
4060+
return self.view()
40174061

40184062
@cache_readonly
40194063
def dtype(self):

pandas/tests/test_index.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -401,30 +401,33 @@ def test_insert_base(self):
401401
for name, idx in compat.iteritems(self.indices):
402402
result = idx[1:4]
403403

404-
if len(idx)>0:
405-
#test 0th element
406-
self.assertTrue(idx[0:4].equals(
407-
result.insert(0, idx[0])))
404+
if not len(idx):
405+
continue
406+
407+
#test 0th element
408+
self.assertTrue(idx[0:4].equals(
409+
result.insert(0, idx[0])))
408410

409411
def test_delete_base(self):
410412

411413
for name, idx in compat.iteritems(self.indices):
412414

413-
if len(idx)>0:
415+
if not len(idx):
416+
continue
414417

415-
expected = idx[1:]
416-
result = idx.delete(0)
417-
self.assertTrue(result.equals(expected))
418-
self.assertEqual(result.name, expected.name)
418+
expected = idx[1:]
419+
result = idx.delete(0)
420+
self.assertTrue(result.equals(expected))
421+
self.assertEqual(result.name, expected.name)
419422

420-
expected = idx[:-1]
421-
result = idx.delete(-1)
422-
self.assertTrue(result.equals(expected))
423-
self.assertEqual(result.name, expected.name)
423+
expected = idx[:-1]
424+
result = idx.delete(-1)
425+
self.assertTrue(result.equals(expected))
426+
self.assertEqual(result.name, expected.name)
424427

425-
with tm.assertRaises((IndexError, ValueError)):
426-
# either depending on numpy version
427-
result = idx.delete(len(idx))
428+
with tm.assertRaises((IndexError, ValueError)):
429+
# either depending on numpy version
430+
result = idx.delete(len(idx))
428431

429432

430433
def test_equals_op(self):

pandas/tseries/index.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import numpy as np
66
from pandas.core.common import (_NS_DTYPE, _INT64_DTYPE,
77
_values_from_object, _maybe_box,
8-
ABCSeries, is_integer, is_float)
8+
ABCSeries, is_integer, is_float,
9+
is_object_dtype, is_datetime64_dtype)
910
from pandas.core.index import Index, Int64Index, Float64Index
1011
import pandas.compat as compat
1112
from pandas.compat import u
@@ -494,9 +495,16 @@ def _local_timestamps(self):
494495

495496
@classmethod
496497
def _simple_new(cls, values, name=None, freq=None, tz=None, **kwargs):
498+
"""
499+
we require the we have a dtype compat for the values
500+
if we are passed a non-dtype compat, then coerce using the constructor
501+
"""
502+
497503
if not getattr(values,'dtype',None):
498504
values = np.array(values,copy=False)
499-
if values.dtype != _NS_DTYPE:
505+
if is_object_dtype(values):
506+
return cls(values, name=name, freq=freq, tz=tz, **kwargs).values
507+
elif not is_datetime64_dtype(values):
500508
values = com._ensure_int64(values).view(_NS_DTYPE)
501509

502510
result = object.__new__(cls)

pandas/tseries/period.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pandas.core.common as com
2020
from pandas.core.common import (isnull, _INT64_DTYPE, _maybe_box,
2121
_values_from_object, ABCSeries,
22-
is_integer, is_float)
22+
is_integer, is_float, is_object_dtype)
2323
from pandas import compat
2424
from pandas.lib import Timestamp, Timedelta
2525
import pandas.lib as lib
@@ -259,13 +259,32 @@ def _from_arraylike(cls, data, freq, tz):
259259

260260
@classmethod
261261
def _simple_new(cls, values, name=None, freq=None, **kwargs):
262+
if not getattr(values,'dtype',None):
263+
values = np.array(values,copy=False)
264+
if is_object_dtype(values):
265+
return PeriodIndex(values, name=name, freq=freq, **kwargs)
266+
262267
result = object.__new__(cls)
263268
result._data = values
264269
result.name = name
265270
result.freq = freq
266271
result._reset_identity()
267272
return result
268273

274+
def _shallow_copy(self, values=None, infer=False, **kwargs):
275+
""" we always want to return a PeriodIndex """
276+
return super(PeriodIndex, self)._shallow_copy(values=values, infer=False, **kwargs)
277+
278+
def _coerce_scalar_to_index(self, item):
279+
"""
280+
we need to coerce a scalar to a compat for our index type
281+
282+
Parameters
283+
----------
284+
item : scalar item to coerce
285+
"""
286+
return PeriodIndex([item], **self._get_attributes_dict())
287+
269288
@property
270289
def _na_value(self):
271290
return self._box_func(tslib.iNaT)

0 commit comments

Comments
 (0)