@@ -4263,7 +4263,7 @@ def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None)
4263
4263
if not matrices :
4264
4264
raise ValueError ("no matrices to allocate" )
4265
4265
dtype = largest_common_dtype (matrices )
4266
- matrices = list (map (pytensor . tensor . as_tensor , matrices ))
4266
+ matrices = list (map (as_sparse_or_tensor_variable , matrices ))
4267
4267
4268
4268
if any (mat .type .ndim != 2 for mat in matrices ):
4269
4269
raise TypeError ("all data arguments must be matrices" )
@@ -4273,7 +4273,7 @@ def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None)
4273
4273
4274
4274
def perform (self , node , inputs , output_storage , params = None ):
4275
4275
format = node .outputs [0 ].type .format
4276
- dtype = largest_common_dtype ( inputs )
4276
+ dtype = node . outputs [ 0 ]. type . dtype
4277
4277
output_storage [0 ][0 ] = scipy .sparse .block_diag (inputs , format = format ).astype (
4278
4278
dtype
4279
4279
)
@@ -4296,9 +4296,12 @@ def block_diag(
4296
4296
4297
4297
Parameters
4298
4298
----------
4299
- A, B, C ... : tensors
4300
- Input sparse matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions,
4299
+ A, B, C ... : tensors or array-like
4300
+ Inputs to form the block diagonal matrix. Each input should have the same number of dimensions,
4301
4301
and the block diagonal matrix will be formed using the right-most two dimensions of each input matrix.
4302
+
4303
+ Note that the input matrices need not be sparse themselves, and will be automatically converted to the
4304
+ requested format if they are not.
4302
4305
format: str, optional
4303
4306
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
4304
4307
name: str, optional
@@ -4321,9 +4324,15 @@ def block_diag(
4321
4324
A = csr_matrix([[1, 2], [3, 4]])
4322
4325
B = csr_matrix([[5, 6], [7, 8]])
4323
4326
result_sparse = block_diag(A, B, format='csr', name='X')
4324
- print(result_sparse.eval())
4325
4327
4326
- The resulting sparse block diagonal matrix `result_sparse` is in CSR format.
4328
+ print(result_sparse)
4329
+ >>> SparseVariable{csr,int32}
4330
+
4331
+ print(result_sparse.toarray().eval())
4332
+ >>> array([[1, 2, 0, 0],
4333
+ >>> [3, 4, 0, 0],
4334
+ >>> [0, 0, 5, 6],
4335
+ >>> [0, 0, 7, 8]])
4327
4336
"""
4328
4337
if len (matrices ) == 1 :
4329
4338
return matrices
0 commit comments