Skip to content

Commit fff0c92

Browse files
jbrockmendelJulianWgs
authored andcommitted
REF: re-use sanitize_array in DataFrame._sanitize_columns (pandas-dev#41611)
1 parent fe591bb commit fff0c92

File tree

4 files changed

+61
-61
lines changed

4 files changed

+61
-61
lines changed

pandas/core/construction.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,8 @@ def sanitize_array(
468468
dtype: DtypeObj | None = None,
469469
copy: bool = False,
470470
raise_cast_failure: bool = True,
471+
*,
472+
allow_2d: bool = False,
471473
) -> ArrayLike:
472474
"""
473475
Sanitize input data to an ndarray or ExtensionArray, copy if specified,
@@ -480,6 +482,8 @@ def sanitize_array(
480482
dtype : np.dtype, ExtensionDtype, or None, default None
481483
copy : bool, default False
482484
raise_cast_failure : bool, default True
485+
allow_2d : bool, default False
486+
If False, raise if we have a 2D Arraylike.
483487
484488
Returns
485489
-------
@@ -554,7 +558,7 @@ def sanitize_array(
554558
# type "Union[ExtensionArray, ndarray]"; expected "ndarray"
555559
subarr = maybe_infer_to_datetimelike(subarr) # type: ignore[arg-type]
556560

557-
subarr = _sanitize_ndim(subarr, data, dtype, index)
561+
subarr = _sanitize_ndim(subarr, data, dtype, index, allow_2d=allow_2d)
558562

559563
if not (
560564
isinstance(subarr.dtype, ExtensionDtype) or isinstance(dtype, ExtensionDtype)
@@ -591,7 +595,12 @@ def range_to_ndarray(rng: range) -> np.ndarray:
591595

592596

593597
def _sanitize_ndim(
594-
result: ArrayLike, data, dtype: DtypeObj | None, index: Index | None
598+
result: ArrayLike,
599+
data,
600+
dtype: DtypeObj | None,
601+
index: Index | None,
602+
*,
603+
allow_2d: bool = False,
595604
) -> ArrayLike:
596605
"""
597606
Ensure we have a 1-dimensional result array.
@@ -605,6 +614,8 @@ def _sanitize_ndim(
605614

606615
elif result.ndim > 1:
607616
if isinstance(data, np.ndarray):
617+
if allow_2d:
618+
return result
608619
raise ValueError("Data must be 1-dimensional")
609620
if is_object_dtype(dtype) and isinstance(dtype, ExtensionDtype):
610621
# i.e. PandasDtype("O")

pandas/core/frame.py

+3-28
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494
infer_dtype_from_scalar,
9595
invalidate_string_dtypes,
9696
maybe_box_native,
97-
maybe_convert_platform,
9897
maybe_downcast_to_dtype,
9998
validate_numeric_casting,
10099
)
@@ -4500,35 +4499,11 @@ def _sanitize_column(self, value) -> ArrayLike:
45004499

45014500
# We should never get here with DataFrame value
45024501
if isinstance(value, Series):
4503-
value = _reindex_for_setitem(value, self.index)
4502+
return _reindex_for_setitem(value, self.index)
45044503

4505-
elif isinstance(value, ExtensionArray):
4506-
# Explicitly copy here
4507-
value = value.copy()
4504+
if is_list_like(value):
45084505
com.require_length_match(value, self.index)
4509-
4510-
elif is_sequence(value):
4511-
com.require_length_match(value, self.index)
4512-
4513-
# turn me into an ndarray
4514-
if not isinstance(value, (np.ndarray, Index)):
4515-
if isinstance(value, list) and len(value) > 0:
4516-
value = maybe_convert_platform(value)
4517-
else:
4518-
value = com.asarray_tuplesafe(value)
4519-
elif isinstance(value, Index):
4520-
value = value.copy(deep=True)._values
4521-
else:
4522-
value = value.copy()
4523-
4524-
# possibly infer to datetimelike
4525-
if is_object_dtype(value.dtype):
4526-
value = sanitize_array(value, None)
4527-
4528-
else:
4529-
value = construct_1d_arraylike_from_scalar(value, len(self), dtype=None)
4530-
4531-
return value
4506+
return sanitize_array(value, self.index, copy=True, allow_2d=True)
45324507

45334508
@property
45344509
def _series(self):

pandas/tests/extension/test_numpy.py

+20
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,25 @@ def _can_hold_element_patched(obj, element) -> bool:
5555
return can_hold_element(obj, element)
5656

5757

58+
orig_assert_attr_equal = tm.assert_attr_equal
59+
60+
61+
def _assert_attr_equal(attr: str, left, right, obj: str = "Attributes"):
62+
"""
63+
patch tm.assert_attr_equal so PandasDtype("object") is closed enough to
64+
np.dtype("object")
65+
"""
66+
if attr == "dtype":
67+
lattr = getattr(left, "dtype", None)
68+
rattr = getattr(right, "dtype", None)
69+
if isinstance(lattr, PandasDtype) and not isinstance(rattr, PandasDtype):
70+
left = left.astype(lattr.numpy_dtype)
71+
elif isinstance(rattr, PandasDtype) and not isinstance(lattr, PandasDtype):
72+
right = right.astype(rattr.numpy_dtype)
73+
74+
orig_assert_attr_equal(attr, left, right, obj)
75+
76+
5877
@pytest.fixture(params=["float", "object"])
5978
def dtype(request):
6079
return PandasDtype(np.dtype(request.param))
@@ -81,6 +100,7 @@ def allow_in_pandas(monkeypatch):
81100
m.setattr(PandasArray, "_typ", "extension")
82101
m.setattr(managers, "_extract_array", _extract_array_patched)
83102
m.setattr(blocks, "can_hold_element", _can_hold_element_patched)
103+
m.setattr(tm.asserters, "assert_attr_equal", _assert_attr_equal)
84104
yield
85105

86106

pandas/tests/indexing/test_partial.py

+25-31
Original file line numberDiff line numberDiff line change
@@ -386,58 +386,51 @@ def test_partial_set_empty_frame(self):
386386
with pytest.raises(ValueError, match=msg):
387387
df.loc[:, 1] = 1
388388

389+
def test_partial_set_empty_frame2(self):
389390
# these work as they don't really change
390391
# anything but the index
391392
# GH5632
392393
expected = DataFrame(columns=["foo"], index=Index([], dtype="object"))
393394

394-
def f():
395-
df = DataFrame(index=Index([], dtype="object"))
396-
df["foo"] = Series([], dtype="object")
397-
return df
395+
df = DataFrame(index=Index([], dtype="object"))
396+
df["foo"] = Series([], dtype="object")
398397

399-
tm.assert_frame_equal(f(), expected)
398+
tm.assert_frame_equal(df, expected)
400399

401-
def f():
402-
df = DataFrame()
403-
df["foo"] = Series(df.index)
404-
return df
400+
df = DataFrame()
401+
df["foo"] = Series(df.index)
405402

406-
tm.assert_frame_equal(f(), expected)
403+
tm.assert_frame_equal(df, expected)
407404

408-
def f():
409-
df = DataFrame()
410-
df["foo"] = df.index
411-
return df
405+
df = DataFrame()
406+
df["foo"] = df.index
412407

413-
tm.assert_frame_equal(f(), expected)
408+
tm.assert_frame_equal(df, expected)
414409

410+
def test_partial_set_empty_frame3(self):
415411
expected = DataFrame(columns=["foo"], index=Index([], dtype="int64"))
416412
expected["foo"] = expected["foo"].astype("float64")
417413

418-
def f():
419-
df = DataFrame(index=Index([], dtype="int64"))
420-
df["foo"] = []
421-
return df
414+
df = DataFrame(index=Index([], dtype="int64"))
415+
df["foo"] = []
422416

423-
tm.assert_frame_equal(f(), expected)
417+
tm.assert_frame_equal(df, expected)
424418

425-
def f():
426-
df = DataFrame(index=Index([], dtype="int64"))
427-
df["foo"] = Series(np.arange(len(df)), dtype="float64")
428-
return df
419+
df = DataFrame(index=Index([], dtype="int64"))
420+
df["foo"] = Series(np.arange(len(df)), dtype="float64")
429421

430-
tm.assert_frame_equal(f(), expected)
422+
tm.assert_frame_equal(df, expected)
431423

432-
def f():
433-
df = DataFrame(index=Index([], dtype="int64"))
434-
df["foo"] = range(len(df))
435-
return df
424+
def test_partial_set_empty_frame4(self):
425+
df = DataFrame(index=Index([], dtype="int64"))
426+
df["foo"] = range(len(df))
436427

437428
expected = DataFrame(columns=["foo"], index=Index([], dtype="int64"))
438-
expected["foo"] = expected["foo"].astype("float64")
439-
tm.assert_frame_equal(f(), expected)
429+
# range is int-dtype-like, so we get int64 dtype
430+
expected["foo"] = expected["foo"].astype("int64")
431+
tm.assert_frame_equal(df, expected)
440432

433+
def test_partial_set_empty_frame5(self):
441434
df = DataFrame()
442435
tm.assert_index_equal(df.columns, Index([], dtype=object))
443436
df2 = DataFrame()
@@ -446,6 +439,7 @@ def f():
446439
tm.assert_frame_equal(df, DataFrame([[1]], index=["foo"], columns=[1]))
447440
tm.assert_frame_equal(df, df2)
448441

442+
def test_partial_set_empty_frame_no_index(self):
449443
# no index to start
450444
expected = DataFrame({0: Series(1, index=range(4))}, columns=["A", "B", 0])
451445

0 commit comments

Comments
 (0)