1
1
package dotty .tools .dotc
2
2
package transform
3
3
4
+ import java .io .{PrintWriter , StringWriter }
5
+ import java .lang .reflect .Method
6
+ import java .net .URLClassLoader
7
+
4
8
import dotty .tools .dotc .ast .tpd
5
9
import dotty .tools .dotc .core .Contexts ._
6
10
import dotty .tools .dotc .core .Decorators ._
11
+ import dotty .tools .dotc .core .Flags .Package
12
+ import dotty .tools .dotc .core .NameKinds .FlatName
13
+ import dotty .tools .dotc .core .Names .Name
7
14
import dotty .tools .dotc .core .quoted ._
8
15
import dotty .tools .dotc .core .Types ._
9
16
import dotty .tools .dotc .core .Symbols ._
10
- import dotty .tools .dotc .interpreter ._
11
17
12
18
import scala .util .control .NonFatal
13
-
14
19
import dotty .tools .dotc .util .Positions .Position
15
20
21
+ import scala .reflect .ClassTag
22
+
16
23
/** Utility class to splice quoted expressions */
17
24
object Splicer {
18
25
import tpd ._
@@ -27,38 +34,39 @@ object Splicer {
27
34
}
28
35
29
36
private def reflectiveSplice (tree : Tree , call : Tree , bindings : List [Tree ], pos : Position )(implicit ctx : Context ): Tree = {
30
- val liftedArgs = {
31
- val bindMap = bindings.map {
32
- case vdef : ValDef => (vdef.rhs, ref(vdef.symbol))
33
- }.toMap
34
- def allArgs (call : Tree , acc : List [List [Tree ]]): List [List [Tree ]] = call match {
35
- case call : Apply => allArgs(call.fun, call.args :: acc)
36
- case call : TypeApply => allArgs(call.fun, call.args :: acc)
37
- case _ => acc
38
- }
39
- def liftArgs (tpe : Type , args : List [List [Tree ]]): List [Any ] = tpe match {
40
- case tp : MethodType =>
41
- val args1 = args.head.zip(tp.paramInfos).map {
42
- case (arg : Literal , tp) if tp.hasAnnotation(defn.InlineParamAnnot ) => arg.const.value
43
- case (arg, tp) =>
44
- assert(! tp.hasAnnotation(defn.InlineParamAnnot ))
45
- new scala.quoted.Exprs .TreeExpr (bindMap.getOrElse(arg, arg))
46
- }
47
- args1 ::: liftArgs(tp.resType, args.tail)
48
- case tp : PolyType =>
49
- val args1 = args.head.map(tp => new scala.quoted.Types .TreeType (tp))
50
- args1 ::: liftArgs(tp.resType, args.tail)
51
- case _ => Nil
52
- }
53
-
54
- liftArgs(call.symbol.info, allArgs(call, Nil ))
55
- }
56
-
37
+ val liftedArgs = getLiftedArgs(call, bindings)
57
38
val interpreter = new Interpreter (pos)
58
39
val interpreted = interpreter.interpretCallToSymbol[Seq [Any ] => Object ](call.symbol)
59
40
interpreted.flatMap(lambda => evaluateLambda(lambda, liftedArgs, pos)).fold(tree)(PickledQuotes .quotedExprToTree)
60
41
}
61
42
43
+ private def getLiftedArgs (call : Tree , bindings : List [Tree ])(implicit ctx : Context ) = {
44
+ val bindMap = bindings.map {
45
+ case vdef : ValDef => (vdef.rhs, ref(vdef.symbol))
46
+ }.toMap
47
+ def allArgs (call : Tree , acc : List [List [Tree ]]): List [List [Tree ]] = call match {
48
+ case call : Apply => allArgs(call.fun, call.args :: acc)
49
+ case call : TypeApply => allArgs(call.fun, call.args :: acc)
50
+ case _ => acc
51
+ }
52
+ def liftArgs (tpe : Type , args : List [List [Tree ]]): List [Any ] = tpe match {
53
+ case tp : MethodType =>
54
+ val args1 = args.head.zip(tp.paramInfos).map {
55
+ case (arg : Literal , tp) if tp.hasAnnotation(defn.InlineParamAnnot ) => arg.const.value
56
+ case (arg, tp) =>
57
+ assert(! tp.hasAnnotation(defn.InlineParamAnnot ))
58
+ new scala.quoted.Exprs .TreeExpr (bindMap.getOrElse(arg, arg))
59
+ }
60
+ args1 ::: liftArgs(tp.resType, args.tail)
61
+ case tp : PolyType =>
62
+ val args1 = args.head.map(tp => new scala.quoted.Types .TreeType (tp))
63
+ args1 ::: liftArgs(tp.resType, args.tail)
64
+ case _ => Nil
65
+ }
66
+
67
+ liftArgs(call.symbol.info, allArgs(call, Nil ))
68
+ }
69
+
62
70
private def evaluateLambda (lambda : Seq [Any ] => Object , args : Seq [Any ], pos : Position )(implicit ctx : Context ): Option [scala.quoted.Expr [_]] = {
63
71
try Some (lambda(args).asInstanceOf [scala.quoted.Expr [_]])
64
72
catch {
@@ -77,4 +85,110 @@ object Splicer {
77
85
}
78
86
}
79
87
88
+ /** Tree interpreter that can interpret calls to static methods with it's default arguments
89
+ *
90
+ * The interpreter assumes that all calls in the trees are to code that was
91
+ * previously compiled and is present in the classpath of the current context.
92
+ */
93
+ private class Interpreter (pos : Position )(implicit ctx : Context ) {
94
+
95
+ private [this ] val classLoader = {
96
+ val urls = ctx.settings.classpath.value.split(':' ).map(cp => java.nio.file.Paths .get(cp).toUri.toURL)
97
+ new URLClassLoader (urls, getClass.getClassLoader)
98
+ }
99
+
100
+ /** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
101
+ * Return Some of the result or None if some error happen during the interpretation.
102
+ */
103
+ def interpretCallToSymbol [T ](sym : Symbol )(implicit ct : ClassTag [T ]): Option [T ] = {
104
+ try {
105
+ val (clazz, instance) = loadModule(sym.owner)
106
+ val paramClasses = paramsSig(sym)
107
+ val interpretedArgs = paramClasses.map(defaultValue)
108
+ val method = getMethod(clazz, sym.name, paramClasses)
109
+ stopIfRuntimeException(method.invoke(instance, interpretedArgs : _* )) match {
110
+ case obj : T => Some (obj)
111
+ case obj =>
112
+ // TODO upgrade to a full type tag check or something similar
113
+ ctx.error(s " Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}" , pos)
114
+ None
115
+ }
116
+ } catch {
117
+ case ex : StopInterpretation =>
118
+ ctx.error(ex.msg, ex.pos)
119
+ None
120
+ }
121
+ }
122
+
123
+ private def loadModule (sym : Symbol ): (Class [_], Object ) = {
124
+ if (sym.owner.is(Package )) {
125
+ // is top level object
126
+ (loadClass(sym.companionModule.fullName), null )
127
+ } else {
128
+ // nested object in an object
129
+ val clazz = loadClass(sym.fullNameSeparated(FlatName ))
130
+ (clazz, clazz.newInstance().asInstanceOf [Object ])
131
+ }
132
+ }
133
+
134
+ private def loadClass (name : Name ): Class [_] = {
135
+ try classLoader.loadClass(name.toString)
136
+ catch {
137
+ case _ : ClassNotFoundException =>
138
+ val msg = s " Could not find interpreted class $name in classpath "
139
+ throw new StopInterpretation (msg, pos)
140
+ }
141
+ }
142
+
143
+ private def getMethod (clazz : Class [_], name : Name , paramClasses : List [Class [_]]): Method = {
144
+ try clazz.getMethod(name.toString, paramClasses : _* )
145
+ catch {
146
+ case _ : NoSuchMethodException =>
147
+ val msg = s " Could not find interpreted method ${clazz.getCanonicalName}. $name with parameters $paramClasses"
148
+ throw new StopInterpretation (msg, pos)
149
+ }
150
+ }
151
+
152
+ private def stopIfRuntimeException [T ](thunk : => T ): T = {
153
+ try thunk
154
+ catch {
155
+ case ex : RuntimeException =>
156
+ val sw = new StringWriter ()
157
+ sw.write(" A runtime exception occurred while interpreting\n " )
158
+ sw.write(ex.getMessage)
159
+ sw.write(" \n " )
160
+ ex.printStackTrace(new PrintWriter (sw))
161
+ sw.write(" \n " )
162
+ throw new StopInterpretation (sw.toString, pos)
163
+ }
164
+ }
165
+
166
+ /** List of classes of the parameters of the signature of `sym` */
167
+ private def paramsSig (sym : Symbol ): List [Class [_]] = {
168
+ sym.signature.paramsSig.map { param =>
169
+ defn.valueTypeNameToJavaType(param) match {
170
+ case Some (clazz) => clazz
171
+ case None => classLoader.loadClass(param.toString)
172
+ }
173
+ }
174
+ }
175
+
176
+ /** Get the default value for the given class */
177
+ private def defaultValue (clazz : Class [_]): Object = {
178
+ if (clazz == classOf [Boolean ]) false .asInstanceOf [Object ]
179
+ else if (clazz == classOf [Byte ]) 0 .toByte.asInstanceOf [Object ]
180
+ else if (clazz == classOf [Char ]) 0 .toChar.asInstanceOf [Object ]
181
+ else if (clazz == classOf [Short ]) 0 .asInstanceOf [Object ]
182
+ else if (clazz == classOf [Int ]) 0 .asInstanceOf [Object ]
183
+ else if (clazz == classOf [Long ]) 0L .asInstanceOf [Object ]
184
+ else if (clazz == classOf [Float ]) 0f .asInstanceOf [Object ]
185
+ else if (clazz == classOf [Double ]) 0d .asInstanceOf [Object ]
186
+ else null
187
+ }
188
+
189
+ /** Exception that stops interpretation if some issue is found */
190
+ private class StopInterpretation (val msg : String , val pos : Position ) extends Exception
191
+
192
+ }
193
+
80
194
}
0 commit comments