Skip to content

Commit 7a683ab

Browse files
authored
Make polymorphic functions more efficient and expressive (#17548)
This PR enhances polymorphic function types in two ways: - Dependent result types can now be inferred from the expected type - polymorphic lambdas are now implemented using JVM lambdas when possible instead of anonymous classes. Additionally, we fix the logic for renaming bound variables when pretty-printing lambdas and fix the handling of `this` in refinements.
2 parents a8e9312 + d7a345f commit 7a683ab

25 files changed

+248
-116
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 42 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,40 @@ object desugar {
10611061
name
10621062
}
10631063

1064+
/** Strip parens and empty blocks around the body of `tree`. */
1065+
def normalizePolyFunction(tree: PolyFunction)(using Context): PolyFunction =
1066+
def stripped(body: Tree): Tree = body match
1067+
case Parens(body1) =>
1068+
stripped(body1)
1069+
case Block(Nil, body1) =>
1070+
stripped(body1)
1071+
case _ => body
1072+
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]
1073+
1074+
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
1075+
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1076+
*/
1077+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1078+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1079+
val funFlags = fun match
1080+
case fun: FunctionWithMods =>
1081+
fun.mods.flags
1082+
case _ => EmptyFlags
1083+
1084+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1085+
// val isImpure = funFlags.is(Impure)
1086+
1087+
// Function flags to be propagated to each parameter in the desugared method type.
1088+
val paramFlags = funFlags.toTermFlags & Given
1089+
val vparams = vparamTypes.zipWithIndex.map:
1090+
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
1091+
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1092+
1093+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1094+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
1095+
)).withSpan(tree.span)
1096+
end makePolyFunctionType
1097+
10641098
/** Invent a name for an anonympus given of type or template `impl`. */
10651099
def inventGivenOrExtensionName(impl: Tree)(using Context): SimpleName =
10661100
val str = impl match
@@ -1454,17 +1488,20 @@ object desugar {
14541488
}
14551489

14561490
/** Make closure corresponding to function.
1457-
* params => body
1491+
* [tparams] => params => body
14581492
* ==>
1459-
* def $anonfun(params) = body
1493+
* def $anonfun[tparams](params) = body
14601494
* Closure($anonfun)
14611495
*/
1462-
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree | Null = null, isContextual: Boolean, span: Span)(using Context): Block =
1496+
def makeClosure(tparams: List[TypeDef], vparams: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
1497+
val paramss: List[ParamClause] =
1498+
if tparams.isEmpty then vparams :: Nil
1499+
else tparams :: vparams :: Nil
14631500
Block(
1464-
DefDef(nme.ANON_FUN, params :: Nil, if (tpt == null) TypeTree() else tpt, body)
1501+
DefDef(nme.ANON_FUN, paramss, if (tpt == null) TypeTree() else tpt, body)
14651502
.withSpan(span)
14661503
.withMods(synthetic | Artifact),
1467-
Closure(Nil, Ident(nme.ANON_FUN), if (isContextual) ContextualEmptyTree else EmptyTree))
1504+
Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
14681505

14691506
/** If `nparams` == 1, expand partial function
14701507
*
@@ -1753,62 +1790,6 @@ object desugar {
17531790
}
17541791
}
17551792

1756-
def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
1757-
case Parens(body1) =>
1758-
makePolyFunction(targs, body1, pt)
1759-
case Block(Nil, body1) =>
1760-
makePolyFunction(targs, body1, pt)
1761-
case Function(vargs, res) =>
1762-
assert(targs.nonEmpty)
1763-
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
1764-
val mods = body match {
1765-
case body: FunctionWithMods => body.mods
1766-
case _ => untpd.EmptyModifiers
1767-
}
1768-
val polyFunctionTpt = ref(defn.PolyFunctionType)
1769-
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1770-
if (ctx.mode.is(Mode.Type)) {
1771-
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1772-
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1773-
1774-
val applyVParams = vargs.zipWithIndex.map {
1775-
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
1776-
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags.toTermFlags)
1777-
}
1778-
RefinedTypeTree(polyFunctionTpt, List(
1779-
DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree).withFlags(Synthetic)
1780-
))
1781-
}
1782-
else {
1783-
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1784-
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1785-
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1786-
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1787-
1788-
def typeTree(tp: Type) = tp match
1789-
case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1790-
var bail = false
1791-
def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match
1792-
case tp: TypeRef => ref(tp)
1793-
case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name)
1794-
case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_)))
1795-
case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree }
1796-
val mapped = mapper(mt.resultType, topLevel = true)
1797-
if bail then TypeTree() else mapped
1798-
case _ => TypeTree()
1799-
1800-
val applyVParams = vargs.asInstanceOf[List[ValDef]]
1801-
.map(varg => varg.withAddedFlags(mods.flags | Param))
1802-
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1803-
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
1804-
))
1805-
}
1806-
case _ =>
1807-
// may happen for erroneous input. An error will already have been reported.
1808-
assert(ctx.reporter.errorsReported)
1809-
EmptyTree
1810-
}
1811-
18121793
// begin desugar
18131794

18141795
// Special case for `Parens` desugaring: unlike all the desugarings below,
@@ -1821,8 +1802,6 @@ object desugar {
18211802
}
18221803

18231804
val desugared = tree match {
1824-
case PolyFunction(targs, body) =>
1825-
makePolyFunction(targs, body, pt) orElse tree
18261805
case SymbolLit(str) =>
18271806
Apply(
18281807
ref(defn.ScalaSymbolClass.companionModule.termRef),

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,7 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
420420
case Closure(_, meth, _) => true
421421
case Block(Nil, expr) => isContextualClosure(expr)
422422
case Block(DefDef(nme.ANON_FUN, params :: _, _, _) :: Nil, cl: Closure) =>
423-
if params.isEmpty then
424-
cl.tpt.eq(untpd.ContextualEmptyTree) || defn.isContextFunctionType(cl.tpt.typeOpt)
425-
else
426-
isUsingClause(params)
423+
isUsingClause(params)
427424
case _ => false
428425
}
429426

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,6 @@ object Trees {
11921192

11931193
@sharable val EmptyTree: Thicket = genericEmptyTree
11941194
@sharable val EmptyValDef: ValDef = genericEmptyValDef
1195-
@sharable val ContextualEmptyTree: Thicket = new EmptyTree() // an empty tree marking a contextual closure
11961195

11971196
// ----- Auxiliary creation methods ------------------
11981197

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
151151
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
152152

153153
/** Short-lived usage in typer, does not need copy/transform/fold infrastructure */
154-
case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree
154+
case class DependentTypeTree(tp: (List[TypeSymbol], List[TermSymbol]) => Type)(implicit @constructorOnly src: SourceFile) extends Tree
155155

156156
@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY)(NoSource) with WithoutTypeOrPos[Untyped] {
157157
override def isEmpty: Boolean = true

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,12 @@ object NameOps {
236236
*/
237237
def isPlainFunction(using Context): Boolean = functionArity >= 0
238238

239-
/** Is a function name that contains `mustHave` as a substring */
240-
private def isSpecificFunction(mustHave: String)(using Context): Boolean =
239+
/** Is a function name that contains `mustHave` as a substring
240+
* and has arity `minArity` or greater.
241+
*/
242+
private def isSpecificFunction(mustHave: String, minArity: Int = 0)(using Context): Boolean =
241243
val suffixStart = functionSuffixStart
242-
isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= 0
244+
isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= minArity
243245

244246
def isContextFunction(using Context): Boolean = isSpecificFunction("Context")
245247
def isImpureFunction(using Context): Boolean = isSpecificFunction("Impure")

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,8 @@ object Types {
18721872
if alwaysDependent || mt.isResultDependent then
18731873
RefinedType(funType, nme.apply, mt)
18741874
else funType
1875+
case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent =>
1876+
RefinedType(defn.PolyFunctionType, nme.apply, poly)
18751877
}
18761878

18771879
/** The signature of this type. This is by default NotAMethod,

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,7 @@ object Parsers {
15111511
TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], resultType)
15121512
else if imods.isOneOf(Given | Impure) || erasedArgs.contains(true) then
15131513
if imods.is(Given) && params.isEmpty then
1514+
imods &~= Given
15141515
syntaxError(em"context function types require at least one parameter", paramSpan)
15151516
FunctionWithMods(params, resultType, imods, erasedArgs.toList)
15161517
else if !ctx.settings.YkindProjector.isDefault then

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ class PlainPrinter(_ctx: Context) extends Printer {
297297

298298
protected def paramsText(lam: LambdaType): Text = {
299299
val erasedParams = lam.erasedParams
300-
def paramText(name: Name, tp: Type, erased: Boolean) =
301-
keywordText("erased ").provided(erased) ~ toText(name) ~ lambdaHash(lam) ~ toTextRHS(tp, isParameter = true)
302-
Text(lam.paramNames.lazyZip(lam.paramInfos).lazyZip(erasedParams).map(paramText), ", ")
300+
def paramText(ref: ParamRef, erased: Boolean) =
301+
keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ lambdaHash(lam) ~ toTextRHS(ref.underlying, isParameter = true)
302+
Text(lam.paramRefs.lazyZip(erasedParams).map(paramText), ", ")
303303
}
304304

305305
protected def ParamRefNameString(name: Name): String = nameString(name)
@@ -363,7 +363,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
363363
case tp @ ConstantType(value) =>
364364
toText(value)
365365
case pref: TermParamRef =>
366-
nameString(pref.binder.paramNames(pref.paramNum)) ~ lambdaHash(pref.binder)
366+
ParamRefNameString(pref) ~ lambdaHash(pref.binder)
367367
case tp: RecThis =>
368368
val idx = openRecs.reverse.indexOf(tp.binder)
369369
if (idx >= 0) selfRecName(idx + 1)

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
174174
~ " " ~ argText(args.last)
175175
}
176176

177-
private def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text = Str("")): Text = info match
177+
protected def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text = Str("")): Text = info match
178178
case info: MethodType =>
179179
val capturesRoot = refs == rootSetText
180180
changePrec(GlobalPrec) {

compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
196196
case AmbiguousExtensionMethodID // errorNumber 180
197197
case UnqualifiedCallToAnyRefMethodID // errorNumber: 181
198198
case NotConstantID // errorNumber: 182
199+
case ClosureCannotHaveInternalParameterDependenciesID // errorNumber: 183
199200

200201
def errorNumber = ordinal - 1
201202

compiler/src/dotty/tools/dotc/reporting/Message.scala

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ object Message:
5151
*/
5252
private class Seen(disambiguate: Boolean):
5353

54+
/** The set of lambdas that were opened at some point during printing. */
55+
private val openedLambdas = new collection.mutable.HashSet[LambdaType]
56+
57+
/** Register that `tp` was opened during printing. */
58+
def openLambda(tp: LambdaType): Unit =
59+
openedLambdas += tp
60+
5461
val seen = new collection.mutable.HashMap[SeenKey, List[Recorded]]:
5562
override def default(key: SeenKey) = Nil
5663

@@ -89,8 +96,22 @@ object Message:
8996
val existing = seen(key)
9097
lazy val dealiased = followAlias(entry)
9198

92-
// alts: The alternatives in `existing` that are equal, or follow (an alias of) `entry`
93-
var alts = existing.dropWhile(alt => dealiased ne followAlias(alt))
99+
/** All lambda parameters with the same name are given the same superscript as
100+
* long as their corresponding binder has been printed.
101+
* See tests/neg/lambda-rename.scala for test cases.
102+
*/
103+
def sameSuperscript(cur: Recorded, existing: Recorded) =
104+
(cur eq existing) ||
105+
(cur, existing).match
106+
case (cur: ParamRef, existing: ParamRef) =>
107+
(cur.paramName eq existing.paramName) &&
108+
openedLambdas.contains(cur.binder) &&
109+
openedLambdas.contains(existing.binder)
110+
case _ =>
111+
false
112+
113+
// The length of alts corresponds to the number of superscripts we need to print.
114+
var alts = existing.dropWhile(alt => !sameSuperscript(dealiased, followAlias(alt)))
94115
if alts.isEmpty then
95116
alts = entry :: existing
96117
seen(key) = alts
@@ -208,10 +229,20 @@ object Message:
208229
case tp: SkolemType => seen.record(tp.repr.toString, isType = true, tp)
209230
case _ => super.toTextRef(tp)
210231

232+
override def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text): Text =
233+
info match
234+
case info: LambdaType =>
235+
seen.openLambda(info)
236+
case _ =>
237+
super.toTextMethodAsFunction(info, isPure, refs)
238+
211239
override def toText(tp: Type): Text =
212240
if !tp.exists || tp.isErroneous then seen.nonSensical = true
213241
tp match
214242
case tp: TypeRef if useSourceModule(tp.symbol) => Str("object ") ~ super.toText(tp)
243+
case tp: LambdaType =>
244+
seen.openLambda(tp)
245+
super.toText(tp)
215246
case _ => super.toText(tp)
216247

217248
override def toText(sym: Symbol): Text =

compiler/src/dotty/tools/dotc/reporting/messages.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2920,3 +2920,10 @@ class MatchTypeScrutineeCannotBeHigherKinded(tp: Type)(using Context)
29202920
extends TypeMsg(MatchTypeScrutineeCannotBeHigherKindedID) :
29212921
def msg(using Context) = i"the scrutinee of a match type cannot be higher-kinded"
29222922
def explain(using Context) = ""
2923+
2924+
class ClosureCannotHaveInternalParameterDependencies(mt: Type)(using Context)
2925+
extends TypeMsg(ClosureCannotHaveInternalParameterDependenciesID):
2926+
def msg(using Context) =
2927+
i"""cannot turn method type $mt into closure
2928+
|because it has internal parameter dependencies"""
2929+
def explain(using Context) = ""

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ object Checking {
412412
case tree: RefTree =>
413413
checkRef(tree, tree.symbol)
414414
foldOver(x, tree)
415-
case tree: This =>
415+
case tree: This if tree.tpe.classSymbol == refineCls =>
416416
selfRef(tree)
417417
case tree: TypeTree =>
418418
val checkType = new TypeAccumulator[Unit] {

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,15 +1692,17 @@ class Namer { typer: Typer =>
16921692
def valOrDefDefSig(mdef: ValOrDefDef, sym: Symbol, paramss: List[List[Symbol]], paramFn: Type => Type)(using Context): Type = {
16931693

16941694
def inferredType = inferredResultType(mdef, sym, paramss, paramFn, WildcardType)
1695-
lazy val termParamss = paramss.collect { case TermSymbols(vparams) => vparams }
16961695

16971696
val tptProto = mdef.tpt match {
16981697
case _: untpd.DerivedTypeTree =>
16991698
WildcardType
17001699
case TypeTree() =>
17011700
checkMembersOK(inferredType, mdef.srcPos)
17021701
case DependentTypeTree(tpFun) =>
1703-
val tpe = tpFun(termParamss.head)
1702+
// A lambda has at most one type parameter list followed by exactly one term parameter list.
1703+
val tpe = (paramss: @unchecked) match
1704+
case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams)
1705+
case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams)
17041706
if (isFullyDefined(tpe, ForceDegree.none)) tpe
17051707
else typedAheadExpr(mdef.rhs, tpe).tpe
17061708
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
@@ -1724,7 +1726,8 @@ class Namer { typer: Typer =>
17241726
// So fixing levels at instantiation avoids the soundness problem but apparently leads
17251727
// to type inference problems since it comes too late.
17261728
if !Config.checkLevelsOnConstraints then
1727-
val hygienicType = TypeOps.avoid(rhsType, termParamss.flatten)
1729+
val termParams = paramss.collect { case TermSymbols(vparams) => vparams }.flatten
1730+
val hygienicType = TypeOps.avoid(rhsType, termParams)
17281731
if (!hygienicType.isValueType || !(hygienicType <:< tpt.tpe))
17291732
report.error(
17301733
em"""return type ${tpt.tpe} of lambda cannot be made hygienic

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ trait TypeAssigner {
2222
*/
2323
def qualifyingClass(tree: untpd.Tree, qual: Name, packageOK: Boolean)(using Context): Symbol = {
2424
def qualifies(sym: Symbol) =
25-
sym.isClass && (
25+
sym.isClass &&
26+
// `this` in a polymorphic function type never refers to the desugared refinement.
27+
// In other refinements, `this` does refer to the refinement but is deprecated
28+
// (see `Checking#checkRefinementNonCyclic`).
29+
!(sym.isRefinementClass && sym.derivesFrom(defn.PolyFunctionClass)) && (
2630
qual.isEmpty ||
2731
sym.name == qual ||
2832
sym.is(Module) && sym.name.stripModuleClassSuffix == qual)

0 commit comments

Comments
 (0)