diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 0f61fd2e25fe..91ff89130d4e 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -293,10 +293,10 @@ trait ConstraintHandling[AbstractContext] { /** Widen inferred type `inst` with upper `bound`, according to the following rules: * 1. If `inst` is a singleton type, or a union containing some singleton types, - * widen (all) the singleton type(s), provied the result is a subtype of `bound` + * widen (all) the singleton type(s), provided the result is a subtype of `bound` * (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint) * 2. If `inst` is a union type, approximate the union type from above by an intersection - * of all common base types, provied the result is a subtype of `bound`. + * of all common base types, provided the result is a subtype of `bound`. * * Don't do these widenings if `bound` is a subtype of `scala.Singleton`. * Also, if the result of these widenings is a TypeRef to a module class, @@ -309,15 +309,17 @@ trait ConstraintHandling[AbstractContext] { def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type = { def widenOr(tp: Type) = { val tpw = tp.widenUnion - if ((tpw ne tp) && tpw <:< bound) tpw else tp + if (tpw ne tp) && (tpw <:< bound) then tpw else tp } def widenSingle(tp: Type) = { val tpw = tp.widenSingletons - if ((tpw ne tp) && tpw <:< bound) tpw else tp + if (tpw ne tp) && (tpw <:< bound) then tpw else tp } + def isSingleton(tp: Type): Boolean = tp match + case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi) + case _ => isSubTypeWhenFrozen(tp, defn.SingletonType) val wideInst = - if (isSubTypeWhenFrozen(bound, defn.SingletonType)) inst - else widenOr(widenSingle(inst)) + if isSingleton(bound) then inst else widenOr(widenSingle(inst)) wideInst match case wideInst: TypeRef if wideInst.symbol.is(Module) => TermRef(wideInst.prefix, wideInst.symbol.sourceModule) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 2faf721a9da8..ad370dfec15a 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -116,7 +116,7 @@ final class ProperGadtConstraint private( ) val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) => - val tv = new TypeVar(paramRef, creatorState = null) + val tv = TypeVar(paramRef, creatorState = null) mapping = mapping.updated(sym, tv) reverseMapping = reverseMapping.updated(tv.origin, sym) tv diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index db639d961b23..d6d0b18333d7 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -1450,6 +1450,14 @@ object SymDenotations { else if is(Contravariant) then Contravariant else EmptyFlags + /** The length of the owner chain of this symbol. 1 for _root_, 0 for NoSymbol */ + def nestingLevel(using Context): Int = + @tailrec def recur(d: SymDenotation, n: Int): Int = d match + case NoDenotation => n + case d: ClassDenotation => d.nestingLevel + n // profit from the cache in ClassDenotation + case _ => recur(d.owner, n + 1) + recur(this, 0) + /** The flags to be used for a type parameter owned by this symbol. * Overridden by ClassDenotation. */ @@ -2151,6 +2159,12 @@ object SymDenotations { override def registeredCompanion(implicit ctx: Context) = { ensureCompleted(); myCompanion } override def registeredCompanion_=(c: Symbol) = { myCompanion = c } + + private var myNestingLevel = -1 + + override def nestingLevel(using Context) = + if myNestingLevel == -1 then myNestingLevel = owner.nestingLevel + 1 + myNestingLevel } /** The denotation of a package class. diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index b3629dc2386b..6cb948f12d1b 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4110,18 +4110,17 @@ object Types { * * @param origin The parameter that's tracked by the type variable. * @param creatorState The typer state in which the variable was created. - * - * `owningTree` and `owner` are used to determine whether a type-variable can be instantiated - * at some given point. See `Inferencing#interpolateUndetVars`. */ - final class TypeVar(private var _origin: TypeParamRef, creatorState: TyperState) extends CachedProxyType with ValueType { + final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, nestingLevel: Int) extends CachedProxyType with ValueType { + + private var currentOrigin = initOrigin - def origin: TypeParamRef = _origin + def origin: TypeParamRef = currentOrigin /** Set origin to new parameter. Called if we merge two conflicting constraints. * See OrderingConstraint#merge, OrderingConstraint#rename */ - def setOrigin(p: TypeParamRef) = _origin = p + def setOrigin(p: TypeParamRef) = currentOrigin = p /** The permanent instance type of the variable, or NoType is none is given yet */ private var myInst: Type = NoType @@ -4150,6 +4149,36 @@ object Types { /** Is the variable already instantiated? */ def isInstantiated(implicit ctx: Context): Boolean = instanceOpt.exists + /** Avoid term references in `tp` to parameters or local variables that + * are nested more deeply than the type variable itself. + */ + private def avoidCaptures(tp: Type)(using Context): Type = + val problemSyms = new TypeAccumulator[Set[Symbol]]: + def apply(syms: Set[Symbol], t: Type): Set[Symbol] = t match + case ref @ TermRef(NoPrefix, _) + // AVOIDANCE TODO: Are there other problematic kinds of references? + // Our current tests only give us these, but we might need to generalize this. + if ref.symbol.maybeOwner.nestingLevel > nestingLevel => + syms + ref.symbol + case _ => + foldOver(syms, t) + val problems = problemSyms(Set.empty, tp) + if problems.isEmpty then tp + else + val atp = ctx.typer.avoid(tp, problems.toList) + def msg = i"Inaccessible variables captured in instantation of type variable $this.\n$tp was fixed to $atp" + typr.println(msg) + val bound = ctx.typeComparer.fullUpperBound(origin) + if !(atp <:< bound) then + throw new TypeError(s"$msg,\nbut the latter type does not conform to the upper bound $bound") + atp + // AVOIDANCE TODO: This really works well only if variables are instantiated from below + // If we hit a problematic symbol while instantiating from above, then avoidance + // will widen the instance type further. This could yield an alias, which would be OK. + // But it also could yield a true super type which would then fail the bounds check + // and throw a TypeError. The right thing to do instead would be to avoid "downwards". + // To do this, we need first test cases for that situation. + /** Instantiate variable with given type */ def instantiateWith(tp: Type)(implicit ctx: Context): Type = { assert(tp ne this, s"self instantiation of ${tp.show}, constraint = ${ctx.typerState.constraint.show}") @@ -4168,7 +4197,7 @@ object Types { * is also a singleton type. */ def instantiate(fromBelow: Boolean)(implicit ctx: Context): Type = - instantiateWith(ctx.typeComparer.instanceType(origin, fromBelow)) + instantiateWith(avoidCaptures(ctx.typeComparer.instanceType(origin, fromBelow))) /** For uninstantiated type variables: Is the lower bound different from Nothing? */ def hasLowerBound(implicit ctx: Context): Boolean = @@ -4200,6 +4229,9 @@ object Types { s"TypeVar($origin$instStr)" } } + object TypeVar: + def apply(initOrigin: TypeParamRef, creatorState: TyperState)(using Context) = + new TypeVar(initOrigin, creatorState, ctx.owner.nestingLevel) type TypeVars = SimpleIdentitySet[TypeVar] diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 9c6e83af0522..95c66a386fd5 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -196,11 +196,11 @@ class Namer { typer: Typer => import untpd._ - val TypedAhead: Property.Key[tpd.Tree] = new Property.Key - val ExpandedTree: Property.Key[untpd.Tree] = new Property.Key + val TypedAhead : Property.Key[tpd.Tree] = new Property.Key + val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key val ExportForwarders: Property.Key[List[tpd.MemberDef]] = new Property.Key - val SymOfTree: Property.Key[Symbol] = new Property.Key - val Deriver: Property.Key[typer.Deriver] = new Property.Key + val SymOfTree : Property.Key[Symbol] = new Property.Key + val Deriver : Property.Key[typer.Deriver] = new Property.Key /** A partial map from unexpanded member and pattern defs and to their expansions. * Populated during enterSyms, emptied during typer. @@ -1439,13 +1439,10 @@ class Namer { typer: Typer => // instead of widening to the underlying module class types. // We also drop the @Repeated annotation here to avoid leaking it in method result types // (see run/inferred-repeated-result). - def widenRhs(tp: Type): Type = { - val tp1 = tp.widenTermRefExpr.simplified match + def widenRhs(tp: Type): Type = + tp.widenTermRefExpr.simplified match case ctp: ConstantType if isInlineVal => ctp - case ref: TypeRef if ref.symbol.is(ModuleClass) => tp - case tp => tp.widenUnion - tp1.dropRepeatedAnnot - } + case tp => ctx.typeComparer.widenInferred(tp, rhsProto) // Replace aliases to Unit by Unit itself. If we leave the alias in // it would be erased to BoxedUnit. @@ -1497,9 +1494,21 @@ class Namer { typer: Typer => if (isFullyDefined(tpe, ForceDegree.none)) tpe else typedAheadExpr(mdef.rhs, tpe).tpe case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) => - val rhsType = typedAheadExpr(mdef.rhs, tpt.tpe).tpe mdef match { case mdef: DefDef if mdef.name == nme.ANON_FUN => + // This case applies if the closure result type contains uninstantiated + // type variables. In this case, constrain the closure result from below + // by the parameter-capture-avoiding type of the body. + val rhsType = typedAheadExpr(mdef.rhs, tpt.tpe).tpe + + // The following part is important since otherwise we might instantiate + // the closure result type with a plain functon type that refers + // to local parameters. An example where this happens in `dependent-closures.scala` + // If the code after `val rhsType` is commented out, this file fails pickling tests. + // AVOIDANCE TODO: Follow up why this happens, and whether there + // are better ways to achieve this. It would be good if we could get rid of this code. + // It seems at least partially redundant with the nesting level checking on TypeVar + // instantiation. val hygienicType = avoid(rhsType, paramss.flatten) if (!hygienicType.isValueType || !(hygienicType <:< tpt.tpe)) ctx.error(i"return type ${tpt.tpe} of lambda cannot be made hygienic;\n" + @@ -1512,10 +1521,10 @@ class Namer { typer: Typer => case _ => WildcardType } - val memTpe = paramFn(checkSimpleKinded(typedAheadType(mdef.tpt, tptProto)).tpe) + val mbrTpe = paramFn(checkSimpleKinded(typedAheadType(mdef.tpt, tptProto)).tpe) if (ctx.explicitNulls && mdef.mods.is(JavaDefined)) - JavaNullInterop.nullifyMember(sym, memTpe, mdef.mods.isAllOf(JavaEnumValue)) - else memTpe + JavaNullInterop.nullifyMember(sym, mbrTpe, mdef.mods.isAllOf(JavaEnumValue)) + else mbrTpe } /** The type signature of a DefDef with given symbol */ diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index b62e1bc7f0e7..7c85d4758a6a 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -501,8 +501,8 @@ object ProtoTypes { def newTypeVars(tl: TypeLambda): List[TypeTree] = for (paramRef <- tl.paramRefs) yield { - val tt = new TypeVarBinder().withSpan(owningTree.span) - val tvar = new TypeVar(paramRef, state) + val tt = TypeVarBinder().withSpan(owningTree.span) + val tvar = TypeVar(paramRef, state) state.ownedVars += tvar tt.withType(tvar) } diff --git a/tests/neg/i8861.scala b/tests/neg/i8861.scala new file mode 100644 index 000000000000..87f1884f6155 --- /dev/null +++ b/tests/neg/i8861.scala @@ -0,0 +1,31 @@ +object Test { + sealed trait Container { s => + type A + def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R + } + final class IntV extends Container { s => + type A = Int + val i: Int = 42 + def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R = int(this) + } + final class StrV extends Container { s => + type A = String + val t: String = "hello" + def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R = str(this) + } + + def minimalOk[R](c: Container { type A = R }): R = c.visit[R]( + int = vi => vi.i : vi.A, + str = vs => vs.t : vs.A + ) + def minimalFail[M](c: Container { type A = M }): M = c.visit( + int = vi => vi.i : vi.A, + str = vs => vs.t : vs.A // error + ) + + def main(args: Array[String]): Unit = { + val e: Container { type A = String } = new StrV + println(minimalOk(e)) // this one prints "hello" + println(minimalFail(e)) // this one fails with ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer + } +} \ No newline at end of file diff --git a/tests/pos/dependent-closures.scala b/tests/pos/dependent-closures.scala new file mode 100644 index 000000000000..687f1d5f0adc --- /dev/null +++ b/tests/pos/dependent-closures.scala @@ -0,0 +1,27 @@ +trait S { type N; def n: N } + +def newS[X](n: X): S { type N = X } = ??? + +def test = + val ss: List[S] = ??? + val cl1 = (s: S) => newS(s.n) + val cl2: (s: S) => S { type N = s.N } = cl1 + def f[R](cl: (s: S) => R) = cl + val x = f(s => newS(s.n)) + val x1: (s: S) => S = x + // If the code in `tptProto` of Namer that refers to this + // file is commented out, we see: + // pickling difference for the result type of the closure argument + // before pickling: S => S { type N = s.N } + // after pickling : (s: S) => S { type N = s.N } + + ss.map(s => newS(s.n)) + // If the code in `tptProto` of Namer that refers to this + // file is commented out, we see a pickling difference like the one above. + + def g[R](cl: (s: S) => (S { type N = s.N }, R)) = ??? + g(s => (newS(s.n), identity(1))) + + def h(cl: (s: S) => S { type N = s.N }) = ??? + h(s => newS(s.n)) +