|
| 1 | +package scala.quoted.util |
| 2 | + |
| 3 | +import scala.quoted._ |
| 4 | + |
| 5 | +trait ExprMap { |
| 6 | + |
| 7 | + /** Map an expression `e` with a type `tpe` */ |
| 8 | + def transform[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T] |
| 9 | + |
| 10 | + /** Map subexpressions an expression `e` with a type `tpe` */ |
| 11 | + def transformChildren[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T] = { |
| 12 | + import qctx.tasty.{_, given} |
| 13 | + final class MapChildren() { |
| 14 | + |
| 15 | + def transformStatement(tree: Statement)(given ctx: Context): Statement = { |
| 16 | + def localCtx(definition: Definition): Context = definition.symbol.localContext |
| 17 | + tree match { |
| 18 | + case tree: Term => |
| 19 | + transformTerm(tree, defn.AnyType) |
| 20 | + case tree: Definition => |
| 21 | + transformDefinition(tree) |
| 22 | + case tree: Import => |
| 23 | + tree |
| 24 | + } |
| 25 | + } |
| 26 | + |
| 27 | + def transformDefinition(tree: Definition)(given ctx: Context): Definition = { |
| 28 | + def localCtx(definition: Definition): Context = definition.symbol.localContext |
| 29 | + tree match { |
| 30 | + case tree: ValDef => |
| 31 | + implicit val ctx = localCtx(tree) |
| 32 | + val rhs1 = tree.rhs.map(x => transformTerm(x, tree.tpt.tpe)) |
| 33 | + ValDef.copy(tree)(tree.name, tree.tpt, rhs1) |
| 34 | + case tree: DefDef => |
| 35 | + implicit val ctx = localCtx(tree) |
| 36 | + DefDef.copy(tree)(tree.name, tree.typeParams, tree.paramss, tree.returnTpt, tree.rhs.map(x => transformTerm(x, tree.returnTpt.tpe))) |
| 37 | + case tree: TypeDef => |
| 38 | + tree |
| 39 | + case tree: ClassDef => |
| 40 | + val newBody = transformStats(tree.body) |
| 41 | + ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, newBody) |
| 42 | + } |
| 43 | + } |
| 44 | + |
| 45 | + def transformTermChildren(tree: Term, tpe: Type)(given ctx: Context): Term = tree match { |
| 46 | + case Ident(name) => |
| 47 | + tree |
| 48 | + case Select(qualifier, name) => |
| 49 | + Select.copy(tree)(transformTerm(qualifier, qualifier.tpe), name) |
| 50 | + case This(qual) => |
| 51 | + tree |
| 52 | + case Super(qual, mix) => |
| 53 | + tree |
| 54 | + case tree @ Apply(fun, args) => |
| 55 | + val MethodType(_, tpes, _) = fun.tpe.widen |
| 56 | + Apply.copy(tree)(transformTerm(fun, defn.AnyType), transformTerms(args, tpes)) |
| 57 | + case TypeApply(fun, args) => |
| 58 | + TypeApply.copy(tree)(transformTerm(fun, defn.AnyType), args) |
| 59 | + case _: Literal => |
| 60 | + tree |
| 61 | + case New(tpt) => |
| 62 | + New.copy(tree)(transformTypeTree(tpt)) |
| 63 | + case Typed(expr, tpt) => |
| 64 | + val tp = tpt.tpe match |
| 65 | + // TODO improve code |
| 66 | + case AppliedType(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "<repeated>"), List(tp0: Type)) => |
| 67 | + type T |
| 68 | + val a = tp0.seal.asInstanceOf[quoted.Type[T]] |
| 69 | + '[Seq[$a]].unseal.tpe |
| 70 | + case tp => tp |
| 71 | + Typed.copy(tree)(transformTerm(expr, tp), transformTypeTree(tpt)) |
| 72 | + case tree: NamedArg => |
| 73 | + NamedArg.copy(tree)(tree.name, transformTerm(tree.value, tpe)) |
| 74 | + case Assign(lhs, rhs) => |
| 75 | + Assign.copy(tree)(lhs, transformTerm(rhs, lhs.tpe.widen)) |
| 76 | + case Block(stats, expr) => |
| 77 | + Block.copy(tree)(transformStats(stats), transformTerm(expr, tpe)) |
| 78 | + case If(cond, thenp, elsep) => |
| 79 | + If.copy(tree)( |
| 80 | + transformTerm(cond, defn.BooleanType), |
| 81 | + transformTerm(thenp, tpe), |
| 82 | + transformTerm(elsep, tpe)) |
| 83 | + case _: Closure => |
| 84 | + tree |
| 85 | + case Match(selector, cases) => |
| 86 | + Match.copy(tree)(transformTerm(selector, selector.tpe), transformCaseDefs(cases, tpe)) |
| 87 | + case Return(expr) => |
| 88 | + // FIXME |
| 89 | + // ctx.owner seems to be set to the wrong symbol |
| 90 | + // Return.copy(tree)(transformTerm(expr, expr.tpe)) |
| 91 | + tree |
| 92 | + case While(cond, body) => |
| 93 | + While.copy(tree)(transformTerm(cond, defn.BooleanType), transformTerm(body, defn.AnyType)) |
| 94 | + case Try(block, cases, finalizer) => |
| 95 | + Try.copy(tree)(transformTerm(block, tpe), transformCaseDefs(cases, defn.AnyType), finalizer.map(x => transformTerm(x, defn.AnyType))) |
| 96 | + case Repeated(elems, elemtpt) => |
| 97 | + Repeated.copy(tree)(transformTerms(elems, elemtpt.tpe), elemtpt) |
| 98 | + case Inlined(call, bindings, expansion) => |
| 99 | + Inlined.copy(tree)(call, transformDefinitions(bindings), transformTerm(expansion, tpe)/*()call.symbol.localContext)*/) |
| 100 | + } |
| 101 | + |
| 102 | + def transformTerm(tree: Term, tpe: Type)(given ctx: Context): Term = |
| 103 | + tree match { |
| 104 | + case _: Closure => |
| 105 | + tree |
| 106 | + case _: Inlined => |
| 107 | + transformTermChildren(tree, tpe) |
| 108 | + case _ => |
| 109 | + tree.tpe.widen match { |
| 110 | + case _: MethodType | _: PolyType => |
| 111 | + transformTermChildren(tree, tpe) |
| 112 | + case _ => |
| 113 | + type X |
| 114 | + val expr = tree.seal.asInstanceOf[Expr[X]] |
| 115 | + val t = tpe.seal.asInstanceOf[quoted.Type[X]] |
| 116 | + transform(expr)(given qctx, t).unseal |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree |
| 121 | + |
| 122 | + def transformCaseDef(tree: CaseDef, tpe: Type)(given ctx: Context): CaseDef = |
| 123 | + CaseDef.copy(tree)(tree.pattern, tree.guard.map(x => transformTerm(x, defn.BooleanType)), transformTerm(tree.rhs, tpe)) |
| 124 | + |
| 125 | + def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = { |
| 126 | + TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs)) |
| 127 | + } |
| 128 | + |
| 129 | + def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] = |
| 130 | + trees mapConserve (transformStatement(_)) |
| 131 | + |
| 132 | + def transformDefinitions(trees: List[Definition])(given ctx: Context): List[Definition] = |
| 133 | + trees mapConserve (transformDefinition(_)) |
| 134 | + |
| 135 | + def transformTerms(trees: List[Term], tpes: List[Type])(given ctx: Context): List[Term] = |
| 136 | + var tpes2 = tpes // TODO use proper zipConserve |
| 137 | + trees mapConserve { x => |
| 138 | + val tpe :: tail = tpes2 |
| 139 | + tpes2 = tail |
| 140 | + transformTerm(x, tpe) |
| 141 | + } |
| 142 | + |
| 143 | + def transformTerms(trees: List[Term], tpe: Type)(given ctx: Context): List[Term] = |
| 144 | + trees.mapConserve(x => transformTerm(x, tpe)) |
| 145 | + |
| 146 | + def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] = |
| 147 | + trees mapConserve (transformTypeTree(_)) |
| 148 | + |
| 149 | + def transformCaseDefs(trees: List[CaseDef], tpe: Type)(given ctx: Context): List[CaseDef] = |
| 150 | + trees mapConserve (x => transformCaseDef(x, tpe)) |
| 151 | + |
| 152 | + def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] = |
| 153 | + trees mapConserve (transformTypeCaseDef(_)) |
| 154 | + |
| 155 | + } |
| 156 | + new MapChildren().transformTermChildren(e.unseal, tpe.unseal.tpe).seal.cast[T] // Cast will only fail if this implementation has a bug |
| 157 | + } |
| 158 | + |
| 159 | +} |
0 commit comments