17
17
import os
18
18
import pkgutil
19
19
import urllib .request
20
+ import warnings
20
21
21
22
from copy import copy
22
- from typing import Any , Dict , List , Sequence
23
+ from typing import Any , Dict , List , Optional , Sequence , Union
23
24
24
25
import aesara
25
26
import aesara .tensor as at
26
27
import numpy as np
27
28
import pandas as pd
28
29
30
+ from aesara .compile .sharedvalue import SharedVariable
29
31
from aesara .graph .basic import Apply
30
32
from aesara .tensor .type import TensorType
31
- from aesara .tensor .var import TensorVariable
33
+ from aesara .tensor .var import TensorConstant , TensorVariable
34
+ from packaging import version
32
35
33
36
import pymc as pm
34
37
40
43
"Minibatch" ,
41
44
"align_minibatches" ,
42
45
"Data" ,
46
+ "ConstantData" ,
47
+ "MutableData" ,
43
48
]
44
49
BASE_URL = "https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/data/{filename}"
45
50
@@ -463,16 +468,115 @@ def align_minibatches(batches=None):
463
468
rng .seed ()
464
469
465
470
466
- class Data :
467
- """Data container class that wraps :func:`aesara.shared` and lets
468
- the model be aware of its inputs and outputs.
471
+ def determine_coords (model , value , dims : Optional [Sequence [str ]] = None ) -> Dict [str , Sequence ]:
472
+ """Determines coordinate values from data or the model (via ``dims``)."""
473
+ coords = {}
474
+
475
+ # If value is a df or a series, we interpret the index as coords:
476
+ if isinstance (value , (pd .Series , pd .DataFrame )):
477
+ dim_name = None
478
+ if dims is not None :
479
+ dim_name = dims [0 ]
480
+ if dim_name is None and value .index .name is not None :
481
+ dim_name = value .index .name
482
+ if dim_name is not None :
483
+ coords [dim_name ] = value .index
484
+
485
+ # If value is a df, we also interpret the columns as coords:
486
+ if isinstance (value , pd .DataFrame ):
487
+ dim_name = None
488
+ if dims is not None :
489
+ dim_name = dims [1 ]
490
+ if dim_name is None and value .columns .name is not None :
491
+ dim_name = value .columns .name
492
+ if dim_name is not None :
493
+ coords [dim_name ] = value .columns
494
+
495
+ if isinstance (value , np .ndarray ) and dims is not None :
496
+ if len (dims ) != value .ndim :
497
+ raise pm .exceptions .ShapeError (
498
+ "Invalid data shape. The rank of the dataset must match the " "length of `dims`." ,
499
+ actual = value .shape ,
500
+ expected = value .ndim ,
501
+ )
502
+ for size , dim in zip (value .shape , dims ):
503
+ coord = model .coords .get (dim , None )
504
+ if coord is None :
505
+ coords [dim ] = pd .RangeIndex (size , name = dim )
506
+
507
+ return coords
508
+
509
+
510
+ def ConstantData (
511
+ name : str ,
512
+ value ,
513
+ * ,
514
+ dims : Optional [Sequence [str ]] = None ,
515
+ export_index_as_coords = False ,
516
+ ** kwargs ,
517
+ ) -> TensorConstant :
518
+ """Alias for ``pm.Data(..., mutable=False)``.
519
+
520
+ Registers the ``value`` as a ``TensorConstant`` with the model.
521
+ """
522
+ return Data (
523
+ name ,
524
+ value ,
525
+ dims = dims ,
526
+ export_index_as_coords = export_index_as_coords ,
527
+ mutable = False ,
528
+ ** kwargs ,
529
+ )
530
+
531
+
532
+ def MutableData (
533
+ name : str ,
534
+ value ,
535
+ * ,
536
+ dims : Optional [Sequence [str ]] = None ,
537
+ export_index_as_coords = False ,
538
+ ** kwargs ,
539
+ ) -> SharedVariable :
540
+ """Alias for ``pm.Data(..., mutable=True)``.
541
+
542
+ Registers the ``value`` as a ``SharedVariable`` with the model.
543
+ """
544
+ return Data (
545
+ name ,
546
+ value ,
547
+ dims = dims ,
548
+ export_index_as_coords = export_index_as_coords ,
549
+ mutable = True ,
550
+ ** kwargs ,
551
+ )
552
+
553
+
554
+ def Data (
555
+ name : str ,
556
+ value ,
557
+ * ,
558
+ dims : Optional [Sequence [str ]] = None ,
559
+ export_index_as_coords = False ,
560
+ mutable : Optional [bool ] = None ,
561
+ ** kwargs ,
562
+ ) -> Union [SharedVariable , TensorConstant ]:
563
+ """Data container that registers a data variable with the model.
564
+
565
+ Depending on the ``mutable`` setting (default: True), the variable
566
+ is registered as a ``SharedVariable``, enabling it to be altered
567
+ in value and shape, but NOT in dimensionality using ``pm.set_data()``.
469
568
470
569
Parameters
471
570
----------
472
571
name: str
473
572
The name for this variable
474
573
value: {List, np.ndarray, pd.Series, pd.Dataframe}
475
574
A value to associate with this variable
575
+ mutable : bool, optional
576
+ Switches between creating a ``SharedVariable`` (``mutable=True``, default)
577
+ vs. creating a ``TensorConstant`` (``mutable=False``).
578
+ Consider using ``pm.ConstantData`` or ``pm.MutableData`` as less verbose
579
+ alternatives to ``pm.Data(..., mutable=...)``.
476
580
dims: {str, tuple of str}, optional, default=None
477
581
Dimension names of the random variables (as opposed to the shapes of these
478
582
random variables). Use this when `value` is a pandas Series or DataFrame. The
@@ -495,7 +599,7 @@ class Data:
495
599
>>> observed_data = [mu + np.random.randn(20) for mu in true_mu]
496
600
497
601
>>> with pm.Model() as model:
498
- ... data = pm.Data ('data', observed_data[0])
602
+ ... data = pm.MutableData ('data', observed_data[0])
499
603
... mu = pm.Normal('mu', 0, 10)
500
604
... pm.Normal('y', mu=mu, sigma=1, observed=data)
501
605
@@ -513,104 +617,58 @@ class Data:
513
617
For more information, take a look at this example notebook
514
618
https://docs.pymc.io/notebooks/data_container.html
515
619
"""
620
+ if isinstance (value , list ):
621
+ value = np .array (value )
516
622
517
- def __new__ (
518
- self ,
519
- name ,
520
- value ,
521
- * ,
522
- dims = None ,
523
- export_index_as_coords = False ,
524
- ** kwargs ,
525
- ):
526
- if isinstance (value , list ):
527
- value = np .array (value )
528
-
529
- # Add data container to the named variables of the model.
530
- try :
531
- model = pm .Model .get_context ()
532
- except TypeError :
533
- raise TypeError (
534
- "No model on context stack, which is needed to instantiate a data container. "
535
- "Add variable inside a 'with model:' block."
536
- )
537
- name = model .name_for (name )
538
-
539
- # `pandas_to_array` takes care of parameter `value` and
540
- # transforms it to something digestible for pymc
541
- shared_object = aesara .shared (pandas_to_array (value ), name , ** kwargs )
542
-
543
- if isinstance (dims , str ):
544
- dims = (dims ,)
545
- if not (dims is None or len (dims ) == shared_object .ndim ):
546
- raise pm .exceptions .ShapeError (
547
- "Length of `dims` must match the dimensions of the dataset." ,
548
- actual = len (dims ),
549
- expected = shared_object .ndim ,
623
+ # Add data container to the named variables of the model.
624
+ try :
625
+ model = pm .Model .get_context ()
626
+ except TypeError :
627
+ raise TypeError (
628
+ "No model on context stack, which is needed to instantiate a data container. "
629
+ "Add variable inside a 'with model:' block."
630
+ )
631
+ name = model .name_for (name )
632
+
633
+ # `pandas_to_array` takes care of parameter `value` and
634
+ # transforms it to something digestible for Aesara.
635
+ arr = pandas_to_array (value )
636
+
637
+ if mutable is None :
638
+ current = version .Version (pm .__version__ )
639
+ mutable = current .major == 4 and current .minor < 1
640
+ if mutable :
641
+ warnings .warn (
642
+ "The `mutable` kwarg was not specified. Currently it defaults to `pm.Data(mutable=True)`,"
643
+ " which is equivalent to using `pm.MutableData()`."
644
+ " In v4.1.0 the default will change to `pm.Data(mutable=False)`, equivalent to `pm.ConstantData`."
645
+ " Set `pm.Data(..., mutable=False/True)`, or use `pm.ConstantData`/`pm.MutableData`." ,
646
+ FutureWarning ,
550
647
)
551
-
552
- coords = self .set_coords (model , value , dims )
553
-
554
- if export_index_as_coords :
555
- model .add_coords (coords )
556
- elif dims :
557
- # Register new dimension lengths
558
- for d , dname in enumerate (dims ):
559
- if not dname in model .dim_lengths :
560
- model .add_coord (dname , values = None , length = shared_object .shape [d ])
561
-
562
- # To draw the node for this variable in the graphviz Digraph we need
563
- # its shape.
564
- # XXX: This needs to be refactored
565
- # shared_object.dshape = tuple(shared_object.shape.eval())
566
- # if dims is not None:
567
- # shape_dims = model.shape_from_dims(dims)
568
- # if shared_object.dshape != shape_dims:
569
- # raise pm.exceptions.ShapeError(
570
- # "Data shape does not match with specified `dims`.",
571
- # actual=shared_object.dshape,
572
- # expected=shape_dims,
573
- # )
574
-
575
- model .add_random_variable (shared_object , dims = dims )
576
-
577
- return shared_object
578
-
579
- @staticmethod
580
- def set_coords (model , value , dims = None ) -> Dict [str , Sequence ]:
581
- coords = {}
582
-
583
- # If value is a df or a series, we interpret the index as coords:
584
- if isinstance (value , (pd .Series , pd .DataFrame )):
585
- dim_name = None
586
- if dims is not None :
587
- dim_name = dims [0 ]
588
- if dim_name is None and value .index .name is not None :
589
- dim_name = value .index .name
590
- if dim_name is not None :
591
- coords [dim_name ] = value .index
592
-
593
- # If value is a df, we also interpret the columns as coords:
594
- if isinstance (value , pd .DataFrame ):
595
- dim_name = None
596
- if dims is not None :
597
- dim_name = dims [1 ]
598
- if dim_name is None and value .columns .name is not None :
599
- dim_name = value .columns .name
600
- if dim_name is not None :
601
- coords [dim_name ] = value .columns
602
-
603
- if isinstance (value , np .ndarray ) and dims is not None :
604
- if len (dims ) != value .ndim :
605
- raise pm .exceptions .ShapeError (
606
- "Invalid data shape. The rank of the dataset must match the "
607
- "length of `dims`." ,
608
- actual = value .shape ,
609
- expected = value .ndim ,
610
- )
611
- for size , dim in zip (value .shape , dims ):
612
- coord = model .coords .get (dim , None )
613
- if coord is None :
614
- coords [dim ] = pd .RangeIndex (size , name = dim )
615
-
616
- return coords
648
+ if mutable :
649
+ x = aesara .shared (arr , name , ** kwargs )
650
+ else :
651
+ x = at .as_tensor_variable (arr , name , ** kwargs )
652
+
653
+ if isinstance (dims , str ):
654
+ dims = (dims ,)
655
+ if not (dims is None or len (dims ) == x .ndim ):
656
+ raise pm .exceptions .ShapeError (
657
+ "Length of `dims` must match the dimensions of the dataset." ,
658
+ actual = len (dims ),
659
+ expected = x .ndim ,
660
+ )
661
+
662
+ coords = determine_coords (model , value , dims )
663
+
664
+ if export_index_as_coords :
665
+ model .add_coords (coords )
666
+ elif dims :
667
+ # Register new dimension lengths
668
+ for d , dname in enumerate (dims ):
669
+ if not dname in model .dim_lengths :
670
+ model .add_coord (dname , values = None , length = x .shape [d ])
671
+
672
+ model .add_random_variable (x , dims = dims )
673
+
674
+ return x
0 commit comments