Skip to content

Commit 554b932

Browse files
jbrockmendelluckyvs1
authored andcommitted
REF: implement sanitize_masked_array (pandas-dev#38398)
1 parent 255bc29 commit 554b932

File tree

3 files changed

+25
-23
lines changed

3 files changed

+25
-23
lines changed

pandas/core/construction.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,20 @@ def ensure_wrapped_if_datetimelike(arr):
419419
return arr
420420

421421

422+
def sanitize_masked_array(data: ma.MaskedArray) -> np.ndarray:
423+
"""
424+
Convert numpy MaskedArray to ensure mask is softened.
425+
"""
426+
mask = ma.getmaskarray(data)
427+
if mask.any():
428+
data, fill_value = maybe_upcast(data, copy=True)
429+
data.soften_mask() # set hardmask False if it was True
430+
data[mask] = fill_value
431+
else:
432+
data = data.copy()
433+
return data
434+
435+
422436
def sanitize_array(
423437
data,
424438
index: Optional[Index],
@@ -432,13 +446,7 @@ def sanitize_array(
432446
"""
433447

434448
if isinstance(data, ma.MaskedArray):
435-
mask = ma.getmaskarray(data)
436-
if mask.any():
437-
data, fill_value = maybe_upcast(data, copy=True)
438-
data.soften_mask() # set hardmask False if it was True
439-
data[mask] = fill_value
440-
else:
441-
data = data.copy()
449+
data = sanitize_masked_array(data)
442450

443451
# extract ndarray or ExtensionArray, ensure we have no PandasArray
444452
data = extract_array(data, extract_numpy=True)

pandas/core/frame.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787
maybe_convert_platform,
8888
maybe_downcast_to_dtype,
8989
maybe_infer_to_datetimelike,
90-
maybe_upcast,
9190
validate_numeric_casting,
9291
)
9392
from pandas.core.dtypes.common import (
@@ -126,7 +125,7 @@
126125
from pandas.core.arraylike import OpsMixin
127126
from pandas.core.arrays import Categorical, ExtensionArray
128127
from pandas.core.arrays.sparse import SparseFrameAccessor
129-
from pandas.core.construction import extract_array
128+
from pandas.core.construction import extract_array, sanitize_masked_array
130129
from pandas.core.generic import NDFrame, _shared_docs
131130
from pandas.core.indexes import base as ibase
132131
from pandas.core.indexes.api import (
@@ -535,13 +534,7 @@ def __init__(
535534

536535
# a masked array
537536
else:
538-
mask = ma.getmaskarray(data)
539-
if mask.any():
540-
data, fill_value = maybe_upcast(data, copy=True)
541-
data.soften_mask() # set hardmask False if it was True
542-
data[mask] = fill_value
543-
else:
544-
data = data.copy()
537+
data = sanitize_masked_array(data)
545538
mgr = init_ndarray(data, index, columns, dtype=dtype, copy=copy)
546539

547540
elif isinstance(data, (np.ndarray, Series, Index)):

pandas/core/internals/construction.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
)
5454

5555
if TYPE_CHECKING:
56+
from numpy.ma.mrecords import MaskedRecords
57+
5658
from pandas import Series
5759

5860
# ---------------------------------------------------------------------
@@ -96,13 +98,12 @@ def arrays_to_mgr(
9698

9799

98100
def masked_rec_array_to_mgr(
99-
data, index, columns, dtype: Optional[DtypeObj], copy: bool
101+
data: "MaskedRecords", index, columns, dtype: Optional[DtypeObj], copy: bool
100102
):
101103
"""
102104
Extract from a masked rec array and create the manager.
103105
"""
104106
# essentially process a record array then fill it
105-
fill_value = data.fill_value
106107
fdata = ma.getdata(data)
107108
if index is None:
108109
index = get_names_from_index(fdata)
@@ -116,11 +117,11 @@ def masked_rec_array_to_mgr(
116117

117118
# fill if needed
118119
new_arrays = []
119-
for fv, arr, col in zip(fill_value, arrays, arr_columns):
120-
# TODO: numpy docs suggest fv must be scalar, but could it be
121-
# non-scalar for object dtype?
122-
assert lib.is_scalar(fv), fv
123-
mask = ma.getmaskarray(data[col])
120+
for col in arr_columns:
121+
arr = data[col]
122+
fv = arr.fill_value
123+
124+
mask = ma.getmaskarray(arr)
124125
if mask.any():
125126
arr, fv = maybe_upcast(arr, fill_value=fv, copy=True)
126127
arr[mask] = fv

0 commit comments

Comments
 (0)