Skip to content

Commit 940f517

Browse files
committed
Check variances of type lambdas after Typer
1 parent 5934d7b commit 940f517

File tree

8 files changed

+61
-60
lines changed

8 files changed

+61
-60
lines changed

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package transform
44
import dotty.tools.dotc.ast.{Trees, tpd, untpd}
55
import scala.collection.mutable
66
import core._
7-
import typer.Checking
7+
import typer.{Checking, VarianceChecker}
88
import Types._, Contexts._, Names._, Flags._, DenotTransformers._, Phases._
99
import SymDenotations._, StdNames._, Annotations._, Trees._, Scopes._
1010
import Decorators._
@@ -296,6 +296,9 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
296296
// when trying to typecheck self types which are intersections.
297297
Checking.checkNonCyclicInherited(tree.tpe, tree.left.tpe :: tree.right.tpe :: Nil, EmptyScope, tree.pos)
298298
super.transform(tree)
299+
case tree: LambdaTypeTree =>
300+
VarianceChecker.checkLambda(tree)
301+
super.transform(tree)
299302
case Import(expr, selectors) =>
300303
val exprTpe = expr.tpe
301304
val seen = mutable.Set.empty[Name]

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,8 +1307,7 @@ class Typer extends Namer
13071307
index(tparams)
13081308
val tparams1 = tparams.mapconserve(typed(_).asInstanceOf[TypeDef])
13091309
val body1 = typedType(tree.body)
1310-
VarianceChecker.checkLambda(
1311-
assignType(cpy.LambdaTypeTree(tree)(tparams1, body1), tparams1, body1))
1310+
assignType(cpy.LambdaTypeTree(tree)(tparams1, body1), tparams1, body1)
13121311
}
13131312

13141313
def typedMatchTypeTree(tree: untpd.MatchTypeTree, pt: Type)(implicit ctx: Context): Tree = {
@@ -1501,8 +1500,7 @@ class Typer extends Namer
15011500
case rhs @ LambdaTypeTree(tparams, body) =>
15021501
val tparams1 = tparams.map(typed(_)).asInstanceOf[List[TypeDef]]
15031502
val body1 = typedType(body)
1504-
VarianceChecker.checkLambda(
1505-
assignType(cpy.LambdaTypeTree(rhs)(tparams1, body1), tparams1, body1))
1503+
assignType(cpy.LambdaTypeTree(rhs)(tparams1, body1), tparams1, body1)
15061504
case rhs =>
15071505
typedType(rhs)
15081506
}

compiler/src/dotty/tools/dotc/typer/VarianceChecker.scala

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,35 @@ object VarianceChecker {
2525
* Note: this is achieved by a mechanism separate from checking class type parameters.
2626
* Question: Can the two mechanisms be combined in one?
2727
*/
28-
def checkLambda(tree: tpd.LambdaTypeTree)(implicit ctx: Context): tree.type = {
29-
tree.tpe match {
30-
case tl: HKTypeLambda =>
31-
val checkOK = new TypeAccumulator[Boolean] {
32-
def error(tref: TypeParamRef) = {
33-
val VariantName(paramName, v) = tl.paramNames(tref.paramNum).toTermName
34-
val paramVarianceStr = if (v == 0) "contra" else "co"
35-
val occursStr = variance match {
36-
case -1 => "contra"
37-
case 0 => "non"
38-
case 1 => "co"
39-
}
40-
val pos = tree.tparams
41-
.find(_.name.toTermName == paramName)
42-
.map(_.pos)
43-
.getOrElse(tree.pos)
44-
ctx.error(em"${paramVarianceStr}variant type parameter $paramName occurs in ${occursStr}variant position in ${tl.resType}", pos)
28+
def checkLambda(tree: tpd.LambdaTypeTree)(implicit ctx: Context): Unit = tree.tpe match {
29+
case tl: HKTypeLambda =>
30+
val checkOK = new TypeAccumulator[Boolean] {
31+
def error(tref: TypeParamRef) = {
32+
val VariantName(paramName, v) = tl.paramNames(tref.paramNum).toTermName
33+
val paramVarianceStr = if (v == 0) "contra" else "co"
34+
val occursStr = variance match {
35+
case -1 => "contra"
36+
case 0 => "non"
37+
case 1 => "co"
4538
}
46-
def apply(x: Boolean, t: Type) = x && {
47-
t match {
48-
case tref: TypeParamRef if tref.binder `eq` tl =>
49-
val v = tl.typeParams(tref.paramNum).paramVariance
50-
varianceConforms(variance, v) || { error(tref); false }
51-
case _ =>
52-
foldOver(x, t)
53-
}
39+
val pos = tree.tparams
40+
.find(_.name.toTermName == paramName)
41+
.map(_.pos)
42+
.getOrElse(tree.pos)
43+
ctx.error(em"${paramVarianceStr}variant type parameter $paramName occurs in ${occursStr}variant position in ${tl.resType}", pos)
44+
}
45+
def apply(x: Boolean, t: Type) = x && {
46+
t match {
47+
case tref: TypeParamRef if tref.binder `eq` tl =>
48+
val v = tl.typeParams(tref.paramNum).paramVariance
49+
varianceConforms(variance, v) || { error(tref); false }
50+
case _ =>
51+
foldOver(x, t)
5452
}
5553
}
56-
checkOK.apply(true, tl.resType)
57-
case _ =>
58-
}
59-
tree
54+
}
55+
checkOK.apply(true, tl.resType)
56+
case _ =>
6057
}
6158
}
6259

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
object Test extends App {
2+
3+
trait Ord[X]
4+
5+
type TL1 = [X <: Ord[X]] => (X, X)
6+
7+
class C extends Ord[C]
8+
9+
type T1 = TL1[Int] // error: Type argument Int does not conform to upper bound Test.Ord[LazyRef(Int)
10+
type T2 = TL1[C] // OK
11+
12+
class Ref[X](init: X) {
13+
var x: X = init
14+
}
15+
16+
type TL3 = [+X] => Ref[X] // error: covariant type parameter X occurs in nonvariant position in Test.Ref[X]
17+
type TL4[-X] = X => X // error: contravariant type parameter X occurs in covariant position in X => X
18+
19+
def f[F <: [+X] => Any](x: F[String]): F[Any] = x
20+
21+
val sref = new Ref[String]("abc")
22+
val aref: Ref[Any] = f[TL3](sref)
23+
aref.x = 1
24+
val s: String = sref.x
25+
26+
}

tests/neg/type-lambdas.scala

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,4 @@ object Test extends App {
55
type TL1 = [X <: Ord[X]] => (X, X) // OK
66
type TL2 = [X >: Ord[X]] => (X, X) // error: illegal cyclic reference: lower bound Test.Ord[X] of type X refers back to the type itself
77

8-
class C extends Ord[C]
9-
10-
type T1 = TL1[Int] // will be discovered later
11-
type T2 = TL1[C] // OK
12-
13-
class Ref[X](init: X) {
14-
var x: X = init
15-
}
16-
17-
type TL3 = [+X] => Ref[X] // error: covariant type parameter X occurs in nonvariant position in Test.Ref[X]
18-
type TL4[-X] = X => X // error: contravariant type parameter X occurs in covariant position in X => X
19-
20-
def f[F <: [+X] => Any](x: F[String]): F[Any] = x
21-
22-
val sref = new Ref[String]("abc")
23-
val aref: Ref[Any] = f[TL3](sref)
24-
aref.x = 1
25-
val s: String = sref.x
26-
27-
28-
29-
308
}

tests/pos/polytypes.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
object Test {
22

3-
type T = [+X] => (List[X] => List[X])
3+
type T = [X] => (List[X] => List[X])
44

55
def reverse[X](xs: List[X]): List[X] = ???
66

tests/pos/reference/type-lambdas.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,5 @@ object Test {
77
type CTL = [X] => [Y] => (X, Y)
88
type T3 = CTL[Int][String]
99

10-
type T2[+X <: X => X] = Any
11-
class C[+X <: X => Unit]
10+
type T2[+X <: X => X] = Any // OK - variance is not checked in param bounds
1211
}

tests/pos/seqtype-cycle/Test2.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ package object scala {
33
type Throwable = java.lang.Throwable
44
type IndexOutOfBoundsException = java.lang.IndexOutOfBoundsException
55

6-
type Seq[+A] = scala.collection.Seq[A]
6+
type Seq[A] = scala.collection.Seq[A]
77
val Seq = scala.collection.Seq
88
}

0 commit comments

Comments
 (0)