@@ -569,6 +569,18 @@ pub(super) fn thir_abstract_const<'tcx>(
569
569
}
570
570
}
571
571
572
+ /// Tries to unify two abstract constants using structural equality.
573
+ #[ instrument( skip( tcx) , level = "debug" ) ]
574
+ pub ( super ) fn try_unify < ' tcx > (
575
+ tcx : TyCtxt < ' tcx > ,
576
+ a : AbstractConst < ' tcx > ,
577
+ b : AbstractConst < ' tcx > ,
578
+ param_env : ty:: ParamEnv < ' tcx > ,
579
+ ) -> bool {
580
+ let const_unify_ctxt = ConstUnifyCtxt :: new ( tcx, param_env) ;
581
+ const_unify_ctxt. try_unify_inner ( a, b)
582
+ }
583
+
572
584
pub ( super ) fn try_unify_abstract_consts < ' tcx > (
573
585
tcx : TyCtxt < ' tcx > ,
574
586
( a, b) : ( ty:: Unevaluated < ' tcx , ( ) > , ty:: Unevaluated < ' tcx , ( ) > ) ,
@@ -622,115 +634,119 @@ where
622
634
recurse ( tcx, ct, & mut f)
623
635
}
624
636
625
- // Substitutes generics repeatedly to allow AbstractConsts to unify where a
626
- // ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g.
627
- // Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])]
628
- #[ inline]
629
- #[ instrument( skip( tcx) , level = "debug" ) ]
630
- fn try_replace_substs_in_root < ' tcx > (
637
+ pub ( super ) struct ConstUnifyCtxt < ' tcx > {
631
638
tcx : TyCtxt < ' tcx > ,
632
- mut abstr_const : AbstractConst < ' tcx > ,
633
- ) -> Option < AbstractConst < ' tcx > > {
634
- while let Node :: Leaf ( ct) = abstr_const. root ( tcx) {
635
- match AbstractConst :: from_const ( tcx, ct) {
636
- Ok ( Some ( act) ) => abstr_const = act,
637
- Ok ( None ) => break ,
638
- Err ( _) => return None ,
639
- }
640
- }
641
-
642
- Some ( abstr_const)
639
+ param_env : ty:: ParamEnv < ' tcx > ,
643
640
}
644
641
645
- /// Tries to unify two abstract constants using structural equality.
646
- #[ instrument( skip( tcx) , level = "debug" ) ]
647
- pub ( super ) fn try_unify < ' tcx > (
648
- tcx : TyCtxt < ' tcx > ,
649
- a : AbstractConst < ' tcx > ,
650
- b : AbstractConst < ' tcx > ,
651
- param_env : ty:: ParamEnv < ' tcx > ,
652
- ) -> bool {
653
- let a = match try_replace_substs_in_root ( tcx, a) {
654
- Some ( a) => a,
655
- None => {
656
- return true ;
642
+ impl < ' tcx > ConstUnifyCtxt < ' tcx > {
643
+ pub ( super ) fn new ( tcx : TyCtxt < ' tcx > , param_env : ty:: ParamEnv < ' tcx > ) -> Self {
644
+ ConstUnifyCtxt { tcx, param_env }
645
+ }
646
+
647
+ // Substitutes generics repeatedly to allow AbstractConsts to unify where a
648
+ // ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g.
649
+ // Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])]
650
+ #[ inline]
651
+ #[ instrument( skip( self ) , level = "debug" ) ]
652
+ pub ( super ) fn try_replace_substs_in_root (
653
+ & self ,
654
+ mut abstr_const : AbstractConst < ' tcx > ,
655
+ ) -> Option < AbstractConst < ' tcx > > {
656
+ while let Node :: Leaf ( ct) = abstr_const. root ( self . tcx ) {
657
+ match AbstractConst :: from_const ( self . tcx , ct) {
658
+ Ok ( Some ( act) ) => abstr_const = act,
659
+ Ok ( None ) => break ,
660
+ Err ( _) => return None ,
661
+ }
657
662
}
658
- } ;
659
663
660
- let b = match try_replace_substs_in_root ( tcx, b) {
661
- Some ( b) => b,
662
- None => {
664
+ Some ( abstr_const)
665
+ }
666
+
667
+ /// Tries to unify two abstract constants using structural equality.
668
+ #[ instrument( skip( self ) , level = "debug" ) ]
669
+ fn try_unify_inner ( & self , a : AbstractConst < ' tcx > , b : AbstractConst < ' tcx > ) -> bool {
670
+ let a = if let Some ( a) = self . try_replace_substs_in_root ( a) {
671
+ a
672
+ } else {
663
673
return true ;
664
- }
665
- } ;
674
+ } ;
666
675
667
- let a_root = a. root ( tcx) ;
668
- let b_root = b. root ( tcx) ;
669
- debug ! ( ?a_root, ?b_root) ;
676
+ let b = if let Some ( b) = self . try_replace_substs_in_root ( b) {
677
+ b
678
+ } else {
679
+ return true ;
680
+ } ;
670
681
671
- match ( a_root, b_root) {
672
- ( Node :: Leaf ( a_ct) , Node :: Leaf ( b_ct) ) => {
673
- let a_ct = a_ct. eval ( tcx, param_env) ;
674
- debug ! ( "a_ct evaluated: {:?}" , a_ct) ;
675
- let b_ct = b_ct. eval ( tcx, param_env) ;
676
- debug ! ( "b_ct evaluated: {:?}" , b_ct) ;
682
+ let a_root = a. root ( self . tcx ) ;
683
+ let b_root = b. root ( self . tcx ) ;
684
+ debug ! ( ?a_root, ?b_root) ;
677
685
678
- if a_ct. ty ( ) != b_ct. ty ( ) {
679
- return false ;
680
- }
686
+ match ( a_root, b_root) {
687
+ ( Node :: Leaf ( a_ct) , Node :: Leaf ( b_ct) ) => {
688
+ let a_ct = a_ct. eval ( self . tcx , self . param_env ) ;
689
+ debug ! ( "a_ct evaluated: {:?}" , a_ct) ;
690
+ let b_ct = b_ct. eval ( self . tcx , self . param_env ) ;
691
+ debug ! ( "b_ct evaluated: {:?}" , b_ct) ;
681
692
682
- match ( a_ct. val ( ) , b_ct. val ( ) ) {
683
- // We can just unify errors with everything to reduce the amount of
684
- // emitted errors here.
685
- ( ty:: ConstKind :: Error ( _) , _) | ( _, ty:: ConstKind :: Error ( _) ) => true ,
686
- ( ty:: ConstKind :: Param ( a_param) , ty:: ConstKind :: Param ( b_param) ) => {
687
- a_param == b_param
693
+ if a_ct. ty ( ) != b_ct. ty ( ) {
694
+ return false ;
688
695
}
689
- ( ty:: ConstKind :: Value ( a_val) , ty:: ConstKind :: Value ( b_val) ) => a_val == b_val,
690
- // If we have `fn a<const N: usize>() -> [u8; N + 1]` and `fn b<const M: usize>() -> [u8; 1 + M]`
691
- // we do not want to use `assert_eq!(a(), b())` to infer that `N` and `M` have to be `1`. This
692
- // means that we only allow inference variables if they are equal.
693
- ( ty:: ConstKind :: Infer ( a_val) , ty:: ConstKind :: Infer ( b_val) ) => a_val == b_val,
694
- // We expand generic anonymous constants at the start of this function, so this
695
- // branch should only be taking when dealing with associated constants, at
696
- // which point directly comparing them seems like the desired behavior.
697
- //
698
- // FIXME(generic_const_exprs): This isn't actually the case.
699
- // We also take this branch for concrete anonymous constants and
700
- // expand generic anonymous constants with concrete substs.
701
- ( ty:: ConstKind :: Unevaluated ( a_uv) , ty:: ConstKind :: Unevaluated ( b_uv) ) => {
702
- a_uv == b_uv
696
+
697
+ match ( a_ct. val ( ) , b_ct. val ( ) ) {
698
+ // We can just unify errors with everything to reduce the amount of
699
+ // emitted errors here.
700
+ ( ty:: ConstKind :: Error ( _) , _) | ( _, ty:: ConstKind :: Error ( _) ) => true ,
701
+ ( ty:: ConstKind :: Param ( a_param) , ty:: ConstKind :: Param ( b_param) ) => {
702
+ a_param == b_param
703
+ }
704
+ ( ty:: ConstKind :: Value ( a_val) , ty:: ConstKind :: Value ( b_val) ) => a_val == b_val,
705
+ // If we have `fn a<const N: usize>() -> [u8; N + 1]` and `fn b<const M: usize>() -> [u8; 1 + M]`
706
+ // we do not want to use `assert_eq!(a(), b())` to infer that `N` and `M` have to be `1`. This
707
+ // means that we only allow inference variables if they are equal.
708
+ ( ty:: ConstKind :: Infer ( a_val) , ty:: ConstKind :: Infer ( b_val) ) => a_val == b_val,
709
+ // We expand generic anonymous constants at the start of this function, so this
710
+ // branch should only be taking when dealing with associated constants, at
711
+ // which point directly comparing them seems like the desired behavior.
712
+ //
713
+ // FIXME(generic_const_exprs): This isn't actually the case.
714
+ // We also take this branch for concrete anonymous constants and
715
+ // expand generic anonymous constants with concrete substs.
716
+ ( ty:: ConstKind :: Unevaluated ( a_uv) , ty:: ConstKind :: Unevaluated ( b_uv) ) => {
717
+ a_uv == b_uv
718
+ }
719
+ // FIXME(generic_const_exprs): We may want to either actually try
720
+ // to evaluate `a_ct` and `b_ct` if they are are fully concrete or something like
721
+ // this, for now we just return false here.
722
+ _ => false ,
703
723
}
704
- // FIXME(generic_const_exprs): We may want to either actually try
705
- // to evaluate `a_ct` and `b_ct` if they are are fully concrete or something like
706
- // this, for now we just return false here.
707
- _ => false ,
708
724
}
725
+ ( Node :: Binop ( a_op, al, ar) , Node :: Binop ( b_op, bl, br) ) if a_op == b_op => {
726
+ self . try_unify_inner ( a. subtree ( al) , b. subtree ( bl) )
727
+ && self . try_unify_inner ( a. subtree ( ar) , b. subtree ( br) )
728
+ }
729
+ ( Node :: UnaryOp ( a_op, av) , Node :: UnaryOp ( b_op, bv) ) if a_op == b_op => {
730
+ self . try_unify_inner ( a. subtree ( av) , b. subtree ( bv) )
731
+ }
732
+ ( Node :: FunctionCall ( a_f, a_args) , Node :: FunctionCall ( b_f, b_args) )
733
+ if a_args. len ( ) == b_args. len ( ) =>
734
+ {
735
+ self . try_unify_inner ( a. subtree ( a_f) , b. subtree ( b_f) )
736
+ && iter:: zip ( a_args, b_args)
737
+ . all ( |( & an, & bn) | self . try_unify_inner ( a. subtree ( an) , b. subtree ( bn) ) )
738
+ }
739
+ ( Node :: Cast ( a_kind, a_operand, a_ty) , Node :: Cast ( b_kind, b_operand, b_ty) )
740
+ if ( a_ty == b_ty) && ( a_kind == b_kind) =>
741
+ {
742
+ self . try_unify_inner ( a. subtree ( a_operand) , b. subtree ( b_operand) )
743
+ }
744
+ // use this over `_ => false` to make adding variants to `Node` less error prone
745
+ ( Node :: Cast ( ..) , _)
746
+ | ( Node :: FunctionCall ( ..) , _)
747
+ | ( Node :: UnaryOp ( ..) , _)
748
+ | ( Node :: Binop ( ..) , _)
749
+ | ( Node :: Leaf ( ..) , _) => false ,
709
750
}
710
- ( Node :: Binop ( a_op, al, ar) , Node :: Binop ( b_op, bl, br) ) if a_op == b_op => {
711
- try_unify ( tcx, a. subtree ( al) , b. subtree ( bl) , param_env)
712
- && try_unify ( tcx, a. subtree ( ar) , b. subtree ( br) , param_env)
713
- }
714
- ( Node :: UnaryOp ( a_op, av) , Node :: UnaryOp ( b_op, bv) ) if a_op == b_op => {
715
- try_unify ( tcx, a. subtree ( av) , b. subtree ( bv) , param_env)
716
- }
717
- ( Node :: FunctionCall ( a_f, a_args) , Node :: FunctionCall ( b_f, b_args) )
718
- if a_args. len ( ) == b_args. len ( ) =>
719
- {
720
- try_unify ( tcx, a. subtree ( a_f) , b. subtree ( b_f) , param_env)
721
- && iter:: zip ( a_args, b_args)
722
- . all ( |( & an, & bn) | try_unify ( tcx, a. subtree ( an) , b. subtree ( bn) , param_env) )
723
- }
724
- ( Node :: Cast ( a_kind, a_operand, a_ty) , Node :: Cast ( b_kind, b_operand, b_ty) )
725
- if ( a_ty == b_ty) && ( a_kind == b_kind) =>
726
- {
727
- try_unify ( tcx, a. subtree ( a_operand) , b. subtree ( b_operand) , param_env)
728
- }
729
- // use this over `_ => false` to make adding variants to `Node` less error prone
730
- ( Node :: Cast ( ..) , _)
731
- | ( Node :: FunctionCall ( ..) , _)
732
- | ( Node :: UnaryOp ( ..) , _)
733
- | ( Node :: Binop ( ..) , _)
734
- | ( Node :: Leaf ( ..) , _) => false ,
735
751
}
736
752
}
0 commit comments