Skip to content

Commit b6ce485

Browse files
author
Ian Schweer
committed
Fix linker
2 parents 5bd100e + ef97287 commit b6ce485

File tree

16 files changed

+292
-559
lines changed

16 files changed

+292
-559
lines changed

pytensor/compile/function/types.py

Lines changed: 77 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,8 @@ def __init__(
393393
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
394394
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
395395

396+
self.has_defaults = any(refeed for _, refeed, _ in self.defaults)
397+
396398
# Group indexes of inputs that are potentially aliased to each other
397399
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
398400
# even though there could be two distinct types that use the same kinds of underlying objects.
@@ -540,14 +542,40 @@ def __contains__(self, item):
540542
self._value = ValueAttribute()
541543
self._container = ContainerAttribute()
542544

543-
# TODO: Get rid of all this `expanded_inputs` nonsense
544-
assert len(self.maker.expanded_inputs) == len(self.input_storage)
545+
update_storage = [
546+
container
547+
for inp, container in zip(
548+
self.maker.expanded_inputs, input_storage, strict=True
549+
)
550+
if inp.update is not None
551+
]
552+
# Updates are the last inner outputs that are not returned by Function.__call__
553+
self.n_returned_outputs = len(self.output_storage) - len(update_storage)
554+
555+
# Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
556+
self.update_input_storage: tuple[int, Container] = ()
557+
if getattr(vm, "need_update_inputs", True):
558+
self.update_input_storage = tuple(
559+
zip(
560+
range(self.n_returned_outputs, len(output_storage)),
561+
update_storage,
562+
strict=True,
563+
)
564+
)
545565

546-
# This is used only when `vm.need_update_inputs` is `False`, because
547-
# we're using one of the VM objects and it is putting updates back into
548-
# the input containers all by itself.
549-
self.n_returned_outputs = len(self.output_storage) - sum(
550-
inp.update is not None for inp in self.maker.expanded_inputs
566+
# In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
567+
# After the call, we want to erase (some of) these references, to allow Python to GC them if unused
568+
# Required input containers are the non-default inputs, must always be provided again, so we GC them
569+
self.clear_input_storage_data = tuple(
570+
container.storage for container in input_storage if container.required
571+
)
572+
# This is only done when `vm.allow_gc` is True, which can change at runtime.
573+
self.clear_output_storage_data = tuple(
574+
container.storage
575+
for container, variable in zip(
576+
self.output_storage, self.maker.fgraph.outputs, strict=True
577+
)
578+
if variable.owner is not None # Not a constant output
551579
)
552580

553581
for node in self.maker.fgraph.apply_nodes:
@@ -747,7 +775,7 @@ def checkSV(sv_ori, sv_rpl):
747775
elif isinstance(profile, str):
748776
profile = pytensor.compile.profiling.ProfileStats(message=profile)
749777

750-
f_cpy = maker.__class__(
778+
f_cpy = type(maker)(
751779
inputs=ins,
752780
outputs=outs,
753781
fgraph=fg_cpy,
@@ -765,6 +793,8 @@ def checkSV(sv_ori, sv_rpl):
765793
# check that.
766794
accept_inplace=True,
767795
no_fgraph_prep=True,
796+
output_keys=maker.output_keys,
797+
name=name,
768798
).create(input_storage, storage_map=new_storage_map)
769799

770800
for in_ori, in_cpy, ori, cpy in zip(
@@ -797,8 +827,6 @@ def checkSV(sv_ori, sv_rpl):
797827

798828
f_cpy.trust_input = self.trust_input
799829
f_cpy.unpack_single = self.unpack_single
800-
f_cpy.name = name
801-
f_cpy.maker.fgraph.name = name
802830
return f_cpy
803831

804832
def _restore_defaults(self):
@@ -808,7 +836,7 @@ def _restore_defaults(self):
808836
value = value.storage[0]
809837
self[i] = value
810838

811-
def __call__(self, *args, **kwargs):
839+
def __call__(self, *args, output_subset=None, **kwargs):
812840
"""
813841
Evaluates value of a function on given arguments.
814842
@@ -836,20 +864,21 @@ def __call__(self, *args, **kwargs):
836864
List of outputs on indices/keys from ``output_subset`` or all of them,
837865
if ``output_subset`` is not passed.
838866
"""
867+
trust_input = self.trust_input
839868
input_storage = self.input_storage
869+
vm = self.vm
840870
profile = self.profile
841871

842872
if profile:
843873
t0 = time.perf_counter()
844874

845-
output_subset = kwargs.pop("output_subset", None)
846875
if output_subset is not None:
847876
warnings.warn("output_subset is deprecated.", FutureWarning)
848877
if self.output_keys is not None:
849878
output_subset = [self.output_keys.index(key) for key in output_subset]
850879

851880
# Reinitialize each container's 'provided' counter
852-
if self.trust_input:
881+
if trust_input:
853882
for arg_container, arg in zip(input_storage, args, strict=False):
854883
arg_container.storage[0] = arg
855884
else:
@@ -908,7 +937,7 @@ def __call__(self, *args, **kwargs):
908937
for k, arg in kwargs.items():
909938
self[k] = arg
910939

911-
if not self.trust_input:
940+
if not trust_input:
912941
# Collect aliased inputs among the storage space
913942
for potential_group in self._potential_aliased_input_groups:
914943
args_share_memory: list[list[int]] = []
@@ -960,11 +989,7 @@ def __call__(self, *args, **kwargs):
960989
if profile:
961990
t0_fn = time.perf_counter()
962991
try:
963-
outputs = (
964-
self.vm()
965-
if output_subset is None
966-
else self.vm(output_subset=output_subset)
967-
)
992+
outputs = vm() if output_subset is None else vm(output_subset=output_subset)
968993
except Exception:
969994
self._restore_defaults()
970995
if hasattr(self.vm, "position_of_error"):
@@ -991,73 +1016,53 @@ def __call__(self, *args, **kwargs):
9911016

9921017
# Retrieve the values that were computed
9931018
if outputs is None:
994-
outputs = [x.data for x in self.output_storage]
995-
996-
# Remove internal references to required inputs.
997-
# These cannot be re-used anyway.
998-
for arg_container in input_storage:
999-
if arg_container.required:
1000-
arg_container.storage[0] = None
1001-
1002-
# if we are allowing garbage collection, remove the
1003-
# output reference from the internal storage cells
1004-
if getattr(self.vm, "allow_gc", False):
1005-
# strict=False because we are in a hot loop
1006-
for o_container, o_variable in zip(
1007-
self.output_storage, self.maker.fgraph.outputs, strict=False
1008-
):
1009-
if o_variable.owner is not None:
1010-
# this node is the variable of computation
1011-
# WARNING: This circumvents the 'readonly' attribute in x
1012-
o_container.storage[0] = None
1013-
1014-
if getattr(self.vm, "need_update_inputs", True):
1015-
# Update the inputs that have an update function
1016-
# strict=False because we are in a hot loop
1017-
for input, storage in reversed(
1018-
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
1019-
):
1020-
if input.update is not None:
1021-
storage.data = outputs.pop()
1022-
else:
1023-
outputs = outputs[: self.n_returned_outputs]
1019+
outputs = [x.storage[0] for x in self.output_storage]
1020+
1021+
# Set updates and filter them out from the returned outputs
1022+
for i, input_storage in self.update_input_storage:
1023+
input_storage.storage[0] = outputs[i]
1024+
outputs = outputs[: self.n_returned_outputs]
1025+
1026+
# Remove input and output values from storage data
1027+
for storage_data in self.clear_input_storage_data:
1028+
storage_data[0] = None
1029+
if getattr(vm, "allow_gc", False):
1030+
for storage_data in self.clear_output_storage_data:
1031+
storage_data[0] = None
10241032

10251033
# Put default values back in the storage
1026-
self._restore_defaults()
1034+
if self.has_defaults:
1035+
self._restore_defaults()
10271036

10281037
if profile:
10291038
dt_call = time.perf_counter() - t0
10301039
pytensor.compile.profiling.total_fct_exec_time += dt_call
10311040
self.maker.mode.call_time += dt_call
10321041
profile.fct_callcount += 1
10331042
profile.fct_call_time += dt_call
1034-
if hasattr(self.vm, "update_profile"):
1035-
self.vm.update_profile(profile)
1043+
if hasattr(vm, "update_profile"):
1044+
vm.update_profile(profile)
10361045
if profile.ignore_first_call:
10371046
profile.reset()
10381047
profile.ignore_first_call = False
10391048

10401049
if self.return_none:
10411050
return None
1042-
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
1043-
return outputs[0]
1044-
else:
1045-
if self.output_keys is not None:
1046-
assert len(self.output_keys) == len(outputs)
10471051

1048-
if output_subset is None:
1049-
# strict=False because we are in a hot loop
1050-
return dict(zip(self.output_keys, outputs, strict=False))
1051-
else:
1052-
return {
1053-
self.output_keys[index]: outputs[index]
1054-
for index in output_subset
1055-
}
1052+
if output_subset is not None:
1053+
outputs = [outputs[i] for i in output_subset]
10561054

1057-
if output_subset is None:
1058-
return outputs
1055+
if self.output_keys is None:
1056+
if self.unpack_single:
1057+
[out] = outputs
1058+
return out
10591059
else:
1060-
return [outputs[i] for i in output_subset]
1060+
return outputs
1061+
else:
1062+
output_keys = self.output_keys
1063+
if output_subset is not None:
1064+
output_keys = [output_keys[i] for i in output_subset]
1065+
return dict(zip(output_keys, outputs, strict=True))
10611066

10621067
value = property(
10631068
lambda self: self._value,
@@ -1077,9 +1082,10 @@ def free(self):
10771082
# 1.no allow_gc return False
10781083
# 2.has allow_gc, if allow_gc is False, return True
10791084
if not getattr(self.vm, "allow_gc", True):
1080-
for key in self.vm.storage_map:
1081-
if not isinstance(key, Constant):
1082-
self.vm.storage_map[key][0] = None
1085+
storage_map = self.vm.storage_map
1086+
for key, value in storage_map.items():
1087+
if key.owner is not None: # Not a constant
1088+
value[0] = None
10831089

10841090
for node in self.nodes_with_inner_function:
10851091
if hasattr(node.fn, "free"):
@@ -1091,10 +1097,6 @@ def get_shared(self):
10911097
"""
10921098
return [i.variable for i in self.maker.inputs if i.implicit]
10931099

1094-
def sync_shared(self):
1095-
# NOTE: sync was needed on old gpu backend
1096-
pass
1097-
10981100
def dprint(self, **kwargs):
10991101
"""Debug print itself
11001102

pytensor/link/basic.py

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -653,41 +653,36 @@ def create_jitable_thunk(
653653
)
654654

655655
thunk_inputs = self.create_thunk_inputs(storage_map)
656-
657-
thunks = []
658-
659656
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
660-
661657
fgraph_jit = self.jit_compile(converted_fgraph)
662658

663659
def thunk(
664-
fgraph=self.fgraph,
665660
fgraph_jit=fgraph_jit,
666661
thunk_inputs=thunk_inputs,
667662
thunk_outputs=thunk_outputs,
668663
):
669-
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
664+
try:
665+
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
666+
except Exception:
667+
# TODO: Should we add a fake node that combines all outputs,
668+
# since the error may come from any of them?
669+
raise_with_op(self.fgraph, output_nodes[0], thunk)
670670

671671
# strict=False because we are in a hot loop
672-
for o_var, o_storage, o_val in zip(
673-
fgraph.outputs, thunk_outputs, outputs, strict=False
674-
):
675-
compute_map[o_var][0] = True
676-
o_storage[0] = self.output_filter(o_var, o_val)
677-
return outputs
672+
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
673+
o_storage[0] = o_val
678674

679675
thunk.inputs = thunk_inputs
680676
thunk.outputs = thunk_outputs
681677
thunk.lazy = False
682678

683-
thunks.append(thunk)
679+
thunks = [thunk]
684680

685681
return thunks, output_nodes, fgraph_jit
686682

687683
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
688684
fgraph = self.fgraph
689685
nodes = self.schedule(fgraph)
690-
no_recycling = self.no_recycling
691686

692687
input_storage, output_storage, storage_map = map_storage(
693688
fgraph, nodes, input_storage, output_storage, storage_map
@@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
701696
compute_map, nodes, input_storage, output_storage, storage_map
702697
)
703698

704-
computed, last_user = gc_helper(nodes)
705-
706-
if self.allow_gc:
707-
post_thunk_old_storage = [
708-
[
709-
storage_map[input]
710-
for input in node.inputs
711-
if (input in computed)
712-
and (input not in fgraph.outputs)
713-
and (node == last_user[input])
714-
]
715-
for node in nodes
716-
]
717-
else:
718-
post_thunk_old_storage = None
719-
720-
if no_recycling is True:
721-
no_recycling = list(storage_map.values())
722-
no_recycling = difference(no_recycling, input_storage)
723-
else:
724-
no_recycling = [
725-
storage_map[r] for r in no_recycling if r not in fgraph.inputs
726-
]
727-
728-
fn = streamline(
729-
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
730-
)
731-
699+
[fn] = thunks
732700
fn.jit_fn = jit_fn
733701
fn.allow_gc = self.allow_gc
734702
fn.storage_map = storage_map

pytensor/link/jax/linker.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from numpy.random import Generator, RandomState
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
6-
from pytensor.graph.basic import Constant
76
from pytensor.link.basic import JITLinker
87

98

@@ -72,12 +71,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
7271
def jit_compile(self, fn):
7372
import jax
7473

75-
# I suppose we can consider `Constant`s to be "static" according to
76-
# JAX.
77-
static_argnums = [
78-
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
79-
]
80-
return jax.jit(fn, static_argnums=static_argnums)
74+
return jax.jit(fn)
8175

8276
def create_thunk_inputs(self, storage_map):
8377
from pytensor.link.jax.dispatch import jax_typify

0 commit comments

Comments
 (0)