@@ -2080,55 +2080,36 @@ def is_super_shape(var1, var2):
2080
2080
2081
2081
2082
2082
class TestMatrixVectorOps :
2083
- def test_vecdot (self ):
2084
- """Test vecdot function with various input shapes."""
2085
- rng = np .random .default_rng (seed = utt .fetch_seed ())
2086
-
2087
- # Test vector-vector
2088
- x = vector ()
2089
- y = vector ()
2090
- z = vecdot (x , y )
2091
- f = function ([x , y ], z )
2092
- x_val = random (5 , rng = rng ).astype (config .floatX )
2093
- y_val = random (5 , rng = rng ).astype (config .floatX )
2094
- np .testing .assert_allclose (f (x_val , y_val ), np .dot (x_val , y_val ))
2095
-
2096
- # Test batched vectors
2097
- x = tensor3 ()
2098
- y = tensor3 ()
2099
- z = vecdot (x , y )
2100
- f = function ([x , y ], z )
2101
-
2102
- x_val = random (2 , 3 , 4 , rng = rng ).astype (config .floatX )
2103
- y_val = random (2 , 3 , 4 , rng = rng ).astype (config .floatX )
2104
- expected = np .sum (x_val * y_val , axis = - 1 )
2105
- np .testing .assert_allclose (f (x_val , y_val ), expected )
2083
+ """Test vecdot, matvec, and vecmat helper functions."""
2106
2084
2107
2085
@pytest .mark .parametrize (
2108
- "func,x_shape,y_shape,make_expected " ,
2086
+ "func,x_shape,y_shape,np_func,batch_axis " ,
2109
2087
[
2110
- # matvec tests - Matrix(M,K) @ Vector(K) -> Vector(M)
2111
- (matvec , (3 , 4 ), (4 ,), lambda x , y : np .dot (x , y )),
2112
- # matvec batch tests - Tensor3(B,M,K) @ Matrix(B,K) -> Matrix(B,M)
2088
+ # vecdot
2089
+ (vecdot , (5 ,), (5 ,), lambda x , y : np .dot (x , y ), None ),
2090
+ (vecdot , (3 , 5 ), (3 , 5 ), lambda x , y : np .sum (x * y , axis = - 1 ), - 1 ),
2091
+ # matvec
2092
+ (matvec , (3 , 4 ), (4 ,), lambda x , y : np .dot (x , y ), None ),
2113
2093
(
2114
2094
matvec ,
2115
2095
(2 , 3 , 4 ),
2116
2096
(2 , 4 ),
2117
2097
lambda x , y : np .array ([np .dot (x [i ], y [i ]) for i in range (len (x ))]),
2098
+ 0 ,
2118
2099
),
2119
- # vecmat tests - Vector(K) @ Matrix(K,N) -> Vector(N)
2120
- (vecmat , (3 ,), (3 , 4 ), lambda x , y : np .dot (x , y )),
2121
- # vecmat batch tests - Matrix(B,K) @ Tensor3(B,K,N) -> Matrix(B,N)
2100
+ # vecmat
2101
+ (vecmat , (3 ,), (3 , 4 ), lambda x , y : np .dot (x , y ), None ),
2122
2102
(
2123
2103
vecmat ,
2124
2104
(2 , 3 ),
2125
2105
(2 , 3 , 4 ),
2126
2106
lambda x , y : np .array ([np .dot (x [i ], y [i ]) for i in range (len (x ))]),
2107
+ 0 ,
2127
2108
),
2128
2109
],
2129
2110
)
2130
- def test_mat_vec_ops (self , func , x_shape , y_shape , make_expected ):
2131
- """Parametrized test for matvec and vecmat functions."""
2111
+ def test_matrix_vector_ops (self , func , x_shape , y_shape , np_func , batch_axis ):
2112
+ """Test all matrix-vector helper functions."""
2132
2113
rng = np .random .default_rng (seed = utt .fetch_seed ())
2133
2114
2134
2115
# Create PyTensor variables with appropriate dimensions
@@ -2146,18 +2127,25 @@ def test_mat_vec_ops(self, func, x_shape, y_shape, make_expected):
2146
2127
else :
2147
2128
y = tensor3 ()
2148
2129
2149
- # Apply the function
2130
+ # Test basic functionality
2150
2131
z = func (x , y )
2151
2132
f = function ([x , y ], z )
2152
2133
2153
- # Create random values
2154
2134
x_val = random (* x_shape , rng = rng ).astype (config .floatX )
2155
2135
y_val = random (* y_shape , rng = rng ).astype (config .floatX )
2156
2136
2157
- # Compare with the expected result
2158
- expected = make_expected (x_val , y_val )
2137
+ expected = np_func (x_val , y_val )
2159
2138
np .testing .assert_allclose (f (x_val , y_val ), expected )
2160
2139
2140
+ # Test with dtype parameter (to improve code coverage)
2141
+ # Use float64 to ensure we can detect the difference
2142
+ z_dtype = func (x , y , dtype = "float64" )
2143
+ f_dtype = function ([x , y ], z_dtype )
2144
+
2145
+ result = f_dtype (x_val , y_val )
2146
+ assert result .dtype == np .float64
2147
+ np .testing .assert_allclose (result , expected )
2148
+
2161
2149
2162
2150
class TestTensordot :
2163
2151
def TensorDot (self , axes ):
0 commit comments