Skip to content

Commit 398383a

Browse files
committed
Merge branch 'torch_dot' of github.com:HangenYuu/pytensor into torch_dot
2 parents a46adc8 + 2fa2f45 commit 398383a

File tree

25 files changed

+129
-97
lines changed

25 files changed

+129
-97
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ line-length = 88
125125
exclude = ["doc/", "pytensor/_version.py"]
126126

127127
[tool.ruff.lint]
128-
select = ["C", "E", "F", "I", "UP", "W", "RUF"]
129-
ignore = ["C408", "C901", "E501", "E741", "RUF012"]
128+
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH"]
129+
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203"]
130130

131131

132132
[tool.ruff.lint.isort]
@@ -136,6 +136,7 @@ lines-after-imports = 2
136136
# TODO: Get rid of these:
137137
"**/__init__.py" = ["F401", "E402", "F403"]
138138
"pytensor/tensor/linalg.py" = ["F403"]
139+
"pytensor/link/c/cmodule.py" = ["PTH"]
139140
# For the tests we skip because `pytest.importorskip` is used:
140141
"tests/link/jax/test_scalar.py" = ["E402"]
141142
"tests/link/jax/test_tensor_basic.py" = ["E402"]

pytensor/compile/debugmode.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def _check_inputs(
478478
"""
479479
destroyed_idx_list = []
480480
destroy_map = node.op.destroy_map
481-
for o_pos, i_pos_list in destroy_map.items():
481+
for i_pos_list in destroy_map.values():
482482
destroyed_idx_list.extend(i_pos_list)
483483
destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list]
484484

@@ -598,7 +598,7 @@ def _check_viewmap(fgraph, node, storage_map):
598598
# TODO: make sure this is correct
599599
# According to OB, duplicate inputs are rejected on build graph time
600600
# if they cause problems. So if they are here it should be ok.
601-
for key, val in good_alias.items():
601+
for key in good_alias:
602602
bad_alias.pop(key, None)
603603
if bad_alias:
604604
raise BadViewMap(node, oi, outstorage, list(bad_alias.values()))
@@ -756,10 +756,7 @@ def _get_preallocated_maps(
756756
# TODO: Sparse? Scalar does not really make sense.
757757

758758
# Do not preallocate memory for outputs that actually work inplace
759-
considered_outputs = []
760-
for r in node.outputs:
761-
if r not in inplace_outs:
762-
considered_outputs.append(r)
759+
considered_outputs = [r for r in node.outputs if r not in inplace_outs]
763760

764761
# Output storage that was initially present in the storage_map
765762
if "initial" in prealloc_modes or "ALL" in prealloc_modes:

pytensor/compile/profiling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,9 +1380,9 @@ def print_stats(stats1, stats2):
13801380
items.sort(key=lambda a: a[1], reverse=True)
13811381
for idx, ((fgraph, node), node_outputs_size) in enumerate(items[:N]):
13821382
code = ["c"] * len(node.outputs)
1383-
for out, inp in node.op.destroy_map.items():
1383+
for out in node.op.destroy_map:
13841384
code[out] = "i"
1385-
for out, inp in node.op.view_map.items():
1385+
for out in node.op.view_map:
13861386
code[out] = "v"
13871387
shapes = str(fct_shapes[fgraph][node])
13881388

pytensor/configdefaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def add_compile_configvars():
348348
if sys.platform == "win32":
349349
mingw_w64_gcc = Path(sys.executable).parent / "Library/mingw-w64/bin/g++"
350350
try:
351-
rc = call_subprocess_Popen([mingw_w64_gcc, "-v"])
351+
rc = call_subprocess_Popen([str(mingw_w64_gcc), "-v"])
352352
if rc == 0:
353353
maybe_add_to_os_environ_pathlist("PATH", mingw_w64_gcc.parent)
354354
except OSError:

pytensor/graph/destroyhandler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _build_droot_impact(destroy_handler):
182182
root_destroyer = {} # root -> destroyer apply
183183

184184
for app in destroy_handler.destroyers:
185-
for output_idx, input_idx_list in app.op.destroy_map.items():
185+
for input_idx_list in app.op.destroy_map.values():
186186
if len(input_idx_list) != 1:
187187
raise NotImplementedError()
188188
input_idx = input_idx_list[0]
@@ -698,7 +698,7 @@ def orderings(self, fgraph, ordered=True):
698698
# keep track of clients that should run before the current Apply
699699
root_clients = set_type()
700700
# for each destroyed input...
701-
for output_idx, input_idx_list in app.op.destroy_map.items():
701+
for input_idx_list in app.op.destroy_map.values():
702702
destroyed_idx = input_idx_list[0]
703703
destroyed_variable = app.inputs[destroyed_idx]
704704
root = droot[destroyed_variable]

pytensor/link/c/basic.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,12 +1450,16 @@ def in_sig(i, topological_pos, i_idx):
14501450
if props:
14511451
version.append(props)
14521452

1453-
for i in node.inputs:
1454-
if isinstance(i.type, CLinkerObject):
1455-
version.append(i.type.c_code_cache_version())
1456-
for o in node.outputs:
1457-
if isinstance(o.type, CLinkerObject):
1458-
version.append(o.type.c_code_cache_version())
1453+
version.extend(
1454+
i.type.c_code_cache_version()
1455+
for i in node.inputs
1456+
if isinstance(i.type, CLinkerObject)
1457+
)
1458+
version.extend(
1459+
o.type.c_code_cache_version()
1460+
for o in node.outputs
1461+
if isinstance(o.type, CLinkerObject)
1462+
)
14591463

14601464
# add the signature for this node
14611465
sig.append(

pytensor/link/c/cmodule.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,9 +2131,11 @@ def get_lines(cmd, parse=True):
21312131
or "-march=native" in line
21322132
):
21332133
continue
2134-
for reg in ("-march=", "-mtune=", "-target-cpu", "-mabi="):
2135-
if reg in line:
2136-
selected_lines.append(line.strip())
2134+
selected_lines.extend(
2135+
line.strip()
2136+
for reg in ("-march=", "-mtune=", "-target-cpu", "-mabi=")
2137+
if reg in line
2138+
)
21372139
lines = list(set(selected_lines)) # to remove duplicate
21382140

21392141
return lines

pytensor/link/c/lazylinker_c.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import errno
21
import logging
3-
import os
42
import sys
53
import warnings
64
from importlib import reload
@@ -43,23 +41,12 @@ def try_reload():
4341
# compile_str()) but if another lazylinker_ext does exist then it will be
4442
# imported and compile_str won't get called at all.
4543
location = config.compiledir / "lazylinker_ext"
46-
if not location.exists():
47-
try:
48-
# Try to make the location
49-
os.mkdir(location)
50-
except OSError as e:
51-
# If we get an error, verify that the error was # 17, the
52-
# path already exists, and that it is a directory Note: we
53-
# can't check if it exists before making it, because we
54-
# are not holding the lock right now, so we could race
55-
# another process and get error 17 if we lose the race
56-
assert e.errno == errno.EEXIST
57-
assert location.is_dir()
44+
location.mkdir(exist_ok=True)
5845

5946
init_file = location / "__init__.py"
6047
if not init_file.exists():
6148
try:
62-
with open(init_file, "w"):
49+
with init_file.open("w"):
6350
pass
6451
except OSError as e:
6552
if init_file.exists():
@@ -129,12 +116,7 @@ def try_reload():
129116
code = cfile.read_text("utf-8")
130117

131118
loc = config.compiledir / dirname
132-
if not loc.exists():
133-
try:
134-
os.mkdir(loc)
135-
except OSError as e:
136-
assert e.errno == errno.EEXIST
137-
assert loc.exists()
119+
loc.mkdir(exist_ok=True)
138120

139121
args = GCC_compiler.compile_args()
140122
GCC_compiler.compile_str(dirname, code, location=loc, preargs=args)
@@ -147,8 +129,7 @@ def try_reload():
147129
# imported at the same time: we need to make sure we do not
148130
# reload the now outdated __init__.pyc below.
149131
init_pyc = loc / "__init__.pyc"
150-
if init_pyc.is_file():
151-
os.remove(init_pyc)
132+
init_pyc.unlink(missing_ok=True)
152133

153134
try_import()
154135
try_reload()

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
import pytensor.link.pytorch.dispatch.elemwise
88
import pytensor.link.pytorch.dispatch.extra_ops
99
import pytensor.link.pytorch.dispatch.math
10-
10+
import pytensor.link.pytorch.dispatch.sort
1111
# isort: on
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4+
from pytensor.tensor.sort import ArgSortOp, SortOp
5+
6+
7+
@pytorch_funcify.register(SortOp)
8+
def pytorch_funcify_Sort(op, **kwargs):
9+
stable = op.kind == "stable"
10+
11+
def sort(arr, axis):
12+
sorted, _ = torch.sort(arr, dim=axis, stable=stable)
13+
return sorted
14+
15+
return sort
16+
17+
18+
@pytorch_funcify.register(ArgSortOp)
19+
def pytorch_funcify_ArgSort(op, **kwargs):
20+
stable = op.kind == "stable"
21+
22+
def argsort(arr, axis):
23+
return torch.argsort(arr, dim=axis, stable=stable)
24+
25+
return argsort

pytensor/link/vm.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def __call__(self, output_subset=None):
545545
# Add the outputs that are needed for the in-place updates of the
546546
# inputs in `self.update_vars`
547547
output_subset = list(output_subset)
548-
for inp, out in self.update_vars.items():
548+
for out in self.update_vars.values():
549549
out_idx = self.fgraph.outputs.index(out)
550550
if out_idx not in output_subset:
551551
output_subset.append(out_idx)
@@ -1055,12 +1055,7 @@ def make_vm(
10551055
for v in self.fgraph.inputs + self.fgraph.outputs:
10561056
vars_idx.setdefault(v, len(vars_idx))
10571057

1058-
nodes_idx_inv = {}
1059-
vars_idx_inv = {}
1060-
for node, i in nodes_idx.items():
1061-
nodes_idx_inv[i] = node
1062-
for var, i in vars_idx.items():
1063-
vars_idx_inv[i] = var
1058+
vars_idx_inv = {i: var for var, i in vars_idx.items()}
10641059

10651060
# put storage_map and compute_map into a int-based scheme
10661061
storage_map_list = [
@@ -1270,15 +1265,16 @@ def make_all(
12701265
if self.allow_gc:
12711266
post_thunk_clear = []
12721267
for node in order:
1273-
clear_after_this_thunk = []
1274-
for input in node.inputs:
1268+
clear_after_this_thunk = [
1269+
storage_map[input]
1270+
for input in node.inputs
12751271
if (
12761272
input in computed
12771273
and input not in fgraph.outputs
12781274
and node == last_user[input]
12791275
and input not in reallocated_vars
1280-
):
1281-
clear_after_this_thunk.append(storage_map[input])
1276+
)
1277+
]
12821278
post_thunk_clear.append(clear_after_this_thunk)
12831279
else:
12841280
post_thunk_clear = None

pytensor/scalar/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4434,8 +4434,7 @@ def apply(self, fgraph):
44344434
)
44354435
# make sure we don't produce any float16.
44364436
assert not any(o.dtype == "float16" for o in new_node.outputs)
4437-
for o, no in zip(node.outputs, new_node.outputs):
4438-
mapping[o] = no
4437+
mapping.update(zip(node.outputs, new_node.outputs))
44394438

44404439
new_ins = [mapping[inp] for inp in fgraph.inputs]
44414440
new_outs = [mapping[out] for out in fgraph.outputs]

pytensor/scan/op.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2240,8 +2240,7 @@ def infer_shape(self, fgraph, node, input_shapes):
22402240
# Non-sequences have a direct equivalent from self.inner_inputs in
22412241
# node.inputs
22422242
inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :]
2243-
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
2244-
out_equivalent[in_ns] = out_ns
2243+
out_equivalent.update(zip(inner_non_sequences, node.inputs[offset:]))
22452244

22462245
if info.as_while:
22472246
self_outs = self.inner_outputs[:-1]

pytensor/scan/rewriting.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,9 +1623,7 @@ def scan_save_mem(fgraph, node):
16231623
(inps, outs, info, node_ins, compress_map) = compress_outs(
16241624
op, not_required, nw_inputs
16251625
)
1626-
inv_compress_map = {}
1627-
for k, v in compress_map.items():
1628-
inv_compress_map[v] = k
1626+
inv_compress_map = {v: k for k, v in compress_map.items()}
16291627

16301628
# 3.6 Compose the new scan
16311629
# TODO: currently we don't support scan with 0 step. So

pytensor/tensor/conv/abstract_conv.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,10 +585,11 @@ def assert_shape(x, expected_shape, msg="Unexpected shape."):
585585
if expected_shape is None or not config.conv__assert_shape:
586586
return x
587587
shape = x.shape
588-
tests = []
589-
for i in range(x.ndim):
590-
if expected_shape[i] is not None:
591-
tests.append(pt.eq(shape[i], expected_shape[i]))
588+
tests = [
589+
pt.eq(shape[i], expected_shape[i])
590+
for i in range(x.ndim)
591+
if expected_shape[i] is not None
592+
]
592593
if tests:
593594
return Assert(msg)(x, *tests)
594595
else:

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,10 +2107,7 @@ def perform(self, node, inp, out_):
21072107
out[0] = x.take(i, axis=0, out=o)
21082108

21092109
def connection_pattern(self, node):
2110-
rval = [[True]]
2111-
2112-
for ipt in node.inputs[1:]:
2113-
rval.append([False])
2110+
rval = [[True], *([False] for _ in node.inputs[1:])]
21142111

21152112
return rval
21162113

pytensor/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,19 @@ def maybe_add_to_os_environ_pathlist(var: str, newpath: Path | str) -> None:
123123
pass
124124

125125

126-
def subprocess_Popen(command, **params):
126+
def subprocess_Popen(command: str | list[str], **params):
127127
"""
128128
Utility function to work around windows behavior that open windows.
129129
130130
:see: call_subprocess_Popen and output_subprocess_Popen
131131
"""
132132
startupinfo = None
133133
if os.name == "nt":
134-
startupinfo = subprocess.STARTUPINFO()
134+
startupinfo = subprocess.STARTUPINFO() # type: ignore[attr-defined]
135135
try:
136-
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
136+
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type: ignore[attr-defined]
137137
except AttributeError:
138-
startupinfo.dwFlags |= subprocess._subprocess.STARTF_USESHOWWINDOW
138+
startupinfo.dwFlags |= subprocess._subprocess.STARTF_USESHOWWINDOW # type: ignore[attr-defined]
139139

140140
# Anaconda for Windows does not always provide .exe files
141141
# in the PATH, they also have .bat files that call the corresponding
@@ -156,7 +156,7 @@ def subprocess_Popen(command, **params):
156156
# with the default None values.
157157
stdin = None
158158
if "stdin" not in params:
159-
stdin = open(os.devnull)
159+
stdin = Path(os.devnull).open()
160160
params["stdin"] = stdin.fileno()
161161

162162
try:

tests/link/pytorch/test_sort.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.graph import FunctionGraph
5+
from pytensor.tensor import matrix
6+
from pytensor.tensor.sort import argsort, sort
7+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
8+
9+
10+
@pytest.mark.parametrize("func", (sort, argsort))
11+
@pytest.mark.parametrize(
12+
"axis",
13+
[
14+
pytest.param(0),
15+
pytest.param(1),
16+
pytest.param(
17+
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
18+
),
19+
],
20+
)
21+
def test_sort(func, axis):
22+
x = matrix("x", shape=(2, 2), dtype="float64")
23+
out = func(x, axis=axis)
24+
fgraph = FunctionGraph([x], [out])
25+
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
26+
compare_pytorch_and_py(fgraph, [arr])

0 commit comments

Comments
 (0)