Skip to content

Add inline case classes as argument #5200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 6, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions compiler/src/dotty/tools/dotc/transform/Splicer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,14 @@ object Splicer {
}

protected def interpretStaticMethodCall(fn: Tree, args: => List[Object])(implicit env: Env): Object = {
if (fn.symbol == defn.NoneModuleRef.termSymbol) {
// TODO generalize
None
} else {
val (clazz, instance) = loadModule(fn.symbol.owner)
val method = getMethod(clazz, fn.symbol.name, paramsSig(fn.symbol))
stopIfRuntimeException(method.invoke(instance, args: _*))
}
val instance = loadModule(fn.symbol.owner)
val method = getMethod(instance.getClass, fn.symbol.name, paramsSig(fn.symbol))
stopIfRuntimeException(method.invoke(instance, args: _*))
}

protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Object =
loadModule(fn.symbol.moduleClass)

protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Object = {
val clazz = loadClass(fn.symbol.owner.fullName)
val constr = clazz.getConstructor(paramsSig(fn.symbol): _*)
Expand All @@ -130,24 +128,23 @@ object Splicer {
protected def unexpectedTree(tree: Tree)(implicit env: Env): Object =
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.pos)

private def loadModule(sym: Symbol): (Class[_], Object) = {
private def loadModule(sym: Symbol): Object = {
if (sym.owner.is(Package)) {
// is top level object
val moduleClass = loadClass(sym.fullName)
val moduleInstance = moduleClass.getField(MODULE_INSTANCE_FIELD).get(null)
(moduleClass, moduleInstance)
moduleClass.getField(MODULE_INSTANCE_FIELD).get(null)
} else {
// nested object in an object
val clazz = loadClass(sym.fullNameSeparated(FlatName))
(clazz, clazz.getConstructor().newInstance().asInstanceOf[Object])
clazz.getConstructor().newInstance().asInstanceOf[Object]
}
}

private def loadClass(name: Name): Class[_] = {
try classLoader.loadClass(name.toString)
catch {
case _: ClassNotFoundException =>
val msg = s"Could not find macro class $name in classpath$extraMsg"
val msg = s"Could not find class $name in classpath$extraMsg"
throw new StopInterpretation(msg, pos)
}
}
Expand All @@ -156,7 +153,7 @@ object Splicer {
try clazz.getMethod(name.toString, paramClasses: _*)
catch {
case _: NoSuchMethodException =>
val msg = em"Could not find macro method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)$extraMsg"
val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)$extraMsg"
throw new StopInterpretation(msg, pos)
}
}
Expand Down Expand Up @@ -257,7 +254,8 @@ object Splicer {
protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretTastyContext()(implicit env: Env): Boolean = true
protected def interpretStaticMethodCall(fn: tpd.Tree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Boolean = true
protected def interpretNew(fn: RefTree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)

def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Boolean = {
// Assuming that top-level splices can only be in inline methods
Expand All @@ -278,6 +276,7 @@ object Splicer {
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
protected def interpretTastyContext()(implicit env: Env): Result
protected def interpretStaticMethodCall(fn: Tree, args: => List[Result])(implicit env: Env): Result
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Result
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Result
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result

Expand All @@ -294,8 +293,17 @@ object Splicer {
case _ if tree.symbol == defn.TastyTasty_macroContext =>
interpretTastyContext()

case StaticMethodCall(fn, args) =>
interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
case Call(fn, args) =>
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
interpretNew(fn, args.map(interpretTree))
} else if (fn.symbol.isStatic) {
if (fn.symbol.is(Module)) interpretModuleAccess(fn)
else interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
} else if (env.contains(fn.name)) {
env(fn.name)
} else {
unexpectedTree(tree)
}

// Interpret `foo(j = x, i = y)` which it is expanded to
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
Expand All @@ -307,28 +315,21 @@ object Splicer {
})
interpretTree(expr)(newEnv)
case NamedArg(_, arg) => interpretTree(arg)
case Ident(name) if env.contains(name) => env(name)

case Inlined(EmptyTree, Nil, expansion) => interpretTree(expansion)

case Apply(TypeApply(fun: RefTree, _), args) if fun.symbol.isConstructor && fun.symbol.owner.owner.is(Package) =>
interpretNew(fun, args.map(interpretTree))

case Apply(fun: RefTree, args) if fun.symbol.isConstructor && fun.symbol.owner.owner.is(Package)=>
interpretNew(fun, args.map(interpretTree))

case Typed(SeqLiteral(elems, _), _) =>
interpretVarargs(elems.map(e => interpretTree(e)))

case _ =>
unexpectedTree(tree)
}

object StaticMethodCall {
object Call {
def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match {
case fn: RefTree if fn.symbol.isStatic => Some((fn, Nil))
case Apply(StaticMethodCall(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
case TypeApply(StaticMethodCall(fn, args), _) => Some((fn, args))
case fn: RefTree => Some((fn, Nil))
case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
case TypeApply(Call(fn, args), _) => Some((fn, args))
case _ => None
}
}
Expand Down
30 changes: 19 additions & 11 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -679,20 +679,28 @@ trait Checking {
tree.tpe.widenTermRefExpr match {
case tp: ConstantType if exprPurity(tree) >= purityLevel => // ok
case tp =>
def isCaseClassApply(sym: Symbol): Boolean =
sym.name == nme.apply && sym.owner.is(Module) && sym.owner.companionClass.is(Case)
def isCaseClassNew(sym: Symbol): Boolean =
sym.isPrimaryConstructor && sym.owner.is(Case) && sym.owner.isStatic
def isCaseObject(sym: Symbol): Boolean = {
// TODO add alias to Nil in scala package
sym.is(Case) && sym.is(Module)
}
val allow =
ctx.erasedTypes ||
ctx.inInlineMethod ||
// TODO: Make None and Some constant types?
tree.symbol.eq(defn.NoneModuleRef.termSymbol) ||
tree.symbol.eq(defn.SomeClass.primaryConstructor) ||
(tree.symbol.name == nme.apply && tree.symbol.owner == defn.SomeClass.companionModule.moduleClass)
if (!allow) ctx.error(em"$what must be a constant expression", tree.pos)
else tree match {
// TODO: add cases for type apply and multiple applies
case Apply(_, args) =>
for (arg <- args)
checkInlineConformant(arg, isFinal, what)
case _ =>
(tree.symbol.isStatic && isCaseObject(tree.symbol) || isCaseClassApply(tree.symbol)) ||
(tree.symbol.owner.isStatic && isCaseClassNew(tree.symbol))
if (!allow) ctx.error(em"$what must be a known value", tree.pos)
else {
def checkArgs(tree: Tree): Unit = tree match {
case Apply(fn, args) =>
args.foreach(arg => checkInlineConformant(arg, isFinal, what))
checkArgs(fn)
case _ =>
}
checkArgs(tree)
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions compiler/test/dotc/run-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ i5119
i5119b
inline-varargs-1
implicitShortcut
inline-case-objects
inline-option
inline-macro-staged-interpreter
inline-tuples-1
inline-tuples-2
lazy-implicit-lists.scala
lazy-implicit-nums.scala
lazy-traits.scala
Expand Down
16 changes: 16 additions & 0 deletions tests/neg/inline-case-objects/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

import scala.quoted._

object Macros {
def impl(foo: Any): Expr[String] = foo.getClass.getCanonicalName.toExpr
}

class Bar {
case object Baz
}

package foo {
class Bar {
case object Baz
}
}
11 changes: 11 additions & 0 deletions tests/neg/inline-case-objects/Main_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

object Test {

def main(args: Array[String]): Unit = {
val bar = new Bar
println(fooString(bar.Baz)) // error
}

inline def fooString(inline x: Any): String = ~Macros.impl(x)

}
33 changes: 33 additions & 0 deletions tests/neg/inline-macro-staged-interpreter/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

import scala.quoted._

object E {

inline def eval[T](inline x: E[T]): T = ~impl(x)

def impl[T](x: E[T]): Expr[T] = x.lift

}

trait E[T] {
def lift: Expr[T]
}

case class I(n: Int) extends E[Int] {
def lift: Expr[Int] = n.toExpr
}

case class Plus[T](x: E[T], y: E[T])(implicit op: Plus2[T]) extends E[T] {
def lift: Expr[T] = op(x.lift, y.lift)
}

trait Op2[T] {
def apply(x: Expr[T], y: Expr[T]): Expr[T]
}

trait Plus2[T] extends Op2[T]
object Plus2 {
implicit case object IPlus extends Plus2[Int] {
def apply(x: Expr[Int], y: Expr[Int]): Expr[Int] = '(~x + ~y)
}
}
20 changes: 20 additions & 0 deletions tests/neg/inline-macro-staged-interpreter/Main_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

object Test {

def main(args: Array[String]): Unit = {
val i = I(2)
E.eval(
i // error
)

E.eval(Plus(
i, // error
I(4)))

val plus = Plus2.IPlus
E.eval(Plus(I(2), I(4))(
plus // error
))
}

}
27 changes: 27 additions & 0 deletions tests/neg/inline-tuples-1/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

import scala.quoted._

object Macros {
def tup1(tup: Tuple1[Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup2(tup: Tuple2[Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup3(tup: Tuple3[Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup4(tup: Tuple4[Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup5(tup: Tuple5[Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup6(tup: Tuple6[Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup7(tup: Tuple7[Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup8(tup: Tuple8[Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup9(tup: Tuple9[Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup10(tup: Tuple10[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup11(tup: Tuple11[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup12(tup: Tuple12[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup13(tup: Tuple13[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup14(tup: Tuple14[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup15(tup: Tuple15[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup16(tup: Tuple16[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup17(tup: Tuple17[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup18(tup: Tuple18[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup19(tup: Tuple19[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup20(tup: Tuple20[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup21(tup: Tuple21[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
def tup22(tup: Tuple22[Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int]): Expr[Int] = tup.productIterator.map(_.asInstanceOf[Int]).sum.toExpr
}
Loading