Skip to content

Commit 918d56a

Browse files
committed
Coerce dtype __props__ to string due to invalid hash of np.dtype() objects
numpy/numpy#17864
1 parent 92eef5e commit 918d56a

File tree

6 files changed

+25
-6
lines changed

6 files changed

+25
-6
lines changed

pytensor/sparse/sandbox/sp2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class Binomial(Op):
9696

9797
def __init__(self, format, dtype):
9898
self.format = format
99-
self.dtype = dtype
99+
self.dtype = np.dtype(dtype).name
100100

101101
def make_node(self, n, p, shape):
102102
n = pt.as_tensor_variable(n)

pytensor/tensor/basic.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,8 @@ class Tri(Op):
10901090
def __init__(self, dtype=None):
10911091
if dtype is None:
10921092
dtype = config.floatX
1093+
else:
1094+
dtype = np.dtype(dtype).name
10931095
self.dtype = dtype
10941096

10951097
def make_node(self, N, M, k):
@@ -1368,6 +1370,8 @@ class Eye(Op):
13681370
def __init__(self, dtype=None):
13691371
if dtype is None:
13701372
dtype = config.floatX
1373+
else:
1374+
dtype = np.dtype(dtype).name
13711375
self.dtype = dtype
13721376

13731377
def make_node(self, n, m, k):
@@ -3225,7 +3229,7 @@ class ARange(COp):
32253229
__props__ = ("dtype",)
32263230

32273231
def __init__(self, dtype):
3228-
self.dtype = dtype
3232+
self.dtype = np.dtype(dtype).name
32293233

32303234
def make_node(self, start, stop, step):
32313235
from math import ceil
@@ -3407,7 +3411,8 @@ def arange(start, stop=None, step=1, dtype=None):
34073411
# We use the same dtype as numpy instead of the result of
34083412
# the upcast.
34093413
dtype = str(numpy_dtype)
3410-
3414+
else:
3415+
dtype = np.dtype(dtype).name
34113416
if dtype not in _arange:
34123417
_arange[dtype] = ARange(dtype)
34133418
return _arange[dtype](start, stop, step)

pytensor/tensor/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,8 +1234,8 @@ def __init__(
12341234
else:
12351235
self.axis = tuple(axis)
12361236

1237-
self.dtype = dtype
1238-
self.acc_dtype = acc_dtype
1237+
self.dtype = np.dtype(dtype).name
1238+
self.acc_dtype = np.dytpe(acc_dtype).name
12391239
self.upcast_discrete_output = upcast_discrete_output
12401240

12411241
@property

pytensor/tensor/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class LoadFromDisk(Op):
2525
__props__ = ("dtype", "shape", "mmap_mode")
2626

2727
def __init__(self, dtype, shape, mmap_mode=None):
28-
self.dtype = np.dtype(dtype) # turn "float64" into np.float64
28+
self.dtype = np.dtype(dtype).name
2929
self.shape = shape
3030
if mmap_mode not in (None, "c"):
3131
raise ValueError(

pytensor/tensor/random/op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def __init__(
112112
else:
113113
self.signature = safe_signature(self.ndims_params, [self.ndim_supp])
114114

115+
if dtype is not None:
116+
dtype = np.dtype(dtype).name
115117
self.dtype = dtype or getattr(self, "dtype", None)
116118

117119
self.inplace = (

tests/tensor/test_basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2869,6 +2869,18 @@ def test_static_shape(self):
28692869
assert np.arange(1.3, 17.48, 2.67).shape == arange(1.3, 17.48, 2.67).type.shape
28702870
assert np.arange(-64, 64).shape == arange(-64, 64).type.shape
28712871

2872+
def test_c_cache_bug(self):
2873+
# Regression test for bug caused by issues in hash of `np.dtype()` objects
2874+
# https://github.com/numpy/numpy/issues/17864
2875+
end = iscalar("end")
2876+
arange1 = ARange(np.dtype("float64"))(0, end, 1)
2877+
arange2 = ARange("float64")(0, end + 1, 1)
2878+
assert arange1.owner.op == arange2.owner.op
2879+
assert hash(arange1.owner.op) == hash(arange2.owner.op)
2880+
fn = function([end], [arange1, arange2])
2881+
res1, res2 = fn(10)
2882+
np.testing.assert_array_equal(res1, res2[:-1], strict=True)
2883+
28722884

28732885
class TestNdGrid:
28742886
def setup_method(self):

0 commit comments

Comments
 (0)