@@ -8,6 +8,7 @@ import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Symbols.
8
8
import CheckTrees ._ , Denotations ._ , Decorators ._
9
9
import config .Printers ._
10
10
import typer .ErrorReporting ._
11
+ import scala .annotation .tailrec
11
12
12
13
/** Some creators for typed trees */
13
14
object tpd extends Trees .Instance [Type ] with TypedTreeInfo {
@@ -413,6 +414,92 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
413
414
def tpes : List [Type ] = xs map (_.tpe)
414
415
}
415
416
417
+ /** A tree map that retypes some nodes if their element types have changed,
418
+ * instead of simply copying the original type. The potential retyped nodes
419
+ * are those nodes where the element type may be part of the parent type.
420
+ */
421
+ class RetypingTreeMap extends TreeMap {
422
+ def retypeSelect (tree : Select , qualifier : Tree , name : Name )(implicit ctx : Context ) = {
423
+ val tree1 = cpy.Select (tree, qualifier, name)
424
+ if ((tree1 eq tree) || (qualifier.tpe eq tree.qualifier.tpe)) tree1
425
+ else (tree1.tpe match {
426
+ case tpe : NamedType => tree1.withType(tpe.derivedSelect(qualifier.tpe))
427
+ case _ => tree1
428
+ })
429
+ }
430
+ def retypePair (tree : Pair , left : Tree , right : Tree )(implicit ctx : Context ) = {
431
+ val tree1 = cpy.Pair (tree, left, right)
432
+ if ((tree1 eq tree) || ((left.tpe eq tree.left.tpe) && (right.tpe eq tree.right.tpe))) tree1
433
+ else ta.assignType(tree1, left, right)
434
+ }
435
+ def retypeBlock (tree : Block , stats : List [Tree ], expr : Tree )(implicit ctx : Context ) = {
436
+ val tree1 = cpy.Block (tree, stats, expr)
437
+ if ((tree1 eq tree) || (expr.tpe eq tree.expr.tpe)) tree1
438
+ else ta.assignType(tree1, stats, expr)
439
+ }
440
+ def retypeIf (tree : If , cond : Tree , thenp : Tree , elsep : Tree )(implicit ctx : Context ) = {
441
+ val tree1 = cpy.If (tree, cond, thenp, elsep)
442
+ if ((tree1 eq tree) || (thenp.tpe eq tree.thenp.tpe) && (elsep.tpe eq tree.elsep.tpe)) tree1
443
+ else ta.assignType(tree1, thenp, elsep)
444
+ }
445
+ def retypeMatch (tree : Match , selector : Tree , cases : List [CaseDef ])(implicit ctx : Context ) = {
446
+ val tree1 = cpy.Match (tree, selector, cases)
447
+ if ((tree1 eq tree) || sameTypes(cases, tree.cases)) tree1
448
+ else ta.assignType(tree1, cases)
449
+ }
450
+ def retypeCaseDef (tree : CaseDef , pat : Tree , guard : Tree , body : Tree )(implicit ctx : Context ) = {
451
+ val tree1 = cpy.CaseDef (tree, pat, guard, body)
452
+ if ((tree eq tree1) || (body.tpe eq tree.body.tpe)) tree1
453
+ else ta.assignType(tree1, body)
454
+ }
455
+ def retypeTry (tree : Try , expr : Tree , handler : Tree , finalizer : Tree )(implicit ctx : Context ) = {
456
+ val tree1 = cpy.Try (tree, expr, handler, finalizer)
457
+ if ((tree1 eq tree) || ((expr.tpe eq tree.expr.tpe) && (handler.tpe eq tree.handler.tpe))) tree
458
+ else ta.assignType(tree1, expr, handler)
459
+ }
460
+ def retypeSeqLiteral (tree : SeqLiteral , elems : List [Tree ])(implicit ctx : Context ) = {
461
+ val tree1 = cpy.SeqLiteral (tree, elems)
462
+ if ((tree1 eq tree) || sameTypes(elems, tree.elems)) tree1
463
+ else ta.assignType(tree1, elems)
464
+ }
465
+ def retypeAnnotated (tree : Annotated , annot : Tree , arg : Tree )(implicit ctx : Context ) = {
466
+ val tree1 = cpy.Annotated (tree, annot, arg)
467
+ if ((tree1 eq tree) || (arg.tpe eq tree.arg.tpe) && (annot eq tree.annot)) tree1
468
+ else ta.assignType(tree1, annot, arg)
469
+ }
470
+ override def transform (tree : Tree )(implicit ctx : Context ): Tree = tree match {
471
+ case tree : Ident => // left here for performance
472
+ super .transform(tree)
473
+ case tree @ Select (qualifier, name) =>
474
+ retypeSelect(tree, transform(qualifier), name)
475
+ case tree @ Pair (left, right) =>
476
+ retypePair(tree, transform(left), transform(right))
477
+ case tree @ Block (stats, expr) =>
478
+ retypeBlock(tree, transformStats(stats), transform(expr))
479
+ case tree @ If (cond, thenp, elsep) =>
480
+ retypeIf(tree, transform(cond), transform(thenp), transform(elsep))
481
+ case tree @ Match (selector, cases) =>
482
+ retypeMatch(tree, transform(selector), transformSub(cases))
483
+ case tree @ CaseDef (pat, guard, body) =>
484
+ retypeCaseDef(tree, transform(pat), transform(guard), transform(body))
485
+ case tree @ Try (block, handler, finalizer) =>
486
+ retypeTry(tree, transform(block), transform(handler), transform(finalizer))
487
+ case tree @ SeqLiteral (elems) =>
488
+ retypeSeqLiteral(tree, transform(elems))
489
+ case tree @ Annotated (annot, arg) =>
490
+ retypeAnnotated(tree, transform(annot), transform(arg))
491
+ case _ =>
492
+ super .transform(tree)
493
+ }
494
+
495
+ @ tailrec
496
+ final def sameTypes (trees : List [tpd.Tree ], trees1 : List [tpd.Tree ]): Boolean = {
497
+ if (trees.isEmpty) trees.isEmpty
498
+ else if (trees1.isEmpty) trees.isEmpty
499
+ else (trees.head.tpe eq trees1.head.tpe) && sameTypes(trees.tail, trees1.tail)
500
+ }
501
+ }
502
+
416
503
/** A map that applies three functions together to a tree and makes sure
417
504
* they are coordinated so that the result is well-typed. The functions are
418
505
* @param typeMap A function from Type to type that gets applied to the
@@ -425,7 +512,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
425
512
final class TreeTypeMap (
426
513
val typeMap : Type => Type = IdentityTypeMap ,
427
514
val ownerMap : Symbol => Symbol = identity _,
428
- val treeMap : Tree => Tree = identity _)(implicit ctx : Context ) extends TreeMap {
515
+ val treeMap : Tree => Tree = identity _)(implicit ctx : Context ) extends RetypingTreeMap {
429
516
430
517
override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = {
431
518
val tree1 = treeMap(tree)
@@ -436,10 +523,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
436
523
cpy.DefDef (ddef, mods, name, tparams1, vparamss1, tmap2.transform(tpt), tmap2.transform(rhs))
437
524
case blk @ Block (stats, expr) =>
438
525
val (tmap1, stats1) = transformDefs(stats)
439
- cpy. Block (blk, stats1, tmap1.transform(expr))
526
+ retypeBlock (blk, stats1, tmap1.transform(expr))
440
527
case cdef @ CaseDef (pat, guard, rhs) =>
441
528
val tmap = withMappedSyms(patVars(pat))
442
- cpy. CaseDef (cdef, tmap.transform(pat), tmap.transform(guard), tmap.transform(rhs))
529
+ retypeCaseDef (cdef, tmap.transform(pat), tmap.transform(guard), tmap.transform(rhs))
443
530
case tree1 =>
444
531
super .transform(tree1)
445
532
}
0 commit comments