Skip to content

Commit 8e7f626

Browse files
Intermediate changes
1 parent 18085b8 commit 8e7f626

File tree

5 files changed

+259
-241
lines changed

5 files changed

+259
-241
lines changed

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytensor.link.jax.dispatch import jax_funcify
44
from pytensor.tensor.blas import BatchedDot
5-
from pytensor.tensor.math import Dot, MaxAndArgmax
5+
from pytensor.tensor.math import Argmax, Dot, Max
66
from pytensor.tensor.nlinalg import (
77
SVD,
88
Det,
@@ -104,18 +104,73 @@ def batched_dot(a, b):
104104
return batched_dot
105105

106106

107-
@jax_funcify.register(MaxAndArgmax)
108-
def jax_funcify_MaxAndArgmax(op, **kwargs):
107+
# @jax_funcify.register(Max)
108+
# @jax_funcify.register(Argmax)
109+
# def jax_funcify_MaxAndArgmax(op, **kwargs):
110+
# axis = op.axis
111+
112+
# def maxandargmax(x, axis=axis):
113+
# if axis is None:
114+
# axes = tuple(range(x.ndim))
115+
# else:
116+
# axes = tuple(int(ax) for ax in axis)
117+
118+
# max_res = jnp.max(x, axis)
119+
120+
# # NumPy does not support multiple axes for argmax; this is a
121+
# # work-around
122+
# keep_axes = jnp.array(
123+
# [i for i in range(x.ndim) if i not in axes], dtype="int64"
124+
# )
125+
# # Not-reduced axes in front
126+
# transposed_x = jnp.transpose(
127+
# x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
128+
# )
129+
# kept_shape = transposed_x.shape[: len(keep_axes)]
130+
# reduced_shape = transposed_x.shape[len(keep_axes) :]
131+
132+
# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
133+
# # Otherwise reshape would complain citing float arg
134+
# new_shape = (
135+
# *kept_shape,
136+
# jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
137+
# )
138+
# reshaped_x = transposed_x.reshape(new_shape)
139+
140+
# max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
141+
142+
# return max_res, max_idx_res
143+
144+
# return maxandargmax
145+
146+
147+
@jax_funcify.register(Max)
148+
def jax_funcify_Max(op, **kwargs):
109149
axis = op.axis
110150

111-
def maxandargmax(x, axis=axis):
151+
def max(x, axis=axis):
152+
# if axis is None:
153+
# axes = tuple(range(x.ndim))
154+
# else:
155+
# axes = tuple(int(ax) for ax in axis)
156+
157+
max_res = jnp.max(x, axis)
158+
159+
return max_res
160+
161+
return max
162+
163+
164+
@jax_funcify.register(Argmax)
165+
def jax_funcify_Argmax(op, **kwargs):
166+
axis = op.axis
167+
168+
def argmax(x, axis=axis):
112169
if axis is None:
113170
axes = tuple(range(x.ndim))
114171
else:
115172
axes = tuple(int(ax) for ax in axis)
116173

117-
max_res = jnp.max(x, axis)
118-
119174
# NumPy does not support multiple axes for argmax; this is a
120175
# work-around
121176
keep_axes = jnp.array(
@@ -138,6 +193,6 @@ def maxandargmax(x, axis=axis):
138193

139194
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
140195

141-
return max_res, max_idx_res
196+
return max_idx_res
142197

143-
return maxandargmax
198+
return argmax

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 123 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from pytensor.scalar.basic import add as add_as
4646
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
47-
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
47+
from pytensor.tensor.math import Argmax, Max, MulWithoutZeros, Sum
4848
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4949
from pytensor.tensor.type import scalar
5050

@@ -985,8 +985,78 @@ def log_softmax_py_fn(x):
985985
return log_softmax
986986

987987

988-
@numba_funcify.register(MaxAndArgmax)
989-
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
988+
# @numba_funcify.register(Max)
989+
# @numba_funcify.register(Argmax)
990+
# # @numba_funcify.register(MaxandArgmax)
991+
# def numba_funcify_MaxAndArgmax(op, node, **kwargs):
992+
# axis = op.axis
993+
# x_at = node.inputs[0]
994+
# x_dtype = x_at.type.numpy_dtype
995+
# x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
996+
# x_ndim = x_at.ndim
997+
998+
# if x_ndim == 0:
999+
1000+
# @numba_basic.numba_njit(inline="always")
1001+
# def maxandargmax(x):
1002+
# return x, 0
1003+
1004+
# else:
1005+
# axes = tuple(int(ax) for ax in axis)
1006+
1007+
# # NumPy does not support multiple axes for argmax; this is a
1008+
# # work-around
1009+
# keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1010+
1011+
# reduce_max_py_fn = create_multiaxis_reducer(
1012+
# scalar_maximum,
1013+
# -np.inf,
1014+
# axes,
1015+
# x_ndim,
1016+
# x_dtype,
1017+
# return_scalar=False,
1018+
# )
1019+
# reduce_max = jit_compile_reducer(
1020+
# Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1021+
# reduce_max_py_fn,
1022+
# reduce_to_scalar=False,
1023+
# )
1024+
1025+
# reduced_x_ndim = x_ndim - len(axes) + 1
1026+
# argmax_axis = create_axis_apply_fn(
1027+
# np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
1028+
# )
1029+
1030+
# reaxis_order = keep_axes + axes
1031+
# sl1 = slice(None, len(keep_axes))
1032+
# sl2 = slice(len(keep_axes), None)
1033+
1034+
# @numba_basic.numba_njit
1035+
# def maxandargmax(x):
1036+
# max_res = reduce_max(x)
1037+
1038+
# # Not-reduced axes in front
1039+
# transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
1040+
# kept_shape = transposed_x.shape[sl1]
1041+
# reduced_shape = transposed_x.shape[sl2]
1042+
# reduced_size = 1
1043+
# for s in reduced_shape:
1044+
# reduced_size *= s
1045+
1046+
# # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
1047+
# # Otherwise reshape would complain citing float arg
1048+
# new_shape = (*kept_shape, reduced_size)
1049+
# reshaped_x = transposed_x.reshape(new_shape)
1050+
1051+
# max_idx_res = argmax_axis(reshaped_x)
1052+
1053+
# return max_res, max_idx_res
1054+
1055+
# return maxandargmax
1056+
1057+
1058+
@numba_funcify.register(Max)
1059+
def numba_funcify_Max(op, node, **kwargs):
9901060
axis = op.axis
9911061
x_at = node.inputs[0]
9921062
x_dtype = x_at.type.numpy_dtype
@@ -996,15 +1066,15 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
9961066
if x_ndim == 0:
9971067

9981068
@numba_basic.numba_njit(inline="always")
999-
def maxandargmax(x):
1000-
return x, 0
1069+
def max(x):
1070+
return x
10011071

10021072
else:
10031073
axes = tuple(int(ax) for ax in axis)
10041074

10051075
# NumPy does not support multiple axes for argmax; this is a
10061076
# work-around
1007-
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1077+
# keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
10081078

10091079
reduce_max_py_fn = create_multiaxis_reducer(
10101080
scalar_maximum,
@@ -1020,6 +1090,50 @@ def maxandargmax(x):
10201090
reduce_to_scalar=False,
10211091
)
10221092

1093+
@numba_basic.numba_njit
1094+
def max(x):
1095+
max_res = reduce_max(x)
1096+
1097+
return max_res
1098+
1099+
return max
1100+
1101+
1102+
@numba_funcify.register(Argmax)
1103+
def numba_funcify_Argmax(op, node, **kwargs):
1104+
axis = op.axis
1105+
x_at = node.inputs[0]
1106+
x_dtype = x_at.type.numpy_dtype
1107+
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
1108+
x_ndim = x_at.ndim
1109+
1110+
if x_ndim == 0:
1111+
1112+
@numba_basic.numba_njit(inline="always")
1113+
def argmax(x):
1114+
return 0
1115+
1116+
else:
1117+
axes = tuple(int(ax) for ax in axis)
1118+
1119+
# NumPy does not support multiple axes for argmax; this is a
1120+
# work-around
1121+
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1122+
1123+
# reduce_max_py_fn = create_multiaxis_reducer(
1124+
# scalar_maximum,
1125+
# -np.inf,
1126+
# axes,
1127+
# x_ndim,
1128+
# x_dtype,
1129+
# return_scalar=False,
1130+
# )
1131+
# reduce_max = jit_compile_reducer(
1132+
# Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1133+
# reduce_max_py_fn,
1134+
# reduce_to_scalar=False,
1135+
# )
1136+
10231137
reduced_x_ndim = x_ndim - len(axes) + 1
10241138
argmax_axis = create_axis_apply_fn(
10251139
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
@@ -1030,9 +1144,7 @@ def maxandargmax(x):
10301144
sl2 = slice(len(keep_axes), None)
10311145

10321146
@numba_basic.numba_njit
1033-
def maxandargmax(x):
1034-
max_res = reduce_max(x)
1035-
1147+
def argmax(x):
10361148
# Not-reduced axes in front
10371149
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
10381150
kept_shape = transposed_x.shape[sl1]
@@ -1048,6 +1160,6 @@ def maxandargmax(x):
10481160

10491161
max_idx_res = argmax_axis(reshaped_x)
10501162

1051-
return max_res, max_idx_res
1163+
return max_idx_res
10521164

1053-
return maxandargmax
1165+
return argmax

0 commit comments

Comments
 (0)