diff --git a/src/dotty/tools/dotc/typer/TypeAssigner.scala b/src/dotty/tools/dotc/typer/TypeAssigner.scala index ba8d44110b10..7225ede143ae 100644 --- a/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -31,35 +31,40 @@ trait TypeAssigner { /** An upper approximation of the given type `tp` that does not refer to any symbol in `symsToAvoid`. * Approximation steps are: * - * - follow aliases if the original refers to a forbidden symbol + * - follow aliases and upper bounds if the original refers to a forbidden symbol * - widen termrefs that refer to a forbidden symbol * - replace ClassInfos of forbidden classes by the intersection of their parents, refined by all * non-private fields, methods, and type members. + * - if the prefix of a class refers to a forbidden symbol, first try to replace the prefix, + * if this is not possible, replace the ClassInfo as above. * - drop refinements referring to a forbidden symbol. */ def avoid(tp: Type, symsToAvoid: => List[Symbol])(implicit ctx: Context): Type = { val widenMap = new TypeMap { lazy val forbidden = symsToAvoid.toSet - def toAvoid(tp: Type): Boolean = tp match { - case tp: TermRef => - val sym = tp.symbol - sym.exists && ( - sym.owner.isTerm && (forbidden contains sym) - || !(sym.owner is Package) && toAvoid(tp.prefix) - ) - case tp: TypeRef => - forbidden contains tp.symbol - case _ => - false - } - def apply(tp: Type) = tp match { + def toAvoid(tp: Type): Boolean = + // TODO: measure the cost of using `existsPart`, and if necessary replace it + // by a `TypeAccumulator` where we have set `stopAtStatic = true`. + tp existsPart { + case tp: NamedType => + forbidden contains tp.symbol + case _ => + false + } + def apply(tp: Type): Type = tp match { case tp: TermRef if toAvoid(tp) && variance > 0 => apply(tp.info.widenExpr) - case tp: TypeRef if (forbidden contains tp.symbol) || toAvoid(tp.prefix) => + case tp: TypeRef if toAvoid(tp) => tp.info match { case TypeAlias(ref) => apply(ref) case info: ClassInfo if variance > 0 => + if (!(forbidden contains tp.symbol)) { + val prefix = apply(tp.prefix) + val tp1 = tp.derivedSelect(prefix) + if (tp1.typeSymbol.exists) + return tp1 + } val parentType = info.instantiatedParents.reduceLeft(ctx.typeComparer.andType(_, _)) def addRefinement(parent: Type, decl: Symbol) = { val inherited = @@ -82,13 +87,28 @@ trait TypeAssigner { case _ => mapOver(tp) } - case tp: RefinedType => - val tp1 @ RefinedType(parent1, _) = mapOver(tp) - if (tp1.refinedInfo.existsPart(toAvoid) && variance > 0) { - typr.println(s"dropping refinement from $tp1") - parent1 + case tp @ RefinedType(parent, name) if variance > 0 => + // The naive approach here would be to first approximate the parent, + // but if the base type of the approximated parent is different from + // the current base type, then the current refinement won't be valid + // if it's a type parameter refinement. + // Therefore we first approximate the base type, then use `baseArgInfos` + // to get correct refinements for the approximated base type, then + // recursively approximate the resulting type. + val base = tp.unrefine + if (toAvoid(base)) { + val base1 = apply(base) + apply(base1.appliedTo(tp.baseArgInfos(base1.typeSymbol))) + } else { + val parent1 = apply(tp.parent) + val refinedInfo1 = apply(tp.refinedInfo) + if (toAvoid(refinedInfo1)) { + typr.println(s"dropping refinement from $tp") + parent1 + } else { + tp.derivedRefinedType(parent1, name, refinedInfo1) + } } - else tp1 case tp: TypeVar if ctx.typerState.constraint.contains(tp) => val lo = ctx.typerState.constraint.fullLowerBound(tp.origin) val lo1 = avoid(lo, symsToAvoid) diff --git a/tests/pos/escapingRefs.scala b/tests/pos/escapingRefs.scala new file mode 100644 index 000000000000..1b1deb8dec75 --- /dev/null +++ b/tests/pos/escapingRefs.scala @@ -0,0 +1,42 @@ +class Outer { + class Inner { + class Inner2 + } +} + +class HasA { type A } + +class Foo[A] + +object Test { + def test = { + val a: Outer#Inner = { + val o = new Outer + new o.Inner + } + + val b: Outer#Inner#Inner2 = { + val o = new Outer + val i = new o.Inner + new i.Inner2 + } + + val c: HasA { type A = Int } = { + val h = new HasA { + type A = Int + } + val x: HasA { type A = h.A } = h + x + } + + val d: Foo[Int] = { + class Bar[B] extends Foo[B] + new Bar[Int] + } + + val e: Foo[_] = { + class Bar[B] extends Foo[B] + new Bar[Int]: Bar[_ <: Int] + } + } +}