Skip to content

Replace is{Poly|Erased}FunctionType with {PolyOrErased,Poly,Erased}FunctionOf #18207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,8 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
def isStructuralTermSelect(tree: Select) =
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
case defn.PolyOrErasedFunctionOf(_) =>
false
case RefinedType(parent, rname, rinfo) =>
rname == tree.name || hasRefinement(parent)
case tp: TypeProxy =>
Expand All @@ -966,10 +968,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
false
!tree.symbol.exists
&& tree.isTerm
&& {
val qualType = tree.qualifier.tpe
hasRefinement(qualType) && !defn.isPolyOrErasedFunctionType(qualType)
}
&& hasRefinement(tree.qualifier.tpe)
def loop(tree: Tree): Boolean = tree match
case TypeApply(fun, _) =>
loop(fun)
Expand Down
60 changes: 43 additions & 17 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ class Definitions {
FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil)
def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = {
ft.dealias match
case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) =>
case ErasedFunctionOf(mt) =>
Some(mt.paramInfos, mt.resType, mt.isContextualMethod)
case _ =>
val tsym = ft.dealias.typeSymbol
Expand All @@ -1126,6 +1126,42 @@ class Definitions {
}
}

object PolyOrErasedFunctionOf {
/** Matches a refined `PolyFunction` or `ErasedFunction` type and extracts the apply info.
*
* Pattern: `(PolyFunction | ErasedFunction) { def apply: $mt }`
*/
def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match
case RefinedType(parent, nme.apply, mt: MethodicType)
if parent.derivesFrom(defn.PolyFunctionClass) || parent.derivesFrom(defn.ErasedFunctionClass) =>
Some(mt)
case _ => None
}

object PolyFunctionOf {
/** Matches a refined `PolyFunction` type and extracts the apply info.
*
* Pattern: `PolyFunction { def apply: $pt }`
*/
def unapply(ft: Type)(using Context): Option[PolyType] = ft.dealias match
case RefinedType(parent, nme.apply, pt: PolyType)
if parent.derivesFrom(defn.PolyFunctionClass) =>
Some(pt)
case _ => None
}

object ErasedFunctionOf {
/** Matches a refined `ErasedFunction` type and extracts the apply info.
*
* Pattern: `ErasedFunction { def apply: $mt }`
*/
def unapply(ft: Type)(using Context): Option[MethodType] = ft.dealias match
case RefinedType(parent, nme.apply, mt: MethodType)
if parent.derivesFrom(defn.ErasedFunctionClass) =>
Some(mt)
case _ => None
}

object PartialFunctionOf {
def apply(arg: Type, result: Type)(using Context): Type =
PartialFunctionClass.typeRef.appliedTo(arg :: result :: Nil)
Expand Down Expand Up @@ -1713,26 +1749,16 @@ class Definitions {
def isFunctionNType(tp: Type)(using Context): Boolean =
isNonRefinedFunction(tp.dropDependentRefinement)

/** Does `tp` derive from `PolyFunction` or `ErasedFunction`? */
def isPolyOrErasedFunctionType(tp: Type)(using Context): Boolean =
isPolyFunctionType(tp) || isErasedFunctionType(tp)

/** Does `tp` derive from `PolyFunction`? */
def isPolyFunctionType(tp: Type)(using Context): Boolean =
tp.derivesFrom(defn.PolyFunctionClass)

/** Does `tp` derive from `ErasedFunction`? */
def isErasedFunctionType(tp: Type)(using Context): Boolean =
tp.derivesFrom(defn.ErasedFunctionClass)

/** Returns whether `tp` is an instance or a refined instance of:
* - scala.FunctionN
* - scala.ContextFunctionN
* - ErasedFunction
* - PolyFunction
*/
def isFunctionType(tp: Type)(using Context): Boolean =
isFunctionNType(tp) || isPolyOrErasedFunctionType(tp)
isFunctionNType(tp)
|| tp.derivesFrom(defn.PolyFunctionClass) // TODO check for refinement?
|| tp.derivesFrom(defn.ErasedFunctionClass) // TODO check for refinement?

private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) =
if !ctx.settings.Yscala2Stdlib.value then
Expand Down Expand Up @@ -1836,7 +1862,7 @@ class Definitions {
tp.stripTypeVar.dealias match
case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) =>
asContextFunctionType(TypeComparer.bounds(tp1).hiBound)
case tp1 @ RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) && mt.isContextualMethod =>
case tp1 @ ErasedFunctionOf(mt) if mt.isContextualMethod =>
tp1
case tp1 =>
if tp1.typeSymbol.name.isContextFunction && isFunctionNType(tp1) then tp1
Expand All @@ -1856,7 +1882,7 @@ class Definitions {
atPhase(erasurePhase)(unapply(tp))
else
asContextFunctionType(tp) match
case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) =>
case ErasedFunctionOf(mt) =>
Some((mt.paramInfos, mt.resType, mt.erasedParams))
case tp1 if tp1.exists =>
val args = tp1.functionArgInfos
Expand All @@ -1866,7 +1892,7 @@ class Definitions {

/* Returns a list of erased booleans marking whether parameters are erased, for a function type. */
def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match {
case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams
case ErasedFunctionOf(mt) => mt.erasedParams
case tp if isFunctionNType(tp) => List.fill(functionArity(tp)) { false }
case _ => Nil
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ class TypeApplications(val self: Type) extends AnyVal {
* Handles `ErasedFunction`s and poly functions gracefully.
*/
final def functionArgInfos(using Context): List[Type] = self.dealias match
case RefinedType(parent, nme.apply, mt: MethodType) if defn.isPolyOrErasedFunctionType(parent) => (mt.paramInfos :+ mt.resultType)
case defn.ErasedFunctionOf(mt) => (mt.paramInfos :+ mt.resultType)
case _ => self.dropDependentRefinement.dealias.argInfos

/** Argument types where existential types in arguments are disallowed */
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
isSubType(info1, info2)

if defn.isFunctionType(tp2) then
if defn.isPolyFunctionType(tp2) then
if tp2.derivesFrom(defn.PolyFunctionClass) then
// TODO should we handle ErasedFunction is this same way?
tp1.member(nme.apply).info match
case info1: PolyType =>
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,8 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
else SuperType(eThis, eSuper)
case ExprType(rt) =>
defn.FunctionType(0)
case RefinedType(parent, nme.apply, refinedInfo) if defn.isPolyOrErasedFunctionType(parent) =>
eraseRefinedFunctionApply(refinedInfo)
case defn.PolyOrErasedFunctionOf(mt) =>
eraseRefinedFunctionApply(mt)
case tp: TypeVar if !tp.isInstantiated =>
assert(inSigName, i"Cannot erase uninstantiated type variable $tp")
WildcardType
Expand Down Expand Up @@ -936,7 +936,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
sigName(defn.FunctionOf(Nil, rt))
case tp: TypeVar if !tp.isInstantiated =>
tpnme.Uninstantiated
case tp @ RefinedType(parent, nme.apply, _) if defn.isPolyOrErasedFunctionType(parent) =>
case tp @ defn.PolyOrErasedFunctionOf(_) =>
// we need this case rather than falling through to the default
// because RefinedTypes <: TypeProxy and it would be caught by
// the case immediately below
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,7 @@ object Types {
else NoType
case t if defn.isNonRefinedFunction(t) =>
t
case t if defn.isErasedFunctionType(t) =>
case t @ defn.PolyOrErasedFunctionOf(_) =>
t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this could be just case t if defn.isFunctionType(t) =>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would make findFunctionType identify dependent function refinements. It looks like a desirable change.

case t @ SAMType(_) =>
t
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ object Erasure {
// Instead, we manually lookup the type of `apply` in the qualifier.
inContext(preErasureCtx) {
val qualTp = tree.qualifier.typeOpt.widen
if defn.isPolyOrErasedFunctionType(qualTp) then
if qualTp.derivesFrom(defn.PolyFunctionClass) || qualTp.derivesFrom(defn.ErasedFunctionClass) then
eraseRefinedFunctionApply(qualTp.select(nme.apply).widen).classSymbol
else
NoSymbol
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,11 @@ object TreeChecker {
val tpe = tree.typeOpt

// PolyFunction and ErasedFunction apply methods stay structural until Erasure
val isRefinedFunctionApply = (tree.name eq nme.apply) && defn.isPolyOrErasedFunctionType(tree.qualifier.typeOpt)
val isRefinedFunctionApply = (tree.name eq nme.apply) && {
val qualTpe = tree.qualifier.typeOpt
qualTpe.derivesFrom(defn.PolyFunctionClass) || qualTpe.derivesFrom(defn.ErasedFunctionClass)
}

// Outer selects are pickled specially so don't require a symbol
val isOuterSelect = tree.name.is(OuterSelectName)
val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
expected =:= defn.FunctionOf(actualArgs, actualRet,
defn.isContextFunctionType(baseFun))
val arity: Int =
if defn.isErasedFunctionType(fun) then -1 // TODO support?
if fun.derivesFrom(defn.ErasedFunctionClass) then -1 // TODO support?
else if defn.isFunctionNType(fun) then
// TupledFunction[(...) => R, ?]
fun.functionArgInfos match
Expand Down
17 changes: 8 additions & 9 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

(pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
if defn.isNonRefinedFunction(parent) && formals.length == defaultArity =>
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
case defn.ErasedFunctionOf(mt @ MethodTpe(_, formals, restpe)) if formals.length == defaultArity =>
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
case pt1 @ SAMType(mt @ MethodTpe(_, formals, _)) =>
val restpe = mt.resultType match
Expand Down Expand Up @@ -1648,11 +1650,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// If the expected type is a polymorphic function with the same number of
// type and value parameters, then infer the types of value parameters from the expected type.
val inferredVParams = pt match
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType))
if (parent.typeSymbol eq defn.PolyFunctionClass)
&& tparams.lengthCompare(poly.paramNames) == 0
&& vparams.lengthCompare(mt.paramNames) == 0
=>
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 =>
vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
// Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
// it must be a valid method parameter type.
Expand All @@ -1667,7 +1666,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
vparams

val resultTpt = pt.dealias match
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) =>
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
case _ => untpd.TypeTree()
Expand Down Expand Up @@ -3234,8 +3233,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
else formals.map(untpd.TypeTree)
}

val erasedParams = pt.dealias match {
case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams
val erasedParams = pt match {
case defn.ErasedFunctionOf(mt: MethodType) => mt.erasedParams
case _ => paramTypes.map(_ => false)
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1788,7 +1788,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def isContextFunctionType: Boolean =
dotc.core.Symbols.defn.isContextFunctionType(self)
def isErasedFunctionType: Boolean =
dotc.core.Symbols.defn.isErasedFunctionType(self)
self.derivesFrom(dotc.core.Symbols.defn.ErasedFunctionClass)
def isDependentFunctionType: Boolean =
val tpNoRefinement = self.dropDependentRefinement
tpNoRefinement != self
Expand Down