diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 85dd5e6665c6..c1dd78451bae 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1498,10 +1498,10 @@ object desugar { case vd: ValDef => vd } - def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(using Context): Function = { - val mods = if (isErased) Given | Erased else Given + def makeContextualFunction(formals: List[Tree], body: Tree, erasedParams: List[Boolean])(using Context): Function = { + val mods = Given val params = makeImplicitParameters(formals, mods) - FunctionWithMods(params, body, Modifiers(mods)) + FunctionWithMods(params, body, Modifiers(mods), erasedParams) } private def derivedValDef(original: Tree, named: NameTree, tpt: Tree, rhs: Tree, mods: Modifiers)(using Context) = { @@ -1834,6 +1834,7 @@ object desugar { cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt)) case _ => annotate(tpnme.retains, parent) + case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt) } desugared.withSpan(tree.span) } @@ -1909,6 +1910,28 @@ object desugar { TypeDef(tpnme.REFINE_CLASS, impl).withFlags(Trait) } + /** Ensure the given function tree use only ValDefs for parameters. + * For example, + * FunctionWithMods(List(TypeTree(A), TypeTree(B)), body, mods, erasedParams) + * gets converted to + * FunctionWithMods(List(ValDef(x$1, A), ValDef(x$2, B)), body, mods, erasedParams) + */ + def makeFunctionWithValDefs(tree: Function, pt: Type)(using Context): Function = { + val Function(args, result) = tree + args match { + case (_ : ValDef) :: _ => tree // ValDef case can be easily handled + case _ if !ctx.mode.is(Mode.Type) => tree + case _ => + val applyVParams = args.zipWithIndex.map { + case (p, n) => makeSyntheticParameter(n + 1, p) + } + tree match + case tree: FunctionWithMods => + untpd.FunctionWithMods(applyVParams, tree.body, tree.mods, tree.erasedParams) + case _ => untpd.Function(applyVParams, result) + } + } + /** Returns list of all pattern variables, possibly with their types, * without duplicates */ diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 9b55db600d3d..73f45bd7369b 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -960,7 +960,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => && tree.isTerm && { val qualType = tree.qualifier.tpe - hasRefinement(qualType) && !qualType.derivesFrom(defn.PolyFunctionClass) + hasRefinement(qualType) && !defn.isRefinedFunctionType(qualType) } def loop(tree: Tree): Boolean = tree match case TypeApply(fun, _) => diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 5b0061b5d036..d1b1cdf607b5 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -260,12 +260,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { // If `isParamDependent == false`, the value of `previousParamRefs` is not used. if isParamDependent then mutable.ListBuffer[TermRef]() else (null: ListBuffer[TermRef] | Null).uncheckedNN - def valueParam(name: TermName, origInfo: Type): TermSymbol = + def valueParam(name: TermName, origInfo: Type, isErased: Boolean): TermSymbol = val maybeImplicit = if tp.isContextualMethod then Given else if tp.isImplicitMethod then Implicit else EmptyFlags - val maybeErased = if tp.isErasedMethod then Erased else EmptyFlags + val maybeErased = if isErased then Erased else EmptyFlags def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord) @@ -283,7 +283,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm) (vparams.asInstanceOf[List[TermSymbol]], remaining1) case nil => - (tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil) + (tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.erasedParams).map(valueParam), Nil) val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1) (rtp, vparams :: paramss) case _ => @@ -1140,10 +1140,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def etaExpandCFT(using Context): Tree = def expand(target: Tree, tp: Type)(using Context): Tree = tp match - case defn.ContextFunctionType(argTypes, resType, isErased) => + case defn.ContextFunctionType(argTypes, resType, _) => val anonFun = newAnonFun( ctx.owner, - MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType), + MethodType.companion(isContextual = true)(argTypes, resType), coord = ctx.owner.coord) def lambdaBody(refss: List[List[Tree]]) = expand(target.select(nme.apply).appliedToArgss(refss), resType)( diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index aeebb1f203e8..a262c3658399 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -76,9 +76,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { override def isType: Boolean = body.isType } - /** A function type or closure with `implicit`, `erased`, or `given` modifiers */ - class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile) - extends Function(args, body) + /** A function type or closure with `implicit` or `given` modifiers and information on which parameters are `erased` */ + class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers, val erasedParams: List[Boolean])(implicit @constructorOnly src: SourceFile) + extends Function(args, body) { + assert(args.length == erasedParams.length) + + def hasErasedParams = erasedParams.contains(true) + } /** A polymorphic function type */ case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree { diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index e4533aa73ce0..decd428f5365 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -146,7 +146,6 @@ extension (tp: Type) defn.FunctionType( fname.functionArity, isContextual = fname.isContextFunction, - isErased = fname.isErasedFunction, isImpure = true).appliedTo(args) case _ => tp diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 77363a165f64..f9401a0c509f 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -336,8 +336,8 @@ class CheckCaptures extends Recheck, SymTransformer: mapArgUsing(_.forceBoxStatus(false)) else if meth == defn.Caps_unsafeBoxFunArg then mapArgUsing { - case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual, isErased) => - defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual, isErased) + case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual) => + defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual) } else super.recheckApply(tree, pt) match @@ -430,7 +430,7 @@ class CheckCaptures extends Recheck, SymTransformer: block match case closureDef(mdef) => pt.dealias match - case defn.FunctionOf(ptformals, _, _, _) + case defn.FunctionOf(ptformals, _, _) if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) => // Redo setup of the anonymous function so that formal parameters don't // get capture sets. This is important to avoid false widenings to `*` @@ -598,8 +598,8 @@ class CheckCaptures extends Recheck, SymTransformer: //println(i"check conforms $actual1 <<< $expected1") super.checkConformsExpr(actual1, expected1, tree) - private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean, isErased: Boolean)(using Context): Type = - MethodType.companion(isContextual = isContextual, isErased = isErased)(args, resultType) + private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type = + MethodType.companion(isContextual = isContextual)(args, resultType) .toFunctionType(isJava = false, alwaysDependent = true) /** Turn `expected` into a dependent function when `actual` is dependent. */ @@ -607,9 +607,9 @@ class CheckCaptures extends Recheck, SymTransformer: def recur(expected: Type): Type = expected.dealias match case expected @ CapturingType(eparent, refs) => CapturingType(recur(eparent), refs, boxed = expected.isBoxed) - case expected @ defn.FunctionOf(args, resultType, isContextual, isErased) + case expected @ defn.FunctionOf(args, resultType, isContextual) if defn.isNonRefinedFunction(expected) && defn.isFunctionType(actual) && !defn.isNonRefinedFunction(actual) => - val expected1 = toDepFun(args, resultType, isContextual, isErased) + val expected1 = toDepFun(args, resultType, isContextual) expected1 case _ => expected @@ -675,7 +675,7 @@ class CheckCaptures extends Recheck, SymTransformer: try val (eargs, eres) = expected.dealias.stripCapturing match - case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres) + case defn.FunctionOf(eargs, eres, _) => (eargs, eres) case expected: MethodType => (expected.paramInfos, expected.resType) case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(expected) => (rinfo.paramInfos, rinfo.resType) case _ => (aargs.map(_ => WildcardType), WildcardType) @@ -739,7 +739,7 @@ class CheckCaptures extends Recheck, SymTransformer: case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) => adaptFun(actual, args.init, args.last, expected, covariant, insertBox, (aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1)) - case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) => + case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(actual) => // TODO Find a way to combine handling of generic and dependent function types (here and elsewhere) adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox, (aargs1, ares1) => @@ -962,7 +962,7 @@ class CheckCaptures extends Recheck, SymTransformer: case CapturingType(parent, refs) => healCaptureSet(refs) traverse(parent) - case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) => + case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) => traverse(rinfo) case tp: TermLambda => val saved = allowed diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 461c18ea0980..5642ea99de1a 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -12,6 +12,7 @@ import transform.Recheck.* import CaptureSet.IdentityCaptRefMap import Synthetics.isExcluded import util.Property +import dotty.tools.dotc.core.Annotations.Annotation /** A tree traverser that prepares a compilation unit to be capture checked. * It does the following: @@ -38,7 +39,6 @@ extends tpd.TreeTraverser: private def depFun(tycon: Type, argTypes: List[Type], resType: Type)(using Context): Type = MethodType.companion( isContextual = defn.isContextFunctionClass(tycon.classSymbol), - isErased = defn.isErasedFunctionClass(tycon.classSymbol) )(argTypes, resType) .toFunctionType(isJava = false, alwaysDependent = true) @@ -54,7 +54,7 @@ extends tpd.TreeTraverser: val boxedRes = recur(res) if boxedRes eq res then tp else tp1.derivedAppliedType(tycon, args.init :+ boxedRes) - case tp1 @ RefinedType(_, _, rinfo) if defn.isFunctionType(tp1) => + case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(tp1) => val boxedRinfo = recur(rinfo) if boxedRinfo eq rinfo then tp else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true) @@ -231,7 +231,7 @@ extends tpd.TreeTraverser: tp.derivedAppliedType(tycon1, args1 :+ res1) else tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg))) - case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) => + case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) => val rinfo1 = apply(rinfo) if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true) else tp @@ -260,7 +260,13 @@ extends tpd.TreeTraverser: private def expandThrowsAlias(tp: Type)(using Context) = tp match case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias => // hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->` - defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = true) + defn.FunctionOf( + AnnotatedType( + defn.CanThrowClass.typeRef.appliedTo(exc), + Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) :: Nil, + res, + isContextual = true + ) case _ => tp private def expandThrowsAliases(using Context) = new TypeMap: @@ -323,7 +329,7 @@ extends tpd.TreeTraverser: args.last, CaptureSet.empty, currentCs ++ outerCs) tp.derivedAppliedType(tycon1, args1 :+ resType1) tp1.capturing(outerCs) - case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionType(tp) => + case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) => propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs) .capturing(outerCs) case _ => diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 8c3f2ad89ca1..20a1d59316eb 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -86,7 +86,7 @@ class Definitions { newPermanentClassSymbol(ScalaPackageClass, name, Artifact, completer).entered } - /** The trait FunctionN, ContextFunctionN, ErasedFunctionN or ErasedContextFunction, for some N + /** The trait FunctionN and ContextFunctionN for some N * @param name The name of the trait to be created * * FunctionN traits follow this template: @@ -104,21 +104,6 @@ class Definitions { * trait ContextFunctionN[-T0,...,-T{N-1}, +R] extends Object { * def apply(using $x0: T0, ..., $x{N_1}: T{N-1}): R * } - * - * ErasedFunctionN traits follow this template: - * - * trait ErasedFunctionN[-T0,...,-T{N-1}, +R] extends Object { - * def apply(erased $x0: T0, ..., $x{N_1}: T{N-1}): R - * } - * - * ErasedContextFunctionN traits follow this template: - * - * trait ErasedContextFunctionN[-T0,...,-T{N-1}, +R] extends Object { - * def apply(using erased $x0: T0, ..., $x{N_1}: T{N-1}): R - * } - * - * ErasedFunctionN and ErasedContextFunctionN erase to Function0. - * * ImpureXYZFunctionN follow this template: * * type ImpureXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R] @@ -149,8 +134,7 @@ class Definitions { val resParamRef = enterTypeParam(cls, paramNamePrefix ++ "R", Covariant, decls).typeRef val methodType = MethodType.companion( isContextual = name.isContextFunction, - isImplicit = false, - isErased = name.isErasedFunction) + isImplicit = false) decls.enter(newMethod(cls, nme.apply, methodType(argParamRefs, resParamRef), Deferred)) denot.info = ClassInfo(ScalaPackageClass.thisType, cls, ObjectType :: Nil, decls) @@ -1109,15 +1093,23 @@ class Definitions { sym.owner.linkedClass.typeRef object FunctionOf { - def apply(args: List[Type], resultType: Type, isContextual: Boolean = false, isErased: Boolean = false)(using Context): Type = - FunctionType(args.length, isContextual, isErased).appliedTo(args ::: resultType :: Nil) - def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean, Boolean)] = { - val tsym = ft.typeSymbol - if isFunctionClass(tsym) && ft.isRef(tsym) then - val targs = ft.dealias.argInfos - if (targs.isEmpty) None - else Some(targs.init, targs.last, tsym.name.isContextFunction, tsym.name.isErasedFunction) - else None + def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type = + val mt = MethodType.companion(isContextual, false)(args, resultType) + if mt.hasErasedParams then + RefinedType(ErasedFunctionClass.typeRef, nme.apply, mt) + else + FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil) + def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = { + ft.dealias match + case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) => + Some(mt.paramInfos, mt.resType, mt.isContextualMethod) + case _ => + val tsym = ft.dealias.typeSymbol + if isFunctionSymbol(tsym) && ft.isRef(tsym) then + val targs = ft.dealias.argInfos + if (targs.isEmpty) None + else Some(targs.init, targs.last, tsym.name.isContextFunction) + else None } } @@ -1436,24 +1428,22 @@ class Definitions { classRefs(n).nn end FunType - private def funTypeIdx(isContextual: Boolean, isErased: Boolean, isImpure: Boolean): Int = + private def funTypeIdx(isContextual: Boolean, isImpure: Boolean): Int = (if isContextual then 1 else 0) - + (if isErased then 2 else 0) - + (if isImpure then 4 else 0) + + (if isImpure then 2 else 0) private val funTypeArray: IArray[FunType] = val arr = Array.ofDim[FunType](8) val choices = List(false, true) - for contxt <- choices; erasd <- choices; impure <- choices do + for contxt <- choices; impure <- choices do var str = "Function" if contxt then str = "Context" + str - if erasd then str = "Erased" + str if impure then str = "Impure" + str - arr(funTypeIdx(contxt, erasd, impure)) = FunType(str) + arr(funTypeIdx(contxt, impure)) = FunType(str) IArray.unsafeFromArray(arr) - def FunctionSymbol(n: Int, isContextual: Boolean = false, isErased: Boolean = false, isImpure: Boolean = false)(using Context): Symbol = - funTypeArray(funTypeIdx(isContextual, isErased, isImpure))(n).symbol + def FunctionSymbol(n: Int, isContextual: Boolean = false, isImpure: Boolean = false)(using Context): Symbol = + funTypeArray(funTypeIdx(isContextual, isImpure))(n).symbol @tu lazy val Function0_apply: Symbol = Function0.requiredMethod(nme.apply) @tu lazy val ContextFunction0_apply: Symbol = ContextFunction0.requiredMethod(nme.apply) @@ -1463,12 +1453,14 @@ class Definitions { @tu lazy val Function2: Symbol = FunctionSymbol(2) @tu lazy val ContextFunction0: Symbol = FunctionSymbol(0, isContextual = true) - def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false, isImpure: Boolean = false)(using Context): TypeRef = - FunctionSymbol(n, isContextual && !ctx.erasedTypes, isErased, isImpure).typeRef + def FunctionType(n: Int, isContextual: Boolean = false, isImpure: Boolean = false)(using Context): TypeRef = + FunctionSymbol(n, isContextual && !ctx.erasedTypes, isImpure).typeRef lazy val PolyFunctionClass = requiredClass("scala.PolyFunction") def PolyFunctionType = PolyFunctionClass.typeRef + lazy val ErasedFunctionClass = requiredClass("scala.runtime.ErasedFunction") + /** If `cls` is a class in the scala package, its name, otherwise EmptyTypeName */ def scalaClassName(cls: Symbol)(using Context): TypeName = cls.denot match case clsd: ClassDenotation if clsd.owner eq ScalaPackageClass => @@ -1501,8 +1493,6 @@ class Definitions { * - FunctionXXL * - FunctionN for N >= 0 * - ContextFunctionN for N >= 0 - * - ErasedFunctionN for N > 0 - * - ErasedContextFunctionN for N > 0 */ def isFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isFunction @@ -1521,12 +1511,6 @@ class Definitions { */ def isContextFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isContextFunction - /** Is an erased function class. - * - ErasedFunctionN for N > 0 - * - ErasedContextFunctionN for N > 0 - */ - def isErasedFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isErasedFunction - /** Is either FunctionXXL or a class that will be erased to FunctionXXL * - FunctionXXL * - FunctionN for N >= 22 @@ -1563,8 +1547,7 @@ class Definitions { */ def functionTypeErasure(cls: Symbol): Type = val arity = scalaClassName(cls).functionArity - if cls.name.isErasedFunction then FunctionType(0) - else if arity > 22 then FunctionXXLClass.typeRef + if arity > 22 then FunctionXXLClass.typeRef else if arity >= 0 then FunctionType(arity) else NoType @@ -1704,16 +1687,29 @@ class Definitions { arity >= 0 && isFunctionClass(sym) && tp.isRef( - FunctionType(arity, sym.name.isContextFunction, sym.name.isErasedFunction).typeSymbol, + FunctionType(arity, sym.name.isContextFunction).typeSymbol, skipRefined = false) end isNonRefinedFunction - /** Is `tp` a representation of a (possibly dependent) function type or an alias of such? */ + /** Returns whether `tp` is an instance or a refined instance of: + * - scala.FunctionN + * - scala.ContextFunctionN + */ def isFunctionType(tp: Type)(using Context): Boolean = isNonRefinedFunction(tp.dropDependentRefinement) + /** Is `tp` a specialized, refined function type? Either an `ErasedFunction` or a `PolyFunction`. */ + def isRefinedFunctionType(tp: Type)(using Context): Boolean = + tp.derivesFrom(defn.PolyFunctionClass) || isErasedFunctionType(tp) + + /** Returns whether `tp` is an instance or a refined instance of: + * - scala.FunctionN + * - scala.ContextFunctionN + * - ErasedFunction + * - PolyFunction + */ def isFunctionOrPolyType(tp: Type)(using Context): Boolean = - isFunctionType(tp) || (tp.typeSymbol eq defn.PolyFunctionClass) + isFunctionType(tp) || isRefinedFunctionType(tp) private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) = for base <- bases; tp <- paramTypes do @@ -1802,7 +1798,7 @@ class Definitions { @tu lazy val FunctionSpecializedApplyNames: collection.Set[Name] = Function0SpecializedApplyNames ++ Function1SpecializedApplyNames ++ Function2SpecializedApplyNames - def functionArity(tp: Type)(using Context): Int = tp.dropDependentRefinement.dealias.argInfos.length - 1 + def functionArity(tp: Type)(using Context): Int = tp.functionArgInfos.length - 1 /** Return underlying context function type (i.e. instance of an ContextFunctionN class) * or NoType if none exists. The following types are considered as underlying types: @@ -1814,6 +1810,8 @@ class Definitions { tp.stripTypeVar.dealias match case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) => asContextFunctionType(TypeComparer.bounds(tp1).hiBound) + case tp1 @ RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) && mt.isContextualMethod => + tp1 case tp1 => if tp1.typeSymbol.name.isContextFunction && isFunctionType(tp1) then tp1 else NoType @@ -1827,18 +1825,28 @@ class Definitions { * types `As`, the result type `B` and a whether the type is an erased context function. */ object ContextFunctionType: - def unapply(tp: Type)(using Context): Option[(List[Type], Type, Boolean)] = + def unapply(tp: Type)(using Context): Option[(List[Type], Type, List[Boolean])] = if ctx.erasedTypes then atPhase(erasurePhase)(unapply(tp)) else - val tp1 = asContextFunctionType(tp) - if tp1.exists then - val args = tp1.dropDependentRefinement.argInfos - Some((args.init, args.last, tp1.typeSymbol.name.isErasedFunction)) - else None + asContextFunctionType(tp) match + case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) => + Some((mt.paramInfos, mt.resType, mt.erasedParams)) + case tp1 if tp1.exists => + val args = tp1.functionArgInfos + val erasedParams = erasedFunctionParameters(tp1) + Some((args.init, args.last, erasedParams)) + case _ => None + + /* Returns a list of erased booleans marking whether parameters are erased, for a function type. */ + def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match { + case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams + case tp if isFunctionType(tp) => List.fill(functionArity(tp)) { false } + case _ => Nil + } def isErasedFunctionType(tp: Type)(using Context): Boolean = - tp.dealias.typeSymbol.name.isErasedFunction && isFunctionType(tp) + tp.derivesFrom(defn.ErasedFunctionClass) /** A whitelist of Scala-2 classes that are known to be pure */ def isAssuredNoInits(sym: Symbol): Boolean = diff --git a/compiler/src/dotty/tools/dotc/core/Denotations.scala b/compiler/src/dotty/tools/dotc/core/Denotations.scala index 723f9408d805..82368fd4dbf5 100644 --- a/compiler/src/dotty/tools/dotc/core/Denotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Denotations.scala @@ -545,8 +545,7 @@ object Denotations { tp2 match case tp2: MethodType if TypeComparer.matchingMethodParams(tp1, tp2) - && tp1.isImplicitMethod == tp2.isImplicitMethod - && tp1.isErasedMethod == tp2.isErasedMethod => + && tp1.isImplicitMethod == tp2.isImplicitMethod => val resType = infoMeet(tp1.resType, tp2.resType.subst(tp2, tp1), safeIntersection) if resType.exists then tp1.derivedLambdaType(mergeParamNames(tp1, tp2), tp1.paramInfos, resType) diff --git a/compiler/src/dotty/tools/dotc/core/NameOps.scala b/compiler/src/dotty/tools/dotc/core/NameOps.scala index 4e075953d7fa..04440c9e9b39 100644 --- a/compiler/src/dotty/tools/dotc/core/NameOps.scala +++ b/compiler/src/dotty/tools/dotc/core/NameOps.scala @@ -214,7 +214,7 @@ object NameOps { if str == mustHave then found = true idx + str.length else idx - skip(skip(skip(0, "Impure"), "Erased"), "Context") == suffixStart + skip(skip(0, "Impure"), "Context") == suffixStart && found } @@ -225,10 +225,11 @@ object NameOps { private def checkedFunArity(suffixStart: Int)(using Context): Int = if isFunctionPrefix(suffixStart) then funArity(suffixStart) else -1 - /** Is a function name, i.e one of FunctionXXL, FunctionN, ContextFunctionN, ErasedFunctionN, ErasedContextFunctionN for N >= 0 + /** Is a function name, i.e one of FunctionXXL, FunctionN, ContextFunctionN, ImpureFunctionN, ImpureContextFunctionN for N >= 0 */ def isFunction(using Context): Boolean = - (name eq tpnme.FunctionXXL) || checkedFunArity(functionSuffixStart) >= 0 + (name eq tpnme.FunctionXXL) + || checkedFunArity(functionSuffixStart) >= 0 /** Is a function name * - FunctionN for N >= 0 @@ -241,14 +242,11 @@ object NameOps { isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= 0 def isContextFunction(using Context): Boolean = isSpecificFunction("Context") - def isErasedFunction(using Context): Boolean = isSpecificFunction("Erased") def isImpureFunction(using Context): Boolean = isSpecificFunction("Impure") /** Is a synthetic function name, i.e. one of * - FunctionN for N > 22 * - ContextFunctionN for N >= 0 - * - ErasedFunctionN for N >= 0 - * - ErasedContextFunctionN for N >= 0 */ def isSyntheticFunction(using Context): Boolean = val suffixStart = functionSuffixStart diff --git a/compiler/src/dotty/tools/dotc/core/NamerOps.scala b/compiler/src/dotty/tools/dotc/core/NamerOps.scala index db6f72590818..dc09edd79781 100644 --- a/compiler/src/dotty/tools/dotc/core/NamerOps.scala +++ b/compiler/src/dotty/tools/dotc/core/NamerOps.scala @@ -42,10 +42,10 @@ object NamerOps: case Nil => resultType case TermSymbols(params) :: paramss1 => - val (isContextual, isImplicit, isErased) = - if params.isEmpty then (false, false, false) - else (params.head.is(Given), params.head.is(Implicit), params.head.is(Erased)) - val make = MethodType.companion(isContextual = isContextual, isImplicit = isImplicit, isErased = isErased) + val (isContextual, isImplicit) = + if params.isEmpty then (false, false) + else (params.head.is(Given), params.head.is(Implicit)) + val make = MethodType.companion(isContextual = isContextual, isImplicit = isImplicit) if isJava then for param <- params do if param.info.isDirectRef(defn.ObjectClass) then param.info = defn.AnyType diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 92f2e55a49bf..27e97a92b48e 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -212,6 +212,7 @@ object StdNames { final val Throwable: N = "Throwable" final val IOOBException: N = "IndexOutOfBoundsException" final val FunctionXXL: N = "FunctionXXL" + final val ErasedFunction: N = "ErasedFunction" final val Abs: N = "Abs" final val And: N = "&&" diff --git a/compiler/src/dotty/tools/dotc/core/TypeApplications.scala b/compiler/src/dotty/tools/dotc/core/TypeApplications.scala index 7c25ecd21ebf..2e8aee4df96c 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeApplications.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeApplications.scala @@ -9,9 +9,11 @@ import SymDenotations.LazyType import Decorators._ import util.Stats._ import Names._ +import StdNames.nme import Flags.{Module, Provisional} import dotty.tools.dotc.config.Config import cc.boxedUnlessFun +import dotty.tools.dotc.transform.TypeUtils.isErasedValueType object TypeApplications { @@ -503,6 +505,14 @@ class TypeApplications(val self: Type) extends AnyVal { case AppliedType(tycon, args) => args.boxedUnlessFun(tycon) case _ => Nil + /** If this is an encoding of a function type, return its arguments, otherwise return Nil. + * Handles `ErasedFunction`s and poly functions gracefully. + */ + final def functionArgInfos(using Context): List[Type] = self.dealias match + case RefinedType(parent, nme.apply, mt: MethodType) if defn.isErasedFunctionType(parent) => (mt.paramInfos :+ mt.resultType) + case RefinedType(parent, nme.apply, mt: MethodType) if parent.typeSymbol eq defn.PolyFunctionClass => (mt.paramInfos :+ mt.resultType) + case _ => self.dropDependentRefinement.dealias.argInfos + /** Argument types where existential types in arguments are disallowed */ def argTypes(using Context): List[Type] = argInfos mapConserve noBounds diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 9e8d18765352..f097bd160fdd 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2119,7 +2119,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case nil => formals2.isEmpty } - loop(tp1.paramInfos, tp2.paramInfos) + // If methods have erased parameters, then the erased parameters must match + val erasedValid = (!tp1.hasErasedParams && !tp2.hasErasedParams) || (tp1.erasedParams == tp2.erasedParams) + + erasedValid && loop(tp1.paramInfos, tp2.paramInfos) } /** Do the parameter types of `tp1` and `tp2` match in a way that allows `tp1` diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index 67839d10c8cd..9bcb3eca36bb 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -536,7 +536,14 @@ object TypeErasure { val paramss = res.paramNamess assert(paramss.length == 1) erasure(defn.FunctionType(paramss.head.length, - isContextual = res.isImplicitMethod, isErased = res.isErasedMethod)) + isContextual = res.isImplicitMethod)) + + def eraseErasedFunctionApply(erasedFn: MethodType)(using Context): Type = + val fnType = defn.FunctionType( + n = erasedFn.erasedParams.count(_ == false), + isContextual = erasedFn.isContextualMethod, + ) + erasure(fnType) } import TypeErasure._ @@ -613,6 +620,8 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst defn.FunctionType(0) case RefinedType(parent, nme.apply, refinedInfo) if parent.typeSymbol eq defn.PolyFunctionClass => erasePolyFunctionApply(refinedInfo) + case RefinedType(parent, nme.apply, refinedInfo: MethodType) if defn.isErasedFunctionType(parent) => + eraseErasedFunctionApply(refinedInfo) case tp: TypeProxy => this(tp.underlying) case tp @ AndType(tp1, tp2) => @@ -639,7 +648,13 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst case tp: MethodType => def paramErasure(tpToErase: Type) = erasureFn(sourceLanguage, semiEraseVCs, isConstructor, isSymbol, wildcardOK)(tpToErase) - val (names, formals0) = if (tp.isErasedMethod) (Nil, Nil) else (tp.paramNames, tp.paramInfos) + val (names, formals0) = if tp.hasErasedParams then + tp.paramNames + .zip(tp.paramInfos) + .zip(tp.erasedParams) + .collect{ case (param, isErased) if !isErased => param } + .unzip + else (tp.paramNames, tp.paramInfos) val formals = formals0.mapConserve(paramErasure) eraseResult(tp.resultType) match { case rt: MethodType => @@ -871,6 +886,8 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst // because RefinedTypes <: TypeProxy and it would be caught by // the case immediately below sigName(this(tp)) + case tp @ RefinedType(parent, nme.apply, refinedInfo) if defn.isErasedFunctionType(parent) => + sigName(this(tp)) case tp: TypeProxy => sigName(tp.underlying) case tp: WildcardType => diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index c05bb164834d..c660ef657b13 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -43,6 +43,7 @@ import scala.annotation.internal.sharable import scala.annotation.threadUnsafe import dotty.tools.dotc.transform.SymUtils._ +import dotty.tools.dotc.transform.TypeUtils.isErasedClass object Types { @@ -425,7 +426,7 @@ object Types { def isContextualMethod: Boolean = false /** Is this a MethodType for which the parameters will not be used? */ - def isErasedMethod: Boolean = false + def hasErasedParams(using Context): Boolean = false /** Is this a match type or a higher-kinded abstraction of one? */ @@ -1180,7 +1181,8 @@ object Types { /** Remove all AnnotatedTypes wrapping this type. */ - def stripAnnots(using Context): Type = this + def stripAnnots(keep: Annotation => Context ?=> Boolean)(using Context): Type = this + final def stripAnnots(using Context): Type = stripAnnots(_ => false) /** Strip TypeVars and Annotation and CapturingType wrappers */ def stripped(using Context): Type = this @@ -1470,7 +1472,7 @@ object Types { /** Dealias, and if result is a dependent function type, drop the `apply` refinement. */ final def dropDependentRefinement(using Context): Type = dealias match { - case RefinedType(parent, nme.apply, _) => parent + case RefinedType(parent, nme.apply, mt) if defn.isNonRefinedFunction(parent) => parent case tp => tp } @@ -1712,6 +1714,8 @@ object Types { else NoType case t if defn.isNonRefinedFunction(t) => t + case t if defn.isErasedFunctionType(t) => + t case t @ SAMType(_) => t case _ => @@ -1839,15 +1843,15 @@ object Types { case mt: MethodType if !mt.isParamDependent => val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast val isContextual = mt.isContextualMethod && !ctx.erasedTypes - val isErased = mt.isErasedMethod && !ctx.erasedTypes val result1 = mt.nonDependentResultApprox match { case res: MethodType => res.toFunctionType(isJava) case res => res } val funType = defn.FunctionOf( formals1 mapConserve (_.translateFromRepeated(toArray = isJava)), - result1, isContextual, isErased) - if alwaysDependent || mt.isResultDependent then RefinedType(funType, nme.apply, mt) + result1, isContextual) + if alwaysDependent || mt.isResultDependent then + RefinedType(funType, nme.apply, mt) else funType } @@ -3648,6 +3652,8 @@ object Types { def companion: LambdaTypeCompanion[ThisName, PInfo, This] + def erasedParams(using Context) = List.fill(paramInfos.size)(false) + /** The type `[tparams := paramRefs] tp`, where `tparams` can be * either a list of type parameter symbols or a list of lambda parameters * @@ -3725,7 +3731,11 @@ object Types { else Signature(tp, sourceLanguage) this match case tp: MethodType => - val params = if (isErasedMethod) Nil else tp.paramInfos + val params = if (hasErasedParams) + tp.paramInfos + .zip(tp.erasedParams) + .collect { case (param, isErased) if !isErased => param } + else tp.paramInfos resultSignature.prependTermParams(params, sourceLanguage) case tp: PolyType => resultSignature.prependTypeParams(tp.paramNames.length) @@ -3932,16 +3942,14 @@ object Types { def companion: MethodTypeCompanion final override def isImplicitMethod: Boolean = - companion.eq(ImplicitMethodType) || - companion.eq(ErasedImplicitMethodType) || - isContextualMethod - final override def isErasedMethod: Boolean = - companion.eq(ErasedMethodType) || - companion.eq(ErasedImplicitMethodType) || - companion.eq(ErasedContextualMethodType) + companion.eq(ImplicitMethodType) || isContextualMethod + final override def hasErasedParams(using Context): Boolean = + erasedParams.contains(true) final override def isContextualMethod: Boolean = - companion.eq(ContextualMethodType) || - companion.eq(ErasedContextualMethodType) + companion.eq(ContextualMethodType) + + override def erasedParams(using Context): List[Boolean] = + paramInfos.map(p => p.hasAnnotation(defn.ErasedParamAnnot)) protected def prefixString: String = companion.prefixString } @@ -4038,7 +4046,7 @@ object Types { tl => tl.integrate(params, resultType)) end fromSymbols - final def apply(paramNames: List[TermName])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type)(using Context): MethodType = + def apply(paramNames: List[TermName])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type)(using Context): MethodType = checkValid(unique(new CachedMethodType(paramNames)(paramInfosExp, resultTypeExp, self))) def checkValid(mt: MethodType)(using Context): mt.type = { @@ -4053,19 +4061,14 @@ object Types { } object MethodType extends MethodTypeCompanion("MethodType") { - def companion(isContextual: Boolean = false, isImplicit: Boolean = false, isErased: Boolean = false): MethodTypeCompanion = - if (isContextual) - if (isErased) ErasedContextualMethodType else ContextualMethodType - else if (isImplicit) - if (isErased) ErasedImplicitMethodType else ImplicitMethodType - else - if (isErased) ErasedMethodType else MethodType + def companion(isContextual: Boolean = false, isImplicit: Boolean = false): MethodTypeCompanion = + if (isContextual) ContextualMethodType + else if (isImplicit) ImplicitMethodType + else MethodType } - object ErasedMethodType extends MethodTypeCompanion("ErasedMethodType") + object ContextualMethodType extends MethodTypeCompanion("ContextualMethodType") - object ErasedContextualMethodType extends MethodTypeCompanion("ErasedContextualMethodType") object ImplicitMethodType extends MethodTypeCompanion("ImplicitMethodType") - object ErasedImplicitMethodType extends MethodTypeCompanion("ErasedImplicitMethodType") /** A ternary extractor for MethodType */ object MethodTpe { @@ -5280,7 +5283,10 @@ object Types { override def stripTypeVar(using Context): Type = derivedAnnotatedType(parent.stripTypeVar, annot) - override def stripAnnots(using Context): Type = parent.stripAnnots + override def stripAnnots(keep: Annotation => (Context) ?=> Boolean)(using Context): Type = + val p = parent.stripAnnots(keep) + if keep(annot) then derivedAnnotatedType(p, annot) + else p override def stripped(using Context): Type = parent.stripped diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index bef28545592a..8a396921f32b 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -287,7 +287,6 @@ class TreePickler(pickler: TastyPickler) { var mods = EmptyFlags if tpe.isContextualMethod then mods |= Given else if tpe.isImplicitMethod then mods |= Implicit - if tpe.isErasedMethod then mods |= Erased pickleMethodic(METHODtype, tpe, mods) case tpe: ParamRef => assert(pickleParamRef(tpe), s"orphan parameter reference: $tpe") diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index dfe04dbe6d2b..9078a8959112 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -249,7 +249,6 @@ class TreeUnpickler(reader: TastyReader, while currentAddr != end do // avoid boxing the mods readByte() match case IMPLICIT => mods |= Implicit - case ERASED => mods |= Erased case GIVEN => mods |= Given (names, mods) @@ -406,9 +405,7 @@ class TreeUnpickler(reader: TastyReader, case METHODtype => def methodTypeCompanion(mods: FlagSet): MethodTypeCompanion = if mods.is(Implicit) then ImplicitMethodType - else if mods.isAllOf(Erased | Given) then ErasedContextualMethodType else if mods.is(Given) then ContextualMethodType - else if mods.is(Erased) then ErasedMethodType else MethodType readMethodic(methodTypeCompanion, _.toTermName) case TYPELAMBDAtype => diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 479ae1fa9095..d1b0c6cba097 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -190,6 +190,8 @@ object Parsers { def isPureArrow(name: Name): Boolean = isIdent(name) && Feature.pureFunsEnabled def isPureArrow: Boolean = isPureArrow(nme.PUREARROW) || isPureArrow(nme.PURECTXARROW) def isErased = isIdent(nme.erased) && in.erasedEnabled + // Are we seeing an `erased` soft keyword that will not be an identifier? + def isErasedKw = isErased && in.isSoftModifierInParamModifierPosition def isSimpleLiteral = simpleLiteralTokens.contains(in.token) || isIdent(nme.raw.MINUS) && numericLitTokens.contains(in.lookahead.token) @@ -463,6 +465,15 @@ object Parsers { case _ => fail() + /** Checks that tuples don't contain a parameter. */ + def checkNonParamTuple(t: Tree) = t match + case Tuple(ts) => ts.collectFirst { + case param: ValDef => + syntaxError(em"invalid parameter definition syntax in tuple value", param.span) + } + case _ => + + /** Convert (qual)ident to type identifier */ def convertToTypeId(tree: Tree): Tree = tree match { @@ -1425,13 +1436,30 @@ object Parsers { */ def toplevelTyp(): Tree = rejectWildcardType(typ()) - private def isFunction(tree: Tree): Boolean = tree match { - case Parens(tree1) => isFunction(tree1) - case Block(Nil, tree1) => isFunction(tree1) - case _: Function => true - case _ => false + private def getFunction(tree: Tree): Option[Function] = tree match { + case Parens(tree1) => getFunction(tree1) + case Block(Nil, tree1) => getFunction(tree1) + case t: Function => Some(t) + case _ => None } + private def checkFunctionNotErased(f: Function, context: String) = + def fail(span: Span) = + syntaxError(em"Implementation restriction: erased parameters are not supported in $context", span) + // erased parameter in type + val hasErasedParam = f match + case f: FunctionWithMods => f.hasErasedParams + case _ => false + if hasErasedParam then + fail(f.span) + // erased parameter in term + val hasErasedMods = f.args.collectFirst { + case v: ValDef if v.mods.is(Flags.Erased) => v + } + hasErasedMods match + case Some(param) => fail(param.span) + case _ => + /** CaptureRef ::= ident | `this` */ def captureRef(): Tree = @@ -1464,6 +1492,7 @@ object Parsers { def typ(): Tree = val start = in.offset var imods = Modifiers() + var erasedArgs: ListBuffer[Boolean] = ListBuffer() def functionRest(params: List[Tree]): Tree = val paramSpan = Span(start, in.lastOffset) atSpan(start, in.offset) { @@ -1495,10 +1524,10 @@ object Parsers { if isByNameType(tpt) then syntaxError(em"parameter of type lambda may not be call-by-name", tpt.span) TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], resultType) - else if imods.isOneOf(Given | Erased | Impure) then + else if imods.isOneOf(Given | Impure) || erasedArgs.contains(true) then if imods.is(Given) && params.isEmpty then syntaxError(em"context function types require at least one parameter", paramSpan) - FunctionWithMods(params, resultType, imods) + FunctionWithMods(params, resultType, imods, erasedArgs.toList) else if !ctx.settings.YkindProjector.isDefault then val (newParams :+ newResultType, tparams) = replaceKindProjectorPlaceholders(params :+ resultType): @unchecked lambdaAbstract(tparams, Function(newParams, newResultType)) @@ -1516,17 +1545,30 @@ object Parsers { functionRest(Nil) } else { - if isErased then imods = addModifier(imods) val paramStart = in.offset + def addErased() = + erasedArgs.addOne(isErasedKw) + if isErasedKw then { in.skipToken(); } + addErased() val ts = in.currentRegion.withCommasExpected { funArgType() match case Ident(name) if name != tpnme.WILDCARD && in.isColon => isValParamList = true + def funParam(start: Offset, mods: Modifiers) = { + atSpan(start) { + addErased() + typedFunParam(in.offset, ident(), imods) + } + } commaSeparatedRest( typedFunParam(paramStart, name.toTermName, imods), - () => typedFunParam(in.offset, ident(), imods)) + () => funParam(in.offset, imods)) case t => - commaSeparatedRest(t, funArgType) + def funParam() = { + addErased() + funArgType() + } + commaSeparatedRest(t, funParam) } accept(RPAREN) if isValParamList || in.isArrow || isPureArrow then @@ -1557,11 +1599,13 @@ object Parsers { val arrowOffset = in.skipToken() val body = toplevelTyp() atSpan(start, arrowOffset) { - if (isFunction(body)) - PolyFunction(tparams, body) - else { - syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset) - Ident(nme.ERROR.toTypeName) + getFunction(body) match { + case Some(f) => + checkFunctionNotErased(f, "poly function") + PolyFunction(tparams, body) + case None => + syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset) + Ident(nme.ERROR.toTypeName) } } } @@ -1573,14 +1617,17 @@ object Parsers { else infixType() in.token match - case ARROW | CTXARROW => functionRest(t :: Nil) + case ARROW | CTXARROW => + erasedArgs.addOne(false) + functionRest(t :: Nil) case MATCH => matchType(t) case FORSOME => syntaxError(ExistentialTypesNoLongerSupported()); t case _ => if isPureArrow then + erasedArgs.addOne(false) functionRest(t :: Nil) else - if (imods.is(Erased) && !t.isInstanceOf[FunctionWithMods]) + if (erasedArgs.contains(true) && !t.isInstanceOf[FunctionWithMods]) syntaxError(ErasedTypesCanOnlyBeFunctionTypes(), implicitKwPos(start)) t end typ @@ -2078,24 +2125,22 @@ object Parsers { def expr(location: Location): Tree = { val start = in.offset - def isSpecialClosureStart = in.lookahead.isIdent(nme.erased) && in.erasedEnabled in.token match case IMPLICIT => closure(start, location, modifiers(BitSet(IMPLICIT))) - case LPAREN if isSpecialClosureStart => - closure(start, location, Modifiers()) case LBRACKET => val start = in.offset val tparams = typeParamClause(ParamOwner.TypeParam) val arrowOffset = accept(ARROW) val body = expr(location) atSpan(start, arrowOffset) { - if (isFunction(body)) - PolyFunction(tparams, body) - else { - syntaxError(em"Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset) - errorTermTree(arrowOffset) - } + getFunction(body) match + case Some(f) => + checkFunctionNotErased(f, "poly function") + PolyFunction(tparams, f) + case None => + syntaxError(em"Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset) + errorTermTree(arrowOffset) } case _ => val saved = placeholderParams @@ -2113,7 +2158,9 @@ object Parsers { else if isWildcard(t) then placeholderParams = placeholderParams ::: saved t - else wrapPlaceholders(t) + else + checkNonParamTuple(t) + wrapPlaceholders(t) } def expr1(location: Location = Location.ElseWhere): Tree = in.token match @@ -2305,10 +2352,8 @@ object Parsers { if in.token == RPAREN then Nil else - var mods1 = mods - if isErased then mods1 = addModifier(mods1) try - commaSeparated(() => binding(mods1)) + commaSeparated(() => binding(mods)) finally accept(RPAREN) else { @@ -2332,10 +2377,13 @@ object Parsers { (atSpan(start) { makeParameter(name, t, mods) }) :: Nil } - /** Binding ::= (id | `_') [`:' Type] + /** Binding ::= [`erased`] (id | `_') [`:' Type] */ def binding(mods: Modifiers): Tree = - atSpan(in.offset) { makeParameter(bindingName(), typedOpt(), mods) } + atSpan(in.offset) { + val mods1 = if isErasedKw then addModifier(mods) else mods + makeParameter(bindingName(), typedOpt(), mods1) + } def bindingName(): TermName = if (in.token == USCORE) { @@ -2532,6 +2580,7 @@ object Parsers { else in.currentRegion.withCommasExpected { var isFormalParams = false def exprOrBinding() = + if isErasedKw then isFormalParams = true if isFormalParams then binding(Modifiers()) else val t = exprInParens() @@ -3175,7 +3224,7 @@ object Parsers { val paramFlags = if ofClass then LocalParamAccessor else Param tps.map(makeSyntheticParameter(nextIdx, _, paramFlags | Synthetic | impliedMods.flags)) - /** ClsTermParamClause ::= ‘(’ [‘erased’] ClsParams ‘)’ | UsingClsTermParamClause + /** ClsTermParamClause ::= ‘(’ ClsParams ‘)’ | UsingClsTermParamClause * UsingClsTermParamClause::= ‘(’ ‘using’ [‘erased’] (ClsParams | ContextTypes) ‘)’ * ClsParams ::= ClsParam {‘,’ ClsParam} * ClsParam ::= {Annotation} @@ -3184,10 +3233,10 @@ object Parsers { * | UsingParamClause * * DefTermParamClause::= [nl] ‘(’ [DefTermParams] ‘)’ - * UsingParamClause ::= ‘(’ ‘using’ [‘erased’] (DefTermParams | ContextTypes) ‘)’ + * UsingParamClause ::= ‘(’ ‘using’ (DefTermParams | ContextTypes) ‘)’ * DefImplicitClause ::= [nl] ‘(’ ‘implicit’ DefTermParams ‘)’ * DefTermParams ::= DefTermParam {‘,’ DefTermParam} - * DefTermParam ::= {Annotation} [‘inline’] Param + * DefTermParam ::= {Annotation} [‘erased’] [‘inline’] Param * * Param ::= id `:' ParamType [`=' Expr] * @@ -3211,12 +3260,12 @@ object Parsers { else if isIdent(nme.using) then addParamMod(() => Mod.Given()) - if isErased then - addParamMod(() => Mod.Erased()) def param(): ValDef = { val start = in.offset var mods = impliedMods.withAnnotations(annotations()) + if isErasedKw then + mods = addModifier(mods) if (ofClass) { mods = addFlag(modifiers(start = mods), ParamAccessor) mods = @@ -3227,7 +3276,7 @@ object Parsers { val mod = atSpan(in.skipToken()) { Mod.Var() } addMod(mods, mod) else - if (!(mods.flags &~ (ParamAccessor | Inline | impliedMods.flags)).isEmpty) + if (!(mods.flags &~ (ParamAccessor | Inline | Erased | impliedMods.flags)).isEmpty) syntaxError(em"`val` or `var` expected") if (firstClause && ofCaseClass) mods else mods | PrivateLocal @@ -3275,12 +3324,22 @@ object Parsers { paramMods() if givenOnly && !impliedMods.is(Given) then syntaxError(em"`using` expected") - val isParams = - !impliedMods.is(Given) - || startParamTokens.contains(in.token) - || isIdent && (in.name == nme.inline || in.lookahead.isColon) - if isParams then commaSeparated(() => param()) - else contextTypes(ofClass, numLeadParams, impliedMods) + val (firstParamMod, isParams) = + var mods = EmptyModifiers + if in.lookahead.isColon then + (mods, true) + else + if isErased then mods = addModifier(mods) + val isParams = + !impliedMods.is(Given) + || startParamTokens.contains(in.token) + || isIdent && (in.name == nme.inline || in.lookahead.isColon) + (mods, isParams) + (if isParams then commaSeparated(() => param()) + else contextTypes(ofClass, numLeadParams, impliedMods)) match { + case Nil => Nil + case (h :: t) => h.withAddedFlags(firstParamMod.flags) :: t + } checkVarArgsRules(clause) clause } diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 0da1993310c6..ee0062f77dcd 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -120,10 +120,10 @@ class PlainPrinter(_ctx: Context) extends Printer { } (keyword ~ refinementNameString(rt) ~ toTextRHS(rt.refinedInfo)).close - protected def argText(arg: Type): Text = homogenizeArg(arg) match { + protected def argText(arg: Type, isErased: Boolean = false): Text = keywordText("erased ").provided(isErased) ~ (homogenizeArg(arg) match { case arg: TypeBounds => "?" ~ toText(arg) case arg => toText(arg) - } + }) /** Pretty-print comma-separated type arguments for a constructor to be inserted among parentheses or brackets * (hence with `GlobalPrec` precedence). @@ -235,7 +235,6 @@ class PlainPrinter(_ctx: Context) extends Printer { changePrec(GlobalPrec) { "(" ~ keywordText("using ").provided(tp.isContextualMethod) - ~ keywordText("erased ").provided(tp.isErasedMethod) ~ keywordText("implicit ").provided(tp.isImplicitMethod && !tp.isContextualMethod) ~ paramsText(tp) ~ ")" @@ -296,9 +295,10 @@ class PlainPrinter(_ctx: Context) extends Printer { "(" ~ toTextRef(tp) ~ " : " ~ toTextGlobal(tp.underlying) ~ ")" protected def paramsText(lam: LambdaType): Text = { - def paramText(name: Name, tp: Type) = - toText(name) ~ lambdaHash(lam) ~ toTextRHS(tp, isParameter = true) - Text(lam.paramNames.lazyZip(lam.paramInfos).map(paramText), ", ") + val erasedParams = lam.erasedParams + def paramText(name: Name, tp: Type, erased: Boolean) = + keywordText("erased ").provided(erased) ~ toText(name) ~ lambdaHash(lam) ~ toTextRHS(tp, isParameter = true) + Text(lam.paramNames.lazyZip(lam.paramInfos).lazyZip(erasedParams).map(paramText), ", ") } protected def ParamRefNameString(name: Name): String = nameString(name) diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 8ffb99f073fb..014e5ddf0d66 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -148,17 +148,16 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { def toTextTuple(args: List[Type]): Text = "(" ~ argsText(args) ~ ")" - def toTextFunction(args: List[Type], isGiven: Boolean, isErased: Boolean, isPure: Boolean): Text = + def toTextFunction(args: List[Type], isGiven: Boolean, isPure: Boolean): Text = changePrec(GlobalPrec) { val argStr: Text = if args.length == 2 && !defn.isTupleNType(args.head) - && !isGiven && !isErased + && !isGiven then atPrec(InfixPrec) { argText(args.head) } else "(" - ~ keywordText("erased ").provided(isErased) ~ argsText(args.init) ~ ")" argStr ~ " " ~ arrow(isGiven, isPure) ~ " " ~ argText(args.last) @@ -168,7 +167,6 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { case info: MethodType => changePrec(GlobalPrec) { "(" - ~ keywordText("erased ").provided(info.isErasedMethod) ~ paramsText(info) ~ ") " ~ arrow(info.isImplicitMethod, isPure) @@ -226,7 +224,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if tycon.isRepeatedParam then toTextLocal(args.head) ~ "*" else if tp.isConvertibleParam then "into " ~ toText(args.head) else if defn.isFunctionSymbol(tsym) then - toTextFunction(args, tsym.name.isContextFunction, tsym.name.isErasedFunction, + toTextFunction(args, tsym.name.isContextFunction, isPure = Feature.pureFunsEnabled && !tsym.name.isImpureFunction) else if isInfixType(tp) then val l :: r :: Nil = args: @unchecked @@ -289,7 +287,6 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { case tp @ FunProto(args, resultType) => "[applied to (" ~ keywordText("using ").provided(tp.isContextualMethod) - ~ keywordText("erased ").provided(tp.isErasedMethod) ~ argsTreeText(args) ~ ") returning " ~ toText(resultType) @@ -650,27 +647,29 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { case str: Literal => strText(str) } toText(id) ~ "\"" ~ Text(segments map segmentText, "") ~ "\"" - case Function(args, body) => + case fn @ Function(args, body) => var implicitSeen: Boolean = false var isGiven: Boolean = false - var isErased: Boolean = false - def argToText(arg: Tree) = arg match { + val erasedParams = fn match { + case fn: FunctionWithMods => fn.erasedParams + case _ => fn.args.map(_ => false) + } + def argToText(arg: Tree, isErased: Boolean) = arg match { case arg @ ValDef(name, tpt, _) => val implicitText = if ((arg.mods.is(Given))) { isGiven = true; "" } - else if ((arg.mods.is(Erased))) { isErased = true; "" } else if ((arg.mods.is(Implicit)) && !implicitSeen) { implicitSeen = true; keywordStr("implicit ") } else "" - implicitText ~ toText(name) ~ optAscription(tpt) + val erasedText = if isErased then keywordStr("erased ") else "" + implicitText ~ erasedText ~ toText(name) ~ optAscription(tpt) case _ => toText(arg) } val argsText = args match { - case (arg @ ValDef(_, tpt, _)) :: Nil if tpt.isEmpty => argToText(arg) + case (arg @ ValDef(_, tpt, _)) :: Nil if tpt.isEmpty => argToText(arg, erasedParams(0)) case _ => "(" - ~ keywordText("erased ").provided(isErased) - ~ Text(args.map(argToText), ", ") + ~ Text(args.zip(erasedParams).map(argToText), ", ") ~ ")" } val isPure = @@ -870,7 +869,6 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { "()" case untpd.ValDefs(vparams @ (vparam :: _)) => "(" ~ keywordText("using ").provided(vparam.mods.is(Given)) - ~ keywordText("erased ").provided(vparam.mods.is(Erased)) ~ toText(vparams, ", ") ~ ")" case untpd.TypeDefs(tparams) => "[" ~ toText(tparams, ", ") ~ "]" @@ -1032,7 +1030,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { else PrintableFlags(isType) if (homogenizedView && mods.flags.isTypeFlags) flagMask &~= GivenOrImplicit // drop implicit/given from classes val rawFlags = if (sym.exists) sym.flagsUNSAFE else mods.flags - if (rawFlags.is(Param)) flagMask = flagMask &~ Given &~ Erased + if (rawFlags.is(Param)) flagMask = flagMask &~ Given val flags = rawFlags & flagMask var flagsText = toTextFlags(sym, flags) val annotTexts = diff --git a/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala b/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala index 851e3c422460..aadfedd2417c 100644 --- a/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala +++ b/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala @@ -128,7 +128,7 @@ class Interpreter(pos: SrcPos, classLoader0: ClassLoader)(using Context): view.toList fnType.dealias match - case fnType: MethodType if fnType.isErasedMethod => interpretArgs(argss, fnType.resType) + case fnType: MethodType if fnType.hasErasedParams => interpretArgs(argss, fnType.resType) case fnType: MethodType => val argTypes = fnType.paramInfos assert(argss.head.size == argTypes.size) @@ -342,7 +342,7 @@ object Interpreter: case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil)) case fn: Select => Some((fn, Nil)) case Apply(f @ Call0(fn, args1), args2) => - if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1)) + if (f.tpe.widenDealias.hasErasedParams) Some((fn, args1)) else Some((fn, args2 :: args1)) case TypeApply(Call0(fn, args), _) => Some((fn, args)) case _ => None diff --git a/compiler/src/dotty/tools/dotc/transform/Bridges.scala b/compiler/src/dotty/tools/dotc/transform/Bridges.scala index e302170991f9..569b16681cde 100644 --- a/compiler/src/dotty/tools/dotc/transform/Bridges.scala +++ b/compiler/src/dotty/tools/dotc/transform/Bridges.scala @@ -129,9 +129,12 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) { assert(ctx.typer.isInstanceOf[Erasure.Typer]) ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType) else - val defn.ContextFunctionType(argTypes, resType, isErased) = tp: @unchecked + val defn.ContextFunctionType(argTypes, resType, erasedParams) = tp: @unchecked val anonFun = newAnonFun(ctx.owner, - MethodType(if isErased then Nil else argTypes, resType), + MethodType( + argTypes.zip(erasedParams.padTo(argTypes.length, false)) + .flatMap((t, e) => if e then None else Some(t)), + resType), coord = ctx.owner.coord) anonFun.info = transformInfo(anonFun, anonFun.info) diff --git a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala index 2ab910f6d06e..5863c360e728 100644 --- a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala +++ b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala @@ -20,7 +20,7 @@ object ContextFunctionResults: */ def annotateContextResults(mdef: DefDef)(using Context): Unit = def contextResultCount(rhs: Tree, tp: Type): Int = tp match - case defn.ContextFunctionType(_, resTpe, _) => + case defn.ContextFunctionType(_, resTpe, erasedParams) if !erasedParams.contains(true) /* Only enable for non-erased functions */ => rhs match case closureDef(meth) => 1 + contextResultCount(meth.rhs, resTpe) case _ => 0 @@ -58,7 +58,7 @@ object ContextFunctionResults: */ def contextResultsAreErased(sym: Symbol)(using Context): Boolean = def allErased(tp: Type): Boolean = tp.dealias match - case defn.ContextFunctionType(_, resTpe, isErased) => isErased && allErased(resTpe) + case defn.ContextFunctionType(_, resTpe, erasedParams) => !erasedParams.contains(false) && allErased(resTpe) case _ => true contextResultCount(sym) > 0 && allErased(sym.info.finalResultType) @@ -72,10 +72,8 @@ object ContextFunctionResults: integrateContextResults(rt, crCount) case tp: MethodOrPoly => tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount)) - case defn.ContextFunctionType(argTypes, resType, isErased) => - val methodType: MethodTypeCompanion = - if isErased then ErasedMethodType else MethodType - methodType(argTypes, integrateContextResults(resType, crCount - 1)) + case defn.ContextFunctionType(argTypes, resType, erasedParams) => + MethodType(argTypes, integrateContextResults(resType, crCount - 1)) /** The total number of parameters of method `sym`, not counting * erased parameters, but including context result parameters. @@ -85,14 +83,16 @@ object ContextFunctionResults: def contextParamCount(tp: Type, crCount: Int): Int = if crCount == 0 then 0 else - val defn.ContextFunctionType(params, resTpe, isErased) = tp: @unchecked + val defn.ContextFunctionType(params, resTpe, erasedParams) = tp: @unchecked val rest = contextParamCount(resTpe, crCount - 1) - if isErased then rest else params.length + rest + if erasedParams.contains(true) then erasedParams.count(_ == false) + rest else params.length + rest def normalParamCount(tp: Type): Int = tp.widenExpr.stripPoly match case mt @ MethodType(pnames) => val rest = normalParamCount(mt.resType) - if mt.isErasedMethod then rest else pnames.length + rest + if mt.hasErasedParams then + mt.erasedParams.count(_ == false) + rest + else pnames.length + rest case _ => contextParamCount(tp, contextResultCount(sym)) normalParamCount(sym.info) @@ -133,4 +133,4 @@ object ContextFunctionResults: case _ => false -end ContextFunctionResults \ No newline at end of file +end ContextFunctionResults diff --git a/compiler/src/dotty/tools/dotc/transform/ElimByName.scala b/compiler/src/dotty/tools/dotc/transform/ElimByName.scala index 479a455b4aea..151e841f0e48 100644 --- a/compiler/src/dotty/tools/dotc/transform/ElimByName.scala +++ b/compiler/src/dotty/tools/dotc/transform/ElimByName.scala @@ -15,6 +15,7 @@ import MegaPhase.* import Decorators.* import typer.RefChecks import reporting.trace +import dotty.tools.dotc.core.Names.Name /** This phase implements the following transformations: * @@ -79,11 +80,14 @@ class ElimByName extends MiniPhase, InfoTransformer: case ExprType(rt) if exprBecomesFunction(sym) => defn.ByNameFunction(rt) case tp: MethodType => - def exprToFun(tp: Type) = tp match - case ExprType(rt) => defn.ByNameFunction(rt) + def exprToFun(tp: Type, name: Name) = tp match + case ExprType(rt) => + if rt.hasAnnotation(defn.ErasedParamAnnot) then + report.error(em"By-name parameter cannot be erased: $name", sym.srcPos) + defn.ByNameFunction(rt) case tp => tp tp.derivedLambdaType( - paramInfos = tp.paramInfos.mapConserve(exprToFun), + paramInfos = tp.paramInfos.zipWithConserve(tp.paramNames)(exprToFun), resType = transformInfo(tp.resType, sym)) case tp: PolyType => tp.derivedLambdaType(resType = transformInfo(tp.resType, sym)) diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index 60affae38ad5..981dd5f60aea 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -500,7 +500,7 @@ object Erasure { if isFunction && !ctx.settings.scalajs.value then val arity = implParamTypes.length val specializedFunctionalInterface = - if defn.isSpecializableFunctionSAM(implParamTypes, implResultType) then + if !implType.hasErasedParams && defn.isSpecializableFunctionSAM(implParamTypes, implResultType) then // Using these subclasses is critical to avoid boxing since their // SAM is a specialized method `apply$mc*$sp` whose default // implementation in FunctionN boxes. @@ -679,6 +679,8 @@ object Erasure { val qualTp = tree.qualifier.typeOpt.widen if qualTp.derivesFrom(defn.PolyFunctionClass) then erasePolyFunctionApply(qualTp.select(nme.apply).widen).classSymbol + else if defn.isErasedFunctionType(qualTp) then + eraseErasedFunctionApply(qualTp.select(nme.apply).widen.asInstanceOf[MethodType]).classSymbol else NoSymbol } @@ -827,7 +829,10 @@ object Erasure { val Apply(fun, args) = tree val origFun = fun.asInstanceOf[tpd.Tree] val origFunType = origFun.tpe.widen(using preErasureCtx) - val ownArgs = if origFunType.isErasedMethod then Nil else args + val ownArgs = origFunType match + case mt: MethodType if mt.hasErasedParams => + args.zip(mt.erasedParams).collect { case (arg, false) => arg } + case _ => args val fun1 = typedExpr(fun, AnyFunctionProto) fun1.tpe.widen match case mt: MethodType => diff --git a/compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala b/compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala index 050abf7f3cb7..a1baeac272b9 100644 --- a/compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala +++ b/compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala @@ -311,7 +311,9 @@ object GenericSignatures { case mtpe: MethodType => // erased method parameters do not make it to the bytecode. def effectiveParamInfoss(t: Type)(using Context): List[List[Type]] = t match { - case t: MethodType if t.isErasedMethod => effectiveParamInfoss(t.resType) + case t: MethodType if t.hasErasedParams => + t.paramInfos.zip(t.erasedParams).collect{ case (i, false) => i } + :: effectiveParamInfoss(t.resType) case t: MethodType => t.paramInfos :: effectiveParamInfoss(t.resType) case _ => Nil } diff --git a/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala b/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala index c56bac4d66af..87c6e294c104 100644 --- a/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala +++ b/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala @@ -315,7 +315,7 @@ object PickleQuotes { defn.QuotedExprClass.typeRef.appliedTo(defn.AnyType)), args => val cases = termSplices.map { case (splice, idx) => - val defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _, _), _, _) = splice.tpe: @unchecked + val defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) = splice.tpe: @unchecked val rhs = { val spliceArgs = argTypes.zipWithIndex.map { (argType, i) => args(1).select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argType) diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 2039a8f19558..574db18c9c7f 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -296,19 +296,21 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase case tree: Apply => val methType = tree.fun.tpe.widen.asInstanceOf[MethodType] val app = - if (methType.isErasedMethod) + if (methType.hasErasedParams) tpd.cpy.Apply(tree)( tree.fun, - tree.args.mapConserve(arg => - if methType.isResultDependent then - Checking.checkRealizable(arg.tpe, arg.srcPos, "erased argument") - if (methType.isImplicitMethod && arg.span.isSynthetic) - arg match - case _: RefTree | _: Apply | _: TypeApply if arg.symbol.is(Erased) => - dropInlines.transform(arg) - case _ => - PruneErasedDefs.trivialErasedTree(arg) - else dropInlines.transform(arg))) + tree.args.zip(methType.erasedParams).map((arg, isErased) => + if !isErased then arg + else + if methType.isResultDependent then + Checking.checkRealizable(arg.tpe, arg.srcPos, "erased argument") + if (methType.isImplicitMethod && arg.span.isSynthetic) + arg match + case _: RefTree | _: Apply | _: TypeApply if arg.symbol.is(Erased) => + dropInlines.transform(arg) + case _ => + PruneErasedDefs.trivialErasedTree(arg) + else dropInlines.transform(arg))) else tree def app1 = diff --git a/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala b/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala index 568512207fde..17f2d11ccfec 100644 --- a/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala +++ b/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala @@ -13,6 +13,7 @@ import ast.tpd import SymUtils._ import config.Feature import Decorators.* +import dotty.tools.dotc.core.Types.MethodType /** This phase makes all erased term members of classes private so that they cannot * conflict with non-erased members. This is needed so that subsequent phases like @@ -38,8 +39,11 @@ class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform => else sym.copySymDenotation(initFlags = sym.flags | Private) override def transformApply(tree: Apply)(using Context): Tree = - if !tree.fun.tpe.widen.isErasedMethod then tree - else cpy.Apply(tree)(tree.fun, tree.args.map(trivialErasedTree)) + tree.fun.tpe.widen match + case mt: MethodType if mt.hasErasedParams => + cpy.Apply(tree)(tree.fun, tree.args.zip(mt.erasedParams).map((a, e) => if e then trivialErasedTree(a) else a)) + case _ => + tree override def transformValDef(tree: ValDef)(using Context): Tree = checkErasedInExperimental(tree.symbol) diff --git a/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala b/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala index c1f891d6293a..2248fbc8d570 100644 --- a/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala +++ b/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala @@ -70,7 +70,7 @@ class SpecializeFunctions extends MiniPhase { /** Dispatch to specialized `apply`s in user code when available */ override def transformApply(tree: Apply)(using Context) = tree match { - case Apply(fun: NameTree, args) if fun.name == nme.apply && args.size <= 3 && fun.symbol.owner.isType => + case Apply(fun: NameTree, args) if fun.name == nme.apply && args.size <= 3 && fun.symbol.maybeOwner.isType => val argTypes = fun.tpe.widen.firstParamTypes.map(_.widenSingleton.dealias) val retType = tree.tpe.widenSingleton.dealias val isSpecializable = diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index 133651fe08b2..52942ea719f9 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -445,10 +445,12 @@ object TreeChecker { // Polymorphic apply methods stay structural until Erasure val isPolyFunctionApply = (tree.name eq nme.apply) && tree.qualifier.typeOpt.derivesFrom(defn.PolyFunctionClass) + // Erased functions stay structural until Erasure + val isErasedFunctionApply = (tree.name eq nme.apply) && tree.qualifier.typeOpt.derivesFrom(defn.ErasedFunctionClass) // Outer selects are pickled specially so don't require a symbol val isOuterSelect = tree.name.is(OuterSelectName) val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name) - if !(tree.isType || isPolyFunctionApply || isOuterSelect || isPrimitiveArrayOp) then + if !(tree.isType || isPolyFunctionApply || isErasedFunctionApply || isOuterSelect || isPrimitiveArrayOp) then val denot = tree.denot assert(denot.exists, i"Selection $tree with type $tpe does not have a denotation") assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol, qualifier type = ${tree.qualifier.typeOpt}") diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index a8abd868fdef..345a8693063b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1978,7 +1978,7 @@ trait Applications extends Compatibility { val formals = ref.widen.firstParamTypes if formals.length > idx then formals(idx) match - case defn.FunctionOf(args, _, _, _) => args.length + case defn.FunctionOf(args, _, _) => args.length case _ => -1 else -1 @@ -2062,7 +2062,7 @@ trait Applications extends Compatibility { if isDetermined(alts2) then alts2 else resolveMapped(alts1, _.widen.appliedTo(targs1.tpes), pt1) - case defn.FunctionOf(args, resultType, _, _) => + case defn.FunctionOf(args, resultType, _) => narrowByTypes(alts, args, resultType) case pt => @@ -2225,7 +2225,7 @@ trait Applications extends Compatibility { val formalsForArg: List[Type] = altFormals.map(_.head) def argTypesOfFormal(formal: Type): List[Type] = formal.dealias match { - case defn.FunctionOf(args, result, isImplicit, isErased) => args + case defn.FunctionOf(args, result, isImplicit) => args case defn.PartialFunctionOf(arg, result) => arg :: Nil case _ => Nil } diff --git a/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala b/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala index 46725f0fa6b2..b1513df777ec 100644 --- a/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala +++ b/compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala @@ -285,8 +285,9 @@ object EtaExpansion extends LiftImpure { val body = Apply(lifted, ids) if (mt.isContextualMethod) body.setApplyKind(ApplyKind.Using) val fn = - if (mt.isContextualMethod) new untpd.FunctionWithMods(params, body, Modifiers(Given)) - else if (mt.isImplicitMethod) new untpd.FunctionWithMods(params, body, Modifiers(Implicit)) + if (mt.isContextualMethod) new untpd.FunctionWithMods(params, body, Modifiers(Given), mt.erasedParams) + else if (mt.isImplicitMethod) new untpd.FunctionWithMods(params, body, Modifiers(Implicit), mt.erasedParams) + else if (mt.hasErasedParams) new untpd.FunctionWithMods(params, body, Modifiers(), mt.erasedParams) else untpd.Function(params, body) if (defs.nonEmpty) untpd.Block(defs.toList map (untpd.TypedSplice(_)), fn) else fn } diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index da785e32865a..bde279c582e6 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -372,7 +372,7 @@ object ProtoTypes { private def isUndefined(tp: Type): Boolean = tp match { case _: WildcardType => true - case defn.FunctionOf(args, result, _, _) => args.exists(isUndefined) || isUndefined(result) + case defn.FunctionOf(args, result, _) => args.exists(isUndefined) || isUndefined(result) case _ => false } @@ -842,7 +842,7 @@ object ProtoTypes { normalize(et.resultType, pt) case wtp => val iftp = defn.asContextFunctionType(wtp) - if iftp.exists && followIFT then normalize(iftp.dropDependentRefinement.argInfos.last, pt) + if iftp.exists && followIFT then normalize(iftp.functionArgInfos.last, pt) else tp } } diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 1a7a4b97855b..103961b68c29 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -103,12 +103,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): def functionTypeEqual(baseFun: Type, actualArgs: List[Type], actualRet: Type, expected: Type) = expected =:= defn.FunctionOf(actualArgs, actualRet, - defn.isContextFunctionType(baseFun), defn.isErasedFunctionType(baseFun)) + defn.isContextFunctionType(baseFun)) val arity: Int = - if defn.isErasedFunctionType(fun) || defn.isErasedFunctionType(fun) then -1 // TODO support? + if defn.isErasedFunctionType(fun) then -1 // TODO support? else if defn.isFunctionType(fun) then // TupledFunction[(...) => R, ?] - fun.dropDependentRefinement.dealias.argInfos match + fun.functionArgInfos match case funArgs :+ funRet if functionTypeEqual(fun, defn.tupleType(funArgs) :: Nil, funRet, tupled) => // TupledFunction[(...funArgs...) => funRet, ?] @@ -116,7 +116,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): case _ => -1 else if defn.isFunctionType(tupled) then // TupledFunction[?, (...) => R] - tupled.dropDependentRefinement.dealias.argInfos match + tupled.functionArgInfos match case tupledArgs :: funRet :: Nil => defn.tupleTypes(tupledArgs.dealias) match case Some(funArgs) if functionTypeEqual(tupled, funArgs, funRet, fun) => diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 9d8fdcc006c9..4c6ec98ba9ba 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1262,7 +1262,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer (pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound))) case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe)) - if defn.isNonRefinedFunction(parent) && formals.length == defaultArity => + if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity => (formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))) case SAMType(mt @ MethodTpe(_, formals, restpe)) => (formals, @@ -1293,16 +1293,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer * If both attempts fail, return `NoType`. */ def inferredFromTarget( - param: untpd.ValDef, formal: Type, calleeType: Type, paramIndex: Name => Int)(using Context): Type = + param: untpd.ValDef, formal: Type, calleeType: Type, isErased: Boolean, paramIndex: Name => Int)(using Context): Type = val target = calleeType.widen match case mtpe: MethodType => val pos = paramIndex(param.name) if pos < mtpe.paramInfos.length then - mtpe.paramInfos(pos) + val tp = mtpe.paramInfos(pos) // This works only if vararg annotations match up. // See neg/i14367.scala for an example where the inferred type is mispredicted. // Nevertheless, the alternative would be to give up completely, so this is // defensible. + // Strip inferred erased annotation, to avoid accidentally inferring erasedness + if !isErased then tp.stripAnnots(_.symbol != defn.ErasedParamAnnot) else tp else NoType case _ => NoType if target.exists then formal <:< target @@ -1316,32 +1318,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedFunctionType(tree: untpd.Function, pt: Type)(using Context): Tree = { val untpd.Function(args, body) = tree - var funFlags = tree match { - case tree: untpd.FunctionWithMods => tree.mods.flags - case _ => EmptyFlags + var (funFlags, erasedParams) = tree match { + case tree: untpd.FunctionWithMods => (tree.mods.flags, tree.erasedParams) + case _ => (EmptyFlags, args.map(_ => false)) } - assert(!funFlags.is(Erased) || !args.isEmpty, "An empty function cannot not be erased") - val numArgs = args.length val isContextual = funFlags.is(Given) - val isErased = funFlags.is(Erased) val isImpure = funFlags.is(Impure) - val funSym = defn.FunctionSymbol(numArgs, isContextual, isErased, isImpure) - - /** If `app` is a function type with arguments that are all erased classes, - * turn it into an erased function type. - */ - def propagateErased(app: Tree): Tree = app match - case AppliedTypeTree(tycon: TypeTree, args) - if !isErased - && numArgs > 0 - && args.indexWhere(!_.tpe.isErasedClass) == numArgs => - val tycon1 = TypeTree(defn.FunctionSymbol(numArgs, isContextual, true, isImpure).typeRef) - .withSpan(tycon.span) - assignType(cpy.AppliedTypeTree(app)(tycon1, args), tycon1, args) - case _ => - app /** Typechecks dependent function type with given parameters `params` */ def typedDependent(params: List[untpd.ValDef])(using Context): Tree = @@ -1356,16 +1340,29 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if funFlags.is(Given) then params.map(_.withAddedFlags(Given)) else params val params2 = params1.map(fixThis.transformSub) - val appDef0 = untpd.DefDef(nme.apply, List(params2), body, EmptyTree).withSpan(tree.span) + val params3 = params2.zipWithConserve(erasedParams) { (arg, isErased) => + if isErased then arg.withAddedFlags(Erased) else arg + } + val appDef0 = untpd.DefDef(nme.apply, List(params3), body, EmptyTree).withSpan(tree.span) index(appDef0 :: Nil) val appDef = typed(appDef0).asInstanceOf[DefDef] val mt = appDef.symbol.info.asInstanceOf[MethodType] if (mt.isParamDependent) report.error(em"$mt is an illegal function type because it has inter-parameter dependencies", tree.srcPos) + // Restart typechecking if there are erased classes that we want to mark erased + if mt.erasedParams.zip(mt.paramInfos.map(_.isErasedClass)).exists((paramErased, classErased) => classErased && !paramErased) then + val newParams = params3.zipWithConserve(mt.paramInfos.map(_.isErasedClass)) { (arg, isErasedClass) => + if isErasedClass then arg.withAddedFlags(Erased) else arg + } + return typedDependent(newParams) val resTpt = TypeTree(mt.nonDependentResultApprox).withSpan(body.span) val typeArgs = appDef.termParamss.head.map(_.tpt) :+ resTpt - val tycon = TypeTree(funSym.typeRef) - val core = propagateErased(AppliedTypeTree(tycon, typeArgs)) + val core = + if mt.hasErasedParams then TypeTree(defn.ErasedFunctionClass.typeRef) + else + val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure) + val tycon = TypeTree(funSym.typeRef) + AppliedTypeTree(tycon, typeArgs) RefinedTypeTree(core, List(appDef), ctx.owner.asClass) end typedDependent @@ -1374,17 +1371,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedDependent(args.asInstanceOf[List[untpd.ValDef]])( using ctx.fresh.setOwner(newRefinedClassSymbol(tree.span)).setNewScope) case _ => - propagateErased( - typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt)) + if erasedParams.contains(true) then + typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt) + else + val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure) + val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt) + // if there are any erased classes, we need to re-do the typecheck. + result match + case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) => + typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt) + case _ => result } } def typedFunctionValue(tree: untpd.Function, pt: Type)(using Context): Tree = { val untpd.Function(params: List[untpd.ValDef] @unchecked, _) = tree: @unchecked - val isContextual = tree match { - case tree: untpd.FunctionWithMods => tree.mods.is(Given) - case _ => false + val (isContextual, isDefinedErased) = tree match { + case tree: untpd.FunctionWithMods => (tree.mods.is(Given), tree.erasedParams) + case _ => (false, tree.args.map(_ => false)) } /** The function body to be returned in the closure. Can become a TypedSplice @@ -1485,9 +1490,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos) - def protoFormal(i: Int): Type = - if (protoFormals.length == params.length) protoFormals(i) - else errorType(WrongNumberOfParameters(protoFormals.length), tree.srcPos) + /** Returns the type and whether the parameter is erased */ + def protoFormal(i: Int): (Type, Boolean) = + if (protoFormals.length == params.length) (protoFormals(i), isDefinedErased(i)) + else (errorType(WrongNumberOfParameters(protoFormals.length), tree.srcPos), false) /** Is `formal` a product type which is elementwise compatible with `params`? */ def ptIsCorrectProduct(formal: Type) = @@ -1525,28 +1531,32 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if desugared.isEmpty then val inferredParams: List[untpd.ValDef] = for ((param, i) <- params.zipWithIndex) yield - if (!param.tpt.isEmpty) param - else - val formalBounds = protoFormal(i) - val formal = formalBounds.loBound - val isBottomFromWildcard = (formalBounds ne formal) && formal.isExactlyNothing - val knownFormal = isFullyDefined(formal, ForceDegree.failBottom) - // If the expected formal is a TypeBounds wildcard argument with Nothing as lower bound, - // try to prioritize inferring from target. See issue 16405 (tests/run/16405.scala) - val paramType = - if knownFormal && !isBottomFromWildcard then - formal - else - inferredFromTarget(param, formal, calleeType, paramIndex).orElse( - if knownFormal then formal - else errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos) + val (formalBounds, isErased) = protoFormal(i) + val param0 = + if (!param.tpt.isEmpty) param + else + val formal = formalBounds.loBound + val isBottomFromWildcard = (formalBounds ne formal) && formal.isExactlyNothing + val knownFormal = isFullyDefined(formal, ForceDegree.failBottom) + // If the expected formal is a TypeBounds wildcard argument with Nothing as lower bound, + // try to prioritize inferring from target. See issue 16405 (tests/run/16405.scala) + val paramType = + // Strip inferred erased annotation, to avoid accidentally inferring erasedness + val formal0 = if !isErased then formal.stripAnnots(_.symbol != defn.ErasedParamAnnot) else formal + if knownFormal && !isBottomFromWildcard then + formal0 + else + inferredFromTarget(param, formal, calleeType, isErased, paramIndex).orElse( + if knownFormal then formal0 + else errorType(AnonymousFunctionMissingParamType(param, tree, formal), param.srcPos) + ) + val paramTpt = untpd.TypedSplice( + (if knownFormal then InferredTypeTree() else untpd.TypeTree()) + .withType(paramType.translateFromRepeated(toArray = false)) + .withSpan(param.span.endPos) ) - val paramTpt = untpd.TypedSplice( - (if knownFormal then InferredTypeTree() else untpd.TypeTree()) - .withType(paramType.translateFromRepeated(toArray = false)) - .withSpan(param.span.endPos) - ) - cpy.ValDef(param)(tpt = paramTpt) + cpy.ValDef(param)(tpt = paramTpt) + if isErased then param0.withAddedFlags(Flags.Erased) else param0 desugared = desugar.makeClosure(inferredParams, fnBody, resultTpt, isContextual, tree.span) typed(desugared, pt) @@ -1585,7 +1595,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer |because it has internal parameter dependencies""") else if ((tree.tpt `eq` untpd.ContextualEmptyTree) && mt.paramNames.isEmpty) // Note implicitness of function in target type since there are no method parameters that indicate it. - TypeTree(defn.FunctionOf(Nil, mt.resType, isContextual = true, isErased = false)) + TypeTree(defn.FunctionOf(Nil, mt.resType, isContextual = true)) else if hasCaptureConversionArg(mt.resType) then errorTree(tree, em"""cannot turn method type $mt into closure @@ -2605,6 +2615,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // check value class constraints checkDerivedValueClass(cls, body1) + // check PolyFunction constraints (no erased functions!) + if parents1.exists(_.tpe.classSymbol eq defn.PolyFunctionClass) then + body1.foreach { + case ddef: DefDef => + ddef.paramss.foreach { params => + val erasedParam = params.collectFirst { case vdef: ValDef if vdef.symbol.is(Erased) => vdef } + erasedParam.foreach { p => + report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", p.srcPos) + } + } + case _ => + } + val effectiveOwner = cls.owner.skipWeakOwner if !cls.isRefinementClass && !cls.isAllOf(PrivateLocal) @@ -3031,7 +3054,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer tree protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { - val defn.FunctionOf(formals, _, true, _) = pt.dropDependentRefinement: @unchecked + val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked // The getter of default parameters may reach here. // Given the code below @@ -3059,7 +3082,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else formals.map(untpd.TypeTree) } - val ifun = desugar.makeContextualFunction(paramTypes, tree, defn.isErasedFunctionType(pt)) + val erasedParams = pt.dealias match { + case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams + case _ => paramTypes.map(_ => false) + } + + val ifun = desugar.makeContextualFunction(paramTypes, tree, erasedParams) typr.println(i"make contextual function $tree / $pt ---> $ifun") typedFunctionValue(ifun, pt) } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 6688d6e81a89..f43c051ebf60 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -1581,8 +1581,12 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler self.nonEmpty && self.head.symbol.is(dotc.core.Flags.Implicit) def isGiven: Boolean = self.nonEmpty && self.head.symbol.is(dotc.core.Flags.Given) - def isErased: Boolean = - self.nonEmpty && self.head.symbol.is(dotc.core.Flags.Erased) + def isErased: Boolean = false + + def erasedArgs: List[Boolean] = + self.map(_.symbol.is(dotc.core.Flags.Erased)) + def hasErasedArgs: Boolean = + self.exists(_.symbol.is(dotc.core.Flags.Erased)) end TermParamClauseMethods type TypeParamClause = List[tpd.TypeDef] @@ -2139,9 +2143,12 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler given MethodTypeMethods: MethodTypeMethods with extension (self: MethodType) - def isErased: Boolean = self.isErasedMethod + def isErased: Boolean = false def isImplicit: Boolean = self.isImplicitMethod def param(idx: Int): TypeRepr = self.newParamRef(idx) + + def erasedParams: List[Boolean] = self.erasedParams + def hasErasedParams: Boolean = self.hasErasedParams end extension end MethodTypeMethods @@ -2768,11 +2775,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler def ProductClass: Symbol = dotc.core.Symbols.defn.ProductClass def FunctionClass(arity: Int, isImplicit: Boolean = false, isErased: Boolean = false): Symbol = if arity < 0 then throw IllegalArgumentException(s"arity: $arity") - dotc.core.Symbols.defn.FunctionSymbol(arity, isImplicit, isErased) + if isErased then + throw new Exception("Erased function classes are not supported. Use a refined `scala.runtime.ErasedFunction`") + else dotc.core.Symbols.defn.FunctionSymbol(arity, isImplicit) def FunctionClass(arity: Int): Symbol = FunctionClass(arity, false, false) def FunctionClass(arity: Int, isContextual: Boolean): Symbol = FunctionClass(arity, isContextual, false) + def ErasedFunctionClass = dotc.core.Symbols.defn.ErasedFunctionClass def TupleClass(arity: Int): Symbol = dotc.core.Symbols.defn.TupleType(arity).nn.classSymbol.asClass def isTupleClass(sym: Symbol): Boolean = diff --git a/docs/_docs/internals/syntax.md b/docs/_docs/internals/syntax.md index 1fbb7a34b078..445e86ee2408 100644 --- a/docs/_docs/internals/syntax.md +++ b/docs/_docs/internals/syntax.md @@ -140,7 +140,7 @@ type val var while with yield ### Soft keywords ``` -as derives end extension infix inline opaque open throws transparent using | * + - +as derives end erased extension infix inline opaque open throws transparent using | * + - ``` See the [separate section on soft keywords](../reference/soft-modifier.md) for additional @@ -180,13 +180,13 @@ Type ::= FunType | FunParamClause ‘=>>’ Type TermLambdaTypeTree(ps, t) | MatchType | InfixType -FunType ::= FunTypeArgs (‘=>’ | ‘?=>’) Type Function(ts, t) +FunType ::= FunTypeArgs (‘=>’ | ‘?=>’) Type Function(ts, t) | FunctionWithMods(ts, t, mods, erasedParams) | HKTypeParamClause '=>' Type PolyFunction(ps, t) FunTypeArgs ::= InfixType | ‘(’ [ FunArgTypes ] ‘)’ | FunParamClause FunParamClause ::= ‘(’ TypedFunParam {‘,’ TypedFunParam } ‘)’ -TypedFunParam ::= id ‘:’ Type +TypedFunParam ::= [`erased`] id ‘:’ Type MatchType ::= InfixType `match` <<< TypeCaseClauses >>> InfixType ::= RefinedType {id [nl] RefinedType} InfixOp(t1, op, t2) RefinedType ::= AnnotType {[nl] Refinement} RefinedTypeTree(t, ds) @@ -207,8 +207,8 @@ Singleton ::= SimpleRef | SimpleLiteral | Singleton ‘.’ id Singletons ::= Singleton { ‘,’ Singleton } -FunArgType ::= Type - | ‘=>’ Type PrefixOp(=>, t) +FunArgType ::= [`erased`] Type + | [`erased`] ‘=>’ Type PrefixOp(=>, t) FunArgTypes ::= FunArgType { ‘,’ FunArgType } ParamType ::= [‘=>’] ParamValueType ParamValueType ::= [‘into’] ExactParamType Into(t) @@ -229,7 +229,7 @@ BlockResult ::= FunParams (‘=>’ | ‘?=>’) Block | HkTypeParamClause ‘=>’ Block | Expr1 FunParams ::= Bindings - | id + | [`erased`] id | ‘_’ Expr1 ::= [‘inline’] ‘if’ ‘(’ Expr ‘)’ {nl} Expr [[semi] ‘else’ Expr] If(Parens(cond), thenp, elsep?) | [‘inline’] ‘if’ Expr ‘then’ Expr [[semi] ‘else’ Expr] If(cond, thenp, elsep?) @@ -376,13 +376,13 @@ UsingParamClause ::= [nl] ‘(’ ‘using’ (DefTermParams | FunArgTypes) DefImplicitClause ::= [nl] ‘(’ ‘implicit’ DefTermParams ‘)’ DefTermParams ::= DefTermParam {‘,’ DefTermParam} -DefTermParam ::= {Annotation} [‘inline’] Param ValDef(mods, id, tpe, expr) -- point of mods at id. +DefTermParam ::= {Annotation} [`erased`] [‘inline’] Param ValDef(mods, id, tpe, expr) -- point of mods at id. Param ::= id ‘:’ ParamType [‘=’ Expr] ``` ### Bindings and Imports ```ebnf -Bindings ::= ‘(’ [Binding {‘,’ Binding}] ‘)’ +Bindings ::= ‘(’[`erased`] [Binding {‘,’ [`erased`] Binding}] ‘)’ Binding ::= (id | ‘_’) [‘:’ Type] ValDef(_, id, tpe, EmptyTree) Modifier ::= LocalModifier diff --git a/docs/_docs/reference/experimental/erased-defs-spec.md b/docs/_docs/reference/experimental/erased-defs-spec.md index 24ae89c7e28b..59dfed92da2a 100644 --- a/docs/_docs/reference/experimental/erased-defs-spec.md +++ b/docs/_docs/reference/experimental/erased-defs-spec.md @@ -19,8 +19,8 @@ TODO: complete def g(erased x: Int) = ... - (erased x: Int) => ... - def h(x: (erased Int) => Int) = ... + (erased x: Int, y: Int) => ... + def h(x: (Int, erased Int) => Int) = ... class K(erased x: Int) { ... } erased class E {} @@ -34,12 +34,12 @@ TODO: complete 3. Functions * `(erased x1: T1, x2: T2, ..., xN: TN) => y : (erased T1, T2, ..., TN) => R` - * `(given erased x1: T1, x2: T2, ..., xN: TN) => y: (given erased T1, T2, ..., TN) => R` + * `(given x1: T1, erased x2: T2, ..., xN: TN) => y: (given T1, erased T2, ..., TN) => R` * `(given erased T1) => R <:< erased T1 => R` - * `(given erased T1, T2) => R <:< (erased T1, T2) => R` + * `(given T1, erased T2) => R <:< (T1, erased T2) => R` * ... - Note that there is no subtype relation between `(erased T) => R` and `T => R` (or `(given erased T) => R` and `(given T) => R`) + Note that there is no subtype relation between `(erased T) => R` and `T => R` (or `(given erased T) => R` and `(given T) => R`). The `erased` parameters must match exactly in their respective positions. 4. Eta expansion @@ -51,7 +51,8 @@ TODO: complete * All `erased` parameters are removed from the function * All argument to `erased` parameters are not passed to the function * All `erased` definitions are removed - * All `(erased T1, T2, ..., TN) => R` and `(given erased T1, T2, ..., TN) => R` become `() => R` + * `(erased ET1, erased ET2, T1, ..., erased ETN, TM) => R` are erased to `(T1, ..., TM) => R`. + * `(given erased ET1, erased ET2, T1, ..., erased ETN, TM) => R` are erased to `(given T1, ..., TM) => R`. 6. Overloading @@ -60,11 +61,10 @@ TODO: complete 7. Overriding - * Member definitions overriding each other must both be `erased` or not be `erased` - * `def foo(x: T): U` cannot be overridden by `def foo(erased x: T): U` and vice-versa - * - + * Member definitions overriding each other must both be `erased` or not be `erased`. + * `def foo(x: T): U` cannot be overridden by `def foo(erased x: T): U` and vice-versa. 8. Type Restrictions * For dependent functions, `erased` parameters are limited to realizable types, that is, types that are inhabited by non-null values. This restriction stops us from using a bad bound introduced by an erased value, which leads to unsoundness (see #4060). + * Polymorphic functions with erased parameters are currently not supported, and will be rejected by the compiler. This is purely an implementation restriction, and might be lifted in the future. diff --git a/docs/_docs/reference/experimental/erased-defs.md b/docs/_docs/reference/experimental/erased-defs.md index 28455f26cdc0..ef4f02e33dd4 100644 --- a/docs/_docs/reference/experimental/erased-defs.md +++ b/docs/_docs/reference/experimental/erased-defs.md @@ -54,13 +54,13 @@ semantics and they are completely erased. ## How to define erased terms? Parameters of methods and functions can be declared as erased, placing `erased` -in front of a parameter list (like `given`). +in front of each erased parameter (like `inline`). ```scala -def methodWithErasedEv(erased ev: Ev): Int = 42 +def methodWithErasedEv(erased ev: Ev, x: Int): Int = x + 2 -val lambdaWithErasedEv: erased Ev => Int = - (erased ev: Ev) => 42 +val lambdaWithErasedEv: (erased Ev, Int) => Int = + (erased ev, x) => x + 2 ``` `erased` parameters will not be usable for computations, though they can be used @@ -80,7 +80,7 @@ parameters. ```scala erased val erasedEvidence: Ev = ... -methodWithErasedEv(erasedEvidence) +methodWithErasedEv(erasedEvidence, 40) // 42 ``` ## What happens with erased values at runtime? @@ -89,15 +89,15 @@ As `erased` are guaranteed not to be used in computations, they can and will be erased. ```scala -// becomes def methodWithErasedEv(): Int at runtime -def methodWithErasedEv(erased ev: Ev): Int = ... +// becomes def methodWithErasedEv(x: Int): Int at runtime +def methodWithErasedEv(x: Int, erased ev: Ev): Int = ... def evidence1: Ev = ... erased def erasedEvidence2: Ev = ... // does not exist at runtime erased val erasedEvidence3: Ev = ... // does not exist at runtime -// evidence1 is not evaluated and no value is passed to methodWithErasedEv -methodWithErasedEv(evidence1) +// evidence1 is not evaluated and only `x` is passed to methodWithErasedEv +methodWithErasedEv(x, evidence1) ``` ## State machine with erased evidence example diff --git a/library/src/scala/quoted/Quotes.scala b/library/src/scala/quoted/Quotes.scala index 321725d4e84c..79a4ea039ab5 100644 --- a/library/src/scala/quoted/Quotes.scala +++ b/library/src/scala/quoted/Quotes.scala @@ -2374,7 +2374,16 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => /** Is this a given parameter clause `(using X1, ..., Xn)` or `(using x1: X1, ..., xn: Xn)` */ def isGiven: Boolean /** Is this a erased parameter clause `(erased x1: X1, ..., xn: Xn)` */ + // TODO:deprecate in 3.4 and stabilize `erasedArgs` and `hasErasedArgs`. + // @deprecated("Use `hasErasedArgs`","3.4") def isErased: Boolean + + /** List of `erased` flags for each parameter of the clause */ + @experimental + def erasedArgs: List[Boolean] + /** Whether the clause has any erased parameters */ + @experimental + def hasErasedArgs: Boolean end TermParamClauseMethods /** A type parameter clause `[X1, ..., Xn]` */ @@ -2650,7 +2659,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => */ def isContextFunctionType: Boolean - /** Is this type an erased function type? + /** Is this type a function type with erased parameters? * * @see `isFunctionType` */ @@ -3145,7 +3154,17 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => extension (self: MethodType) /** Is this the type of using parameter clause `(implicit X1, ..., Xn)`, `(using X1, ..., Xn)` or `(using x1: X1, ..., xn: Xn)` */ def isImplicit: Boolean + /** Is this the type of erased parameter clause `(erased x1: X1, ..., xn: Xn)` */ + // TODO:deprecate in 3.4 and stabilize `erasedParams` and `hasErasedParams`. + // @deprecated("Use `hasErasedParams`","3.4") def isErased: Boolean + + /** List of `erased` flags for each parameters of the clause */ + @experimental + def erasedParams: List[Boolean] + /** Whether the clause has any erased parameters */ + @experimental + def hasErasedParams: Boolean def param(idx: Int): TypeRepr end extension end MethodTypeMethods @@ -4275,6 +4294,10 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => @experimental def FunctionClass(arity: Int, isContextual: Boolean): Symbol + /** The `scala.runtime.ErasedFunction` built-in trait. */ + @experimental + def ErasedFunctionClass: Symbol + /** Function-like object that maps arity to symbols for classes `scala.TupleX`. * - 0th element is `NoSymbol` * - 1st element is `NoSymbol` diff --git a/library/src/scala/runtime/ErasedFunction.scala b/library/src/scala/runtime/ErasedFunction.scala new file mode 100644 index 000000000000..7e9211bba75a --- /dev/null +++ b/library/src/scala/runtime/ErasedFunction.scala @@ -0,0 +1,11 @@ +package scala.runtime + +import scala.annotation.experimental + +/** Marker trait for function types with erased parameters. + * + * This trait will be refined with an `apply` method with erased parameters: + * ErasedFunction { def apply([erased] x_1: P_1, ..., [erased] x_N: P_N): R } + * This type will be erased to FunctionL, where L = N - count(erased). + */ +@experimental trait ErasedFunction diff --git a/tests/neg-custom-args/erased/by-name.scala b/tests/neg-custom-args/erased/by-name.scala new file mode 100644 index 000000000000..707cfd96734b --- /dev/null +++ b/tests/neg-custom-args/erased/by-name.scala @@ -0,0 +1,4 @@ +def f(x: => Int, erased y: => Int) = x // error +def g(erased x: => Int, y: => Int) = y // error + +val h: (erased => Int, Int) => Int = (erased x, y) => y // error diff --git a/tests/neg-custom-args/erased/erased-in-tuples.scala b/tests/neg-custom-args/erased/erased-in-tuples.scala new file mode 100644 index 000000000000..11a251c3bd4d --- /dev/null +++ b/tests/neg-custom-args/erased/erased-in-tuples.scala @@ -0,0 +1,16 @@ +@main def Test() = + val x = 5 + val y = 7 + + val t1 = (x, erased y) // error + val t2 = (erased x, y) // error + val t1a = (x: Int, erased y: Int) // error + val t2a = (erased x: Int, y: Int) // error + + val nest = (x, (x, erased y)) // error + + def use(f: (Int, Int) => Any) = f(5, 6) + + use((_, erased _)) // error + + (x, erased y) // error diff --git a/tests/neg-custom-args/erased/lambda-infer.scala b/tests/neg-custom-args/erased/lambda-infer.scala new file mode 100644 index 000000000000..2eebf8186b0d --- /dev/null +++ b/tests/neg-custom-args/erased/lambda-infer.scala @@ -0,0 +1,23 @@ +type F = (Int, erased Int) => Int + +erased class A + +@main def Test() = + val a: F = (x, y) => x + 1 // error: Expected F got (Int, Int) => Int + val b: F = (x, erased y) => x + 1 // ok + val c: F = (_, _) => 5 // error: Expected F got (Int, Int) => Int + val d: F = (_, erased _) => 5 // ok + + def use(f: F) = f(5, 6) + + use { (x, y) => x } // error: Expected F got (Int, Int) => Int + + def singleParam(f: (erased Int) => Int) = f(5) + + singleParam(x => 5) // error: Expected (erased Int) => Int got Int => Int + singleParam((erased x) => 5) // ok + + def erasedClass(f: A => Int) = f(new A) + + erasedClass(_ => 5) // ok since A is implicitly erased + diff --git a/tests/neg-custom-args/erased/multiple-args-consume.scala b/tests/neg-custom-args/erased/multiple-args-consume.scala new file mode 100644 index 000000000000..e4aaacca8969 --- /dev/null +++ b/tests/neg-custom-args/erased/multiple-args-consume.scala @@ -0,0 +1,13 @@ +def foo(erased x: Int, y: Int) = y +def bar(x: Int, erased y: Int) = x + +def consumeFoo(f: (erased x: Int, y: Int) => Int) = f(0, 1) + +val fooF: (erased x: Int, y: Int) => Int = foo +val barF: (x: Int, erased y: Int) => Int = bar + +val a = consumeFoo(foo) // ok +val b = consumeFoo(bar) // error + +val c = consumeFoo(fooF) // ok +val d = consumeFoo(barF) // error diff --git a/tests/neg-custom-args/erased/multiple-args.scala b/tests/neg-custom-args/erased/multiple-args.scala new file mode 100644 index 000000000000..fb9bce8e4573 --- /dev/null +++ b/tests/neg-custom-args/erased/multiple-args.scala @@ -0,0 +1,11 @@ +def foo(x: Int, erased y: Int): Int = x +def bar(erased x: Int, y: Int): Int = y + +val fooF: (x: Int, erased y: Int) => Int = foo + +val fooG: (erased x: Int, y: Int) => Int = foo // error + +val barF: (x: Int, erased y: Int) => Int = bar // error + +val barG: (erased x: Int, y: Int) => Int = bar + diff --git a/tests/neg-custom-args/erased/poly-functions.scala b/tests/neg-custom-args/erased/poly-functions.scala new file mode 100644 index 000000000000..000a2ca49cc9 --- /dev/null +++ b/tests/neg-custom-args/erased/poly-functions.scala @@ -0,0 +1,16 @@ +object Test: + // Poly functions with erased parameters are disallowed as an implementation restriction + + type T1 = [X] => (erased x: X, y: Int) => Int // error + type T2 = [X] => (x: X, erased y: Int) => X // error + + val t1 = [X] => (erased x: X, y: Int) => y // error + val t2 = [X] => (x: X, erased y: Int) => x // error + + // Erased classes should be detected too + erased class A + + type T3 = [X] => (x: A, y: X) => X // error + + val t3 = [X] => (x: A, y: X) => y // error + diff --git a/tests/neg/safeThrowsStrawman2.scala b/tests/neg/safeThrowsStrawman2.scala index 7d87baad6fa4..8d95494e30e0 100644 --- a/tests/neg/safeThrowsStrawman2.scala +++ b/tests/neg/safeThrowsStrawman2.scala @@ -24,7 +24,7 @@ def bar(x: Boolean)(using CanThrow[Fail]): Int = val x = new CanThrow[Fail]() // OK, x is erased val y: Any = new CanThrow[Fail]() // error: illegal reference to erased class CanThrow val y2: Any = new CTF() // error: illegal reference to erased class CanThrow - println(foo(true, ctf)) // error: ctf is declared as erased, but is in fact used + println(foo(true, ctf)) // not error: ctf will be erased at erasure val a = (1, new CanThrow[Fail]()) // error: illegal reference to erased class CanThrow def b: (Int, CanThrow[Fail]) = ??? def c = b._2 // ok; we only check creation sites diff --git a/tests/pos-custom-args/erased/erased-class-as-args.scala b/tests/pos-custom-args/erased/erased-class-as-args.scala new file mode 100644 index 000000000000..74c827fbd54b --- /dev/null +++ b/tests/pos-custom-args/erased/erased-class-as-args.scala @@ -0,0 +1,22 @@ +erased class A + +erased class B(val x: Int) extends A + +type T = (x: A, y: Int) => Int + +type TSub[-T <: A] = (erased x: T, y: Int) => Int + +def useT(f: T) = f(new A, 5) + +def useTSub(f: TSub[B]) = f(new B(5), 5) + +@main def Test() = + val tInfer = (x: A, y: Int) => y + 1 + val tExpl: T = (x, y) => y + 1 + assert(useT((erased x, y) => y + 1) == 6) + assert(useT(tInfer) == 6) + assert(useT(tExpl) == 6) + + val tSub: TSub[A] = (x, y) => y + 1 + assert(useT(tSub) == 6) + assert(useTSub(tSub) == 6) diff --git a/tests/pos-custom-args/erased/erased-soft-keyword.scala b/tests/pos-custom-args/erased/erased-soft-keyword.scala new file mode 100644 index 000000000000..fdb884628c7d --- /dev/null +++ b/tests/pos-custom-args/erased/erased-soft-keyword.scala @@ -0,0 +1,18 @@ +def f1(x: Int, erased y: Int) = 0 +def f2(x: Int, erased: Int) = 0 +inline def f3(x: Int, inline erased: Int) = 0 +def f4(x: Int, erased inline: Int) = 0 +// inline def f5(x: Int, erased inline y: Int) = 0 // should parse but rejected later + +def f6(using erased y: Int) = 0 +def f7(using erased: Int) = 0 +inline def f8(using inline erased: Int) = 0 +def f9(using erased inline: Int) = 0 +// inline def f10(using erased inline x: Int) = 0 // should parse but rejected later +def f11(using erased Int) = 0 + +val v1 = (erased: Int) => 0 +val v2: Int => Int = erased => 0 +val v3 = (erased x: Int) => 0 +val v4: (erased Int) => Int = (erased x) => 0 +val v5: (erased: Int) => Int = x => 0 diff --git a/tests/run-custom-args/erased/erased-15.scala b/tests/run-custom-args/erased/erased-15.scala index b879ee4c54d8..02b70f9125d6 100644 --- a/tests/run-custom-args/erased/erased-15.scala +++ b/tests/run-custom-args/erased/erased-15.scala @@ -1,3 +1,5 @@ +import scala.runtime.ErasedFunction + object Test { def main(args: Array[String]): Unit = { @@ -10,7 +12,7 @@ object Test { } } -class Foo extends ErasedFunction1[Int, Int] { +class Foo extends ErasedFunction { def apply(erased x: Int): Int = { println("Foo.apply") 42 diff --git a/tests/run-custom-args/erased/erased-27.check b/tests/run-custom-args/erased/erased-27.check index 4413863feead..1c255dd5419f 100644 --- a/tests/run-custom-args/erased/erased-27.check +++ b/tests/run-custom-args/erased/erased-27.check @@ -1,3 +1,2 @@ block -x foo diff --git a/tests/run-custom-args/erased/erased-28.check b/tests/run-custom-args/erased/erased-28.check index 85733f6db2d7..3bd1f0e29744 100644 --- a/tests/run-custom-args/erased/erased-28.check +++ b/tests/run-custom-args/erased/erased-28.check @@ -1,4 +1,2 @@ -x foo -x bar diff --git a/tests/run-custom-args/erased/erased-class-are-erased.check b/tests/run-custom-args/erased/erased-class-are-erased.check new file mode 100644 index 000000000000..f64f5d8d85ac --- /dev/null +++ b/tests/run-custom-args/erased/erased-class-are-erased.check @@ -0,0 +1 @@ +27 diff --git a/tests/run-custom-args/erased/erased-class-are-erased.scala b/tests/run-custom-args/erased/erased-class-are-erased.scala new file mode 100644 index 000000000000..b48e0265c521 --- /dev/null +++ b/tests/run-custom-args/erased/erased-class-are-erased.scala @@ -0,0 +1,14 @@ +object Test { + erased class Erased() { + println("Oh no!!!") + } + + def f(x: Erased, y: Int = 0): Int = y + 5 + + def g() = Erased() + + def main(args: Array[String]) = + val y = Erased() + val z = 10 + println(f(Erased()) + z + f(g(), 7)) +} diff --git a/tests/run-custom-args/erased/lambdas.scala b/tests/run-custom-args/erased/lambdas.scala new file mode 100644 index 000000000000..4c1746283099 --- /dev/null +++ b/tests/run-custom-args/erased/lambdas.scala @@ -0,0 +1,38 @@ +// lambdas should parse and work + +type F = (erased Int, String) => String +type S = (Int, erased String) => Int + +def useF(f: F) = f(5, "a") +def useS(f: S) = f(5, "a") + +val ff: F = (erased x, y) => y + +val fs: S = (x, erased y) => x +val fsExpl = (x: Int, erased y: String) => x + +// contextual lambdas should work + +type FC = (Int, erased String) ?=> Int + +def useCtx(f: FC) = f(using 5, "a") + +val fCv: FC = (x, erased y) ?=> x +val fCvExpl = (x: Int, erased y: String) ?=> x + +// nested lambdas should work + +val nested: Int => (String, erased Int) => FC = a => (_, erased _) => (c, erased d) ?=> a + c + +@main def Test() = + assert("a" == useF(ff)) + + assert(5 == useS(fs)) + assert(5 == useS(fsExpl)) + assert(5 == useS { (x, erased y) => x }) + + assert(5 == useCtx(fCv)) + assert(5 == useCtx(fCvExpl)) + assert(5 == useCtx { (x, erased y) ?=> x }) + + assert(6 == useCtx(nested(1)("b", 2))) diff --git a/tests/run-custom-args/erased/quotes-add-erased.check b/tests/run-custom-args/erased/quotes-add-erased.check new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/tests/run-custom-args/erased/quotes-add-erased.check @@ -0,0 +1 @@ +1 diff --git a/tests/run-custom-args/erased/quotes-add-erased/Macro_1.scala b/tests/run-custom-args/erased/quotes-add-erased/Macro_1.scala new file mode 100644 index 000000000000..66f8475da96d --- /dev/null +++ b/tests/run-custom-args/erased/quotes-add-erased/Macro_1.scala @@ -0,0 +1,26 @@ +import scala.annotation.MacroAnnotation +import scala.annotation.internal.ErasedParam +import scala.quoted._ + +class NewAnnotation extends scala.annotation.Annotation + +class erasedParamsMethod extends MacroAnnotation: + def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case ClassDef(name, ctr, parents, self, body) => + val erasedInt = AnnotatedType(TypeRepr.of[Int], '{ new ErasedParam }.asTerm) + val methType = MethodType(List("x", "y"))(_ => List(erasedInt, TypeRepr.of[Int]), _ => TypeRepr.of[Int]) + + assert(methType.hasErasedParams) + assert(methType.erasedParams == List(true, false)) + + val methSym = Symbol.newMethod(tree.symbol, "takesErased", methType, Flags.EmptyFlags, Symbol.noSymbol) + val methDef = DefDef(methSym, _ => Some(Literal(IntConstant(1)))) + + val clsDef = ClassDef.copy(tree)(name, ctr, parents, self, methDef :: body) + + List(clsDef) + case _ => + report.error("Annotation only supports `class`") + List(tree) diff --git a/tests/run-custom-args/erased/quotes-add-erased/Test_2.scala b/tests/run-custom-args/erased/quotes-add-erased/Test_2.scala new file mode 100644 index 000000000000..107fa0833e95 --- /dev/null +++ b/tests/run-custom-args/erased/quotes-add-erased/Test_2.scala @@ -0,0 +1,12 @@ +import scala.language.experimental.erasedDefinitions + +class TakesErased { + def takesErased(erased x: Int, y: Int): Int = ??? +} + +@erasedParamsMethod class Foo extends TakesErased + +@main def Test() = + val foo = Foo() + val v = foo.takesErased(1, 2) + println(v) diff --git a/tests/run-custom-args/erased/quotes-reflection.check b/tests/run-custom-args/erased/quotes-reflection.check new file mode 100644 index 000000000000..838479e0b7af --- /dev/null +++ b/tests/run-custom-args/erased/quotes-reflection.check @@ -0,0 +1,10 @@ +method : () isGiven=false isImplicit=false erasedArgs=List() +method m1: (i: scala.Int) isGiven=true isImplicit=false erasedArgs=List(false) +method m2: (i: scala.Int) isGiven=false isImplicit=false erasedArgs=List(true) +method m3: (i: scala.Int, j: scala.Int) isGiven=false isImplicit=false erasedArgs=List(false, true) +method m4: (i: EC) isGiven=false isImplicit=false erasedArgs=List(true) +val l1: scala.ContextFunction1[scala.Int, scala.Int] +val l2: scala.runtime.ErasedFunction with apply: (x$0: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(true) +val l3: scala.runtime.ErasedFunction with apply: (x$0: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=true erasedParams=List(true) +val l4: scala.runtime.ErasedFunction with apply: (x$0: scala.Int, x$1: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(false, true) +val l5: scala.runtime.ErasedFunction with apply: (x$0: EC @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(true) diff --git a/tests/run-custom-args/erased/quotes-reflection/Macros_1.scala b/tests/run-custom-args/erased/quotes-reflection/Macros_1.scala new file mode 100644 index 000000000000..f7b1187433f0 --- /dev/null +++ b/tests/run-custom-args/erased/quotes-reflection/Macros_1.scala @@ -0,0 +1,35 @@ +import scala.quoted.* + +inline def inspect[A]: String = + ${ inspect2[A] } + +def inspect2[A: Type](using Quotes): Expr[String] = { + import quotes.reflect.* + + val methods = TypeRepr.of[A].typeSymbol.declarations + val names = methods.map { m => + m.tree match + case dd @ DefDef(name, params, r, body) => + val paramStr = + params.map { + case ps: TermParamClause => + val params = ps.params.map(p => s"${p.name}: ${p.tpt.show}").mkString("(", ", ", ")") + s"$params isGiven=${ps.isGiven} isImplicit=${ps.isImplicit} erasedArgs=${ps.erasedArgs}" + case ps: TypeParamClause => ps.params.map(_.show).mkString("[", ", ", "]") + }.mkString("") + s"method $name: $paramStr" + case vd @ ValDef(name, tpt, body) => + tpt.tpe match + case Refinement(parent, "apply", tpe: MethodType) if parent == defn.ErasedFunctionClass.typeRef => + assert(tpt.tpe.isErasedFunctionType) + + val params = tpe.paramNames.zip(tpe.paramTypes).map((n, t) => s"$n: ${t.show}").mkString("(", ", ", ")") + s"val $name: ${parent.show} with apply: ${params} isImplicit=${tpe.isImplicit} erasedParams=${tpe.erasedParams}" + case _ => + s"val $name: ${tpt.show}" + case td @ TypeDef(name, tpt) => s"type $name: ${tpt.show}" + case _ => s"something else: $m" + } + + Expr(names.mkString("\n")) +} diff --git a/tests/run-custom-args/erased/quotes-reflection/Test_2.scala b/tests/run-custom-args/erased/quotes-reflection/Test_2.scala new file mode 100644 index 000000000000..ce1cc8d3dff1 --- /dev/null +++ b/tests/run-custom-args/erased/quotes-reflection/Test_2.scala @@ -0,0 +1,20 @@ +import scala.language.experimental.erasedDefinitions + +erased class EC + +trait X { + def m1(using i: Int): Int + def m2(erased i: Int): Int + def m3(i: Int, erased j: Int): Int + def m4(i: EC): Int + + val l1 = (x: Int) ?=> 5 + val l2 = (erased x: Int) => 5 + val l3 = (erased x: Int) ?=> 5 + val l4 = (x: Int, erased y: Int) => 5 + val l5 = (x: EC) => 5 +} + +@main def Test = { + println(inspect[X]) +} diff --git a/tests/run-custom-args/run-macros-erased/macro-erased/1.scala b/tests/run-custom-args/run-macros-erased/macro-erased/1.scala index 567ef57b1c06..36f583a7dc91 100644 --- a/tests/run-custom-args/run-macros-erased/macro-erased/1.scala +++ b/tests/run-custom-args/run-macros-erased/macro-erased/1.scala @@ -13,8 +13,8 @@ object Macro { def case1(erased i: Expr[Int])(using Quotes): Expr[Int] = '{ 0 } def case2 (i: Int)(erased j: Expr[Int])(using Quotes): Expr[Int] = '{ 0 } def case3(erased i: Expr[Int]) (j: Int)(using Quotes): Expr[Int] = '{ 0 } - def case4 (h: Int)(erased i: Expr[Int], j: Expr[Int])(using Quotes): Expr[Int] = '{ 0 } - def case5(erased i: Expr[Int], j: Expr[Int]) (h: Int)(using Quotes): Expr[Int] = '{ 0 } + def case4 (h: Int)(erased i: Expr[Int], erased j: Expr[Int])(using Quotes): Expr[Int] = '{ 0 } + def case5(erased i: Expr[Int], erased j: Expr[Int]) (h: Int)(using Quotes): Expr[Int] = '{ 0 } def case6 (h: Int)(erased i: Expr[Int])(erased j: Expr[Int])(using Quotes): Expr[Int] = '{ 0 } def case7(erased i: Expr[Int]) (h: Int)(erased j: Expr[Int])(using Quotes): Expr[Int] = '{ 0 } def case8(erased i: Expr[Int])(erased j: Expr[Int]) (h: Int)(using Quotes): Expr[Int] = '{ 0 } diff --git a/tests/run-custom-args/tasty-inspector/stdlibExperimentalDefinitions.scala b/tests/run-custom-args/tasty-inspector/stdlibExperimentalDefinitions.scala index eff76720a7e2..2c49ca46349e 100644 --- a/tests/run-custom-args/tasty-inspector/stdlibExperimentalDefinitions.scala +++ b/tests/run-custom-args/tasty-inspector/stdlibExperimentalDefinitions.scala @@ -1,5 +1,6 @@ import scala.quoted.* import scala.tasty.inspector.* +import scala.language.experimental.erasedDefinitions val experimentalDefinitionInLibrary = Set( @@ -80,6 +81,21 @@ val experimentalDefinitionInLibrary = Set( "scala.quoted.Quotes.reflectModule.SymbolModule.newModule", "scala.quoted.Quotes.reflectModule.SymbolModule.freshName", "scala.quoted.Quotes.reflectModule.SymbolMethods.info", + // Quotes for functions with erased parameters. + "scala.quoted.Quotes.reflectModule.MethodTypeMethods.erasedParams", + "scala.quoted.Quotes.reflectModule.MethodTypeMethods.hasErasedParams", + "scala.quoted.Quotes.reflectModule.TermParamClauseMethods.erasedArgs", + "scala.quoted.Quotes.reflectModule.TermParamClauseMethods.hasErasedArgs", + "scala.quoted.Quotes.reflectModule.defnModule.ErasedFunctionClass", + + // New feature: functions with erased parameters. + // Need erasedDefinitions enabled. + "scala.runtime.ErasedFunction", + "scala.quoted.Quotes.reflectModule.MethodTypeMethods.erasedParams", + "scala.quoted.Quotes.reflectModule.MethodTypeMethods.hasErasedParams", + "scala.quoted.Quotes.reflectModule.TermParamClauseMethods.erasedArgs", + "scala.quoted.Quotes.reflectModule.TermParamClauseMethods.hasErasedArgs", + "scala.quoted.Quotes.reflectModule.defnModule.ErasedFunctionClass" ) diff --git a/tests/run-macros/i12021.check b/tests/run-macros/i12021.check index b244dce80b34..ef998c725209 100644 --- a/tests/run-macros/i12021.check +++ b/tests/run-macros/i12021.check @@ -1,3 +1,5 @@ -X1: (i: scala.Int) isImplicit=true, isGiven=false, isErased=false -X2: (i: scala.Int) isImplicit=false, isGiven=true, isErased=false -X3: (i: scala.Int) isImplicit=false, isGiven=false, isErased=true +X1: (i: scala.Int) isImplicit=true, isGiven=false, erasedArgs=List(false) +X2: (i: scala.Int) isImplicit=false, isGiven=true, erasedArgs=List(false) +X3: (i: scala.Int) isImplicit=false, isGiven=false, erasedArgs=List(true) +X4: (i: scala.Int, j: scala.Int) isImplicit=false, isGiven=false, erasedArgs=List(false, true) +X5: (i: EC) isImplicit=false, isGiven=false, erasedArgs=List(true) diff --git a/tests/run-macros/i12021/Macro_1.scala b/tests/run-macros/i12021/Macro_1.scala index 81703dfbab3d..25cab1786146 100644 --- a/tests/run-macros/i12021/Macro_1.scala +++ b/tests/run-macros/i12021/Macro_1.scala @@ -14,5 +14,5 @@ def inspect2[A: Type](using Quotes): Expr[String] = { val names = ps.params.map(p => s"${p.name}: ${p.tpt.show}").mkString("(", ", ", ")") - Expr(s"${Type.show[A]}: $names isImplicit=${ps.isImplicit}, isGiven=${ps.isGiven}, isErased=${ps.isErased}") + Expr(s"${Type.show[A]}: $names isImplicit=${ps.isImplicit}, isGiven=${ps.isGiven}, erasedArgs=${ps.erasedArgs}") } diff --git a/tests/run-macros/i12021/Test_2.scala b/tests/run-macros/i12021/Test_2.scala index 17f74792ece4..a542b14f1175 100644 --- a/tests/run-macros/i12021/Test_2.scala +++ b/tests/run-macros/i12021/Test_2.scala @@ -1,11 +1,17 @@ import scala.language.experimental.erasedDefinitions +erased class EC + class X1(implicit i: Int) class X2(using i: Int) class X3(erased i: Int) +class X4(i: Int, erased j: Int) +class X5(i: EC) @main def Test = { println(inspect[X1]) println(inspect[X2]) println(inspect[X3]) -} \ No newline at end of file + println(inspect[X4]) + println(inspect[X5]) +} diff --git a/tests/run-macros/tasty-definitions-1.check b/tests/run-macros/tasty-definitions-1.check index c2674c1b7a06..ce7251d7d3ee 100644 --- a/tests/run-macros/tasty-definitions-1.check +++ b/tests/run-macros/tasty-definitions-1.check @@ -108,56 +108,8 @@ ContextFunction24 ContextFunction24 ContextFunction25 ContextFunction25 -ErasedFunction1 -ErasedFunction2 -ErasedFunction3 -ErasedFunction4 -ErasedFunction5 -ErasedFunction6 -ErasedFunction7 -ErasedFunction8 -ErasedFunction9 -ErasedFunction10 -ErasedFunction11 -ErasedFunction12 -ErasedFunction13 -ErasedFunction14 -ErasedFunction15 -ErasedFunction16 -ErasedFunction17 -ErasedFunction18 -ErasedFunction19 -ErasedFunction20 -ErasedFunction21 -ErasedFunction22 -ErasedFunction23 -ErasedFunction24 -ErasedFunction25 -ErasedContextFunction1 -ErasedContextFunction2 -ErasedContextFunction3 -ErasedContextFunction4 -ErasedContextFunction5 -ErasedContextFunction6 -ErasedContextFunction7 -ErasedContextFunction8 -ErasedContextFunction9 -ErasedContextFunction10 -ErasedContextFunction11 -ErasedContextFunction12 -ErasedContextFunction13 -ErasedContextFunction14 -ErasedContextFunction15 -ErasedContextFunction16 -ErasedContextFunction17 -ErasedContextFunction18 -ErasedContextFunction19 -ErasedContextFunction20 -ErasedContextFunction21 -ErasedContextFunction22 -ErasedContextFunction23 -ErasedContextFunction24 -ErasedContextFunction25 +class java.lang.Exception: Erased function classes are not supported. Use a refined `scala.runtime.ErasedFunction` +ErasedFunction Tuple2 Tuple3 Tuple4 diff --git a/tests/run-macros/tasty-definitions-1/quoted_1.scala b/tests/run-macros/tasty-definitions-1/quoted_1.scala index baf0b929f202..bf9e28288486 100644 --- a/tests/run-macros/tasty-definitions-1/quoted_1.scala +++ b/tests/run-macros/tasty-definitions-1/quoted_1.scala @@ -63,11 +63,10 @@ object Macros { printout(defn.FunctionClass(i, isContextual = true).name) printout(defn.FunctionClass(i, isImplicit = true).name) - for (i <- 1 to 25) - printout(defn.FunctionClass(i, isErased = true).name) + // should fail + printout(defn.FunctionClass(1, isErased = true).name) - for (i <- 1 to 25) - printout(defn.FunctionClass(i, isImplicit = true, isErased = true).name) + printout(defn.ErasedFunctionClass.name) for (i <- 2 to 22) printout(defn.TupleClass(i).name)