Skip to content

Commit 60a3f85

Browse files
committed
Replace is{Poly|Erased}FunctionType with {Poly|Erased}FunctionOf
1 parent 9ffd44e commit 60a3f85

File tree

11 files changed

+68
-40
lines changed

11 files changed

+68
-40
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,8 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
954954
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
955955
def isStructuralTermSelect(tree: Select) =
956956
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
957+
case defn.PolyOrErasedFunctionOf(_) =>
958+
false
957959
case RefinedType(parent, rname, rinfo) =>
958960
rname == tree.name || hasRefinement(parent)
959961
case tp: TypeProxy =>
@@ -966,10 +968,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
966968
false
967969
!tree.symbol.exists
968970
&& tree.isTerm
969-
&& {
970-
val qualType = tree.qualifier.tpe
971-
hasRefinement(qualType) && !defn.isPolyOrErasedFunctionType(qualType)
972-
}
971+
&& hasRefinement(tree.qualifier.tpe)
973972
def loop(tree: Tree): Boolean = tree match
974973
case TypeApply(fun, _) =>
975974
loop(fun)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ class Definitions {
11141114
FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil)
11151115
def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = {
11161116
ft.dealias match
1117-
case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) =>
1117+
case ErasedFunctionOf(mt) =>
11181118
Some(mt.paramInfos, mt.resType, mt.isContextualMethod)
11191119
case _ =>
11201120
val tsym = ft.dealias.typeSymbol
@@ -1126,6 +1126,42 @@ class Definitions {
11261126
}
11271127
}
11281128

1129+
object PolyOrErasedFunctionOf {
1130+
/** Matches a refined `PolyFunction` or `ErasedFunction` type and extracts the apply info.
1131+
*
1132+
* Pattern: `(PolyFunction | ErasedFunction) { def apply: $mt }`
1133+
*/
1134+
def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match
1135+
case RefinedType(parent, nme.apply, mt: MethodicType)
1136+
if parent.derivesFrom(defn.PolyFunctionClass) || parent.derivesFrom(defn.ErasedFunctionClass) =>
1137+
Some(mt)
1138+
case _ => None
1139+
}
1140+
1141+
object PolyFunctionOf {
1142+
/** Matches a refined `PolyFunction` type and extracts the apply info.
1143+
*
1144+
* Pattern: `PolyFunction { def apply: $pt }`
1145+
*/
1146+
def unapply(ft: Type)(using Context): Option[PolyType] = ft.dealias match
1147+
case RefinedType(parent, nme.apply, pt: PolyType)
1148+
if parent.derivesFrom(defn.PolyFunctionClass) =>
1149+
Some(pt)
1150+
case _ => None
1151+
}
1152+
1153+
object ErasedFunctionOf {
1154+
/** Matches a refined `ErasedFunction` type and extracts the apply info.
1155+
*
1156+
* Pattern: `ErasedFunction { def apply: $mt }`
1157+
*/
1158+
def unapply(ft: Type)(using Context): Option[MethodType] = ft.dealias match
1159+
case RefinedType(parent, nme.apply, mt: MethodType)
1160+
if parent.derivesFrom(defn.ErasedFunctionClass) =>
1161+
Some(mt)
1162+
case _ => None
1163+
}
1164+
11291165
object PartialFunctionOf {
11301166
def apply(arg: Type, result: Type)(using Context): Type =
11311167
PartialFunctionClass.typeRef.appliedTo(arg :: result :: Nil)
@@ -1713,26 +1749,16 @@ class Definitions {
17131749
def isFunctionNType(tp: Type)(using Context): Boolean =
17141750
isNonRefinedFunction(tp.dropDependentRefinement)
17151751

1716-
/** Does `tp` derive from `PolyFunction` or `ErasedFunction`? */
1717-
def isPolyOrErasedFunctionType(tp: Type)(using Context): Boolean =
1718-
isPolyFunctionType(tp) || isErasedFunctionType(tp)
1719-
1720-
/** Does `tp` derive from `PolyFunction`? */
1721-
def isPolyFunctionType(tp: Type)(using Context): Boolean =
1722-
tp.derivesFrom(defn.PolyFunctionClass)
1723-
1724-
/** Does `tp` derive from `ErasedFunction`? */
1725-
def isErasedFunctionType(tp: Type)(using Context): Boolean =
1726-
tp.derivesFrom(defn.ErasedFunctionClass)
1727-
17281752
/** Returns whether `tp` is an instance or a refined instance of:
17291753
* - scala.FunctionN
17301754
* - scala.ContextFunctionN
17311755
* - ErasedFunction
17321756
* - PolyFunction
17331757
*/
17341758
def isFunctionType(tp: Type)(using Context): Boolean =
1735-
isFunctionNType(tp) || isPolyOrErasedFunctionType(tp)
1759+
isFunctionNType(tp)
1760+
|| tp.derivesFrom(defn.PolyFunctionClass) // TODO check for refinement?
1761+
|| tp.derivesFrom(defn.ErasedFunctionClass) // TODO check for refinement?
17361762

17371763
private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) =
17381764
if !ctx.settings.Yscala2Stdlib.value then
@@ -1836,7 +1862,7 @@ class Definitions {
18361862
tp.stripTypeVar.dealias match
18371863
case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) =>
18381864
asContextFunctionType(TypeComparer.bounds(tp1).hiBound)
1839-
case tp1 @ RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) && mt.isContextualMethod =>
1865+
case tp1 @ ErasedFunctionOf(mt) if mt.isContextualMethod =>
18401866
tp1
18411867
case tp1 =>
18421868
if tp1.typeSymbol.name.isContextFunction && isFunctionNType(tp1) then tp1
@@ -1856,7 +1882,7 @@ class Definitions {
18561882
atPhase(erasurePhase)(unapply(tp))
18571883
else
18581884
asContextFunctionType(tp) match
1859-
case RefinedType(parent, nme.apply, mt: MethodType) if isErasedFunctionType(parent) =>
1885+
case ErasedFunctionOf(mt) =>
18601886
Some((mt.paramInfos, mt.resType, mt.erasedParams))
18611887
case tp1 if tp1.exists =>
18621888
val args = tp1.functionArgInfos
@@ -1866,7 +1892,7 @@ class Definitions {
18661892

18671893
/* Returns a list of erased booleans marking whether parameters are erased, for a function type. */
18681894
def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match {
1869-
case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams
1895+
case ErasedFunctionOf(mt) => mt.erasedParams
18701896
case tp if isFunctionNType(tp) => List.fill(functionArity(tp)) { false }
18711897
case _ => Nil
18721898
}

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ class TypeApplications(val self: Type) extends AnyVal {
509509
* Handles `ErasedFunction`s and poly functions gracefully.
510510
*/
511511
final def functionArgInfos(using Context): List[Type] = self.dealias match
512-
case RefinedType(parent, nme.apply, mt: MethodType) if defn.isPolyOrErasedFunctionType(parent) => (mt.paramInfos :+ mt.resultType)
512+
case defn.ErasedFunctionOf(mt) => (mt.paramInfos :+ mt.resultType)
513513
case _ => self.dropDependentRefinement.dealias.argInfos
514514

515515
/** Argument types where existential types in arguments are disallowed */

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
666666
isSubType(info1, info2)
667667

668668
if defn.isFunctionType(tp2) then
669-
if defn.isPolyFunctionType(tp2) then
669+
if tp2.derivesFrom(defn.PolyFunctionClass) then
670670
// TODO should we handle ErasedFunction is this same way?
671671
tp1.member(nme.apply).info match
672672
case info1: PolyType =>

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,8 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
654654
else SuperType(eThis, eSuper)
655655
case ExprType(rt) =>
656656
defn.FunctionType(0)
657-
case RefinedType(parent, nme.apply, refinedInfo) if defn.isPolyOrErasedFunctionType(parent) =>
658-
eraseRefinedFunctionApply(refinedInfo)
657+
case defn.PolyOrErasedFunctionOf(mt) =>
658+
eraseRefinedFunctionApply(mt)
659659
case tp: TypeVar if !tp.isInstantiated =>
660660
assert(inSigName, i"Cannot erase uninstantiated type variable $tp")
661661
WildcardType
@@ -936,7 +936,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
936936
sigName(defn.FunctionOf(Nil, rt))
937937
case tp: TypeVar if !tp.isInstantiated =>
938938
tpnme.Uninstantiated
939-
case tp @ RefinedType(parent, nme.apply, _) if defn.isPolyOrErasedFunctionType(parent) =>
939+
case tp @ defn.PolyOrErasedFunctionOf(_) =>
940940
// we need this case rather than falling through to the default
941941
// because RefinedTypes <: TypeProxy and it would be caught by
942942
// the case immediately below

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1749,7 +1749,7 @@ object Types {
17491749
else NoType
17501750
case t if defn.isNonRefinedFunction(t) =>
17511751
t
1752-
case t if defn.isErasedFunctionType(t) =>
1752+
case t @ defn.PolyOrErasedFunctionOf(_: MethodType) =>
17531753
t
17541754
case t @ SAMType(_) =>
17551755
t

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ object Erasure {
679679
// Instead, we manually lookup the type of `apply` in the qualifier.
680680
inContext(preErasureCtx) {
681681
val qualTp = tree.qualifier.typeOpt.widen
682-
if defn.isPolyOrErasedFunctionType(qualTp) then
682+
if qualTp.derivesFrom(defn.PolyFunctionClass) || qualTp.derivesFrom(defn.ErasedFunctionClass) then
683683
eraseRefinedFunctionApply(qualTp.select(nme.apply).widen).classSymbol
684684
else
685685
NoSymbol

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,11 @@ object TreeChecker {
447447
val tpe = tree.typeOpt
448448

449449
// PolyFunction and ErasedFunction apply methods stay structural until Erasure
450-
val isRefinedFunctionApply = (tree.name eq nme.apply) && defn.isPolyOrErasedFunctionType(tree.qualifier.typeOpt)
450+
val isRefinedFunctionApply = (tree.name eq nme.apply) && {
451+
val qualTpe = tree.qualifier.typeOpt
452+
qualTpe.derivesFrom(defn.PolyFunctionClass) || qualTpe.derivesFrom(defn.ErasedFunctionClass)
453+
}
454+
451455
// Outer selects are pickled specially so don't require a symbol
452456
val isOuterSelect = tree.name.is(OuterSelectName)
453457
val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name)

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
105105
expected =:= defn.FunctionOf(actualArgs, actualRet,
106106
defn.isContextFunctionType(baseFun))
107107
val arity: Int =
108-
if defn.isErasedFunctionType(fun) then -1 // TODO support?
108+
if fun.derivesFrom(defn.ErasedFunctionClass) then -1 // TODO support?
109109
else if defn.isFunctionNType(fun) then
110110
// TupledFunction[(...) => R, ?]
111111
fun.functionArgInfos match

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13261326

13271327
(pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
13281328
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
1329-
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
1329+
if defn.isNonRefinedFunction(parent) && formals.length == defaultArity =>
1330+
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
1331+
case defn.ErasedFunctionOf(mt @ MethodTpe(_, formals, restpe)) if formals.length == defaultArity =>
13301332
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
13311333
case pt1 @ SAMType(mt @ MethodTpe(_, formals, _)) =>
13321334
val restpe = mt.resultType match
@@ -1648,11 +1650,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16481650
// If the expected type is a polymorphic function with the same number of
16491651
// type and value parameters, then infer the types of value parameters from the expected type.
16501652
val inferredVParams = pt match
1651-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType))
1652-
if (parent.typeSymbol eq defn.PolyFunctionClass)
1653-
&& tparams.lengthCompare(poly.paramNames) == 0
1654-
&& vparams.lengthCompare(mt.paramNames) == 0
1655-
=>
1653+
case PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
1654+
if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 =>
16561655
vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
16571656
// Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
16581657
// it must be a valid method parameter type.
@@ -1667,7 +1666,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16671666
vparams
16681667

16691668
val resultTpt = pt.dealias match
1670-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1669+
case PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) =>
16711670
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
16721671
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
16731672
case _ => untpd.TypeTree()
@@ -3234,8 +3233,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
32343233
else formals.map(untpd.TypeTree)
32353234
}
32363235

3237-
val erasedParams = pt.dealias match {
3238-
case RefinedType(parent, nme.apply, mt: MethodType) => mt.erasedParams
3236+
val erasedParams = pt match {
3237+
case defn.ErasedFunctionOf(mt: MethodType) => mt.erasedParams
32393238
case _ => paramTypes.map(_ => false)
32403239
}
32413240

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1788,7 +1788,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
17881788
def isContextFunctionType: Boolean =
17891789
dotc.core.Symbols.defn.isContextFunctionType(self)
17901790
def isErasedFunctionType: Boolean =
1791-
dotc.core.Symbols.defn.isErasedFunctionType(self)
1791+
self.derivesFrom(dotc.core.Symbols.defn.ErasedFunctionClass)
17921792
def isDependentFunctionType: Boolean =
17931793
val tpNoRefinement = self.dropDependentRefinement
17941794
tpNoRefinement != self

0 commit comments

Comments
 (0)