20
20
import warnings
21
21
22
22
from functools import singledispatch
23
- from typing import Any , Optional , Sequence , Tuple , Union
23
+ from typing import Any , Optional , Sequence , Tuple , Union , cast
24
24
25
25
import numpy as np
26
26
@@ -671,7 +671,7 @@ def get_support_shape(
671
671
observed : Optional [Any ] = None ,
672
672
support_shape_offset : Sequence [int ] = None ,
673
673
ndim_supp : int = 1 ,
674
- ):
674
+ ) -> Optional [ TensorVariable ] :
675
675
"""Extract the support shapes from shape / dims / observed information
676
676
677
677
Parameters
@@ -702,46 +702,61 @@ def get_support_shape(
702
702
raise NotImplementedError ("ndim_supp must be bigger than 0" )
703
703
if support_shape_offset is None :
704
704
support_shape_offset = [0 ] * ndim_supp
705
- inferred_support_shape = None
705
+ elif isinstance (support_shape_offset , int ):
706
+ support_shape_offset = [support_shape_offset ] * ndim_supp
707
+ inferred_support_shape : Optional [Sequence [Union [int , np .ndarray , Variable ]]] = None
706
708
707
709
if shape is not None :
708
710
shape = to_tuple (shape )
709
711
assert isinstance (shape , tuple )
710
- inferred_support_shape = at .stack (
711
- [shape [i ] - support_shape_offset [i ] for i in np .arange (- ndim_supp , 0 )]
712
- )
712
+ if len (shape ) < ndim_supp :
713
+ raise ValueError (
714
+ f"Number of shape dimensions is too small for ndim_supp of { ndim_supp } "
715
+ )
716
+ inferred_support_shape = [
717
+ shape [i ] - support_shape_offset [i ] for i in np .arange (- ndim_supp , 0 )
718
+ ]
713
719
714
720
if inferred_support_shape is None and dims is not None :
715
721
dims = convert_dims (dims )
716
722
assert isinstance (dims , tuple )
723
+ if len (dims ) < ndim_supp :
724
+ raise ValueError (f"Number of dims is too small for ndim_supp of { ndim_supp } " )
717
725
model = modelcontext (None )
718
- inferred_support_shape = at .stack (
719
- [
720
- model .dim_lengths [dims [i ]] - support_shape_offset [i ] # type: ignore
721
- for i in np .arange (- ndim_supp , 0 )
722
- ]
723
- )
726
+ inferred_support_shape = [
727
+ model .dim_lengths [dims [i ]] - support_shape_offset [i ] # type: ignore
728
+ for i in np .arange (- ndim_supp , 0 )
729
+ ]
724
730
725
731
if inferred_support_shape is None and observed is not None :
726
732
observed = convert_observed_data (observed )
727
- inferred_support_shape = at .stack (
728
- [observed .shape [i ] - support_shape_offset [i ] for i in np .arange (- ndim_supp , 0 )]
729
- )
733
+ if observed .ndim < ndim_supp :
734
+ raise ValueError (
735
+ f"Number of observed dimensions is too small for ndim_supp of { ndim_supp } "
736
+ )
737
+ inferred_support_shape = [
738
+ observed .shape [i ] - support_shape_offset [i ] for i in np .arange (- ndim_supp , 0 )
739
+ ]
730
740
731
- if inferred_support_shape is None :
741
+ # We did not learn anything
742
+ if inferred_support_shape is None and support_shape is None :
743
+ return None
744
+ # Only source of information was the originally provided support_shape
745
+ elif inferred_support_shape is None :
732
746
inferred_support_shape = support_shape
733
- # If there are two sources of information for the support shapes, assert they are consistent:
747
+ # There were two sources of support_shape, make sure they are consistent
734
748
elif support_shape is not None :
735
- inferred_support_shape = at .stack (
736
- [
749
+ inferred_support_shape = [
750
+ cast (
751
+ Variable ,
737
752
Assert (msg = "support_shape does not match respective shape dimension" )(
738
753
inferred , at .eq (inferred , explicit )
739
- )
740
- for inferred , explicit in zip ( inferred_support_shape , support_shape )
741
- ]
742
- )
754
+ ),
755
+ )
756
+ for inferred , explicit in zip ( inferred_support_shape , support_shape )
757
+ ]
743
758
744
- return inferred_support_shape
759
+ return at . stack ( inferred_support_shape )
745
760
746
761
747
762
def get_support_shape_1d (
@@ -751,21 +766,18 @@ def get_support_shape_1d(
751
766
dims : Optional [Dims ] = None ,
752
767
observed : Optional [Any ] = None ,
753
768
support_shape_offset : int = 0 ,
754
- ):
769
+ ) -> Optional [ TensorVariable ] :
755
770
"""Helper function for cases when you just care about one dimension."""
756
- if support_shape is not None :
757
- support_shape_tuple = (support_shape ,)
758
- else :
759
- support_shape_tuple = None
760
-
761
771
support_shape_tuple = get_support_shape (
762
- support_shape_tuple ,
772
+ support_shape = ( support_shape ,) if support_shape is not None else None ,
763
773
shape = shape ,
764
774
dims = dims ,
765
775
observed = observed ,
766
776
support_shape_offset = (support_shape_offset ,),
767
777
)
768
- if support_shape_tuple is not None :
769
- (support_shape ,) = support_shape_tuple
770
778
771
- return support_shape
779
+ if support_shape_tuple is not None :
780
+ (support_shape_ ,) = support_shape_tuple
781
+ return support_shape_
782
+ else :
783
+ return None
0 commit comments