-
Notifications
You must be signed in to change notification settings - Fork 132
PyTorch Softmax Ops #846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Softmax Ops #846
Changes from 47 commits
27e2526
629d00b
3eceb56
3cde964
c003aa5
8dc406e
a8f6ddb
9d535f5
c5600da
54b6248
19454b3
92d7114
5aae0e5
8c174dd
0977c3a
1c23825
c9195a8
b07805c
9e8d3fc
a2d3afa
a577a80
499a174
2826613
62ffcec
acdbba1
2519c65
eb6d5c2
9f02a4f
caf2965
bf87eb9
2c27683
c603c6b
3f17107
e850d8d
e682fc4
bf4cf92
899e7f9
04d2935
8ec7661
bb7df41
0441cf2
85f2742
4ca5aca
b9aca57
287d9c2
f42e2a0
35b17e0
5efc3c8
16e415a
ffbc594
b4cdce0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,6 +2,7 @@ | |||||
|
||||||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify | ||||||
from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||||||
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(Elemwise) | ||||||
|
@@ -34,3 +35,46 @@ def dimshuffle(x): | |||||
return res | ||||||
|
||||||
return dimshuffle | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(Softmax) | ||||||
def pytorch_funcify_Softmax(op, **kwargs): | ||||||
axis = op.axis | ||||||
|
||||||
def softmax(x): | ||||||
if not torch.is_floating_point(x): | ||||||
x = x.to(torch.float32) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably convert to the output type advertised by PyTensor. The second argument of the funcify function, is
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually there's a bug in PyTensor Softmax, it also fails if you try to execute with integer dtype in the default backend. I'll open an issue for that. For now it's enough for the torch dispatch function to raise a Same for the other Softmax related functions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Opened an issue: #857 |
||||||
|
||||||
if axis is not None: | ||||||
return torch.softmax(x, dim=axis) | ||||||
else: | ||||||
return torch.softmax(x.ravel(), dim=0).reshape(x.shape) | ||||||
|
||||||
return softmax | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(LogSoftmax) | ||||||
def pytorch_funcify_LogSoftmax(op, **kwargs): | ||||||
axis = op.axis | ||||||
|
||||||
def log_softmax(x): | ||||||
if not torch.is_floating_point(x): | ||||||
x = x.to(torch.float32) | ||||||
|
||||||
if axis is not None: | ||||||
return torch.log_softmax(x, dim=axis) | ||||||
else: | ||||||
return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape) | ||||||
|
||||||
return log_softmax | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(SoftmaxGrad) | ||||||
def jax_funcify_SoftmaxGrad(op, **kwargs): | ||||||
axis = op.axis | ||||||
|
||||||
def softmax_grad(dy, sm): | ||||||
dy_times_sm = dy * sm | ||||||
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm | ||||||
|
||||||
return softmax_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is because of this error: #827