@@ -296,21 +296,9 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
296
296
297
297
assert 0 not in indexers
298
298
299
- needs_filling = False
300
- if 1 in indexers :
301
- # indexers[1] is shared by all the JoinUnits, so we can save time
302
- # by only doing this check once
303
- if (indexers [1 ] == - 1 ).any ():
304
- needs_filling = True
305
-
306
299
if mgr .is_single_block :
307
300
blk = mgr .blocks [0 ]
308
- return [
309
- (
310
- blk .mgr_locs ,
311
- JoinUnit (blk , mgr_shape , indexers , needs_filling = needs_filling ),
312
- )
313
- ]
301
+ return [(blk .mgr_locs , JoinUnit (blk , mgr_shape , indexers ))]
314
302
315
303
blknos = mgr .blknos
316
304
blklocs = mgr .blklocs
@@ -356,17 +344,15 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
356
344
# Assertions disabled for performance
357
345
# assert blk._mgr_locs.as_slice == placements.as_slice
358
346
# assert blk.shape[0] == shape[0]
359
- unit = JoinUnit (blk , shape , join_unit_indexers , needs_filling = needs_filling )
347
+ unit = JoinUnit (blk , shape , join_unit_indexers )
360
348
361
349
plan .append ((placements , unit ))
362
350
363
351
return plan
364
352
365
353
366
354
class JoinUnit :
367
- def __init__ (
368
- self , block : Block , shape : Shape , indexers = None , * , needs_filling : bool = False
369
- ):
355
+ def __init__ (self , block : Block , shape : Shape , indexers = None ):
370
356
# Passing shape explicitly is required for cases when block is None.
371
357
# Note: block is None implies indexers is None, but not vice-versa
372
358
if indexers is None :
@@ -376,11 +362,28 @@ def __init__(
376
362
self .indexers = indexers
377
363
self .shape = shape
378
364
379
- self .needs_filling = needs_filling
380
-
381
365
def __repr__ (self ) -> str :
382
366
return f"{ type (self ).__name__ } ({ repr (self .block )} , { self .indexers } )"
383
367
368
+ @cache_readonly
369
+ def needs_filling (self ) -> bool :
370
+ for indexer in self .indexers .values ():
371
+ # FIXME: cache results of indexer == -1 checks.
372
+ if (indexer == - 1 ).any ():
373
+ return True
374
+
375
+ return False
376
+
377
+ @cache_readonly
378
+ def dtype (self ):
379
+ blk = self .block
380
+ if blk .values .dtype .kind == "V" :
381
+ raise AssertionError ("Block is None, no dtype" )
382
+
383
+ if not self .needs_filling :
384
+ return blk .dtype
385
+ return ensure_dtype_can_hold_na (blk .dtype )
386
+
384
387
@cache_readonly
385
388
def is_na (self ) -> bool :
386
389
blk = self .block
@@ -538,12 +541,12 @@ def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj:
538
541
empty_dtype = join_units [0 ].block .dtype
539
542
return empty_dtype
540
543
541
- needs_can_hold_na = any (unit .is_na or unit . needs_filling for unit in join_units )
544
+ has_none_blocks = any (unit .is_na for unit in join_units )
542
545
543
- dtypes = [unit .block . dtype for unit in join_units if not unit .is_na ]
546
+ dtypes = [unit .dtype for unit in join_units if not unit .is_na ]
544
547
545
548
dtype = find_common_type (dtypes )
546
- if needs_can_hold_na :
549
+ if has_none_blocks :
547
550
dtype = ensure_dtype_can_hold_na (dtype )
548
551
return dtype
549
552
@@ -606,13 +609,7 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
606
609
extra_shape = (join_unit .shape [0 ] - length ,) + join_unit .shape [1 :]
607
610
join_unit .shape = (length ,) + join_unit .shape [1 :]
608
611
609
- # extra_indexers does not introduce any -1s, so we can inherit needs_filling
610
- return JoinUnit (
611
- block = extra_block ,
612
- indexers = extra_indexers ,
613
- shape = extra_shape ,
614
- needs_filling = join_unit .needs_filling ,
615
- )
612
+ return JoinUnit (block = extra_block , indexers = extra_indexers , shape = extra_shape )
616
613
617
614
618
615
def _combine_concat_plans (plans , concat_axis : int ):
0 commit comments