Skip to content

Commit 4526bb1

Browse files
committed
updated testing logic
1 parent f3545bf commit 4526bb1

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

pandas/tests/extension/test_arrow.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
is_string_dtype,
6565
is_unsigned_integer_dtype,
6666
)
67+
from pandas.core.dtypes.common import is_timedelta64_dtype
6768
from pandas.tests.extension import base
6869

6970
pa = pytest.importorskip("pyarrow")
@@ -281,15 +282,33 @@ def test_compare_scalar(self, data, comparison_op):
281282
@pytest.mark.parametrize("na_action", [None, "ignore"])
282283
def test_map(self, data_missing, na_action):
283284
if data_missing.dtype.kind in "mM":
284-
result = pd.Series(
285-
np.asarray(
286-
data_missing.map(lambda x: x, na_action=na_action), dtype="int64"
287-
)
288-
)
289-
expected = pd.Series(
290-
data_missing.to_numpy().astype(result.dtype).view("int64")
291-
)
292-
tm.assert_series_equal(result, expected, check_dtype=False)
285+
mapped = data_missing.map(lambda x: x, na_action=na_action)
286+
result = pd.Series(mapped)
287+
expected = pd.Series(data_missing.to_numpy())
288+
289+
orig_dtype = expected.dtype
290+
291+
if result.dtype == "float64" and (
292+
is_datetime64_any_dtype(orig_dtype)
293+
or is_timedelta64_dtype(orig_dtype)
294+
or isinstance(orig_dtype, pd.DatetimeTZDtype)
295+
):
296+
result = result.astype(orig_dtype)
297+
298+
if isinstance(orig_dtype, pd.DatetimeTZDtype):
299+
pass
300+
elif is_datetime64_any_dtype(orig_dtype):
301+
result = result.astype("datetime64[ns]").astype("int64")
302+
expected = expected.astype("datetime64[ns]").astype("int64")
303+
result = pd.Series(result)
304+
expected = pd.Series(expected)
305+
elif is_timedelta64_dtype(orig_dtype):
306+
result = result.astype("timedelta64[ns]")
307+
expected = expected.astype("timedelta64[ns]")
308+
309+
310+
tm.assert_series_equal(result, expected, check_dtype=False, check_exact=False)
311+
293312
else:
294313
result = data_missing.map(lambda x: x, na_action=na_action)
295314
if data_missing.dtype == "float32[pyarrow]":

0 commit comments

Comments
 (0)