20
20
import warnings
21
21
22
22
from copy import copy
23
- from typing import Any , Dict , List , Optional , Sequence , Union , cast
23
+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union , cast
24
24
25
25
import aesara
26
26
import aesara .tensor as at
@@ -466,9 +466,15 @@ def align_minibatches(batches=None):
466
466
rng .seed ()
467
467
468
468
469
- def determine_coords (model , value , dims : Optional [Sequence [str ]] = None ) -> Dict [str , Sequence ]:
469
+ def determine_coords (
470
+ model ,
471
+ value ,
472
+ dims : Optional [Sequence [Optional [str ]]] = None ,
473
+ coords : Optional [Dict [str , Sequence ]] = None ,
474
+ ) -> Tuple [Dict [str , Sequence ], Sequence [Optional [str ]]]:
470
475
"""Determines coordinate values from data or the model (via ``dims``)."""
471
- coords = {}
476
+ if coords is None :
477
+ coords = {}
472
478
473
479
# If value is a df or a series, we interpret the index as coords:
474
480
if hasattr (value , "index" ):
@@ -499,17 +505,22 @@ def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict
499
505
)
500
506
for size , dim in zip (value .shape , dims ):
501
507
coord = model .coords .get (dim , None )
502
- if coord is None :
508
+ if coord is None and dim is not None :
503
509
coords [dim ] = range (size )
504
510
505
- return coords
511
+ if dims is None :
512
+ # TODO: Also determine dim names from the index
513
+ dims = [None ] * np .ndim (value )
514
+
515
+ return coords , dims
506
516
507
517
508
518
def ConstantData (
509
519
name : str ,
510
520
value ,
511
521
* ,
512
522
dims : Optional [Sequence [str ]] = None ,
523
+ coords : Optional [Dict [str , Sequence ]] = None ,
513
524
export_index_as_coords = False ,
514
525
** kwargs ,
515
526
) -> TensorConstant :
@@ -522,6 +533,7 @@ def ConstantData(
522
533
name ,
523
534
value ,
524
535
dims = dims ,
536
+ coords = coords ,
525
537
export_index_as_coords = export_index_as_coords ,
526
538
mutable = False ,
527
539
** kwargs ,
@@ -534,6 +546,7 @@ def MutableData(
534
546
value ,
535
547
* ,
536
548
dims : Optional [Sequence [str ]] = None ,
549
+ coords : Optional [Dict [str , Sequence ]] = None ,
537
550
export_index_as_coords = False ,
538
551
** kwargs ,
539
552
) -> SharedVariable :
@@ -546,6 +559,7 @@ def MutableData(
546
559
name ,
547
560
value ,
548
561
dims = dims ,
562
+ coords = coords ,
549
563
export_index_as_coords = export_index_as_coords ,
550
564
mutable = True ,
551
565
** kwargs ,
@@ -558,6 +572,7 @@ def Data(
558
572
value ,
559
573
* ,
560
574
dims : Optional [Sequence [str ]] = None ,
575
+ coords : Optional [Dict [str , Sequence ]] = None ,
561
576
export_index_as_coords = False ,
562
577
mutable : Optional [bool ] = None ,
563
578
** kwargs ,
@@ -588,9 +603,11 @@ def Data(
588
603
:ref:`arviz:quickstart`.
589
604
If this parameter is not specified, the random variables will not have dimension
590
605
names.
606
+ coords : dict, optional
607
+ Coordinate values to set for new dimensions introduced by this ``Data`` variable.
591
608
export_index_as_coords : bool, default=False
592
- If True, the ``Data`` container will try to infer what the coordinates should be
593
- if there is an index in ``value``.
609
+ If True, the ``Data`` container will try to infer what the coordinates
610
+ and dimension names should be if there is an index in ``value``.
594
611
mutable : bool, optional
595
612
Switches between creating a :class:`~aesara.compile.sharedvalue.SharedVariable`
596
613
(``mutable=True``) vs. creating a :class:`~aesara.tensor.TensorConstant`
@@ -624,6 +641,9 @@ def Data(
624
641
... model.set_data('data', data_vals)
625
642
... idatas.append(pm.sample())
626
643
"""
644
+ if coords is None :
645
+ coords = {}
646
+
627
647
if isinstance (value , list ):
628
648
value = np .array (value )
629
649
@@ -665,15 +685,27 @@ def Data(
665
685
expected = x .ndim ,
666
686
)
667
687
668
- coords = determine_coords (model , value , dims )
669
-
688
+ # Optionally infer coords and dims from the input value.
670
689
if export_index_as_coords :
671
- model .add_coords (coords )
672
- elif dims :
690
+ coords , dims = determine_coords (model , value , dims )
691
+
692
+ if dims :
693
+ if not mutable :
694
+ # Use the dimension lengths from the before it was tensorified.
695
+ # These can still be tensors, but in many cases they are numeric.
696
+ xshape = np .shape (arr )
697
+ else :
698
+ xshape = x .shape
673
699
# Register new dimension lengths
674
700
for d , dname in enumerate (dims ):
675
701
if not dname in model .dim_lengths :
676
- model .add_coord (dname , values = None , length = x .shape [d ])
702
+ model .add_coord (
703
+ name = dname ,
704
+ # Note: Coordinate values can't be taken from
705
+ # the value, because it could be N-dimensional.
706
+ values = coords .get (dname , None ),
707
+ length = xshape [d ],
708
+ )
677
709
678
710
model .add_random_variable (x , dims = dims )
679
711
0 commit comments