Skip to content

Commit 4d0aa3f

Browse files
committed
Return from scalar constants in get_unique_constant_value
1 parent 7a0ea76 commit 4d0aa3f

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

pytensor/tensor/variable.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1045,11 +1045,13 @@ def get_unique_constant_value(x: TensorVariable) -> Number | None:
10451045
if isinstance(x, Constant):
10461046
data = x.data
10471047

1048-
if isinstance(data, np.ndarray) and data.ndim > 0:
1048+
if isinstance(data, np.ndarray) and data.size > 0:
1049+
if data.size == 1:
1050+
return data.squeeze()
1051+
10491052
flat_data = data.ravel()
1050-
if flat_data.shape[0]:
1051-
if (flat_data == flat_data[0]).all():
1052-
return flat_data[0]
1053+
if (flat_data == flat_data[0]).all():
1054+
return flat_data[0]
10531055

10541056
return None
10551057

tests/scan/test_printing.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -654,24 +654,22 @@ def no_shared_fn(n, x_tm1, M):
654654
Inner graphs:
655655
656656
Scan{scan_fn, while_loop=False, inplace=all} [id A]
657-
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
658-
├─ 0 [id J]
659-
├─ Subtensor{i, j, k} [id K]
660-
│ ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
661-
│ ├─ ScalarFromTensor [id M]
662-
│ │ └─ *0-<Scalar(int64, shape=())> [id N] -> [id C] (inner_in_seqs-0)
663-
│ ├─ ScalarFromTensor [id O]
664-
│ │ └─ *1-<Scalar(int64, shape=())> [id P] -> [id D] (inner_in_sit_sot-0)
665-
│ └─ 0 [id Q]
666-
└─ 1 [id R]
667-
668-
Composite{switch(lt(i0, i1), i2, i0)} [id I]
669-
← Switch [id S] 'o0'
670-
├─ LT [id T]
671-
│ ├─ i0 [id U]
672-
│ └─ i1 [id V]
673-
├─ i2 [id W]
674-
└─ i0 [id U]
657+
← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0)
658+
└─ Subtensor{i, j, k} [id J]
659+
├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id K] -> [id H] (inner_in_non_seqs-0)
660+
├─ ScalarFromTensor [id L]
661+
│ └─ *0-<Scalar(int64, shape=())> [id M] -> [id C] (inner_in_seqs-0)
662+
├─ ScalarFromTensor [id N]
663+
│ └─ *1-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_sit_sot-0)
664+
└─ 0 [id P]
665+
666+
Composite{switch(lt(0, i0), 1, 0)} [id I]
667+
← Switch [id Q] 'o0'
668+
├─ LT [id R]
669+
│ ├─ 0 [id S]
670+
│ └─ i0 [id T]
671+
├─ 1 [id U]
672+
└─ 0 [id S]
675673
"""
676674

677675
output_str = debugprint(out, file="str", print_op_info=True)

0 commit comments

Comments
 (0)