Skip to content

Commit 89a2104

Browse files
abgruszeckiradeusgd
authored andcommitted
Add GADT approximation in adaptToSubType
1 parent 6f51dbb commit 89a2104

File tree

5 files changed

+124
-8
lines changed

5 files changed

+124
-8
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ sealed abstract class GadtConstraint extends Showable {
4444
def contains(sym: Symbol)(implicit ctx: Context): Boolean
4545

4646
def isEmpty: Boolean
47+
final def nonEmpty: Boolean = !isEmpty
4748

4849
/** See [[ConstraintHandling.approximation]] */
4950
def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ object Erasure {
954954
(stats2.filter(!_.isEmpty), finalCtx)
955955
}
956956

957-
override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
957+
override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree =
958958
trace(i"adapting ${tree.showSummary}: ${tree.tpe} to $pt", show = true) {
959959
if ctx.phase != ctx.erasurePhase && ctx.phase != ctx.erasurePhase.next then
960960
// this can happen when reading annotations loaded during erasure,

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ class TreeChecker extends Phase with SymTransformer {
511511
override def ensureNoLocalRefs(tree: Tree, pt: Type, localSyms: => List[Symbol])(using Context): Tree =
512512
tree
513513

514-
override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
514+
override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
515515
def isPrimaryConstructorReturn =
516516
ctx.owner.isPrimaryConstructor && pt.isRef(ctx.owner.owner) && tree.tpe.isRef(defn.UnitClass)
517517
def infoStr(tp: Type) = tp match {

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import collection.mutable
1919
import scala.annotation.internal.sharable
2020
import scala.annotation.threadUnsafe
2121

22+
import config.Printers.debug
23+
2224
object Inferencing {
2325

2426
import tpd._
@@ -162,6 +164,73 @@ object Inferencing {
162164
)
163165
}
164166

167+
def approximateGADT(tp: Type)(implicit ctx: Context): Type = {
168+
val map = new IsFullyDefinedAccumulator2
169+
val res = map(tp)
170+
assert(!map.failed)
171+
debug.println(i"approximateGADT( $tp ) = $res // {${tp.toString}}")
172+
res
173+
}
174+
175+
private class IsFullyDefinedAccumulator2(implicit ctx: Context) extends TypeMap {
176+
177+
var failed = false
178+
179+
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
180+
val inst = tvar.instantiate(fromBelow)
181+
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
182+
inst
183+
}
184+
185+
private def instDirection2(sym: Symbol)(implicit ctx: Context): Int = {
186+
val constrained = ctx.gadt.fullBounds(sym)
187+
val original = sym.info.bounds
188+
val cmp = ctx.typeComparer
189+
val approxBelow =
190+
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
191+
val approxAbove =
192+
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
193+
approxAbove - approxBelow
194+
}
195+
196+
private[this] var toMaximize: Boolean = false
197+
198+
def apply(tp: Type): Type = tp.dealias match {
199+
case tp @ TypeRef(qual, nme) if (qual eq NoPrefix) && ctx.gadt.contains(tp.symbol) =>
200+
val sym = tp.symbol
201+
val res =
202+
ctx.gadt.approximation(sym, fromBelow = variance < 0)
203+
204+
debug.println(i"approximated $tp ~~ $res")
205+
206+
res
207+
208+
case _: WildcardType | _: ProtoType =>
209+
failed = true
210+
NoType
211+
212+
case tp =>
213+
mapOver(tp)
214+
}
215+
216+
// private class UpperInstantiator(implicit ctx: Context) extends TypeAccumulator[Unit] {
217+
// def apply(x: Unit, tp: Type): Unit = {
218+
// tp match {
219+
// case tvar: TypeVar if !tvar.isInstantiated =>
220+
// instantiate(tvar, fromBelow = false)
221+
// case _ =>
222+
// }
223+
// foldOver(x, tp)
224+
// }
225+
// }
226+
227+
def process(tp: Type): Type = {
228+
val res = apply(tp)
229+
// if (res && toMaximize) new UpperInstantiator().apply((), tp)
230+
res
231+
}
232+
}
233+
165234
/** For all type parameters occurring in `tp`:
166235
* If the bounds of `tp` in the current constraint are equal wrt =:=,
167236
* instantiate the type parameter to the lower bound's approximation

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import transform.TypeUtils._
4444
import reporting.trace
4545
import Nullables.{NotNullInfo, given _}
4646
import NullOpsDecorator._
47+
import config.Printers.debug
4748

4849
object Typer {
4950

@@ -2771,20 +2772,23 @@ class Typer extends Namer
27712772
* If all this fails, error
27722773
* Parameters as for `typedUnadapted`.
27732774
*/
2774-
def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
2775-
trace(i"adapting $tree to $pt", typr, show = true) {
2775+
def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = {
2776+
val last = Thread.currentThread.getStackTrace()(2).toString;
2777+
trace/*.force*/(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
27762778
record("adapt")
2777-
adapt1(tree, pt, locked)
2779+
adapt1(tree, pt, locked, tryGadtHealing)
27782780
}
2781+
}
27792782

27802783
final def adapt(tree: Tree, pt: Type)(using Context): Tree =
27812784
adapt(tree, pt, ctx.typerState.ownedVars)
27822785

2783-
private def adapt1(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
2786+
private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
2787+
// assert(pt.exists && !pt.isInstanceOf[ExprType])
27842788
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
27852789
def methodStr = err.refStr(methPart(tree).tpe)
27862790

2787-
def readapt(tree: Tree)(using Context) = adapt(tree, pt, locked)
2791+
def readapt(tree: Tree, shouldTryGadtHealing: Boolean = tryGadtHealing)(using Context) = adapt(tree, pt, locked, shouldTryGadtHealing)
27882792
def readaptSimplified(tree: Tree)(using Context) = readapt(simplify(tree, pt, locked))
27892793

27902794
def missingArgs(mt: MethodType) = {
@@ -3239,16 +3243,19 @@ class Typer extends Namer
32393243
}
32403244

32413245
def adaptToSubType(wtp: Type): Tree = {
3246+
debug.println("adaptToSubType")
3247+
debug.println("// try converting a constant to the target type")
32423248
// try converting a constant to the target type
32433249
val folded = ConstFold(tree, pt)
32443250
if (folded ne tree)
32453251
return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType])
32463252

3247-
// Try to capture wildcards in type
3253+
debug.println("// Try to capture wildcards in type")
32483254
val captured = captureWildcards(wtp)
32493255
if (captured `ne` wtp)
32503256
return readapt(tree.cast(captured))
32513257

3258+
debug.println("// drop type if prototype is Unit")
32523259
// drop type if prototype is Unit
32533260
if (pt isRef defn.UnitClass) {
32543261
// local adaptation makes sure every adapted tree conforms to its pt
@@ -3258,6 +3265,7 @@ class Typer extends Namer
32583265
return tpd.Block(tree1 :: Nil, Literal(Constant(())))
32593266
}
32603267

3268+
debug.println("// convert function literal to SAM closure")
32613269
// convert function literal to SAM closure
32623270
tree match {
32633271
case closure(Nil, id @ Ident(nme.ANON_FUN), _)
@@ -3275,6 +3283,28 @@ class Typer extends Namer
32753283
case _ =>
32763284
}
32773285

3286+
debug.println("// try GADT approximation")
3287+
val foo = Inferencing.approximateGADT(wtp)
3288+
debug.println(
3289+
i"""
3290+
foo = $foo
3291+
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
3292+
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
3293+
pt.isMatchedBy = ${
3294+
if (pt.isInstanceOf[SelectionProto])
3295+
pt.asInstanceOf[SelectionProto].isMatchedBy(foo).toString
3296+
else
3297+
"<not a SelectionProto>"
3298+
}
3299+
"""
3300+
)
3301+
pt match {
3302+
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(foo) =>
3303+
return tpd.Typed(tree, TypeTree(foo))
3304+
case _ => ;
3305+
}
3306+
3307+
debug.println("// try an extension method in scope")
32783308
// try an extension method in scope
32793309
pt match {
32803310
case SelectionProto(name, mbrType, _, _) =>
@@ -3292,17 +3322,33 @@ class Typer extends Namer
32923322
val app = tryExtension(using nestedCtx)
32933323
if (!app.isEmpty && !nestedCtx.reporter.hasErrors) {
32943324
nestedCtx.typerState.commit()
3325+
debug.println("returning ext meth in scope")
32953326
return ExtMethodApply(app)
32963327
}
32973328
case _ =>
32983329
}
32993330

3331+
debug.println("// try an implicit conversion")
33003332
// try an implicit conversion
33013333
val prevConstraint = ctx.typerState.constraint
33023334
def recover(failure: SearchFailureType) =
3335+
{
3336+
debug.println("recover")
33033337
if (isFullyDefined(wtp, force = ForceDegree.all) &&
33043338
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
3339+
// else if ({
3340+
// debug.println(i"tryGadtHealing=$tryGadtHealing && \n\tctx.gadt.nonEmpty=${ctx.gadt.nonEmpty}")
3341+
// tryGadtHealing && ctx.gadt.nonEmpty
3342+
// })
3343+
// {
3344+
// debug.println("here")
3345+
// readapt(
3346+
// tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3347+
// shouldTryGadtHealing = false,
3348+
// )
3349+
// }
33053350
else err.typeMismatch(tree, pt, failure)
3351+
}
33063352
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
33073353
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
33083354
ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos)

0 commit comments

Comments
 (0)