Skip to content

Commit 7f262a1

Browse files
committed
Fix lampepfl#7044: Added GADT recovery with fallback to default error.
1 parent e06eb26 commit 7f262a1

File tree

6 files changed

+80
-69
lines changed

6 files changed

+80
-69
lines changed

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import collection.mutable
1919
import scala.annotation.internal.sharable
2020
import scala.annotation.threadUnsafe
2121

22-
import config.Printers.debug
22+
import config.Printers.gadts
2323

2424
object Inferencing {
2525

@@ -165,14 +165,16 @@ object Inferencing {
165165
}
166166

167167
def approximateGADT(tp: Type)(implicit ctx: Context): Type = {
168-
val map = new IsFullyDefinedAccumulator2
168+
val map = new ApproximateGadtAccumulator
169169
val res = map(tp)
170170
assert(!map.failed)
171-
debug.println(i"approximateGADT( $tp ) = $res // {${tp.toString}}")
172171
res
173172
}
174173

175-
private class IsFullyDefinedAccumulator2(implicit ctx: Context) extends TypeMap {
174+
/** This class is mostly based on IsFullyDefinedAccumulator.
175+
* It tries to approximate the given type based on the available GADT constraints.
176+
*/
177+
private class ApproximateGadtAccumulator(implicit ctx: Context) extends TypeMap {
176178

177179
var failed = false
178180

@@ -200,9 +202,7 @@ object Inferencing {
200202
val sym = tp.symbol
201203
val res =
202204
ctx.gadt.approximation(sym, fromBelow = variance < 0)
203-
204-
debug.println(i"approximated $tp ~~ $res")
205-
205+
gadts.println(i"approximated $tp ~~ $res")
206206
res
207207

208208
case _: WildcardType | _: ProtoType =>
@@ -213,21 +213,8 @@ object Inferencing {
213213
mapOver(tp)
214214
}
215215

216-
// private class UpperInstantiator(implicit ctx: Context) extends TypeAccumulator[Unit] {
217-
// def apply(x: Unit, tp: Type): Unit = {
218-
// tp match {
219-
// case tvar: TypeVar if !tvar.isInstantiated =>
220-
// instantiate(tvar, fromBelow = false)
221-
// case _ =>
222-
// }
223-
// foldOver(x, tp)
224-
// }
225-
// }
226-
227216
def process(tp: Type): Type = {
228-
val res = apply(tp)
229-
// if (res && toMaximize) new UpperInstantiator().apply((), tp)
230-
res
217+
apply(tp)
231218
}
232219
}
233220

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import transform.TypeUtils._
4444
import reporting.trace
4545
import Nullables.{NotNullInfo, given _}
4646
import NullOpsDecorator._
47-
import config.Printers.debug
47+
import config.Printers.gadts
4848

4949
object Typer {
5050

@@ -2764,7 +2764,7 @@ class Typer extends Namer
27642764
*/
27652765
def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = {
27662766
val last = Thread.currentThread.getStackTrace()(2).toString;
2767-
trace/*.force*/(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
2767+
trace(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
27682768
record("adapt")
27692769
adapt1(tree, pt, locked, tryGadtHealing)
27702770
}
@@ -2774,7 +2774,6 @@ class Typer extends Namer
27742774
adapt(tree, pt, ctx.typerState.ownedVars)
27752775

27762776
private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
2777-
// assert(pt.exists && !pt.isInstanceOf[ExprType])
27782777
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
27792778
def methodStr = err.refStr(methPart(tree).tpe)
27802779

@@ -3224,19 +3223,15 @@ class Typer extends Namer
32243223
}
32253224

32263225
def adaptToSubType(wtp: Type): Tree = {
3227-
debug.println("adaptToSubType")
3228-
debug.println("// try converting a constant to the target type")
32293226
// try converting a constant to the target type
32303227
val folded = ConstFold(tree, pt)
32313228
if (folded ne tree)
32323229
return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType])
32333230

3234-
debug.println("// Try to capture wildcards in type")
32353231
val captured = captureWildcards(wtp)
32363232
if (captured `ne` wtp)
32373233
return readapt(tree.cast(captured))
32383234

3239-
debug.println("// drop type if prototype is Unit")
32403235
// drop type if prototype is Unit
32413236
if (pt isRef defn.UnitClass) {
32423237
// local adaptation makes sure every adapted tree conforms to its pt
@@ -3246,7 +3241,6 @@ class Typer extends Namer
32463241
return tpd.Block(tree1 :: Nil, Literal(Constant(())))
32473242
}
32483243

3249-
debug.println("// convert function literal to SAM closure")
32503244
// convert function literal to SAM closure
32513245
tree match {
32523246
case closure(Nil, id @ Ident(nme.ANON_FUN), _)
@@ -3264,28 +3258,28 @@ class Typer extends Namer
32643258
case _ =>
32653259
}
32663260

3267-
debug.println("// try GADT approximation")
3268-
val foo = Inferencing.approximateGADT(wtp)
3269-
debug.println(
3270-
i"""
3271-
foo = $foo
3261+
val approximation = Inferencing.approximateGADT(wtp)
3262+
gadts.println(
3263+
i"""GADT approximation {
3264+
approximation = $approximation
32723265
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
32733266
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
3267+
ctx.gadt = ${ctx.gadt.debugBoundsDescription}
32743268
pt.isMatchedBy = ${
32753269
if (pt.isInstanceOf[SelectionProto])
3276-
pt.asInstanceOf[SelectionProto].isMatchedBy(foo).toString
3270+
pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString
32773271
else
32783272
"<not a SelectionProto>"
32793273
}
3274+
}
32803275
"""
32813276
)
32823277
pt match {
3283-
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(foo) =>
3284-
return tpd.Typed(tree, TypeTree(foo))
3278+
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) =>
3279+
return tpd.Typed(tree, TypeTree(approximation))
32853280
case _ => ;
32863281
}
32873282

3288-
debug.println("// try an extension method in scope")
32893283
// try an extension method in scope
32903284
pt match {
32913285
case SelectionProto(name, mbrType, _, _) =>
@@ -3303,33 +3297,34 @@ class Typer extends Namer
33033297
val app = tryExtension(using nestedCtx)
33043298
if (!app.isEmpty && !nestedCtx.reporter.hasErrors) {
33053299
nestedCtx.typerState.commit()
3306-
debug.println("returning ext meth in scope")
33073300
return ExtMethodApply(app)
33083301
}
33093302
case _ =>
33103303
}
33113304

3312-
debug.println("// try an implicit conversion")
33133305
// try an implicit conversion
33143306
val prevConstraint = ctx.typerState.constraint
3315-
def recover(failure: SearchFailureType) =
3316-
{
3317-
debug.println("recover")
3307+
def recover(failure: SearchFailureType) = {
33183308
if (isFullyDefined(wtp, force = ForceDegree.all) &&
33193309
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
3320-
// else if ({
3321-
// debug.println(i"tryGadtHealing=$tryGadtHealing && \n\tctx.gadt.nonEmpty=${ctx.gadt.nonEmpty}")
3322-
// tryGadtHealing && ctx.gadt.nonEmpty
3323-
// })
3324-
// {
3325-
// debug.println("here")
3326-
// readapt(
3327-
// tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3328-
// shouldTryGadtHealing = false,
3329-
// )
3330-
// }
3331-
else err.typeMismatch(tree, pt, failure)
3332-
}
3310+
else if (tryGadtHealing && ctx.gadt.nonEmpty) {
3311+
// try recovering with a GADT approximation
3312+
val nestedCtx = ctx.fresh.setNewTyperState()
3313+
val res =
3314+
readapt(
3315+
tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3316+
shouldTryGadtHealing = false,
3317+
)(using nestedCtx)
3318+
if (!nestedCtx.reporter.hasErrors) {
3319+
// GADT recovery successful
3320+
nestedCtx.typerState.commit()
3321+
res
3322+
} else {
3323+
// otherwise fail with the error that would have been reported without the GADT recovery
3324+
err.typeMismatch(tree, pt, failure)
3325+
}
3326+
} else err.typeMismatch(tree, pt, failure)
3327+
}
33333328
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
33343329
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
33353330
ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos)

tests/neg/boundspropagation.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,6 @@ object test3 {
2525
}
2626
}
2727

28-
// Example contributed by Jason.
29-
object test4 {
30-
class Base {
31-
type N
32-
33-
class Tree[-S, -T >: Option[S]]
34-
35-
def g(x: Any): Tree[_, _ <: Option[N]] = x match {
36-
case y: Tree[_, _] => y // error -- used to work (because of capture conversion?)
37-
}
38-
}
39-
}
40-
4128
class Test5 {
4229
"": ({ type U = this.type })#U // error
4330
}

tests/pos/boundspropagation.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,16 @@ object test2 {
2929
}
3030
}
3131
*/
32+
33+
// Example contributed by Jason.
34+
object test2 {
35+
class Base {
36+
type N
37+
38+
class Tree[-S, -T >: Option[S]]
39+
40+
def g(x: Any): Tree[_, _ <: Option[N]] = x match {
41+
case y: Tree[_, _] => y
42+
}
43+
}
44+
}

tests/pos/gadt-infer-ascription.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// test based on an example code by @Blaisorblade
2+
object GadtAscription {
3+
enum Var[G, A] {
4+
case Z[G, A]() extends Var[(A, G), A]
5+
case S[G, A, B](x: Var[G, A]) extends Var[(B, G), A]
6+
}
7+
8+
import Var._
9+
def evalVar[G, A](x: Var[G, A])(rho: G): A = x match {
10+
case _: Z[g, a] =>
11+
rho(0)
12+
case s: S[g, a, b] =>
13+
evalVar(s.x)(rho(1))
14+
}
15+
}

tests/pos/i7044.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object i7044 {
2+
case class Seg[T](pat:Pat[T], body:T)
3+
4+
trait Pat[T]
5+
object Pat {
6+
case class Expr() extends Pat[Int]
7+
case class Opt[S](el:Pat[S]) extends Pat[Option[S]]
8+
}
9+
10+
def test[T](s:Seg[T]):Int = s match {
11+
case Seg(Pat.Expr(),body) => body + 1
12+
case Seg(Pat.Opt(Pat.Expr()),body) => body.get
13+
}
14+
}

0 commit comments

Comments
 (0)