@@ -40,49 +40,74 @@ class BetaReduce extends MiniPhase:
40
40
41
41
override def transformApply (app : Apply )(using Context ): Tree = app.fun match
42
42
case Select (fn, nme.apply) if defn.isFunctionType(fn.tpe) =>
43
- val app1 = BetaReduce (app, fn, app.args)
43
+ val app1 = BetaReduce (app, fn, List (app.args))
44
+ if app1 ne app then report.log(i " beta reduce $app -> $app1" )
45
+ app1
46
+ case TypeApply (Select (fn, nme.apply), targs) if fn.tpe.typeSymbol eq defn.PolyFunctionClass =>
47
+ val app1 = BetaReduce (app, fn, List (targs, app.args))
44
48
if app1 ne app then report.log(i " beta reduce $app -> $app1" )
45
49
app1
46
50
case _ =>
47
51
app
48
52
49
-
50
53
object BetaReduce :
51
54
import ast .tpd ._
52
55
53
56
val name : String = " betaReduce"
54
57
val description : String = " reduce closure applications"
55
58
56
59
/** Beta-reduces a call to `fn` with arguments `argSyms` or returns `tree` */
57
- def apply (original : Tree , fn : Tree , args : List [Tree ])(using Context ): Tree =
60
+ def apply (original : Tree , fn : Tree , argss : List [List [ Tree ] ])(using Context ): Tree =
58
61
fn match
59
62
case Typed (expr, _) =>
60
- BetaReduce (original, expr, args )
63
+ BetaReduce (original, expr, argss )
61
64
case Block ((anonFun : DefDef ) :: Nil , closure : Closure ) =>
62
- BetaReduce (anonFun, args)
65
+ BetaReduce (anonFun, argss)
66
+ case Block ((TypeDef (_, template : Template )) :: Nil , Typed (Apply (Select (New (_), _), _), _)) if template.constr.rhs.isEmpty =>
67
+ template.body match
68
+ case (anonFun : DefDef ) :: Nil =>
69
+ BetaReduce (anonFun, argss)
70
+ case _ =>
71
+ original
63
72
case Block (stats, expr) =>
64
- val tree = BetaReduce (original, expr, args )
73
+ val tree = BetaReduce (original, expr, argss )
65
74
if tree eq original then original
66
75
else cpy.Block (fn)(stats, tree)
67
76
case Inlined (call, bindings, expr) =>
68
- val tree = BetaReduce (original, expr, args )
77
+ val tree = BetaReduce (original, expr, argss )
69
78
if tree eq original then original
70
79
else cpy.Inlined (fn)(call, bindings, tree)
71
80
case _ =>
72
81
original
73
82
end apply
74
83
75
84
/** Beta-reduces a call to `ddef` with arguments `args` */
76
- def apply (ddef : DefDef , args : List [Tree ])(using Context ) =
77
- val bindings = new ListBuffer [ValDef ]()
78
- val expansion1 = reduceApplication(ddef, args , bindings)
85
+ def apply (ddef : DefDef , argss : List [List [ Tree ] ])(using Context ) =
86
+ val bindings = new ListBuffer [DefTree ]()
87
+ val expansion1 = reduceApplication(ddef, argss , bindings)
79
88
val bindings1 = bindings.result()
80
89
seq(bindings1, expansion1)
81
90
82
91
/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
83
- def reduceApplication (ddef : DefDef , args : List [Tree ], bindings : ListBuffer [ValDef ])(using Context ): Tree =
84
- val vparams = ddef.termParamss.iterator.flatten.toList
92
+ def reduceApplication (ddef : DefDef , argss : List [List [Tree ]], bindings : ListBuffer [DefTree ])(using Context ): Tree =
93
+ assert(argss.size == 1 || argss.size == 2 )
94
+ val targs = if argss.size == 2 then argss.head else Nil
95
+ val args = argss.last
96
+ val tparams = ddef.leadingTypeParams
97
+ val vparams = ddef.termParamss.flatten
98
+ assert(targs.hasSameLengthAs(tparams))
85
99
assert(args.hasSameLengthAs(vparams))
100
+
101
+ val targSyms =
102
+ for (targ, tparam) <- targs.zip(tparams) yield
103
+ targ.tpe.dealias match
104
+ case ref @ TypeRef (NoPrefix , _) =>
105
+ ref.symbol
106
+ case _ =>
107
+ val binding = TypeDef (newSymbol(ctx.owner, tparam.name, EmptyFlags , targ.tpe, coord = targ.span)).withSpan(targ.span)
108
+ bindings += binding
109
+ binding.symbol
110
+
86
111
val argSyms =
87
112
for (arg, param) <- args.zip(vparams) yield
88
113
arg.tpe.dealias match
@@ -99,8 +124,8 @@ object BetaReduce:
99
124
val expansion = TreeTypeMap (
100
125
oldOwners = ddef.symbol :: Nil ,
101
126
newOwners = ctx.owner :: Nil ,
102
- substFrom = vparams.map(_.symbol),
103
- substTo = argSyms
127
+ substFrom = (tparams ::: vparams) .map(_.symbol),
128
+ substTo = targSyms ::: argSyms
104
129
).transform(ddef.rhs)
105
130
106
131
val expansion1 = new TreeMap {
0 commit comments