Skip to content

WIP: Add tt.nnet.softmax to math (#4226) #4229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

ricardoV94
Copy link
Member

This PR is a straightforward proposal to give access to the tt.nnet.softmax via the pm.math module.

This PR does not address the current deprecation warning when using the softmax function. Perhaps this should be solved before adding it?

I assumed that no new tests would be needed, since this is just a wrapper to the theano function. Is there anything else that should be done?

@codecov
Copy link

codecov bot commented Nov 16, 2020

Codecov Report

Merging #4229 (3b8db0f) into master (83b91d8) will decrease coverage by 0.00%.
The diff coverage is 50.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4229      +/-   ##
==========================================
- Coverage   88.95%   88.94%   -0.01%     
==========================================
  Files          92       92              
  Lines       14806    14808       +2     
==========================================
+ Hits        13170    13171       +1     
- Misses       1636     1637       +1     
Impacted Files Coverage Δ
pymc3/math.py 68.27% <50.00%> (-0.20%) ⬇️

Copy link
Contributor

@AlexAndorra AlexAndorra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a good start, thanks @ricardoV94 !

This PR does not address the current deprecation warning when using the softmax function. Perhaps this should be solved before adding it?

Yeah probably, but first: what do you mean by "the current deprecation warning"? And what version of Theano (or even Theano-PyMC) are you on?

I assumed that no new tests would be needed, since this is just a wrapper to the theano function. Is there anything else that should be done?

I think a few tests would be good to have: maybe just spot some tests where the theano softmax is used and replace it with the new pm.math.softmax?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 17, 2020

Thanks for the feedback!

Yeah probably, but first: what do you mean by "the current deprecation warning"? And what version of Theano (or even Theano-PyMC) are you on?

I am using Theano-PyMC 1.0.11. The softmax usually gives this warning:

UserWarning: DEPRECATION: If x is a vector, Softmax will not automatically pad x anymore in next releases. If you need it, please do it manually. The vector case is gonna be supported soon and the output will be a vector.

It is raised in these two places:
https://github.com/pymc-devs/Theano-PyMC/blob/a9275c3dcc998c8cca5719037e493809b23422ff/theano/tensor/nnet/nnet.py#L443
https://github.com/pymc-devs/Theano-PyMC/blob/a9275c3dcc998c8cca5719037e493809b23422ff/theano/tensor/nnet/nnet.py#L641

In fact a custom softmax function is currently implemented in the StickBreaking Transform to go around this limitation / avoid the warning: https://github.com/pymc-devs/pymc3/blob/a05684b9588208882db164be113954eb21604ea1/pymc3/distributions/transforms.py#L463

    def backward(self, y_):
        y = y_.T
        y = tt.concatenate([y, -tt.sum(y, 0, keepdims=True)])
        # "softmax" with vector support and no deprication warning:
        e_y = tt.exp(y - tt.max(y, 0, keepdims=True))
        x = e_y / tt.sum(e_y, 0, keepdims=True)
        return floatX(x.T)

I think a few tests would be good to have: maybe just spot some tests where the theano softmax is used and replace it with the new pm.math.softmax?

I couldn't find any PyMC3 tests that are using the theano softmax. In fact, it seems that the theano function is not used anywhere in the library. As in the example above, it is always implemented on the spot:

https://github.com/pymc-devs/pymc3/blob/b707791d1cf36fc4d0b6ec83c7edee13d401dcbf/pymc3/step_methods/metropolis.py#L876

https://github.com/pymc-devs/pymc3/blob/e51b9d3d7dbfef713ec39414c6ae7671ba4aa817/pymc3/step_methods/pgbart.py#L165

In any case I can add new tests, similar to what is already being done with the theano log1pexp function. The only doubt I have, is whether the Warning is a big issue or not.

@AlexAndorra
Copy link
Contributor

Thanks for this detailed review @ricardoV94 ! This raises several points, but the main one is probably that we should replace the current Theano implementation of the softmax by the custom one you spotted in the StickBreaking transform (assuming they indeed do the same things). That way, the warning will be taken care of.
Do you feel like opening this PR on the Theano-PyMC repo? Sorry for the ping-pong, but it seems to be the best thing to do here 🏓

Once this part is done, giving access to softmax through PyMC will be super straightforward: in the pymc/math.py file, just add the line from theano.tensor.nnet import softmax; then, in a user session, pm.math.softmax should be available 🤩 (thanks for the tip @brandonwillard)

@AlexAndorra
Copy link
Contributor

FYI, I opened an issue on Theano-PyMC: aesara-devs/aesara#183

@ricardoV94
Copy link
Member Author

Thanks @AlexAndorra!

@michaelosthege
Copy link
Member

@ricardoV94 given the upcoming v4 change, we should somehow get this PR "out of the way", meaning either to merge or close it.

In my opinion pm.math shouldn't be a shortcut to the Aesara API. We had this with ArviZ and recently got rid of it too.

What do you think?

@ricardoV94
Copy link
Member Author

We can close this in favor of the issue that is open in aesara

@ricardoV94 ricardoV94 closed this Feb 27, 2021
@ricardoV94 ricardoV94 mentioned this pull request Jan 10, 2022
@ricardoV94 ricardoV94 deleted the add-softmax-math branch January 31, 2022 09:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants