diff --git a/src/dotty/tools/dotc/ast/Trees.scala b/src/dotty/tools/dotc/ast/Trees.scala index 8a36fee3a5ae..e19221841dd0 100644 --- a/src/dotty/tools/dotc/ast/Trees.scala +++ b/src/dotty/tools/dotc/ast/Trees.scala @@ -1109,36 +1109,20 @@ object Trees { cpy.Apply(tree, transform(fun), transform(args)) case TypeApply(fun, args) => cpy.TypeApply(tree, transform(fun), transform(args)) - case Literal(const) => - tree case New(tpt) => cpy.New(tree, transform(tpt)) - case Pair(left, right) => - cpy.Pair(tree, transform(left), transform(right)) case Typed(expr, tpt) => cpy.Typed(tree, transform(expr), transform(tpt)) case NamedArg(name, arg) => cpy.NamedArg(tree, name, transform(arg)) case Assign(lhs, rhs) => cpy.Assign(tree, transform(lhs), transform(rhs)) - case Block(stats, expr) => - cpy.Block(tree, transformStats(stats), transform(expr)) - case If(cond, thenp, elsep) => - cpy.If(tree, transform(cond), transform(thenp), transform(elsep)) case Closure(env, meth, tpt) => cpy.Closure(tree, transform(env), transform(meth), transform(tpt)) - case Match(selector, cases) => - cpy.Match(tree, transform(selector), transformSub(cases)) - case CaseDef(pat, guard, body) => - cpy.CaseDef(tree, transform(pat), transform(guard), transform(body)) case Return(expr, from) => cpy.Return(tree, transform(expr), transformSub(from)) - case Try(block, handler, finalizer) => - cpy.Try(tree, transform(block), transform(handler), transform(finalizer)) case Throw(expr) => cpy.Throw(tree, transform(expr)) - case SeqLiteral(elems) => - cpy.SeqLiteral(tree, transform(elems)) case TypeTree(original) => tree case SingletonTypeTree(ref) => @@ -1177,12 +1161,29 @@ object Trees { cpy.Import(tree, transform(expr), selectors) case PackageDef(pid, stats) => cpy.PackageDef(tree, transformSub(pid), transformStats(stats)) - case Annotated(annot, arg) => - cpy.Annotated(tree, transform(annot), transform(arg)) case Thicket(trees) => val trees1 = transform(trees) if (trees1 eq trees) tree else Thicket(trees1) + case Literal(const) => + tree + case Pair(left, right) => + cpy.Pair(tree, transform(left), transform(right)) + case Block(stats, expr) => + cpy.Block(tree, transformStats(stats), transform(expr)) + case If(cond, thenp, elsep) => + cpy.If(tree, transform(cond), transform(thenp), transform(elsep)) + case Match(selector, cases) => + cpy.Match(tree, transform(selector), transformSub(cases)) + case CaseDef(pat, guard, body) => + cpy.CaseDef(tree, transform(pat), transform(guard), transform(body)) + case Try(block, handler, finalizer) => + cpy.Try(tree, transform(block), transform(handler), transform(finalizer)) + case SeqLiteral(elems) => + cpy.SeqLiteral(tree, transform(elems)) + case Annotated(annot, arg) => + cpy.Annotated(tree, transform(annot), transform(arg)) } + def transformStats(trees: List[Tree])(implicit ctx: Context): List[Tree] = transform(trees) def transform(trees: List[Tree])(implicit ctx: Context): List[Tree] = diff --git a/src/dotty/tools/dotc/ast/tpd.scala b/src/dotty/tools/dotc/ast/tpd.scala index fecfefd37ac4..0ef855de261e 100644 --- a/src/dotty/tools/dotc/ast/tpd.scala +++ b/src/dotty/tools/dotc/ast/tpd.scala @@ -3,12 +3,15 @@ package dotc package ast import core._ +import dotty.tools.dotc.transform.TypeUtils import util.Positions._, Types._, Contexts._, Constants._, Names._, Flags._ import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Symbols._ import CheckTrees._, Denotations._, Decorators._ import config.Printers._ import typer.ErrorReporting._ +import scala.annotation.tailrec + /** Some creators for typed trees */ object tpd extends Trees.Instance[Type] with TypedTreeInfo { @@ -413,6 +416,68 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def tpes: List[Type] = xs map (_.tpe) } + /** RetypingTreeMap is a TreeMap that is able to propagate type changes. + * + * This is required when types can change during transformation, + * for example if `Block(stats, expr)` is being transformed + * and type of `expr` changes from `TypeRef(prefix, name)` to `TypeRef(newPrefix, name)` with different prefix, t + * type of enclosing Block should also change, otherwise the whole tree would not be type-correct anymore. + * see `propagateType` methods for propagation rulles. + * + * TreeMap does not include such logic as it assumes that types of threes do not change during transformation. + */ + class RetypingTreeMap extends TreeMap { + + override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match { + case tree@Select(qualifier, name) => + val tree1 = cpy.Select(tree, transform(qualifier), name) + propagateType(tree, tree1) + case tree@Pair(left, right) => + val left1 = transform(left) + val right1 = transform(right) + val tree1 = cpy.Pair(tree, left1, right1) + propagateType(tree, tree1) + case tree@Block(stats, expr) => + val stats1 = transform(stats) + val expr1 = transform(expr) + val tree1 = cpy.Block(tree, stats1, expr1) + propagateType(tree, tree1) + case tree@If(cond, thenp, elsep) => + val cond1 = transform(cond) + val thenp1 = transform(thenp) + val elsep1 = transform(elsep) + val tree1 = cpy.If(tree, cond1, thenp1, elsep1) + propagateType(tree, tree1) + case tree@Match(selector, cases) => + val selector1 = transform(selector) + val cases1 = transformSub(cases) + val tree1 = cpy.Match(tree, selector1, cases1) + propagateType(tree, tree1) + case tree@CaseDef(pat, guard, body) => + val pat1 = transform(pat) + val guard1 = transform(guard) + val body1 = transform(body) + val tree1 = cpy.CaseDef(tree, pat1, guard1, body1) + propagateType(tree, tree1) + case tree@Try(block, handler, finalizer) => + val expr1 = transform(block) + val handler1 = transform(handler) + val finalizer1 = transform(finalizer) + val tree1 = cpy.Try(tree, expr1, handler1, finalizer1) + propagateType(tree, tree1) + case tree@SeqLiteral(elems) => + val elems1 = transform(elems) + val tree1 = cpy.SeqLiteral(tree, elems1) + propagateType(tree, tree1) + case tree@Annotated(annot, arg) => + val annot1 = transform(annot) + val arg1 = transform(arg) + val tree1 = cpy.Annotated(tree, annot1, arg1) + propagateType(tree, tree1) + case _ => super.transform(tree) + } + } + /** A map that applies three functions together to a tree and makes sure * they are coordinated so that the result is well-typed. The functions are * @param typeMap A function from Type to type that gets applied to the @@ -425,7 +490,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { final class TreeTypeMap( val typeMap: Type => Type = IdentityTypeMap, val ownerMap: Symbol => Symbol = identity _, - val treeMap: Tree => Tree = identity _)(implicit ctx: Context) extends TreeMap { + val treeMap: Tree => Tree = identity _)(implicit ctx: Context) extends RetypingTreeMap { override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = { val tree1 = treeMap(tree) @@ -436,10 +501,16 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { cpy.DefDef(ddef, mods, name, tparams1, vparamss1, tmap2.transform(tpt), tmap2.transform(rhs)) case blk @ Block(stats, expr) => val (tmap1, stats1) = transformDefs(stats) - cpy.Block(blk, stats1, tmap1.transform(expr)) + val expr1 = tmap1.transform(expr) + val tree1 = cpy.Block(blk, stats1, expr1) + propagateType(blk, tree1) case cdef @ CaseDef(pat, guard, rhs) => val tmap = withMappedSyms(patVars(pat)) - cpy.CaseDef(cdef, tmap.transform(pat), tmap.transform(guard), tmap.transform(rhs)) + val pat1 = tmap.transform(pat) + val guard1 = tmap.transform(guard) + val rhs1 = tmap.transform(rhs) + val tree1 = cpy.CaseDef(tree, pat1, guard1, rhs1) + propagateType(cdef, tree1) case tree1 => super.transform(tree1) } @@ -501,6 +572,56 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { acc(Nil, tree) } + def propagateType(origTree: Pair, newTree: Pair)(implicit ctx: Context) = { + if ((newTree eq origTree) || + ((newTree.left.tpe eq origTree.left.tpe) && (newTree.right.tpe eq origTree.right.tpe))) newTree + else ta.assignType(newTree, newTree.left, newTree.right) + } + + def propagateType(origTree: Block, newTree: Block)(implicit ctx: Context) = { + if ((newTree eq origTree) || (newTree.expr.tpe eq origTree.expr.tpe)) newTree + else ta.assignType(newTree, newTree.stats, newTree.expr) + } + + def propagateType(origTree: If, newTree: If)(implicit ctx: Context) = { + if ((newTree eq origTree) || + ((newTree.thenp.tpe eq origTree.thenp.tpe) && (newTree.elsep.tpe eq origTree.elsep.tpe))) newTree + else ta.assignType(newTree, newTree.thenp, newTree.elsep) + } + + def propagateType(origTree: Match, newTree: Match)(implicit ctx: Context) = { + if ((newTree eq origTree) || sameTypes(newTree.cases, origTree.cases)) newTree + else ta.assignType(newTree, newTree.cases) + } + + def propagateType(origTree: CaseDef, newTree: CaseDef)(implicit ctx: Context) = { + if ((newTree eq newTree) || (newTree.body.tpe eq origTree.body.tpe)) newTree + else ta.assignType(newTree, newTree.body) + } + + def propagateType(origTree: Try, newTree: Try)(implicit ctx: Context) = { + if ((newTree eq origTree) || + ((newTree.expr.tpe eq origTree.expr.tpe) && (newTree.handler.tpe eq origTree.handler.tpe))) newTree + else ta.assignType(newTree, newTree.expr, newTree.handler) + } + + def propagateType(origTree: SeqLiteral, newTree: SeqLiteral)(implicit ctx: Context) = { + if ((newTree eq origTree) || sameTypes(newTree.elems, origTree.elems)) newTree + else ta.assignType(newTree, newTree.elems) + } + + def propagateType(origTree: Annotated, newTree: Annotated)(implicit ctx: Context) = { + if ((newTree eq origTree) || ((newTree.arg.tpe eq origTree.arg.tpe) && (newTree.annot eq origTree.annot))) newTree + else ta.assignType(newTree, newTree.annot, newTree.arg) + } + + def propagateType(origTree: Select, newTree: Select)(implicit ctx: Context) = { + if ((origTree eq newTree) || (origTree.qualifier.tpe eq newTree.qualifier.tpe)) newTree + else newTree.tpe match { + case tpe: NamedType => newTree.withType(tpe.derivedSelect(newTree.qualifier.tpe)) + case _ => newTree + } + } // convert a numeric with a toXXX method def primitiveConversion(tree: Tree, numericCls: Symbol)(implicit ctx: Context): Tree = { val mname = ("to" + numericCls.name).toTermName @@ -515,6 +636,13 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { } } + @tailrec + def sameTypes(trees: List[tpd.Tree], trees1: List[tpd.Tree]): Boolean = { + if (trees.isEmpty) trees.isEmpty + else if (trees1.isEmpty) trees.isEmpty + else (trees.head.tpe eq trees1.head.tpe) && sameTypes(trees.tail, trees1.tail) + } + def evalOnce(tree: Tree)(within: Tree => Tree)(implicit ctx: Context) = { if (isIdempotentExpr(tree)) within(tree) else { diff --git a/src/dotty/tools/dotc/transform/FullParameterization.scala b/src/dotty/tools/dotc/transform/FullParameterization.scala index fea0482a0d01..ac8773f10f02 100644 --- a/src/dotty/tools/dotc/transform/FullParameterization.scala +++ b/src/dotty/tools/dotc/transform/FullParameterization.scala @@ -142,7 +142,7 @@ trait FullParameterization { * followed by the class parameters of its enclosing class. */ private def allInstanceTypeParams(originalDef: DefDef)(implicit ctx: Context): List[Symbol] = - originalDef.tparams.map(_.symbol) ::: originalDef.symbol.owner.typeParams + originalDef.tparams.map(_.symbol) ::: originalDef.symbol.enclosingClass.typeParams /** Given an instance method definition `originalDef`, return a * fully parameterized method definition derived from `originalDef`, which @@ -152,7 +152,7 @@ trait FullParameterization { def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree = polyDefDef(derived, trefs => vrefss => { val origMeth = originalDef.symbol - val origClass = origMeth.owner.asClass + val origClass = origMeth.enclosingClass.asClass val origTParams = allInstanceTypeParams(originalDef) val origVParams = originalDef.vparamss.flatten map (_.symbol) val thisRef :: argRefs = vrefss.flatten @@ -219,7 +219,7 @@ trait FullParameterization { def forwarder(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree = ref(derived.termRef) .appliedToTypes(allInstanceTypeParams(originalDef).map(_.typeRef)) - .appliedTo(This(originalDef.symbol.owner.asClass)) + .appliedTo(This(originalDef.symbol.enclosingClass.asClass)) .appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol))) .withPos(originalDef.rhs.pos) } \ No newline at end of file diff --git a/src/dotty/tools/dotc/transform/TypeUtils.scala b/src/dotty/tools/dotc/transform/TypeUtils.scala index f11bb980acfc..a266600929e8 100644 --- a/src/dotty/tools/dotc/transform/TypeUtils.scala +++ b/src/dotty/tools/dotc/transform/TypeUtils.scala @@ -1,23 +1,18 @@ package dotty.tools.dotc package transform -import core._ -import Types._ -import Contexts._ -import Symbols._ -import Decorators._ -import StdNames.nme -import NameOps._ -import language.implicitConversions +import dotty.tools.dotc.core.Types._ + +import scala.language.implicitConversions object TypeUtils { implicit def decorateTypeUtils(tpe: Type): TypeUtils = new TypeUtils(tpe) + } /** A decorator that provides methods for type transformations * that are needed in the transofmer pipeline (not needed right now) */ class TypeUtils(val self: Type) extends AnyVal { - import TypeUtils._ } \ No newline at end of file