diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index fb045b8a5f64..10d4fed7f058 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1504,7 +1504,7 @@ object desugar { .withSpan(original.span.withPoint(named.span.start)) /** Main desugaring method */ - def apply(tree: Tree)(using Context): Tree = { + def apply(tree: Tree, pt: Type = NoType)(using Context): Tree = { /** Create tree for for-comprehension `` or * `` where mapName and flatMapName are chosen @@ -1698,11 +1698,11 @@ object desugar { } } - def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match { + def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match { case Parens(body1) => - makePolyFunction(targs, body1) + makePolyFunction(targs, body1, pt) case Block(Nil, body1) => - makePolyFunction(targs, body1) + makePolyFunction(targs, body1, pt) case Function(vargs, res) => assert(targs.nonEmpty) // TODO: Figure out if we need a `PolyFunctionWithMods` instead. @@ -1726,12 +1726,26 @@ object desugar { } else { // Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body - // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body } + // with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R + // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body } + // where R2 is R, with all references to S_1..S_M replaced with T1..T_M. + + def typeTree(tp: Type) = tp match + case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass => + var bail = false + def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match + case tp: TypeRef => ref(tp) + case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name) + case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_))) + case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree } + val mapped = mapper(mt.resultType, topLevel = true) + if bail then TypeTree() else mapped + case _ => TypeTree() val applyVParams = vargs.asInstanceOf[List[ValDef]] .map(varg => varg.withAddedFlags(mods.flags | Param)) New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef, - List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, TypeTree(), res)) + List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res)) )) } case _ => @@ -1753,7 +1767,7 @@ object desugar { val desugared = tree match { case PolyFunction(targs, body) => - makePolyFunction(targs, body) orElse tree + makePolyFunction(targs, body, pt) orElse tree case SymbolLit(str) => Apply( ref(defn.ScalaSymbolClass.companionModule.termRef), diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 71a8872343b4..830131311c12 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2871,7 +2871,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedTypeOrClassDef case tree: untpd.Labeled => typedLabeled(tree) - case _ => typedUnadapted(desugar(tree), pt, locked) + case _ => typedUnadapted(desugar(tree, pt), pt, locked) } } @@ -2924,7 +2924,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case tree: untpd.Splice => typedSplice(tree, pt) case tree: untpd.MacroTree => report.error("Unexpected macro", tree.srcPos); tpd.nullLiteral // ill-formed code may reach here case tree: untpd.Hole => typedHole(tree, pt) - case _ => typedUnadapted(desugar(tree), pt, locked) + case _ => typedUnadapted(desugar(tree, pt), pt, locked) } try diff --git a/tests/pos/i15554.scala b/tests/pos/i15554.scala new file mode 100644 index 000000000000..8573a5fff549 --- /dev/null +++ b/tests/pos/i15554.scala @@ -0,0 +1,8 @@ +enum PingMessage[Response]: + case Ping(from: String) extends PingMessage[String] + +val pongBehavior: [O] => (Unit, PingMessage[O]) => (Unit, O) = + [P] => + (state: Unit, msg: PingMessage[P]) => + msg match + case PingMessage.Ping(from) => ((), s"Pong from $from")