Skip to content

Commit 03e34a2

Browse files
abgruszeckiradeusgd
authored andcommitted
Add GADT approximation in adaptToSubType
1 parent 624cc47 commit 03e34a2

File tree

5 files changed

+124
-8
lines changed

5 files changed

+124
-8
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ sealed abstract class GadtConstraint extends Showable {
4444
def contains(sym: Symbol)(implicit ctx: Context): Boolean
4545

4646
def isEmpty: Boolean
47+
final def nonEmpty: Boolean = !isEmpty
4748

4849
/** See [[ConstraintHandling.approximation]] */
4950
def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ object Erasure {
954954
(stats2.filter(!_.isEmpty), finalCtx)
955955
}
956956

957-
override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
957+
override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree =
958958
trace(i"adapting ${tree.showSummary}: ${tree.tpe} to $pt", show = true) {
959959
if ctx.phase != ctx.erasurePhase && ctx.phase != ctx.erasurePhase.next then
960960
// this can happen when reading annotations loaded during erasure,

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ class TreeChecker extends Phase with SymTransformer {
511511
override def ensureNoLocalRefs(tree: Tree, pt: Type, localSyms: => List[Symbol])(using Context): Tree =
512512
tree
513513

514-
override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
514+
override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
515515
def isPrimaryConstructorReturn =
516516
ctx.owner.isPrimaryConstructor && pt.isRef(ctx.owner.owner) && tree.tpe.isRef(defn.UnitClass)
517517
def infoStr(tp: Type) = tp match {

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import collection.mutable
1919
import scala.annotation.internal.sharable
2020
import scala.annotation.threadUnsafe
2121

22+
import config.Printers.debug
23+
2224
object Inferencing {
2325

2426
import tpd._
@@ -162,6 +164,73 @@ object Inferencing {
162164
)
163165
}
164166

167+
def approximateGADT(tp: Type)(implicit ctx: Context): Type = {
168+
val map = new IsFullyDefinedAccumulator2
169+
val res = map(tp)
170+
assert(!map.failed)
171+
debug.println(i"approximateGADT( $tp ) = $res // {${tp.toString}}")
172+
res
173+
}
174+
175+
private class IsFullyDefinedAccumulator2(implicit ctx: Context) extends TypeMap {
176+
177+
var failed = false
178+
179+
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
180+
val inst = tvar.instantiate(fromBelow)
181+
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
182+
inst
183+
}
184+
185+
private def instDirection2(sym: Symbol)(implicit ctx: Context): Int = {
186+
val constrained = ctx.gadt.fullBounds(sym)
187+
val original = sym.info.bounds
188+
val cmp = ctx.typeComparer
189+
val approxBelow =
190+
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
191+
val approxAbove =
192+
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
193+
approxAbove - approxBelow
194+
}
195+
196+
private[this] var toMaximize: Boolean = false
197+
198+
def apply(tp: Type): Type = tp.dealias match {
199+
case tp @ TypeRef(qual, nme) if (qual eq NoPrefix) && ctx.gadt.contains(tp.symbol) =>
200+
val sym = tp.symbol
201+
val res =
202+
ctx.gadt.approximation(sym, fromBelow = variance < 0)
203+
204+
debug.println(i"approximated $tp ~~ $res")
205+
206+
res
207+
208+
case _: WildcardType | _: ProtoType =>
209+
failed = true
210+
NoType
211+
212+
case tp =>
213+
mapOver(tp)
214+
}
215+
216+
// private class UpperInstantiator(implicit ctx: Context) extends TypeAccumulator[Unit] {
217+
// def apply(x: Unit, tp: Type): Unit = {
218+
// tp match {
219+
// case tvar: TypeVar if !tvar.isInstantiated =>
220+
// instantiate(tvar, fromBelow = false)
221+
// case _ =>
222+
// }
223+
// foldOver(x, tp)
224+
// }
225+
// }
226+
227+
def process(tp: Type): Type = {
228+
val res = apply(tp)
229+
// if (res && toMaximize) new UpperInstantiator().apply((), tp)
230+
res
231+
}
232+
}
233+
165234
/** For all type parameters occurring in `tp`:
166235
* If the bounds of `tp` in the current constraint are equal wrt =:=,
167236
* instantiate the type parameter to the lower bound's approximation

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import transform.TypeUtils._
4444
import reporting.trace
4545
import Nullables.{NotNullInfo, given _}
4646
import NullOpsDecorator._
47+
import config.Printers.debug
4748

4849
object Typer {
4950

@@ -2762,20 +2763,23 @@ class Typer extends Namer
27622763
* If all this fails, error
27632764
* Parameters as for `typedUnadapted`.
27642765
*/
2765-
def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
2766-
trace(i"adapting $tree to $pt", typr, show = true) {
2766+
def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean = true)(using Context): Tree = {
2767+
val last = Thread.currentThread.getStackTrace()(2).toString;
2768+
trace/*.force*/(i"adapting (tryGadtHealing=$tryGadtHealing) $tree to $pt\n{callsite: $last}", typr, show = true) {
27672769
record("adapt")
2768-
adapt1(tree, pt, locked)
2770+
adapt1(tree, pt, locked, tryGadtHealing)
27692771
}
2772+
}
27702773

27712774
final def adapt(tree: Tree, pt: Type)(using Context): Tree =
27722775
adapt(tree, pt, ctx.typerState.ownedVars)
27732776

2774-
private def adapt1(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
2777+
private def adapt1(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = {
2778+
// assert(pt.exists && !pt.isInstanceOf[ExprType])
27752779
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
27762780
def methodStr = err.refStr(methPart(tree).tpe)
27772781

2778-
def readapt(tree: Tree)(using Context) = adapt(tree, pt, locked)
2782+
def readapt(tree: Tree, shouldTryGadtHealing: Boolean = tryGadtHealing)(using Context) = adapt(tree, pt, locked, shouldTryGadtHealing)
27792783
def readaptSimplified(tree: Tree)(using Context) = readapt(simplify(tree, pt, locked))
27802784

27812785
def missingArgs(mt: MethodType) = {
@@ -3221,16 +3225,19 @@ class Typer extends Namer
32213225
}
32223226

32233227
def adaptToSubType(wtp: Type): Tree = {
3228+
debug.println("adaptToSubType")
3229+
debug.println("// try converting a constant to the target type")
32243230
// try converting a constant to the target type
32253231
val folded = ConstFold(tree, pt)
32263232
if (folded ne tree)
32273233
return adaptConstant(folded, folded.tpe.asInstanceOf[ConstantType])
32283234

3229-
// Try to capture wildcards in type
3235+
debug.println("// Try to capture wildcards in type")
32303236
val captured = captureWildcards(wtp)
32313237
if (captured `ne` wtp)
32323238
return readapt(tree.cast(captured))
32333239

3240+
debug.println("// drop type if prototype is Unit")
32343241
// drop type if prototype is Unit
32353242
if (pt isRef defn.UnitClass) {
32363243
// local adaptation makes sure every adapted tree conforms to its pt
@@ -3240,6 +3247,7 @@ class Typer extends Namer
32403247
return tpd.Block(tree1 :: Nil, Literal(Constant(())))
32413248
}
32423249

3250+
debug.println("// convert function literal to SAM closure")
32433251
// convert function literal to SAM closure
32443252
tree match {
32453253
case closure(Nil, id @ Ident(nme.ANON_FUN), _)
@@ -3257,6 +3265,28 @@ class Typer extends Namer
32573265
case _ =>
32583266
}
32593267

3268+
debug.println("// try GADT approximation")
3269+
val foo = Inferencing.approximateGADT(wtp)
3270+
debug.println(
3271+
i"""
3272+
foo = $foo
3273+
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
3274+
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
3275+
pt.isMatchedBy = ${
3276+
if (pt.isInstanceOf[SelectionProto])
3277+
pt.asInstanceOf[SelectionProto].isMatchedBy(foo).toString
3278+
else
3279+
"<not a SelectionProto>"
3280+
}
3281+
"""
3282+
)
3283+
pt match {
3284+
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(foo) =>
3285+
return tpd.Typed(tree, TypeTree(foo))
3286+
case _ => ;
3287+
}
3288+
3289+
debug.println("// try an extension method in scope")
32603290
// try an extension method in scope
32613291
pt match {
32623292
case SelectionProto(name, mbrType, _, _) =>
@@ -3274,17 +3304,33 @@ class Typer extends Namer
32743304
val app = tryExtension(using nestedCtx)
32753305
if (!app.isEmpty && !nestedCtx.reporter.hasErrors) {
32763306
nestedCtx.typerState.commit()
3307+
debug.println("returning ext meth in scope")
32773308
return ExtMethodApply(app)
32783309
}
32793310
case _ =>
32803311
}
32813312

3313+
debug.println("// try an implicit conversion")
32823314
// try an implicit conversion
32833315
val prevConstraint = ctx.typerState.constraint
32843316
def recover(failure: SearchFailureType) =
3317+
{
3318+
debug.println("recover")
32853319
if (isFullyDefined(wtp, force = ForceDegree.all) &&
32863320
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
3321+
// else if ({
3322+
// debug.println(i"tryGadtHealing=$tryGadtHealing && \n\tctx.gadt.nonEmpty=${ctx.gadt.nonEmpty}")
3323+
// tryGadtHealing && ctx.gadt.nonEmpty
3324+
// })
3325+
// {
3326+
// debug.println("here")
3327+
// readapt(
3328+
// tree = tpd.Typed(tree, TypeTree(Inferencing.approximateGADT(wtp))),
3329+
// shouldTryGadtHealing = false,
3330+
// )
3331+
// }
32873332
else err.typeMismatch(tree, pt, failure)
3333+
}
32883334
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
32893335
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
32903336
ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos)

0 commit comments

Comments
 (0)