Skip to content

Commit 35cc565

Browse files
authored
REF: simplify _sanitize_column (#38459)
1 parent deaf138 commit 35cc565

File tree

1 file changed

+86
-72
lines changed

1 file changed

+86
-72
lines changed

pandas/core/frame.py

+86-72
Original file line numberDiff line numberDiff line change
@@ -3160,6 +3160,8 @@ def __setitem__(self, key, value):
31603160
self._setitem_frame(key, value)
31613161
elif isinstance(key, (Series, np.ndarray, list, Index)):
31623162
self._setitem_array(key, value)
3163+
elif isinstance(value, DataFrame):
3164+
self._set_item_frame_value(key, value)
31633165
else:
31643166
# set column
31653167
self._set_item(key, value)
@@ -3213,15 +3215,47 @@ def _setitem_frame(self, key, value):
32133215
self._check_setitem_copy()
32143216
self._where(-key, value, inplace=True)
32153217

3218+
def _set_item_frame_value(self, key, value: "DataFrame") -> None:
3219+
self._ensure_valid_index(value)
3220+
3221+
# align right-hand-side columns if self.columns
3222+
# is multi-index and self[key] is a sub-frame
3223+
if isinstance(self.columns, MultiIndex) and key in self.columns:
3224+
loc = self.columns.get_loc(key)
3225+
if isinstance(loc, (slice, Series, np.ndarray, Index)):
3226+
cols = maybe_droplevels(self.columns[loc], key)
3227+
if len(cols) and not cols.equals(value.columns):
3228+
value = value.reindex(cols, axis=1)
3229+
3230+
# now align rows
3231+
value = _reindex_for_setitem(value, self.index)
3232+
value = value.T
3233+
self._set_item_mgr(key, value)
3234+
32163235
def _iset_item_mgr(self, loc: int, value) -> None:
32173236
self._mgr.iset(loc, value)
32183237
self._clear_item_cache()
32193238

3220-
def _iset_item(self, loc: int, value, broadcast: bool = False):
3239+
def _set_item_mgr(self, key, value):
3240+
value = _maybe_atleast_2d(value)
3241+
3242+
try:
3243+
loc = self._info_axis.get_loc(key)
3244+
except KeyError:
3245+
# This item wasn't present, just insert at end
3246+
self._mgr.insert(len(self._info_axis), key, value)
3247+
else:
3248+
self._iset_item_mgr(loc, value)
3249+
3250+
# check if we are modifying a copy
3251+
# try to set first as we want an invalid
3252+
# value exception to occur first
3253+
if len(self):
3254+
self._check_setitem_copy()
32213255

3222-
# technically _sanitize_column expects a label, not a position,
3223-
# but the behavior is the same as long as we pass broadcast=False
3224-
value = self._sanitize_column(loc, value, broadcast=broadcast)
3256+
def _iset_item(self, loc: int, value):
3257+
value = self._sanitize_column(value)
3258+
value = _maybe_atleast_2d(value)
32253259
self._iset_item_mgr(loc, value)
32263260

32273261
# check if we are modifying a copy
@@ -3240,21 +3274,20 @@ def _set_item(self, key, value):
32403274
Series/TimeSeries will be conformed to the DataFrames index to
32413275
ensure homogeneity.
32423276
"""
3243-
value = self._sanitize_column(key, value)
3277+
value = self._sanitize_column(value)
32443278

3245-
try:
3246-
loc = self._info_axis.get_loc(key)
3247-
except KeyError:
3248-
# This item wasn't present, just insert at end
3249-
self._mgr.insert(len(self._info_axis), key, value)
3250-
else:
3251-
self._iset_item_mgr(loc, value)
3279+
if (
3280+
key in self.columns
3281+
and value.ndim == 1
3282+
and not is_extension_array_dtype(value)
3283+
):
3284+
# broadcast across multiple columns if necessary
3285+
if not self.columns.is_unique or isinstance(self.columns, MultiIndex):
3286+
existing_piece = self[key]
3287+
if isinstance(existing_piece, DataFrame):
3288+
value = np.tile(value, (len(existing_piece.columns), 1))
32523289

3253-
# check if we are modifying a copy
3254-
# try to set first as we want an invalid
3255-
# value exception to occur first
3256-
if len(self):
3257-
self._check_setitem_copy()
3290+
self._set_item_mgr(key, value)
32583291

32593292
def _set_value(self, index, col, value, takeable: bool = False):
32603293
"""
@@ -3790,7 +3823,8 @@ def insert(self, loc, column, value, allow_duplicates: bool = False) -> None:
37903823
"Cannot specify 'allow_duplicates=True' when "
37913824
"'self.flags.allows_duplicate_labels' is False."
37923825
)
3793-
value = self._sanitize_column(column, value, broadcast=False)
3826+
value = self._sanitize_column(value)
3827+
value = _maybe_atleast_2d(value)
37943828
self._mgr.insert(loc, column, value, allow_duplicates=allow_duplicates)
37953829

37963830
def assign(self, **kwargs) -> DataFrame:
@@ -3861,63 +3895,24 @@ def assign(self, **kwargs) -> DataFrame:
38613895
data[k] = com.apply_if_callable(v, data)
38623896
return data
38633897

3864-
def _sanitize_column(self, key, value, broadcast: bool = True):
3898+
def _sanitize_column(self, value):
38653899
"""
38663900
Ensures new columns (which go into the BlockManager as new blocks) are
38673901
always copied and converted into an array.
38683902
38693903
Parameters
38703904
----------
3871-
key : object
38723905
value : scalar, Series, or array-like
3873-
broadcast : bool, default True
3874-
If ``key`` matches multiple duplicate column names in the
3875-
DataFrame, this parameter indicates whether ``value`` should be
3876-
tiled so that the returned array contains a (duplicated) column for
3877-
each occurrence of the key. If False, ``value`` will not be tiled.
38783906
38793907
Returns
38803908
-------
38813909
numpy.ndarray
38823910
"""
38833911
self._ensure_valid_index(value)
38843912

3885-
def reindexer(value):
3886-
# reindex if necessary
3887-
3888-
if value.index.equals(self.index) or not len(self.index):
3889-
value = value._values.copy()
3890-
else:
3891-
3892-
# GH 4107
3893-
try:
3894-
value = value.reindex(self.index)._values
3895-
except ValueError as err:
3896-
# raised in MultiIndex.from_tuples, see test_insert_error_msmgs
3897-
if not value.index.is_unique:
3898-
# duplicate axis
3899-
raise err
3900-
3901-
# other
3902-
raise TypeError(
3903-
"incompatible index of inserted column with frame index"
3904-
) from err
3905-
return value
3906-
3913+
# We should never get here with DataFrame value
39073914
if isinstance(value, Series):
3908-
value = reindexer(value)
3909-
3910-
elif isinstance(value, DataFrame):
3911-
# align right-hand-side columns if self.columns
3912-
# is multi-index and self[key] is a sub-frame
3913-
if isinstance(self.columns, MultiIndex) and key in self.columns:
3914-
loc = self.columns.get_loc(key)
3915-
if isinstance(loc, (slice, Series, np.ndarray, Index)):
3916-
cols = maybe_droplevels(self.columns[loc], key)
3917-
if len(cols) and not cols.equals(value.columns):
3918-
value = value.reindex(cols, axis=1)
3919-
# now align rows
3920-
value = reindexer(value).T
3915+
value = _reindex_for_setitem(value, self.index)
39213916

39223917
elif isinstance(value, ExtensionArray):
39233918
# Explicitly copy here, instead of in sanitize_index,
@@ -3948,18 +3943,7 @@ def reindexer(value):
39483943
else:
39493944
value = construct_1d_arraylike_from_scalar(value, len(self), dtype=None)
39503945

3951-
# return internal types directly
3952-
if is_extension_array_dtype(value):
3953-
return value
3954-
3955-
# broadcast across multiple columns if necessary
3956-
if broadcast and key in self.columns and value.ndim == 1:
3957-
if not self.columns.is_unique or isinstance(self.columns, MultiIndex):
3958-
existing_piece = self[key]
3959-
if isinstance(existing_piece, DataFrame):
3960-
value = np.tile(value, (len(existing_piece.columns), 1))
3961-
3962-
return np.atleast_2d(np.asarray(value))
3946+
return value
39633947

39643948
@property
39653949
def _series(self):
@@ -9557,3 +9541,33 @@ def _from_nested_dict(data) -> collections.defaultdict:
95579541
for col, v in s.items():
95589542
new_data[col][index] = v
95599543
return new_data
9544+
9545+
9546+
def _reindex_for_setitem(value, index: Index):
9547+
# reindex if necessary
9548+
9549+
if value.index.equals(index) or not len(index):
9550+
return value._values.copy()
9551+
9552+
# GH#4107
9553+
try:
9554+
value = value.reindex(index)._values
9555+
except ValueError as err:
9556+
# raised in MultiIndex.from_tuples, see test_insert_error_msmgs
9557+
if not value.index.is_unique:
9558+
# duplicate axis
9559+
raise err
9560+
9561+
raise TypeError(
9562+
"incompatible index of inserted column with frame index"
9563+
) from err
9564+
return value
9565+
9566+
9567+
def _maybe_atleast_2d(value):
9568+
# TODO(EA2D): not needed with 2D EAs
9569+
9570+
if is_extension_array_dtype(value):
9571+
return value
9572+
9573+
return np.atleast_2d(np.asarray(value))

0 commit comments

Comments
 (0)