Skip to content

Commit 78a29a4

Browse files
committed
Initial versions of Variances and CheckVariances
Not yet integrated or tested.
1 parent 3065790 commit 78a29a4

File tree

5 files changed

+229
-2
lines changed

5 files changed

+229
-2
lines changed

src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ class Definitions {
196196
lazy val Array_update = ctx.requiredMethod(ArrayClass, nme.update)
197197
lazy val Array_length = ctx.requiredMethod(ArrayClass, nme.length)
198198
lazy val Array_clone = ctx.requiredMethod(ArrayClass, nme.clone_)
199-
lazy val uncheckedStableClass: ClassSymbol = ctx.requiredClass("scala.annotation.unchecked.uncheckedStable")
200199

201200
lazy val UnitClass = valueClassSymbol("scala.Unit", BoxedUnitClass, java.lang.Void.TYPE, UnitEnc)
202201
lazy val BooleanClass = valueClassSymbol("scala.Boolean", BoxedBooleanClass, java.lang.Boolean.TYPE, BooleanEnc)
@@ -296,6 +295,8 @@ class Definitions {
296295
lazy val AnnotationDefaultAnnot = ctx.requiredClass("dotty.annotation.internal.AnnotationDefault")
297296
lazy val ThrowsAnnot = ctx.requiredClass("scala.throws")
298297
lazy val UncheckedAnnot = ctx.requiredClass("scala.unchecked")
298+
lazy val UncheckedStableAnnot = ctx.requiredClass("scala.annotation.unchecked.uncheckedStable")
299+
lazy val UncheckedVarianceAnnot = ctx.requiredClass("scala.annotation.unchecked.uncheckedVariance")
299300
lazy val VolatileAnnot = ctx.requiredClass("scala.volatile")
300301

301302
// convenient one-parameter method types

src/dotty/tools/dotc/core/Flags.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ object Flags {
500500
/** A parameter or parameter accessor */
501501
final val ParamOrAccessor = Param | ParamAccessor
502502

503+
/** A type parameter or type parameter accessor */
504+
final val TypeParamOrAccessor = TypeParam | TypeParamAccessor
505+
503506
/** A covariant type parameter instance */
504507
final val LocalCovariant = allOf(Local, Covariant)
505508

src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ object SymDenotations {
402402
final def isStable(implicit ctx: Context) = {
403403
val isUnstable =
404404
(this is UnstableValue) ||
405-
ctx.isVolatile(info) && !hasAnnotation(defn.uncheckedStableClass)
405+
ctx.isVolatile(info) && !hasAnnotation(defn.UncheckedStableAnnot)
406406
(this is Stable) || isType || {
407407
if (isUnstable) false
408408
else { setFlag(Stable); true }
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import dotty.tools.dotc.ast.{ Trees, tpd }
5+
import core._
6+
import Types._, Contexts._, Flags._, Symbols._, Annotations._, Trees._
7+
import Decorators._
8+
import Variances._
9+
10+
object VarianceChecker {
11+
12+
case class VarianceError(tvar: Symbol, required: Variance)
13+
}
14+
15+
/** See comments at scala.reflect.internal.Variance.
16+
*/
17+
class VarianceChecker(implicit ctx: Context) {
18+
import VarianceChecker._
19+
import tpd._
20+
21+
private object Validator extends TypeAccumulator[Option[VarianceError]] {
22+
private var base: Symbol = _
23+
24+
/** The variance of a symbol occurrence of `tvar` seen at the level of the definition of `base`.
25+
* The search proceeds from `base` to the owner of `tvar`.
26+
* Initially the state is covariant, but it might change along the search.
27+
*/
28+
def relativeVariance(tvar: Symbol, base: Symbol, v: Variance = Covariant): Variance = {
29+
if (base.owner == tvar.owner) v
30+
else if ((base is Param) && base.owner.isTerm) relativeVariance(tvar, base.owner.owner, flip(v))
31+
else if (base.isTerm) Bivariant
32+
else if (base.isAliasType) relativeVariance(tvar, base.owner, Invariant)
33+
else relativeVariance(tvar, base.owner, v)
34+
}
35+
36+
def isUncheckedVariance(tp: Type): Boolean = tp match {
37+
case AnnotatedType(annot, tp1) =>
38+
annot.symbol == defn.UncheckedVarianceAnnot || isUncheckedVariance(tp1)
39+
case _ => false
40+
}
41+
42+
private def checkVarianceOfSymbol(tvar: Symbol): Option[VarianceError] = {
43+
val relative = relativeVariance(tvar, base)
44+
val required = Variances.compose(relative, this.variance)
45+
if (relative == Bivariant) None
46+
else {
47+
def tvar_s = s"$tvar (${tvar.variance}${tvar.showLocated})"
48+
def base_s = s"$base in ${base.owner}" + (if (base.owner.isClass) "" else " in " + base.owner.enclosingClass)
49+
ctx.log(s"verifying $tvar_s is $required at $base_s")
50+
if (tvar.variance == required) None
51+
else Some(VarianceError(tvar, required))
52+
}
53+
}
54+
55+
/** For PolyTypes, type parameters are skipped because they are defined
56+
* explicitly (their TypeDefs will be passed here.) For MethodTypes, the
57+
* same is true of the parameters (ValDefs) unless we are inside a
58+
* refinement, in which case they are checked from here.
59+
*/
60+
def apply(status: Option[VarianceError], tp: Type): Option[VarianceError] =
61+
if (status.isDefined) status
62+
else tp match {
63+
case tp: TypeRef =>
64+
val sym = tp.symbol
65+
if (sym.variance != 0 && base.isContainedIn(sym.owner)) checkVarianceOfSymbol(sym)
66+
else if (sym.isAliasType) this(status, sym.info)
67+
else foldOver(status, tp)
68+
case tp: MethodType =>
69+
this(status, tp.resultType) // params will be checked in their TypeDef nodes.
70+
case tp: PolyType =>
71+
this(status, tp.resultType) // params will be checked in their ValDef nodes.
72+
case AnnotatedType(annot, _) if annot.symbol == defn.UncheckedVarianceAnnot =>
73+
status
74+
case tp: ClassInfo =>
75+
???
76+
case _ =>
77+
foldOver(status, tp)
78+
}
79+
80+
def validateDefinition(base: Symbol): Option[VarianceError] = {
81+
val saved = this.base
82+
this.base = base
83+
try apply(None, base.info)
84+
finally this.base = saved
85+
}
86+
}
87+
88+
def varianceString(v: Variance) =
89+
if (v is Covariant) "covariant"
90+
else if (v is Contravariant) "contravariant"
91+
else "invariant"
92+
93+
object Traverser extends TreeTraverser {
94+
def checkVariance(sym: Symbol) = Validator.validateDefinition(sym) match {
95+
case Some(VarianceError(tvar, required)) =>
96+
ctx.error(
97+
i"${varianceString(tvar.flags)} $tvar occurs in ${varianceString(required)} position in type ${sym.info} of $sym",
98+
sym.pos)
99+
case None =>
100+
}
101+
102+
override def traverse(tree: Tree) = {
103+
def sym = tree.symbol
104+
// No variance check for object-private/protected methods/values.
105+
// Or constructors, or case class factory or extractor.
106+
def skip = (
107+
sym == NoSymbol
108+
|| sym.is(Local)
109+
|| sym.owner.isConstructor
110+
//|| sym.owner.isCaseApplyOrUnapply // not clear why needed
111+
)
112+
tree match {
113+
case defn: MemberDef if skip =>
114+
ctx.debuglog(s"Skipping variance check of ${sym.showDcl}")
115+
case tree: TypeDef =>
116+
checkVariance(sym)
117+
super.traverse(tree)
118+
case tree: ValDef =>
119+
checkVariance(sym)
120+
case DefDef(_, _, tparams, vparamss, _, _) =>
121+
checkVariance(sym)
122+
tparams foreach traverse
123+
vparamss foreach (_ foreach traverse)
124+
case Template(_, _, _, body) =>
125+
super.traverse(tree)
126+
case _ =>
127+
}
128+
}
129+
}
130+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import dotty.tools.dotc.ast.{Trees, tpd}
5+
import core._
6+
import Types._, Contexts._, Flags._, Symbols._, Annotations._, Trees._
7+
import Decorators._
8+
9+
object Variances {
10+
import tpd._
11+
12+
type Variance = FlagSet
13+
val Bivariant = VarianceFlags
14+
val Invariant = EmptyFlags
15+
16+
/** Flip between covariant and contravariant */
17+
def flip(v: Variance): Variance = {
18+
if (v == Covariant) Contravariant
19+
else if (v == Contravariant) Covariant
20+
else v
21+
}
22+
23+
/** Map everything below Bivariant to Invariant */
24+
def cut(v: Variance): Variance =
25+
if (v == Bivariant) v else Invariant
26+
27+
def compose(v: Variance, boundsVariance: Int) =
28+
if (boundsVariance == 1) v
29+
else if (boundsVariance == -1) flip(v)
30+
else cut(v)
31+
32+
/** Compute variance of type parameter `tparam' in types of all symbols `sym'. */
33+
def varianceInSyms(syms: List[Symbol])(tparam: Symbol)(implicit ctx: Context): Variance =
34+
(Bivariant /: syms) ((v, sym) => v & varianceInSym(sym)(tparam))
35+
36+
/** Compute variance of type parameter `tparam' in type of symbol `sym'. */
37+
def varianceInSym(sym: Symbol)(tparam: Symbol)(implicit ctx: Context): Variance =
38+
if (sym.isAliasType) cut(varianceInType(sym.info)(tparam))
39+
else varianceInType(sym.info)(tparam)
40+
41+
/** Compute variance of type parameter `tparam' in all types `tps'. */
42+
def varianceInTypes(tps: List[Type])(tparam: Symbol)(implicit ctx: Context): Variance =
43+
(Bivariant /: tps) ((v, tp) => v & varianceInType(tp)(tparam))
44+
45+
/** Compute variance of type parameter `tparam' in all type arguments
46+
* <code>tps</code> which correspond to formal type parameters `tparams1'.
47+
*/
48+
def varianceInArgs(tps: List[Type], tparams1: List[Symbol])(tparam: Symbol)(implicit ctx: Context): Variance = {
49+
var v: Variance = Bivariant;
50+
for ((tp, tparam1) <- tps zip tparams1) {
51+
val v1 = varianceInType(tp)(tparam)
52+
v = v & (if (tparam1.is(Covariant)) v1
53+
else if (tparam1.is(Contravariant)) flip(v1)
54+
else cut(v1))
55+
}
56+
v
57+
}
58+
59+
/** Compute variance of type parameter `tparam' in all type annotations `annots'. */
60+
def varianceInAnnots(annots: List[Annotation])(tparam: Symbol)(implicit ctx: Context): Variance = {
61+
(Bivariant /: annots) ((v, annot) => v & varianceInAnnot(annot)(tparam))
62+
}
63+
64+
/** Compute variance of type parameter `tparam' in type annotation `annot'. */
65+
def varianceInAnnot(annot: Annotation)(tparam: Symbol)(implicit ctx: Context): Variance = {
66+
varianceInType(annot.tree.tpe)(tparam)
67+
}
68+
69+
/** Compute variance of type parameter <code>tparam</code> in type <code>tp</code>. */
70+
def varianceInType(tp: Type)(tparam: Symbol)(implicit ctx: Context): Variance = tp match {
71+
case TermRef(pre, sym) =>
72+
varianceInType(pre)(tparam)
73+
case TypeRef(pre, sym) =>
74+
if (sym == tparam) Covariant else varianceInType(pre)(tparam)
75+
case tp @ TypeBounds(lo, hi) =>
76+
if (lo eq hi) compose(varianceInType(hi)(tparam), tp.variance)
77+
else flip(varianceInType(lo)(tparam)) & varianceInType(hi)(tparam)
78+
case tp @ RefinedType(parent, _) =>
79+
varianceInType(parent)(tparam) & varianceInType(tp.refinedInfo)(tparam)
80+
case tp @ MethodType(_, paramTypes) =>
81+
flip(varianceInTypes(paramTypes)(tparam)) & varianceInType(tp.resultType)(tparam)
82+
case ExprType(restpe) =>
83+
varianceInType(restpe)(tparam)
84+
case tp @ PolyType(_) =>
85+
flip(varianceInTypes(tp.paramBounds)(tparam)) & varianceInType(tp.resultType)(tparam)
86+
case AnnotatedType(annot, tp) =>
87+
varianceInAnnot(annot)(tparam) & varianceInType(tp)(tparam)
88+
case tp: AndOrType =>
89+
varianceInType(tp.tp1)(tparam) & varianceInType(tp.tp2)(tparam)
90+
case _ =>
91+
Bivariant
92+
}
93+
}

0 commit comments

Comments
 (0)