Skip to content

Commit 4a1010e

Browse files
ricardoV94brandonwillard
authored andcommitted
Add SoftmaxGrad numba dispatch
1 parent f06146a commit 4a1010e

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

aesara/link/numba/dispatch/elemwise.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2323
from aesara.tensor.math import MaxAndArgmax
24-
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
24+
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
2525
from aesara.tensor.type import tensor
2626

2727

@@ -424,6 +424,31 @@ def softmax(x):
424424
return softmax
425425

426426

427+
@numba_funcify.register(SoftmaxGrad)
428+
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
429+
430+
sm_at = node.inputs[1]
431+
sm_dtype = sm_at.type.numpy_dtype
432+
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
433+
434+
axis = op.axis
435+
if axis is not None:
436+
reduce_sum = create_axis_reducer(
437+
np.add, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
438+
)
439+
else:
440+
reduce_sum = np.sum
441+
442+
@numba.njit
443+
def softmax_grad(dy, sm):
444+
dy_times_sm = dy * sm
445+
sum_dy_times_sm = reduce_sum(dy_times_sm)
446+
dx = dy_times_sm - sum_dy_times_sm * sm
447+
return dx
448+
449+
return softmax_grad
450+
451+
427452
@numba_funcify.register(LogSoftmax)
428453
def numba_funcify_LogSoftmax(op, node, **kwargs):
429454

tests/link/test_numba.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,6 +1893,51 @@ def test_Dot(x, y, exc):
18931893
)
18941894

18951895

1896+
@pytest.mark.parametrize(
1897+
"dy, sm, axis, exc",
1898+
[
1899+
(
1900+
set_test_value(
1901+
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
1902+
),
1903+
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1904+
None,
1905+
None,
1906+
),
1907+
(
1908+
set_test_value(
1909+
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
1910+
),
1911+
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1912+
0,
1913+
None,
1914+
),
1915+
(
1916+
set_test_value(
1917+
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
1918+
),
1919+
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1920+
1,
1921+
None,
1922+
),
1923+
],
1924+
)
1925+
def test_SoftmaxGrad(dy, sm, axis, exc):
1926+
g = nnetb.SoftmaxGrad(axis=axis)(dy, sm)
1927+
g_fg = FunctionGraph(outputs=[g])
1928+
1929+
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
1930+
with cm:
1931+
compare_numba_and_py(
1932+
g_fg,
1933+
[
1934+
i.tag.test_value
1935+
for i in g_fg.inputs
1936+
if not isinstance(i, (SharedVariable, Constant))
1937+
],
1938+
)
1939+
1940+
18961941
@pytest.mark.parametrize(
18971942
"x, axis, exc",
18981943
[

0 commit comments

Comments
 (0)