28
28
29
29
30
30
def concatenate_block_managers (
31
- mgrs_indexers , axes , concat_axis : int , copy : bool
31
+ mgrs_indexers , axes , concat_axis : int , copy : bool , ignore_2d_ea : bool = False ,
32
32
) -> BlockManager :
33
33
"""
34
34
Concatenate block managers into one.
@@ -65,7 +65,9 @@ def concatenate_block_managers(
65
65
b .mgr_locs = placement
66
66
else :
67
67
b = make_block (
68
- _concatenate_join_units (join_units , concat_axis , copy = copy ),
68
+ _concatenate_join_units (
69
+ join_units , concat_axis , copy = copy , ignore_2d_ea = ignore_2d_ea
70
+ ),
69
71
placement = placement ,
70
72
)
71
73
blocks .append (b )
@@ -247,6 +249,16 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
247
249
pass
248
250
elif getattr (self .block , "is_extension" , False ):
249
251
pass
252
+ elif is_extension_array_dtype (empty_dtype ):
253
+ missing_arr = empty_dtype .construct_array_type ()._from_sequence (
254
+ [], dtype = empty_dtype
255
+ )
256
+ ncols , nrows = self .shape
257
+ assert ncols == 1 , ncols
258
+ empty_arr = - 1 * np .ones ((nrows ,), dtype = "int8" )
259
+ return missing_arr .take (
260
+ empty_arr , allow_fill = True , fill_value = fill_value
261
+ )
250
262
else :
251
263
missing_arr = np .empty (self .shape , dtype = empty_dtype )
252
264
missing_arr .fill (fill_value )
@@ -280,7 +292,7 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
280
292
return values
281
293
282
294
283
- def _concatenate_join_units (join_units , concat_axis , copy ):
295
+ def _concatenate_join_units (join_units , concat_axis , copy , ignore_2d_ea = False ):
284
296
"""
285
297
Concatenate values from several join units along selected axis.
286
298
"""
@@ -307,7 +319,9 @@ def _concatenate_join_units(join_units, concat_axis, copy):
307
319
else :
308
320
concat_values = concat_values .copy ()
309
321
else :
310
- concat_values = concat_compat (to_concat , axis = concat_axis )
322
+ concat_values = concat_compat (
323
+ to_concat , axis = concat_axis , ignore_2d_ea = ignore_2d_ea
324
+ )
311
325
312
326
return concat_values
313
327
@@ -344,6 +358,7 @@ def _get_empty_dtype_and_na(join_units):
344
358
345
359
upcast_classes = defaultdict (list )
346
360
null_upcast_classes = defaultdict (list )
361
+
347
362
for dtype , unit in zip (dtypes , join_units ):
348
363
if dtype is None :
349
364
continue
@@ -352,6 +367,11 @@ def _get_empty_dtype_and_na(join_units):
352
367
upcast_cls = "category"
353
368
elif is_datetime64tz_dtype (dtype ):
354
369
upcast_cls = "datetimetz"
370
+
371
+ # may need to move sparse back up
372
+ elif is_extension_array_dtype (dtype ):
373
+ upcast_cls = "extension"
374
+
355
375
elif issubclass (dtype .type , np .bool_ ):
356
376
upcast_cls = "bool"
357
377
elif issubclass (dtype .type , np .object_ ):
@@ -362,8 +382,6 @@ def _get_empty_dtype_and_na(join_units):
362
382
upcast_cls = "timedelta"
363
383
elif is_sparse (dtype ):
364
384
upcast_cls = dtype .subtype .name
365
- elif is_extension_array_dtype (dtype ):
366
- upcast_cls = "object"
367
385
elif is_float_dtype (dtype ) or is_numeric_dtype (dtype ):
368
386
upcast_cls = dtype .name
369
387
else :
@@ -379,6 +397,12 @@ def _get_empty_dtype_and_na(join_units):
379
397
380
398
if not upcast_classes :
381
399
upcast_classes = null_upcast_classes
400
+ if "extension" in upcast_classes :
401
+ if len (upcast_classes ) == 1 :
402
+ cls = upcast_classes ["extension" ][0 ]
403
+ return cls , cls .na_value
404
+ else :
405
+ return np .dtype ("object" ), np .nan
382
406
383
407
# TODO: de-duplicate with maybe_promote?
384
408
# create the result
0 commit comments