Skip to content

Commit e0e1c80

Browse files
authored
Merge pull request #6767 from dotty-staging/lambda-invariant-check
Lambda invariant check
2 parents 33666e9 + 1f8650c commit e0e1c80

File tree

7 files changed

+44
-20
lines changed

7 files changed

+44
-20
lines changed

compiler/src/dotty/tools/dotc/core/Phases.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ object Phases {
353353
private[this] var myErasedTypes = false
354354
private[this] var myFlatClasses = false
355355
private[this] var myRefChecked = false
356+
private[this] var myLambdaLifted = false
356357

357358
private[this] var mySameMembersStartId = NoPhaseId
358359
private[this] var mySameParentsStartId = NoPhaseId
@@ -371,6 +372,7 @@ object Phases {
371372
final def erasedTypes: Boolean = myErasedTypes // Phase is after erasure
372373
final def flatClasses: Boolean = myFlatClasses // Phase is after flatten
373374
final def refChecked: Boolean = myRefChecked // Phase is after RefChecks
375+
final def lambdaLifted: Boolean = myLambdaLifted // Phase is after LambdaLift
374376

375377
final def sameMembersStartId: Int = mySameMembersStartId
376378
// id of first phase where all symbols are guaranteed to have the same members as in this phase
@@ -385,9 +387,10 @@ object Phases {
385387
assert(start <= Periods.MaxPossiblePhaseId, s"Too many phases, Period bits overflow")
386388
myBase = base
387389
myPeriod = Period(NoRunId, start, end)
388-
myErasedTypes = prev.getClass == classOf[Erasure] || prev.erasedTypes
389-
myFlatClasses = prev.getClass == classOf[Flatten] || prev.flatClasses
390-
myRefChecked = prev.getClass == classOf[RefChecks] || prev.refChecked
390+
myErasedTypes = prev.getClass == classOf[Erasure] || prev.erasedTypes
391+
myFlatClasses = prev.getClass == classOf[Flatten] || prev.flatClasses
392+
myRefChecked = prev.getClass == classOf[RefChecks] || prev.refChecked
393+
myLambdaLifted = prev.getClass == classOf[LambdaLift] || prev.lambdaLifted
391394
mySameMembersStartId = if (changesMembers) id else prev.sameMembersStartId
392395
mySameParentsStartId = if (changesParents) id else prev.sameParentsStartId
393396
mySameBaseTypesStartId = if (changesBaseTypes) id else prev.sameBaseTypesStartId

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,16 @@ class TreeChecker extends Phase with SymTransformer {
178178
res
179179
}
180180

181+
// used to check invariant of lambda encoding
182+
var nestingBlock: untpd.Block | Null = null
183+
private def withBlock[T](block: untpd.Block)(op: => T): T = {
184+
val outerBlock = nestingBlock
185+
nestingBlock = block
186+
val res = op
187+
nestingBlock = outerBlock
188+
res
189+
}
190+
181191
def assertDefined(tree: untpd.Tree)(implicit ctx: Context): Unit =
182192
if (tree.symbol.maybeOwner.isTerm)
183193
assert(nowDefinedSyms contains tree.symbol, i"undefined symbol ${tree.symbol} at line " + tree.sourcePos.line)
@@ -407,8 +417,22 @@ class TreeChecker extends Phase with SymTransformer {
407417
}
408418
}
409419

420+
override def typedClosure(tree: untpd.Closure, pt: Type)(implicit ctx: Context): Tree = {
421+
if (!ctx.phase.lambdaLifted) nestingBlock match {
422+
case block @ Block((meth : DefDef) :: Nil, closure: Closure) =>
423+
assert(meth.symbol == closure.meth.symbol, "closure.meth symbol not equal to method symbol. Block: " + block.show)
424+
425+
case block: untpd.Block =>
426+
assert(false, "function literal are not properly formed as a block of DefDef and Closure. Found: " + tree.show + " Nesting block: " + block.show)
427+
428+
case null =>
429+
assert(false, "function literal are not properly formed as a block of DefDef and Closure. Found: " + tree.show + " Nesting block: null")
430+
}
431+
super.typedClosure(tree, pt)
432+
}
433+
410434
override def typedBlock(tree: untpd.Block, pt: Type)(implicit ctx: Context): Tree =
411-
withDefinedSyms(tree.stats) { super.typedBlock(tree, pt) }
435+
withBlock(tree) { withDefinedSyms(tree.stats) { super.typedBlock(tree, pt) } }
412436

413437
override def typedInlined(tree: untpd.Inlined, pt: Type)(implicit ctx: Context): Tree =
414438
withDefinedSyms(tree.bindings) { super.typedInlined(tree, pt) }

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class Typer extends Namer
730730
*/
731731
protected def ensureNoLocalRefs(tree: Tree, pt: Type, localSyms: => List[Symbol])(implicit ctx: Context): Tree = {
732732
def ascribeType(tree: Tree, pt: Type): Tree = tree match {
733-
case block @ Block(stats, expr) =>
733+
case block @ Block(stats, expr) if !expr.isInstanceOf[Closure] =>
734734
val expr1 = ascribeType(expr, pt)
735735
cpy.Block(block)(stats, expr1) withType expr1.tpe // no assignType here because avoid is redundant
736736
case _ =>
@@ -3081,7 +3081,7 @@ class Typer extends Namer
30813081
}
30823082

30833083
tree match {
3084-
case _: MemberDef | _: PackageDef | _: Import | _: WithoutTypeOrPos[_] => tree
3084+
case _: MemberDef | _: PackageDef | _: Import | _: WithoutTypeOrPos[_] | _: Closure => tree
30853085
case _ => tree.tpe.widen match {
30863086
case tp: FlexType =>
30873087
ensureReported(tp)

tests/neg/erased-5.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ object Test {
33
type UU[T] = erased T => Int
44

55
def main(args: Array[String]): Unit = {
6-
fun { x =>
7-
x // error: Cannot use `erased` value in a context that is not `erased`
6+
fun { x => // error: `Int => Int` not compatible with `erased Int => Int`
7+
x
88
}
99

1010
fun {

tests/neg/i2146.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
object Test {
2-
case class A()
3-
case class B()
4-
5-
def foo[A, B]: given A => given B => Int = { given b: B =>
6-
42 // error: found Int, required: given A => given B => Int
1+
class Test {
2+
def foo[A, B]: given A => given B => Int = { given b: B => // error: found Int, required: given A => given B => Int
3+
42
74
}
85
}

tests/neg/i5311.check

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
-- [E007] Type Mismatch Error: tests/neg/i5311.scala:11:27 -------------------------------------------------------------
1+
-- [E007] Type Mismatch Error: tests/neg/i5311.scala:11:9 --------------------------------------------------------------
22
11 | baz((x : s.T[Int]) => x) // error
3-
| ^
4-
| Found: s.T[Int] => s.T[Int]
5-
| Required: m.Foo
3+
| ^^^^^^^^^^^^^^^^^^
4+
| Found: s.T[Int] => s.T[Int]
5+
| Required: m.Foo

tests/neg/i5592.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ object Test {
2020
}
2121

2222
val eqSymmetric2: Forall[[x] =>> (y: Obj) => (EQ[x, y.type]) => (EQ[y.type, x])] = {
23-
{ x: Obj => { y: Obj => { xEqy: EQ[x.type, y.type] => xEqy.commute } } } // error // error
23+
{ x: Obj => { y: Obj => { xEqy: EQ[x.type, y.type] => xEqy.commute } } } // error
2424
}
2525

2626
val eqSymmetric3: Forall[[x] =>> Forall[[y] =>> EQ[x, y] => EQ[y, x]]] = {
27-
{ x: Obj => { y: Obj => { xEqy: EQ[x.type, y.type] => xEqy.commute } } } // error // error
27+
{ x: Obj => { y: Obj => { xEqy: EQ[x.type, y.type] => xEqy.commute } } } // error
2828
}
2929
}

0 commit comments

Comments
 (0)