Skip to content

Commit b98d94b

Browse files
committed
Fix lampepfl#7044: Added GADT recovery with fallback to default error.
1 parent 89a2104 commit b98d94b

File tree

7 files changed

+95
-71
lines changed

7 files changed

+95
-71
lines changed

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import collection.mutable
1919
import scala.annotation.internal.sharable
2020
import scala.annotation.threadUnsafe
2121

22-
import config.Printers.debug
22+
import config.Printers.gadts
2323

2424
object Inferencing {
2525

@@ -165,14 +165,16 @@ object Inferencing {
165165
}
166166

167167
def approximateGADT(tp: Type)(implicit ctx: Context): Type = {
168-
val map = new IsFullyDefinedAccumulator2
168+
val map = new ApproximateGadtAccumulator
169169
val res = map(tp)
170170
assert(!map.failed)
171-
debug.println(i"approximateGADT( $tp ) = $res // {${tp.toString}}")
172171
res
173172
}
174173

175-
private class IsFullyDefinedAccumulator2(implicit ctx: Context) extends TypeMap {
174+
/** This class is mostly based on IsFullyDefinedAccumulator.
175+
* It tries to approximate the given type based on the available GADT constraints.
176+
*/
177+
private class ApproximateGadtAccumulator(implicit ctx: Context) extends TypeMap {
176178

177179
var failed = false
178180

@@ -200,9 +202,7 @@ object Inferencing {
200202
val sym = tp.symbol
201203
val res =
202204
ctx.gadt.approximation(sym, fromBelow = variance < 0)
203-
204-
debug.println(i"approximated $tp ~~ $res")
205-
205+
gadts.println(i"approximated $tp ~~ $res")
206206
res
207207

208208
case _: WildcardType | _: ProtoType =>
@@ -213,21 +213,8 @@ object Inferencing {
213213
mapOver(tp)
214214
}
215215

216-
// private class UpperInstantiator(implicit ctx: Context) extends TypeAccumulator[Unit] {
217-
// def apply(x: Unit, tp: Type): Unit = {
218-
// tp match {
219-
// case tvar: TypeVar if !tvar.isInstantiated =>
220-
// instantiate(tvar, fromBelow = false)
221-
// case _ =>
222-
// }
223-
// foldOver(x, tp)
224-
// }
225-
// }
226-
227216
def process(tp: Type): Type = {
228-
val res = apply(tp)
229-
// if (res && toMaximize) new UpperInstantiator().apply((), tp)
230-
res
217+
apply(tp)
231218
}
232219
}
233220

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Decorators._
1313
import Uniques._
1414
import config.Printers.typr
1515
import util.SourceFile
16+
import util.Property
1617

1718
import scala.annotation.internal.sharable
1819

@@ -684,7 +685,14 @@ object ProtoTypes {
684685

685686
/** Dummy tree to be used as an argument of a FunProto or ViewProto type */
686687
object dummyTreeOfType {
687-
def apply(tp: Type)(implicit src: SourceFile): Tree = untpd.Literal(Constant(null)) withTypeUnchecked tp
688+
/*
689+
* A property indicating that the given tree was created with dummyTreeOfType.
690+
* It is sometimes necessary to detect the dummy trees to avoid unwanted readaptations on them.
691+
*/
692+
val IsDummyTree = new Property.Key[Unit]
693+
694+
def apply(tp: Type)(implicit src: SourceFile): Tree =
695+
(untpd.Literal(Constant(null)) withTypeUnchecked tp).withAttachment(IsDummyTree, ())
688696
def unapply(tree: untpd.Tree): Option[Type] = tree match {
689697
case Literal(Constant(null)) => Some(tree.typeOpt)
690698
case _ => None

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ import transform.TypeUtils._
4444
import reporting.trace
4545
import Nullables.{NotNullInfo, given _}
4646
import NullOpsDecorator._
47-
import config.Printers.debug
4847

4948
object Typer {
5049

@@ -2773,8 +2772,7 @@ class Typer extends Namer
27732772
* Parameters as for `typedUnadapted`.
27742773
*/
27752774
def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = {
2776-
val last = Thread.currentThread.getStackTrace()(2).toString;
2777-
trace/*.force*/(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
2775+
trace(i"adapting $tree to $pt ${if (tryGadtHealing) "" else "(tryGadtHealing=false)" }\n", typr, show = true) {
27782776
record("adapt")
27792777
adapt1(tree, pt, locked, tryGadtHealing)
27802778
}
@@ -2784,7 +2782,6 @@ class Typer extends Namer
27842782
adapt(tree, pt, ctx.typerState.ownedVars)
27852783

27862784
private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
2787-
// assert(pt.exists && !pt.isInstanceOf[ExprType])
27882785
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
27892786
def methodStr = err.refStr(methPart(tree).tpe)
27902787

@@ -3243,19 +3240,15 @@ class Typer extends Namer
32433240
}
32443241

32453242
def adaptToSubType(wtp: Type): Tree = {
3246-
debug.println("adaptToSubType")
3247-
debug.println("// try converting a constant to the target type")
32483243
// try converting a constant to the target type
32493244
val folded = ConstFold(tree, pt)
32503245
if (folded ne tree)
32513246
return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType])
32523247

3253-
debug.println("// Try to capture wildcards in type")
32543248
val captured = captureWildcards(wtp)
32553249
if (captured `ne` wtp)
32563250
return readapt(tree.cast(captured))
32573251

3258-
debug.println("// drop type if prototype is Unit")
32593252
// drop type if prototype is Unit
32603253
if (pt isRef defn.UnitClass) {
32613254
// local adaptation makes sure every adapted tree conforms to its pt
@@ -3265,7 +3258,6 @@ class Typer extends Namer
32653258
return tpd.Block(tree1 :: Nil, Literal(Constant(())))
32663259
}
32673260

3268-
debug.println("// convert function literal to SAM closure")
32693261
// convert function literal to SAM closure
32703262
tree match {
32713263
case closure(Nil, id @ Ident(nme.ANON_FUN), _)
@@ -3283,28 +3275,28 @@ class Typer extends Namer
32833275
case _ =>
32843276
}
32853277

3286-
debug.println("// try GADT approximation")
3287-
val foo = Inferencing.approximateGADT(wtp)
3288-
debug.println(
3289-
i"""
3290-
foo = $foo
3278+
val approximation = Inferencing.approximateGADT(wtp)
3279+
gadts.println(
3280+
i"""GADT approximation {
3281+
approximation = $approximation
32913282
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
32923283
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
3284+
ctx.gadt = ${ctx.gadt.debugBoundsDescription}
32933285
pt.isMatchedBy = ${
32943286
if (pt.isInstanceOf[SelectionProto])
3295-
pt.asInstanceOf[SelectionProto].isMatchedBy(foo).toString
3287+
pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString
32963288
else
32973289
"<not a SelectionProto>"
32983290
}
3291+
}
32993292
"""
33003293
)
33013294
pt match {
3302-
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(foo) =>
3303-
return tpd.Typed(tree, TypeTree(foo))
3295+
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) =>
3296+
return tpd.Typed(tree, TypeTree(approximation))
33043297
case _ => ;
33053298
}
33063299

3307-
debug.println("// try an extension method in scope")
33083300
// try an extension method in scope
33093301
pt match {
33103302
case SelectionProto(name, mbrType, _, _) =>
@@ -3322,33 +3314,41 @@ class Typer extends Namer
33223314
val app = tryExtension(using nestedCtx)
33233315
if (!app.isEmpty && !nestedCtx.reporter.hasErrors) {
33243316
nestedCtx.typerState.commit()
3325-
debug.println("returning ext meth in scope")
33263317
return ExtMethodApply(app)
33273318
}
33283319
case _ =>
33293320
}
33303321

3331-
debug.println("// try an implicit conversion")
33323322
// try an implicit conversion
33333323
val prevConstraint = ctx.typerState.constraint
3334-
def recover(failure: SearchFailureType) =
3335-
{
3336-
debug.println("recover")
3324+
def recover(failure: SearchFailureType) = {
3325+
def canTryGADTHealing: Boolean = {
3326+
val isDummy = tree.hasAttachment(dummyTreeOfType.IsDummyTree)
3327+
tryGadtHealing // allow GADT healing only once to avoid a loop
3328+
&& ctx.gadt.nonEmpty // GADT healing only makes sense if there are GADT constraints present
3329+
&& !isDummy // avoid healing a dummy tree as it can lead to an error in a very specific case
3330+
}
3331+
33373332
if (isFullyDefined(wtp, force = ForceDegree.all) &&
33383333
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
3339-
// else if ({
3340-
// debug.println(i"tryGadtHealing=$tryGadtHealing && \n\tctx.gadt.nonEmpty=${ctx.gadt.nonEmpty}")
3341-
// tryGadtHealing && ctx.gadt.nonEmpty
3342-
// })
3343-
// {
3344-
// debug.println("here")
3345-
// readapt(
3346-
// tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3347-
// shouldTryGadtHealing = false,
3348-
// )
3349-
// }
3350-
else err.typeMismatch(tree, pt, failure)
3351-
}
3334+
else if (canTryGADTHealing) {
3335+
// try recovering with a GADT approximation
3336+
val nestedCtx = ctx.fresh.setNewTyperState()
3337+
val res =
3338+
readapt(
3339+
tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3340+
shouldTryGadtHealing = false,
3341+
)(using nestedCtx)
3342+
if (!nestedCtx.reporter.hasErrors) {
3343+
// GADT recovery successful
3344+
nestedCtx.typerState.commit()
3345+
res
3346+
} else {
3347+
// otherwise fail with the error that would have been reported without the GADT recovery
3348+
err.typeMismatch(tree, pt, failure)
3349+
}
3350+
} else err.typeMismatch(tree, pt, failure)
3351+
}
33523352
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
33533353
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
33543354
ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos)

tests/neg/boundspropagation.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,6 @@ object test3 {
2525
}
2626
}
2727

28-
// Example contributed by Jason.
29-
object test4 {
30-
class Base {
31-
type N
32-
33-
class Tree[-S, -T >: Option[S]]
34-
35-
def g(x: Any): Tree[_, _ <: Option[N]] = x match {
36-
case y: Tree[_, _] => y // error -- used to work (because of capture conversion?)
37-
}
38-
}
39-
}
40-
4128
class Test5 {
4229
"": ({ type U = this.type })#U // error
4330
}

tests/pos/boundspropagation.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,16 @@ object test2 {
2929
}
3030
}
3131
*/
32+
33+
// Example contributed by Jason.
34+
object test2 {
35+
class Base {
36+
type N
37+
38+
class Tree[-S, -T >: Option[S]]
39+
40+
def g(x: Any): Tree[_, _ <: Option[N]] = x match {
41+
case y: Tree[_, _] => y
42+
}
43+
}
44+
}

tests/pos/gadt-infer-ascription.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// test based on an example code by @Blaisorblade
2+
object GadtAscription {
3+
enum Var[G, A] {
4+
case Z[G, A]() extends Var[(A, G), A]
5+
case S[G, A, B](x: Var[G, A]) extends Var[(B, G), A]
6+
}
7+
8+
import Var._
9+
def evalVar[G, A](x: Var[G, A])(rho: G): A = x match {
10+
case _: Z[g, a] =>
11+
rho(0)
12+
case s: S[g, a, b] =>
13+
evalVar(s.x)(rho(1))
14+
}
15+
}

tests/pos/i7044.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object i7044 {
2+
case class Seg[T](pat:Pat[T], body:T)
3+
4+
trait Pat[T]
5+
object Pat {
6+
case class Expr() extends Pat[Int]
7+
case class Opt[S](el:Pat[S]) extends Pat[Option[S]]
8+
}
9+
10+
def test[T](s:Seg[T]):Int = s match {
11+
case Seg(Pat.Expr(),body) => body + 1
12+
case Seg(Pat.Opt(Pat.Expr()),body) => body.get
13+
}
14+
}

0 commit comments

Comments
 (0)