Skip to content

Commit 1040923

Browse files
committed
Store a list of erased flags in FunctionWithMods, allow parsing more erased function types
`(Int, erased Int) => Int` is now allowed as a function type.
1 parent 292103a commit 1040923

File tree

6 files changed

+81
-57
lines changed

6 files changed

+81
-57
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,10 +1493,10 @@ object desugar {
14931493
case vd: ValDef => vd
14941494
}
14951495

1496-
def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(using Context): Function = {
1497-
val mods = if (isErased) Given | Erased else Given
1496+
def makeContextualFunction(formals: List[Tree], body: Tree, isErased: List[Boolean])(using Context): Function = {
1497+
val mods = Given
14981498
val params = makeImplicitParameters(formals, mods)
1499-
FunctionWithMods(params, body, Modifiers(mods))
1499+
FunctionWithMods(params, body, Modifiers(mods), isErased)
15001500
}
15011501

15021502
private def derivedValDef(original: Tree, named: NameTree, tpt: Tree, rhs: Tree, mods: Modifiers)(using Context) = {
@@ -1829,6 +1829,7 @@ object desugar {
18291829
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
18301830
case _ =>
18311831
annotate(tpnme.retains, parent)
1832+
case f: FunctionWithMods if f.isErased.contains(true) => makeErasedFunctionValDefs(f, pt)
18321833
}
18331834
desugared.withSpan(tree.span)
18341835
}
@@ -1904,6 +1905,22 @@ object desugar {
19041905
TypeDef(tpnme.REFINE_CLASS, impl).withFlags(Trait)
19051906
}
19061907

1908+
/** Make erased function definitions use only ValDefs */
1909+
def makeErasedFunctionValDefs(tree: FunctionWithMods, pt: Type)(using Context): Function = {
1910+
val isErased = tree.isErased
1911+
val Function(args, result) = tree
1912+
args match {
1913+
case (_ : ValDef) :: _ => tree // ValDef case can be easily handled
1914+
case _ if !isErased.contains(true) => tree // not erased
1915+
case _ if !ctx.mode.is(Mode.Type) => tree
1916+
case _ =>
1917+
val applyVParams = args.zipWithIndex.map {
1918+
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(tree.mods.flags)
1919+
}
1920+
untpd.FunctionWithMods(applyVParams, tree.body, tree.mods, isErased)
1921+
}
1922+
}
1923+
19071924
/** Returns list of all pattern variables, possibly with their types,
19081925
* without duplicates
19091926
*/

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
7676
}
7777

7878
/** A function type or closure with `implicit`, `erased`, or `given` modifiers */
79-
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile)
80-
extends Function(args, body)
79+
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers, val isErased: List[Boolean])(implicit @constructorOnly src: SourceFile)
80+
extends Function(args, body) {
81+
assert(args.length == isErased.length)
82+
}
8183

8284
/** A polymorphic function type */
8385
case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree {

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,7 @@ object Parsers {
14641464
def typ(): Tree =
14651465
val start = in.offset
14661466
var imods = Modifiers()
1467+
var erasedArgs: ListBuffer[Boolean] = ListBuffer()
14671468
def functionRest(params: List[Tree]): Tree =
14681469
val paramSpan = Span(start, in.lastOffset)
14691470
atSpan(start, in.offset) {
@@ -1495,10 +1496,10 @@ object Parsers {
14951496
if isByNameType(tpt) then
14961497
syntaxError(em"parameter of type lambda may not be call-by-name", tpt.span)
14971498
TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], resultType)
1498-
else if imods.isOneOf(Given | Impure) then
1499+
else if imods.isOneOf(Given | Impure) || erasedArgs.exists(_ == true) then
14991500
if imods.is(Given) && params.isEmpty then
15001501
syntaxError(em"context function types require at least one parameter", paramSpan)
1501-
FunctionWithMods(params, resultType, imods)
1502+
FunctionWithMods(params, resultType, imods, erasedArgs.toList)
15021503
else if !ctx.settings.YkindProjector.isDefault then
15031504
val (newParams :+ newResultType, tparams) = replaceKindProjectorPlaceholders(params :+ resultType): @unchecked
15041505
lambdaAbstract(tparams, Function(newParams, newResultType))
@@ -1517,25 +1518,29 @@ object Parsers {
15171518
}
15181519
else {
15191520
val paramStart = in.offset
1520-
val firstMods = if isErased then addModifier(imods) else imods
1521+
def addErased() =
1522+
erasedArgs.addOne(isErased)
1523+
if isErased then { in.skipToken(); }
1524+
addErased()
15211525
val ts = in.currentRegion.withCommasExpected {
15221526
funArgType() match
15231527
case Ident(name) if name != tpnme.WILDCARD && in.isColon =>
15241528
isValParamList = true
15251529
def funParam(start: Offset, mods: Modifiers) = {
15261530
atSpan(start) {
1527-
val mods1 = if isErased then addModifier(mods) else mods
1528-
typedFunParam(in.offset, ident(), mods1)
1531+
addErased()
1532+
typedFunParam(in.offset, ident(), imods)
15291533
}
15301534
}
15311535
commaSeparatedRest(
1532-
typedFunParam(paramStart, name.toTermName, firstMods),
1536+
typedFunParam(paramStart, name.toTermName, imods),
15331537
() => funParam(in.offset, imods))
15341538
case t =>
1535-
// For now we just reject `erased` in (T, U) => V definitions (i.e. you need parameter names)
1536-
if firstMods.is(Erased) then
1537-
syntaxError(em"Implementation restriction: erased parameters are not allowed without parameter names", paramStart)
1538-
commaSeparatedRest(t, funArgType)
1539+
def funParam() = {
1540+
addErased()
1541+
funArgType()
1542+
}
1543+
commaSeparatedRest(t, funParam)
15391544
}
15401545
accept(RPAREN)
15411546
if isValParamList || in.isArrow || isPureArrow then
@@ -1589,7 +1594,7 @@ object Parsers {
15891594
if isPureArrow then
15901595
functionRest(t :: Nil)
15911596
else
1592-
if (imods.is(Erased) && !t.isInstanceOf[FunctionWithMods])
1597+
if (erasedArgs.exists(_ == true) && !t.isInstanceOf[FunctionWithMods])
15931598
syntaxError(ErasedTypesCanOnlyBeFunctionTypes(), implicitKwPos(start))
15941599
t
15951600
end typ

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -647,27 +647,29 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
647647
case str: Literal => strText(str)
648648
}
649649
toText(id) ~ "\"" ~ Text(segments map segmentText, "") ~ "\""
650-
case Function(args, body) =>
650+
case fn @ Function(args, body) =>
651651
var implicitSeen: Boolean = false
652652
var isGiven: Boolean = false
653-
var isErased: Boolean = false
654-
def argToText(arg: Tree) = arg match {
653+
val isErased = fn match {
654+
case fn: FunctionWithMods => fn.isErased
655+
case _ => fn.args.map(_ => false)
656+
}
657+
def argToText(arg: Tree, isErased: Boolean) = arg match {
655658
case arg @ ValDef(name, tpt, _) =>
656659
val implicitText =
657660
if ((arg.mods.is(Given))) { isGiven = true; "" }
658-
else if ((arg.mods.is(Erased))) { isErased = true; "" }
659661
else if ((arg.mods.is(Implicit)) && !implicitSeen) { implicitSeen = true; keywordStr("implicit ") }
660662
else ""
661-
implicitText ~ toText(name) ~ optAscription(tpt)
663+
val erasedText = if isErased then keywordStr("erased ") else ""
664+
implicitText ~ erasedText ~ toText(name) ~ optAscription(tpt)
662665
case _ =>
663666
toText(arg)
664667
}
665668
val argsText = args match {
666-
case (arg @ ValDef(_, tpt, _)) :: Nil if tpt.isEmpty => argToText(arg)
669+
case (arg @ ValDef(_, tpt, _)) :: Nil if tpt.isEmpty => argToText(arg, isErased(0))
667670
case _ =>
668671
"("
669-
~ keywordText("erased ").provided(isErased)
670-
~ Text(args.map(argToText), ", ")
672+
~ Text(args.zip(isErased).map(argToText), ", ")
671673
~ ")"
672674
}
673675
val isPure =

compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ object EtaExpansion extends LiftImpure {
285285
val body = Apply(lifted, ids)
286286
if (mt.isContextualMethod) body.setApplyKind(ApplyKind.Using)
287287
val fn =
288-
if (mt.isContextualMethod) new untpd.FunctionWithMods(params, body, Modifiers(Given))
289-
else if (mt.isImplicitMethod) new untpd.FunctionWithMods(params, body, Modifiers(Implicit))
288+
if (mt.isContextualMethod) new untpd.FunctionWithMods(params, body, Modifiers(Given), params.map(_ => false))
289+
else if (mt.isImplicitMethod) new untpd.FunctionWithMods(params, body, Modifiers(Implicit), params.map(_ => false))
290290
else untpd.Function(params, body)
291291
if (defs.nonEmpty) untpd.Block(defs.toList map (untpd.TypedSplice(_)), fn) else fn
292292
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,32 +1290,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
12901290

12911291
def typedFunctionType(tree: untpd.Function, pt: Type)(using Context): Tree = {
12921292
val untpd.Function(args, body) = tree
1293-
var funFlags = tree match {
1294-
case tree: untpd.FunctionWithMods => tree.mods.flags
1295-
case _ => EmptyFlags
1293+
var (funFlags, isErased) = tree match {
1294+
case tree: untpd.FunctionWithMods => (tree.mods.flags, tree.isErased)
1295+
case _ => (EmptyFlags, args.map(_ => false))
12961296
}
12971297

1298-
assert(!funFlags.is(Erased) || !args.isEmpty, "An empty function cannot not be erased")
1299-
13001298
val numArgs = args.length
13011299
val isContextual = funFlags.is(Given)
1302-
val isErased = args.collect({ case x: untpd.ValDef => x }).exists(_.mods.flags.is(Erased))
13031300
val isImpure = funFlags.is(Impure)
1304-
val funSym = defn.FunctionSymbol(numArgs, isContextual, isErased, isImpure)
1305-
1306-
/** If `app` is a function type with arguments that are all erased classes,
1307-
* turn it into an erased function type.
1308-
*/
1309-
def propagateErased(app: Tree): Tree = app match
1310-
case AppliedTypeTree(tycon: TypeTree, args)
1311-
if !isErased
1312-
&& numArgs > 0
1313-
&& args.indexWhere(!_.tpe.isErasedClass) == numArgs =>
1314-
val tycon1 = TypeTree(defn.FunctionSymbol(numArgs, isContextual, true, isImpure).typeRef)
1315-
.withSpan(tycon.span)
1316-
assignType(cpy.AppliedTypeTree(app)(tycon1, args), tycon1, args)
1317-
case _ =>
1318-
app
1301+
val funSym = defn.FunctionSymbol(numArgs, isContextual, isErased.contains(true), isImpure)
13191302

13201303
/** Typechecks dependent function type with given parameters `params` */
13211304
def typedDependent(params: List[untpd.ValDef])(using Context): Tree =
@@ -1330,7 +1313,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13301313
if funFlags.is(Given) then params.map(_.withAddedFlags(Given))
13311314
else params
13321315
val params2 = params1.map(fixThis.transformSub)
1333-
val appDef0 = untpd.DefDef(nme.apply, List(params2), body, EmptyTree).withSpan(tree.span)
1316+
val params3 = params2.zipWithConserve(isErased) { (arg, isErased) =>
1317+
if isErased then arg.withAddedFlags(Erased) else arg
1318+
}
1319+
val appDef0 = untpd.DefDef(nme.apply, List(params3), body, EmptyTree).withSpan(tree.span)
13341320
index(appDef0 :: Nil)
13351321
val appDef = typed(appDef0).asInstanceOf[DefDef]
13361322
val mt = appDef.symbol.info.asInstanceOf[MethodType]
@@ -1339,7 +1325,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13391325
val resTpt = TypeTree(mt.nonDependentResultApprox).withSpan(body.span)
13401326
val typeArgs = appDef.termParamss.head.map(_.tpt) :+ resTpt
13411327
val tycon = TypeTree(funSym.typeRef)
1342-
val core = propagateErased(AppliedTypeTree(tycon, typeArgs))
1328+
val core = AppliedTypeTree(tycon, typeArgs)
13431329
RefinedTypeTree(core, List(appDef), ctx.owner.asClass)
13441330
end typedDependent
13451331

@@ -1348,17 +1334,22 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13481334
typedDependent(args.asInstanceOf[List[untpd.ValDef]])(
13491335
using ctx.fresh.setOwner(newRefinedClassSymbol(tree.span)).setNewScope)
13501336
case _ =>
1351-
propagateErased(
1352-
typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt))
1337+
if isErased.contains(true) then
1338+
typedFunctionType(
1339+
desugar.makeErasedFunctionValDefs(tree.asInstanceOf[untpd.FunctionWithMods], pt),
1340+
pt
1341+
)
1342+
else
1343+
typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt)
13531344
}
13541345
}
13551346

13561347
def typedFunctionValue(tree: untpd.Function, pt: Type)(using Context): Tree = {
13571348
val untpd.Function(params: List[untpd.ValDef] @unchecked, _) = tree: @unchecked
13581349

1359-
val isContextual = tree match {
1360-
case tree: untpd.FunctionWithMods => tree.mods.is(Given)
1361-
case _ => false
1350+
val (isContextual, isDefinedErased) = tree match {
1351+
case tree: untpd.FunctionWithMods => (tree.mods.is(Given), tree.isErased)
1352+
case _ => (false, tree.args.map(_ => false))
13621353
}
13631354

13641355
/** The function body to be returned in the closure. Can become a TypedSplice
@@ -1467,8 +1458,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
14671458

14681459
/** Returns the type and whether the parameter is erased */
14691460
def protoFormal(i: Int): (Type, Boolean) =
1470-
if (protoFormals.length == params.length) (protoFormals(i), protoIsErased(i))
1471-
else (errorType(WrongNumberOfParameters(protoFormals.length), tree.srcPos), false)
1461+
if (protoFormals.length == params.length) (protoFormals(i), protoIsErased(i) || isDefinedErased(i))
1462+
else (errorType(WrongNumberOfParameters(protoFormals.length), tree.srcPos), isDefinedErased(i))
14721463

14731464
/** Is `formal` a product type which is elementwise compatible with `params`? */
14741465
def ptIsCorrectProduct(formal: Type) =
@@ -3034,7 +3025,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
30343025
else formals.map(untpd.TypeTree)
30353026
}
30363027

3037-
val ifun = desugar.makeContextualFunction(paramTypes, tree, defn.isErasedFunctionType(pt))
3028+
val erasedParams = pt match {
3029+
case RefinedType(parent, nme.apply, mt) if defn.isErasedFunctionType(parent) =>
3030+
val companion = mt.asInstanceOf[MethodType].companion.asInstanceOf[ErasedMethodCompanion]
3031+
companion.isErased
3032+
case _ => paramTypes.map(_ => false)
3033+
}
3034+
3035+
val ifun = desugar.makeContextualFunction(paramTypes, tree, erasedParams)
30383036
typr.println(i"make contextual function $tree / $pt ---> $ifun")
30393037
typedFunctionValue(ifun, pt)
30403038
}

0 commit comments

Comments
 (0)