Skip to content

Commit ea18491

Browse files
committed
more concise testing
let Categoricals accept Index let equals work symetrically between Index / CategoricalIndex allow dtype='category' to coerce to CategoricalIndex on creation of Index
1 parent 918c01a commit ea18491

File tree

5 files changed

+48
-25
lines changed

5 files changed

+48
-25
lines changed

pandas/core/categorical.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ def __init__(self, values, categories=None, ordered=False, name=None, fastpath=F
242242
values = values.__array__()
243243

244244
elif isinstance(values, Index):
245-
pass
245+
values = np.array(values)
246+
ordered = True
246247

247248
else:
248249

pandas/core/index.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None, fastpath=False,
143143
return Float64Index(data, copy=copy, dtype=dtype, name=name)
144144
elif issubclass(data.dtype.type, np.bool) or is_bool_dtype(data):
145145
subarr = data.astype('object')
146-
elif is_categorical_dtype(data):
146+
elif is_categorical_dtype(data) or is_categorical_dtype(dtype):
147147
return CategoricalIndex(data, copy=copy, name=name, **kwargs)
148148
else:
149149
subarr = com._asarray_tuplesafe(data, dtype=object)
@@ -153,7 +153,7 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None, fastpath=False,
153153
if copy:
154154
subarr = subarr.copy()
155155

156-
elif is_categorical_dtype(data):
156+
elif is_categorical_dtype(data) or is_categorical_dtype(dtype):
157157
return CategoricalIndex(data, copy=copy, name=name, **kwargs)
158158
elif hasattr(data, '__array__'):
159159
return Index(np.asarray(data), dtype=dtype, copy=copy, name=name,
@@ -2611,13 +2611,21 @@ def equals(self, other):
26112611
if self.is_(other):
26122612
return True
26132613

2614-
if not isinstance(other, CategoricalIndex):
2615-
return False
2616-
26172614
try:
2618-
return (self._data == other._data).all()
2615+
if is_categorical_dtype(other):
2616+
if not self.categories.equals(other.categories):
2617+
return False
2618+
else:
2619+
from pandas import Categorical
2620+
other = Categorical(other, categories=self.categories, ordered=True)
2621+
if isnull(other).any():
2622+
return False
2623+
2624+
return (self._data == other).all()
26192625
except:
2620-
return False
2626+
pass
2627+
2628+
return False
26212629

26222630
@property
26232631
def inferred_type(self):

pandas/tests/test_categorical.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,14 @@ def f():
221221
c_old2 = Categorical([0, 1, 2, 0, 1, 2], [1, 2, 3])
222222
cat = Categorical([1,2], categories=[1,2,3])
223223

224-
def test_constructor_with_categorical_index(self):
224+
def test_constructor_with_index(self):
225225

226226
ci = CategoricalIndex(list('aabbca'),categories=list('cab'))
227227
self.assertTrue(ci.values.equals(Categorical(ci)))
228228

229+
ci = CategoricalIndex(list('aabbca'),categories=list('cab'))
230+
self.assertTrue(ci.values.equals(Categorical(ci.astype(object),categories=ci.categories)))
231+
229232
def test_constructor_with_generator(self):
230233
# This was raising an Error in isnull(single_val).any() because isnull returned a scalar
231234
# for a generator

pandas/tests/test_index.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -1330,10 +1330,10 @@ def test_construction(self):
13301330
idx = self.create_index()
13311331

13321332
result = Index(idx)
1333-
self.assertTrue(result.equals(idx))
1333+
tm.assert_index_equal(result,idx,exact=True)
13341334

13351335
result = Index(idx.values)
1336-
self.assertTrue(result.equals(idx))
1336+
tm.assert_index_equal(result,idx,exact=True)
13371337

13381338
# empty
13391339
result = CategoricalIndex(categories=categories)
@@ -1371,22 +1371,29 @@ def test_construction(self):
13711371
self.assertIsInstance(result, Index)
13721372
self.assertNotIsInstance(result, CategoricalIndex)
13731373

1374+
# specify dtype
1375+
result = Index(np.array(idx), dtype='category')
1376+
tm.assert_index_equal(result,idx,exact=True)
1377+
1378+
result = Index(np.array(idx).tolist(), dtype='category')
1379+
tm.assert_index_equal(result,idx,exact=True)
1380+
13741381
def test_append(self):
13751382

13761383
categories = list('cab')
13771384
ci = CategoricalIndex(list('aabbca'), categories=categories)
13781385

13791386
# append cats with the same categories
13801387
result = ci[:3].append(ci[3:])
1381-
self.assertTrue(result.equals(ci))
1388+
tm.assert_index_equal(result,ci,exact=True)
13821389

13831390
foos = [ci[:1], ci[1:3], ci[3:]]
13841391
result = foos[0].append(foos[1:])
1385-
self.assertTrue(result.equals(ci))
1392+
tm.assert_index_equal(result,ci,exact=True)
13861393

13871394
# empty
13881395
result = ci.append([])
1389-
self.assertTrue(result.equals(ci))
1396+
tm.assert_index_equal(result,ci,exact=True)
13901397

13911398
# appending with different categories or reoreded is not ok
13921399
self.assertRaises(TypeError, lambda : ci.append(ci.values.set_categories(list('abcd'))))
@@ -1395,7 +1402,7 @@ def test_append(self):
13951402
# with objects
13961403
result = ci.append(['c','a'])
13971404
expected = CategoricalIndex(list('aabbcaca'), categories=categories)
1398-
self.assertTrue(result.equals(expected))
1405+
tm.assert_index_equal(result,expected,exact=True)
13991406

14001407
# invalid objects
14011408
self.assertRaises(TypeError, lambda : ci.append(['a','d']))
@@ -1408,17 +1415,17 @@ def test_insert(self):
14081415
#test 0th element
14091416
result = ci.insert(0, 'a')
14101417
expected = CategoricalIndex(list('aaabbca'),categories=categories)
1411-
self.assertTrue(result.equals(expected))
1418+
tm.assert_index_equal(result,expected,exact=True)
14121419

14131420
#test Nth element that follows Python list behavior
14141421
result = ci.insert(-1, 'a')
14151422
expected = CategoricalIndex(list('aabbcaa'),categories=categories)
1416-
self.assertTrue(result.equals(expected))
1423+
tm.assert_index_equal(result,expected,exact=True)
14171424

14181425
#test empty
14191426
result = CategoricalIndex(categories=categories).insert(0, 'a')
14201427
expected = CategoricalIndex(['a'],categories=categories)
1421-
self.assertTrue(result.equals(expected))
1428+
tm.assert_index_equal(result,expected,exact=True)
14221429

14231430
# invalid
14241431
self.assertRaises(ValueError, lambda : ci.insert(0,'d'))
@@ -1430,11 +1437,11 @@ def test_delete(self):
14301437

14311438
result = ci.delete(0)
14321439
expected = CategoricalIndex(list('abbca'),categories=categories)
1433-
self.assertTrue(result.equals(expected))
1440+
tm.assert_index_equal(result,expected,exact=True)
14341441

14351442
result = ci.delete(-1)
14361443
expected = CategoricalIndex(list('aabbc'),categories=categories)
1437-
self.assertTrue(result.equals(expected))
1444+
tm.assert_index_equal(result,expected,exact=True)
14381445

14391446
with tm.assertRaises((IndexError, ValueError)):
14401447
# either depeidnig on numpy version
@@ -1444,10 +1451,13 @@ def test_astype(self):
14441451

14451452
idx = self.create_index()
14461453
result = idx.astype('category')
1447-
self.assertTrue(result.equals(idx))
1454+
tm.assert_index_equal(result,idx,exact=True)
14481455

14491456
result = idx.astype(object)
14501457
self.assertTrue(result.equals(Index(np.array(idx))))
1458+
1459+
# this IS equal, but not the same class
1460+
self.assertTrue(result.equals(idx))
14511461
self.assertIsInstance(result, Index)
14521462
self.assertNotIsInstance(result, CategoricalIndex)
14531463

@@ -1487,6 +1497,9 @@ def test_equals(self):
14871497

14881498
self.assertTrue(ci1.equals(ci1))
14891499
self.assertFalse(ci1.equals(ci2))
1500+
self.assertTrue(ci1.equals(ci1.astype(object)))
1501+
self.assertTrue(ci1.astype(object).equals(ci1))
1502+
14901503
self.assertTrue((ci1 == ci1).all())
14911504
self.assertFalse((ci1 != ci1).all())
14921505
self.assertFalse((ci1 > ci1).all())

pandas/util/testing.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -531,16 +531,14 @@ def assert_equal(a, b, msg=""):
531531
assert a == b, "%s: %r != %r" % (msg.format(a,b), a, b)
532532

533533

534-
def assert_index_equal(left, right):
534+
def assert_index_equal(left, right, exact=False):
535535
assert_isinstance(left, Index, '[index] ')
536536
assert_isinstance(right, Index, '[index] ')
537-
if not left.equals(right):
537+
if not left.equals(right) or (exact and type(left) != type(right)):
538538
raise AssertionError("[index] left [{0} {1}], right [{2} {3}]".format(left.dtype,
539539
left,
540540
right,
541541
right.dtype))
542-
543-
544542
def assert_attr_equal(attr, left, right):
545543
"""checks attributes are equal. Both objects must have attribute."""
546544
left_attr = getattr(left, attr)

0 commit comments

Comments
 (0)