Skip to content

Commit ef2c479

Browse files
committed
Merge pull request #1231 from dotty-staging/fix-equality
Fixes related to equality strawman
2 parents 02f1ec9 + 6b0ae0b commit ef2c479

File tree

9 files changed

+130
-10
lines changed

9 files changed

+130
-10
lines changed

src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ class Definitions {
449449
def ContravariantBetweenAnnot(implicit ctx: Context) = ContravariantBetweenAnnotType.symbol.asClass
450450
lazy val DeprecatedAnnotType = ctx.requiredClassRef("scala.deprecated")
451451
def DeprecatedAnnot(implicit ctx: Context) = DeprecatedAnnotType.symbol.asClass
452+
lazy val ImplicitNotFoundAnnotType = ctx.requiredClassRef("scala.annotation.implicitNotFound")
453+
def ImplicitNotFoundAnnot(implicit ctx: Context) = ImplicitNotFoundAnnotType.symbol.asClass
452454
lazy val InvariantBetweenAnnotType = ctx.requiredClassRef("dotty.annotation.internal.InvariantBetween")
453455
def InvariantBetweenAnnot(implicit ctx: Context) = InvariantBetweenAnnotType.symbol.asClass
454456
lazy val MigrationAnnotType = ctx.requiredClassRef("scala.annotation.migration")

src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
678678
isSubType(tp11, tp21) && {
679679
val leftConstraint = constraint
680680
constraint = preConstraint
681-
if (isSubType(tp12, tp22) && !subsumes(leftConstraint, constraint, preConstraint))
681+
if (!(isSubType(tp12, tp22) && subsumes(leftConstraint, constraint, preConstraint)))
682682
constraint = leftConstraint
683683
true
684684
} || isSubType(tp12, tp22)

src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1611,7 +1611,7 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans
16111611
// (otherwise equality is required)
16121612
def compareOp: (Tree, Tree) => Tree =
16131613
if (aligner.isStar) _.select(defn.Int_>=).appliedTo(_)
1614-
else _.select(defn.Int_==).appliedTo(_)
1614+
else _.select(defn.Int_==).appliedTo(_)
16151615

16161616
// `if (binder != null && $checkExpectedLength [== | >=] 0) then else zero`
16171617
(seqTree(binder).select(defn.Any_!=).appliedTo(Literal(Constant(null)))).select(defn.Boolean_&&).appliedTo(compareOp(checkExpectedLength, Literal(Constant(0))))

src/dotty/tools/dotc/typer/ErrorReporting.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ object ErrorReporting {
116116
| found : $found
117117
| required: $expected""".stripMargin + whyNoMatchStr(found, expected)
118118
}
119+
120+
/** Format `raw` implicitNotFound argument, replacing all
121+
* occurrences of `${X}` where `X` is in `paramNames` with the
122+
* corresponding shown type in `args`.
123+
*/
124+
def implicitNotFoundString(raw: String, paramNames: List[String], args: List[Type]): String = {
125+
def translate(name: String): Option[String] = {
126+
val idx = paramNames.indexOf(name)
127+
if (idx >= 0) Some(args(idx).show) else None
128+
}
129+
"""\$\{\w*\}""".r.replaceSomeIn(raw, m => translate(m.matched.drop(2).init))
130+
}
119131
}
120132

121133
def err(implicit ctx: Context): Errors = new Errors

src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ object Inferencing {
4646

4747
/** Instantiate selected type variables `tvars` in type `tp` */
4848
def instantiateSelected(tp: Type, tvars: List[Type])(implicit ctx: Context): Unit =
49-
new IsFullyDefinedAccumulator(new ForceDegree.Value(tvars.contains)).process(tp)
49+
new IsFullyDefinedAccumulator(new ForceDegree.Value(tvars.contains, minimizeAll = true)).process(tp)
5050

5151
/** The accumulator which forces type variables using the policy encoded in `force`
5252
* and returns whether the type is fully defined. The direction in which
@@ -86,6 +86,7 @@ object Inferencing {
8686
}
8787
else {
8888
val minimize =
89+
force.minimizeAll ||
8990
variance >= 0 && !(
9091
force == ForceDegree.noBottom &&
9192
defn.isBottomType(ctx.typeComparer.approximation(tvar.origin, fromBelow = true)))
@@ -293,9 +294,9 @@ object Inferencing {
293294

294295
/** An enumeration controlling the degree of forcing in "is-dully-defined" checks. */
295296
@sharable object ForceDegree {
296-
class Value(val appliesTo: TypeVar => Boolean)
297-
val none = new Value(_ => false)
298-
val all = new Value(_ => true)
299-
val noBottom = new Value(_ => true)
297+
class Value(val appliesTo: TypeVar => Boolean, val minimizeAll: Boolean)
298+
val none = new Value(_ => false, minimizeAll = false)
299+
val all = new Value(_ => true, minimizeAll = false)
300+
val noBottom = new Value(_ => true, minimizeAll = false)
300301
}
301302

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
15291529
}
15301530
def issueErrors() = {
15311531
for (err <- errors) ctx.error(err(), tree.pos.endPos)
1532-
tree
1532+
tree.withType(wtp.resultType)
15331533
}
15341534
val args = (wtp.paramNames, wtp.paramTypes).zipped map { (pname, formal) =>
15351535
def where = d"parameter $pname of $methodStr"
@@ -1541,7 +1541,16 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
15411541
case failure: SearchFailure =>
15421542
val arg = synthesizedClassTag(formal)
15431543
if (!arg.isEmpty) arg
1544-
else implicitArgError(d"no implicit argument of type $formal found for $where" + failure.postscript)
1544+
else {
1545+
var msg = d"no implicit argument of type $formal found for $where" + failure.postscript
1546+
for (notFound <- formal.typeSymbol.getAnnotation(defn.ImplicitNotFoundAnnot);
1547+
Literal(Constant(raw: String)) <- notFound.argument(0))
1548+
msg = err.implicitNotFoundString(
1549+
raw,
1550+
formal.typeSymbol.typeParams.map(_.name.unexpandedName.toString),
1551+
formal.argInfos)
1552+
implicitArgError(msg)
1553+
}
15451554
}
15461555
}
15471556
if (errors.nonEmpty) {

tests/neg/EqualityStrawman1.scala

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package strawman.equality
2+
import annotation.implicitNotFound
3+
4+
object EqualityStrawman1 {
5+
6+
trait Eq[-T]
7+
8+
@implicitNotFound("cannot compare value of type ${T} with a value outside its equality class")
9+
trait Impossible[T]
10+
11+
object Eq extends Eq[Any]
12+
13+
trait Base {
14+
def === (other: Any): Boolean = this.equals(other)
15+
def === [T <: CondEquals](other: T)(implicit ce: Impossible[T]): Boolean = ???
16+
}
17+
18+
trait CondEquals extends Base {
19+
def === [T >: this.type <: CondEquals](other: T)(implicit ce: Eq[T]): Boolean = this.equals(other)
20+
def === [T](other: T)(implicit ce: Impossible[T]): Boolean = ???
21+
}
22+
23+
trait Equals[-T] extends CondEquals
24+
25+
case class Str(str: String) extends CondEquals
26+
27+
case class Num(x: Int) extends Equals[Num]
28+
29+
case class Other(x: Int) extends Base
30+
31+
trait Option[+T] extends CondEquals
32+
case class Some[+T](x: T) extends Option[T]
33+
case object None extends Option[Nothing]
34+
35+
implicit def eqStr: Eq[Str] = Eq
36+
//implicit def eqNum: Eq[Num] = Eq
37+
implicit def eqOption[T: Eq]: Eq[Option[T]] = Eq
38+
39+
implicit def eqEq[T <: Equals[T]]: Eq[T] = Eq
40+
41+
def main(args: Array[String]): Unit = {
42+
val x = Str("abc")
43+
x === x
44+
45+
val n = Num(2)
46+
val m = Num(3)
47+
n === m
48+
49+
Other(1) === Other(2)
50+
51+
Some(x) === None
52+
Some(x) === Some(Str(""))
53+
val z: Option[Str] = Some(Str("abc"))
54+
z === Some(x)
55+
z === None
56+
Some(x) === z
57+
None === z
58+
59+
60+
def ddistinct[T <: Base: Eq](xs: List[T]): List[T] = xs match {
61+
case Nil => Nil
62+
case x :: xs => x :: xs.filterNot(x === _)
63+
}
64+
65+
ddistinct(List(z, z, z))
66+
67+
x === n // error
68+
n === x // error
69+
x === Other(1) // error
70+
Other(2) === x // error
71+
z === Some(n) // error
72+
z === n // error
73+
Some(n) === z // error
74+
n === z // error
75+
Other(1) === z // error
76+
z === Other(1) // error
77+
ddistinct(List(z, n)) // error
78+
}
79+
}

tests/neg/subtyping.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class A extends B
66
object Test {
77
def test1(): Unit = {
88
implicitly[B#X <:< A#X] // error: no implicit argument
9-
} // error: no implicit argument
9+
}
1010
def test2(): Unit = {
1111
val a : { type T; type U } = ??? // error // error
1212
implicitly[a.T <:< a.U] // error: no implicit argument

tests/pos/i878.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
class X
2+
3+
class Message[-A]
4+
class Seg[+A]
5+
class IChan[A] {
6+
def add[B >: A](x: Seg[B])(implicit ev: Message[B]): IChan[B] = ???
7+
}
8+
9+
class Test {
10+
def test: Unit = {
11+
implicit val mx: Message[X] = ???
12+
val fx: IChan[X] = ???
13+
val sx: Seg[X] = ???
14+
// the implicit `mx` should be used even though the type parameter of Message is contravariant
15+
fx.add(sx)
16+
}
17+
}

0 commit comments

Comments
 (0)