diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 25b67921fc44..5bb3b4ccbcc9 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -699,10 +699,12 @@ object Trees { s"TypeTree${if (hasType) s"[$typeOpt]" else ""}" } - /** A type tree that defines a new type variable. Its type is always a TypeVar. - * Every TypeVar is created as the type of one TypeVarBinder. + /** A type tree whose type is inferred. These trees appear in two contexts + * - as an argument of a TypeApply. In that case its type is always a TypeVar + * - as a (result-)type of an inferred ValDef or DefDef. + * Every TypeVar is created as the type of one InferredTypeTree. */ - class TypeVarBinder[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends TypeTree[T] + class InferredTypeTree[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends TypeTree[T] /** ref.type */ case class SingletonTypeTree[-T >: Untyped] private[ast] (ref: Tree[T])(implicit @constructorOnly src: SourceFile) @@ -1079,6 +1081,7 @@ object Trees { type JavaSeqLiteral = Trees.JavaSeqLiteral[T] type Inlined = Trees.Inlined[T] type TypeTree = Trees.TypeTree[T] + type InferredTypeTree = Trees.InferredTypeTree[T] type SingletonTypeTree = Trees.SingletonTypeTree[T] type RefinedTypeTree = Trees.RefinedTypeTree[T] type AppliedTypeTree = Trees.AppliedTypeTree[T] diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 6555e2b3b1fb..cb68717f36cb 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -981,11 +981,13 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { } /** cast tree to `tp`, assuming no exception is raised, i.e the operation is pure */ - def cast(tp: Type)(using Context): Tree = { - assert(tp.isValueType, i"bad cast: $tree.asInstanceOf[$tp]") + def cast(tp: Type)(using Context): Tree = cast(TypeTree(tp)) + + /** cast tree to `tp`, assuming no exception is raised, i.e the operation is pure */ + def cast(tpt: TypeTree)(using Context): Tree = + assert(tpt.tpe.isValueType, i"bad cast: $tree.asInstanceOf[$tpt]") tree.select(if (ctx.erasedTypes) defn.Any_asInstanceOf else defn.Any_typeCast) - .appliedToType(tp) - } + .appliedToTypeTree(tpt) /** cast `tree` to `tp` (or its box/unbox/cast equivalent when after * erasure and value and non-value types are mixed), diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index f357a4d2441d..edaa1a7a3dbb 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -159,14 +159,29 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { argStr ~ " " ~ arrow(isGiven) ~ " " ~ argText(args.last) } - def toTextDependentFunction(appType: MethodType): Text = - "(" - ~ keywordText("erased ").provided(appType.isErasedMethod) - ~ paramsText(appType) - ~ ") " - ~ arrow(appType.isImplicitMethod) - ~ " " - ~ toText(appType.resultType) + def toTextMethodAsFunction(info: Type): Text = info match + case info: MethodType => + changePrec(GlobalPrec) { + "(" + ~ keywordText("erased ").provided(info.isErasedMethod) + ~ ( if info.isParamDependent || info.isResultDependent + then paramsText(info) + else argsText(info.paramInfos) + ) + ~ ") " + ~ arrow(info.isImplicitMethod) + ~ " " + ~ toTextMethodAsFunction(info.resultType) + } + case info: PolyType => + changePrec(GlobalPrec) { + "[" + ~ paramsText(info) + ~ "] => " + ~ toTextMethodAsFunction(info.resultType) + } + case _ => + toText(info) def isInfixType(tp: Type): Boolean = tp match case AppliedType(tycon, args) => @@ -230,8 +245,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty => // don't eta contract if the application would be printed specially toText(tycon) - case tp: RefinedType if defn.isFunctionType(tp) && !printDebug => - toTextDependentFunction(tp.refinedInfo.asInstanceOf[MethodType]) + case tp: RefinedType + if (defn.isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass)) + && !printDebug => + toTextMethodAsFunction(tp.refinedInfo) case tp: TypeRef => if (tp.symbol.isAnonymousClass && !showUniqueIds) toText(tp.info) @@ -245,6 +262,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { case ErasedValueType(tycon, underlying) => "ErasedValueType(" ~ toText(tycon) ~ ", " ~ toText(underlying) ~ ")" case tp: ClassInfo => + if tp.cls.derivesFrom(defn.PolyFunctionClass) then + tp.member(nme.apply).info match + case info: PolyType => return toTextMethodAsFunction(info) + case _ => toTextParents(tp.parents) ~~ "{...}" case JavaArrayType(elemtp) => toText(elemtp) ~ "[]" @@ -501,6 +522,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { "" case TypeTree() => typeText(toText(tree.typeOpt)) + ~ Str("(inf)").provided(tree.isInstanceOf[InferredTypeTree] && printDebug) case SingletonTypeTree(ref) => toTextLocal(ref) ~ "." ~ keywordStr("type") case RefinedTypeTree(tpt, refines) => @@ -510,6 +532,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { changePrec(OrTypePrec) { toText(args(0)) ~ " | " ~ atPrec(OrTypePrec + 1) { toText(args(1)) } } else if (tpt.symbol == defn.andType && args.length == 2) changePrec(AndTypePrec) { toText(args(0)) ~ " & " ~ atPrec(AndTypePrec + 1) { toText(args(1)) } } + else if defn.isFunctionClass(tpt.symbol) + && tpt.isInstanceOf[TypeTree] && tree.hasType && !printDebug + then changePrec(GlobalPrec) { toText(tree.typeOpt) } else args match case arg :: _ if arg.isTerm => toTextLocal(tpt) ~ "(" ~ Text(args.map(argText), ", ") ~ ")" diff --git a/compiler/src/dotty/tools/dotc/reporting/messages.scala b/compiler/src/dotty/tools/dotc/reporting/messages.scala index 8a7852e56f9a..fb7c6d41e885 100644 --- a/compiler/src/dotty/tools/dotc/reporting/messages.scala +++ b/compiler/src/dotty/tools/dotc/reporting/messages.scala @@ -149,7 +149,6 @@ import transform.SymUtils._ } class AnonymousFunctionMissingParamType(param: untpd.ValDef, - args: List[untpd.Tree], tree: untpd.Function, pt: Type) (using Context) @@ -157,7 +156,7 @@ import transform.SymUtils._ def msg = { val ofFun = if param.name.is(WildcardParamName) - || (MethodType.syntheticParamNames(args.length + 1) contains param.name) + || (MethodType.syntheticParamNames(tree.args.length + 1) contains param.name) then i" of expanded function:\n$tree" else "" diff --git a/compiler/src/dotty/tools/dotc/transform/Dependencies.scala b/compiler/src/dotty/tools/dotc/transform/Dependencies.scala index 0503dd71601c..c5c6c5baaa7b 100644 --- a/compiler/src/dotty/tools/dotc/transform/Dependencies.scala +++ b/compiler/src/dotty/tools/dotc/transform/Dependencies.scala @@ -194,20 +194,18 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co if isExpr(sym) && isLocal(sym) then markCalled(sym, enclosure) case tree: This => narrowTo(tree.symbol.asClass) - case tree: DefDef => - if sym.owner.isTerm then - logicOwner(sym) = sym.enclosingPackageClass - // this will make methods in supercall constructors of top-level classes owned - // by the enclosing package, which means they will be static. - // On the other hand, all other methods will be indirectly owned by their - // top-level class. This avoids possible deadlocks when a static method - // has to access its enclosing object from the outside. - else if sym.isConstructor then - if sym.isPrimaryConstructor && isLocal(sym.owner) && !sym.owner.is(Trait) then - // add a call edge from the constructor of a local non-trait class to - // the class itself. This is done so that the constructor inherits - // the free variables of the class. - symSet(called, sym) += sym.owner + case tree: MemberDef if isExpr(sym) && sym.owner.isTerm => + logicOwner(sym) = sym.enclosingPackageClass + // this will make methods in supercall constructors of top-level classes owned + // by the enclosing package, which means they will be static. + // On the other hand, all other methods will be indirectly owned by their + // top-level class. This avoids possible deadlocks when a static method + // has to access its enclosing object from the outside. + case tree: DefDef if sym.isPrimaryConstructor && isLocal(sym.owner) && !sym.owner.is(Trait) => + // add a call edge from the constructor of a local non-trait class to + // the class itself. This is done so that the constructor inherits + // the free variables of the class. + symSet(called, sym) += sym.owner case tree: TypeDef => if sym.owner.isTerm then logicOwner(sym) = sym.topLevelClass.owner case _ => diff --git a/compiler/src/dotty/tools/dotc/transform/SymUtils.scala b/compiler/src/dotty/tools/dotc/transform/SymUtils.scala index f8b276d52088..4df1d75e93fa 100644 --- a/compiler/src/dotty/tools/dotc/transform/SymUtils.scala +++ b/compiler/src/dotty/tools/dotc/transform/SymUtils.scala @@ -287,5 +287,50 @@ object SymUtils: self.addAnnotation( Annotation(defn.TargetNameAnnot, Literal(Constant(nameFn(original.targetName).toString)).withSpan(original.span))) + + /** The return type as seen from the body of this definition. It is + * computed from the symbol's type by replacing param refs by param symbols. + */ + def localReturnType(using Context): Type = + if self.isConstructor then defn.UnitType + else + def instantiateRT(info: Type, psymss: List[List[Symbol]]): Type = info match + case info: PolyType => + instantiateRT(info.instantiate(psymss.head.map(_.typeRef)), psymss.tail) + case info: MethodType => + instantiateRT(info.instantiate(psymss.head.map(_.termRef)), psymss.tail) + case info => + info.widenExpr + instantiateRT(self.info, self.paramSymss) + + /** The expected type of a return to `self` at the place indicated by the context. + * This is the local return type instantiated by the symbols of any context function + * closures that enclose the site of the return + */ + def returnProto(using Context): Type = + + /** If `pt` is a context function type, its return type. If the CFT + * is dependent, instantiate with the parameters of the associated + * anonymous function. + * @param paramss the parameters of the anonymous functions + * enclosing the return expression + */ + def instantiateCFT(pt: Type, paramss: => List[List[Symbol]]): Type = + val ift = defn.asContextFunctionType(pt) + if ift.exists then + ift.nonPrivateMember(nme.apply).info match + case appType: MethodType => + instantiateCFT(appType.instantiate(paramss.head.map(_.termRef)), paramss.tail) + else pt + + def iftParamss = ctx.owner.ownersIterator + .filter(_.is(Method, butNot = Accessor)) + .takeWhile(_.isAnonymousFunction) + .toList + .reverse + .map(_.paramSymss.head) + + instantiateCFT(self.localReturnType, iftParamss) + end returnProto end extension end SymUtils diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index aa2d071cafba..cc1012993bd6 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -333,7 +333,7 @@ object Inferencing { @tailrec def boundVars(tree: Tree, acc: List[TypeVar]): List[TypeVar] = tree match { case Apply(fn, _) => boundVars(fn, acc) case TypeApply(fn, targs) => - val tvars = targs.filter(_.isInstanceOf[TypeVarBinder[?]]).tpes.collect { + val tvars = targs.filter(_.isInstanceOf[InferredTypeTree]).tpes.collect { case tvar: TypeVar if !tvar.isInstantiated && ctx.typerState.ownedVars.contains(tvar) && diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 6a47cf143d62..230ffe752685 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1431,168 +1431,7 @@ class Namer { typer: Typer => */ def valOrDefDefSig(mdef: ValOrDefDef, sym: Symbol, paramss: List[List[Symbol]], paramFn: Type => Type)(using Context): Type = { - def inferredType = { - /** A type for this definition that might be inherited from elsewhere: - * If this is a setter parameter, the corresponding getter type. - * If this is a class member, the conjunction of all result types - * of overridden methods. - * NoType if neither case holds. - */ - val inherited = - if (sym.owner.isTerm) NoType - else - // TODO: Look only at member of supertype instead? - lazy val schema = paramFn(WildcardType) - val site = sym.owner.thisType - val bcs = sym.owner.info.baseClasses - if bcs.isEmpty then - assert(ctx.reporter.errorsReported) - NoType - else bcs.tail.foldLeft(NoType: Type) { (tp, cls) => - def instantiatedResType(info: Type, paramss: List[List[Symbol]]): Type = info match - case info: PolyType => - paramss match - case TypeSymbols(tparams) :: paramss1 if info.paramNames.length == tparams.length => - instantiatedResType(info.instantiate(tparams.map(_.typeRef)), paramss1) - case _ => - NoType - case info: MethodType => - paramss match - case TermSymbols(vparams) :: paramss1 if info.paramNames.length == vparams.length => - instantiatedResType(info.instantiate(vparams.map(_.termRef)), paramss1) - case _ => - NoType - case _ => - if paramss.isEmpty then info.widenExpr - else NoType - - val iRawInfo = - cls.info.nonPrivateDecl(sym.name).matchingDenotation(site, schema, sym.targetName).info - val iResType = instantiatedResType(iRawInfo, paramss).asSeenFrom(site, cls) - if (iResType.exists) - typr.println(i"using inherited type for ${mdef.name}; raw: $iRawInfo, inherited: $iResType") - tp & iResType - } - end inherited - - /** If this is a default getter, the type of the corresponding method parameter, - * otherwise NoType. - */ - def defaultParamType = sym.name match - case DefaultGetterName(original, idx) => - val meth: Denotation = - if (original.isConstructorName && (sym.owner.is(ModuleClass))) - sym.owner.companionClass.info.decl(nme.CONSTRUCTOR) - else - ctx.defContext(sym).denotNamed(original) - def paramProto(paramss: List[List[Type]], idx: Int): Type = paramss match { - case params :: paramss1 => - if (idx < params.length) params(idx) - else paramProto(paramss1, idx - params.length) - case nil => - NoType - } - val defaultAlts = meth.altsWith(_.hasDefaultParams) - if (defaultAlts.length == 1) - paramProto(defaultAlts.head.info.widen.paramInfoss, idx) - else - NoType - case _ => - NoType - - /** The expected type for a default argument. This is normally the `defaultParamType` - * with references to internal parameters replaced by wildcards. This replacement - * makes it possible that the default argument can have a more specific type than the - * parameter. For instance, we allow - * - * class C[A](a: A) { def copy[B](x: B = a): C[B] = C(x) } - * - * However, if the default parameter type is a context function type, we - * have to make sure that wildcard types do not leak into the implicitly - * generated closure's result type. Test case is pos/i12019.scala. If there - * would be a leakage with the wildcard approximation, we pick the original - * default parameter type as expected type. - */ - def expectedDefaultArgType = - val originalTp = defaultParamType - val approxTp = wildApprox(originalTp) - approxTp.stripPoly match - case atp @ defn.ContextFunctionType(_, resType, _) - if !defn.isNonRefinedFunction(atp) // in this case `resType` is lying, gives us only the non-dependent upper bound - || resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) => - originalTp - case _ => - approxTp - - // println(s"final inherited for $sym: ${inherited.toString}") !!! - // println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}") - // TODO Scala 3.1: only check for inline vals (no final ones) - def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable) - - var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType) - if sym.isInlineMethod then rhsCtx = rhsCtx.addMode(Mode.InlineableBody) - if sym.is(ExtensionMethod) then rhsCtx = rhsCtx.addMode(Mode.InExtensionMethod) - val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten - if (typeParams.nonEmpty) { - // we'll be typing an expression from a polymorphic definition's body, - // so we must allow constraining its type parameters - // compare with typedDefDef, see tests/pos/gadt-inference.scala - rhsCtx.setFreshGADTBounds - rhsCtx.gadt.addToConstraint(typeParams) - } - - def typedAheadRhs(pt: Type) = - PrepareInlineable.dropInlineIfError(sym, - typedAheadExpr(mdef.rhs, pt)(using rhsCtx)) - - def rhsType = - // For default getters, we use the corresponding parameter type as an - // expected type but we run it through `wildApprox` to allow default - // parameters like in `def mkList[T](value: T = 1): List[T]`. - val defaultTp = defaultParamType - val pt = inherited.orElse(expectedDefaultArgType).orElse(WildcardType).widenExpr - val tp = typedAheadRhs(pt).tpe - if (defaultTp eq pt) && (tp frozen_<:< defaultTp) then - // When possible, widen to the default getter parameter type to permit a - // larger choice of overrides (see `default-getter.scala`). - // For justification on the use of `@uncheckedVariance`, see - // `default-getter-variance.scala`. - AnnotatedType(defaultTp, Annotation(defn.UncheckedVarianceAnnot)) - else - // don't strip @uncheckedVariance annot for default getters - TypeOps.simplify(tp.widenTermRefExpr, - if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match - case ctp: ConstantType if isInlineVal => ctp - case tp => TypeComparer.widenInferred(tp, pt) - - // Replace aliases to Unit by Unit itself. If we leave the alias in - // it would be erased to BoxedUnit. - def dealiasIfUnit(tp: Type) = if (tp.isRef(defn.UnitClass)) defn.UnitType else tp - - // Approximate a type `tp` with a type that does not contain skolem types. - val deskolemize = new ApproximatingTypeMap { - def apply(tp: Type) = /*trace(i"deskolemize($tp) at $variance", show = true)*/ - tp match { - case tp: SkolemType => range(defn.NothingType, atVariance(1)(apply(tp.info))) - case _ => mapOver(tp) - } - } - - def cookedRhsType = deskolemize(dealiasIfUnit(rhsType)) - def lhsType = fullyDefinedType(cookedRhsType, "right-hand side", mdef.span) - //if (sym.name.toString == "y") println(i"rhs = $rhsType, cooked = $cookedRhsType") - if (inherited.exists) - if (isInlineVal) lhsType else inherited - else { - if (sym.is(Implicit)) - mdef match { - case _: DefDef => missingType(sym, "result ") - case _: ValDef if sym.owner.isType => missingType(sym, "") - case _ => - } - lhsType orElse WildcardType - } - } + def inferredType = inferredResultType(mdef, sym, paramss, paramFn, WildcardType) lazy val termParamss = paramss.collect { case TermSymbols(vparams) => vparams } val tptProto = mdef.tpt match { @@ -1673,15 +1512,184 @@ class Namer { typer: Typer => ddef.trailingParamss.foreach(completeParams) val paramSymss = normalizeIfConstructor(ddef.paramss.nestedMap(symbolOfTree), isConstructor) sym.setParamss(paramSymss) - def wrapMethType(restpe: Type): Type = { + def wrapMethType(restpe: Type): Type = instantiateDependent(restpe, paramSymss) - methodType(paramSymss, restpe, isJava = ddef.mods.is(JavaDefined)) - } - if (isConstructor) { + methodType(paramSymss, restpe, ddef.mods.is(JavaDefined)) + if isConstructor then // set result type tree to unit, but take the current class as result type of the symbol typedAheadType(ddef.tpt, defn.UnitType) wrapMethType(effectiveResultType(sym, paramSymss)) - } - else valOrDefDefSig(ddef, sym, paramSymss, wrapMethType) + else + valOrDefDefSig(ddef, sym, paramSymss, wrapMethType) } + + def inferredResultType( + mdef: ValOrDefDef, + sym: Symbol, + paramss: List[List[Symbol]], + paramFn: Type => Type, + fallbackProto: Type + )(using Context): Type = + + /** A type for this definition that might be inherited from elsewhere: + * If this is a setter parameter, the corresponding getter type. + * If this is a class member, the conjunction of all result types + * of overridden methods. + * NoType if neither case holds. + */ + val inherited = + if (sym.owner.isTerm) NoType + else + // TODO: Look only at member of supertype instead? + lazy val schema = paramFn(WildcardType) + val site = sym.owner.thisType + val bcs = sym.owner.info.baseClasses + if bcs.isEmpty then + assert(ctx.reporter.errorsReported) + NoType + else bcs.tail.foldLeft(NoType: Type) { (tp, cls) => + def instantiatedResType(info: Type, paramss: List[List[Symbol]]): Type = info match + case info: PolyType => + paramss match + case TypeSymbols(tparams) :: paramss1 if info.paramNames.length == tparams.length => + instantiatedResType(info.instantiate(tparams.map(_.typeRef)), paramss1) + case _ => + NoType + case info: MethodType => + paramss match + case TermSymbols(vparams) :: paramss1 if info.paramNames.length == vparams.length => + instantiatedResType(info.instantiate(vparams.map(_.termRef)), paramss1) + case _ => + NoType + case _ => + if paramss.isEmpty then info.widenExpr + else NoType + + val iRawInfo = + cls.info.nonPrivateDecl(sym.name).matchingDenotation(site, schema, sym.targetName).info + val iResType = instantiatedResType(iRawInfo, paramss).asSeenFrom(site, cls) + if (iResType.exists) + typr.println(i"using inherited type for ${mdef.name}; raw: $iRawInfo, inherited: $iResType") + tp & iResType + } + end inherited + + /** If this is a default getter, the type of the corresponding method parameter, + * otherwise NoType. + */ + def defaultParamType = sym.name match + case DefaultGetterName(original, idx) => + val meth: Denotation = + if (original.isConstructorName && (sym.owner.is(ModuleClass))) + sym.owner.companionClass.info.decl(nme.CONSTRUCTOR) + else + ctx.defContext(sym).denotNamed(original) + def paramProto(paramss: List[List[Type]], idx: Int): Type = paramss match { + case params :: paramss1 => + if (idx < params.length) params(idx) + else paramProto(paramss1, idx - params.length) + case nil => + NoType + } + val defaultAlts = meth.altsWith(_.hasDefaultParams) + if (defaultAlts.length == 1) + paramProto(defaultAlts.head.info.widen.paramInfoss, idx) + else + NoType + case _ => + NoType + + /** The expected type for a default argument. This is normally the `defaultParamType` + * with references to internal parameters replaced by wildcards. This replacement + * makes it possible that the default argument can have a more specific type than the + * parameter. For instance, we allow + * + * class C[A](a: A) { def copy[B](x: B = a): C[B] = C(x) } + * + * However, if the default parameter type is a context function type, we + * have to make sure that wildcard types do not leak into the implicitly + * generated closure's result type. Test case is pos/i12019.scala. If there + * would be a leakage with the wildcard approximation, we pick the original + * default parameter type as expected type. + */ + def expectedDefaultArgType = + val originalTp = defaultParamType + val approxTp = wildApprox(originalTp) + approxTp.stripPoly match + case atp @ defn.ContextFunctionType(_, resType, _) + if !defn.isNonRefinedFunction(atp) // in this case `resType` is lying, gives us only the non-dependent upper bound + || resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) => + originalTp + case _ => + approxTp + + // println(s"final inherited for $sym: ${inherited.toString}") !!! + // println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}") + // TODO Scala 3.1: only check for inline vals (no final ones) + def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable) + + var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType) + if sym.isInlineMethod then rhsCtx = rhsCtx.addMode(Mode.InlineableBody) + if sym.is(ExtensionMethod) then rhsCtx = rhsCtx.addMode(Mode.InExtensionMethod) + val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten + if (typeParams.nonEmpty) { + // we'll be typing an expression from a polymorphic definition's body, + // so we must allow constraining its type parameters + // compare with typedDefDef, see tests/pos/gadt-inference.scala + rhsCtx.setFreshGADTBounds + rhsCtx.gadt.addToConstraint(typeParams) + } + + def typedAheadRhs(pt: Type) = + PrepareInlineable.dropInlineIfError(sym, + typedAheadExpr(mdef.rhs, pt)(using rhsCtx)) + + def rhsType = + // For default getters, we use the corresponding parameter type as an + // expected type but we run it through `wildApprox` to allow default + // parameters like in `def mkList[T](value: T = 1): List[T]`. + val defaultTp = defaultParamType + val pt = inherited.orElse(expectedDefaultArgType).orElse(fallbackProto).widenExpr + val tp = typedAheadRhs(pt).tpe + if (defaultTp eq pt) && (tp frozen_<:< defaultTp) then + // When possible, widen to the default getter parameter type to permit a + // larger choice of overrides (see `default-getter.scala`). + // For justification on the use of `@uncheckedVariance`, see + // `default-getter-variance.scala`. + AnnotatedType(defaultTp, Annotation(defn.UncheckedVarianceAnnot)) + else + // don't strip @uncheckedVariance annot for default getters + TypeOps.simplify(tp.widenTermRefExpr, + if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match + case ctp: ConstantType if isInlineVal => ctp + case tp => TypeComparer.widenInferred(tp, pt) + + // Replace aliases to Unit by Unit itself. If we leave the alias in + // it would be erased to BoxedUnit. + def dealiasIfUnit(tp: Type) = if (tp.isRef(defn.UnitClass)) defn.UnitType else tp + + // Approximate a type `tp` with a type that does not contain skolem types. + val deskolemize = new ApproximatingTypeMap { + def apply(tp: Type) = /*trace(i"deskolemize($tp) at $variance", show = true)*/ + tp match { + case tp: SkolemType => range(defn.NothingType, atVariance(1)(apply(tp.info))) + case _ => mapOver(tp) + } + } + + def cookedRhsType = deskolemize(dealiasIfUnit(rhsType)) + def lhsType = fullyDefinedType(cookedRhsType, "right-hand side", mdef.span) + //if (sym.name.toString == "y") println(i"rhs = $rhsType, cooked = $cookedRhsType") + if (inherited.exists) + if (isInlineVal) lhsType else inherited + else { + if (sym.is(Implicit)) + mdef match { + case _: DefDef => missingType(sym, "result ") + case _: ValDef if sym.owner.isType => missingType(sym, "") + case _ => + } + lhsType orElse WildcardType + } + end inferredResultType } diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index 64bf32c49f91..f18843a1a4fd 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -636,7 +636,7 @@ object ProtoTypes { def newTypeVars(tl: TypeLambda): List[TypeTree] = for (paramRef <- tl.paramRefs) yield { - val tt = TypeVarBinder().withSpan(owningTree.span) + val tt = InferredTypeTree().withSpan(owningTree.span) val tvar = TypeVar(paramRef, state) state.ownedVars += tvar tt.withType(tvar) diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index dfba38d4efa5..8ed463d2eea9 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -235,19 +235,18 @@ trait TypeAssigner { else errorType("not a legal qualifying class for this", tree.srcPos)) } - def assignType(tree: untpd.Super, qual: Tree, mixinClass: Symbol = NoSymbol)(using Context): Super = { - val mix = tree.mix - qual.tpe match { - case err: ErrorType => untpd.cpy.Super(tree)(qual, mix).withType(err) + def superType(qualType: Type, mix: untpd.Ident, mixinClass: Symbol, pos: SrcPos)(using Context) = + qualType match + case err: ErrorType => err case qtype @ ThisType(_) => val cls = qtype.cls def findMixinSuper(site: Type): Type = site.parents filter (_.typeSymbol.name == mix.name) match { case p :: Nil => p.typeConstructor case Nil => - errorType(SuperQualMustBeParent(mix, cls), tree.srcPos) + errorType(SuperQualMustBeParent(mix, cls), pos) case p :: q :: _ => - errorType("ambiguous parent class qualifier", tree.srcPos) + errorType("ambiguous parent class qualifier", pos) } val owntype = if (mixinClass.exists) mixinClass.appliedRef @@ -257,9 +256,11 @@ trait TypeAssigner { val ps = cls.classInfo.parents if (ps.isEmpty) defn.AnyType else ps.reduceLeft((x: Type, y: Type) => x & y) } - tree.withType(SuperType(cls.thisType, owntype)) - } - } + SuperType(cls.thisType, owntype) + + def assignType(tree: untpd.Super, qual: Tree, mixinClass: Symbol = NoSymbol)(using Context): Super = + untpd.cpy.Super(tree)(qual, tree.mix) + .withType(superType(qual.tpe, tree.mix, mixinClass, tree.srcPos)) /** Substitute argument type `argType` for parameter `pref` in type `tp`, * skolemizing the argument type if it is not stable and `pref` occurs in `tp`. diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 989f65215377..24e0c1abf5c7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1058,13 +1058,14 @@ class Typer extends Namer cpy.Block(block)(stats, expr1) withType expr1.tpe // no assignType here because avoid is redundant case _ => val target = pt.simplified - if tree.tpe <:< target then Typed(tree, TypeTree(pt.simplified)) + val targetTpt = InferredTypeTree().withType(target) + if tree.tpe <:< target then Typed(tree, targetTpt) else // This case should not normally arise. It currently does arise in test cases // pos/t4080b.scala and pos/i7067.scala. In that case, a type ascription is wrong // and would not pass Ycheck. We have to use a cast instead. TODO: follow-up why // the cases arise and eliminate them, if possible. - tree.cast(target) + tree.cast(targetTpt) } def noLeaks(t: Tree): Boolean = escapingRefs(t, localSyms).isEmpty if (noLeaks(tree)) tree @@ -1116,7 +1117,7 @@ class Typer extends Namer * def double(x: Char): String = s"$x$x" * "abc" flatMap double */ - private def decomposeProtoFunction(pt: Type, defaultArity: Int, tree: untpd.Tree)(using Context): (List[Type], untpd.Tree) = { + private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = { def typeTree(tp: Type) = tp match { case _: WildcardType => untpd.TypeTree() case _ => untpd.TypeTree(tp) @@ -1135,29 +1136,62 @@ class Typer extends Namer report.error( i"""Implementation restriction: Expected result type $pt1 |is a curried dependent context function type. Such types are not yet supported.""", - tree.srcPos) - + pos) pt1 match { case tp: TypeParamRef => - decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree) + decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, pos) case _ => pt1.findFunctionTypeInUnion match { case pt1 if defn.isNonRefinedFunction(pt1) => // if expected parameter type(s) are wildcards, approximate from below. // if expected result type is a wildcard, approximate from above. // this can type the greatest set of admissible closures. (pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last))) - case SAMType(sam @ MethodTpe(_, formals, restpe)) => + case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe)) + if defn.isNonRefinedFunction(parent) && formals.length == defaultArity => + (formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) + case SAMType(mt @ MethodTpe(_, formals, restpe)) => (formals, - if sam.isResultDependent then - untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef))) - else - typeTree(restpe)) + if (mt.isResultDependent) + untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))) + else + typeTree(restpe)) case _ => (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) } } } + /** The parameter type for a parameter in a lambda that does + * not have an explicit type given, and where the type is not known from the context. + * In this case the parameter type needs to be inferred the "target type" T known + * from the callee `f` if the lambda is of a form like `x => f(x)`. + * If `T` exists, we know that `S <: I <: T`. + * + * The inference makes two attempts: + * + * 1. Compute the target type `T` and make it known that `S <: T`. + * If the expected type `S` can be fully defined under ForceDegree.flipBottom, + * pick this one (this might use the fact that S <: T for an upper approximation). + * 2. Otherwise, if the target type `T` can be fully defined under ForceDegree.flipBottom, + * pick this one. + * + * If both attempts fail, return `NoType`. + */ + def inferredFromTarget( + param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int)(using Context): Type = + val target = calleeType.widen match + case mtpe: MethodType => + val pos = paramIndex(param.name) + if pos < mtpe.paramInfos.length then + val ptype = mtpe.paramInfos(pos) + if ptype.isRepeatedParam then NoType else ptype + else NoType + case _ => NoType + if target.exists then formal <:< target + if isFullyDefined(formal, ForceDegree.flipBottom) then formal + else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target + else NoType + def typedFunction(tree: untpd.Function, pt: Type)(using Context): Tree = if (ctx.mode is Mode.Type) typedFunctionType(tree, pt) else typedFunctionValue(tree, pt) @@ -1330,41 +1364,7 @@ class Typer extends Namer case _ => } - val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree) - - /** The inferred parameter type for a parameter in a lambda that does - * not have an explicit type given. - * An inferred parameter type I has two possible sources: - * - the type S known from the context - * - the "target type" T known from the callee `f` if the lambda is of a form like `x => f(x)` - * If `T` exists, we know that `S <: I <: T`. - * - * The inference makes three attempts: - * - * 1. If the expected type `S` is already fully defined under ForceDegree.failBottom - * pick this one. - * 2. Compute the target type `T` and make it known that `S <: T`. - * If the expected type `S` can be fully defined under ForceDegree.flipBottom, - * pick this one (this might use the fact that S <: T for an upper approximation). - * 3. Otherwise, if the target type `T` can be fully defined under ForceDegree.flipBottom, - * pick this one. - * - * If all attempts fail, issue a "missing parameter type" error. - */ - def inferredParamType(param: untpd.ValDef, formal: Type): Type = - if isFullyDefined(formal, ForceDegree.failBottom) then return formal - val target = calleeType.widen match - case mtpe: MethodType => - val pos = paramIndex(param.name) - if pos < mtpe.paramInfos.length then - val ptype = mtpe.paramInfos(pos) - if ptype.isRepeatedParam then NoType else ptype - else NoType - case _ => NoType - if target.exists then formal <:< target - if isFullyDefined(formal, ForceDegree.flipBottom) then formal - else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target - else errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.srcPos) + val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos) def protoFormal(i: Int): Type = if (protoFormals.length == params.length) protoFormals(i) @@ -1390,9 +1390,19 @@ class Typer extends Namer val inferredParams: List[untpd.ValDef] = for ((param, i) <- params.zipWithIndex) yield if (!param.tpt.isEmpty) param - else cpy.ValDef(param)( - tpt = untpd.TypeTree( - inferredParamType(param, protoFormal(i)).translateFromRepeated(toArray = false))) + else + val formal = protoFormal(i) + val knownFormal = isFullyDefined(formal, ForceDegree.failBottom) + val paramType = + if knownFormal then formal + else inferredFromTarget(param, formal, calleeType, paramIndex) + .orElse(errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos)) + val paramTpt = untpd.TypedSplice( + (if knownFormal then InferredTypeTree() else untpd.TypeTree()) + .withType(paramType.translateFromRepeated(toArray = false)) + .withSpan(param.span.endPos) + ) + cpy.ValDef(param)(tpt = paramTpt) desugar.makeClosure(inferredParams, fnBody, resultTpt, isContextual, tree.span) } typed(desugared, pt) @@ -1458,7 +1468,7 @@ class Typer extends Namer typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt) } else { - val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree) + val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree.srcPos) val checkMode = if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None else desugar.MatchCheck.Exhaustive @@ -1659,66 +1669,25 @@ class Typer extends Namer caseRest(using ctx.fresh.setFreshGADTBounds.setNewScope) } - def typedReturn(tree: untpd.Return)(using Context): Return = { + def typedReturn(tree: untpd.Return)(using Context): Return = - /** If `pt` is a context function type, its return type. If the CFT - * is dependent, instantiate with the parameters of the associated - * anonymous function. - * @param paramss the parameters of the anonymous functions - * enclosing the return expression - */ - def instantiateCFT(pt: Type, paramss: => List[List[Symbol]]): Type = - val ift = defn.asContextFunctionType(pt) - if ift.exists then - ift.nonPrivateMember(nme.apply).info match - case appType: MethodType => - instantiateCFT(appType.instantiate(paramss.head.map(_.termRef)), paramss.tail) - else pt - - def returnProto(owner: Symbol): Type = - if (owner.isConstructor) defn.UnitType - else - // We need to get the return type of the enclosing function, with all parameters replaced - // by the local type and value parameters. It would be nice if we could look up that - // type simply in the tpt field of the enclosing function. But the tree argument in - // a context is an untyped tree, so we cannot extract its type. - def instantiateRT(info: Type, psymss: List[List[Symbol]]): Type = info match - case info: PolyType => - instantiateRT(info.instantiate(psymss.head.map(_.typeRef)), psymss.tail) - case info: MethodType => - instantiateRT(info.instantiate(psymss.head.map(_.termRef)), psymss.tail) - case info => - info.widenExpr - val rt = instantiateRT(owner.info, owner.paramSymss) - def iftParamss = ctx.owner.ownersIterator - .filter(_.is(Method, butNot = Accessor)) - .takeWhile(_.isAnonymousFunction) - .toList - .reverse - .map(_.paramSymss.head) - instantiateCFT(rt, iftParamss) - - def enclMethInfo(cx: Context): (Tree, Type) = { + def enclMethInfo(cx: Context): (Tree, Type) = val owner = cx.owner - if (owner.isType) { + if owner.isType then report.error(ReturnOutsideMethodDefinition(owner), tree.srcPos) (EmptyTree, WildcardType) - } - else if (owner != cx.outer.owner && owner.isRealMethod) - if (owner.isInlineMethod) + else if owner != cx.outer.owner && owner.isRealMethod then + if owner.isInlineMethod then (EmptyTree, errorType(NoReturnFromInlineable(owner), tree.srcPos)) - else if (!owner.isCompleted) + else if !owner.isCompleted then (EmptyTree, errorType(MissingReturnTypeWithReturnStatement(owner), tree.srcPos)) - else { - val from = Ident(TermRef(NoPrefix, owner.asTerm)) - val proto = returnProto(owner) - (from, proto) - } + else + (Ident(TermRef(NoPrefix, owner.asTerm)), owner.returnProto) else enclMethInfo(cx.outer) - } + val (from, proto) = - if (tree.from.isEmpty) enclMethInfo(ctx) - else { + if tree.from.isEmpty then enclMethInfo(ctx) + else val from = tree.from.asInstanceOf[tpd.Tree] val proto = if (ctx.erasedTypes) from.symbol.info.finalResultType @@ -1726,10 +1695,9 @@ class Typer extends Namer // because we do not know the internal type params and method params. // Hence no adaptation is possible, and we assume WildcardType as prototype. (from, proto) - } val expr1 = typedExpr(tree.expr orElse untpd.unitLiteral.withSpan(tree.span), proto) assignType(cpy.Return(tree)(expr1, from)) - } + end typedReturn def typedWhileDo(tree: untpd.WhileDo)(using Context): Tree = inContext(Nullables.whileContext(tree.span)) { @@ -1802,8 +1770,15 @@ class Typer extends Namer bindings1, expansion1) } + def completeTypeTree(tree: untpd.TypeTree, pt: Type, original: untpd.Tree)(using Context): TypeTree = + tree.withSpan(original.span).withAttachmentsFrom(original) + .withType( + if isFullyDefined(pt, ForceDegree.flipBottom) then pt + else if ctx.reporter.errorsReported then UnspecifiedErrorType + else errorType(i"cannot infer type; expected type $pt is not fully defined", tree.srcPos)) + def typedTypeTree(tree: untpd.TypeTree, pt: Type)(using Context): Tree = - tree match { + tree match case tree: untpd.DerivedTypeTree => tree.ensureCompletions tree.getAttachment(untpd.OriginalSymbol) match { @@ -1817,11 +1792,7 @@ class Typer extends Namer errorTree(tree, "Something's wrong: missing original symbol for type tree") } case _ => - tree.withType( - if (isFullyDefined(pt, ForceDegree.flipBottom)) pt - else if (ctx.reporter.errorsReported) UnspecifiedErrorType - else errorType(i"cannot infer type; expected type $pt is not fully defined", tree.srcPos)) - } + completeTypeTree(InferredTypeTree(), pt, tree) def typedSingletonTypeTree(tree: untpd.SingletonTypeTree)(using Context): SingletonTypeTree = { val ref1 = typedExpr(tree.ref) @@ -2736,7 +2707,7 @@ class Typer extends Namer case tree: untpd.TypedSplice => typedTypedSplice(tree) case tree: untpd.UnApply => typedUnApply(tree, pt) case tree: untpd.Tuple => typedTuple(tree, pt) - case tree: untpd.DependentTypeTree => typed(untpd.TypeTree().withSpan(tree.span), pt) + case tree: untpd.DependentTypeTree => completeTypeTree(untpd.TypeTree(), pt, tree) case tree: untpd.InfixOp => typedInfixOp(tree, pt) case tree: untpd.ParsedTry => typedTry(tree, pt) case tree @ untpd.PostfixOp(qual, Ident(nme.WILDCARD)) => typedAsFunction(tree, pt) @@ -3063,7 +3034,7 @@ class Typer extends Namer else Some(adapt(tree1, pt, locked)) } { (_, _) => None } - case TypeApply(fn, args) if args.forall(_.isInstanceOf[TypeVarBinder[_]]) => + case TypeApply(fn, args) if args.forall(_.isInstanceOf[untpd.InferredTypeTree]) => tryInsertImplicitOnQualifier(fn, pt, locked) case _ => None } diff --git a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala index 5b3544b894c4..ffca320d53d3 100644 --- a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala +++ b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala @@ -7,7 +7,6 @@ import collection.mutable */ abstract class SimpleIdentitySet[+Elem <: AnyRef] { def size: Int - final def isEmpty: Boolean = size == 0 def + [E >: Elem <: AnyRef](x: E): SimpleIdentitySet[E] def - [E >: Elem <: AnyRef](x: E): SimpleIdentitySet[Elem] def contains[E >: Elem <: AnyRef](x: E): Boolean @@ -15,20 +14,38 @@ abstract class SimpleIdentitySet[+Elem <: AnyRef] { def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A def toList: List[Elem] + + final def isEmpty: Boolean = size == 0 + + def forall[E >: Elem <: AnyRef](p: E => Boolean): Boolean = !exists(!p(_)) + + def filter(p: Elem => Boolean): SimpleIdentitySet[Elem] = + val z: SimpleIdentitySet[Elem] = SimpleIdentitySet.empty + (z /: this)((s, x) => if p(x) then s + x else s) + def ++ [E >: Elem <: AnyRef](that: SimpleIdentitySet[E]): SimpleIdentitySet[E] = if (this.size == 0) that else if (that.size == 0) this else ((this: SimpleIdentitySet[E]) /: that)(_ + _) + def -- [E >: Elem <: AnyRef](that: SimpleIdentitySet[E]): SimpleIdentitySet[E] = if (that.size == 0) this else ((SimpleIdentitySet.empty: SimpleIdentitySet[E]) /: this) { (s, x) => if (that.contains(x)) s else s + x } - override def toString: String = toList.mkString("(", ", ", ")") + override def toString: String = toList.mkString("{", ", ", "}") } object SimpleIdentitySet { + + def apply[Elem <: AnyRef](elems: Elem*): SimpleIdentitySet[Elem] = + elems.foldLeft(empty: SimpleIdentitySet[Elem])(_ + _) + + extension [E <: AnyRef](xs: SimpleIdentitySet[E]) + def intersect(ys: SimpleIdentitySet[E]): SimpleIdentitySet[E] = + xs.filter(ys.contains) + object empty extends SimpleIdentitySet[Nothing] { def size: Int = 0 def + [E <: AnyRef](x: E): SimpleIdentitySet[E] =