Skip to content

Commit 52c0e89

Browse files
ricardoV94Ian Schweer
authored and
Ian Schweer
committed
Split and inverse
1 parent 2f1d25a commit 52c0e89

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Eye,
2020
Join,
2121
MakeVector,
22+
Split,
2223
TensorFromScalar,
2324
)
2425

@@ -185,3 +186,11 @@ def tensorfromscalar(x):
185186
return torch.as_tensor(x)
186187

187188
return tensorfromscalar
189+
190+
191+
@pytorch_funcify.register(Split)
192+
def pytorch_funcify_Split(op, node, **kwargs):
193+
def inner_fn(x, dim, split_amounts):
194+
return x.split(split_amounts.tolist(), dim=dim.item())
195+
196+
return inner_fn

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
66
from pytensor.scalar.basic import (
77
Cast,
8+
Invert,
89
ScalarOp,
910
)
1011
from pytensor.scalar.loop import ScalarLoop
1112
from pytensor.scalar.math import Softplus
1213

1314

15+
@pytorch_funcify.register(Invert)
16+
def pytorch_funcify_invert(op, node, **kwargs):
17+
return torch.bitwise_not
18+
19+
1420
@pytorch_funcify.register(ScalarOp)
1521
def pytorch_funcify_ScalarOp(op, node, **kwargs):
1622
"""Return pytorch function that implements the same computation as the Scalar Op.

tests/link/pytorch/test_basic.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries():
471471
compare_pytorch_and_py(
472472
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
473473
)
474+
475+
476+
rng = np.random.default_rng(42849)
477+
478+
479+
@pytest.mark.parametrize(
480+
"n_splits, axis, values, sizes",
481+
[
482+
(
483+
0,
484+
0,
485+
rng.normal(size=20).astype(config.floatX),
486+
[],
487+
),
488+
(
489+
5,
490+
0,
491+
rng.normal(size=5).astype(config.floatX),
492+
rng.multinomial(5, np.ones(5) / 5),
493+
),
494+
(
495+
5,
496+
0,
497+
rng.normal(size=10).astype(config.floatX),
498+
rng.multinomial(10, np.ones(5) / 5),
499+
),
500+
(
501+
5,
502+
-1,
503+
rng.normal(size=(11, 7)).astype(config.floatX),
504+
rng.multinomial(7, np.ones(5) / 5),
505+
),
506+
(
507+
5,
508+
-2,
509+
rng.normal(size=(11, 7)).astype(config.floatX),
510+
rng.multinomial(11, np.ones(5) / 5),
511+
),
512+
],
513+
)
514+
def test_Split(n_splits, axis, values, sizes):
515+
i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
516+
s = pt.vector("s", dtype="int64")
517+
g = pt.split(i, s, n_splits, axis=axis)
518+
assert len(g) == n_splits
519+
if n_splits == 0:
520+
return
521+
g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g)
522+
523+
compare_pytorch_and_py(g_fg, [values, sizes])

0 commit comments

Comments
 (0)