Skip to content

Commit ff1441c

Browse files
committed
Ban classes that incompatibly refine type params
In upickle there was a misuse of Any in a contravariant position.
1 parent 49337a0 commit ff1441c

17 files changed

+292
-119
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,9 @@ class Definitions {
17651765
Set[Symbol](ComparableClass, ProductClass, SerializableClass,
17661766
// add these for now, until we had a chance to retrofit 2.13 stdlib
17671767
// we should do a more through sweep through it then.
1768+
requiredClass("scala.collection.IterableFactoryDefaults"),
17681769
requiredClass("scala.collection.SortedOps"),
1770+
requiredClass("scala.collection.StrictOptimizedSetOps"),
17691771
requiredClass("scala.collection.StrictOptimizedSortedSetOps"),
17701772
requiredClass("scala.collection.generic.DefaultSerializable"),
17711773
requiredClass("scala.collection.generic.IsIterable"),

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
7373
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
7474
* in which case the subtyping relationship "heals" the type.
7575
*/
76-
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
76+
def constrainPatternType(pat: Type, scrut: Type): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
7777

7878
def classesMayBeCompatible: Boolean = {
7979
import Flags._
@@ -98,7 +98,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
9898
val scrCls = scrut.classSymbol
9999
patCls.exists && scrCls.exists
100100
&& (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls))
101-
&& constrainSimplePatternType(pat, scrut, forceInvariantRefinement)
101+
&& constrainSimplePatternType(pat, scrut)
102102
}
103103

104104
def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) {
@@ -209,8 +209,9 @@ trait PatternTypeConstrainer { self: TypeComparer =>
209209
* are used to infer type arguments to Unapply trees.
210210
*
211211
* ## Invariant refinement
212-
* Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
213-
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`. This is violated if `A` is variant:
212+
* Essentially, we say that `D[B] extends C[B]` refines parameter `A` of `trait C[A]` invariantly if
213+
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`.
214+
* This is violated if `A` is variant and `C` is mixed in with an incompatible type argument:
214215
*
215216
* trait C[+A]
216217
* trait D[+B](val b: B) extends C[B]
@@ -224,42 +225,31 @@ trait PatternTypeConstrainer { self: TypeComparer =>
224225
* }
225226
*
226227
* It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`.
227-
* Note, however, that if `D` was a final class, we *could* rely on that relationship.
228-
* To support typical case classes, we also assume that this relationship holds for them and their parent traits.
229-
* This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those
230-
* case classes without also appropriately extending the relevant case class
231-
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
228+
* Note, however, that if `D` was a concrete class, we can rely on that relationship.
229+
* We can assume this relationship holds for them and their parent traits
230+
* by checking that classes inheriting from those classes do not mix-in any parent traits
231+
* with a type parameter that isn't the same type, a subtype, or a super type, depending on if the
232+
* trait's parameter is invariant, covariant or contravariant, respectively
233+
* (see `RefChecks#checkClassInheritanceInvariant`).
232234
*/
233-
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
235+
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type): Boolean = {
234236
def refinementIsInvariant(tp: Type): Boolean = tp match {
235237
case tp: SingletonType => true
236-
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
238+
case tp: ClassInfo => !tp.cls.isOneOf(AbstractOrTrait) || tp.cls.isOneOf(Private | Sealed)
237239
case tp: TypeProxy => refinementIsInvariant(tp.superType)
238240
case _ => false
239241
}
240242

241-
def widenVariantParams(tp: Type) = tp match {
242-
case tp @ AppliedType(tycon, args) =>
243-
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
244-
if (tparam.paramVarianceSign != 0) TypeBounds.empty else arg
245-
)
246-
tp.derivedAppliedType(tycon, args1)
247-
case tp =>
248-
tp
249-
}
250-
251243
val patternCls = patternTp.classSymbol
252244
val scrutineeCls = scrutineeTp.classSymbol
253245

254246
// NOTE: we already know that there is a derives-from relationship in either direction
255-
val upcastPattern =
256-
patternCls.derivesFrom(scrutineeCls)
247+
val upcastPattern = patternCls.derivesFrom(scrutineeCls)
257248

258249
val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
259250
val tp = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
260251

261-
val assumeInvariantRefinement =
262-
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
252+
val assumeInvariantRefinement = migrateTo3 || refinementIsInvariant(patternTp)
263253

264254
trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
265255
(tp, pt) match {

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2074,9 +2074,9 @@ object SymDenotations {
20742074
required: FlagSet = EmptyFlags, excluded: FlagSet = EmptyFlags)(using Context): Denotation =
20752075
membersNamedNoShadowingBasedOnFlags(name, required, excluded).asSeenFrom(pre).toDenot(pre)
20762076

2077-
/** Compute tp.baseType(this) */
2078-
final def baseTypeOf(tp: Type)(using Context): Type = {
2079-
val btrCache = baseTypeCache
2077+
/** Compute tp.baseType(this) or tp.baseType(this, without) */
2078+
final def baseTypeOf(tp: Type, without: Option[Symbol] = None)(using Context): Type = {
2079+
val btrCache = if without.isEmpty then baseTypeCache else new BaseTypeMap()
20802080
def inCache(tp: Type) = tp match
20812081
case tp: CachedType => btrCache.contains(tp)
20822082
case _ => false
@@ -2130,6 +2130,8 @@ object SymDenotations {
21302130
val baseTp =
21312131
if (tpSym eq symbol)
21322132
tp
2133+
else if without.exists(tpSym eq _) then
2134+
defn.AnyType
21332135
else if (isOwnThis)
21342136
if (clsd.baseClassSet.contains(symbol))
21352137
if (symbol.isStatic && symbol.typeParams.isEmpty) symbol.typeRef
@@ -2156,6 +2158,7 @@ object SymDenotations {
21562158
btrCache(tp) = NoPrefix
21572159
val baseTp =
21582160
if (tycon.typeSymbol eq symbol) tp
2161+
else if without.exists(tycon.typeSymbol eq _) then defn.AnyType
21592162
else (tycon.typeParams: @unchecked) match {
21602163
case LambdaParam(_, _) :: _ =>
21612164
recur(tp.superType)

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,7 +1735,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
17351735
*
17361736
* then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive.
17371737
*
1738-
* This method is also used in ConstrainResult mode
1738+
* This method is also used with useNecessaryEither
17391739
* to avoid inference getting stuck due to lack of backtracking,
17401740
* see or-inf.scala and and-inf.scala for examples.
17411741
*
@@ -2889,8 +2889,8 @@ object TypeComparer {
28892889
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
28902890
comparing(_.dropTransparentTraits(tp, bound))
28912891

2892-
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
2893-
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))
2892+
def constrainPatternType(pat: Type, scrut: Type)(using Context): Boolean =
2893+
comparing(_.constrainPatternType(pat, scrut))
28942894

28952895
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String =
28962896
comparing(_.explained(op, header))

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,12 @@ object Types {
11291129
}
11301130
}
11311131

1132+
/** `basetype`, but ignoring any base classes that have the given `without` class symbol. */
1133+
final def baseTypeWithout(base: Symbol, without: Symbol)(using Context): Type =
1134+
base.denot match
1135+
case classd: ClassDenotation => classd.baseTypeOf(this, Some(without))
1136+
case _ => NoType
1137+
11321138
def & (that: Type)(using Context): Type = {
11331139
record("&")
11341140
TypeComparer.glb(this, that)

compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,9 @@ object TypeTestsCasts {
8888
// if P1 <: P, which means the type arguments in P are trivial,
8989
// thus no runtime checks are needed for them.
9090
withMode(Mode.GadtConstraintInference) {
91-
// Why not widen type arguments here? Given the following program
92-
//
93-
// trait Tree[-T] class Ident[-T] extends Tree[T]
94-
//
95-
// def foo1(tree: Tree[Int]) = tree.isInstanceOf[Ident[Int]]
96-
//
97-
// In checking whether the test tree.isInstanceOf[Ident[Int]]
98-
// is realizable, we want to constrain Ident[X] <: Tree[Int],
99-
// such that we can infer X = Int and Ident[X] <:< Ident[Int].
100-
//
101-
// If we perform widening, we will get X = Nothing, and we don't have
102-
// Ident[X] <:< Ident[Int] any more.
103-
TypeComparer.constrainPatternType(P1, X, forceInvariantRefinement = true)
91+
TypeComparer.constrainPatternType(P1, X)
10492
debug.println(
105-
TypeComparer.explained(_.constrainPatternType(P1, X, forceInvariantRefinement = true))
93+
TypeComparer.explained(_.constrainPatternType(P1, X))
10694
)
10795
}
10896

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -770,16 +770,18 @@ object RefChecks {
770770
}
771771
}
772772

773-
/** Check that inheriting a case class does not constitute a variant refinement
774-
* of a base type of the case class. It is because of this restriction that we
775-
* can assume invariant refinement for case classes in `constrainPatternType`.
773+
/** Check that inheriting a class does not constitute a variant refinement
774+
* of a base type of the class. It is because of this restriction that we
775+
* can assume invariant refinement for concrete classes in `constrainPatternType`.
776776
*/
777-
def checkCaseClassInheritanceInvariant() =
778-
for (caseCls <- clazz.info.baseClasses.tail.find(_.is(Case)))
779-
for (baseCls <- caseCls.info.baseClasses.tail)
777+
def checkClassInheritanceInvariant() =
778+
for (middle <- clazz.info.baseClasses.tail.filter(!_.isTransparentTrait))
779+
for (baseCls <- middle.info.baseClasses.tail)
780780
if (baseCls.typeParams.exists(_.paramVarianceSign != 0))
781-
for (problem <- variantInheritanceProblems(baseCls, caseCls, "non-variant", "case "))
781+
val middleStr = if middle.is(Case) then "case " else ""
782+
for (problem <- variantInheritanceProblems(baseCls, middle, "variant", middleStr))
782783
report.errorOrMigrationWarning(problem(), clazz.srcPos, from = `3.0`)
784+
783785
checkNoAbstractMembers()
784786
if (abstractErrors.isEmpty)
785787
checkNoAbstractDecls(clazz)
@@ -788,7 +790,7 @@ object RefChecks {
788790
report.error(abstractErrorMessage, clazz.srcPos)
789791

790792
checkMemberTypesOK()
791-
checkCaseClassInheritanceInvariant()
793+
checkClassInheritanceInvariant()
792794
}
793795

794796
if (!clazz.is(Trait)) {
@@ -825,16 +827,29 @@ object RefChecks {
825827
*/
826828
def variantInheritanceProblems(
827829
baseCls: Symbol, middle: Symbol, baseStr: String, middleStr: String): Option[() => String] = {
830+
if baseCls == middle then return None
828831
val superBT = self.baseType(middle)
829-
val thisBT = self.baseType(baseCls)
830832
val combinedBT = superBT.baseType(baseCls)
831-
if (combinedBT =:= thisBT) None // ok
833+
val withoutMiddleBT = self.baseTypeWithout(baseCls, middle)
834+
val allOk = (combinedBT, withoutMiddleBT) match
835+
case (AppliedType(tycon, args1), AppliedType(_, args2)) =>
836+
val superBTArgs = superBT.argInfos.toSet
837+
tycon.typeParams.lazyZip(args1).lazyZip(args2).forall { (param, arg1, arg2) =>
838+
if superBTArgs.contains(arg1) then
839+
val variance = param.paramVarianceSign
840+
(variance > 0 || (arg2 <:< arg1)) &&
841+
(variance < 0 || (arg1 <:< arg2))
842+
else true // e.g. CovBoth in neg/i11834
843+
}
844+
case _ => combinedBT =:= self.baseType(baseCls)
845+
if allOk then None // ok
832846
else
833847
Some(() =>
834848
em"""illegal inheritance: $clazz inherits conflicting instances of $baseStr base $baseCls.
835849
|
836-
| Direct basetype: $thisBT
837-
| Basetype via $middleStr$middle: $combinedBT""")
850+
| Basetype via $middleStr$middle: $combinedBT
851+
| Basetype without $middleStr$middle: $withoutMiddleBT""")
852+
//Some(() => em"""$clazz: $withoutMiddleBT vs $combinedBT via $superBT""")
838853
}
839854

840855
/* Returns whether there is a symbol declared in class `inclazz`

tests/neg/3324g.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
// scalac: -Werror
22
class Test {
3-
trait A[+T]
4-
class B[T] extends A[T]
5-
class C[T] extends B[Any] with A[T]
6-
7-
def foo[T](c: C[T]): Unit = c match {
8-
case _: B[T] => // error
9-
}
3+
trait A[+X]
4+
class B[Y] extends A[Y]
5+
class C[Z] extends B[Any] with A[Z] // error
106

117
def bar[T](b: B[T]): Unit = b match {
128
case _: A[T] =>
139
}
1410

1511
def quux[T](a: A[T]): Unit = a match {
16-
case _: B[T] => // error!!
12+
case _: B[T] => // is-error!!
13+
// superseded by refcheck error above
14+
// test covered by neg/gadt
1715
}
18-
19-
quux(new C[Int])
2016
}

tests/neg/JavaSeqLiteral.scala

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,18 @@
11
// scalac: -Werror
2-
object Test1 {
3-
trait Tree[-T]
2+
trait Tree[-T]
3+
trait Type
44

5+
object Test1 {
56
class JavaSeqLiteral[T] extends Tree[T]
6-
7-
trait Type
8-
97
class DummyTree extends JavaSeqLiteral[Any]
10-
118
def foo1(tree: Tree[Type]) =
12-
tree.isInstanceOf[JavaSeqLiteral[Type]] // error
13-
9+
tree.isInstanceOf[JavaSeqLiteral[Type]] // error: unchecked
1410
foo1(new DummyTree)
1511
}
1612

1713
object Test2 {
18-
trait Tree[-T]
19-
2014
class JavaSeqLiteral[-T] extends Tree[T]
21-
22-
trait Type
23-
2415
class DummyTree extends JavaSeqLiteral[Any]
25-
26-
def foo1(tree: Tree[Type]) =
27-
tree.isInstanceOf[JavaSeqLiteral[Type]]
28-
16+
def foo1(tree: Tree[Type]) = tree.isInstanceOf[JavaSeqLiteral[Type]]
2917
foo1(new DummyTree)
3018
}

tests/neg/gadt.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
// scalac: -Werror
2-
class Test {
3-
trait A[+T]
4-
class B[T] extends A[T]
5-
6-
class C
7-
class D extends C
2+
trait A[+X]
3+
class B[Y](val put: Y => Int) extends A[Y]
84

9-
def quux(a: A[C]): Unit = a match {
10-
case _: B[C] => // error!!
11-
}
5+
class C
6+
class D extends C { def int = 4 }
127

13-
quux(new B[D])
8+
class Test {
9+
def quux(a: A[C], c: C) = a match
10+
case b: B[C] => // error: unchecked
11+
b.put(c)
1412
}
13+
14+
object Test extends Test:
15+
def main(args: Array[String]): Unit =
16+
val a: A[C] = new B[D]((d: D) => d.int)
17+
val c: C = new C
18+
quux(a, c)

tests/neg/i11018.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,19 @@ trait CTrait[+A](val a: A) {
1313
trait DTrait[+B] extends CTrait[B]
1414
trait DClass[+B] extends CClass[B]
1515

16-
final class F1 extends DTrait[Foo] with CTrait[Bar](new Bar) // error: illegal parameter
17-
final class F2 extends CTrait[Bar](new Bar) with DTrait[Foo] // error: illegal parameter
18-
final class F3 extends DClass[Foo] with CClass[Bar](new Bar) // error: illegal parameter
19-
final class F4 extends CClass[Bar](new Bar) with DClass[Foo] // error: illegal parameter
16+
final class F1 // error: illegal inheritance
17+
extends DTrait[Foo]
18+
with CTrait[Bar](new Bar) // error: illegal parameter
19+
final class F2 // error: illegal inheritance
20+
extends CTrait[Bar](new Bar) // error: illegal parameter
21+
with DTrait[Foo]
22+
final class F3 // error: illegal inheritance
23+
extends DClass[Foo]
24+
with CClass[Bar](new Bar) // error: illegal parameter
25+
final class F4 // error: illegal inheritance
26+
extends CClass[Bar](new Bar) // error: illegal parameter
27+
with DClass[Foo]
2028

21-
final class F5 extends DTrait[Foo] with CTrait[Foo & Bar](new Bar with Foo { def name = "hello"}) // ok
29+
final class F5 // error: illegal inheritance
30+
extends DTrait[Foo]
31+
with CTrait[Foo & Bar](new Bar with Foo { def name = "hello"})

0 commit comments

Comments
 (0)