Skip to content

Commit d81ca0c

Browse files
committed
Rewrite FunctionOf to deal with erased parameters
`apply` and `unapply` now returns and consumes a list of erased parameter flags, just like other function/method constructing functions.
1 parent 11cfc37 commit d81ca0c

File tree

8 files changed

+48
-22
lines changed

8 files changed

+48
-22
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
254254
// If `isParamDependent == false`, the value of `previousParamRefs` is not used.
255255
if isParamDependent then mutable.ListBuffer[TermRef]() else (null: ListBuffer[TermRef] | Null).uncheckedNN
256256

257-
def valueParam(name: TermName, origInfo: Type): TermSymbol =
257+
def valueParam(name: TermName, origInfo: Type, isErased: Boolean): TermSymbol =
258258
val maybeImplicit =
259259
if tp.isContextualMethod then Given
260260
else if tp.isImplicitMethod then Implicit
261261
else EmptyFlags
262-
val maybeErased = if tp.isErasedMethod then Erased else EmptyFlags // TODO @natsukagami change this
262+
val maybeErased = if isErased then Erased else EmptyFlags
263263

264264
def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord)
265265

@@ -270,14 +270,18 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
270270
else makeSym(origInfo)
271271
end valueParam
272272

273+
val erasedParams = if tp.isErasedMethod
274+
then tp.companion.asInstanceOf[ErasedMethodCompanion].isErased
275+
else List.fill(tp.paramNames.size) { false }
276+
273277
val (vparams: List[TermSymbol], remaining1) =
274278
if tp.paramNames.isEmpty then (Nil, remaining)
275279
else remaining match
276280
case vparams :: remaining1 =>
277281
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
278282
(vparams.asInstanceOf[List[TermSymbol]], remaining1)
279283
case nil =>
280-
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
284+
(tp.paramNames.lazyZip(tp.paramInfos).lazyZip(erasedParams).map(valueParam), Nil)
281285
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1)
282286
(rtp, vparams :: paramss)
283287
case _ =>

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,8 @@ class CheckCaptures extends Recheck, SymTransformer:
611611
//println(i"check conforms $actual1 <<< $expected1")
612612
super.checkConformsExpr(actual1, expected1, tree)
613613

614-
private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean, isErased: Boolean)(using Context): Type =
615-
val erasedParams = args.map(_ => isErased) // TODO @natsukagami fix
616-
MethodType.companion(isContextual = isContextual, isErased = erasedParams)(args, resultType)
614+
private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean, isErased: List[Boolean])(using Context): Type =
615+
MethodType.companion(isContextual = isContextual, isErased = isErased)(args, resultType)
617616
.toFunctionType(isJava = false, alwaysDependent = true)
618617

619618
/** Turn `expected` into a dependent function when `actual` is dependent. */

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ extends tpd.TreeTraverser:
3535
* arguments `argTypes` and result `resType`.
3636
*/
3737
private def depFun(tycon: Type, argTypes: List[Type], resType: Type)(using Context): Type =
38-
assert(!defn.isErasedFunctionType(tycon)) // TODO @natsukagami not sure how to know
39-
// erased parameter info here.
4038
MethodType.companion(
4139
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
42-
isErased = List()
40+
isErased = defn.erasedFunctionParameters(tycon)
4341
)(argTypes, resType)
4442
.toFunctionType(isJava = false, alwaysDependent = true)
4543

@@ -256,7 +254,7 @@ extends tpd.TreeTraverser:
256254
private def expandThrowsAlias(tp: Type)(using Context) = tp match
257255
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
258256
// hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->`
259-
defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = true)
257+
defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = List(true))
260258
case _ => tp
261259

262260
private def expandThrowsAliases(using Context) = new TypeMap:

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,14 +1086,23 @@ class Definitions {
10861086
sym.owner.linkedClass.typeRef
10871087

10881088
object FunctionOf {
1089-
def apply(args: List[Type], resultType: Type, isContextual: Boolean = false, isErased: Boolean = false)(using Context): Type =
1090-
FunctionType(args.length, isContextual, isErased).appliedTo(args ::: resultType :: Nil)
1091-
def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean, Boolean)] = {
1089+
def apply(args: List[Type], resultType: Type, isContextual: Boolean = false, isErased: List[Boolean] = List())(using Context): Type =
1090+
assert(isErased.size == 0 || args.size == isErased.size)
1091+
val erasedParams = isErased.padTo(args.size, false)
1092+
val isErasedFn = erasedParams.contains(true)
1093+
val fnType =
1094+
FunctionType(args.length, isContextual, isErased = isErasedFn).appliedTo(args ::: resultType :: Nil)
1095+
if isErasedFn then
1096+
val mt = MethodType.companion(isContextual, false, isErased)(args, resultType)
1097+
RefinedType(fnType, nme.apply, mt)
1098+
else fnType
1099+
def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean, List[Boolean])] = {
10921100
val tsym = ft.typeSymbol
10931101
if isFunctionClass(tsym) && ft.isRef(tsym) then
10941102
val targs = ft.dealias.argInfos
1103+
val isErased = erasedFunctionParameters(ft)
10951104
if (targs.isEmpty) None
1096-
else Some(targs.init, targs.last, tsym.name.isContextFunction, tsym.name.isErasedFunction)
1105+
else Some(targs.init, targs.last, tsym.name.isContextFunction, isErased)
10971106
else None
10981107
}
10991108
}
@@ -1503,9 +1512,12 @@ class Definitions {
15031512

15041513
/** Is an erased function class.
15051514
* - ErasedFunctionN for N > 0
1506-
* - ErasedContextFunctionN for N > 0
1515+
*
1516+
* This should not be used for checking if a type is an erased function type. Use `isErasedFunctionType`.
1517+
* Only helpful if you have a `RefinedType(parent, nme.apply, mt)` and
1518+
* want to check if `parent` is an erased function class.
15071519
*/
1508-
def isErasedFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isErasedFunction // TODO @natsukagami fix this
1520+
def isErasedFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isErasedFunction
15091521

15101522
/** Is either FunctionXXL or a class that will be erased to FunctionXXL
15111523
* - FunctionXXL
@@ -1804,9 +1816,18 @@ class Definitions {
18041816
val tp1 = asContextFunctionType(tp)
18051817
if tp1.exists then
18061818
val args = tp1.dropDependentRefinement.argInfos
1807-
Some((args.init, args.last, List(tp1.typeSymbol.name.isErasedFunction))) // TODO @natsukagami fix this
1819+
val isErased = erasedFunctionParameters(tp1)
1820+
Some((args.init, args.last, isErased))
18081821
else None
18091822

1823+
/* Returns a list of erased booleans marking whether parameters are erased, for a function type. */
1824+
def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match {
1825+
case RefinedType(parent, nme.apply, mt) if defn.isErasedFunctionType(parent) =>
1826+
mt.asInstanceOf[MethodType].companion.asInstanceOf[ErasedMethodCompanion].isErased
1827+
case tp if isFunctionType(tp) => List.fill(functionArity(tp)) { false }
1828+
case _ => List()
1829+
}
1830+
18101831
def isErasedFunctionType(tp: Type)(using Context): Boolean =
18111832
tp.derivesFrom(defn.ErasedFunctionClass)
18121833

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ object NameOps {
208208
if str == mustHave then found = true
209209
idx + str.length
210210
else idx
211-
212211
skip(skip(skip(0, "Impure"), "Erased"), "Context") == suffixStart
213212
&& found
214213
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,15 +1839,20 @@ object Types {
18391839
case mt: MethodType if !mt.isParamDependent =>
18401840
val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast
18411841
val isContextual = mt.isContextualMethod && !ctx.erasedTypes
1842-
val isErased = mt.isErasedMethod && !ctx.erasedTypes // TODO @natsukagami fix this
1842+
val isErased =
1843+
(if mt.isErasedMethod && !ctx.erasedTypes then
1844+
mt.companion.asInstanceOf[ErasedMethodCompanion].isErased
1845+
else List.fill(mt.paramInfos.size) { false }) dropRight dropLast
1846+
// println(s"$mt => $isErased")
18431847
val result1 = mt.nonDependentResultApprox match {
18441848
case res: MethodType => res.toFunctionType(isJava)
18451849
case res => res
18461850
}
18471851
val funType = defn.FunctionOf(
18481852
formals1 mapConserve (_.translateFromRepeated(toArray = isJava)),
18491853
result1, isContextual, isErased)
1850-
if alwaysDependent || isErased || mt.isResultDependent then RefinedType(funType, nme.apply, mt)
1854+
if alwaysDependent || mt.isResultDependent then
1855+
RefinedType(funType, nme.apply, mt)
18511856
else funType
18521857
}
18531858

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
110110
def functionTypeEqual(baseFun: Type, actualArgs: List[Type],
111111
actualRet: Type, expected: Type) =
112112
expected =:= defn.FunctionOf(actualArgs, actualRet,
113-
defn.isContextFunctionType(baseFun), defn.isErasedFunctionType(baseFun))
113+
defn.isContextFunctionType(baseFun), defn.erasedFunctionParameters(baseFun))
114114
val arity: Int =
115115
if defn.isErasedFunctionType(fun) || defn.isErasedFunctionType(fun) then -1 // TODO support?
116116
else if defn.isFunctionType(fun) then

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
15541554
|because it has internal parameter dependencies""")
15551555
else if ((tree.tpt `eq` untpd.ContextualEmptyTree) && mt.paramNames.isEmpty)
15561556
// Note implicitness of function in target type since there are no method parameters that indicate it.
1557-
TypeTree(defn.FunctionOf(Nil, mt.resType, isContextual = true, isErased = false))
1557+
TypeTree(defn.FunctionOf(Nil, mt.resType, isContextual = true, isErased = List()))
15581558
else
15591559
EmptyTree
15601560
}

0 commit comments

Comments
 (0)