Skip to content

Commit bca9a38

Browse files
ricardoV94brandonwillard
authored andcommitted
Move test helper softmax_graph to test module
1 parent 4a1010e commit bca9a38

File tree

4 files changed

+6
-7
lines changed

4 files changed

+6
-7
lines changed

aesara/tensor/nnet/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
sigmoid_binary_crossentropy,
3636
softmax,
3737
softmax_grad_legacy,
38-
softmax_graph,
3938
softmax_legacy,
4039
softmax_simplifier,
4140
softmax_with_bias,

aesara/tensor/nnet/basic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,10 +1121,6 @@ def local_logsoftmax_grad(fgraph, node):
11211121
return [ret]
11221122

11231123

1124-
def softmax_graph(c):
1125-
return exp(c) / exp(c).sum(axis=-1, keepdims=True)
1126-
1127-
11281124
UNSET_AXIS = object()
11291125

11301126

tests/scan/test_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from aesara.tensor.math import dot, mean, sigmoid
5757
from aesara.tensor.math import sum as aet_sum
5858
from aesara.tensor.math import tanh
59-
from aesara.tensor.nnet import categorical_crossentropy, softmax_graph
59+
from aesara.tensor.nnet import categorical_crossentropy
6060
from aesara.tensor.random.utils import RandomStream
6161
from aesara.tensor.shape import Shape_i, reshape, shape, specify_shape
6262
from aesara.tensor.sharedvar import SharedVariable
@@ -81,6 +81,7 @@
8181
vector,
8282
)
8383
from tests import unittest_tools as utt
84+
from tests.tensor.nnet.test_basic import softmax_graph
8485

8586

8687
if config.mode == "FAST_COMPILE":

tests/tensor/nnet/test_basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
sigmoid_binary_crossentropy,
5353
softmax,
5454
softmax_grad_legacy,
55-
softmax_graph,
5655
softmax_legacy,
5756
softmax_with_bias,
5857
softsign,
@@ -83,6 +82,10 @@
8382
)
8483

8584

85+
def softmax_graph(c):
86+
return exp(c) / exp(c).sum(axis=-1, keepdims=True)
87+
88+
8689
def valid_axis_tester(Op):
8790
with pytest.raises(TypeError):
8891
Op(1.5)

0 commit comments

Comments
 (0)