@@ -293,9 +293,21 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
293
293
294
294
assert 0 not in indexers
295
295
296
+ needs_filling = False
297
+ if 1 in indexers :
298
+ # indexers[1] is shared by all the JoinUnits, so we can save time
299
+ # by only doing this check once
300
+ if (indexers [1 ] == - 1 ).any ():
301
+ needs_filling = True
302
+
296
303
if mgr .is_single_block :
297
304
blk = mgr .blocks [0 ]
298
- return [(blk .mgr_locs , JoinUnit (blk , mgr_shape , indexers ))]
305
+ return [
306
+ (
307
+ blk .mgr_locs ,
308
+ JoinUnit (blk , mgr_shape , indexers , needs_filling = needs_filling ),
309
+ )
310
+ ]
299
311
300
312
blknos = mgr .blknos
301
313
blklocs = mgr .blklocs
@@ -339,15 +351,17 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
339
351
# Assertions disabled for performance
340
352
# assert blk._mgr_locs.as_slice == placements.as_slice
341
353
# assert blk.shape[0] == shape[0]
342
- unit = JoinUnit (blk , shape , join_unit_indexers )
354
+ unit = JoinUnit (blk , shape , join_unit_indexers , needs_filling = needs_filling )
343
355
344
356
plan .append ((placements , unit ))
345
357
346
358
return plan
347
359
348
360
349
361
class JoinUnit :
350
- def __init__ (self , block : Block , shape : Shape , indexers = None ):
362
+ def __init__ (
363
+ self , block : Block , shape : Shape , indexers = None , * , needs_filling : bool = False
364
+ ):
351
365
# Passing shape explicitly is required for cases when block is None.
352
366
# Note: block is None implies indexers is None, but not vice-versa
353
367
if indexers is None :
@@ -357,28 +371,11 @@ def __init__(self, block: Block, shape: Shape, indexers=None):
357
371
self .indexers = indexers
358
372
self .shape = shape
359
373
374
+ self .needs_filling = needs_filling
375
+
360
376
def __repr__ (self ) -> str :
361
377
return f"{ type (self ).__name__ } ({ repr (self .block )} , { self .indexers } )"
362
378
363
- @cache_readonly
364
- def needs_filling (self ) -> bool :
365
- for indexer in self .indexers .values ():
366
- # FIXME: cache results of indexer == -1 checks.
367
- if (indexer == - 1 ).any ():
368
- return True
369
-
370
- return False
371
-
372
- @cache_readonly
373
- def dtype (self ):
374
- blk = self .block
375
- if blk .values .dtype .kind == "V" :
376
- raise AssertionError ("Block is None, no dtype" )
377
-
378
- if not self .needs_filling :
379
- return blk .dtype
380
- return ensure_dtype_can_hold_na (blk .dtype )
381
-
382
379
@cache_readonly
383
380
def is_na (self ) -> bool :
384
381
blk = self .block
@@ -535,12 +532,12 @@ def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj:
535
532
empty_dtype = join_units [0 ].block .dtype
536
533
return empty_dtype
537
534
538
- has_none_blocks = any (unit .is_na for unit in join_units )
535
+ needs_can_hold_na = any (unit .is_na or unit . needs_filling for unit in join_units )
539
536
540
- dtypes = [unit .dtype for unit in join_units if not unit .is_na ]
537
+ dtypes = [unit .block . dtype for unit in join_units if not unit .is_na ]
541
538
542
539
dtype = find_common_type (dtypes )
543
- if has_none_blocks :
540
+ if needs_can_hold_na :
544
541
dtype = ensure_dtype_can_hold_na (dtype )
545
542
return dtype
546
543
@@ -603,7 +600,13 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
603
600
extra_shape = (join_unit .shape [0 ] - length ,) + join_unit .shape [1 :]
604
601
join_unit .shape = (length ,) + join_unit .shape [1 :]
605
602
606
- return JoinUnit (block = extra_block , indexers = extra_indexers , shape = extra_shape )
603
+ # extra_indexers does not introduce any -1s, so we can inherit needs_filling
604
+ return JoinUnit (
605
+ block = extra_block ,
606
+ indexers = extra_indexers ,
607
+ shape = extra_shape ,
608
+ needs_filling = join_unit .needs_filling ,
609
+ )
607
610
608
611
609
612
def _combine_concat_plans (plans , concat_axis : int ):
0 commit comments