Skip to content

Commit 3af47d0

Browse files
nicolasstuckigriggt
authored andcommitted
Use typed quotes API
1 parent 1aab38f commit 3af47d0

File tree

1 file changed

+24
-43
lines changed

1 file changed

+24
-43
lines changed

shared/src/main/scala-3/verify/asserts/RecorderMacro.scala

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,13 @@ class RecorderMacro(using qctx0: Quotes) {
2727
recording: Expr[A],
2828
message: Expr[String],
2929
listener: Expr[RecorderListener[A, R]]): Expr[R] = {
30-
val termArg: Term = recording.asTerm.underlyingArgument // TODO remove use of underlyingArgument
30+
val termArg = recording.asTerm.underlyingArgument.asExprOf[A] // TODO remove use of underlyingArgument
3131

3232
'{
3333
val recorderRuntime: RecorderRuntime[A, R] = new RecorderRuntime($listener)
3434
recorderRuntime.recordMessage($message)
35-
${
36-
Block(
37-
recordExpressions('{ recorderRuntime }.asTerm, termArg),
38-
'{ recorderRuntime.completeRecording() }.asTerm
39-
).asExprOf[R]
40-
}
35+
${recordExpressions('recorderRuntime, termArg)}
36+
recorderRuntime.completeRecording()
4137
}
4238
}
4339

@@ -46,33 +42,22 @@ class RecorderMacro(using qctx0: Quotes) {
4642
found: Expr[A],
4743
message: Expr[String],
4844
listener: Expr[RecorderListener[A, R]]): Expr[R] = {
49-
val expectedArg: Term = expected.asTerm
50-
val foundArg: Term = found.asTerm
51-
5245
'{
5346
val recorderRuntime: RecorderRuntime[A, R] = new RecorderRuntime($listener)
5447
recorderRuntime.recordMessage($message)
55-
${
56-
Block(
57-
recordExpressions('{ recorderRuntime }.asTerm, expectedArg) :::
58-
recordExpressions('{ recorderRuntime }.asTerm, foundArg),
59-
'{ recorderRuntime.completeRecording() }.asTerm
60-
).asExprOf[R]
61-
}
48+
${recordExpressions('recorderRuntime, expected)}
49+
${recordExpressions('recorderRuntime, found)}
50+
recorderRuntime.completeRecording()
6251
}
6352
}
6453

65-
private[this] def recordExpressions(runtime: Term, recording: Term): List[Term] = {
54+
private[this] def recordExpressions[A: Type, R: Type](runtime: Expr[RecorderRuntime[A, R]], recording: Expr[A]): Expr[Any] = {
6655
val source = getSourceCode(recording)
67-
val ast = recording.show(using Printer.TreeStructure)
56+
val ast = recording.asTerm.show(using Printer.TreeStructure)
6857

69-
val resetValuesSel: Term = {
70-
val m = runtimeSym.memberMethod("resetValues").head
71-
runtime.select(m)
72-
}
7358
try {
74-
List(
75-
Apply(resetValuesSel, List()),
59+
Expr.block(
60+
List('{ $runtime.resetValues() }),
7661
recordExpression(runtime, source, ast, recording)
7762
)
7863
} catch {
@@ -82,21 +67,14 @@ class RecorderMacro(using qctx0: Quotes) {
8267
}
8368

8469
// emit recorderRuntime.recordExpression(<source>, <tree>, instrumented)
85-
private[this] def recordExpression(runtime: Term, source: String, ast: String, expr: Term): Term = {
86-
val instrumented = recordAllValues(runtime, expr)
87-
val recordExpressionSel: Term = {
88-
val m = runtimeSym.memberMethod("recordExpression").head
89-
runtime.select(m)
90-
}
91-
Apply(recordExpressionSel,
92-
List(
93-
Literal(StringConstant(source)),
94-
Literal(StringConstant(ast)),
95-
instrumented
96-
))
70+
private[this] def recordExpression[R: Type, A: Type](runtime: Expr[RecorderRuntime[A, R]], source: String, ast: String, expr: Expr[A]): Expr[Any] = {
71+
val instrumented = recordAllValues(runtime, expr.asTerm).asExprOf[A]
72+
val sourceExpr = Expr(source)
73+
val astExpr = Expr(ast)
74+
'{ $runtime.recordExpression($sourceExpr, $astExpr, $instrumented) }
9775
}
9876

99-
private[this] def recordAllValues(runtime: Term, expr: Term): Term =
77+
private[this] def recordAllValues[R, A](runtime: Expr[RecorderRuntime[A, R]], expr: Term): Term =
10078
// TODO use an TreeMap or an ExprMap
10179
expr match {
10280
case New(_) => expr
@@ -109,7 +87,7 @@ class RecorderMacro(using qctx0: Quotes) {
10987
case _ => recordValue(runtime, recordSubValues(runtime, expr), expr)
11088
}
11189

112-
private[this] def recordSubValues(runtime: Term, expr: Term): Term =
90+
private[this] def recordSubValues[R, A](runtime: Expr[RecorderRuntime[A, R]], expr: Term): Term =
11391
expr match {
11492
case Apply(x, ys) =>
11593
try {
@@ -125,12 +103,12 @@ class RecorderMacro(using qctx0: Quotes) {
125103
case _ => expr
126104
}
127105

128-
private[this] def recordValue(runtime: Term, expr: Term, origExpr: Term): Term = {
106+
private[this] def recordValue[R, A](runtime: Expr[RecorderRuntime[A, R]], expr: Term, origExpr: Term): Term = {
129107
// debug
130108
// println("recording " + expr.showExtractors + " at " + getAnchor(expr))
131109
val recordValueSel: Term = {
132110
val m = runtimeSym.memberMethod("recordValue").head
133-
runtime.select(m)
111+
runtime.asTerm.select(m)
134112
}
135113
def skipIdent(sym: Symbol): Boolean =
136114
sym match {
@@ -156,6 +134,9 @@ class RecorderMacro(using qctx0: Quotes) {
156134
case TypeApply(_, _) => expr
157135
case Ident(_) if skipIdent(expr.symbol) => expr
158136
case _ =>
137+
// TODO:
138+
// expr.asExpr match { case '{ $e: t } => '{ $runtime.recordValue[t]($e, ${Expr(getAnchor(expr))}) } }
139+
// Then remove `runtimeSym`
159140
val tapply = recordValueSel.appliedToType(expr.tpe)
160141
Apply.copy(expr)(
161142
tapply,
@@ -167,8 +148,8 @@ class RecorderMacro(using qctx0: Quotes) {
167148
}
168149
}
169150

170-
private[this] def getSourceCode(expr: Tree): String = {
171-
val pos = expr.pos
151+
private[this] def getSourceCode(expr: Expr[Any]): String = {
152+
val pos = expr.asTerm.pos
172153
(" " * pos.startColumn) + pos.sourceCode.get
173154
}
174155

0 commit comments

Comments
 (0)