@@ -716,6 +716,32 @@ def test_masked_array_not_implemented(
716
716
ptb .as_tensor (x )
717
717
718
718
719
+ def check_alloc_runtime_broadcast (mode ):
720
+ """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
721
+ floatX = config .floatX
722
+ x_v = vector ("x" , shape = (None ,))
723
+
724
+ out = alloc (x_v , 5 , 3 )
725
+ f = pytensor .function ([x_v ], out , mode = mode )
726
+ TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
727
+
728
+ np .testing .assert_array_equal (
729
+ f (x = np .zeros ((3 ,), dtype = floatX )),
730
+ np .zeros ((5 , 3 ), dtype = floatX ),
731
+ )
732
+ with pytest .raises (ValueError , match = "Runtime broadcasting not allowed" ):
733
+ f (x = np .zeros ((1 ,), dtype = floatX ))
734
+
735
+ out = alloc (specify_shape (x_v , (1 ,)), 5 , 3 )
736
+ f = pytensor .function ([x_v ], out , mode = mode )
737
+ TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
738
+
739
+ np .testing .assert_array_equal (
740
+ f (x = np .zeros ((1 ,), dtype = floatX )),
741
+ np .zeros ((5 , 3 ), dtype = floatX ),
742
+ )
743
+
744
+
719
745
class TestAlloc :
720
746
dtype = config .floatX
721
747
mode = mode_opt
@@ -729,32 +755,6 @@ def check_allocs_in_fgraph(fgraph, n):
729
755
== n
730
756
)
731
757
732
- @staticmethod
733
- def check_runtime_broadcast (mode ):
734
- """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
735
- floatX = config .floatX
736
- x_v = vector ("x" , shape = (None ,))
737
-
738
- out = alloc (x_v , 5 , 3 )
739
- f = pytensor .function ([x_v ], out , mode = mode )
740
- TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
741
-
742
- np .testing .assert_array_equal (
743
- f (x = np .zeros ((3 ,), dtype = floatX )),
744
- np .zeros ((5 , 3 ), dtype = floatX ),
745
- )
746
- with pytest .raises (ValueError , match = "Runtime broadcasting not allowed" ):
747
- f (x = np .zeros ((1 ,), dtype = floatX ))
748
-
749
- out = alloc (specify_shape (x_v , (1 ,)), 5 , 3 )
750
- f = pytensor .function ([x_v ], out , mode = mode )
751
- TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
752
-
753
- np .testing .assert_array_equal (
754
- f (x = np .zeros ((1 ,), dtype = floatX )),
755
- np .zeros ((5 , 3 ), dtype = floatX ),
756
- )
757
-
758
758
def setup_method (self ):
759
759
self .rng = np .random .default_rng (seed = utt .fetch_seed ())
760
760
@@ -911,7 +911,7 @@ def test_alloc_of_view_linker(self):
911
911
912
912
@pytest .mark .parametrize ("mode" , (Mode ("py" ), Mode ("c" )))
913
913
def test_runtime_broadcast (self , mode ):
914
- self . check_runtime_broadcast (mode )
914
+ check_alloc_runtime_broadcast (mode )
915
915
916
916
917
917
def test_infer_static_shape ():
0 commit comments