Skip to content

Commit 9941c2a

Browse files
committed
Fix lampepfl#7044: Added GADT recovery with fallback to default error.
1 parent 89a2104 commit 9941c2a

File tree

7 files changed

+95
-70
lines changed

7 files changed

+95
-70
lines changed

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import collection.mutable
1919
import scala.annotation.internal.sharable
2020
import scala.annotation.threadUnsafe
2121

22-
import config.Printers.debug
22+
import config.Printers.gadts
2323

2424
object Inferencing {
2525

@@ -165,14 +165,16 @@ object Inferencing {
165165
}
166166

167167
def approximateGADT(tp: Type)(implicit ctx: Context): Type = {
168-
val map = new IsFullyDefinedAccumulator2
168+
val map = new ApproximateGadtAccumulator
169169
val res = map(tp)
170170
assert(!map.failed)
171-
debug.println(i"approximateGADT( $tp ) = $res // {${tp.toString}}")
172171
res
173172
}
174173

175-
private class IsFullyDefinedAccumulator2(implicit ctx: Context) extends TypeMap {
174+
/** This class is mostly based on IsFullyDefinedAccumulator.
175+
* It tries to approximate the given type based on the available GADT constraints.
176+
*/
177+
private class ApproximateGadtAccumulator(implicit ctx: Context) extends TypeMap {
176178

177179
var failed = false
178180

@@ -200,9 +202,7 @@ object Inferencing {
200202
val sym = tp.symbol
201203
val res =
202204
ctx.gadt.approximation(sym, fromBelow = variance < 0)
203-
204-
debug.println(i"approximated $tp ~~ $res")
205-
205+
gadts.println(i"approximated $tp ~~ $res")
206206
res
207207

208208
case _: WildcardType | _: ProtoType =>
@@ -213,21 +213,8 @@ object Inferencing {
213213
mapOver(tp)
214214
}
215215

216-
// private class UpperInstantiator(implicit ctx: Context) extends TypeAccumulator[Unit] {
217-
// def apply(x: Unit, tp: Type): Unit = {
218-
// tp match {
219-
// case tvar: TypeVar if !tvar.isInstantiated =>
220-
// instantiate(tvar, fromBelow = false)
221-
// case _ =>
222-
// }
223-
// foldOver(x, tp)
224-
// }
225-
// }
226-
227216
def process(tp: Type): Type = {
228-
val res = apply(tp)
229-
// if (res && toMaximize) new UpperInstantiator().apply((), tp)
230-
res
217+
apply(tp)
231218
}
232219
}
233220

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Decorators._
1313
import Uniques._
1414
import config.Printers.typr
1515
import util.SourceFile
16+
import util.Property
1617

1718
import scala.annotation.internal.sharable
1819

@@ -684,7 +685,14 @@ object ProtoTypes {
684685

685686
/** Dummy tree to be used as an argument of a FunProto or ViewProto type */
686687
object dummyTreeOfType {
687-
def apply(tp: Type)(implicit src: SourceFile): Tree = untpd.Literal(Constant(null)) withTypeUnchecked tp
688+
/*
689+
* A property indicating that the given tree was created with dummyTreeOfType.
690+
* It is sometimes necessary to detect the dummy trees to avoid unwanted readaptations on them.
691+
*/
692+
val IsDummyTree = new Property.Key[Unit]
693+
694+
def apply(tp: Type)(implicit src: SourceFile): Tree =
695+
(untpd.Literal(Constant(null)) withTypeUnchecked tp).withAttachment(IsDummyTree, ())
688696
def unapply(tree: untpd.Tree): Option[Type] = tree match {
689697
case Literal(Constant(null)) => Some(tree.typeOpt)
690698
case _ => None

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ import transform.TypeUtils._
4444
import reporting.trace
4545
import Nullables.{NotNullInfo, given _}
4646
import NullOpsDecorator._
47-
import config.Printers.debug
4847

4948
object Typer {
5049

@@ -2774,7 +2773,7 @@ class Typer extends Namer
27742773
*/
27752774
def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = {
27762775
val last = Thread.currentThread.getStackTrace()(2).toString;
2777-
trace/*.force*/(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
2776+
trace(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
27782777
record("adapt")
27792778
adapt1(tree, pt, locked, tryGadtHealing)
27802779
}
@@ -2784,7 +2783,6 @@ class Typer extends Namer
27842783
adapt(tree, pt, ctx.typerState.ownedVars)
27852784

27862785
private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
2787-
// assert(pt.exists && !pt.isInstanceOf[ExprType])
27882786
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
27892787
def methodStr = err.refStr(methPart(tree).tpe)
27902788

@@ -3243,19 +3241,15 @@ class Typer extends Namer
32433241
}
32443242

32453243
def adaptToSubType(wtp: Type): Tree = {
3246-
debug.println("adaptToSubType")
3247-
debug.println("// try converting a constant to the target type")
32483244
// try converting a constant to the target type
32493245
val folded = ConstFold(tree, pt)
32503246
if (folded ne tree)
32513247
return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType])
32523248

3253-
debug.println("// Try to capture wildcards in type")
32543249
val captured = captureWildcards(wtp)
32553250
if (captured `ne` wtp)
32563251
return readapt(tree.cast(captured))
32573252

3258-
debug.println("// drop type if prototype is Unit")
32593253
// drop type if prototype is Unit
32603254
if (pt isRef defn.UnitClass) {
32613255
// local adaptation makes sure every adapted tree conforms to its pt
@@ -3265,7 +3259,6 @@ class Typer extends Namer
32653259
return tpd.Block(tree1 :: Nil, Literal(Constant(())))
32663260
}
32673261

3268-
debug.println("// convert function literal to SAM closure")
32693262
// convert function literal to SAM closure
32703263
tree match {
32713264
case closure(Nil, id @ Ident(nme.ANON_FUN), _)
@@ -3283,28 +3276,28 @@ class Typer extends Namer
32833276
case _ =>
32843277
}
32853278

3286-
debug.println("// try GADT approximation")
3287-
val foo = Inferencing.approximateGADT(wtp)
3288-
debug.println(
3289-
i"""
3290-
foo = $foo
3279+
val approximation = Inferencing.approximateGADT(wtp)
3280+
gadts.println(
3281+
i"""GADT approximation {
3282+
approximation = $approximation
32913283
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
32923284
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
3285+
ctx.gadt = ${ctx.gadt.debugBoundsDescription}
32933286
pt.isMatchedBy = ${
32943287
if (pt.isInstanceOf[SelectionProto])
3295-
pt.asInstanceOf[SelectionProto].isMatchedBy(foo).toString
3288+
pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString
32963289
else
32973290
"<not a SelectionProto>"
32983291
}
3292+
}
32993293
"""
33003294
)
33013295
pt match {
3302-
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(foo) =>
3303-
return tpd.Typed(tree, TypeTree(foo))
3296+
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) =>
3297+
return tpd.Typed(tree, TypeTree(approximation))
33043298
case _ => ;
33053299
}
33063300

3307-
debug.println("// try an extension method in scope")
33083301
// try an extension method in scope
33093302
pt match {
33103303
case SelectionProto(name, mbrType, _, _) =>
@@ -3322,33 +3315,41 @@ class Typer extends Namer
33223315
val app = tryExtension(using nestedCtx)
33233316
if (!app.isEmpty && !nestedCtx.reporter.hasErrors) {
33243317
nestedCtx.typerState.commit()
3325-
debug.println("returning ext meth in scope")
33263318
return ExtMethodApply(app)
33273319
}
33283320
case _ =>
33293321
}
33303322

3331-
debug.println("// try an implicit conversion")
33323323
// try an implicit conversion
33333324
val prevConstraint = ctx.typerState.constraint
3334-
def recover(failure: SearchFailureType) =
3335-
{
3336-
debug.println("recover")
3325+
def recover(failure: SearchFailureType) = {
3326+
def canTryGADTHealing: Boolean = {
3327+
val isDummy = tree.hasAttachment(dummyTreeOfType.IsDummyTree)
3328+
tryGadtHealing // allow GADT healing only once to avoid a loop
3329+
&& ctx.gadt.nonEmpty // GADT healing only makes sense if there are GADT constraints present
3330+
&& !isDummy // avoid healing a dummy tree as it can lead to an error in a very specific case
3331+
}
3332+
33373333
if (isFullyDefined(wtp, force = ForceDegree.all) &&
33383334
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
3339-
// else if ({
3340-
// debug.println(i"tryGadtHealing=$tryGadtHealing && \n\tctx.gadt.nonEmpty=${ctx.gadt.nonEmpty}")
3341-
// tryGadtHealing && ctx.gadt.nonEmpty
3342-
// })
3343-
// {
3344-
// debug.println("here")
3345-
// readapt(
3346-
// tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3347-
// shouldTryGadtHealing = false,
3348-
// )
3349-
// }
3350-
else err.typeMismatch(tree, pt, failure)
3351-
}
3335+
else if (canTryGADTHealing) {
3336+
// try recovering with a GADT approximation
3337+
val nestedCtx = ctx.fresh.setNewTyperState()
3338+
val res =
3339+
readapt(
3340+
tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3341+
shouldTryGadtHealing = false,
3342+
)(using nestedCtx)
3343+
if (!nestedCtx.reporter.hasErrors) {
3344+
// GADT recovery successful
3345+
nestedCtx.typerState.commit()
3346+
res
3347+
} else {
3348+
// otherwise fail with the error that would have been reported without the GADT recovery
3349+
err.typeMismatch(tree, pt, failure)
3350+
}
3351+
} else err.typeMismatch(tree, pt, failure)
3352+
}
33523353
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
33533354
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
33543355
ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos)

tests/neg/boundspropagation.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,6 @@ object test3 {
2525
}
2626
}
2727

28-
// Example contributed by Jason.
29-
object test4 {
30-
class Base {
31-
type N
32-
33-
class Tree[-S, -T >: Option[S]]
34-
35-
def g(x: Any): Tree[_, _ <: Option[N]] = x match {
36-
case y: Tree[_, _] => y // error -- used to work (because of capture conversion?)
37-
}
38-
}
39-
}
40-
4128
class Test5 {
4229
"": ({ type U = this.type })#U // error
4330
}

tests/pos/boundspropagation.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,16 @@ object test2 {
2929
}
3030
}
3131
*/
32+
33+
// Example contributed by Jason.
34+
object test2 {
35+
class Base {
36+
type N
37+
38+
class Tree[-S, -T >: Option[S]]
39+
40+
def g(x: Any): Tree[_, _ <: Option[N]] = x match {
41+
case y: Tree[_, _] => y
42+
}
43+
}
44+
}

tests/pos/gadt-infer-ascription.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// test based on an example code by @Blaisorblade
2+
object GadtAscription {
3+
enum Var[G, A] {
4+
case Z[G, A]() extends Var[(A, G), A]
5+
case S[G, A, B](x: Var[G, A]) extends Var[(B, G), A]
6+
}
7+
8+
import Var._
9+
def evalVar[G, A](x: Var[G, A])(rho: G): A = x match {
10+
case _: Z[g, a] =>
11+
rho(0)
12+
case s: S[g, a, b] =>
13+
evalVar(s.x)(rho(1))
14+
}
15+
}

tests/pos/i7044.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object i7044 {
2+
case class Seg[T](pat:Pat[T], body:T)
3+
4+
trait Pat[T]
5+
object Pat {
6+
case class Expr() extends Pat[Int]
7+
case class Opt[S](el:Pat[S]) extends Pat[Option[S]]
8+
}
9+
10+
def test[T](s:Seg[T]):Int = s match {
11+
case Seg(Pat.Expr(),body) => body + 1
12+
case Seg(Pat.Opt(Pat.Expr()),body) => body.get
13+
}
14+
}

0 commit comments

Comments
 (0)