@@ -10739,115 +10739,110 @@ class AdjointGenerator
10739
10739
}
10740
10740
}
10741
10741
10742
- // Don't erase any store that needs to be preserved for a
10743
- // rematerialization
10744
- {
10745
- auto found = gutils->rematerializableAllocations .find (orig);
10746
- if (found != gutils->rematerializableAllocations .end ()) {
10747
- // If rematerializing (e.g. needed in reverse, but not needing
10748
- // the whole allocation):
10749
- if (primalNeededInReverse && !cacheWholeAllocation) {
10750
- // if rematerialize, don't ever cache and downgrade to stack
10751
- // allocation where possible.
10752
- if (auto MD = hasMetadata (orig, " enzyme_fromstack" )) {
10753
- if (Mode == DerivativeMode::ReverseModeGradient &&
10754
- found->second .LI ) {
10755
- gutils->rematerializedPrimalOrShadowAllocations .push_back (
10756
- newCall);
10757
- } else {
10758
- IRBuilder<> B (newCall);
10759
-
10760
- Value *Size;
10761
- if (funcName == " malloc" )
10762
- Size = orig->getArgOperand (0 );
10763
- else if (funcName == " julia.gc_alloc_obj" ||
10764
- funcName == " jl_gc_alloc_typed" ||
10765
- funcName == " ijl_gc_alloc_typed" )
10766
- Size = orig->getArgOperand (1 );
10767
- else
10768
- llvm_unreachable (" Unknown allocation to upgrade" );
10769
- Size = gutils->getNewFromOriginal (Size);
10770
-
10771
- if (auto CI = dyn_cast<ConstantInt>(Size)) {
10772
- B.SetInsertPoint (gutils->inversionAllocs );
10773
- }
10742
+ std::function<void (MDNode *)> restoreFromStack = [&](MDNode *MD) {
10743
+ IRBuilder<> B (newCall);
10744
+ Value *Size;
10745
+ if (funcName == " malloc" )
10746
+ Size = orig->getArgOperand (0 );
10747
+ else if (funcName == " julia.gc_alloc_obj" ||
10748
+ funcName == " jl_gc_alloc_typed" ||
10749
+ funcName == " ijl_gc_alloc_typed" )
10750
+ Size = orig->getArgOperand (1 );
10751
+ else
10752
+ llvm_unreachable (" Unknown allocation to upgrade" );
10753
+ Size = gutils->getNewFromOriginal (Size);
10774
10754
10775
- Type *elTy = Type::getInt8Ty (orig->getContext ());
10776
- Instruction *I = nullptr ;
10755
+ if (auto CI = dyn_cast<ConstantInt>(Size)) {
10756
+ B.SetInsertPoint (gutils->inversionAllocs );
10757
+ }
10758
+ Type *elTy = Type::getInt8Ty (orig->getContext ());
10759
+ Instruction *I = nullptr ;
10777
10760
#if LLVM_VERSION_MAJOR >= 15
10778
- if (orig->getContext ().supportsTypedPointers ()) {
10779
- #endif
10780
- for (auto U : orig->users ()) {
10781
- if (hasMetadata (cast<Instruction>(U), " enzyme_caststack" )) {
10782
- elTy = U->getType ()->getPointerElementType ();
10783
- Value *tsize = ConstantInt::get (
10784
- Size->getType (), (gutils->newFunc ->getParent ()
10785
- ->getDataLayout ()
10786
- .getTypeAllocSizeInBits (elTy) +
10787
- 7 ) /
10788
- 8 );
10789
- Size = B.CreateUDiv (Size, tsize, " " , /* exact*/ true );
10790
- I = gutils->getNewFromOriginal (cast<Instruction>(U));
10791
- break ;
10792
- }
10793
- }
10761
+ if (orig->getContext ().supportsTypedPointers ()) {
10762
+ #endif
10763
+ for (auto U : orig->users ()) {
10764
+ if (hasMetadata (cast<Instruction>(U), " enzyme_caststack" )) {
10765
+ elTy = U->getType ()->getPointerElementType ();
10766
+ Value *tsize = ConstantInt::get (
10767
+ Size->getType (), (gutils->newFunc ->getParent ()
10768
+ ->getDataLayout ()
10769
+ .getTypeAllocSizeInBits (elTy) +
10770
+ 7 ) /
10771
+ 8 );
10772
+ Size = B.CreateUDiv (Size, tsize, " " , /* exact*/ true );
10773
+ I = gutils->getNewFromOriginal (cast<Instruction>(U));
10774
+ break ;
10775
+ }
10776
+ }
10794
10777
#if LLVM_VERSION_MAJOR >= 15
10795
- }
10778
+ }
10796
10779
#endif
10797
-
10798
- Value *replacement = B.CreateAlloca (elTy, Size);
10799
- if (I)
10800
- replacement->takeName (I);
10801
- else
10802
- replacement->takeName (newCall);
10803
-
10804
- auto Alignment =
10805
- cast<ConstantInt>(
10806
- cast<ConstantAsMetadata>(MD->getOperand (0 ))->getValue ())
10807
- ->getLimitedValue ();
10808
- // Don't set zero alignment
10809
- if (Alignment) {
10780
+ Value *replacement = B.CreateAlloca (elTy, Size);
10781
+ if (I)
10782
+ replacement->takeName (I);
10783
+ else
10784
+ replacement->takeName (newCall);
10785
+ auto Alignment =
10786
+ cast<ConstantInt>(
10787
+ cast<ConstantAsMetadata>(MD->getOperand (0 ))->getValue ())
10788
+ ->getLimitedValue ();
10789
+ // Don't set zero alignment
10790
+ if (Alignment) {
10810
10791
#if LLVM_VERSION_MAJOR >= 10
10811
- cast<AllocaInst>(replacement)->setAlignment (Align (Alignment));
10792
+ cast<AllocaInst>(replacement)->setAlignment (Align (Alignment));
10812
10793
#else
10813
- cast<AllocaInst>(replacement)->setAlignment (Alignment);
10794
+ cast<AllocaInst>(replacement)->setAlignment (Alignment);
10814
10795
#endif
10815
- }
10796
+ }
10816
10797
#if LLVM_VERSION_MAJOR >= 15
10817
- if (orig->getContext ().supportsTypedPointers ()) {
10798
+ if (orig->getContext ().supportsTypedPointers ()) {
10818
10799
#endif
10819
- if (orig->getType ()->getPointerElementType () != elTy)
10820
- replacement = B.CreatePointerCast (
10821
- replacement,
10822
- PointerType::getUnqual (
10823
- orig->getType ()->getPointerElementType ()));
10800
+ if (orig->getType ()->getPointerElementType () != elTy)
10801
+ replacement = B.CreatePointerCast (
10802
+ replacement, PointerType::getUnqual (
10803
+ orig->getType ()->getPointerElementType ()));
10824
10804
10825
10805
#if LLVM_VERSION_MAJOR >= 15
10826
- }
10806
+ }
10827
10807
#endif
10808
+ if (int AS = cast<PointerType>(orig->getType ())->getAddressSpace ()) {
10828
10809
10829
- if (int AS =
10830
- cast<PointerType>(orig->getType ())->getAddressSpace ()) {
10831
-
10832
- llvm::PointerType *PT;
10810
+ llvm::PointerType *PT;
10833
10811
#if LLVM_VERSION_MAJOR >= 15
10834
- if (orig->getContext ().supportsTypedPointers ()) {
10812
+ if (orig->getContext ().supportsTypedPointers ()) {
10835
10813
#endif
10836
- PT = PointerType::get (
10837
- orig->getType ()->getPointerElementType (), AS);
10814
+ PT = PointerType::get (orig->getType ()->getPointerElementType (), AS);
10838
10815
#if LLVM_VERSION_MAJOR >= 15
10839
- } else {
10840
- PT = PointerType::get (orig->getContext (), AS);
10841
- }
10816
+ } else {
10817
+ PT = PointerType::get (orig->getContext (), AS);
10818
+ }
10842
10819
#endif
10843
- replacement = B.CreateAddrSpaceCast (replacement, PT);
10844
- cast<Instruction>(replacement)
10845
- ->setMetadata (" enzyme_backstack" ,
10846
- MDNode::get (replacement->getContext (), {}));
10847
- }
10820
+ replacement = B.CreateAddrSpaceCast (replacement, PT);
10821
+ cast<Instruction>(replacement)
10822
+ ->setMetadata (" enzyme_backstack" ,
10823
+ MDNode::get (replacement->getContext (), {}));
10824
+ }
10825
+ gutils->replaceAWithB (newCall, replacement);
10826
+ gutils->erase (newCall);
10827
+ };
10848
10828
10849
- gutils->replaceAWithB (newCall, replacement);
10850
- gutils->erase (newCall);
10829
+ // Don't erase any store that needs to be preserved for a
10830
+ // rematerialization
10831
+ {
10832
+ auto found = gutils->rematerializableAllocations .find (orig);
10833
+ if (found != gutils->rematerializableAllocations .end ()) {
10834
+ // If rematerializing (e.g. needed in reverse, but not needing
10835
+ // the whole allocation):
10836
+ if (primalNeededInReverse && !cacheWholeAllocation) {
10837
+ // if rematerialize, don't ever cache and downgrade to stack
10838
+ // allocation where possible.
10839
+ if (auto MD = hasMetadata (orig, " enzyme_fromstack" )) {
10840
+ if (Mode == DerivativeMode::ReverseModeGradient &&
10841
+ found->second .LI ) {
10842
+ gutils->rematerializedPrimalOrShadowAllocations .push_back (
10843
+ newCall);
10844
+ } else {
10845
+ restoreFromStack (MD);
10851
10846
}
10852
10847
return ;
10853
10848
}
@@ -10896,97 +10891,7 @@ class AdjointGenerator
10896
10891
if (Mode == DerivativeMode::ReverseModeGradient)
10897
10892
eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
10898
10893
else if (auto MD = hasMetadata (orig, " enzyme_fromstack" )) {
10899
- IRBuilder<> B (newCall);
10900
-
10901
- Value *Size;
10902
- if (funcName == " malloc" )
10903
- Size = orig->getArgOperand (0 );
10904
- else if (funcName == " julia.gc_alloc_obj" ||
10905
- funcName == " jl_gc_alloc_typed" ||
10906
- funcName == " ijl_gc_alloc_typed" )
10907
- Size = orig->getArgOperand (1 );
10908
- else
10909
- llvm_unreachable (" Unknown allocation to upgrade" );
10910
- Size = gutils->getNewFromOriginal (Size);
10911
-
10912
- if (auto CI = dyn_cast<ConstantInt>(Size)) {
10913
- B.SetInsertPoint (gutils->inversionAllocs );
10914
- }
10915
-
10916
- Type *elTy = Type::getInt8Ty (orig->getContext ());
10917
- Instruction *I = nullptr ;
10918
- #if LLVM_VERSION_MAJOR >= 15
10919
- if (orig->getContext ().supportsTypedPointers ()) {
10920
- #endif
10921
- for (auto U : orig->users ()) {
10922
- if (hasMetadata (cast<Instruction>(U), " enzyme_caststack" )) {
10923
- elTy = U->getType ()->getPointerElementType ();
10924
- Value *tsize = ConstantInt::get (
10925
- Size->getType (), (gutils->newFunc ->getParent ()
10926
- ->getDataLayout ()
10927
- .getTypeAllocSizeInBits (elTy) +
10928
- 7 ) /
10929
- 8 );
10930
- Size = B.CreateUDiv (Size, tsize, " " , /* exact*/ true );
10931
- I = gutils->getNewFromOriginal (cast<Instruction>(U));
10932
- break ;
10933
- }
10934
- }
10935
- #if LLVM_VERSION_MAJOR >= 15
10936
- }
10937
- #endif
10938
-
10939
- Value *replacement = B.CreateAlloca (elTy, Size);
10940
- if (I)
10941
- replacement->takeName (I);
10942
- else
10943
- replacement->takeName (newCall);
10944
- auto Alignment =
10945
- cast<ConstantInt>(
10946
- cast<ConstantAsMetadata>(MD->getOperand (0 ))->getValue ())
10947
- ->getLimitedValue ();
10948
- // Don't set zero alignment
10949
- if (Alignment) {
10950
- #if LLVM_VERSION_MAJOR >= 10
10951
- cast<AllocaInst>(replacement)->setAlignment (Align (Alignment));
10952
- #else
10953
- cast<AllocaInst>(replacement)->setAlignment (Alignment);
10954
- #endif
10955
- }
10956
-
10957
- #if LLVM_VERSION_MAJOR >= 15
10958
- if (orig->getContext ().supportsTypedPointers ()) {
10959
- #endif
10960
- if (orig->getType ()->getPointerElementType () != elTy)
10961
- replacement = B.CreatePointerCast (
10962
- replacement,
10963
- PointerType::getUnqual (
10964
- orig->getType ()->getPointerElementType ()));
10965
-
10966
- #if LLVM_VERSION_MAJOR >= 15
10967
- }
10968
- #endif
10969
- if (int AS =
10970
- cast<PointerType>(orig->getType ())->getAddressSpace ()) {
10971
- llvm::PointerType *PT;
10972
- #if LLVM_VERSION_MAJOR >= 15
10973
- if (orig->getContext ().supportsTypedPointers ()) {
10974
- #endif
10975
- PT = PointerType::get (
10976
- orig->getType ()->getPointerElementType (), AS);
10977
- #if LLVM_VERSION_MAJOR >= 15
10978
- } else {
10979
- PT = PointerType::get (orig->getContext (), AS);
10980
- }
10981
- #endif
10982
- replacement = B.CreateAddrSpaceCast (replacement, PT);
10983
- cast<Instruction>(replacement)
10984
- ->setMetadata (" enzyme_backstack" ,
10985
- MDNode::get (replacement->getContext (), {}));
10986
- }
10987
-
10988
- gutils->replaceAWithB (newCall, replacement);
10989
- gutils->erase (newCall);
10894
+ restoreFromStack (MD);
10990
10895
}
10991
10896
return ;
10992
10897
}
@@ -11004,92 +10909,7 @@ class AdjointGenerator
11004
10909
eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
11005
10910
} else {
11006
10911
if (auto MD = hasMetadata (orig, " enzyme_fromstack" )) {
11007
- IRBuilder<> B (newCall);
11008
- Value *Size;
11009
- if (funcName == " malloc" )
11010
- Size = orig->getArgOperand (0 );
11011
- else if (funcName == " julia.gc_alloc_obj" ||
11012
- funcName == " jl_gc_alloc_typed" ||
11013
- funcName == " ijl_gc_alloc_typed" )
11014
- Size = orig->getArgOperand (1 );
11015
- else
11016
- llvm_unreachable (" Unknown allocation to upgrade" );
11017
- Size = gutils->getNewFromOriginal (Size);
11018
-
11019
- if (auto CI = dyn_cast<ConstantInt>(Size)) {
11020
- B.SetInsertPoint (gutils->inversionAllocs );
11021
- }
11022
- Type *elTy = Type::getInt8Ty (orig->getContext ());
11023
- Instruction *I = nullptr ;
11024
- #if LLVM_VERSION_MAJOR >= 15
11025
- if (orig->getContext ().supportsTypedPointers ()) {
11026
- #endif
11027
- for (auto U : orig->users ()) {
11028
- if (hasMetadata (cast<Instruction>(U), " enzyme_caststack" )) {
11029
- elTy = U->getType ()->getPointerElementType ();
11030
- Value *tsize = ConstantInt::get (
11031
- Size->getType (), (gutils->newFunc ->getParent ()
11032
- ->getDataLayout ()
11033
- .getTypeAllocSizeInBits (elTy) +
11034
- 7 ) /
11035
- 8 );
11036
- Size = B.CreateUDiv (Size, tsize, " " , /* exact*/ true );
11037
- I = gutils->getNewFromOriginal (cast<Instruction>(U));
11038
- break ;
11039
- }
11040
- }
11041
- #if LLVM_VERSION_MAJOR >= 15
11042
- }
11043
- #endif
11044
- Value *replacement = B.CreateAlloca (elTy, Size);
11045
- if (I)
11046
- replacement->takeName (I);
11047
- else
11048
- replacement->takeName (newCall);
11049
- auto Alignment =
11050
- cast<ConstantInt>(
11051
- cast<ConstantAsMetadata>(MD->getOperand (0 ))->getValue ())
11052
- ->getLimitedValue ();
11053
- // Don't set zero alignment
11054
- if (Alignment) {
11055
- #if LLVM_VERSION_MAJOR >= 10
11056
- cast<AllocaInst>(replacement)->setAlignment (Align (Alignment));
11057
- #else
11058
- cast<AllocaInst>(replacement)->setAlignment (Alignment);
11059
- #endif
11060
- }
11061
- #if LLVM_VERSION_MAJOR >= 15
11062
- if (orig->getContext ().supportsTypedPointers ()) {
11063
- #endif
11064
- if (orig->getType ()->getPointerElementType () != elTy)
11065
- replacement = B.CreatePointerCast (
11066
- replacement, PointerType::getUnqual (
11067
- orig->getType ()->getPointerElementType ()));
11068
-
11069
- #if LLVM_VERSION_MAJOR >= 15
11070
- }
11071
- #endif
11072
- if (int AS =
11073
- cast<PointerType>(orig->getType ())->getAddressSpace ()) {
11074
-
11075
- llvm::PointerType *PT;
11076
- #if LLVM_VERSION_MAJOR >= 15
11077
- if (orig->getContext ().supportsTypedPointers ()) {
11078
- #endif
11079
- PT = PointerType::get (orig->getType ()->getPointerElementType (),
11080
- AS);
11081
- #if LLVM_VERSION_MAJOR >= 15
11082
- } else {
11083
- PT = PointerType::get (orig->getContext (), AS);
11084
- }
11085
- #endif
11086
- replacement = B.CreateAddrSpaceCast (replacement, PT);
11087
- cast<Instruction>(replacement)
11088
- ->setMetadata (" enzyme_backstack" ,
11089
- MDNode::get (replacement->getContext (), {}));
11090
- }
11091
- gutils->replaceAWithB (newCall, replacement);
11092
- gutils->erase (newCall);
10912
+ restoreFromStack (MD);
11093
10913
}
11094
10914
}
11095
10915
return ;
0 commit comments