Skip to content

Commit 25e57c3

Browse files
committed
improved code readability
1 parent 52cd37f commit 25e57c3

File tree

1 file changed

+10
-19
lines changed

1 file changed

+10
-19
lines changed

pandas/tests/extension/test_arrow.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -281,42 +281,33 @@ def test_compare_scalar(self, data, comparison_op):
281281

282282
@pytest.mark.parametrize("na_action", [None, "ignore"])
283283
def test_map(self, data_missing, na_action):
284-
if data_missing.dtype.kind in "mM":
285-
mapped = data_missing.map(lambda x: x, na_action=na_action)
286-
result = pd.Series(mapped)
287-
expected = pd.Series(data_missing.to_numpy())
284+
result = data_missing.map(lambda x: x, na_action=na_action)
288285

289-
orig_dtype = expected.dtype
286+
if data_missing.dtype == "float32[pyarrow]":
287+
expected = data_missing.to_numpy(dtype="float64", na_value=np.nan)
288+
tm.assert_numpy_array_equal(result, expected)
290289

291-
if result.dtype == "float64" and (
292-
is_datetime64_any_dtype(orig_dtype)
293-
or is_timedelta64_dtype(orig_dtype)
294-
or isinstance(orig_dtype, pd.DatetimeTZDtype)
295-
):
296-
result = result.astype(orig_dtype)
290+
elif data_missing.dtype.kind in "mM":
291+
expected = pd.Series(data_missing.to_numpy())
297292

293+
orig_dtype = expected.dtype
298294
if isinstance(orig_dtype, pd.DatetimeTZDtype):
299295
pass
300296
elif is_datetime64_any_dtype(orig_dtype):
301297
result = result.astype("datetime64[ns]").astype("int64")
302298
expected = expected.astype("datetime64[ns]").astype("int64")
303-
result = pd.Series(result)
304-
expected = pd.Series(expected)
305299
elif is_timedelta64_dtype(orig_dtype):
306300
result = result.astype("timedelta64[ns]")
307301
expected = expected.astype("timedelta64[ns]")
308302

303+
result = pd.Series(result)
304+
expected = pd.Series(expected)
309305
tm.assert_series_equal(
310306
result, expected, check_dtype=False, check_exact=False
311307
)
312308

313309
else:
314-
result = data_missing.map(lambda x: x, na_action=na_action)
315-
if data_missing.dtype == "float32[pyarrow]":
316-
# map roundtrips through objects, which converts to float64
317-
expected = data_missing.to_numpy(dtype="float64", na_value=np.nan)
318-
else:
319-
expected = data_missing.to_numpy()
310+
expected = data_missing.to_numpy()
320311
tm.assert_numpy_array_equal(result, expected)
321312

322313
def test_astype_str(self, data, request, using_infer_string):

0 commit comments

Comments
 (0)