File tree Expand file tree Collapse file tree 3 files changed +49
-0
lines changed
pytensor/link/pytorch/dispatch Expand file tree Collapse file tree 3 files changed +49
-0
lines changed Original file line number Diff line number Diff line change 5
5
import pytensor .link .pytorch .dispatch .scalar
6
6
import pytensor .link .pytorch .dispatch .elemwise
7
7
import pytensor .link .pytorch .dispatch .extra_ops
8
+ import pytensor .link .pytorch .dispatch .sort
8
9
# isort: on
Original file line number Diff line number Diff line change
1
+ import torch
2
+
3
+ from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
4
+ from pytensor .tensor .sort import ArgSortOp , SortOp
5
+
6
+
7
+ @pytorch_funcify .register (SortOp )
8
+ def pytorch_funcify_Sort (op , ** kwargs ):
9
+ stable = op .kind == "stable"
10
+
11
+ def sort (arr , axis ):
12
+ sorted , _ = torch .sort (arr , dim = axis , stable = stable )
13
+ return sorted
14
+
15
+ return sort
16
+
17
+
18
+ @pytorch_funcify .register (ArgSortOp )
19
+ def pytorch_funcify_ArgSort (op , ** kwargs ):
20
+ stable = op .kind == "stable"
21
+
22
+ def argsort (arr , axis ):
23
+ return torch .argsort (arr , dim = axis , stable = stable )
24
+
25
+ return argsort
Original file line number Diff line number Diff line change
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from pytensor .graph import FunctionGraph
5
+ from pytensor .tensor import matrix
6
+ from pytensor .tensor .sort import argsort , sort
7
+ from tests .link .pytorch .test_basic import compare_pytorch_and_py
8
+
9
+
10
+ @pytest .mark .parametrize ("axis" , [0 , 1 , None ])
11
+ @pytest .mark .parametrize ("func" , (sort , argsort ))
12
+ def test_sort (func , axis ):
13
+ x = matrix ("x" , shape = (2 , 2 ), dtype = "float64" )
14
+ out = func (x , axis = axis )
15
+ fgraph = FunctionGraph ([x ], [out ])
16
+ arr = np .array ([[1.0 , 4.0 ], [5.0 , 2.0 ]])
17
+
18
+ # TODO: remove condition once Reshape is implemented
19
+ if axis is None :
20
+ with pytest .raises (NotImplementedError ):
21
+ compare_pytorch_and_py (fgraph , [arr ])
22
+ else :
23
+ compare_pytorch_and_py (fgraph , [arr ])
You can’t perform that action at this time.
0 commit comments