44
44
)
45
45
from pytensor .scalar .basic import add as add_as
46
46
from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
47
- from pytensor .tensor .math import MaxAndArgmax , MulWithoutZeros , Sum
47
+ from pytensor .tensor .math import Argmax , Max , MulWithoutZeros , Sum
48
48
from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
49
49
from pytensor .tensor .type import scalar
50
50
@@ -985,8 +985,78 @@ def log_softmax_py_fn(x):
985
985
return log_softmax
986
986
987
987
988
- @numba_funcify .register (MaxAndArgmax )
989
- def numba_funcify_MaxAndArgmax (op , node , ** kwargs ):
988
+ # @numba_funcify.register(Max)
989
+ # @numba_funcify.register(Argmax)
990
+ # # @numba_funcify.register(MaxandArgmax)
991
+ # def numba_funcify_MaxAndArgmax(op, node, **kwargs):
992
+ # axis = op.axis
993
+ # x_at = node.inputs[0]
994
+ # x_dtype = x_at.type.numpy_dtype
995
+ # x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
996
+ # x_ndim = x_at.ndim
997
+
998
+ # if x_ndim == 0:
999
+
1000
+ # @numba_basic.numba_njit(inline="always")
1001
+ # def maxandargmax(x):
1002
+ # return x, 0
1003
+
1004
+ # else:
1005
+ # axes = tuple(int(ax) for ax in axis)
1006
+
1007
+ # # NumPy does not support multiple axes for argmax; this is a
1008
+ # # work-around
1009
+ # keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1010
+
1011
+ # reduce_max_py_fn = create_multiaxis_reducer(
1012
+ # scalar_maximum,
1013
+ # -np.inf,
1014
+ # axes,
1015
+ # x_ndim,
1016
+ # x_dtype,
1017
+ # return_scalar=False,
1018
+ # )
1019
+ # reduce_max = jit_compile_reducer(
1020
+ # Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1021
+ # reduce_max_py_fn,
1022
+ # reduce_to_scalar=False,
1023
+ # )
1024
+
1025
+ # reduced_x_ndim = x_ndim - len(axes) + 1
1026
+ # argmax_axis = create_axis_apply_fn(
1027
+ # np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
1028
+ # )
1029
+
1030
+ # reaxis_order = keep_axes + axes
1031
+ # sl1 = slice(None, len(keep_axes))
1032
+ # sl2 = slice(len(keep_axes), None)
1033
+
1034
+ # @numba_basic.numba_njit
1035
+ # def maxandargmax(x):
1036
+ # max_res = reduce_max(x)
1037
+
1038
+ # # Not-reduced axes in front
1039
+ # transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
1040
+ # kept_shape = transposed_x.shape[sl1]
1041
+ # reduced_shape = transposed_x.shape[sl2]
1042
+ # reduced_size = 1
1043
+ # for s in reduced_shape:
1044
+ # reduced_size *= s
1045
+
1046
+ # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
1047
+ # # Otherwise reshape would complain citing float arg
1048
+ # new_shape = (*kept_shape, reduced_size)
1049
+ # reshaped_x = transposed_x.reshape(new_shape)
1050
+
1051
+ # max_idx_res = argmax_axis(reshaped_x)
1052
+
1053
+ # return max_res, max_idx_res
1054
+
1055
+ # return maxandargmax
1056
+
1057
+
1058
+ @numba_funcify .register (Max )
1059
+ def numba_funcify_Max (op , node , ** kwargs ):
990
1060
axis = op .axis
991
1061
x_at = node .inputs [0 ]
992
1062
x_dtype = x_at .type .numpy_dtype
@@ -996,15 +1066,15 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
996
1066
if x_ndim == 0 :
997
1067
998
1068
@numba_basic .numba_njit (inline = "always" )
999
- def maxandargmax (x ):
1000
- return x , 0
1069
+ def max (x ):
1070
+ return x
1001
1071
1002
1072
else :
1003
1073
axes = tuple (int (ax ) for ax in axis )
1004
1074
1005
1075
# NumPy does not support multiple axes for argmax; this is a
1006
1076
# work-around
1007
- keep_axes = tuple (i for i in range (x_ndim ) if i not in axes )
1077
+ # keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
1008
1078
1009
1079
reduce_max_py_fn = create_multiaxis_reducer (
1010
1080
scalar_maximum ,
@@ -1020,6 +1090,50 @@ def maxandargmax(x):
1020
1090
reduce_to_scalar = False ,
1021
1091
)
1022
1092
1093
+ @numba_basic .numba_njit
1094
+ def max (x ):
1095
+ max_res = reduce_max (x )
1096
+
1097
+ return max_res
1098
+
1099
+ return max
1100
+
1101
+
1102
+ @numba_funcify .register (Argmax )
1103
+ def numba_funcify_Argmax (op , node , ** kwargs ):
1104
+ axis = op .axis
1105
+ x_at = node .inputs [0 ]
1106
+ x_dtype = x_at .type .numpy_dtype
1107
+ x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
1108
+ x_ndim = x_at .ndim
1109
+
1110
+ if x_ndim == 0 :
1111
+
1112
+ @numba_basic .numba_njit (inline = "always" )
1113
+ def argmax (x ):
1114
+ return 0
1115
+
1116
+ else :
1117
+ axes = tuple (int (ax ) for ax in axis )
1118
+
1119
+ # NumPy does not support multiple axes for argmax; this is a
1120
+ # work-around
1121
+ keep_axes = tuple (i for i in range (x_ndim ) if i not in axes )
1122
+
1123
+ # reduce_max_py_fn = create_multiaxis_reducer(
1124
+ # scalar_maximum,
1125
+ # -np.inf,
1126
+ # axes,
1127
+ # x_ndim,
1128
+ # x_dtype,
1129
+ # return_scalar=False,
1130
+ # )
1131
+ # reduce_max = jit_compile_reducer(
1132
+ # Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1133
+ # reduce_max_py_fn,
1134
+ # reduce_to_scalar=False,
1135
+ # )
1136
+
1023
1137
reduced_x_ndim = x_ndim - len (axes ) + 1
1024
1138
argmax_axis = create_axis_apply_fn (
1025
1139
np .argmax , reduced_x_ndim - 1 , reduced_x_ndim , np .int64
@@ -1030,9 +1144,7 @@ def maxandargmax(x):
1030
1144
sl2 = slice (len (keep_axes ), None )
1031
1145
1032
1146
@numba_basic .numba_njit
1033
- def maxandargmax (x ):
1034
- max_res = reduce_max (x )
1035
-
1147
+ def argmax (x ):
1036
1148
# Not-reduced axes in front
1037
1149
transposed_x = np .ascontiguousarray (np .transpose (x , reaxis_order ))
1038
1150
kept_shape = transposed_x .shape [sl1 ]
@@ -1048,6 +1160,6 @@ def maxandargmax(x):
1048
1160
1049
1161
max_idx_res = argmax_axis (reshaped_x )
1050
1162
1051
- return max_res , max_idx_res
1163
+ return max_idx_res
1052
1164
1053
- return maxandargmax
1165
+ return argmax
0 commit comments