@@ -393,6 +393,8 @@ def __init__(
393
393
assert len (self .input_storage ) == len (self .maker .fgraph .inputs )
394
394
assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
395
395
396
+ self .has_defaults = any (refeed for _ , refeed , _ in self .defaults )
397
+
396
398
# Group indexes of inputs that are potentially aliased to each other
397
399
# Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
398
400
# even though there could be two distinct types that use the same kinds of underlying objects.
@@ -540,14 +542,40 @@ def __contains__(self, item):
540
542
self ._value = ValueAttribute ()
541
543
self ._container = ContainerAttribute ()
542
544
543
- # TODO: Get rid of all this `expanded_inputs` nonsense
544
- assert len (self .maker .expanded_inputs ) == len (self .input_storage )
545
+ update_storage = [
546
+ container
547
+ for inp , container in zip (
548
+ self .maker .expanded_inputs , input_storage , strict = True
549
+ )
550
+ if inp .update is not None
551
+ ]
552
+ # Updates are the last inner outputs that are not returned by Function.__call__
553
+ self .n_returned_outputs = len (self .output_storage ) - len (update_storage )
554
+
555
+ # Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
556
+ self .update_input_storage : tuple [int , Container ] = ()
557
+ if getattr (vm , "need_update_inputs" , True ):
558
+ self .update_input_storage = tuple (
559
+ zip (
560
+ range (self .n_returned_outputs , len (output_storage )),
561
+ update_storage ,
562
+ strict = True ,
563
+ )
564
+ )
545
565
546
- # This is used only when `vm.need_update_inputs` is `False`, because
547
- # we're using one of the VM objects and it is putting updates back into
548
- # the input containers all by itself.
549
- self .n_returned_outputs = len (self .output_storage ) - sum (
550
- inp .update is not None for inp in self .maker .expanded_inputs
566
+ # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
567
+ # After the call, we want to erase (some of) these references, to allow Python to GC them if unused
568
+ # Required input containers are the non-default inputs, must always be provided again, so we GC them
569
+ self .clear_input_storage_data = tuple (
570
+ container .storage for container in input_storage if container .required
571
+ )
572
+ # This is only done when `vm.allow_gc` is True, which can change at runtime.
573
+ self .clear_output_storage_data = tuple (
574
+ container .storage
575
+ for container , variable in zip (
576
+ self .output_storage , self .maker .fgraph .outputs , strict = True
577
+ )
578
+ if variable .owner is not None # Not a constant output
551
579
)
552
580
553
581
for node in self .maker .fgraph .apply_nodes :
@@ -747,7 +775,7 @@ def checkSV(sv_ori, sv_rpl):
747
775
elif isinstance (profile , str ):
748
776
profile = pytensor .compile .profiling .ProfileStats (message = profile )
749
777
750
- f_cpy = maker . __class__ (
778
+ f_cpy = type ( maker ) (
751
779
inputs = ins ,
752
780
outputs = outs ,
753
781
fgraph = fg_cpy ,
@@ -765,6 +793,8 @@ def checkSV(sv_ori, sv_rpl):
765
793
# check that.
766
794
accept_inplace = True ,
767
795
no_fgraph_prep = True ,
796
+ output_keys = maker .output_keys ,
797
+ name = name ,
768
798
).create (input_storage , storage_map = new_storage_map )
769
799
770
800
for in_ori , in_cpy , ori , cpy in zip (
@@ -797,8 +827,6 @@ def checkSV(sv_ori, sv_rpl):
797
827
798
828
f_cpy .trust_input = self .trust_input
799
829
f_cpy .unpack_single = self .unpack_single
800
- f_cpy .name = name
801
- f_cpy .maker .fgraph .name = name
802
830
return f_cpy
803
831
804
832
def _restore_defaults (self ):
@@ -808,7 +836,7 @@ def _restore_defaults(self):
808
836
value = value .storage [0 ]
809
837
self [i ] = value
810
838
811
- def __call__ (self , * args , ** kwargs ):
839
+ def __call__ (self , * args , output_subset = None , ** kwargs ):
812
840
"""
813
841
Evaluates value of a function on given arguments.
814
842
@@ -836,20 +864,21 @@ def __call__(self, *args, **kwargs):
836
864
List of outputs on indices/keys from ``output_subset`` or all of them,
837
865
if ``output_subset`` is not passed.
838
866
"""
867
+ trust_input = self .trust_input
839
868
input_storage = self .input_storage
869
+ vm = self .vm
840
870
profile = self .profile
841
871
842
872
if profile :
843
873
t0 = time .perf_counter ()
844
874
845
- output_subset = kwargs .pop ("output_subset" , None )
846
875
if output_subset is not None :
847
876
warnings .warn ("output_subset is deprecated." , FutureWarning )
848
877
if self .output_keys is not None :
849
878
output_subset = [self .output_keys .index (key ) for key in output_subset ]
850
879
851
880
# Reinitialize each container's 'provided' counter
852
- if self . trust_input :
881
+ if trust_input :
853
882
for arg_container , arg in zip (input_storage , args , strict = False ):
854
883
arg_container .storage [0 ] = arg
855
884
else :
@@ -908,7 +937,7 @@ def __call__(self, *args, **kwargs):
908
937
for k , arg in kwargs .items ():
909
938
self [k ] = arg
910
939
911
- if not self . trust_input :
940
+ if not trust_input :
912
941
# Collect aliased inputs among the storage space
913
942
for potential_group in self ._potential_aliased_input_groups :
914
943
args_share_memory : list [list [int ]] = []
@@ -960,11 +989,7 @@ def __call__(self, *args, **kwargs):
960
989
if profile :
961
990
t0_fn = time .perf_counter ()
962
991
try :
963
- outputs = (
964
- self .vm ()
965
- if output_subset is None
966
- else self .vm (output_subset = output_subset )
967
- )
992
+ outputs = vm () if output_subset is None else vm (output_subset = output_subset )
968
993
except Exception :
969
994
self ._restore_defaults ()
970
995
if hasattr (self .vm , "position_of_error" ):
@@ -991,73 +1016,53 @@ def __call__(self, *args, **kwargs):
991
1016
992
1017
# Retrieve the values that were computed
993
1018
if outputs is None :
994
- outputs = [x .data for x in self .output_storage ]
995
-
996
- # Remove internal references to required inputs.
997
- # These cannot be re-used anyway.
998
- for arg_container in input_storage :
999
- if arg_container .required :
1000
- arg_container .storage [0 ] = None
1001
-
1002
- # if we are allowing garbage collection, remove the
1003
- # output reference from the internal storage cells
1004
- if getattr (self .vm , "allow_gc" , False ):
1005
- # strict=False because we are in a hot loop
1006
- for o_container , o_variable in zip (
1007
- self .output_storage , self .maker .fgraph .outputs , strict = False
1008
- ):
1009
- if o_variable .owner is not None :
1010
- # this node is the variable of computation
1011
- # WARNING: This circumvents the 'readonly' attribute in x
1012
- o_container .storage [0 ] = None
1013
-
1014
- if getattr (self .vm , "need_update_inputs" , True ):
1015
- # Update the inputs that have an update function
1016
- # strict=False because we are in a hot loop
1017
- for input , storage in reversed (
1018
- list (zip (self .maker .expanded_inputs , input_storage , strict = False ))
1019
- ):
1020
- if input .update is not None :
1021
- storage .data = outputs .pop ()
1022
- else :
1023
- outputs = outputs [: self .n_returned_outputs ]
1019
+ outputs = [x .storage [0 ] for x in self .output_storage ]
1020
+
1021
+ # Set updates and filter them out from the returned outputs
1022
+ for i , input_storage in self .update_input_storage :
1023
+ input_storage .storage [0 ] = outputs [i ]
1024
+ outputs = outputs [: self .n_returned_outputs ]
1025
+
1026
+ # Remove input and output values from storage data
1027
+ for storage_data in self .clear_input_storage_data :
1028
+ storage_data [0 ] = None
1029
+ if getattr (vm , "allow_gc" , False ):
1030
+ for storage_data in self .clear_output_storage_data :
1031
+ storage_data [0 ] = None
1024
1032
1025
1033
# Put default values back in the storage
1026
- self ._restore_defaults ()
1034
+ if self .has_defaults :
1035
+ self ._restore_defaults ()
1027
1036
1028
1037
if profile :
1029
1038
dt_call = time .perf_counter () - t0
1030
1039
pytensor .compile .profiling .total_fct_exec_time += dt_call
1031
1040
self .maker .mode .call_time += dt_call
1032
1041
profile .fct_callcount += 1
1033
1042
profile .fct_call_time += dt_call
1034
- if hasattr (self . vm , "update_profile" ):
1035
- self . vm .update_profile (profile )
1043
+ if hasattr (vm , "update_profile" ):
1044
+ vm .update_profile (profile )
1036
1045
if profile .ignore_first_call :
1037
1046
profile .reset ()
1038
1047
profile .ignore_first_call = False
1039
1048
1040
1049
if self .return_none :
1041
1050
return None
1042
- elif self .unpack_single and len (outputs ) == 1 and output_subset is None :
1043
- return outputs [0 ]
1044
- else :
1045
- if self .output_keys is not None :
1046
- assert len (self .output_keys ) == len (outputs )
1047
1051
1048
- if output_subset is None :
1049
- # strict=False because we are in a hot loop
1050
- return dict (zip (self .output_keys , outputs , strict = False ))
1051
- else :
1052
- return {
1053
- self .output_keys [index ]: outputs [index ]
1054
- for index in output_subset
1055
- }
1052
+ if output_subset is not None :
1053
+ outputs = [outputs [i ] for i in output_subset ]
1056
1054
1057
- if output_subset is None :
1058
- return outputs
1055
+ if self .output_keys is None :
1056
+ if self .unpack_single :
1057
+ [out ] = outputs
1058
+ return out
1059
1059
else :
1060
- return [outputs [i ] for i in output_subset ]
1060
+ return outputs
1061
+ else :
1062
+ output_keys = self .output_keys
1063
+ if output_subset is not None :
1064
+ output_keys = [output_keys [i ] for i in output_subset ]
1065
+ return dict (zip (output_keys , outputs , strict = True ))
1061
1066
1062
1067
value = property (
1063
1068
lambda self : self ._value ,
@@ -1077,9 +1082,10 @@ def free(self):
1077
1082
# 1.no allow_gc return False
1078
1083
# 2.has allow_gc, if allow_gc is False, return True
1079
1084
if not getattr (self .vm , "allow_gc" , True ):
1080
- for key in self .vm .storage_map :
1081
- if not isinstance (key , Constant ):
1082
- self .vm .storage_map [key ][0 ] = None
1085
+ storage_map = self .vm .storage_map
1086
+ for key , value in storage_map .items ():
1087
+ if key .owner is not None : # Not a constant
1088
+ value [0 ] = None
1083
1089
1084
1090
for node in self .nodes_with_inner_function :
1085
1091
if hasattr (node .fn , "free" ):
@@ -1091,10 +1097,6 @@ def get_shared(self):
1091
1097
"""
1092
1098
return [i .variable for i in self .maker .inputs if i .implicit ]
1093
1099
1094
- def sync_shared (self ):
1095
- # NOTE: sync was needed on old gpu backend
1096
- pass
1097
-
1098
1100
def dprint (self , ** kwargs ):
1099
1101
"""Debug print itself
1100
1102
0 commit comments