Skip to content

Commit b77797c

Browse files
authored
ENH: Index[complex] (#45256)
1 parent e681fcd commit b77797c

16 files changed

+99
-36
lines changed

pandas/_libs/index.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class IndexEngine:
2929

3030
class Float64Engine(IndexEngine): ...
3131
class Float32Engine(IndexEngine): ...
32+
class Complex128Engine(IndexEngine): ...
33+
class Complex64Engine(IndexEngine): ...
3234
class Int64Engine(IndexEngine): ...
3335
class Int32Engine(IndexEngine): ...
3436
class Int16Engine(IndexEngine): ...

pandas/_libs/index_class_helper.pxi.in

+11-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ dtypes = [('Float64', 'float64'),
2121
('UInt32', 'uint32'),
2222
('UInt16', 'uint16'),
2323
('UInt8', 'uint8'),
24+
('Complex64', 'complex64'),
25+
('Complex128', 'complex128'),
2426
]
2527
}}
2628

@@ -33,18 +35,25 @@ cdef class {{name}}Engine(IndexEngine):
3335
return _hash.{{name}}HashTable(n)
3436

3537
cdef _check_type(self, object val):
36-
{{if name not in {'Float64', 'Float32'} }}
38+
{{if name not in {'Float64', 'Float32', 'Complex64', 'Complex128'} }}
3739
if not util.is_integer_object(val):
3840
raise KeyError(val)
3941
{{if name.startswith("U")}}
4042
if val < 0:
4143
# cannot have negative values with unsigned int dtype
4244
raise KeyError(val)
4345
{{endif}}
44-
{{else}}
46+
{{elif name not in {'Complex64', 'Complex128'} }}
4547
if not util.is_integer_object(val) and not util.is_float_object(val):
4648
# in particular catch bool and avoid casting True -> 1.0
4749
raise KeyError(val)
50+
{{else}}
51+
if (not util.is_integer_object(val)
52+
and not util.is_float_object(val)
53+
and not util.is_complex_object(val)
54+
):
55+
# in particular catch bool and avoid casting True -> 1.0
56+
raise KeyError(val)
4857
{{endif}}
4958

5059

pandas/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,8 @@ def _create_mi_with_dt64tz_level():
539539
"uint": tm.makeUIntIndex(100),
540540
"range": tm.makeRangeIndex(100),
541541
"float": tm.makeFloatIndex(100),
542+
"complex64": tm.makeFloatIndex(100).astype("complex64"),
543+
"complex128": tm.makeFloatIndex(100).astype("complex128"),
542544
"num_int64": tm.makeNumericIndex(100, dtype="int64"),
543545
"num_int32": tm.makeNumericIndex(100, dtype="int32"),
544546
"num_int16": tm.makeNumericIndex(100, dtype="int16"),

pandas/core/indexes/base.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,8 @@ def __new__(
487487
if data.dtype.kind in ["i", "u", "f"]:
488488
# maybe coerce to a sub-class
489489
arr = data
490+
elif data.dtype.kind in ["c"]:
491+
arr = np.asarray(data)
490492
else:
491493
arr = com.asarray_tuplesafe(data, dtype=_dtype_obj)
492494

@@ -614,7 +616,9 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
614616
# NB: assuming away MultiIndex
615617
return Index
616618

617-
elif issubclass(dtype.type, (str, bool, np.bool_)):
619+
elif issubclass(
620+
dtype.type, (str, bool, np.bool_, complex, np.complex64, np.complex128)
621+
):
618622
return Index
619623

620624
raise NotImplementedError(dtype)
@@ -858,6 +862,11 @@ def _engine(
858862
# TODO(ExtensionIndex): use libindex.ExtensionEngine(self._values)
859863
return libindex.ObjectEngine(self._get_engine_target())
860864

865+
elif self.values.dtype == np.complex64:
866+
return libindex.Complex64Engine(self._get_engine_target())
867+
elif self.values.dtype == np.complex128:
868+
return libindex.Complex128Engine(self._get_engine_target())
869+
861870
# to avoid a reference cycle, bind `target_values` to a local variable, so
862871
# `self` is not passed into the lambda.
863872
target_values = self._get_engine_target()
@@ -5980,8 +5989,6 @@ def _find_common_type_compat(self, target) -> DtypeObj:
59805989
# FIXME: find_common_type incorrect with Categorical GH#38240
59815990
# FIXME: some cases where float64 cast can be lossy?
59825991
dtype = np.dtype(np.float64)
5983-
if dtype.kind == "c":
5984-
dtype = _dtype_obj
59855992
return dtype
59865993

59875994
@final
@@ -7120,7 +7127,7 @@ def _maybe_cast_data_without_dtype(
71207127
FutureWarning,
71217128
stacklevel=3,
71227129
)
7123-
if result.dtype.kind in ["b", "c"]:
7130+
if result.dtype.kind in ["b"]:
71247131
return subarr
71257132
result = ensure_wrapped_if_datetimelike(result)
71267133
return result

pandas/core/indexes/numeric.py

+3
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def _can_hold_na(self) -> bool: # type: ignore[override]
114114
np.dtype(np.uint64): libindex.UInt64Engine,
115115
np.dtype(np.float32): libindex.Float32Engine,
116116
np.dtype(np.float64): libindex.Float64Engine,
117+
np.dtype(np.complex64): libindex.Complex64Engine,
118+
np.dtype(np.complex128): libindex.Complex128Engine,
117119
}
118120

119121
@property
@@ -128,6 +130,7 @@ def inferred_type(self) -> str:
128130
"i": "integer",
129131
"u": "integer",
130132
"f": "floating",
133+
"c": "complex",
131134
}[self.dtype.kind]
132135

133136
def __new__(cls, data=None, dtype: Dtype | None = None, copy=False, name=None):

pandas/tests/arrays/categorical/test_constructors.py

-1
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,6 @@ def test_construction_with_ordered(self, ordered):
676676
cat = Categorical([0, 1, 2], ordered=ordered)
677677
assert cat.ordered == bool(ordered)
678678

679-
@pytest.mark.xfail(reason="Imaginary values not supported in Categorical")
680679
def test_constructor_imaginary(self):
681680
values = [1, 2, 3 + 1j]
682681
c1 = Categorical(values)

pandas/tests/base/test_misc.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,19 @@ def test_memory_usage_components_narrow_series(dtype):
137137
assert total_usage == non_index_usage + index_usage
138138

139139

140-
def test_searchsorted(index_or_series_obj):
140+
def test_searchsorted(index_or_series_obj, request):
141141
# numpy.searchsorted calls obj.searchsorted under the hood.
142142
# See gh-12238
143143
obj = index_or_series_obj
144144

145145
if isinstance(obj, pd.MultiIndex):
146146
# See gh-14833
147147
pytest.skip("np.searchsorted doesn't work on pd.MultiIndex")
148+
if obj.dtype.kind == "c" and isinstance(obj, Index):
149+
# TODO: Should Series cases also raise? Looks like they use numpy
150+
# comparison semantics https://github.com/numpy/numpy/issues/15981
151+
mark = pytest.mark.xfail(reason="complex objects are not comparable")
152+
request.node.add_marker(mark)
148153

149154
max_obj = max(obj, default=0)
150155
index = np.searchsorted(obj, max_obj)

pandas/tests/groupby/test_groupby.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1054,15 +1054,14 @@ def test_groupby_complex_numbers():
10541054
)
10551055
expected = DataFrame(
10561056
np.array([1, 1, 1], dtype=np.int64),
1057-
index=Index([(1 + 1j), (1 + 2j), (1 + 0j)], dtype="object", name="b"),
1057+
index=Index([(1 + 1j), (1 + 2j), (1 + 0j)], name="b"),
10581058
columns=Index(["a"], dtype="object"),
10591059
)
10601060
result = df.groupby("b", sort=False).count()
10611061
tm.assert_frame_equal(result, expected)
10621062

10631063
# Sorted by the magnitude of the complex numbers
1064-
# Complex Index dtype is cast to object
1065-
expected.index = Index([(1 + 0j), (1 + 1j), (1 + 2j)], dtype="object", name="b")
1064+
expected.index = Index([(1 + 0j), (1 + 1j), (1 + 2j)], name="b")
10661065
result = df.groupby("b", sort=True).count()
10671066
tm.assert_frame_equal(result, expected)
10681067

pandas/tests/indexes/multi/test_setops.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -525,11 +525,17 @@ def test_union_nan_got_duplicated():
525525
tm.assert_index_equal(result, mi2)
526526

527527

528-
def test_union_duplicates(index):
528+
def test_union_duplicates(index, request):
529529
# GH#38977
530530
if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)):
531531
# No duplicates in empty indexes
532532
return
533+
if index.dtype.kind == "c":
534+
mark = pytest.mark.xfail(
535+
reason="sort_values() call raises bc complex objects are not comparable"
536+
)
537+
request.node.add_marker(mark)
538+
533539
values = index.unique().values.tolist()
534540
mi1 = MultiIndex.from_arrays([values, [1] * len(values)])
535541
mi2 = MultiIndex.from_arrays([[values[0]] + values, [1] * (len(values) + 1)])

pandas/tests/indexes/test_any_index.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ def test_mutability(index):
4646
index[0] = index[0]
4747

4848

49-
def test_map_identity_mapping(index):
49+
def test_map_identity_mapping(index, request):
5050
# GH#12766
51+
if index.dtype == np.complex64:
52+
mark = pytest.mark.xfail(
53+
reason="maybe_downcast_to_dtype doesn't handle complex"
54+
)
55+
request.node.add_marker(mark)
56+
5157
result = index.map(lambda x: x)
5258
tm.assert_index_equal(result, index, exact="equiv")
5359

pandas/tests/indexes/test_base.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -526,14 +526,19 @@ def test_map_dictlike_simple(self, mapper):
526526
lambda values, index: Series(values, index),
527527
],
528528
)
529-
def test_map_dictlike(self, index, mapper):
529+
def test_map_dictlike(self, index, mapper, request):
530530
# GH 12756
531531
if isinstance(index, CategoricalIndex):
532532
# Tested in test_categorical
533533
return
534534
elif not index.is_unique:
535535
# Cannot map duplicated index
536536
return
537+
if index.dtype == np.complex64 and not isinstance(mapper(index, index), Series):
538+
mark = pytest.mark.xfail(
539+
reason="maybe_downcast_to_dtype doesn't handle complex"
540+
)
541+
request.node.add_marker(mark)
537542

538543
rng = np.arange(len(index), 0, -1)
539544

@@ -655,7 +660,8 @@ def test_format_missing(self, vals, nulls_fixture):
655660
# 2845
656661
vals = list(vals) # Copy for each iteration
657662
vals.append(nulls_fixture)
658-
index = Index(vals)
663+
index = Index(vals, dtype=object)
664+
# TODO: case with complex dtype?
659665

660666
formatted = index.format()
661667
expected = [str(index[0]), str(index[1]), str(index[2]), "NaN"]

pandas/tests/indexes/test_common.py

+7
Original file line numberDiff line numberDiff line change
@@ -386,13 +386,20 @@ def test_astype_preserves_name(self, index, dtype):
386386
if dtype in ["int64", "uint64"]:
387387
if needs_i8_conversion(index.dtype):
388388
warn = FutureWarning
389+
elif index.dtype.kind == "c":
390+
# imaginary components discarded
391+
warn = np.ComplexWarning
389392
elif (
390393
isinstance(index, DatetimeIndex)
391394
and index.tz is not None
392395
and dtype == "datetime64[ns]"
393396
):
394397
# This astype is deprecated in favor of tz_localize
395398
warn = FutureWarning
399+
elif index.dtype.kind == "c" and dtype == "float64":
400+
# imaginary components discarded
401+
warn = np.ComplexWarning
402+
396403
try:
397404
# Some of these conversions cannot succeed so we use a try / except
398405
with tm.assert_produces_warning(warn):

pandas/tests/indexes/test_numpy_compat.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def test_numpy_ufuncs_basic(index, func):
5454
with tm.external_error_raised((TypeError, AttributeError)):
5555
with np.errstate(all="ignore"):
5656
func(index)
57-
elif isinstance(index, NumericIndex) or (
58-
not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric
57+
elif (
58+
isinstance(index, NumericIndex)
59+
or (not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric)
60+
or (index.dtype.kind == "c" and func not in [np.deg2rad, np.rad2deg])
5961
):
6062
# coerces to float (e.g. np.sin)
6163
with np.errstate(all="ignore"):
@@ -99,8 +101,10 @@ def test_numpy_ufuncs_other(index, func, request):
99101
with tm.external_error_raised(TypeError):
100102
func(index)
101103

102-
elif isinstance(index, NumericIndex) or (
103-
not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric
104+
elif (
105+
isinstance(index, NumericIndex)
106+
or (not isinstance(index.dtype, np.dtype) and index.dtype._is_numeric)
107+
or (index.dtype.kind == "c" and func is not np.signbit)
104108
):
105109
# Results in bool array
106110
result = func(index)

pandas/tests/indexes/test_setops.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,25 @@ def test_union_different_types(index_flat, index_flat2, request):
6767

6868
common_dtype = find_common_type([idx1.dtype, idx2.dtype])
6969

70+
warn = None
71+
if not len(idx1) or not len(idx2):
72+
pass
73+
elif (
74+
idx1.dtype.kind == "c"
75+
and (
76+
idx2.dtype.kind not in ["i", "u", "f", "c"]
77+
or not isinstance(idx2.dtype, np.dtype)
78+
)
79+
) or (
80+
idx2.dtype.kind == "c"
81+
and (
82+
idx1.dtype.kind not in ["i", "u", "f", "c"]
83+
or not isinstance(idx1.dtype, np.dtype)
84+
)
85+
):
86+
# complex objects non-sortable
87+
warn = RuntimeWarning
88+
7089
any_uint64 = idx1.dtype == np.uint64 or idx2.dtype == np.uint64
7190
idx1_signed = is_signed_integer_dtype(idx1.dtype)
7291
idx2_signed = is_signed_integer_dtype(idx2.dtype)
@@ -76,8 +95,9 @@ def test_union_different_types(index_flat, index_flat2, request):
7695
idx1 = idx1.sort_values()
7796
idx2 = idx2.sort_values()
7897

79-
res1 = idx1.union(idx2)
80-
res2 = idx2.union(idx1)
98+
with tm.assert_produces_warning(warn, match="'<' not supported between"):
99+
res1 = idx1.union(idx2)
100+
res2 = idx2.union(idx1)
81101

82102
if any_uint64 and (idx1_signed or idx2_signed):
83103
assert res1.dtype == np.dtype("O")

pandas/tests/indexing/test_coercion.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,6 @@ def test_where_object(self, index_or_series, fill_val, exp_dtype):
433433
)
434434
def test_where_int64(self, index_or_series, fill_val, exp_dtype, request):
435435
klass = index_or_series
436-
if klass is pd.Index and exp_dtype is np.complex128:
437-
mark = pytest.mark.xfail(reason="Complex Index not supported")
438-
request.node.add_marker(mark)
439436

440437
obj = klass([1, 2, 3, 4])
441438
assert obj.dtype == np.int64
@@ -447,9 +444,6 @@ def test_where_int64(self, index_or_series, fill_val, exp_dtype, request):
447444
)
448445
def test_where_float64(self, index_or_series, fill_val, exp_dtype, request):
449446
klass = index_or_series
450-
if klass is pd.Index and exp_dtype is np.complex128:
451-
mark = pytest.mark.xfail(reason="Complex Index not supported")
452-
request.node.add_marker(mark)
453447

454448
obj = klass([1.1, 2.2, 3.3, 4.4])
455449
assert obj.dtype == np.float64
@@ -464,8 +458,8 @@ def test_where_float64(self, index_or_series, fill_val, exp_dtype, request):
464458
(True, object),
465459
],
466460
)
467-
def test_where_series_complex128(self, fill_val, exp_dtype):
468-
klass = pd.Series # TODO: use index_or_series once we have Index[complex]
461+
def test_where_series_complex128(self, index_or_series, fill_val, exp_dtype):
462+
klass = index_or_series
469463
obj = klass([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j])
470464
assert obj.dtype == np.complex128
471465
self._run_test(obj, fill_val, klass, exp_dtype)
@@ -624,11 +618,6 @@ def test_fillna_float64(self, index_or_series, fill_val, fill_dtype):
624618
assert obj.dtype == np.float64
625619

626620
exp = klass([1.1, fill_val, 3.3, 4.4])
627-
# float + complex -> we don't support a complex Index
628-
# complex for Series,
629-
# object for Index
630-
if fill_dtype == np.complex128 and klass == pd.Index:
631-
fill_dtype = object
632621
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
633622

634623
@pytest.mark.parametrize(

pandas/tests/series/methods/test_value_counts.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,12 @@ def test_value_counts_bool_with_nan(self, ser, dropna, exp):
216216
Series([3, 2, 1], index=pd.Index([3j, 1 + 1j, 1], dtype=np.complex128)),
217217
),
218218
(
219-
[1 + 1j, 1 + 1j, 1, 3j, 3j, 3j],
219+
np.array([1 + 1j, 1 + 1j, 1, 3j, 3j, 3j], dtype=np.complex64),
220220
Series([3, 2, 1], index=pd.Index([3j, 1 + 1j, 1], dtype=np.complex64)),
221221
),
222222
],
223223
)
224224
def test_value_counts_complex_numbers(self, input_array, expected):
225225
# GH 17927
226-
# Complex Index dtype is cast to object
227226
result = Series(input_array).value_counts()
228227
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)