Skip to content

Commit 1699197

Browse files
authored
REF: avoid ravel/reshape in astype_nansafe, ndarray_to_mgr (#45817)
1 parent 17b15e9 commit 1699197

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

pandas/core/construction.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,16 @@ def _try_cast(
786786

787787
elif dtype.kind == "U":
788788
# TODO: test cases with arr.dtype.kind in ["m", "M"]
789-
return lib.ensure_string_array(arr, convert_na_value=False, copy=copy)
789+
if is_ndarray:
790+
arr = cast(np.ndarray, arr)
791+
shape = arr.shape
792+
if arr.ndim > 1:
793+
arr = arr.ravel()
794+
else:
795+
shape = (len(arr),)
796+
return lib.ensure_string_array(arr, convert_na_value=False, copy=copy).reshape(
797+
shape
798+
)
790799

791800
elif dtype.kind in ["m", "M"]:
792801
return maybe_cast_to_datetime(arr, dtype)

pandas/core/dtypes/astype.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,6 @@ def astype_nansafe(
8383
ValueError
8484
The dtype was a datetime64/timedelta64 dtype, but it had no unit.
8585
"""
86-
if arr.ndim > 1:
87-
flat = arr.ravel()
88-
result = astype_nansafe(flat, dtype, copy=copy, skipna=skipna)
89-
# error: Item "ExtensionArray" of "Union[ExtensionArray, ndarray]" has no
90-
# attribute "reshape"
91-
return result.reshape(arr.shape) # type: ignore[union-attr]
9286

9387
# We get here with 0-dim from sparse
9488
arr = np.atleast_1d(arr)
@@ -109,7 +103,12 @@ def astype_nansafe(
109103
return arr.astype(dtype, copy=copy)
110104

111105
if issubclass(dtype.type, str):
112-
return lib.ensure_string_array(arr, skipna=skipna, convert_na_value=False)
106+
shape = arr.shape
107+
if arr.ndim > 1:
108+
arr = arr.ravel()
109+
return lib.ensure_string_array(
110+
arr, skipna=skipna, convert_na_value=False
111+
).reshape(shape)
113112

114113
elif is_datetime64_dtype(arr.dtype):
115114
if dtype == np.int64:
@@ -146,7 +145,7 @@ def astype_nansafe(
146145
from pandas import to_datetime
147146

148147
return astype_nansafe(
149-
to_datetime(arr).values,
148+
to_datetime(arr.ravel()).values.reshape(arr.shape),
150149
dtype,
151150
copy=copy,
152151
)

pandas/core/internals/construction.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -329,18 +329,18 @@ def ndarray_to_mgr(
329329
values = _prep_ndarray(values, copy=copy_on_sanitize)
330330

331331
if dtype is not None and not is_dtype_equal(values.dtype, dtype):
332-
shape = values.shape
333-
flat = values.ravel()
334-
335332
# GH#40110 see similar check inside sanitize_array
336333
rcf = not (is_integer_dtype(dtype) and values.dtype.kind == "f")
337334

338335
values = sanitize_array(
339-
flat, None, dtype=dtype, copy=copy_on_sanitize, raise_cast_failure=rcf
336+
values,
337+
None,
338+
dtype=dtype,
339+
copy=copy_on_sanitize,
340+
raise_cast_failure=rcf,
341+
allow_2d=True,
340342
)
341343

342-
values = values.reshape(shape)
343-
344344
# _prep_ndarray ensures that values.ndim == 2 at this point
345345
index, columns = _get_axes(
346346
values.shape[0], values.shape[1], index=index, columns=columns

pandas/tests/frame/test_constructors.py

+8
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@
7070

7171

7272
class TestDataFrameConstructors:
73+
def test_constructor_from_ndarray_with_str_dtype(self):
74+
# If we don't ravel/reshape around ensure_str_array, we end up
75+
# with an array of strings each of which is e.g. "[0 1 2]"
76+
arr = np.arange(12).reshape(4, 3)
77+
df = DataFrame(arr, dtype=str)
78+
expected = DataFrame(arr.astype(str))
79+
tm.assert_frame_equal(df, expected)
80+
7381
def test_constructor_from_2d_datetimearray(self, using_array_manager):
7482
dti = date_range("2016-01-01", periods=6, tz="US/Pacific")
7583
dta = dti._data.reshape(3, 2)

0 commit comments

Comments
 (0)