diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 2faf721a9da8..a72f55506aa3 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -44,6 +44,7 @@ sealed abstract class GadtConstraint extends Showable { def contains(sym: Symbol)(implicit ctx: Context): Boolean def isEmpty: Boolean + final def nonEmpty: Boolean = !isEmpty /** See [[ConstraintHandling.approximation]] */ def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index 24ab78f4df7a..5fbaf39797d8 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -954,7 +954,7 @@ object Erasure { (stats2.filter(!_.isEmpty), finalCtx) } - override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = + override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = trace(i"adapting ${tree.showSummary}: ${tree.tpe} to $pt", show = true) { if ctx.phase != ctx.erasurePhase && ctx.phase != ctx.erasurePhase.next then // this can happen when reading annotations loaded during erasure, diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index d1cd9f870d37..3dc3b3eba1cc 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -511,7 +511,7 @@ class TreeChecker extends Phase with SymTransformer { override def ensureNoLocalRefs(tree: Tree, pt: Type, localSyms: => List[Symbol])(using Context): Tree = tree - override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = { + override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = { def isPrimaryConstructorReturn = ctx.owner.isPrimaryConstructor && pt.isRef(ctx.owner.owner) && tree.tpe.isRef(defn.UnitClass) def infoStr(tp: Type) = tp match { diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index dfc38f8006a2..c6b241458ee1 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -19,6 +19,8 @@ import collection.mutable import scala.annotation.internal.sharable import scala.annotation.threadUnsafe +import config.Printers.gadts + object Inferencing { import tpd._ @@ -162,6 +164,60 @@ object Inferencing { ) } + def approximateGADT(tp: Type)(implicit ctx: Context): Type = { + val map = new ApproximateGadtAccumulator + val res = map(tp) + assert(!map.failed) + res + } + + /** This class is mostly based on IsFullyDefinedAccumulator. + * It tries to approximate the given type based on the available GADT constraints. + */ + private class ApproximateGadtAccumulator(implicit ctx: Context) extends TypeMap { + + var failed = false + + private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = { + val inst = tvar.instantiate(fromBelow) + typr.println(i"forced instantiation of ${tvar.origin} = $inst") + inst + } + + private def instDirection2(sym: Symbol)(implicit ctx: Context): Int = { + val constrained = ctx.gadt.fullBounds(sym) + val original = sym.info.bounds + val cmp = ctx.typeComparer + val approxBelow = + if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0 + val approxAbove = + if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0 + approxAbove - approxBelow + } + + private[this] var toMaximize: Boolean = false + + def apply(tp: Type): Type = tp.dealias match { + case tp @ TypeRef(qual, nme) if (qual eq NoPrefix) && ctx.gadt.contains(tp.symbol) => + val sym = tp.symbol + val res = + ctx.gadt.approximation(sym, fromBelow = variance < 0) + gadts.println(i"approximated $tp ~~ $res") + res + + case _: WildcardType | _: ProtoType => + failed = true + NoType + + case tp => + mapOver(tp) + } + + def process(tp: Type): Type = { + apply(tp) + } + } + /** For all type parameters occurring in `tp`: * If the bounds of `tp` in the current constraint are equal wrt =:=, * instantiate the type parameter to the lower bound's approximation diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index b62e1bc7f0e7..19f15caa7223 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -13,6 +13,7 @@ import Decorators._ import Uniques._ import config.Printers.typr import util.SourceFile +import util.Property import scala.annotation.internal.sharable @@ -684,7 +685,14 @@ object ProtoTypes { /** Dummy tree to be used as an argument of a FunProto or ViewProto type */ object dummyTreeOfType { - def apply(tp: Type)(implicit src: SourceFile): Tree = untpd.Literal(Constant(null)) withTypeUnchecked tp + /* + * A property indicating that the given tree was created with dummyTreeOfType. + * It is sometimes necessary to detect the dummy trees to avoid unwanted readaptations on them. + */ + val IsDummyTree = new Property.Key[Unit] + + def apply(tp: Type)(implicit src: SourceFile): Tree = + (untpd.Literal(Constant(null)) withTypeUnchecked tp).withAttachment(IsDummyTree, ()) def unapply(tree: untpd.Tree): Option[Type] = tree match { case Literal(Constant(null)) => Some(tree.typeOpt) case _ => None diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 9c4ba0b47128..4e785b62f8e0 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2771,20 +2771,21 @@ class Typer extends Namer * If all this fails, error * Parameters as for `typedUnadapted`. */ - def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = - trace(i"adapting $tree to $pt", typr, show = true) { + def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = { + trace(i"adapting $tree to $pt ${if (tryGadtHealing) "" else "(tryGadtHealing=false)" }\n", typr, show = true) { record("adapt") - adapt1(tree, pt, locked) + adapt1(tree, pt, locked, tryGadtHealing) } + } final def adapt(tree: Tree, pt: Type)(using Context): Tree = adapt(tree, pt, ctx.typerState.ownedVars) - private def adapt1(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = { + private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = { assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported) def methodStr = err.refStr(methPart(tree).tpe) - def readapt(tree: Tree)(using Context) = adapt(tree, pt, locked) + def readapt(tree: Tree, shouldTryGadtHealing: Boolean = tryGadtHealing)(using Context) = adapt(tree, pt, locked, shouldTryGadtHealing) def readaptSimplified(tree: Tree)(using Context) = readapt(simplify(tree, pt, locked)) def missingArgs(mt: MethodType) = { @@ -3244,7 +3245,6 @@ class Typer extends Namer if (folded ne tree) return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType]) - // Try to capture wildcards in type val captured = captureWildcards(wtp) if (captured `ne` wtp) return readapt(tree.cast(captured)) @@ -3275,6 +3275,28 @@ class Typer extends Namer case _ => } + val approximation = Inferencing.approximateGADT(wtp) + gadts.println( + i"""GADT approximation { + approximation = $approximation + pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]} + ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty} + ctx.gadt = ${ctx.gadt.debugBoundsDescription} + pt.isMatchedBy = ${ + if (pt.isInstanceOf[SelectionProto]) + pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString + else + "" + } + } + """ + ) + pt match { + case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) => + return tpd.Typed(tree, TypeTree(approximation)) + case _ => ; + } + // try an extension method in scope pt match { case SelectionProto(name, mbrType, _, _) => @@ -3299,10 +3321,34 @@ class Typer extends Namer // try an implicit conversion val prevConstraint = ctx.typerState.constraint - def recover(failure: SearchFailureType) = + def recover(failure: SearchFailureType) = { + def canTryGADTHealing: Boolean = { + val isDummy = tree.hasAttachment(dummyTreeOfType.IsDummyTree) + tryGadtHealing // allow GADT healing only once to avoid a loop + && ctx.gadt.nonEmpty // GADT healing only makes sense if there are GADT constraints present + && !isDummy // avoid healing a dummy tree as it can lead to an error in a very specific case + } + if (isFullyDefined(wtp, force = ForceDegree.all) && ctx.typerState.constraint.ne(prevConstraint)) readapt(tree) - else err.typeMismatch(tree, pt, failure) + else if (canTryGADTHealing) { + // try recovering with a GADT approximation + val nestedCtx = ctx.fresh.setNewTyperState() + val res = + readapt( + tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))), + shouldTryGadtHealing = false, + )(using nestedCtx) + if (!nestedCtx.reporter.hasErrors) { + // GADT recovery successful + nestedCtx.typerState.commit() + res + } else { + // otherwise fail with the error that would have been reported without the GADT recovery + err.typeMismatch(tree, pt, failure) + } + } else err.typeMismatch(tree, pt, failure) + } if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos) diff --git a/tests/neg/boundspropagation.scala b/tests/neg/boundspropagation.scala index 63a0d1c359ba..8adf73d33453 100644 --- a/tests/neg/boundspropagation.scala +++ b/tests/neg/boundspropagation.scala @@ -25,19 +25,6 @@ object test3 { } } -// Example contributed by Jason. -object test4 { - class Base { - type N - - class Tree[-S, -T >: Option[S]] - - def g(x: Any): Tree[_, _ <: Option[N]] = x match { - case y: Tree[_, _] => y // error -- used to work (because of capture conversion?) - } - } -} - class Test5 { "": ({ type U = this.type })#U // error } diff --git a/tests/pos/boundspropagation.scala b/tests/pos/boundspropagation.scala index 78366c3a1196..e3a42711128c 100644 --- a/tests/pos/boundspropagation.scala +++ b/tests/pos/boundspropagation.scala @@ -29,3 +29,16 @@ object test2 { } } */ + +// Example contributed by Jason. +object test2 { + class Base { + type N + + class Tree[-S, -T >: Option[S]] + + def g(x: Any): Tree[_, _ <: Option[N]] = x match { + case y: Tree[_, _] => y + } + } +} diff --git a/tests/pos/gadt-infer-ascription.scala b/tests/pos/gadt-infer-ascription.scala new file mode 100644 index 000000000000..773b1ee17eca --- /dev/null +++ b/tests/pos/gadt-infer-ascription.scala @@ -0,0 +1,15 @@ +// test based on an example code by @Blaisorblade +object GadtAscription { + enum Var[G, A] { + case Z[G, A]() extends Var[(A, G), A] + case S[G, A, B](x: Var[G, A]) extends Var[(B, G), A] + } + + import Var._ + def evalVar[G, A](x: Var[G, A])(rho: G): A = x match { + case _: Z[g, a] => + rho(0) + case s: S[g, a, b] => + evalVar(s.x)(rho(1)) + } +} diff --git a/tests/pos/i7044.scala b/tests/pos/i7044.scala new file mode 100644 index 000000000000..a18d87244643 --- /dev/null +++ b/tests/pos/i7044.scala @@ -0,0 +1,14 @@ +object i7044 { + case class Seg[T](pat:Pat[T], body:T) + + trait Pat[T] + object Pat { + case class Expr() extends Pat[Int] + case class Opt[S](el:Pat[S]) extends Pat[Option[S]] + } + + def test[T](s:Seg[T]):Int = s match { + case Seg(Pat.Expr(),body) => body + 1 + case Seg(Pat.Opt(Pat.Expr()),body) => body.get + } +}