Skip to content

Commit c24a3c8

Browse files
authored
ENH: Index[complex] (#45845)
* ENH: Index[complex] * mypy fixup * whatsnew
1 parent 45eff30 commit c24a3c8

File tree

17 files changed

+91
-41
lines changed

17 files changed

+91
-41
lines changed

doc/source/whatsnew/v1.5.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ Other enhancements
3838
- :meth:`Series.reset_index` and :meth:`DataFrame.reset_index` now support the argument ``allow_duplicates`` (:issue:`44410`)
3939
- :meth:`.GroupBy.min` and :meth:`.GroupBy.max` now supports `Numba <https://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`45428`)
4040
- Implemented a ``bool``-dtype :class:`Index`, passing a bool-dtype arraylike to ``pd.Index`` will now retain ``bool`` dtype instead of casting to ``object`` (:issue:`45061`)
41+
- Implemented a complex-dtype :class:`Index`, passing a complex-dtype arraylike to ``pd.Index`` will now retain complex dtype instead of casting to ``object`` (:issue:`45845`)
42+
4143
-
4244

4345
.. ---------------------------------------------------------------------------

pandas/_libs/index.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class IndexEngine:
3030

3131
class Float64Engine(IndexEngine): ...
3232
class Float32Engine(IndexEngine): ...
33+
class Complex128Engine(IndexEngine): ...
34+
class Complex64Engine(IndexEngine): ...
3335
class Int64Engine(IndexEngine): ...
3436
class Int32Engine(IndexEngine): ...
3537
class Int16Engine(IndexEngine): ...

pandas/_libs/index_class_helper.pxi.in

+11-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ dtypes = [('Float64', 'float64'),
2121
('UInt32', 'uint32'),
2222
('UInt16', 'uint16'),
2323
('UInt8', 'uint8'),
24+
('Complex64', 'complex64'),
25+
('Complex128', 'complex128'),
2426
]
2527
}}
2628

@@ -33,7 +35,7 @@ cdef class {{name}}Engine(IndexEngine):
3335
return _hash.{{name}}HashTable(n)
3436

3537
cdef _check_type(self, object val):
36-
{{if name not in {'Float64', 'Float32'} }}
38+
{{if name not in {'Float64', 'Float32', 'Complex64', 'Complex128'} }}
3739
if not util.is_integer_object(val):
3840
if util.is_float_object(val):
3941
# Make sure Int64Index.get_loc(2.0) works
@@ -45,10 +47,17 @@ cdef class {{name}}Engine(IndexEngine):
4547
# cannot have negative values with unsigned int dtype
4648
raise KeyError(val)
4749
{{endif}}
48-
{{else}}
50+
{{elif name not in {'Complex64', 'Complex128'} }}
4951
if not util.is_integer_object(val) and not util.is_float_object(val):
5052
# in particular catch bool and avoid casting True -> 1.0
5153
raise KeyError(val)
54+
{{else}}
55+
if (not util.is_integer_object(val)
56+
and not util.is_float_object(val)
57+
and not util.is_complex_object(val)
58+
):
59+
# in particular catch bool and avoid casting True -> 1.0
60+
raise KeyError(val)
5261
{{endif}}
5362
return val
5463

pandas/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,8 @@ def _create_mi_with_dt64tz_level():
547547
"uint": tm.makeUIntIndex(100),
548548
"range": tm.makeRangeIndex(100),
549549
"float": tm.makeFloatIndex(100),
550+
"complex64": tm.makeFloatIndex(100).astype("complex64"),
551+
"complex128": tm.makeFloatIndex(100).astype("complex128"),
550552
"num_int64": tm.makeNumericIndex(100, dtype="int64"),
551553
"num_int32": tm.makeNumericIndex(100, dtype="int32"),
552554
"num_int16": tm.makeNumericIndex(100, dtype="int16"),

pandas/core/indexes/base.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def __new__(
505505
if data.dtype.kind in ["i", "u", "f"]:
506506
# maybe coerce to a sub-class
507507
arr = data
508-
elif data.dtype.kind == "b":
508+
elif data.dtype.kind in ["b", "c"]:
509509
# No special subclass, and Index._ensure_array won't do this
510510
# for us.
511511
arr = np.asarray(data)
@@ -636,7 +636,9 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
636636
# NB: assuming away MultiIndex
637637
return Index
638638

639-
elif issubclass(dtype.type, (str, bool, np.bool_)):
639+
elif issubclass(
640+
dtype.type, (str, bool, np.bool_, complex, np.complex64, np.complex128)
641+
):
640642
return Index
641643

642644
raise NotImplementedError(dtype)
@@ -881,6 +883,10 @@ def _engine(
881883
# `self` is not passed into the lambda.
882884
if target_values.dtype == bool:
883885
return libindex.BoolEngine(target_values)
886+
elif target_values.dtype == np.complex64:
887+
return libindex.Complex64Engine(target_values)
888+
elif target_values.dtype == np.complex128:
889+
return libindex.Complex128Engine(target_values)
884890

885891
# error: Argument 1 to "ExtensionEngine" has incompatible type
886892
# "ndarray[Any, Any]"; expected "ExtensionArray"
@@ -6162,9 +6168,6 @@ def _find_common_type_compat(self, target) -> DtypeObj:
61626168

61636169
dtype = find_common_type([self.dtype, target_dtype])
61646170
dtype = common_dtype_categorical_compat([self, target], dtype)
6165-
6166-
if dtype.kind == "c":
6167-
dtype = _dtype_obj
61686171
return dtype
61696172

61706173
@final
@@ -7308,8 +7311,6 @@ def _maybe_cast_data_without_dtype(
73087311
FutureWarning,
73097312
stacklevel=3,
73107313
)
7311-
if result.dtype.kind in ["c"]:
7312-
return subarr
73137314
result = ensure_wrapped_if_datetimelike(result)
73147315
return result
73157316

pandas/core/indexes/numeric.py

+3
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class NumericIndex(Index):
101101
np.dtype(np.uint64): libindex.UInt64Engine,
102102
np.dtype(np.float32): libindex.Float32Engine,
103103
np.dtype(np.float64): libindex.Float64Engine,
104+
np.dtype(np.complex64): libindex.Complex64Engine,
105+
np.dtype(np.complex128): libindex.Complex128Engine,
104106
}
105107

106108
@property
@@ -115,6 +117,7 @@ def inferred_type(self) -> str:
115117
"i": "integer",
116118
"u": "integer",
117119
"f": "floating",
120+
"c": "complex",
118121
}[self.dtype.kind]
119122

120123
def __new__(cls, data=None, dtype: Dtype | None = None, copy=False, name=None):

pandas/tests/arrays/categorical/test_constructors.py

-1
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,6 @@ def test_construction_with_ordered(self, ordered):
676676
cat = Categorical([0, 1, 2], ordered=ordered)
677677
assert cat.ordered == bool(ordered)
678678

679-
@pytest.mark.xfail(reason="Imaginary values not supported in Categorical")
680679
def test_constructor_imaginary(self):
681680
values = [1, 2, 3 + 1j]
682681
c1 = Categorical(values)

pandas/tests/base/test_misc.py

+5
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ def test_searchsorted(request, index_or_series_obj):
160160
reason="np.searchsorted doesn't work on pd.MultiIndex: GH 14833"
161161
)
162162
)
163+
elif obj.dtype.kind == "c" and isinstance(obj, Index):
164+
# TODO: Should Series cases also raise? Looks like they use numpy
165+
# comparison semantics https://github.com/numpy/numpy/issues/15981
166+
mark = pytest.mark.xfail(reason="complex objects are not comparable")
167+
request.node.add_marker(mark)
163168

164169
max_obj = max(obj, default=0)
165170
index = np.searchsorted(obj, max_obj)

pandas/tests/groupby/test_groupby.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1047,15 +1047,14 @@ def test_groupby_complex_numbers():
10471047
)
10481048
expected = DataFrame(
10491049
np.array([1, 1, 1], dtype=np.int64),
1050-
index=Index([(1 + 1j), (1 + 2j), (1 + 0j)], dtype="object", name="b"),
1050+
index=Index([(1 + 1j), (1 + 2j), (1 + 0j)], name="b"),
10511051
columns=Index(["a"], dtype="object"),
10521052
)
10531053
result = df.groupby("b", sort=False).count()
10541054
tm.assert_frame_equal(result, expected)
10551055

10561056
# Sorted by the magnitude of the complex numbers
1057-
# Complex Index dtype is cast to object
1058-
expected.index = Index([(1 + 0j), (1 + 1j), (1 + 2j)], dtype="object", name="b")
1057+
expected.index = Index([(1 + 0j), (1 + 1j), (1 + 2j)], name="b")
10591058
result = df.groupby("b", sort=True).count()
10601059
tm.assert_frame_equal(result, expected)
10611060

pandas/tests/indexes/multi/test_setops.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -525,11 +525,17 @@ def test_union_nan_got_duplicated():
525525
tm.assert_index_equal(result, mi2)
526526

527527

528-
def test_union_duplicates(index):
528+
def test_union_duplicates(index, request):
529529
# GH#38977
530530
if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)):
531531
# No duplicates in empty indexes
532532
return
533+
if index.dtype.kind == "c":
534+
mark = pytest.mark.xfail(
535+
reason="sort_values() call raises bc complex objects are not comparable"
536+
)
537+
request.node.add_marker(mark)
538+
533539
values = index.unique().values.tolist()
534540
mi1 = MultiIndex.from_arrays([values, [1] * len(values)])
535541
mi2 = MultiIndex.from_arrays([[values[0]] + values, [1] * (len(values) + 1)])

pandas/tests/indexes/test_any_index.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ def test_mutability(index):
4646
index[0] = index[0]
4747

4848

49-
def test_map_identity_mapping(index):
49+
def test_map_identity_mapping(index, request):
5050
# GH#12766
51+
if index.dtype == np.complex64:
52+
mark = pytest.mark.xfail(
53+
reason="maybe_downcast_to_dtype doesn't handle complex"
54+
)
55+
request.node.add_marker(mark)
56+
5157
result = index.map(lambda x: x)
5258
if index.dtype == object and result.dtype == bool:
5359
assert (index == result).all()

pandas/tests/indexes/test_base.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -532,14 +532,19 @@ def test_map_dictlike_simple(self, mapper):
532532
lambda values, index: Series(values, index),
533533
],
534534
)
535-
def test_map_dictlike(self, index, mapper):
535+
def test_map_dictlike(self, index, mapper, request):
536536
# GH 12756
537537
if isinstance(index, CategoricalIndex):
538538
# Tested in test_categorical
539539
return
540540
elif not index.is_unique:
541541
# Cannot map duplicated index
542542
return
543+
if index.dtype == np.complex64 and not isinstance(mapper(index, index), Series):
544+
mark = pytest.mark.xfail(
545+
reason="maybe_downcast_to_dtype doesn't handle complex"
546+
)
547+
request.node.add_marker(mark)
543548

544549
rng = np.arange(len(index), 0, -1)
545550

@@ -664,7 +669,8 @@ def test_format_missing(self, vals, nulls_fixture):
664669
# 2845
665670
vals = list(vals) # Copy for each iteration
666671
vals.append(nulls_fixture)
667-
index = Index(vals)
672+
index = Index(vals, dtype=object)
673+
# TODO: case with complex dtype?
668674

669675
formatted = index.format()
670676
null_repr = "NaN" if isinstance(nulls_fixture, float) else str(nulls_fixture)

pandas/tests/indexes/test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ def test_astype_preserves_name(self, index, dtype):
389389
):
390390
# This astype is deprecated in favor of tz_localize
391391
warn = FutureWarning
392+
elif index.dtype.kind == "c" and dtype in ["float64", "int64", "uint64"]:
393+
# imaginary components discarded
394+
warn = np.ComplexWarning
395+
392396
try:
393397
# Some of these conversions cannot succeed so we use a try / except
394398
with tm.assert_produces_warning(warn):

pandas/tests/indexes/test_numpy_compat.py

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_numpy_ufuncs_basic(index, func):
7171
elif (
7272
isinstance(index, NumericIndex)
7373
or (not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric)
74+
or (index.dtype.kind == "c" and func not in [np.deg2rad, np.rad2deg])
7475
or index.dtype == bool
7576
):
7677
# coerces to float (e.g. np.sin)
@@ -122,6 +123,7 @@ def test_numpy_ufuncs_other(index, func):
122123
elif (
123124
isinstance(index, NumericIndex)
124125
or (not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric)
126+
or (index.dtype.kind == "c" and func is not np.signbit)
125127
or index.dtype == bool
126128
):
127129
# Results in bool array

pandas/tests/indexes/test_setops.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,25 @@ def test_union_different_types(index_flat, index_flat2, request):
7272

7373
common_dtype = find_common_type([idx1.dtype, idx2.dtype])
7474

75+
warn = None
76+
if not len(idx1) or not len(idx2):
77+
pass
78+
elif (
79+
idx1.dtype.kind == "c"
80+
and (
81+
idx2.dtype.kind not in ["i", "u", "f", "c"]
82+
or not isinstance(idx2.dtype, np.dtype)
83+
)
84+
) or (
85+
idx2.dtype.kind == "c"
86+
and (
87+
idx1.dtype.kind not in ["i", "u", "f", "c"]
88+
or not isinstance(idx1.dtype, np.dtype)
89+
)
90+
):
91+
# complex objects non-sortable
92+
warn = RuntimeWarning
93+
7594
any_uint64 = idx1.dtype == np.uint64 or idx2.dtype == np.uint64
7695
idx1_signed = is_signed_integer_dtype(idx1.dtype)
7796
idx2_signed = is_signed_integer_dtype(idx2.dtype)
@@ -81,8 +100,9 @@ def test_union_different_types(index_flat, index_flat2, request):
81100
idx1 = idx1.sort_values()
82101
idx2 = idx2.sort_values()
83102

84-
res1 = idx1.union(idx2)
85-
res2 = idx2.union(idx1)
103+
with tm.assert_produces_warning(warn, match="'<' not supported between"):
104+
res1 = idx1.union(idx2)
105+
res2 = idx2.union(idx1)
86106

87107
if any_uint64 and (idx1_signed or idx2_signed):
88108
assert res1.dtype == np.dtype("O")

pandas/tests/indexing/test_coercion.py

+5-20
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,6 @@ def test_where_object(self, index_or_series, fill_val, exp_dtype):
433433
)
434434
def test_where_int64(self, index_or_series, fill_val, exp_dtype, request):
435435
klass = index_or_series
436-
if klass is pd.Index and exp_dtype is np.complex128:
437-
mark = pytest.mark.xfail(reason="Complex Index not supported")
438-
request.node.add_marker(mark)
439436

440437
obj = klass([1, 2, 3, 4])
441438
assert obj.dtype == np.int64
@@ -447,9 +444,6 @@ def test_where_int64(self, index_or_series, fill_val, exp_dtype, request):
447444
)
448445
def test_where_float64(self, index_or_series, fill_val, exp_dtype, request):
449446
klass = index_or_series
450-
if klass is pd.Index and exp_dtype is np.complex128:
451-
mark = pytest.mark.xfail(reason="Complex Index not supported")
452-
request.node.add_marker(mark)
453447

454448
obj = klass([1.1, 2.2, 3.3, 4.4])
455449
assert obj.dtype == np.float64
@@ -464,9 +458,9 @@ def test_where_float64(self, index_or_series, fill_val, exp_dtype, request):
464458
(True, object),
465459
],
466460
)
467-
def test_where_series_complex128(self, fill_val, exp_dtype):
468-
klass = pd.Series # TODO: use index_or_series once we have Index[complex]
469-
obj = klass([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j])
461+
def test_where_complex128(self, index_or_series, fill_val, exp_dtype):
462+
klass = index_or_series
463+
obj = klass([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=np.complex128)
470464
assert obj.dtype == np.complex128
471465
self._run_test(obj, fill_val, klass, exp_dtype)
472466

@@ -608,11 +602,6 @@ def test_fillna_float64(self, index_or_series, fill_val, fill_dtype):
608602
assert obj.dtype == np.float64
609603

610604
exp = klass([1.1, fill_val, 3.3, 4.4])
611-
# float + complex -> we don't support a complex Index
612-
# complex for Series,
613-
# object for Index
614-
if fill_dtype == np.complex128 and klass == pd.Index:
615-
fill_dtype = object
616605
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
617606

618607
@pytest.mark.parametrize(
@@ -624,16 +613,12 @@ def test_fillna_float64(self, index_or_series, fill_val, fill_dtype):
624613
(True, object),
625614
],
626615
)
627-
def test_fillna_complex128(self, index_or_series, fill_val, fill_dtype, request):
616+
def test_fillna_complex128(self, index_or_series, fill_val, fill_dtype):
628617
klass = index_or_series
629-
if klass is pd.Index:
630-
mark = pytest.mark.xfail(reason="No Index[complex]")
631-
request.node.add_marker(mark)
632-
633618
obj = klass([1 + 1j, np.nan, 3 + 3j, 4 + 4j], dtype=np.complex128)
634619
assert obj.dtype == np.complex128
635620

636-
exp = klass([1 + 1j, fill_val, 3 + 3j, 4 + 4j], dtype=fill_dtype)
621+
exp = klass([1 + 1j, fill_val, 3 + 3j, 4 + 4j])
637622
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
638623

639624
@pytest.mark.parametrize(

pandas/tests/series/methods/test_value_counts.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,12 @@ def test_value_counts_bool_with_nan(self, ser, dropna, exp):
216216
Series([3, 2, 1], index=pd.Index([3j, 1 + 1j, 1], dtype=np.complex128)),
217217
),
218218
(
219-
[1 + 1j, 1 + 1j, 1, 3j, 3j, 3j],
219+
np.array([1 + 1j, 1 + 1j, 1, 3j, 3j, 3j], dtype=np.complex64),
220220
Series([3, 2, 1], index=pd.Index([3j, 1 + 1j, 1], dtype=np.complex64)),
221221
),
222222
],
223223
)
224224
def test_value_counts_complex_numbers(self, input_array, expected):
225225
# GH 17927
226-
# Complex Index dtype is cast to object
227226
result = Series(input_array).value_counts()
228227
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)