Skip to content

Commit 37b86e3

Browse files
committed
WIP
1 parent 67d67ff commit 37b86e3

File tree

3 files changed

+143
-158
lines changed

3 files changed

+143
-158
lines changed

compiler/src/dotty/tools/dotc/interpreter/Interpreter.scala

Lines changed: 0 additions & 128 deletions
This file was deleted.

compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import typer.Implicits.SearchFailureType
2424
import scala.collection.mutable
2525
import dotty.tools.dotc.core.StdNames._
2626
import dotty.tools.dotc.core.quoted._
27-
import dotty.tools.dotc.interpreter.Interpreter
2827

2928

3029
/** Translates quoted terms and types to `unpickle` method calls.
Lines changed: 143 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
package dotty.tools.dotc
22
package transform
33

4+
import java.io.{PrintWriter, StringWriter}
5+
import java.lang.reflect.Method
6+
import java.net.URLClassLoader
7+
48
import dotty.tools.dotc.ast.tpd
59
import dotty.tools.dotc.core.Contexts._
610
import dotty.tools.dotc.core.Decorators._
11+
import dotty.tools.dotc.core.Flags.Package
12+
import dotty.tools.dotc.core.NameKinds.FlatName
13+
import dotty.tools.dotc.core.Names.Name
714
import dotty.tools.dotc.core.quoted._
815
import dotty.tools.dotc.core.Types._
916
import dotty.tools.dotc.core.Symbols._
10-
import dotty.tools.dotc.interpreter._
1117

1218
import scala.util.control.NonFatal
13-
1419
import dotty.tools.dotc.util.Positions.Position
1520

21+
import scala.reflect.ClassTag
22+
1623
/** Utility class to splice quoted expressions */
1724
object Splicer {
1825
import tpd._
@@ -27,38 +34,39 @@ object Splicer {
2734
}
2835

2936
private def reflectiveSplice(tree: Tree, call: Tree, bindings: List[Tree], pos: Position)(implicit ctx: Context): Tree = {
30-
val liftedArgs = {
31-
val bindMap = bindings.map {
32-
case vdef: ValDef => (vdef.rhs, ref(vdef.symbol))
33-
}.toMap
34-
def allArgs(call: Tree, acc: List[List[Tree]]): List[List[Tree]] = call match {
35-
case call: Apply => allArgs(call.fun, call.args :: acc)
36-
case call: TypeApply => allArgs(call.fun, call.args :: acc)
37-
case _ => acc
38-
}
39-
def liftArgs(tpe: Type, args: List[List[Tree]]): List[Any] = tpe match {
40-
case tp: MethodType =>
41-
val args1 = args.head.zip(tp.paramInfos).map {
42-
case (arg: Literal, tp) if tp.hasAnnotation(defn.InlineParamAnnot) => arg.const.value
43-
case (arg, tp) =>
44-
assert(!tp.hasAnnotation(defn.InlineParamAnnot))
45-
new scala.quoted.Exprs.TreeExpr(bindMap.getOrElse(arg, arg))
46-
}
47-
args1 ::: liftArgs(tp.resType, args.tail)
48-
case tp: PolyType =>
49-
val args1 = args.head.map(tp => new scala.quoted.Types.TreeType(tp))
50-
args1 ::: liftArgs(tp.resType, args.tail)
51-
case _ => Nil
52-
}
53-
54-
liftArgs(call.symbol.info, allArgs(call, Nil))
55-
}
56-
37+
val liftedArgs = getLiftedArgs(call, bindings)
5738
val interpreter = new Interpreter(pos)
5839
val interpreted = interpreter.interpretCallToSymbol[Seq[Any] => Object](call.symbol)
5940
interpreted.flatMap(lambda => evaluateLambda(lambda, liftedArgs, pos)).fold(tree)(PickledQuotes.quotedExprToTree)
6041
}
6142

43+
private def getLiftedArgs(call: Tree, bindings: List[Tree])(implicit ctx: Context) = {
44+
val bindMap = bindings.map {
45+
case vdef: ValDef => (vdef.rhs, ref(vdef.symbol))
46+
}.toMap
47+
def allArgs(call: Tree, acc: List[List[Tree]]): List[List[Tree]] = call match {
48+
case call: Apply => allArgs(call.fun, call.args :: acc)
49+
case call: TypeApply => allArgs(call.fun, call.args :: acc)
50+
case _ => acc
51+
}
52+
def liftArgs(tpe: Type, args: List[List[Tree]]): List[Any] = tpe match {
53+
case tp: MethodType =>
54+
val args1 = args.head.zip(tp.paramInfos).map {
55+
case (arg: Literal, tp) if tp.hasAnnotation(defn.InlineParamAnnot) => arg.const.value
56+
case (arg, tp) =>
57+
assert(!tp.hasAnnotation(defn.InlineParamAnnot))
58+
new scala.quoted.Exprs.TreeExpr(bindMap.getOrElse(arg, arg))
59+
}
60+
args1 ::: liftArgs(tp.resType, args.tail)
61+
case tp: PolyType =>
62+
val args1 = args.head.map(tp => new scala.quoted.Types.TreeType(tp))
63+
args1 ::: liftArgs(tp.resType, args.tail)
64+
case _ => Nil
65+
}
66+
67+
liftArgs(call.symbol.info, allArgs(call, Nil))
68+
}
69+
6270
private def evaluateLambda(lambda: Seq[Any] => Object, args: Seq[Any], pos: Position)(implicit ctx: Context): Option[scala.quoted.Expr[_]] = {
6371
try Some(lambda(args).asInstanceOf[scala.quoted.Expr[_]])
6472
catch {
@@ -77,4 +85,110 @@ object Splicer {
7785
}
7886
}
7987

88+
/** Tree interpreter that can interpret calls to static methods with it's default arguments
89+
*
90+
* The interpreter assumes that all calls in the trees are to code that was
91+
* previously compiled and is present in the classpath of the current context.
92+
*/
93+
private class Interpreter(pos: Position)(implicit ctx: Context) {
94+
95+
private[this] val classLoader = {
96+
val urls = ctx.settings.classpath.value.split(':').map(cp => java.nio.file.Paths.get(cp).toUri.toURL)
97+
new URLClassLoader(urls, getClass.getClassLoader)
98+
}
99+
100+
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
101+
* Return Some of the result or None if some error happen during the interpretation.
102+
*/
103+
def interpretCallToSymbol[T](sym: Symbol)(implicit ct: ClassTag[T]): Option[T] = {
104+
try {
105+
val (clazz, instance) = loadModule(sym.owner)
106+
val paramClasses = paramsSig(sym)
107+
val interpretedArgs = paramClasses.map(defaultValue)
108+
val method = getMethod(clazz, sym.name, paramClasses)
109+
stopIfRuntimeException(method.invoke(instance, interpretedArgs: _*)) match {
110+
case obj: T => Some(obj)
111+
case obj =>
112+
// TODO upgrade to a full type tag check or something similar
113+
ctx.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos)
114+
None
115+
}
116+
} catch {
117+
case ex: StopInterpretation =>
118+
ctx.error(ex.msg, ex.pos)
119+
None
120+
}
121+
}
122+
123+
private def loadModule(sym: Symbol): (Class[_], Object) = {
124+
if (sym.owner.is(Package)) {
125+
// is top level object
126+
(loadClass(sym.companionModule.fullName), null)
127+
} else {
128+
// nested object in an object
129+
val clazz = loadClass(sym.fullNameSeparated(FlatName))
130+
(clazz, clazz.newInstance().asInstanceOf[Object])
131+
}
132+
}
133+
134+
private def loadClass(name: Name): Class[_] = {
135+
try classLoader.loadClass(name.toString)
136+
catch {
137+
case _: ClassNotFoundException =>
138+
val msg = s"Could not find interpreted class $name in classpath"
139+
throw new StopInterpretation(msg, pos)
140+
}
141+
}
142+
143+
private def getMethod(clazz: Class[_], name: Name, paramClasses: List[Class[_]]): Method = {
144+
try clazz.getMethod(name.toString, paramClasses: _*)
145+
catch {
146+
case _: NoSuchMethodException =>
147+
val msg = s"Could not find interpreted method ${clazz.getCanonicalName}.$name with parameters $paramClasses"
148+
throw new StopInterpretation(msg, pos)
149+
}
150+
}
151+
152+
private def stopIfRuntimeException[T](thunk: => T): T = {
153+
try thunk
154+
catch {
155+
case ex: RuntimeException =>
156+
val sw = new StringWriter()
157+
sw.write("A runtime exception occurred while interpreting\n")
158+
sw.write(ex.getMessage)
159+
sw.write("\n")
160+
ex.printStackTrace(new PrintWriter(sw))
161+
sw.write("\n")
162+
throw new StopInterpretation(sw.toString, pos)
163+
}
164+
}
165+
166+
/** List of classes of the parameters of the signature of `sym` */
167+
private def paramsSig(sym: Symbol): List[Class[_]] = {
168+
sym.signature.paramsSig.map { param =>
169+
defn.valueTypeNameToJavaType(param) match {
170+
case Some(clazz) => clazz
171+
case None => classLoader.loadClass(param.toString)
172+
}
173+
}
174+
}
175+
176+
/** Get the default value for the given class */
177+
private def defaultValue(clazz: Class[_]): Object = {
178+
if (clazz == classOf[Boolean]) false.asInstanceOf[Object]
179+
else if (clazz == classOf[Byte]) 0.toByte.asInstanceOf[Object]
180+
else if (clazz == classOf[Char]) 0.toChar.asInstanceOf[Object]
181+
else if (clazz == classOf[Short]) 0.asInstanceOf[Object]
182+
else if (clazz == classOf[Int]) 0.asInstanceOf[Object]
183+
else if (clazz == classOf[Long]) 0L.asInstanceOf[Object]
184+
else if (clazz == classOf[Float]) 0f.asInstanceOf[Object]
185+
else if (clazz == classOf[Double]) 0d.asInstanceOf[Object]
186+
else null
187+
}
188+
189+
/** Exception that stops interpretation if some issue is found */
190+
private class StopInterpretation(val msg: String, val pos: Position) extends Exception
191+
192+
}
193+
80194
}

0 commit comments

Comments
 (0)