Skip to content

Commit b0c010e

Browse files
authored
Merge pull request #8728 from radeusgd/gadt-ascription3
Fix #7044: Approximate GADT bounds to avoid explicit type ascriptions
2 parents 6c489aa + b98d94b commit b0c010e

10 files changed

+164
-24
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ sealed abstract class GadtConstraint extends Showable {
4444
def contains(sym: Symbol)(implicit ctx: Context): Boolean
4545

4646
def isEmpty: Boolean
47+
final def nonEmpty: Boolean = !isEmpty
4748

4849
/** See [[ConstraintHandling.approximation]] */
4950
def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ object Erasure {
953953
(stats2.filter(!_.isEmpty), finalCtx)
954954
}
955955

956-
override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
956+
override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree =
957957
trace(i"adapting ${tree.showSummary}: ${tree.tpe} to $pt", show = true) {
958958
if ctx.phase != ctx.erasurePhase && ctx.phase != ctx.erasurePhase.next then
959959
// this can happen when reading annotations loaded during erasure,

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ class TreeChecker extends Phase with SymTransformer {
511511
override def ensureNoLocalRefs(tree: Tree, pt: Type, localSyms: => List[Symbol])(using Context): Tree =
512512
tree
513513

514-
override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
514+
override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
515515
def isPrimaryConstructorReturn =
516516
ctx.owner.isPrimaryConstructor && pt.isRef(ctx.owner.owner) && tree.tpe.isRef(defn.UnitClass)
517517
def infoStr(tp: Type) = tp match {

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import collection.mutable
1919
import scala.annotation.internal.sharable
2020
import scala.annotation.threadUnsafe
2121

22+
import config.Printers.gadts
23+
2224
object Inferencing {
2325

2426
import tpd._
@@ -162,6 +164,60 @@ object Inferencing {
162164
)
163165
}
164166

167+
def approximateGADT(tp: Type)(implicit ctx: Context): Type = {
168+
val map = new ApproximateGadtAccumulator
169+
val res = map(tp)
170+
assert(!map.failed)
171+
res
172+
}
173+
174+
/** This class is mostly based on IsFullyDefinedAccumulator.
175+
* It tries to approximate the given type based on the available GADT constraints.
176+
*/
177+
private class ApproximateGadtAccumulator(implicit ctx: Context) extends TypeMap {
178+
179+
var failed = false
180+
181+
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
182+
val inst = tvar.instantiate(fromBelow)
183+
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
184+
inst
185+
}
186+
187+
private def instDirection2(sym: Symbol)(implicit ctx: Context): Int = {
188+
val constrained = ctx.gadt.fullBounds(sym)
189+
val original = sym.info.bounds
190+
val cmp = ctx.typeComparer
191+
val approxBelow =
192+
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
193+
val approxAbove =
194+
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
195+
approxAbove - approxBelow
196+
}
197+
198+
private[this] var toMaximize: Boolean = false
199+
200+
def apply(tp: Type): Type = tp.dealias match {
201+
case tp @ TypeRef(qual, nme) if (qual eq NoPrefix) && ctx.gadt.contains(tp.symbol) =>
202+
val sym = tp.symbol
203+
val res =
204+
ctx.gadt.approximation(sym, fromBelow = variance < 0)
205+
gadts.println(i"approximated $tp ~~ $res")
206+
res
207+
208+
case _: WildcardType | _: ProtoType =>
209+
failed = true
210+
NoType
211+
212+
case tp =>
213+
mapOver(tp)
214+
}
215+
216+
def process(tp: Type): Type = {
217+
apply(tp)
218+
}
219+
}
220+
165221
/** For all type parameters occurring in `tp`:
166222
* If the bounds of `tp` in the current constraint are equal wrt =:=,
167223
* instantiate the type parameter to the lower bound's approximation

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Decorators._
1313
import Uniques._
1414
import config.Printers.typr
1515
import util.SourceFile
16+
import util.Property
1617

1718
import scala.annotation.internal.sharable
1819

@@ -684,7 +685,14 @@ object ProtoTypes {
684685

685686
/** Dummy tree to be used as an argument of a FunProto or ViewProto type */
686687
object dummyTreeOfType {
687-
def apply(tp: Type)(implicit src: SourceFile): Tree = untpd.Literal(Constant(null)) withTypeUnchecked tp
688+
/*
689+
* A property indicating that the given tree was created with dummyTreeOfType.
690+
* It is sometimes necessary to detect the dummy trees to avoid unwanted readaptations on them.
691+
*/
692+
val IsDummyTree = new Property.Key[Unit]
693+
694+
def apply(tp: Type)(implicit src: SourceFile): Tree =
695+
(untpd.Literal(Constant(null)) withTypeUnchecked tp).withAttachment(IsDummyTree, ())
688696
def unapply(tree: untpd.Tree): Option[Type] = tree match {
689697
case Literal(Constant(null)) => Some(tree.typeOpt)
690698
case _ => None

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,20 +2772,21 @@ class Typer extends Namer
27722772
* If all this fails, error
27732773
* Parameters as for `typedUnadapted`.
27742774
*/
2775-
def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
2776-
trace(i"adapting $tree to $pt", typr, show = true) {
2775+
def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = {
2776+
trace(i"adapting $tree to $pt ${if (tryGadtHealing) "" else "(tryGadtHealing=false)" }\n", typr, show = true) {
27772777
record("adapt")
2778-
adapt1(tree, pt, locked)
2778+
adapt1(tree, pt, locked, tryGadtHealing)
27792779
}
2780+
}
27802781

27812782
final def adapt(tree: Tree, pt: Type)(using Context): Tree =
27822783
adapt(tree, pt, ctx.typerState.ownedVars)
27832784

2784-
private def adapt1(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
2785+
private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
27852786
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
27862787
def methodStr = err.refStr(methPart(tree).tpe)
27872788

2788-
def readapt(tree: Tree)(using Context) = adapt(tree, pt, locked)
2789+
def readapt(tree: Tree, shouldTryGadtHealing: Boolean = tryGadtHealing)(using Context) = adapt(tree, pt, locked, shouldTryGadtHealing)
27892790
def readaptSimplified(tree: Tree)(using Context) = readapt(simplify(tree, pt, locked))
27902791

27912792
def missingArgs(mt: MethodType) = {
@@ -3245,7 +3246,6 @@ class Typer extends Namer
32453246
if (folded ne tree)
32463247
return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType])
32473248

3248-
// Try to capture wildcards in type
32493249
val captured = captureWildcards(wtp)
32503250
if (captured `ne` wtp)
32513251
return readapt(tree.cast(captured))
@@ -3276,6 +3276,28 @@ class Typer extends Namer
32763276
case _ =>
32773277
}
32783278

3279+
val approximation = Inferencing.approximateGADT(wtp)
3280+
gadts.println(
3281+
i"""GADT approximation {
3282+
approximation = $approximation
3283+
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
3284+
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
3285+
ctx.gadt = ${ctx.gadt.debugBoundsDescription}
3286+
pt.isMatchedBy = ${
3287+
if (pt.isInstanceOf[SelectionProto])
3288+
pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString
3289+
else
3290+
"<not a SelectionProto>"
3291+
}
3292+
}
3293+
"""
3294+
)
3295+
pt match {
3296+
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) =>
3297+
return tpd.Typed(tree, TypeTree(approximation))
3298+
case _ => ;
3299+
}
3300+
32793301
// try an extension method in scope
32803302
pt match {
32813303
case SelectionProto(name, mbrType, _, _) =>
@@ -3300,10 +3322,34 @@ class Typer extends Namer
33003322

33013323
// try an implicit conversion
33023324
val prevConstraint = ctx.typerState.constraint
3303-
def recover(failure: SearchFailureType) =
3325+
def recover(failure: SearchFailureType) = {
3326+
def canTryGADTHealing: Boolean = {
3327+
val isDummy = tree.hasAttachment(dummyTreeOfType.IsDummyTree)
3328+
tryGadtHealing // allow GADT healing only once to avoid a loop
3329+
&& ctx.gadt.nonEmpty // GADT healing only makes sense if there are GADT constraints present
3330+
&& !isDummy // avoid healing a dummy tree as it can lead to an error in a very specific case
3331+
}
3332+
33043333
if (isFullyDefined(wtp, force = ForceDegree.all) &&
33053334
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
3306-
else err.typeMismatch(tree, pt, failure)
3335+
else if (canTryGADTHealing) {
3336+
// try recovering with a GADT approximation
3337+
val nestedCtx = ctx.fresh.setNewTyperState()
3338+
val res =
3339+
readapt(
3340+
tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3341+
shouldTryGadtHealing = false,
3342+
)(using nestedCtx)
3343+
if (!nestedCtx.reporter.hasErrors) {
3344+
// GADT recovery successful
3345+
nestedCtx.typerState.commit()
3346+
res
3347+
} else {
3348+
// otherwise fail with the error that would have been reported without the GADT recovery
3349+
err.typeMismatch(tree, pt, failure)
3350+
}
3351+
} else err.typeMismatch(tree, pt, failure)
3352+
}
33073353
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
33083354
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
33093355
ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos)

tests/neg/boundspropagation.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,6 @@ object test3 {
2525
}
2626
}
2727

28-
// Example contributed by Jason.
29-
object test4 {
30-
class Base {
31-
type N
32-
33-
class Tree[-S, -T >: Option[S]]
34-
35-
def g(x: Any): Tree[_, _ <: Option[N]] = x match {
36-
case y: Tree[_, _] => y // error -- used to work (because of capture conversion?)
37-
}
38-
}
39-
}
40-
4128
class Test5 {
4229
"": ({ type U = this.type })#U // error
4330
}

tests/pos/boundspropagation.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,16 @@ object test2 {
2929
}
3030
}
3131
*/
32+
33+
// Example contributed by Jason.
34+
object test2 {
35+
class Base {
36+
type N
37+
38+
class Tree[-S, -T >: Option[S]]
39+
40+
def g(x: Any): Tree[_, _ <: Option[N]] = x match {
41+
case y: Tree[_, _] => y
42+
}
43+
}
44+
}

tests/pos/gadt-infer-ascription.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// test based on an example code by @Blaisorblade
2+
object GadtAscription {
3+
enum Var[G, A] {
4+
case Z[G, A]() extends Var[(A, G), A]
5+
case S[G, A, B](x: Var[G, A]) extends Var[(B, G), A]
6+
}
7+
8+
import Var._
9+
def evalVar[G, A](x: Var[G, A])(rho: G): A = x match {
10+
case _: Z[g, a] =>
11+
rho(0)
12+
case s: S[g, a, b] =>
13+
evalVar(s.x)(rho(1))
14+
}
15+
}

tests/pos/i7044.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object i7044 {
2+
case class Seg[T](pat:Pat[T], body:T)
3+
4+
trait Pat[T]
5+
object Pat {
6+
case class Expr() extends Pat[Int]
7+
case class Opt[S](el:Pat[S]) extends Pat[Option[S]]
8+
}
9+
10+
def test[T](s:Seg[T]):Int = s match {
11+
case Seg(Pat.Expr(),body) => body + 1
12+
case Seg(Pat.Opt(Pat.Expr()),body) => body.get
13+
}
14+
}

0 commit comments

Comments
 (0)