From 5a1d2d3bc7e79148e23148c0b49cb3dedf84daa2 Mon Sep 17 00:00:00 2001 From: Natsu Kagami Date: Wed, 19 Apr 2023 15:38:00 +0200 Subject: [PATCH] Allow contextual functions with erased parameters to be integrated - also add test from #17147 --- .../dotty/tools/dotc/core/Definitions.scala | 1 - .../transform/ContextFunctionResults.scala | 12 ++++++++--- tests/pos-custom-args/erased/tailrec.scala | 20 +++++++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) create mode 100644 tests/pos-custom-args/erased/tailrec.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 148b314220a8..85aac4ce601c 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1509,7 +1509,6 @@ class Definitions { /** Is an context function class. * - ContextFunctionN for N >= 0 - * - ErasedContextFunctionN for N > 0 */ def isContextFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isContextFunction diff --git a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala index 5863c360e728..b4eb71c541d3 100644 --- a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala +++ b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala @@ -20,7 +20,7 @@ object ContextFunctionResults: */ def annotateContextResults(mdef: DefDef)(using Context): Unit = def contextResultCount(rhs: Tree, tp: Type): Int = tp match - case defn.ContextFunctionType(_, resTpe, erasedParams) if !erasedParams.contains(true) /* Only enable for non-erased functions */ => + case defn.ContextFunctionType(_, resTpe, _) => rhs match case closureDef(meth) => 1 + contextResultCount(meth.rhs, resTpe) case _ => 0 @@ -116,8 +116,14 @@ object ContextFunctionResults: atPhase(erasurePhase)(integrateSelect(tree, n)) else tree match case Select(qual, name) => - if name == nme.apply && defn.isContextFunctionClass(tree.symbol.maybeOwner) then - integrateSelect(qual, n + 1) + if name == nme.apply then + qual.tpe match + case defn.ContextFunctionType(_, _, _) => + integrateSelect(qual, n + 1) + case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs + integrateSelect(qual, n + 1) + case _ => + n > 0 && contextResultCount(tree.symbol) >= n else n > 0 && contextResultCount(tree.symbol) >= n case Ident(name) => diff --git a/tests/pos-custom-args/erased/tailrec.scala b/tests/pos-custom-args/erased/tailrec.scala new file mode 100644 index 000000000000..cebcf4785c7a --- /dev/null +++ b/tests/pos-custom-args/erased/tailrec.scala @@ -0,0 +1,20 @@ +import scala.annotation.tailrec + +erased class Foo1 +class Foo2 + +@tailrec +final def test1(n: Int, acc: Int): (Foo1, Foo2) ?=> Int = + if n <= 0 then acc + else test1(n - 1, acc * n) + +@tailrec +final def test2(n: Int, acc: Int): Foo1 ?=> Int = + if n <= 0 then acc + else test2(n - 1, acc * n) + +@main def Test() = + given Foo1 = Foo1() + given Foo2 = Foo2() + test1(10, 0) + test2(10, 0)