Skip to content

Commit 9b694ae

Browse files
committed
Implement diagram macros
1 parent a173bbd commit 9b694ae

File tree

3 files changed

+116
-26
lines changed

3 files changed

+116
-26
lines changed

scalatest/src/main/scala/org/scalatest/DiagrammedExpr.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ private[org] case class AnchorValue(anchor: Int, value: Any)
3030
* so that the generated code can be compiled. It is expected that ScalaTest users would ever need to use <code>DiagrammedExpr</code>
3131
* directly.
3232
*/
33-
trait DiagrammedExpr[T] {
33+
trait DiagrammedExpr[+T] {
3434
val anchor: Int
3535
def anchorValues: List[AnchorValue]
3636
def value: T

scalatest/src/main/scala/org/scalatest/DiagrammedExprMacro.scala

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -422,41 +422,121 @@ import scala.quoted._
422422

423423

424424
object DiagrammedExprMacro {
425+
def let[S: Type, T](expr: Expr[S])(body: Expr[S] => Expr[T]): Expr[T] =
426+
'{
427+
val x = ~expr
428+
~body('(x))
429+
}
430+
431+
def lets[S: Type, T](xs: List[Expr[S]])(body: List[Expr[S]] => Expr[T]): Expr[T] = {
432+
def rec(xs: List[Expr[S]], acc: List[Expr[S]]): Expr[T] = xs match {
433+
case Nil => body(acc)
434+
case x :: xs => let(x) { (x: Expr[S]) => rec(xs, x :: acc) }
435+
}
436+
rec(xs, Nil)
437+
}
425438

426439
// Transform the input expression by parsing out the anchor and generate expression that can support diagram rendering
427-
def parse(condition: Expr[Boolean])(implicit refl: Reflection): Expr[DiagrammedExpr[Boolean]] = {
440+
def parse[T:Type](expr: Expr[T])(implicit refl: Reflection): Expr[DiagrammedExpr[T]] = {
428441
import refl._
429442
import quoted.Toolbox.Default._
430443
import Term._
431444

432445
def isXmlSugar(apply: Apply): Boolean = apply.tpe <:< typeOf[scala.xml.Elem]
433446
def isJavaStatic(tree: Tree): Boolean = tree.symbol.flags.is(Flags.Static)
434447

435-
condition.unseal match {
436-
case Apply(Select(New(_), _), _) => simpleExpr(condition) // delegate to simpleExpr if it is a New expression
437-
case IsApply(apply) if isXmlSugar(apply) => simpleExpr(condition)
438-
case IsApply(apply) if isJavaStatic(apply) => simpleExpr(condition)
439-
case IsApply(apply) => applyExpr(condition) // delegate to applyExpr if it is Apply
440-
case IsTypeApply(apply) => applyExpr(condition) // delegate to applyExpr if it is Apply
441-
case Select(This(_), _) => simpleExpr(condition) // delegate to simpleExpr if it is a Select for this, e.g. referring a to instance member.
442-
case IsSelect(x) if x.symbol.flags.is(Flags.Object) => simpleExpr(condition) // don't traverse packages
443-
case IsSelect(x) if isJavaStatic(x) => simpleExpr(condition)
444-
case IsSelect(select) => selectExpr(condition) // delegate to selectExpr if it is a Select
448+
expr.unseal match {
449+
case Apply(Select(New(_), _), _) => simpleExpr(expr) // delegate to simpleExpr if it is a New expression
450+
case IsApply(apply) if isXmlSugar(apply) => simpleExpr(expr)
451+
case IsApply(apply) if isJavaStatic(apply) => simpleExpr(expr)
452+
case IsApply(apply) => applyExpr(expr) // delegate to applyExpr if it is Apply
453+
case IsTypeApply(apply) => applyExpr(expr) // delegate to applyExpr if it is Apply
454+
case Select(This(_), _) => simpleExpr(expr) // delegate to simpleExpr if it is a Select for this, e.g. referring a to instance member.
455+
case IsSelect(x) if x.symbol.flags.is(Flags.Object) => simpleExpr(expr) // don't traverse packages
456+
case IsSelect(x) if isJavaStatic(x) => simpleExpr(expr)
457+
case IsSelect(select) => selectExpr(expr) // delegate to selectExpr if it is a Select
445458
case Block(stats, expr) =>
446-
Block(stats, parse(expr.seal[Boolean]).unseal).seal[DiagrammedExpr[Boolean]] // call parse recursively using the expr argument if it is a block
447-
case _ => simpleExpr(condition) // for others, just delegate to simpleExpr
459+
Block(stats, parse(expr.seal[T]).unseal).seal[DiagrammedExpr[T]] // call parse recursively using the expr argument if it is a block
460+
case _ => simpleExpr(expr) // for others, just delegate to simpleExpr
448461
}
449462
}
450463

451-
def applyExpr(condition: Expr[Boolean])(implicit refl: Reflection): Expr[DiagrammedExpr[Boolean]] = ???
452-
def selectExpr(condition: Expr[Boolean])(implicit refl: Reflection): Expr[DiagrammedExpr[Boolean]] = ???
464+
def applyExpr[T:Type](expr: Expr[T])(implicit refl: Reflection): Expr[DiagrammedExpr[T]] = {
465+
import refl._
466+
import quoted.Toolbox.Default._
467+
import Term._
468+
469+
def apply(l: Expr[_], name: String, r: List[Expr[_]]): Expr[T] = ???
470+
471+
expr.unseal.underlyingArgument match {
472+
case Term.Apply(Term.Select(lhs, op), rhs :: Nil) =>
473+
op match {
474+
case "||" | "|" =>
475+
val left = parse(lhs.seal[T & Boolean])
476+
val right = parse(rhs.seal[T & Boolean])
477+
'{
478+
val l = ~left
479+
val r = ~right
480+
if (l.value) l
481+
else DiagrammedExpr.applyExpr(l, r :: Nil, r.value, ~getAnchor(expr))
482+
}
483+
case "&&" | "&" =>
484+
val left = parse(lhs.seal[T & Boolean])
485+
val right = parse(rhs.seal[T & Boolean])
486+
'{
487+
val l = ~left
488+
val r = ~right
489+
if (l.value) DiagrammedExpr.applyExpr(l, r :: Nil, r.value, ~getAnchor(expr))
490+
else l
491+
}
492+
case _ =>
493+
val left = parse(lhs.seal[Any])
494+
val right = parse(rhs.seal[Any])
495+
'{
496+
val l = ~left
497+
val r = ~right
498+
val res = ~apply('(l.value), op, '(r.value) :: Nil)
499+
DiagrammedExpr.applyExpr(l, r :: Nil, res, ~getAnchor(expr))
500+
}
501+
}
502+
case Term.Apply(Term.Select(lhs, op), args) =>
503+
val left = parse(lhs.seal[Any])
504+
val rights = args.map(arg => parse(arg.seal[Any]))
505+
506+
let(left) { (l: Expr[DiagrammedExpr[_]]) =>
507+
lets(rights) { (rs: List[Expr[DiagrammedExpr[_]]]) =>
508+
val res = apply('((~l).value), op, rs)
509+
'{ DiagrammedExpr.applyExpr(~l, ~rs.toExprOfList, ~res, ~getAnchor(expr)) }
510+
}
511+
}
512+
case _ =>
513+
simpleExpr(expr)
514+
}
515+
}
516+
517+
def selectExpr[T:Type](expr: Expr[T])(implicit refl: Reflection): Expr[DiagrammedExpr[T]] = {
518+
import refl._
519+
import quoted.Toolbox.Default._
520+
import Term._
521+
522+
def selectField(o: Expr[_], name: String): Expr[T] = ???
523+
524+
expr.unseal match {
525+
case Select(qual, name) =>
526+
val obj = parse(qual.seal[Any])
527+
528+
'{
529+
val o = ~obj
530+
DiagrammedExpr.selectExpr(o, ~selectField('(o.value), name), ~getAnchor(expr))
531+
}
532+
}
533+
}
453534

454535
def transform(
455-
helper:Expr[(DiagrammedExpr[Boolean], Any, String, source.Position) => Assertion],
536+
helper: Expr[(DiagrammedExpr[Boolean], Any, String, source.Position) => Assertion],
456537
condition: Expr[Boolean], prettifier: Expr[Prettifier],
457538
pos: Expr[source.Position], clue: Expr[Any], sourceText: String
458-
)
459-
(implicit refl: Reflection): Expr[Assertion] = ???
539+
)(implicit refl: Reflection): Expr[Assertion] = ???
460540

461541

462542
/**

scalatest/src/main/scala/org/scalatest/matchers/TypeMatcherMacro.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,20 @@ private[scalatest] object TypeMatcherMacro {
182182
// Do checking on type parameter and generate AST to call TypeMatcherHelper.checkAType, used by 'shouldBe a [type]' syntax
183183
def shouldBeATypeImpl(self: Expr[org.scalatest.Matchers#AnyShouldWrapper[_]], aType: Expr[ResultOfATypeInvocation[_]])(implicit refl: Reflection): Expr[org.scalatest.Assertion] = {
184184
import refl._
185-
checkTypeParameter(refl)(aType.unseal, "an")
185+
checkTypeParameter(refl)(aType.unseal, "a")
186186
'{
187187
TypeMatcherHelper.assertAType((~self).leftSideValue, ~aType, (~self).prettifier, (~self).pos)
188188
}
189189
}
190190

191-
// // Do checking on type parameter and generate AST to call TypeMatcherHelper.checkAType, used by 'mustBe a [type]' syntax
192-
// def mustBeATypeImpl(context: Context)(aType: context.Expr[ResultOfATypeInvocation[_]]): context.Expr[org.scalatest.Assertion] =
193-
// assertTypeImpl(context)(aType.tree, "mustBe a", "assertAType")
191+
// Do checking on type parameter and generate AST to call TypeMatcherHelper.checkAType, used by 'mustBe a [type]' syntax
192+
def mustBeATypeImpl(self: Expr[org.scalatest.MustMatchers#AnyMustWrapper[_]], aType: Expr[ResultOfATypeInvocation[_]])(implicit refl: Reflection): Expr[org.scalatest.Assertion] = {
193+
import refl._
194+
checkTypeParameter(refl)(aType.unseal, "a")
195+
'{
196+
TypeMatcherHelper.assertAType((~self).leftSideValue, ~aType, (~self).prettifier, (~self).pos)
197+
}
198+
}
194199

195200
// Do checking on type parameter and generate AST to call TypeMatcherHelper.checkAType, used by 'shouldBe an [type]' syntax
196201
def shouldBeAnTypeImpl(self: Expr[org.scalatest.Matchers#AnyShouldWrapper[_]], anType: Expr[ResultOfAnTypeInvocation[_]])(implicit refl: Reflection): Expr[org.scalatest.Assertion] = {
@@ -201,9 +206,14 @@ private[scalatest] object TypeMatcherMacro {
201206
}
202207
}
203208

204-
// // Do checking on type parameter and generate AST to call TypeMatcherHelper.checkAnType, used by 'mustBe an [type]' syntax
205-
// def mustBeAnTypeImpl(context: Context)(anType: context.Expr[ResultOfAnTypeInvocation[_]]): context.Expr[org.scalatest.Assertion] =
206-
// assertTypeImpl(context)(anType.tree, "mustBe an", "assertAnType")
209+
// Do checking on type parameter and generate AST to call TypeMatcherHelper.checkAnType, used by 'mustBe an [type]' syntax
210+
def mustBeAnTypeImpl(self: Expr[org.scalatest.MustMatchers#AnyMustWrapper[_]], anType: Expr[ResultOfAnTypeInvocation[_]])(implicit refl: Reflection): Expr[org.scalatest.Assertion] = {
211+
import refl._
212+
checkTypeParameter(refl)(anType.unseal, "an")
213+
'{
214+
TypeMatcherHelper.assertAnType((~self).leftSideValue, ~anType, (~self).prettifier, (~self).pos)
215+
}
216+
}
207217

208218
// /*def expectTypeImpl(context: Context)(tree: context.Tree, beMethodName: String, assertMethodName: String): context.Expr[org.scalatest.Fact] = {
209219
// import context.universe._

0 commit comments

Comments
 (0)