|
64 | 64 | is_string_dtype,
|
65 | 65 | is_unsigned_integer_dtype,
|
66 | 66 | )
|
| 67 | +from pandas.core.dtypes.common import is_timedelta64_dtype |
67 | 68 | from pandas.tests.extension import base
|
68 | 69 |
|
69 | 70 | pa = pytest.importorskip("pyarrow")
|
@@ -281,15 +282,33 @@ def test_compare_scalar(self, data, comparison_op):
|
281 | 282 | @pytest.mark.parametrize("na_action", [None, "ignore"])
|
282 | 283 | def test_map(self, data_missing, na_action):
|
283 | 284 | if data_missing.dtype.kind in "mM":
|
284 |
| - result = pd.Series( |
285 |
| - np.asarray( |
286 |
| - data_missing.map(lambda x: x, na_action=na_action), dtype="int64" |
287 |
| - ) |
288 |
| - ) |
289 |
| - expected = pd.Series( |
290 |
| - data_missing.to_numpy().astype(result.dtype).view("int64") |
291 |
| - ) |
292 |
| - tm.assert_series_equal(result, expected, check_dtype=False) |
| 285 | + mapped = data_missing.map(lambda x: x, na_action=na_action) |
| 286 | + result = pd.Series(mapped) |
| 287 | + expected = pd.Series(data_missing.to_numpy()) |
| 288 | + |
| 289 | + orig_dtype = expected.dtype |
| 290 | + |
| 291 | + if result.dtype == "float64" and ( |
| 292 | + is_datetime64_any_dtype(orig_dtype) |
| 293 | + or is_timedelta64_dtype(orig_dtype) |
| 294 | + or isinstance(orig_dtype, pd.DatetimeTZDtype) |
| 295 | + ): |
| 296 | + result = result.astype(orig_dtype) |
| 297 | + |
| 298 | + if isinstance(orig_dtype, pd.DatetimeTZDtype): |
| 299 | + pass |
| 300 | + elif is_datetime64_any_dtype(orig_dtype): |
| 301 | + result = result.astype("datetime64[ns]").astype("int64") |
| 302 | + expected = expected.astype("datetime64[ns]").astype("int64") |
| 303 | + result = pd.Series(result) |
| 304 | + expected = pd.Series(expected) |
| 305 | + elif is_timedelta64_dtype(orig_dtype): |
| 306 | + result = result.astype("timedelta64[ns]") |
| 307 | + expected = expected.astype("timedelta64[ns]") |
| 308 | + |
| 309 | + |
| 310 | + tm.assert_series_equal(result, expected, check_dtype=False, check_exact=False) |
| 311 | + |
293 | 312 | else:
|
294 | 313 | result = data_missing.map(lambda x: x, na_action=na_action)
|
295 | 314 | if data_missing.dtype == "float32[pyarrow]":
|
|
0 commit comments