@@ -3,12 +3,15 @@ package dotc
3
3
package ast
4
4
5
5
import core ._
6
+ import dotty .tools .dotc .transform .TypeUtils
6
7
import util .Positions ._ , Types ._ , Contexts ._ , Constants ._ , Names ._ , Flags ._
7
8
import SymDenotations ._ , Symbols ._ , StdNames ._ , Annotations ._ , Trees ._ , Symbols ._
8
9
import CheckTrees ._ , Denotations ._ , Decorators ._
9
10
import config .Printers ._
10
11
import typer .ErrorReporting ._
11
12
13
+ import scala .annotation .tailrec
14
+
12
15
/** Some creators for typed trees */
13
16
object tpd extends Trees .Instance [Type ] with TypedTreeInfo {
14
17
@@ -413,6 +416,68 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
413
416
def tpes : List [Type ] = xs map (_.tpe)
414
417
}
415
418
419
+ /** RetypingTreeMap is a TreeMap that is able to propagate type changes.
420
+ *
421
+ * This is required when types can change during transformation,
422
+ * for example if `Block(stats, expr)` is being transformed
423
+ * and type of `expr` changes from `TypeRef(prefix, name)` to `TypeRef(newPrefix, name)` with different prefix, t
424
+ * type of enclosing Block should also change, otherwise the whole tree would not be type-correct anymore.
425
+ * see `propagateType` methods for propagation rulles.
426
+ *
427
+ * TreeMap does not include such logic as it assumes that types of threes do not change during transformation.
428
+ */
429
+ class RetypingTreeMap extends TreeMap {
430
+
431
+ override def transform (tree : Tree )(implicit ctx : Context ): Tree = tree match {
432
+ case tree@ Select (qualifier, name) =>
433
+ val tree1 = cpy.Select (tree, transform(qualifier), name)
434
+ propagateType(tree, tree1)
435
+ case tree@ Pair (left, right) =>
436
+ val left1 = transform(left)
437
+ val right1 = transform(right)
438
+ val tree1 = cpy.Pair (tree, left1, right1)
439
+ propagateType(tree, tree1)
440
+ case tree@ Block (stats, expr) =>
441
+ val stats1 = transform(stats)
442
+ val expr1 = transform(expr)
443
+ val tree1 = cpy.Block (tree, stats1, expr1)
444
+ propagateType(tree, tree1)
445
+ case tree@ If (cond, thenp, elsep) =>
446
+ val cond1 = transform(cond)
447
+ val thenp1 = transform(thenp)
448
+ val elsep1 = transform(elsep)
449
+ val tree1 = cpy.If (tree, cond1, thenp1, elsep1)
450
+ propagateType(tree, tree1)
451
+ case tree@ Match (selector, cases) =>
452
+ val selector1 = transform(selector)
453
+ val cases1 = transformSub(cases)
454
+ val tree1 = cpy.Match (tree, selector1, cases1)
455
+ propagateType(tree, tree1)
456
+ case tree@ CaseDef (pat, guard, body) =>
457
+ val pat1 = transform(pat)
458
+ val guard1 = transform(guard)
459
+ val body1 = transform(body)
460
+ val tree1 = cpy.CaseDef (tree, pat1, guard1, body1)
461
+ propagateType(tree, tree1)
462
+ case tree@ Try (block, handler, finalizer) =>
463
+ val expr1 = transform(block)
464
+ val handler1 = transform(handler)
465
+ val finalizer1 = transform(finalizer)
466
+ val tree1 = cpy.Try (tree, expr1, handler1, finalizer1)
467
+ propagateType(tree, tree1)
468
+ case tree@ SeqLiteral (elems) =>
469
+ val elems1 = transform(elems)
470
+ val tree1 = cpy.SeqLiteral (tree, elems1)
471
+ propagateType(tree, tree1)
472
+ case tree@ Annotated (annot, arg) =>
473
+ val annot1 = transform(annot)
474
+ val arg1 = transform(arg)
475
+ val tree1 = cpy.Annotated (tree, annot1, arg1)
476
+ propagateType(tree, tree1)
477
+ case _ => super .transform(tree)
478
+ }
479
+ }
480
+
416
481
/** A map that applies three functions together to a tree and makes sure
417
482
* they are coordinated so that the result is well-typed. The functions are
418
483
* @param typeMap A function from Type to type that gets applied to the
@@ -425,7 +490,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
425
490
final class TreeTypeMap (
426
491
val typeMap : Type => Type = IdentityTypeMap ,
427
492
val ownerMap : Symbol => Symbol = identity _,
428
- val treeMap : Tree => Tree = identity _)(implicit ctx : Context ) extends TreeMap {
493
+ val treeMap : Tree => Tree = identity _)(implicit ctx : Context ) extends RetypingTreeMap {
429
494
430
495
override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = {
431
496
val tree1 = treeMap(tree)
@@ -436,10 +501,16 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
436
501
cpy.DefDef (ddef, mods, name, tparams1, vparamss1, tmap2.transform(tpt), tmap2.transform(rhs))
437
502
case blk @ Block (stats, expr) =>
438
503
val (tmap1, stats1) = transformDefs(stats)
439
- cpy.Block (blk, stats1, tmap1.transform(expr))
504
+ val expr1 = tmap1.transform(expr)
505
+ val tree1 = cpy.Block (blk, stats1, expr1)
506
+ propagateType(blk, tree1)
440
507
case cdef @ CaseDef (pat, guard, rhs) =>
441
508
val tmap = withMappedSyms(patVars(pat))
442
- cpy.CaseDef (cdef, tmap.transform(pat), tmap.transform(guard), tmap.transform(rhs))
509
+ val pat1 = tmap.transform(pat)
510
+ val guard1 = tmap.transform(guard)
511
+ val rhs1 = tmap.transform(rhs)
512
+ val tree1 = cpy.CaseDef (tree, pat1, guard1, rhs1)
513
+ propagateType(cdef, tree1)
443
514
case tree1 =>
444
515
super .transform(tree1)
445
516
}
@@ -501,6 +572,56 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
501
572
acc(Nil , tree)
502
573
}
503
574
575
+ def propagateType (origTree : Pair , newTree : Pair )(implicit ctx : Context ) = {
576
+ if ((newTree eq origTree) ||
577
+ ((newTree.left.tpe eq origTree.left.tpe) && (newTree.right.tpe eq origTree.right.tpe))) newTree
578
+ else ta.assignType(newTree, newTree.left, newTree.right)
579
+ }
580
+
581
+ def propagateType (origTree : Block , newTree : Block )(implicit ctx : Context ) = {
582
+ if ((newTree eq origTree) || (newTree.expr.tpe eq origTree.expr.tpe)) newTree
583
+ else ta.assignType(newTree, newTree.stats, newTree.expr)
584
+ }
585
+
586
+ def propagateType (origTree : If , newTree : If )(implicit ctx : Context ) = {
587
+ if ((newTree eq origTree) ||
588
+ ((newTree.thenp.tpe eq origTree.thenp.tpe) && (newTree.elsep.tpe eq origTree.elsep.tpe))) newTree
589
+ else ta.assignType(newTree, newTree.thenp, newTree.elsep)
590
+ }
591
+
592
+ def propagateType (origTree : Match , newTree : Match )(implicit ctx : Context ) = {
593
+ if ((newTree eq origTree) || sameTypes(newTree.cases, origTree.cases)) newTree
594
+ else ta.assignType(newTree, newTree.cases)
595
+ }
596
+
597
+ def propagateType (origTree : CaseDef , newTree : CaseDef )(implicit ctx : Context ) = {
598
+ if ((newTree eq newTree) || (newTree.body.tpe eq origTree.body.tpe)) newTree
599
+ else ta.assignType(newTree, newTree.body)
600
+ }
601
+
602
+ def propagateType (origTree : Try , newTree : Try )(implicit ctx : Context ) = {
603
+ if ((newTree eq origTree) ||
604
+ ((newTree.expr.tpe eq origTree.expr.tpe) && (newTree.handler.tpe eq origTree.handler.tpe))) newTree
605
+ else ta.assignType(newTree, newTree.expr, newTree.handler)
606
+ }
607
+
608
+ def propagateType (origTree : SeqLiteral , newTree : SeqLiteral )(implicit ctx : Context ) = {
609
+ if ((newTree eq origTree) || sameTypes(newTree.elems, origTree.elems)) newTree
610
+ else ta.assignType(newTree, newTree.elems)
611
+ }
612
+
613
+ def propagateType (origTree : Annotated , newTree : Annotated )(implicit ctx : Context ) = {
614
+ if ((newTree eq origTree) || ((newTree.arg.tpe eq origTree.arg.tpe) && (newTree.annot eq origTree.annot))) newTree
615
+ else ta.assignType(newTree, newTree.annot, newTree.arg)
616
+ }
617
+
618
+ def propagateType (origTree : Select , newTree : Select )(implicit ctx : Context ) = {
619
+ if ((origTree eq newTree) || (origTree.qualifier.tpe eq newTree.qualifier.tpe)) newTree
620
+ else newTree.tpe match {
621
+ case tpe : NamedType => newTree.withType(tpe.derivedSelect(newTree.qualifier.tpe))
622
+ case _ => newTree
623
+ }
624
+ }
504
625
// convert a numeric with a toXXX method
505
626
def primitiveConversion (tree : Tree , numericCls : Symbol )(implicit ctx : Context ): Tree = {
506
627
val mname = (" to" + numericCls.name).toTermName
@@ -515,6 +636,13 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
515
636
}
516
637
}
517
638
639
+ @ tailrec
640
+ def sameTypes (trees : List [tpd.Tree ], trees1 : List [tpd.Tree ]): Boolean = {
641
+ if (trees.isEmpty) trees.isEmpty
642
+ else if (trees1.isEmpty) trees.isEmpty
643
+ else (trees.head.tpe eq trees1.head.tpe) && sameTypes(trees.tail, trees1.tail)
644
+ }
645
+
518
646
def evalOnce (tree : Tree )(within : Tree => Tree )(implicit ctx : Context ) = {
519
647
if (isIdempotentExpr(tree)) within(tree)
520
648
else {
0 commit comments