Skip to content

Commit e5ca0c4

Browse files
committed
Check user defined PolyFunction refinements
`PolyFunction` must be refined with an `apply` method that has a single parameter list with no by-name nor varargs parameters. It may optionally have type parameters. Some of these restrictions could be lifted later, but for now these features are not properly handled by the compiler. Fixes #8299 Fixes #18302
1 parent a37dac6 commit e5ca0c4

17 files changed

+112
-2
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,10 +1161,12 @@ class Definitions {
11611161
Some(mt)
11621162
case _ => None
11631163

1164-
private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
1164+
def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
11651165
def isValidMethodType(info: Type) = info match
11661166
case info: MethodType =>
1167-
!info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list
1167+
!info.resType.isInstanceOf[MethodOrPoly] && // Has only one parameter list
1168+
!info.isVarArgsMethod &&
1169+
!info.paramInfos.exists(_.isInstanceOf[ExprType]) // No by-name parameters
11681170
case _ => false
11691171
info match
11701172
case info: PolyType => isValidMethodType(info.resType)

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,13 +381,15 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
381381
case tree: ValDef =>
382382
registerIfHasMacroAnnotations(tree)
383383
checkErasedDef(tree)
384+
Checking.checkPolyFunctionType(tree.tpt)
384385
val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
385386
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
386387
checkStableSelection(tree.rhs)
387388
processValOrDefDef(super.transform(tree1))
388389
case tree: DefDef =>
389390
registerIfHasMacroAnnotations(tree)
390391
checkErasedDef(tree)
392+
Checking.checkPolyFunctionType(tree.tpt)
391393
annotateContextResults(tree)
392394
val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
393395
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
@@ -492,6 +494,9 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
492494
)
493495
case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) =>
494496
superAcc.withInvalidCurrentClass(super.transform(tree))
497+
case tree: RefinedTypeTree =>
498+
Checking.checkPolyFunctionType(tree)
499+
super.transform(tree)
495500
case _: Quote | _: QuotePattern =>
496501
ctx.compilationUnit.needsStaging = true
497502
super.transform(tree)

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,31 @@ object Checking {
816816
else Feature.checkExperimentalFeature("features", imp.srcPos)
817817
case _ =>
818818
end checkExperimentalImports
819+
820+
/** Checks that PolyFunction only have valid refinements.
821+
*
822+
* It only supports `apply` methods with one parameter list and optional type arguments.
823+
*/
824+
def checkPolyFunctionType(tree: Tree)(using Context): Unit = new TreeTraverser {
825+
def traverse(tree: Tree)(using Context): Unit = tree match
826+
case tree: RefinedTypeTree if tree.tpe.derivesFrom(defn.PolyFunctionClass) =>
827+
if tree.refinements.isEmpty then
828+
reportNoRefinements(tree.srcPos)
829+
tree.refinements.foreach {
830+
case refinement: DefDef if refinement.name != nme.apply =>
831+
report.error("PolyFunction only supports apply method refinements", refinement.srcPos)
832+
case refinement: DefDef if !defn.PolyFunctionOf.isValidPolyFunctionInfo(refinement.tpe.widen) =>
833+
report.error("Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.", refinement.srcPos)
834+
case _ =>
835+
}
836+
case _: RefTree if tree.symbol == defn.PolyFunctionClass =>
837+
reportNoRefinements(tree.srcPos)
838+
case _ =>
839+
traverseChildren(tree)
840+
841+
def reportNoRefinements(pos: SrcPos) =
842+
report.error("PolyFunction subtypes must refine the apply method", pos)
843+
}.traverse(tree)
819844
}
820845

821846
trait Checking {

tests/neg/i18302b.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-- Error: tests/neg/i18302b.scala:3:32 ---------------------------------------------------------------------------------
2+
3 |def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error
3+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4+
|Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.

tests/neg/i18302b.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def test = polyFun(1)(2)
2+
3+
def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error
4+
new PolyFunction:
5+
def apply(x: Int)(y: Int): Int = x + y

tests/neg/i18302c.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-- Error: tests/neg/i18302c.scala:4:32 ---------------------------------------------------------------------------------
2+
4 |def polyFun: PolyFunction { def foo(x: Int): Int } = // error
3+
| ^^^^^^^^^^^^^^^^^^^^
4+
| PolyFunction only supports apply method refinements

tests/neg/i18302c.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import scala.reflect.Selectable.reflectiveSelectable
2+
3+
def test = polyFun.foo(1)
4+
def polyFun: PolyFunction { def foo(x: Int): Int } = // error
5+
new PolyFunction { def foo(x: Int): Int = x + 1 }

tests/neg/i18302d.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-- Error: tests/neg/i18302d.scala:1:32 ---------------------------------------------------------------------------------
2+
1 |def polyFun: PolyFunction { def apply: Int } = // error
3+
| ^^^^^^^^^^^^^^
4+
|Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.

tests/neg/i18302d.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def polyFun: PolyFunction { def apply: Int } = // error
2+
new PolyFunction { def apply: Int = 1 }

tests/neg/i18302e.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Error: tests/neg/i18302e.scala:1:13 ---------------------------------------------------------------------------------
2+
1 |def polyFun: PolyFunction { } = // error
3+
| ^^^^^^^^^^^^^^^^^
4+
| PolyFunction subtypes must refine the apply method
5+
-- Error: tests/neg/i18302e.scala:4:15 ---------------------------------------------------------------------------------
6+
4 |def polyFun(f: PolyFunction { }) = () // error
7+
| ^^^^^^^^^^^^^^^^^
8+
| PolyFunction subtypes must refine the apply method

tests/neg/i18302e.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def polyFun: PolyFunction { } = // error
2+
new PolyFunction { }
3+
4+
def polyFun(f: PolyFunction { }) = () // error

tests/neg/i18302f.check

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
-- Error: tests/neg/i18302f.scala:1:13 ---------------------------------------------------------------------------------
2+
1 |def polyFun: PolyFunction = // error
3+
| ^^^^^^^^^^^^
4+
| PolyFunction subtypes must refine the apply method
5+
-- Error: tests/neg/i18302f.scala:4:16 ---------------------------------------------------------------------------------
6+
4 |def polyFun2(a: PolyFunction) = () // error
7+
| ^^^^^^^^^^^^
8+
| PolyFunction subtypes must refine the apply method
9+
-- Error: tests/neg/i18302f.scala:6:14 ---------------------------------------------------------------------------------
10+
6 |val polyFun3: PolyFunction = // error
11+
| ^^^^^^^^^^^^
12+
| PolyFunction subtypes must refine the apply method

tests/neg/i18302f.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
def polyFun: PolyFunction = // error
2+
new PolyFunction { }
3+
4+
def polyFun2(a: PolyFunction) = () // error
5+
6+
val polyFun3: PolyFunction = // error
7+
new PolyFunction { }

tests/neg/i18302i.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def polyFun1: Option[PolyFunction] = ??? // error
2+
def polyFun2: PolyFunction & Any = ??? // error
3+
def polyFun3: Any & PolyFunction = ??? // error
4+
def polyFun4: PolyFunction | Any = ??? // error
5+
def polyFun5: Any | PolyFunction = ??? // error
6+
def polyFun6(a: Any | PolyFunction) = ??? // error

tests/neg/i18302j.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def polyFunByName: PolyFunction { def apply(thunk: => Int): Int } = // error
2+
new PolyFunction { def apply(thunk: => Int): Int = 1 }
3+
4+
def polyFunVarArgs: PolyFunction { def apply(args: Int*): Int } = // error
5+
new PolyFunction { def apply(thunk: Int*): Int = 1 }

tests/neg/i8299.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package example
2+
3+
object Main {
4+
def main(a: Array[String]): Unit = {
5+
val p: PolyFunction = // error: PolyFunction subtypes must refine the apply method
6+
[A] => (xs: List[A]) => xs.headOption
7+
}
8+
}

tests/pos/i18302a.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = polyFun(1)
2+
3+
def polyFun: PolyFunction { def apply(x: Int): Int } =
4+
new PolyFunction { def apply(x: Int): Int = x + 1 }

0 commit comments

Comments
 (0)