Skip to content

Commit 96253d5

Browse files
Make as_symbolic work with sparse matrices
1 parent a1739f6 commit 96253d5

File tree

1 file changed

+7
-20
lines changed

1 file changed

+7
-20
lines changed

aesara/sparse/basic.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from numpy.lib.stride_tricks import as_strided
1515

1616
import aesara
17+
from aesara import _as_symbolic, as_symbolic
1718
from aesara import scalar as aes
1819
from aesara.configdefaults import config
1920
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
@@ -128,6 +129,11 @@ def _is_dense(x):
128129
return isinstance(x, np.ndarray)
129130

130131

132+
@_as_symbolic.register(scipy.sparse.base.spmatrix)
133+
def as_symbolic_sparse(x, **kwargs):
134+
return as_sparse_variable(x, **kwargs)
135+
136+
131137
def as_sparse_variable(x, name=None, ndim=None, **kwargs):
132138
"""
133139
Wrapper around SparseVariable constructor to construct
@@ -174,26 +180,7 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
174180
as_sparse = as_sparse_variable
175181

176182

177-
def as_sparse_or_tensor_variable(x, name=None):
178-
"""
179-
Same as `as_sparse_variable` but if we can't make a
180-
sparse variable, we try to make a tensor variable.
181-
182-
Parameters
183-
----------
184-
x
185-
A sparse matrix.
186-
187-
Returns
188-
-------
189-
SparseVariable or TensorVariable version of `x`
190-
191-
"""
192-
193-
try:
194-
return as_sparse_variable(x, name)
195-
except (ValueError, TypeError):
196-
return at.as_tensor_variable(x, name)
183+
as_sparse_or_tensor_variable = as_symbolic
197184

198185

199186
def constant(x, name=None):

0 commit comments

Comments
 (0)