diff --git a/compiler/test/dotty/tools/CheckTypesTests.scala b/compiler/test/dotty/tools/CheckTypesTests.scala new file mode 100644 index 000000000000..0b68fd5e5072 --- /dev/null +++ b/compiler/test/dotty/tools/CheckTypesTests.scala @@ -0,0 +1,69 @@ +package dotty.tools + +import org.junit.Test +import org.junit.Assert.{ assertFalse, assertTrue, fail } + +import dotc.ast.Trees._ +import dotc.core.Decorators._ + +class CheckTypeTest extends DottyTest { + @Test + def checkTypesTest: Unit = { + val source = """ + |class A + |class B extends A + """.stripMargin + + val types = List( + "A", + "B", + "List[_]", + "List[Int]", + "List[AnyRef]", + "List[String]", + "List[A]", + "List[B]" + ) + + checkTypes(source, types: _*) { + case (List(a, b, lu, li, lr, ls, la, lb), context) => + implicit val ctx = context + + assertTrue ( b <:< a) + assertTrue (li <:< lu) + assertFalse (li <:< lr) + assertTrue (ls <:< lr) + assertTrue (lb <:< la) + assertFalse (la <:< lb) + + case _ => fail + } + } + + @Test + def checkTypessTest: Unit = { + val source = """ + |class A + |class B extends A + """.stripMargin + + val typesA = List( + "A", + "List[A]" + ) + + val typesB = List( + "B", + "List[B]" + ) + + checkTypes(source, List(typesA, typesB)) { + case (List(sups, subs), context) => + implicit val ctx = context + + (sups, subs).zipped.foreach { (sup, sub) => assertTrue(sub <:< sup) } + + case _ => fail + } + } +} diff --git a/compiler/test/dotty/tools/DottyTest.scala b/compiler/test/dotty/tools/DottyTest.scala index 54c8f4e1ad60..2304187527a4 100644 --- a/compiler/test/dotty/tools/DottyTest.scala +++ b/compiler/test/dotty/tools/DottyTest.scala @@ -75,6 +75,35 @@ trait DottyTest extends ContextEscapeDetection { run.runContext } + def checkTypes(source: String, typeStrings: String*)(assertion: (List[Type], Context) => Unit): Unit = + checkTypes(source, List(typeStrings.toList)) { (tpess, ctx) => (tpess: @unchecked) match { + case List(tpes) => assertion(tpes, ctx) + }} + + def checkTypes(source: String, typeStringss: List[List[String]])(assertion: (List[List[Type]], Context) => Unit): Unit = { + val dummyName = "x_x_x" + val vals = typeStringss.flatten.zipWithIndex.map{case (s, x)=> s"val ${dummyName}$x: $s = ???"}.mkString("\n") + val gatheredSource = s" ${source}\n object A$dummyName {$vals}" + checkCompile("frontend", gatheredSource) { + (tree, context) => + implicit val ctx = context + val findValDef: (List[tpd.ValDef], tpd.Tree) => List[tpd.ValDef] = + (acc , tree) => { tree match { + case t: tpd.ValDef if t.name.startsWith(dummyName) => t :: acc + case _ => acc + } + } + val d = new tpd.DeepFolder[List[tpd.ValDef]](findValDef).foldOver(Nil, tree) + val tpes = d.map(_.tpe.widen).reverse + val tpess = typeStringss.foldLeft[(List[Type], List[List[Type]])]((tpes, Nil)) { + case ((rest, result), typeStrings) => + val (prefix, suffix) = rest.splitAt(typeStrings.length) + (suffix, prefix :: result) + }._2.reverse + assertion(tpess, context) + } + } + def methType(names: String*)(paramTypes: Type*)(resultType: Type = defn.UnitType) = MethodType(names.toList map (_.toTermName), paramTypes.toList, resultType) }