Skip to content

Fix #10151: Fix #10211: Fix changeOwnity for trees, assembled from multiple parts. #10218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -799,13 +799,29 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* owner to `to`, and continue until a non-weak owner is reached.
*/
def changeOwner(from: Symbol, to: Symbol)(using Context): ThisTree = {
@tailrec def loop(from: Symbol, froms: List[Symbol], tos: List[Symbol]): ThisTree =
if (from.isWeakOwner && !from.owner.isClass)
loop(from.owner, from :: froms, to :: tos)
else
//println(i"change owner ${from :: froms}%, % ==> $tos of $tree")
TreeTypeMap(oldOwners = from :: froms, newOwners = tos).apply(tree)
if (from == to) tree else loop(from, Nil, to :: Nil)
changeOwners(List(from),to)
}

/** Change owner from all `froms` to `to`. If `from` is a weak owner, also change its
* owner to `to`, and continue until a non-weak owner is reached.
*/
def changeOwners(froms: List[Symbol], to: Symbol)(using Context): ThisTree = {
@tailrec def loop(froms: List[Symbol], processedFroms: List[Symbol], tos: List[Symbol]): ThisTree =
froms match
case from::rest =>
if (from == to)
loop(rest, processedFroms, tos)
else
if (from.isWeakOwner && !from.owner.isClass)
loop(from.owner::rest, from :: processedFroms, to :: tos)
else
loop(rest, from::processedFroms, to :: tos)
case Nil =>
if (processedFroms.isEmpty)
tree
else
TreeTypeMap(oldOwners = processedFroms, newOwners = tos).apply(tree)
loop(froms, Nil, Nil)
}

/**
Expand Down
23 changes: 11 additions & 12 deletions compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@ import dotty.tools.dotc.core.Symbols._
object QuoteUtils:
import tpd._

/** Get the owner of a tree if it has one */
def treeOwner(tree: Tree)(using Context): Option[Symbol] = {
val getCurrentOwner = new TreeAccumulator[Option[Symbol]] {
def apply(x: Option[Symbol], tree: tpd.Tree)(using Context): Option[Symbol] =
if (x.isDefined) x
else tree match {
case tree: DefTree => Some(tree.symbol.owner)
case _ => foldOver(x, tree)
/** Get the list of owners of a tree if it has one */
def treeOwners(tree: Tree)(using Context): List[Symbol] = {
val getOwners = new TreeAccumulator[Map[Int,Symbol]] {
def apply(x: Map[Int,Symbol], tree: tpd.Tree)(using Context): Map[Int,Symbol] =
tree match {
case tree: DefTree => val owner = tree.symbol.owner
x.updated(owner.id, owner)
case _ => foldOver(x,tree)
}
}
getCurrentOwner(None, tree)
getOwners(Map.empty,tree).values.toList
}


/** Changes the owner of the tree based on the current owner of the tree */
def changeOwnerOfTree(tree: Tree, owner: Symbol)(using Context): Tree = {
treeOwner(tree) match
case Some(oldOwner) if oldOwner != owner => tree.changeOwner(oldOwner, owner)
case _ => tree
tree.changeOwners(treeOwners(tree), owner)
}

end QuoteUtils
93 changes: 93 additions & 0 deletions tests/pos-macros/i10151/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package x

import scala.quoted._

trait CB[T]:
def map[S](f: T=>S): CB[S] = ???
def flatMap[S](f: T=>CB[S]): CB[S] = ???

class MyArr[AK,AV]:
def map1[BK,BV](f: ((AK,AV)) => (BK, BV)):MyArr[BK,BV] = ???
def map1Out[BK, BV](f: ((AK,AV)) => CB[(BK,BV)]): CB[MyArr[BK,BV]] = ???

def await[T](x:CB[T]):T = ???

object CBM:
def pure[T](t:T):CB[T] = ???

object X:

inline def process[T](inline f:T) = ${
processImpl[T]('f)
}

def processImpl[T:Type](f:Expr[T])(using qctx: QuoteContext):Expr[CB[T]] =
import qctx.reflect._

def transform(term:Term):Term =
term match
case Apply(TypeApply(Select(obj,"map1"),targs),args) =>
val nArgs = args.map(x => shiftLambda(x))
val nSelect = Select.unique(obj, "map1Out")
Apply(TypeApply(nSelect,targs),nArgs)
case Apply(TypeApply(Ident("await"),targs),args) => args.head
case a@Apply(x,List(y,z)) =>
val mty=MethodType(List("y1"))( _ => List(y.tpe.widen), _ => Type[CB].unseal.tpe.appliedTo(a.tpe.widen))
val mtz=MethodType(List("z1"))( _ => List(z.tpe.widen), _ => a.tpe.widen)
Apply(
TypeApply(Select.unique(transform(y),"flatMap"),
List(Inferred(a.tpe.widen))
),
List(
Lambda(mty, yArgs =>
Apply(
TypeApply(Select.unique(transform(z),"map"),
List(Inferred(a.tpe.widen))
),
List(
Lambda(mtz, zArgs => {
val termYArgs = yArgs.asInstanceOf[List[Term]]
val termZArgs = zArgs.asInstanceOf[List[Term]]
Apply(x,List(termYArgs.head,termZArgs.head))
})
)
)
)
)
)
case Block(stats, last) => Block(stats, transform(last))
case Inlined(x,List(),body) => transform(body)
case l@Literal(x) =>
l.seal match
case '{ $l: $L } =>
'{ CBM.pure(${term.seal.cast[L]}) }.unseal
case other =>
throw RuntimeException(s"Not supported $other")

def shiftLambda(term:Term): Term =
term match
case lt@Lambda(params, body) =>
val paramTypes = params.map(_.tpt.tpe)
val paramNames = params.map(_.name)
val mt = MethodType(paramNames)(_ => paramTypes, _ => Type[CB].unseal.tpe.appliedTo(body.tpe.widen) )
Lambda(mt, args => changeArgs(params,args,transform(body)) )
case Block(stats, last) =>
Block(stats, shiftLambda(last))
case _ =>
throw RuntimeException("lambda expected")

def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term):Term =
val association: Map[Symbol, Term] = (oldArgs zip newArgs).foldLeft(Map.empty){
case (m, (oldParam, newParam: Term)) => m.updated(oldParam.symbol, newParam)
case (m, (oldParam, newParam: Tree)) => throw RuntimeException("Term expected")
}
val changes = new TreeMap() {
override def transformTerm(tree:Term)(using Context): Term =
tree match
case ident@Ident(name) => association.getOrElse(ident.symbol, super.transformTerm(tree))
case _ => super.transformTerm(tree)
}
changes.transformTerm(body)

val r = transform(f.unseal).seal.cast[CB[T]]
r
15 changes: 15 additions & 0 deletions tests/pos-macros/i10151/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package x


object Main {

def main(args:Array[String]):Unit =
val arr = new MyArr[Int,Int]()
val r = X.process{
arr.map1( (x,y) =>
( 1, await(CBM.pure(x)) )
)
}
println("r")

}
102 changes: 102 additions & 0 deletions tests/pos-macros/i10211/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package x

import scala.quoted._

trait CB[T]:
def map[S](f: T=>S): CB[S] = ???


class MyArr[A]:
def map[B](f: A=>B):MyArr[B] = ???
def mapOut[B](f: A=> CB[B]): CB[MyArr[B]] = ???
def flatMap[B](f: A=>MyArr[B]):MyArr[B] = ???
def flatMapOut[B](f: A=>CB[MyArr[B]]):MyArr[B] = ???
def withFilter(p: A=>Boolean): MyArr[A] = ???
def withFilterOut(p: A=>CB[Boolean]): DelayedWithFilter[A] = ???
def map2[B](f: A=>B):MyArr[B] = ???

class DelayedWithFilter[A]:
def map[B](f: A=>B):MyArr[B] = ???
def mapOut[B](f: A=> CB[B]): CB[MyArr[B]] = ???
def flatMap[B](f: A=>MyArr[B]):MyArr[B] = ???
def flatMapOut[B](f: A=>CB[MyArr[B]]): CB[MyArr[B]] = ???
def map2[B](f: A=>B):CB[MyArr[B]] = ???


def await[T](x:CB[T]):T = ???

object CBM:
def pure[T](t:T):CB[T] = ???
def map[T,S](a:CB[T])(f:T=>S):CB[S] = ???

object X:

inline def process[T](inline f:T) = ${
processImpl[T]('f)
}

def processImpl[T:Type](f:Expr[T])(using qctx: QuoteContext):Expr[CB[T]] =
import qctx.reflect._

def transform(term:Term):Term =
term match
case ap@Apply(TypeApply(Select(obj,name),targs),args)
if (name=="map"||name=="flatMap") =>
obj match
case Apply(Select(obj1,"withFilter"),args1) =>
val nObj = transform(obj)
transform(Apply(TypeApply(Select.unique(nObj,name),targs),args))
case _ =>
val nArgs = args.map(x => shiftLambda(x))
val nSelect = Select.unique(obj, name+"Out")
Apply(TypeApply(nSelect,targs),nArgs)
case ap@Apply(Select(obj,"withFilter"),args) =>
val nArgs = args.map(x => shiftLambda(x))
val nSelect = Select.unique(obj, "withFilterOut")
Apply(nSelect,nArgs)
case ap@Apply(TypeApply(Select(obj,"map2"),targs),args) =>
val nObj = transform(obj)
Apply(TypeApply(
Select.unique(nObj,"map2"),
List(Type[Int].unseal)
),
args
)
case Apply(TypeApply(Ident("await"),targs),args) => args.head
case Apply(Select(obj,"=="),List(b)) =>
val tb = transform(b).seal.cast[CB[Int]]
val mt = MethodType(List("p"))(_ => List(b.tpe.widen), _ => Type[Boolean].unseal.tpe)
val mapLambda = Lambda(mt, x => Select.overloaded(obj,"==",List(),List(x.head.asInstanceOf[Term]))).seal.cast[Int=>Boolean]
'{ CBM.map($tb)($mapLambda) }.unseal
case Block(stats, last) => Block(stats, transform(last))
case Inlined(x,List(),body) => transform(body)
case l@Literal(x) =>
'{ CBM.pure(${term.seal}) }.unseal
case other =>
throw RuntimeException(s"Not supported $other")

def shiftLambda(term:Term): Term =
term match
case lt@Lambda(params, body) =>
val paramTypes = params.map(_.tpt.tpe)
val paramNames = params.map(_.name)
val mt = MethodType(paramNames)(_ => paramTypes, _ => Type[CB].unseal.tpe.appliedTo(body.tpe.widen) )
val r = Lambda(mt, args => changeArgs(params,args,transform(body)) )
r
case _ =>
throw RuntimeException("lambda expected")

def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term):Term =
val association: Map[Symbol, Term] = (oldArgs zip newArgs).foldLeft(Map.empty){
case (m, (oldParam, newParam: Term)) => m.updated(oldParam.symbol, newParam)
case (m, (oldParam, newParam: Tree)) => throw RuntimeException("Term expected")
}
val changes = new TreeMap() {
override def transformTerm(tree:Term)(using Context): Term =
tree match
case ident@Ident(name) => association.getOrElse(ident.symbol, super.transformTerm(tree))
case _ => super.transformTerm(tree)
}
changes.transformTerm(body)

transform(f.unseal).seal.cast[CB[T]]
18 changes: 18 additions & 0 deletions tests/pos-macros/i10211/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package x


object Main {

def main(args:Array[String]):Unit =
val arr1 = new MyArr[Int]()
val arr2 = new MyArr[Int]()
val r = X.process{
arr1.withFilter(x => x == await(CBM.pure(1)))
.flatMap(x =>
arr2.withFilter( y => y == await(CBM.pure(2)) ).
map2( y => x + y )
)
}
println(r)

}