Skip to content

Commit 6337c96

Browse files
committed
Add PolyFunctionOf
1 parent 3577cf1 commit 6337c96

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,18 @@ class Definitions {
11381138
case _ => None
11391139
}
11401140

1141+
object PolyFunctionOf {
1142+
/** Matches a refined `PolyFunction` type and extracts the apply info.
1143+
*
1144+
* Pattern: `PolyFunction { def apply: $pt }`
1145+
*/
1146+
def unapply(ft: Type)(using Context): Option[PolyType] = ft.dealias match
1147+
case RefinedType(parent, nme.apply, pt: PolyType)
1148+
if parent.derivesFrom(defn.PolyFunctionClass) =>
1149+
Some(pt)
1150+
case _ => None
1151+
}
1152+
11411153
object ErasedFunctionOf {
11421154
/** Matches a refined `ErasedFunction` type and extracts the apply info.
11431155
*

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,11 +1650,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16501650
// If the expected type is a polymorphic function with the same number of
16511651
// type and value parameters, then infer the types of value parameters from the expected type.
16521652
val inferredVParams = pt match
1653-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType))
1654-
if (parent.typeSymbol eq defn.PolyFunctionClass)
1655-
&& tparams.lengthCompare(poly.paramNames) == 0
1656-
&& vparams.lengthCompare(mt.paramNames) == 0
1657-
=>
1653+
case PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
1654+
if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 =>
16581655
vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
16591656
// Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
16601657
// it must be a valid method parameter type.
@@ -1669,7 +1666,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16691666
vparams
16701667

16711668
val resultTpt = pt.dealias match
1672-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1669+
case PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) =>
16731670
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
16741671
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
16751672
case _ => untpd.TypeTree()

0 commit comments

Comments
 (0)