Skip to content

Commit ec2f0db

Browse files
authored
BUG: bad display for complex series with nan (#53764)
* BUG: bad display for complex series with nan * added comments * added more test cases
1 parent 9afa6f1 commit ec2f0db

File tree

2 files changed

+51
-20
lines changed

2 files changed

+51
-20
lines changed

pandas/io/formats/format.py

+26-20
Original file line numberDiff line numberDiff line change
@@ -1506,14 +1506,16 @@ def format_values_with(float_format):
15061506

15071507
# default formatter leaves a space to the left when formatting
15081508
# floats, must be consistent for left-justifying NaNs (GH #25061)
1509-
if self.justify == "left":
1510-
na_rep = " " + self.na_rep
1511-
else:
1512-
na_rep = self.na_rep
1509+
na_rep = " " + self.na_rep if self.justify == "left" else self.na_rep
15131510

1514-
# separate the wheat from the chaff
1511+
# different formatting strategies for complex and non-complex data
1512+
# need to distinguish complex and float NaNs (GH #53762)
15151513
values = self.values
15161514
is_complex = is_complex_dtype(values)
1515+
if is_complex:
1516+
na_rep = f"{na_rep}+{0:.{self.digits}f}j"
1517+
1518+
# separate the wheat from the chaff
15171519
values = format_with_na_rep(values, formatter, na_rep)
15181520

15191521
if self.fixed_width:
@@ -1912,22 +1914,26 @@ def _trim_zeros_complex(str_complexes: np.ndarray, decimal: str = ".") -> list[s
19121914
Separates the real and imaginary parts from the complex number, and
19131915
executes the _trim_zeros_float method on each of those.
19141916
"""
1915-
trimmed = [
1916-
"".join(_trim_zeros_float(re.split(r"([j+-])", x), decimal))
1917-
for x in str_complexes
1918-
]
1919-
1920-
# pad strings to the length of the longest trimmed string for alignment
1921-
lengths = [len(s) for s in trimmed]
1922-
max_length = max(lengths)
1917+
real_part, imag_part = [], []
1918+
for x in str_complexes:
1919+
# Complex numbers are represented as "(-)xxx(+/-)xxxj"
1920+
# The split will give [maybe "-", "xxx", "+/-", "xxx", "j", ""]
1921+
# Therefore, the imaginary part is the 4th and 3rd last elements,
1922+
# and the real part is everything before the imaginary part
1923+
trimmed = re.split(r"([j+-])", x)
1924+
real_part.append("".join(trimmed[:-4]))
1925+
imag_part.append("".join(trimmed[-4:-2]))
1926+
1927+
# We want to align the lengths of the real and imaginary parts of each complex
1928+
# number, as well as the lengths the real (resp. complex) parts of all numbers
1929+
# in the array
1930+
n = len(str_complexes)
1931+
padded_parts = _trim_zeros_float(real_part + imag_part, decimal)
19231932
padded = [
1924-
s[: -((k - 1) // 2 + 1)] # real part
1925-
+ (max_length - k) // 2 * "0"
1926-
+ s[-((k - 1) // 2 + 1) : -((k - 1) // 2)] # + / -
1927-
+ s[-((k - 1) // 2) : -1] # imaginary part
1928-
+ (max_length - k) // 2 * "0"
1929-
+ s[-1]
1930-
for s, k in zip(trimmed, lengths)
1933+
padded_parts[i] # real part (including - or space, possibly "NaN")
1934+
+ padded_parts[i + n] # imaginary part (including + or -)
1935+
+ "j"
1936+
for i in range(n)
19311937
]
19321938
return padded
19331939

pandas/tests/io/formats/test_printing.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import string
22

33
import numpy as np
4+
import pytest
45

56
import pandas._config.config as cf
67

@@ -207,3 +208,27 @@ def test_multiindex_long_element():
207208
"cccccccccccccccccccccc',)],\n )"
208209
)
209210
assert str(data) == expected
211+
212+
213+
@pytest.mark.parametrize(
214+
"data,output",
215+
[
216+
([2, complex("nan"), 1], [" 2.0+0.0j", " NaN+0.0j", " 1.0+0.0j"]),
217+
([2, complex("nan"), -1], [" 2.0+0.0j", " NaN+0.0j", "-1.0+0.0j"]),
218+
([-2, complex("nan"), -1], ["-2.0+0.0j", " NaN+0.0j", "-1.0+0.0j"]),
219+
([-1.23j, complex("nan"), -1], ["-0.00-1.23j", " NaN+0.00j", "-1.00+0.00j"]),
220+
([1.23j, complex("nan"), 1.23], [" 0.00+1.23j", " NaN+0.00j", " 1.23+0.00j"]),
221+
],
222+
)
223+
@pytest.mark.parametrize("as_frame", [True, False])
224+
def test_ser_df_with_complex_nans(data, output, as_frame):
225+
# GH#53762
226+
obj = pd.Series(data)
227+
if as_frame:
228+
obj = obj.to_frame(name="val")
229+
reprs = [f"{i} {val}" for i, val in enumerate(output)]
230+
expected = f"{'val': >{len(reprs[0])}}\n" + "\n".join(reprs)
231+
else:
232+
reprs = [f"{i} {val}" for i, val in enumerate(output)]
233+
expected = "\n".join(reprs) + "\ndtype: complex128"
234+
assert str(obj) == expected

0 commit comments

Comments
 (0)