Skip to content

Commit aaa4d90

Browse files
BUG: Handle all-NA blocks in concat (pandas-dev#20382)
* BUG: Handle all-NA blocks in concat Previously we special cased all-na blocks. We should only do that for non-extension dtypes.
1 parent 670c2e4 commit aaa4d90

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

pandas/core/internals.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5391,7 +5391,8 @@ def is_uniform_join_units(join_units):
53915391
# all blocks need to have the same type
53925392
all(type(ju.block) is type(join_units[0].block) for ju in join_units) and # noqa
53935393
# no blocks that would get missing values (can lead to type upcasts)
5394-
all(not ju.is_na for ju in join_units) and
5394+
# unless we're an extension dtype.
5395+
all(not ju.is_na or ju.block.is_extension for ju in join_units) and
53955396
# no blocks with indexers (as then the dimensions do not fit)
53965397
all(not ju.indexers for ju in join_units) and
53975398
# disregard Panels

pandas/tests/extension/base/reshaping.py

+15
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@ def test_concat(self, data, in_frame):
2525
assert dtype == data.dtype
2626
assert isinstance(result._data.blocks[0], ExtensionBlock)
2727

28+
@pytest.mark.parametrize('in_frame', [True, False])
29+
def test_concat_all_na_block(self, data_missing, in_frame):
30+
valid_block = pd.Series(data_missing.take([1, 1]), index=[0, 1])
31+
na_block = pd.Series(data_missing.take([0, 0]), index=[2, 3])
32+
if in_frame:
33+
valid_block = pd.DataFrame({"a": valid_block})
34+
na_block = pd.DataFrame({"a": na_block})
35+
result = pd.concat([valid_block, na_block])
36+
if in_frame:
37+
expected = pd.DataFrame({"a": data_missing.take([1, 1, 0, 0])})
38+
self.assert_frame_equal(result, expected)
39+
else:
40+
expected = pd.Series(data_missing.take([1, 1, 0, 0]))
41+
self.assert_series_equal(result, expected)
42+
2843
def test_align(self, data, na_value):
2944
a = data[:3]
3045
b = data[2:5]

pandas/tests/extension/decimal/test_decimal.py

+11-20
Original file line numberDiff line numberDiff line change
@@ -36,31 +36,22 @@ def na_value():
3636

3737

3838
class BaseDecimal(object):
39-
@staticmethod
40-
def assert_series_equal(left, right, *args, **kwargs):
41-
# tm.assert_series_equal doesn't handle Decimal('NaN').
42-
# We will ensure that the NA values match, and then
43-
# drop those values before moving on.
39+
40+
def assert_series_equal(self, left, right, *args, **kwargs):
4441

4542
left_na = left.isna()
4643
right_na = right.isna()
4744

4845
tm.assert_series_equal(left_na, right_na)
49-
tm.assert_series_equal(left[~left_na], right[~right_na],
50-
*args, **kwargs)
51-
52-
@staticmethod
53-
def assert_frame_equal(left, right, *args, **kwargs):
54-
# TODO(EA): select_dtypes
55-
decimals = (left.dtypes == 'decimal').index
56-
57-
for col in decimals:
58-
BaseDecimal.assert_series_equal(left[col], right[col],
59-
*args, **kwargs)
60-
61-
left = left.drop(columns=decimals)
62-
right = right.drop(columns=decimals)
63-
tm.assert_frame_equal(left, right, *args, **kwargs)
46+
return tm.assert_series_equal(left[~left_na],
47+
right[~right_na],
48+
*args, **kwargs)
49+
50+
def assert_frame_equal(self, left, right, *args, **kwargs):
51+
self.assert_series_equal(left.dtypes, right.dtypes)
52+
for col in left.columns:
53+
self.assert_series_equal(left[col], right[col],
54+
*args, **kwargs)
6455

6556

6657
class TestDtype(BaseDecimal, base.BaseDtypeTests):

0 commit comments

Comments
 (0)