Skip to content

Commit 60e2510

Browse files
ricardoV94michaelosthege
authored andcommitted
Make Constant and Shared variables subclasses of the respective Variables
1 parent 97317a5 commit 60e2510

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

pytensor/sparse/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def pytensor_hash(self):
479479
return hash_from_sparse(d)
480480

481481

482-
class SparseConstant(TensorConstant, _sparse_py_operators):
482+
class SparseConstant(SparseVariable, TensorConstant):
483483
format = property(lambda self: self.type.format)
484484

485485
def signature(self):

pytensor/sparse/sharedvar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import scipy.sparse
44

55
from pytensor.compile import shared_constructor
6-
from pytensor.sparse.basic import SparseTensorType, _sparse_py_operators
6+
from pytensor.sparse.basic import SparseTensorType, SparseVariable
77
from pytensor.tensor.sharedvar import TensorSharedVariable
88

99

10-
class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators):
10+
class SparseTensorSharedVariable(TensorSharedVariable, SparseVariable):
1111
@property
1212
def format(self):
1313
return self.type.format

pytensor/tensor/sharedvar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.misc.safe_asarray import _asarray
77
from pytensor.tensor import _get_vector_length
88
from pytensor.tensor.type import TensorType
9-
from pytensor.tensor.variable import _tensor_py_operators
9+
from pytensor.tensor.variable import TensorVariable
1010

1111

1212
def __getattr__(name):
@@ -31,7 +31,7 @@ def load_shared_variable(val):
3131
return tensor_constructor(val)
3232

3333

34-
class TensorSharedVariable(_tensor_py_operators, SharedVariable):
34+
class TensorSharedVariable(SharedVariable, TensorVariable):
3535
def zero(self, borrow: bool = False):
3636
r"""Set the values of a shared variable to 0.
3737

0 commit comments

Comments
 (0)