Skip to content

Commit 708a2cf

Browse files
committed
wip
1 parent 2ca1165 commit 708a2cf

File tree

5 files changed

+207
-36
lines changed

5 files changed

+207
-36
lines changed

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import NameKinds.UniqueName
1111
import util.Spans._
1212
import util.{Stats, SimpleIdentityMap}
1313
import Decorators._
14-
import config.Printers.{gadts, typr}
14+
import config.Printers.{gadts, typr, debug}
1515
import annotation.tailrec
1616
import reporting._
1717
import collection.mutable
@@ -171,37 +171,45 @@ object Inferencing {
171171
res
172172
}
173173

174-
/** This class is mostly based on IsFullyDefinedAccumulator.
175-
* It tries to approximate the given type based on the available GADT constraints.
176-
*/
174+
/** Approximates a type to get rid of as many GADT-constrained abstract types as possible. */
177175
private class ApproximateGadtAccumulator(implicit ctx: Context) extends TypeMap {
178176

179177
var failed = false
180178

181-
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
182-
val inst = tvar.instantiate(fromBelow)
183-
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
184-
inst
185-
}
186-
187-
private def instDirection2(sym: Symbol)(implicit ctx: Context): Int = {
188-
val constrained = ctx.gadt.fullBounds(sym)
189-
val original = sym.info.bounds
190-
val cmp = ctx.typeComparer
191-
val approxBelow =
192-
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
193-
val approxAbove =
194-
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
195-
approxAbove - approxBelow
196-
}
197-
198-
private[this] var toMaximize: Boolean = false
199-
179+
/** GADT approximation proceeds differently from type variable approximation.
180+
*
181+
* Essentially, what we're doing is we're inferring a type ascription that
182+
* will remove as many GADT-constrained types as possible. This means that
183+
* we want to approximate type T to type S in such a way that no matter how
184+
* GADT-constrained types are instantiated, T <: S. In other words, the
185+
* relationship _necessarily_ must hold.
186+
*
187+
* We accomplish that by:
188+
* - replacing covariant occurences with upper GADT bound
189+
* - replacing contravariant occurences with lower GADT bound
190+
* - leaving invariant occurences alone
191+
*
192+
* Examples:
193+
* - If we have GADT cstr A <: Int, then for all A <: Int, Option[A] <: Option[Int].
194+
* Therefore, we can approximate Option[A] ~~ Option[Int].
195+
* - If we have A >: S <: T, then for all such A, A => A <: S => T. This
196+
* illustrates that it's fine to differently approximate different
197+
* occurences of same type.
198+
* - If we have A <: Int and F <: [A] => Option[A] (note the invariance),
199+
* then we should approximate F[A] ~~ Option[A]. That is, we should
200+
* respect the invariance of the type constructor.
201+
* - If we have A <: Option[B] and B <: Int, we approximate A ~~ Option[Int].
202+
* That is, we recursively approximate all nested GADT-constrained types.
203+
* This is certain to be sound (because we maintain necessary subtyping),
204+
* but not accurate.
205+
*/
200206
def apply(tp: Type): Type = tp.dealias match {
201-
case tp @ TypeRef(qual, nme) if (qual eq NoPrefix) && ctx.gadt.contains(tp.symbol) =>
207+
case tp @ TypeRef(qual, nme) if (qual eq NoPrefix)
208+
&& variance != 0
209+
&& ctx.gadt.contains(tp.symbol)
210+
=>
202211
val sym = tp.symbol
203-
val res =
204-
ctx.gadt.approximation(sym, fromBelow = variance < 0)
212+
val res = ctx.gadt.approximation(sym, fromBelow = variance < 0)
205213
gadts.println(i"approximated $tp ~~ $res")
206214
res
207215

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import collection.mutable
3333
import annotation.tailrec
3434
import Implicits._
3535
import util.Stats.record
36-
import config.Printers.{gadts, typr}
36+
import config.Printers.{gadts, typr, debug}
3737
import config.Feature._
3838
import config.SourceVersion._
3939
import rewrites.Rewrites.patch
@@ -3410,25 +3410,25 @@ class Typer extends Namer
34103410
case _ =>
34113411
}
34123412

3413-
val approximation = Inferencing.approximateGADT(wtp)
3413+
val gadtApprox = Inferencing.approximateGADT(wtp)
34143414
gadts.println(
34153415
i"""GADT approximation {
3416-
approximation = $approximation
3416+
approximation = $gadtApprox
34173417
pt.isInstanceOf[SelectionProto] = ${pt.isInstanceOf[SelectionProto]}
34183418
ctx.gadt.nonEmpty = ${ctx.gadt.nonEmpty}
34193419
ctx.gadt = ${ctx.gadt.debugBoundsDescription}
34203420
pt.isMatchedBy = ${
34213421
if (pt.isInstanceOf[SelectionProto])
3422-
pt.asInstanceOf[SelectionProto].isMatchedBy(approximation).toString
3422+
pt.asInstanceOf[SelectionProto].isMatchedBy(gadtApprox).toString
34233423
else
34243424
"<not a SelectionProto>"
34253425
}
34263426
}
34273427
"""
34283428
)
34293429
pt match {
3430-
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(approximation) =>
3431-
return tpd.Typed(tree, TypeTree(approximation))
3430+
case pt: SelectionProto if ctx.gadt.nonEmpty && pt.isMatchedBy(gadtApprox) =>
3431+
return tpd.Typed(tree, TypeTree(gadtApprox))
34323432
case _ => ;
34333433
}
34343434

@@ -3458,10 +3458,38 @@ class Typer extends Namer
34583458

34593459
// try an implicit conversion
34603460
val prevConstraint = ctx.typerState.constraint
3461-
def recover(failure: SearchFailureType) =
3461+
def recover(failure: SearchFailureType) = {
3462+
def canTryGADTHealing: Boolean = {
3463+
def isDummy = tree.hasAttachment(dummyTreeOfType.IsDummyTree)
3464+
tryGadtHealing // allow GADT healing only once to avoid a loop
3465+
&& ctx.gadt.nonEmpty // GADT healing only makes sense if there are GADT constraints present
3466+
&& !isDummy // avoid healing a dummy tree as it can lead to an error in a very specific case
3467+
}
3468+
34623469
if (isFullyDefined(wtp, force = ForceDegree.all) &&
34633470
ctx.typerState.constraint.ne(prevConstraint)) readapt(tree)
3464-
else err.typeMismatch(tree, pt, failure)
3471+
else if (canTryGADTHealing) {
3472+
// try recovering with a GADT approximation
3473+
// note: this seems be be important only in a very specific case
3474+
// where we select a member from so
3475+
val nestedCtx = ctx.fresh.setNewTyperState()
3476+
val ascribed = tpd.Typed(tree, TypeTree(gadtApprox))
3477+
val res =
3478+
readapt(
3479+
tree = ascribed,
3480+
shouldTryGadtHealing = false,
3481+
)(using nestedCtx)
3482+
if (!nestedCtx.reporter.hasErrors) {
3483+
// GADT recovery successful
3484+
nestedCtx.typerState.commit()
3485+
res
3486+
} else {
3487+
// otherwise fail with the error that would have been reported without the GADT recovery
3488+
err.typeMismatch(tree, pt, failure)
3489+
}
3490+
} else err.typeMismatch(tree, pt, failure)
3491+
}
3492+
34653493
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
34663494
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
34673495
ctx.error(em"the result of an implicit conversion must be more specific than $pt", tree.sourcePos)
@@ -3472,14 +3500,13 @@ class Typer extends Namer
34723500
checkImplicitConversionUseOK(found.symbol, tree.posd)
34733501
readapt(found)(using ctx.retractMode(Mode.ImplicitsEnabled))
34743502
case failure: SearchFailure =>
3475-
if (pt.isInstanceOf[ProtoType] && !failure.isAmbiguous) {
3503+
if (pt.isInstanceOf[ProtoType] && !failure.isAmbiguous) then
34763504
// don't report the failure but return the tree unchanged. This
34773505
// will cause a failure at the next level out, which usually gives
34783506
// a better error message. To compensate, store the encountered failure
34793507
// as an attachment, so that it can be reported later as an addendum.
34803508
rememberSearchFailure(tree, failure)
34813509
tree
3482-
}
34833510
else recover(failure.reason)
34843511
}
34853512
else recover(NoMatchingImplicits)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
object MemberHealing {
2+
enum SUB[-A, +B]:
3+
case Refl[S]() extends SUB[S, S]
4+
5+
def foo[T](t: T, ev: T SUB Int) =
6+
ev match { case SUB.Refl() =>
7+
t + 2
8+
}
9+
}
10+
11+
object ImplicitLookup {
12+
enum SUB[-A, +B]:
13+
case Refl[S]() extends SUB[S, S]
14+
15+
class Tag[T]
16+
17+
implicit val ti: Tag[Int] = Tag()
18+
19+
def foo[T](t: T, ev: T SUB Int) =
20+
ev match { case SUB.Refl() =>
21+
implicitly[Tag[Int]]
22+
}
23+
}
24+
25+
object GivenLookup {
26+
enum SUB[-A, +B]:
27+
case Refl[S]() extends SUB[S, S]
28+
29+
class Tag[T]
30+
31+
given ti as Tag[Int]
32+
33+
def foo[T](t: T, ev: T SUB Int) =
34+
ev match { case SUB.Refl() =>
35+
summon[Tag[Int]]
36+
}
37+
}
38+
39+
object ImplicitConversion {
40+
enum SUB[-A, +B]:
41+
case Refl[S]() extends SUB[S, S]
42+
43+
class Pow(self: Int):
44+
def **(other: Int): Int = math.pow(self, other).toInt
45+
46+
implicit def pow(i: Int): Pow = Pow(i)
47+
48+
def foo[T](t: T, ev: T SUB Int) =
49+
ev match { case SUB.Refl() =>
50+
t ** 2 // error // implementation limitation
51+
}
52+
53+
def bar[T](t: T, ev: T SUB Int) =
54+
ev match { case SUB.Refl() =>
55+
(t: Int) ** 2 // sanity check
56+
}
57+
}
58+
59+
object GivenConversion {
60+
enum SUB[-A, +B]:
61+
case Refl[S]() extends SUB[S, S]
62+
63+
class Pow(self: Int):
64+
def **(other: Int): Int = math.pow(self, other).toInt
65+
66+
given as Conversion[Int, Pow] = (i: Int) => Pow(i)
67+
68+
def foo[T](t: T, ev: T SUB Int) =
69+
ev match { case SUB.Refl() =>
70+
t ** 2 // error (implementation limitation)
71+
}
72+
73+
def bar[T](t: T, ev: T SUB Int) =
74+
ev match { case SUB.Refl() =>
75+
(t: Int) ** 2 // sanity check
76+
}
77+
}
78+
79+
object ExtensionMethod {
80+
enum SUB[-A, +B]:
81+
case Refl[S]() extends SUB[S, S]
82+
83+
extension (x: Int):
84+
def **(y: Int) = math.pow(x, y).toInt
85+
86+
def foo[T](t: T, ev: T SUB Int) =
87+
ev match { case SUB.Refl() =>
88+
t ** 2
89+
}
90+
}
91+
92+
object HKFun {
93+
enum SUB[-A, +B]:
94+
case Refl[S]() extends SUB[S, S]
95+
96+
enum HKSUB[-F[_], +G[_]]:
97+
case Refl[H[_]]() extends HKSUB[H, H]
98+
99+
def foo[F[_], T](ft: F[T], hkev: F HKSUB Option, ev: T SUB Int) =
100+
hkev match { case HKSUB.Refl() =>
101+
ev match { case SUB.Refl() =>
102+
// both should typecheck - we should respect invariance of F
103+
// (and not approximate its argument)
104+
// but also T <: Int b/c of ev
105+
val x: T = ft.get
106+
val y: Int = ft.get
107+
}
108+
}
109+
110+
enum COVHKSUB[-F[+_], +G[+_]]:
111+
case Refl[H[_]]() extends COVHKSUB[H, H]
112+
113+
def bar[F[+_], T](ft: F[T], hkev: F COVHKSUB Option, ev: T SUB Int) =
114+
hkev match { case COVHKSUB.Refl() =>
115+
ev match { case SUB.Refl() =>
116+
// Sanity check for `foo`
117+
// this is an error only because we blindly approximate covariant type arguments
118+
// if it stops being an error, `foo` should be re-thought
119+
val x: T = ft.get // error
120+
val y: Int = ft.get
121+
}
122+
}
123+
}
124+
125+
object NestedConstrained {
126+
enum SUB[-A, +B]:
127+
case Refl[S]() extends SUB[S, S]
128+
129+
def foo[A, B](a: A, ev1: A SUB Option[B], ev2: B SUB Int) =
130+
ev1 match { case SUB.Refl() =>
131+
ev2 match { case SUB.Refl() =>
132+
1 + "a"
133+
a.get : Int
134+
}
135+
}
136+
}
File renamed without changes.

0 commit comments

Comments
 (0)