Skip to content

Commit b21c2ce

Browse files
committed
Add regression test for scala#10211
1 parent 0744eab commit b21c2ce

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

tests/pos-macros/i10211/Macro_1.scala

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package x
2+
3+
import scala.quoted._
4+
5+
trait CB[T]:
6+
def map[S](f: T=>S): CB[S] = ???
7+
8+
9+
class MyArr[A]:
10+
def map[B](f: A=>B):MyArr[B] = ???
11+
def mapOut[B](f: A=> CB[B]): CB[MyArr[B]] = ???
12+
def flatMap[B](f: A=>MyArr[B]):MyArr[B] = ???
13+
def flatMapOut[B](f: A=>CB[MyArr[B]]):MyArr[B] = ???
14+
def withFilter(p: A=>Boolean): MyArr[A] = ???
15+
def withFilterOut(p: A=>CB[Boolean]): DelayedWithFilter[A] = ???
16+
def map2[B](f: A=>B):MyArr[B] = ???
17+
18+
class DelayedWithFilter[A]:
19+
def map[B](f: A=>B):MyArr[B] = ???
20+
def mapOut[B](f: A=> CB[B]): CB[MyArr[B]] = ???
21+
def flatMap[B](f: A=>MyArr[B]):MyArr[B] = ???
22+
def flatMapOut[B](f: A=>CB[MyArr[B]]): CB[MyArr[B]] = ???
23+
def map2[B](f: A=>B):CB[MyArr[B]] = ???
24+
25+
26+
def await[T](x:CB[T]):T = ???
27+
28+
object CBM:
29+
def pure[T](t:T):CB[T] = ???
30+
def map[T,S](a:CB[T])(f:T=>S):CB[S] = ???
31+
32+
object X:
33+
34+
inline def process[T](inline f:T) = ${
35+
processImpl[T]('f)
36+
}
37+
38+
def processImpl[T:Type](f:Expr[T])(using qctx: QuoteContext):Expr[CB[T]] =
39+
import qctx.reflect._
40+
41+
def transform(term:Term):Term =
42+
term match
43+
case ap@Apply(TypeApply(Select(obj,name),targs),args)
44+
if (name=="map"||name=="flatMap") =>
45+
obj match
46+
case Apply(Select(obj1,"withFilter"),args1) =>
47+
val nObj = transform(obj)
48+
transform(Apply(TypeApply(Select.unique(nObj,name),targs),args))
49+
case _ =>
50+
val nArgs = args.map(x => shiftLambda(x))
51+
val nSelect = Select.unique(obj, name+"Out")
52+
Apply(TypeApply(nSelect,targs),nArgs)
53+
case ap@Apply(Select(obj,"withFilter"),args) =>
54+
val nArgs = args.map(x => shiftLambda(x))
55+
val nSelect = Select.unique(obj, "withFilterOut")
56+
Apply(nSelect,nArgs)
57+
case ap@Apply(TypeApply(Select(obj,"map2"),targs),args) =>
58+
val nObj = transform(obj)
59+
Apply(TypeApply(
60+
Select.unique(nObj,"map2"),
61+
List(Type[Int].unseal)
62+
),
63+
args
64+
)
65+
case Apply(TypeApply(Ident("await"),targs),args) => args.head
66+
case Apply(Select(obj,"=="),List(b)) =>
67+
val tb = transform(b).seal.cast[CB[Int]]
68+
val mt = MethodType(List("p"))(_ => List(b.tpe.widen), _ => Type[Boolean].unseal.tpe)
69+
val mapLambda = Lambda(mt, x => Select.overloaded(obj,"==",List(),List(x.head.asInstanceOf[Term]))).seal.cast[Int=>Boolean]
70+
'{ CBM.map($tb)($mapLambda) }.unseal
71+
case Block(stats, last) => Block(stats, transform(last))
72+
case Inlined(x,List(),body) => transform(body)
73+
case l@Literal(x) =>
74+
'{ CBM.pure(${term.seal}) }.unseal
75+
case other =>
76+
throw RuntimeException(s"Not supported $other")
77+
78+
def shiftLambda(term:Term): Term =
79+
term match
80+
case lt@Lambda(params, body) =>
81+
val paramTypes = params.map(_.tpt.tpe)
82+
val paramNames = params.map(_.name)
83+
val mt = MethodType(paramNames)(_ => paramTypes, _ => Type[CB].unseal.tpe.appliedTo(body.tpe.widen) )
84+
val r = Lambda(mt, args => changeArgs(params,args,transform(body)) )
85+
r
86+
case _ =>
87+
throw RuntimeException("lambda expected")
88+
89+
def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term):Term =
90+
val association: Map[Symbol, Term] = (oldArgs zip newArgs).foldLeft(Map.empty){
91+
case (m, (oldParam, newParam: Term)) => m.updated(oldParam.symbol, newParam)
92+
case (m, (oldParam, newParam: Tree)) => throw RuntimeException("Term expected")
93+
}
94+
val changes = new TreeMap() {
95+
override def transformTerm(tree:Term)(using Context): Term =
96+
tree match
97+
case ident@Ident(name) => association.getOrElse(ident.symbol, super.transformTerm(tree))
98+
case _ => super.transformTerm(tree)
99+
}
100+
changes.transformTerm(body)
101+
102+
transform(f.unseal).seal.cast[CB[T]]

tests/pos-macros/i10211/Test_2.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package x
2+
3+
4+
object Main {
5+
6+
def main(args:Array[String]):Unit =
7+
val arr1 = new MyArr[Int]()
8+
val arr2 = new MyArr[Int]()
9+
val r = X.process{
10+
arr1.withFilter(x => x == await(CBM.pure(1)))
11+
.flatMap(x =>
12+
arr2.withFilter( y => y == await(CBM.pure(2)) ).
13+
map2( y => x + y )
14+
)
15+
}
16+
println(r)
17+
18+
}

0 commit comments

Comments
 (0)