Skip to content

Commit 331093e

Browse files
justinessertTomAugspurgerjorisvandenbossche
authored
DF.__setitem__ creates extension column when given extension scalar (#34875)
* Bugfix to make DF.__setitem__ create extension column instead of object column when given an extension scalar * removed bad whitespace * Apply suggestions from code review Checking if extension dtype via built in function instead of manually Co-authored-by: Tom Augspurger <[email protected]> * added missing : * modified cast_extension_scalar_to_array test to include an Interval type * added user-facing test for extension type bug * fixed pep8 issues * added note about bug in setting series to scalar extension type * corrected order of imports * corrected order of imports * fixed black formatting errors * removed extra comma * updated cast_scalar_to_arr to support tuple shape for extension dtype * removed unneeded code * added coverage for datetime with timezone in extension_array test * added TODO * correct line that was too long * fixed dtype issue with tz test * creating distinct arrays for each column * resolving mypy error * added docstring info and test * removed unneeded import * flattened else case in init * refactored extension type column fix * reverted docstring changes * reverted docstring changes * removed unneeded imports * reverted test changes * fixed construct_1d_arraylike bug * reorganized if statements * moved what's new statement to correct file * created new test for period df construction * added assert_frame_equal to period_data test * Using pandas array instead of df constructor for better test Co-authored-by: Joris Van den Bossche <[email protected]> * changed wording * pylint fixes * parameterized test and added comment * removed extra comma * parameterized test * renamed test Co-authored-by: Tom Augspurger <[email protected]> Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent 4fc4622 commit 331093e

File tree

5 files changed

+103
-15
lines changed

5 files changed

+103
-15
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,7 @@ ExtensionArray
11461146
- Fixed bug where :meth:`StringArray.memory_usage` was not implemented (:issue:`33963`)
11471147
- Fixed bug where :meth:`DataFrameGroupBy` would ignore the ``min_count`` argument for aggregations on nullable boolean dtypes (:issue:`34051`)
11481148
- Fixed bug that `DataFrame(columns=.., dtype='string')` would fail (:issue:`27953`, :issue:`33623`)
1149+
- Bug where :class:`DataFrame` column set to scalar extension type was considered an object type rather than the extension type (:issue:`34832`)
11491150
- Fixed bug in ``IntegerArray.astype`` to correctly copy the mask as well (:issue:`34931`).
11501151

11511152
Other

pandas/core/frame.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -520,25 +520,43 @@ def __init__(
520520
mgr = init_ndarray(data, index, columns, dtype=dtype, copy=copy)
521521
else:
522522
mgr = init_dict({}, index, columns, dtype=dtype)
523+
# For data is scalar
523524
else:
524-
try:
525-
arr = np.array(data, dtype=dtype, copy=copy)
526-
except (ValueError, TypeError) as err:
527-
exc = TypeError(
528-
"DataFrame constructor called with "
529-
f"incompatible data and dtype: {err}"
530-
)
531-
raise exc from err
525+
if index is None or columns is None:
526+
raise ValueError("DataFrame constructor not properly called!")
527+
528+
if not dtype:
529+
dtype, _ = infer_dtype_from_scalar(data, pandas_dtype=True)
530+
531+
# For data is a scalar extension dtype
532+
if is_extension_array_dtype(dtype):
533+
534+
values = [
535+
construct_1d_arraylike_from_scalar(data, len(index), dtype)
536+
for _ in range(len(columns))
537+
]
538+
mgr = arrays_to_mgr(values, columns, index, columns, dtype=None)
539+
else:
540+
# Attempt to coerce to a numpy array
541+
try:
542+
arr = np.array(data, dtype=dtype, copy=copy)
543+
except (ValueError, TypeError) as err:
544+
exc = TypeError(
545+
"DataFrame constructor called with "
546+
f"incompatible data and dtype: {err}"
547+
)
548+
raise exc from err
549+
550+
if arr.ndim != 0:
551+
raise ValueError("DataFrame constructor not properly called!")
532552

533-
if arr.ndim == 0 and index is not None and columns is not None:
534553
values = cast_scalar_to_array(
535554
(len(index), len(columns)), data, dtype=dtype
536555
)
556+
537557
mgr = init_ndarray(
538558
values, index, columns, dtype=values.dtype, copy=False
539559
)
540-
else:
541-
raise ValueError("DataFrame constructor not properly called!")
542560

543561
NDFrame.__init__(self, mgr)
544562

@@ -3740,7 +3758,13 @@ def reindexer(value):
37403758
infer_dtype, _ = infer_dtype_from_scalar(value, pandas_dtype=True)
37413759

37423760
# upcast
3743-
value = cast_scalar_to_array(len(self.index), value)
3761+
if is_extension_array_dtype(infer_dtype):
3762+
value = construct_1d_arraylike_from_scalar(
3763+
value, len(self.index), infer_dtype
3764+
)
3765+
else:
3766+
value = cast_scalar_to_array(len(self.index), value)
3767+
37443768
value = maybe_cast_to_datetime(value, infer_dtype)
37453769

37463770
# return internal types directly

pandas/tests/frame/indexing/test_setitem.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
import numpy as np
22
import pytest
33

4-
from pandas import Categorical, DataFrame, Index, Series, Timestamp, date_range
4+
from pandas.core.dtypes.dtypes import DatetimeTZDtype, IntervalDtype, PeriodDtype
5+
6+
from pandas import (
7+
Categorical,
8+
DataFrame,
9+
Index,
10+
Interval,
11+
Period,
12+
Series,
13+
Timestamp,
14+
date_range,
15+
)
516
import pandas._testing as tm
617
from pandas.core.arrays import SparseArray
718

@@ -150,3 +161,23 @@ def test_setitem_dict_preserves_dtypes(self):
150161
"c": float(b),
151162
}
152163
tm.assert_frame_equal(df, expected)
164+
165+
@pytest.mark.parametrize(
166+
"obj,dtype",
167+
[
168+
(Period("2020-01"), PeriodDtype("M")),
169+
(Interval(left=0, right=5), IntervalDtype("int64")),
170+
(
171+
Timestamp("2011-01-01", tz="US/Eastern"),
172+
DatetimeTZDtype(tz="US/Eastern"),
173+
),
174+
],
175+
)
176+
def test_setitem_extension_types(self, obj, dtype):
177+
# GH: 34832
178+
expected = DataFrame({"idx": [1, 2, 3], "obj": Series([obj] * 3, dtype=dtype)})
179+
180+
df = DataFrame({"idx": [1, 2, 3]})
181+
df["obj"] = obj
182+
183+
tm.assert_frame_equal(df, expected)

pandas/tests/frame/methods/test_combine_first.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,14 @@ def test_combine_first_timezone(self):
199199
columns=["UTCdatetime", "abc"],
200200
data=data1,
201201
index=pd.date_range("20140627", periods=1),
202+
dtype="object",
202203
)
203204
data2 = pd.to_datetime("20121212 12:12").tz_localize("UTC")
204205
df2 = pd.DataFrame(
205206
columns=["UTCdatetime", "xyz"],
206207
data=data2,
207208
index=pd.date_range("20140628", periods=1),
209+
dtype="object",
208210
)
209211
res = df2[["UTCdatetime"]].combine_first(df1)
210212
exp = pd.DataFrame(
@@ -217,10 +219,14 @@ def test_combine_first_timezone(self):
217219
},
218220
columns=["UTCdatetime", "abc"],
219221
index=pd.date_range("20140627", periods=2, freq="D"),
222+
dtype="object",
220223
)
221-
tm.assert_frame_equal(res, exp)
222224
assert res["UTCdatetime"].dtype == "datetime64[ns, UTC]"
223225
assert res["abc"].dtype == "datetime64[ns, UTC]"
226+
# Need to cast all to "obejct" because combine_first does not retain dtypes:
227+
# GH Issue 7509
228+
res = res.astype("object")
229+
tm.assert_frame_equal(res, exp)
224230

225231
# see gh-10567
226232
dts1 = pd.date_range("2015-01-01", "2015-01-05", tz="UTC")

pandas/tests/frame/test_constructors.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
from pandas.compat.numpy import _np_version_under1p19
1515

1616
from pandas.core.dtypes.common import is_integer_dtype
17+
from pandas.core.dtypes.dtypes import DatetimeTZDtype, IntervalDtype, PeriodDtype
1718

1819
import pandas as pd
1920
from pandas import (
2021
Categorical,
2122
DataFrame,
2223
Index,
24+
Interval,
2325
MultiIndex,
26+
Period,
2427
RangeIndex,
2528
Series,
2629
Timedelta,
@@ -700,7 +703,7 @@ def create_data(constructor):
700703
tm.assert_frame_equal(result_timedelta, expected)
701704
tm.assert_frame_equal(result_Timedelta, expected)
702705

703-
def test_constructor_period(self):
706+
def test_constructor_period_dict(self):
704707
# PeriodIndex
705708
a = pd.PeriodIndex(["2012-01", "NaT", "2012-04"], freq="M")
706709
b = pd.PeriodIndex(["2012-02-01", "2012-03-01", "NaT"], freq="D")
@@ -713,6 +716,29 @@ def test_constructor_period(self):
713716
assert df["a"].dtype == a.dtype
714717
assert df["b"].dtype == b.dtype
715718

719+
@pytest.mark.parametrize(
720+
"data,dtype",
721+
[
722+
(Period("2020-01"), PeriodDtype("M")),
723+
(Interval(left=0, right=5), IntervalDtype("int64")),
724+
(
725+
Timestamp("2011-01-01", tz="US/Eastern"),
726+
DatetimeTZDtype(tz="US/Eastern"),
727+
),
728+
],
729+
)
730+
def test_constructor_extension_scalar_data(self, data, dtype):
731+
# GH 34832
732+
df = DataFrame(index=[0, 1], columns=["a", "b"], data=data)
733+
734+
assert df["a"].dtype == dtype
735+
assert df["b"].dtype == dtype
736+
737+
arr = pd.array([data] * 2, dtype=dtype)
738+
expected = DataFrame({"a": arr, "b": arr})
739+
740+
tm.assert_frame_equal(df, expected)
741+
716742
def test_nested_dict_frame_constructor(self):
717743
rng = pd.period_range("1/1/2000", periods=5)
718744
df = DataFrame(np.random.randn(10, 5), columns=rng)

0 commit comments

Comments
 (0)