Skip to content

[Proof of concept] Polymorphic function types #4672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 30, 2019
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class Compiler {
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
new TailRec, // Rewrite tail recursion to loops
new Mixin, // Expand trait fields and trait initializers
new LazyVals, // Expand lazy vals
Expand Down
35 changes: 35 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,39 @@ object desugar {
}
}

def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match {
case Function(vargs, res) =>
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
val mods = body match {
case body: FunctionWithMods => body.mods
case _ => untpd.EmptyModifiers
}
val polyFunctionTpt = ref(defn.PolyFunctionType)
val applyTParams = targs.asInstanceOf[List[TypeDef]]
if (ctx.mode.is(Mode.Type)) {
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }

val applyVParams = vargs.zipWithIndex.map { case (p, n) =>
makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags)
}
RefinedTypeTree(polyFunctionTpt, List(
DefDef(nme.apply, applyTParams, List(applyVParams), res, EmptyTree)
))
} else {
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }

val applyVParams = vargs.asInstanceOf[List[ValDef]]
.map(varg => varg.withAddedFlags(mods.flags | Param))
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
List(DefDef(nme.apply, applyTParams, List(applyVParams), TypeTree(), res))
))
}
case _ =>
EmptyTree // may happen for erroneous input. An error will already have been reported.
}

// begin desugar

// Special case for `Parens` desugaring: unlike all the desugarings below,
Expand All @@ -1430,6 +1463,8 @@ object desugar {
}

val desugared = tree match {
case PolyFunction(targs, body) =>
makePolyFunction(targs, body) orElse tree
case SymbolLit(str) =>
Literal(Constant(scala.Symbol(str)))
case InterpolatedString(id, segments) =>
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ object Trees {
}

def withFlags(flags: FlagSet): ThisTree[Untyped] = withMods(untpd.Modifiers(flags))
def withAddedFlags(flags: FlagSet): ThisTree[Untyped] = withMods(rawMods | flags)

def setComment(comment: Option[Comment]): this.type = {
comment.map(putAttachment(DocComment, _))
Expand Down
14 changes: 14 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile)
extends Function(args, body)

/** A polymorphic function type */
case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree {
override def isTerm = body.isTerm
override def isType = body.isType
}

/** A function created from a wildcard expression
* @param placeholderParams a list of definitions of synthetic parameters.
* @param body the function body where wildcards are replaced by
Expand Down Expand Up @@ -491,6 +497,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case tree: Function if (args eq tree.args) && (body eq tree.body) => tree
case _ => finalize(tree, untpd.Function(args, body)(tree.source))
}
def PolyFunction(tree: Tree)(targs: List[Tree], body: Tree)(implicit ctx: Context): Tree = tree match {
case tree: PolyFunction if (targs eq tree.targs) && (body eq tree.body) => tree
case _ => finalize(tree, untpd.PolyFunction(targs, body)(tree.source))
}
def InfixOp(tree: Tree)(left: Tree, op: Ident, right: Tree)(implicit ctx: Context): Tree = tree match {
case tree: InfixOp if (left eq tree.left) && (op eq tree.op) && (right eq tree.right) => tree
case _ => finalize(tree, untpd.InfixOp(left, op, right)(tree.source))
Expand Down Expand Up @@ -579,6 +589,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
cpy.InterpolatedString(tree)(id, segments.mapConserve(transform))
case Function(args, body) =>
cpy.Function(tree)(transform(args), transform(body))
case PolyFunction(targs, body) =>
cpy.PolyFunction(tree)(transform(targs), transform(body))
case InfixOp(left, op, right) =>
cpy.InfixOp(tree)(transform(left), op, transform(right))
case PostfixOp(od, op) =>
Expand Down Expand Up @@ -634,6 +646,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
this(x, segments)
case Function(args, body) =>
this(this(x, args), body)
case PolyFunction(targs, body) =>
this(this(x, targs), body)
case InfixOp(left, op, right) =>
this(this(this(x, left), op), right)
case PostfixOp(od, op) =>
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,9 @@ class Definitions {
if (n <= MaxImplementedFunctionArity && (!isContextual || ctx.erasedTypes) && !isErased) ImplementedFunctionType(n)
else FunctionClass(n, isContextual, isErased).typeRef

lazy val PolyFunctionClass = ctx.requiredClass("scala.PolyFunction")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should follow the principle used elsewhere: The TypeRef is computed in the lazy val and the context-dependent symbol follows. This is to make sure that the system keeps functioning if Definition classes are edited and recompiled. If you deviate from this, you create confusion for others.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking nearby, the pattern seems to be the same as here: the symbol is defined as a lazy val and the typeref as a def. Can you point me to an example which is arranged the way you want it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g.

  lazy val StringBuilderType: TypeRef      = ctx.requiredClassRef("scala.collection.mutable.StringBuilder")
  def StringBuilderClass(implicit ctx: Context): ClassSymbol = StringBuilderType.symbol.asClass

But I meant to go over Definitions anyway, trying to avoid the duplication and make it safe by design. The problem with the lazy val pattern as you wrote it is that it would not work in interactive mode if PolyFunction was edited. Then
the system would hang on to the first version computed instead of the edited ones. I agree that's a rather esoteric use case. So we can leave it for now.

def PolyFunctionType = PolyFunctionClass.typeRef

/** If `cls` is a class in the scala package, its name, otherwise EmptyTypeName */
def scalaClassName(cls: Symbol)(implicit ctx: Context): TypeName =
if (cls.isClass && cls.owner == ScalaPackageClass) cls.asClass.name else EmptyTypeName
Expand Down
26 changes: 25 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,27 @@ object TypeErasure {
MethodType(Nil, defn.BoxedUnitType)
else if (sym.isAnonymousFunction && einfo.paramInfos.length > MaxImplementedFunctionArity)
MethodType(nme.ALLARGS :: Nil, JavaArrayType(defn.ObjectType) :: Nil, einfo.resultType)
else if (sym.name == nme.apply && sym.owner.derivesFrom(defn.PolyFunctionClass)) {
// The erasure of `apply` in subclasses of PolyFunction has to match
// the erasure of FunctionN#apply, since after `ElimPolyFunction` we replace
// a `PolyFunction` parent by a `FunctionN` parent.
einfo.derivedLambdaType(
paramInfos = einfo.paramInfos.map(_ => defn.ObjectType),
resType = defn.ObjectType
)
}
else
einfo
case einfo =>
einfo
// Erase the parameters of `apply` in subclasses of PolyFunction
// Preserve PolyFunction argument types to support PolyFunctions with
// PolyFunction arguments
if (sym.is(TermParam) && sym.owner.name == nme.apply
&& sym.owner.owner.derivesFrom(defn.PolyFunctionClass)
&& !(tp <:< defn.PolyFunctionType)) {
defn.ObjectType
} else
einfo
}
}

Expand Down Expand Up @@ -383,6 +400,7 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean
* - otherwise, if T is a type parameter coming from Java, []Object
* - otherwise, Object
* - For a term ref p.x, the type <noprefix> # x.
* - For a refined type scala.PolyFunction { def apply[...](x_1, ..., x_N): R }, scala.FunctionN
* - For a typeref scala.Any, scala.AnyVal, scala.Singleton, scala.Tuple, or scala.*: : |java.lang.Object|
* - For a typeref scala.Unit, |scala.runtime.BoxedUnit|.
* - For a typeref scala.FunctionN, where N > MaxImplementedFunctionArity, scala.FunctionXXL
Expand Down Expand Up @@ -429,6 +447,12 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean
SuperType(this(thistpe), this(supertpe))
case ExprType(rt) =>
defn.FunctionType(0)
case RefinedType(parent, nme.apply, refinedInfo) if parent.typeSymbol eq defn.PolyFunctionClass =>
assert(refinedInfo.isInstanceOf[PolyType])
val res = refinedInfo.resultType
val paramss = res.paramNamess
assert(paramss.length == 1)
this(defn.FunctionType(paramss.head.length, isContextual = res.isImplicitMethod, isErased = res.isErasedMethod))
case tp: TypeProxy =>
this(tp.underlying)
case AndType(tp1, tp2) =>
Expand Down
11 changes: 10 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3194,8 +3194,17 @@ object Types {
companion.eq(ContextualMethodType) ||
companion.eq(ErasedContextualMethodType)


def computeSignature(implicit ctx: Context): Signature = {
val params = if (isErasedMethod) Nil else paramInfos
def polyFunctionSignature(tp: Type): Type = tp match {
case RefinedType(parent, nme.apply, refinedInfo) if parent.typeSymbol eq defn.PolyFunctionClass =>
val res = refinedInfo.resultType
val paramss = res.paramNamess
defn.FunctionType(paramss.head.length)
case _ => tp
}

val params = if (isErasedMethod) Nil else paramInfos.mapConserve(polyFunctionSignature)
resultSignature.prepend(params, isJavaMethod)
}

Expand Down
35 changes: 32 additions & 3 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,13 @@ object Parsers {
*/
def toplevelTyp(): Tree = rejectWildcardType(typ())

/** Type ::= FunTypeMods FunArgTypes `=>' Type
* | HkTypeParamClause `=>>' Type
/** Type ::= FunType
* | HkTypeParamClause ‘=>>’ Type
* | MatchType
* | InfixType
* FunType ::= { 'erased' | 'given' } (MonoFunType | PolyFunType)
* MonoFunType ::= FunArgTypes ‘=>’ Type
* PolyFunType ::= HKTypeParamClause '=>' Type
* FunArgTypes ::= InfixType
* | `(' [ FunArgType {`,' FunArgType } ] `)'
* | '(' TypedFunParam {',' TypedFunParam } ')'
Expand Down Expand Up @@ -924,7 +928,18 @@ object Parsers {
val tparams = typeParamClause(ParamOwner.TypeParam)
if (in.token == TLARROW)
atSpan(start, in.skipToken())(LambdaTypeTree(tparams, toplevelTyp()))
else { accept(TLARROW); typ() }
else if (in.token == ARROW) {
val arrowOffset = in.skipToken()
val body = toplevelTyp()
atSpan(start, arrowOffset) {
body match {
case _: Function => PolyFunction(tparams, body)
case _ =>
syntaxError("Implementation restriction: polymorphic function types must have a value parameter", arrowOffset)
Ident(nme.ERROR.toTypeName)
}
}
} else { accept(TLARROW); typ() }
}
else infixType()

Expand Down Expand Up @@ -1223,6 +1238,7 @@ object Parsers {
* | `throw' Expr
* | `return' [Expr]
* | ForExpr
* | HkTypeParamClause ‘=>’ Expr
* | [SimpleExpr `.'] id `=' Expr
* | SimpleExpr1 ArgumentExprs `=' Expr
* | Expr2
Expand Down Expand Up @@ -1323,6 +1339,19 @@ object Parsers {
atSpan(in.skipToken()) { Return(if (isExprIntro) expr() else EmptyTree, EmptyTree) }
case FOR =>
forExpr()
case LBRACKET =>
val start = in.offset
val tparams = typeParamClause(ParamOwner.TypeParam)
val arrowOffset = accept(ARROW)
val body = expr()
atSpan(start, arrowOffset) {
body match {
case _: Function => PolyFunction(tparams, body)
case _ =>
syntaxError("Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset)
errorTermTree
}
}
case _ =>
if (isIdent(nme.inline) && !in.inModifierPosition() && in.lookaheadIn(canStartExpressionTokens)) {
val start = in.skipToken()
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
(keywordText("erased ") provided isErased) ~
argsText ~ " => " ~ toText(body)
}
case PolyFunction(targs, body) =>
val targsText = "[" ~ Text(targs.map((arg: Tree) => toText(arg)), ", ") ~ "]"
changePrec(GlobalPrec) {
targsText ~ " => " ~ toText(body)
}
case InfixOp(l, op, r) =>
val opPrec = parsing.precedence(op.name)
changePrec(opPrec) { toText(l) ~ " " ~ toText(op) ~ " " ~ toText(r) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,19 @@ class ElimErasedValueType extends MiniPhase with InfoTransformer { thisPhase =>
override def matches(sym1: Symbol, sym2: Symbol) =
sym1.signature == sym2.signature
}

def checkNoConflict(sym1: Symbol, sym2: Symbol, info: Type)(implicit ctx: Context): Unit = {
val site = root.thisType
val info1 = site.memberInfo(sym1)
val info2 = site.memberInfo(sym2)
if (!info1.matchesLoosely(info2))
// PolyFunction apply methods will be eliminated later during
// ElimPolyFunction, so we let them pass here.
def bothPolyApply =
sym1.name == nme.apply &&
(sym1.owner.derivesFrom(defn.PolyFunctionClass) ||
sym2.owner.derivesFrom(defn.PolyFunctionClass))

if (!info1.matchesLoosely(info2) && !bothPolyApply)
ctx.error(DoubleDefinition(sym1, sym2, root), root.sourcePos)
}
val earlyCtx = ctx.withPhase(ctx.elimRepeatedPhase.next)
Expand Down
68 changes: 68 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/ElimPolyFunction.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package dotty.tools.dotc
package transform

import ast.{Trees, tpd}
import core._, core.Decorators._
import MegaPhase._, Phases.Phase
import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._, DenotTransformers._
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Scopes._, Denotations._
import TypeErasure.ErasedValueType, ValueClasses._

/** This phase rewrite PolyFunction subclasses to FunctionN subclasses
*
* class Foo extends PolyFunction {
* def apply(x_1: P_1, ..., x_N: P_N): R = rhs
* }
* becomes:
* class Foo extends FunctionN {
* def apply(x_1: P_1, ..., x_N: P_N): R = rhs
* }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should allow writing such a class Foo, rather than just allowing closures — we don't for implicit function types and nobody seems to mind the restriction, and this restriction enables transformations such as ShortcutImplicits. https://github.com/lampepfl/dotty/pull/1775/files#diff-71350811180f41d868e7fb3258fd774dR18

Copy link
Contributor

@LPTK LPTK Jun 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that it can be really useful being able to extend/implement a polymorphic function type; it's one of the use cases I mention in https://github.com/lampepfl/dotty/issues/4670#issuecomment-397819801 – making polymorphic type case classes extend the corresponding polymorphic function type.

EDIT – typo: "type class" -> "case class"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, tho s/type classes/case classes/. But if we follow the approach for implicit function types, val b: Int => B[Int] = B could still work by eta-expansion. In fact, it's not clear why eta-expansion doesn't handle that case today by turning B into B _ or B.apply (and I'm not going to try which ones do work, I'd just ask they all work unless backward compatibility gets in the way).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Blaisorblade AFAIK eta expansion never inserts .apply calls on things that are not functions. It's true that adding this behavior could be an alternative solution to the stated polymorphic case class problem.

*/
class ElimPolyFunction extends MiniPhase with DenotTransformer {

import tpd._

override def phaseName: String = ElimPolyFunction.name

override def runsAfter = Set(Erasure.name)

override def changesParents: Boolean = true // Replaces PolyFunction by FunctionN

override def transform(ref: SingleDenotation)(implicit ctx: Context) = ref match {
case ref: ClassDenotation if ref.symbol != defn.PolyFunctionClass && ref.derivesFrom(defn.PolyFunctionClass) =>
val cinfo = ref.classInfo
val newParent = functionTypeOfPoly(cinfo)
val newParents = cinfo.classParents.map(parent =>
if (parent.typeSymbol == defn.PolyFunctionClass)
newParent
else
parent
)
ref.copySymDenotation(info = cinfo.derivedClassInfo(classParents = newParents))
case _ =>
ref
}

def functionTypeOfPoly(cinfo: ClassInfo)(implicit ctx: Context): Type = {
val applyMeth = cinfo.decls.lookup(nme.apply).info
val arity = applyMeth.paramNamess.head.length
defn.FunctionType(arity)
}

override def transformTemplate(tree: Template)(implicit ctx: Context): Tree = {
val newParents = tree.parents.mapconserve(parent =>
if (parent.tpe.typeSymbol == defn.PolyFunctionClass) {
val cinfo = tree.symbol.owner.asClass.classInfo
tpd.TypeTree(functionTypeOfPoly(cinfo))
}
else
parent
)
cpy.Template(tree)(parents = newParents)
}
}

object ElimPolyFunction {
val name = "elimPolyFunction"
}

19 changes: 15 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -415,22 +415,33 @@ object Erasure {
* e.m -> e.[]m if `m` is an array operation other than `clone`.
*/
override def typedSelect(tree: untpd.Select, pt: Type)(implicit ctx: Context): Tree = {
val qual1 = typed(tree.qualifier, AnySelectionProto)

def mapOwner(sym: Symbol): Symbol = {
def recur(owner: Symbol): Symbol =
// PolyFunction apply Selects will not have a symbol, so deduce the owner
// from the typed qual.
def polyOwner: Symbol =
if (sym.exists || tree.name != nme.apply) NoSymbol
else {
val owner = qual1.tpe.widen.typeSymbol
if (defn.isFunctionClass(owner)) owner else NoSymbol
}

polyOwner orElse {
val owner = sym.owner
if (defn.specialErasure.contains(owner)) {
assert(sym.isConstructor, s"${sym.showLocated}")
defn.specialErasure(owner)
} else if (defn.isSyntheticFunctionClass(owner))
defn.erasedFunctionClass(owner)
else
owner
recur(sym.owner)
}
}

val origSym = tree.symbol
val owner = mapOwner(origSym)
val sym = if (owner eq origSym.owner) origSym else owner.info.decl(origSym.name).symbol
val sym = if (owner eq origSym.maybeOwner) origSym else owner.info.decl(tree.name).symbol
assert(sym.exists, origSym.showLocated)

def select(qual: Tree, sym: Symbol): Tree =
Expand Down Expand Up @@ -474,7 +485,7 @@ object Erasure {
}
}

checkNotErased(recur(typed(tree.qualifier, AnySelectionProto)))
checkNotErased(recur(qual1))
}

override def typedThis(tree: untpd.This)(implicit ctx: Context): Tree =
Expand Down
Loading