Skip to content

Commit 6e59705

Browse files
authored
Merge pull request #5539 from dotty-staging/fix-#5533
Fix #5533: Check conformance of subtype directly
2 parents 75cbbd6 + c308a56 commit 6e59705

File tree

10 files changed

+135
-24
lines changed

10 files changed

+135
-24
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/QuotedOpsImpl.scala

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package dotty.tools.dotc.tastyreflect
22

3-
import dotty.tools.dotc.core.Contexts.FreshContext
43
import dotty.tools.dotc.core.quoted.PickledQuotes
5-
import dotty.tools.dotc.reporting.Reporter
6-
import dotty.tools.dotc.reporting.diagnostic.MessageContainer
74

85
trait QuotedOpsImpl extends scala.tasty.reflect.QuotedOps with CoreImpl {
96

@@ -18,30 +15,17 @@ trait QuotedOpsImpl extends scala.tasty.reflect.QuotedOps with CoreImpl {
1815
def TermToQuoteDeco(term: Term): TermToQuotedAPI = new TermToQuotedAPI {
1916

2017
def seal[T: scala.quoted.Type](implicit ctx: Context): scala.quoted.Expr[T] = {
21-
typecheck(ctx)
18+
typecheck()
2219
new scala.quoted.Exprs.TastyTreeExpr(term).asInstanceOf[scala.quoted.Expr[T]]
2320
}
2421

25-
private def typecheck[T: scala.quoted.Type](ctx: Context): Unit = {
26-
implicit val ctx0: FreshContext = ctx.fresh
27-
ctx0.setTyperState(ctx0.typerState.fresh())
28-
ctx0.typerState.setReporter(new Reporter {
29-
def doReport(m: MessageContainer)(implicit ctx: Context): Unit = ()
30-
})
31-
val tp = QuotedTypeDeco(implicitly[scala.quoted.Type[T]]).unseal
32-
ctx0.typer.typed(term, tp.tpe)
33-
if (ctx0.reporter.hasErrors) {
34-
val stack = new Exception().getStackTrace
35-
def filter(elem: StackTraceElement) =
36-
elem.getClassName.startsWith("dotty.tools.dotc.tasty.ReflectionImpl") ||
37-
!elem.getClassName.startsWith("dotty.tools.dotc")
22+
private def typecheck[T: scala.quoted.Type]()(implicit ctx: Context): Unit = {
23+
val tpt = QuotedTypeDeco(implicitly[scala.quoted.Type[T]]).unseal
24+
if (!(term.tpe <:< tpt.tpe)) {
3825
throw new scala.tasty.TastyTypecheckError(
39-
s"""Error during tasty reflection while typing term
40-
|term: ${term.show}
41-
|with expected type: ${tp.tpe.show}
42-
|
43-
| ${stack.takeWhile(filter).mkString("\n ")}
44-
""".stripMargin
26+
s"""Term: ${term.show}
27+
|did not conform to type: ${tpt.tpe.show}
28+
|""".stripMargin
4529
)
4630
}
4731
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
package scala.tasty
22

3-
class TastyTypecheckError(msg: String) extends Throwable
3+
class TastyTypecheckError(msg: String) extends Throwable(msg)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
true
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import scala.quoted._
2+
import scala.tasty._
3+
4+
object scalatest {
5+
6+
def f(x: Int): Boolean = false
7+
def f(x: String): Boolean = true
8+
9+
inline def assert(condition: => Boolean): Unit = ~assertImpl('(condition))
10+
11+
def assertImpl(condition: Expr[Boolean])(implicit refl: Reflection): Expr[Unit] = {
12+
import refl._
13+
14+
val tree = condition.unseal
15+
16+
val expr = tree.seal[Boolean]
17+
18+
'(println(~expr))
19+
}
20+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
import scalatest._
5+
val x = "String"
6+
assert(f("abc"))
7+
}
8+
9+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
String
2+
String
3+
()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import scala.quoted._
2+
import scala.tasty._
3+
4+
object scalatest {
5+
def f(x: Int): Int = x
6+
def f(x: String): String = x
7+
8+
inline def assert(condition: => Boolean): Unit = ~assertImpl('(condition))
9+
10+
def assertImpl(condition: Expr[Boolean])(implicit refl: Reflection): Expr[Unit] = {
11+
import refl._
12+
import quoted.Toolbox.Default._
13+
14+
val tree = condition.unseal
15+
def exprStr: String = condition.show
16+
17+
tree.underlyingArgument match {
18+
case Term.Apply(Term.Select(lhs, op, _), rhs :: Nil) =>
19+
val left = lhs.seal[Any]
20+
val right = rhs.seal[Any]
21+
op match {
22+
case "==" =>
23+
'{
24+
val _left = ~left
25+
val _right = ~right
26+
val _result = _left == _right
27+
println(_left)
28+
println(_right)
29+
scala.Predef.assert(_result)
30+
}
31+
}
32+
}
33+
}
34+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
import scalatest._
5+
val x = "String"
6+
println(assert(f(x) == "String"))
7+
}
8+
9+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import scala.quoted._
2+
import scala.tasty._
3+
4+
object scalatest {
5+
inline def assert(condition: => Boolean): Unit = ~assertImpl('(condition))
6+
7+
def assertImpl(condition: Expr[Boolean])(implicit refl: Reflection): Expr[Unit] = {
8+
import refl._
9+
import quoted.Toolbox.Default._
10+
11+
val tree = condition.unseal
12+
def exprStr: String = condition.show
13+
14+
tree.underlyingArgument match {
15+
case Term.Apply(Term.Select(lhs, op, _), rhs :: Nil) =>
16+
val left = lhs.seal[Any]
17+
val right = rhs.seal[Any]
18+
op match {
19+
case "===" =>
20+
'{
21+
val _left = ~left
22+
val _right = ~right
23+
val _result = _left == _right
24+
scala.Predef.assert(_result)
25+
}
26+
}
27+
}
28+
}
29+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
class Equalizer[L](val leftSide: L) {
3+
def ===(literalNull: Null): Boolean = leftSide == null
4+
}
5+
6+
object Equality {
7+
implicit def toEqualizer[T](x: T): Equalizer[T] = new Equalizer(x)
8+
}
9+
10+
11+
object Test {
12+
import scalatest._
13+
import Equality._
14+
15+
def main(args: Array[String]): Unit = {
16+
val x = "String"
17+
try assert(x === null)
18+
catch {
19+
case _: AssertionError => // OK
20+
}
21+
}
22+
}

0 commit comments

Comments
 (0)