Skip to content

Commit 7fa5107

Browse files
committed
Follow GADT bounds when computing members
When computing the member denotation of a selection of a TypeRef `T`, if the normal scheme fails and `T` has GADT bounds, compute the member in the upper bound instead. This is needed to make the opaque-probability test work. Add this test, as well as some others coming from SIP 35.
1 parent 2384aed commit 7fa5107

File tree

9 files changed

+237
-6
lines changed

9 files changed

+237
-6
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ object Types {
263263
}
264264

265265
/** Is some part of this type produced as a repair for an error? */
266-
final def isErroneous(implicit ctx: Context): Boolean = existsPart(_.isError, forceLazy = false)
266+
def isErroneous(implicit ctx: Context): Boolean = existsPart(_.isError, forceLazy = false)
267267

268268
/** Does the type carry an annotation that is an instance of `cls`? */
269269
@tailrec final def hasAnnotation(cls: ClassSymbol)(implicit ctx: Context): Boolean = stripTypeVar match {
@@ -1283,6 +1283,23 @@ object Types {
12831283
*/
12841284
def deepenProto(implicit ctx: Context): Type = this
12851285

1286+
/** If this is a TypeRef or an Application of a GADT-bound type, replace the
1287+
* GADT reference by its upper GADT bound. Otherwise NoType.
1288+
*/
1289+
def followGADT(implicit ctx: Context): Type = widenDealias match {
1290+
case site: TypeRef if site.symbol.is(Opaque) =>
1291+
ctx.gadt.bounds(site.symbol) match {
1292+
case TypeBounds(_, hi) => hi
1293+
case _ => NoType
1294+
}
1295+
case AppliedType(tycon, args) =>
1296+
val tycon1 = tycon.followGADT
1297+
if (tycon1.exists) tycon1.appliedTo(args)
1298+
else NoType
1299+
case _ =>
1300+
NoType
1301+
}
1302+
12861303
// ----- Substitutions -----------------------------------------------------
12871304

12881305
/** Substitute all types that refer in their symbol attribute to

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import core._
66
import ast._
77
import Contexts._, Types._, Flags._, Denotations._, Names._, StdNames._, NameOps._, Symbols._
88
import NameKinds.DepParamName
9+
import SymDenotations.NoDenotation
910
import Trees._
1011
import Constants._
1112
import Scopes._
@@ -104,8 +105,9 @@ object ProtoTypes {
104105
memberProto.isRef(defn.UnitClass) ||
105106
compat.normalizedCompatible(NamedType(tp1, name, m), memberProto)
106107
// Note: can't use `m.info` here because if `m` is a method, `m.info`
107-
// loses knowledge about `m`'s default arguments.
108+
// loses knowledge about `m`'s default arguments. ||
108109
mbr match { // hasAltWith inlined for performance
110+
case NoDenotation => tp1.exists && isMatchedBy(tp1.followGADT)
109111
case mbr: SingleDenotation => mbr.exists && qualifies(mbr)
110112
case _ => mbr hasAltWith qualifies
111113
}
@@ -282,6 +284,9 @@ object ProtoTypes {
282284

283285
def isDropped: Boolean = toDrop
284286

287+
override def isErroneous(implicit ctx: Context): Boolean =
288+
myTypedArgs.tpes.exists(_.widen.isErroneous)
289+
285290
override def toString = s"FunProto(${args mkString ","} => $resultType)"
286291

287292
def map(tm: TypeMap)(implicit ctx: Context): FunProto =

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,12 @@ trait TypeAssigner {
236236
if (reallyExists(mbr)) site.select(name, mbr)
237237
else if (site.derivesFrom(defn.DynamicClass) && !Dynamic.isDynamicMethod(name)) {
238238
TryDynamicCallType
239-
} else {
240-
if (site.isErroneous || name.toTermName == nme.ERROR) UnspecifiedErrorType
239+
}
240+
else {
241+
val site1 = site.followGADT
242+
if (site1.exists) selectionType(site1, name, pos)
243+
else if (site.isErroneous || name.toTermName == nme.ERROR)
244+
UnspecifiedErrorType
241245
else {
242246
def kind = if (name.isTypeName) "type" else "value"
243247
def addendum =

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,8 +2087,11 @@ class Typer extends Namer
20872087
noMatches
20882088
}
20892089
case alts =>
2090-
val remainingDenots = alts map (_.denot.asInstanceOf[SingleDenotation])
2091-
errorTree(tree, AmbiguousOverload(tree, remainingDenots, pt)(err))
2090+
if (tree.tpe.isErroneous || pt.isErroneous) tree.withType(UnspecifiedErrorType)
2091+
else {
2092+
val remainingDenots = alts map (_.denot.asInstanceOf[SingleDenotation])
2093+
errorTree(tree, AmbiguousOverload(tree, remainingDenots, pt)(err))
2094+
}
20922095
}
20932096
}
20942097

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
object opaquetypes {
2+
3+
opaque type Fix[F[_]] = F[Fix2[F]]
4+
5+
opaque type Fix2[F[_]] = Fix[F]
6+
7+
object Fix {
8+
def unfold[F[_]](x: Fix[F]): F[Fix]
9+
}
10+
11+
object Fix2 {
12+
def unfold[F[_]](x: Fix2[F]: Fix[F] = x
13+
def fold[F[_]](x: Fix[F]: Fix2[F] = x
14+
}
15+
16+
}

tests/pos/opaque-digits.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
object pkg {
2+
3+
import Character.{isAlphabetic, isDigit}
4+
5+
class Alphabetic private[pkg] (val value: String) extends AnyVal
6+
7+
object Alphabetic {
8+
def fromString(s: String): Option[Alphabetic] =
9+
if (s.forall(isAlphabetic(_))) Some(new Alphabetic(s))
10+
else None
11+
}
12+
13+
opaque type Digits = String
14+
15+
object Digits {
16+
def fromString(s: String): Option[Digits] =
17+
if (s.forall(isDigit(_))) Some(s)
18+
else None
19+
20+
def asString(d: Digits): String = d
21+
}
22+
}

tests/pos/opaque-goups.scala

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package object groups {
2+
trait Semigroup[A] {
3+
def combine(x: A, y: A): A
4+
}
5+
6+
object Semigroup {
7+
def instance[A](f: (A, A) => A): Semigroup[A] =
8+
new Semigroup[A] {
9+
def combine(x: A, y: A): A = f(x, y)
10+
}
11+
}
12+
13+
type Id[A] = A
14+
15+
trait Wrapping[F[_]] {
16+
def wraps[G[_], A](ga: G[A]): G[F[A]]
17+
def unwrap[G[_], A](ga: G[F[A]]): G[A]
18+
}
19+
20+
abstract class Wrapper[F[_]] { self =>
21+
def wraps[G[_], A](ga: G[A]): G[F[A]]
22+
def unwrap[G[_], A](gfa: G[F[A]]): G[A]
23+
24+
final def apply[A](a: A): F[A] = wraps[Id, A](a)
25+
26+
implicit object WrapperWrapping extends Wrapping[F] {
27+
def wraps[G[_], A](ga: G[A]): G[F[A]] = self.wraps(ga)
28+
def unwrap[G[_], A](ga: G[F[A]]): G[A] = self.unwrap(ga)
29+
}
30+
}
31+
32+
opaque type First[A] = A
33+
object First extends Wrapper[First] {
34+
def wraps[G[_], A](ga: G[A]): G[First[A]] = ga
35+
def unwrap[G[_], A](gfa: G[First[A]]): G[A] = gfa
36+
implicit def firstSemigroup[A]: Semigroup[First[A]] =
37+
Semigroup.instance((x, y) => x)
38+
}
39+
40+
opaque type Last[A] = A
41+
object Last extends Wrapper[Last] {
42+
def wraps[G[_], A](ga: G[A]): G[Last[A]] = ga
43+
def unwrap[G[_], A](gfa: G[Last[A]]): G[A] = gfa
44+
implicit def lastSemigroup[A]: Semigroup[Last[A]] =
45+
Semigroup.instance((x, y) => y)
46+
}
47+
48+
opaque type Min[A] = A
49+
object Min extends Wrapper[Min] {
50+
def wraps[G[_], A](ga: G[A]): G[Min[A]] = ga
51+
def unwrap[G[_], A](gfa: G[Min[A]]): G[A] = gfa
52+
implicit def minSemigroup[A](implicit o: Ordering[A]): Semigroup[Min[A]] =
53+
Semigroup.instance(o.min)
54+
}
55+
56+
opaque type Max[A] = A
57+
object Max extends Wrapper[Max] {
58+
def wraps[G[_], A](ga: G[A]): G[Max[A]] = ga
59+
def unwrap[G[_], A](gfa: G[Max[A]]): G[A] = gfa
60+
implicit def maxSemigroup[A](implicit o: Ordering[A]): Semigroup[Max[A]] =
61+
Semigroup.instance(o.max)
62+
}
63+
64+
opaque type Plus[A] = A
65+
object Plus extends Wrapper[Plus] {
66+
def wraps[G[_], A](ga: G[A]): G[Plus[A]] = ga
67+
def unwrap[G[_], A](gfa: G[Plus[A]]): G[A] = gfa
68+
implicit def plusSemigroup[A](implicit n: Numeric[A]): Semigroup[Plus[A]] =
69+
Semigroup.instance(n.plus)
70+
}
71+
72+
opaque type Times[A] = A
73+
object Times extends Wrapper[Times] {
74+
def wraps[G[_], A](ga: G[A]): G[Times[A]] = ga
75+
def unwrap[G[_], A](gfa: G[Times[A]]): G[A] = gfa
76+
implicit def timesSemigroup[A](implicit n: Numeric[A]): Semigroup[Times[A]] =
77+
Semigroup.instance(n.times)
78+
}
79+
80+
opaque type Reversed[A] = A
81+
object Reversed extends Wrapper[Reversed] {
82+
def wraps[G[_], A](ga: G[A]): G[Reversed[A]] = ga
83+
def unwrap[G[_], A](gfa: G[Reversed[A]]): G[A] = gfa
84+
implicit def reversedOrdering[A](implicit o: Ordering[A]): Ordering[Reversed[A]] =
85+
o.reverse
86+
}
87+
88+
opaque type Unordered[A] = A
89+
object Unordered extends Wrapper[Unordered] {
90+
def wraps[G[_], A](ga: G[A]): G[Unordered[A]] = ga
91+
def unwrap[G[_], A](gfa: G[Unordered[A]]): G[A] = gfa
92+
implicit def unorderedOrdering[A]: Ordering[Unordered[A]] =
93+
Ordering.by(_ => ())
94+
}
95+
}

tests/pos/opaque-nullable.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
object nullable {
2+
opaque type Nullable[A >: Null <: AnyRef] = A
3+
4+
object Nullable {
5+
def apply[A >: Null <: AnyRef](a: A): Nullable[A] = a
6+
7+
implicit class NullableOps[A >: Null <: AnyRef](na: Nullable[A]) {
8+
def exists(p: A => Boolean): Boolean =
9+
na != null && p(na)
10+
def filter(p: A => Boolean): Nullable[A] =
11+
if (na != null && p(na)) na else null
12+
def flatMap[B >: Null <: AnyRef](f: A => Nullable[B]): Nullable[B] =
13+
if (na == null) null else f(na)
14+
def forall(p: A => Boolean): Boolean =
15+
na == null || p(na)
16+
def getOrElse(a: => A): A =
17+
if (na == null) a else na
18+
def map[B >: Null <: AnyRef](f: A => B): Nullable[B] =
19+
if (na == null) null else f(na)
20+
def orElse(na2: => Nullable[A]): Nullable[A] =
21+
if (na == null) na2 else na
22+
def toOption: Option[A] =
23+
Option(na)
24+
}
25+
}
26+
}

tests/pos/opaque-propability.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
object prob {
2+
opaque type Probability = Double
3+
4+
object Probability {
5+
def apply(n: Double): Option[Probability] =
6+
if (0.0 <= n && n <= 1.0) Some(n) else None
7+
8+
def unsafe(p: Double): Probability = {
9+
require(0.0 <= p && p <= 1.0, s"probabilities lie in [0, 1] (got $p)")
10+
p
11+
}
12+
13+
def asDouble(p: Probability): Double = p
14+
15+
val Never: Probability = 0.0
16+
val CoinToss: Probability = 0.5
17+
val Certain: Probability = 1.0
18+
19+
implicit val ordering: Ordering[Probability] =
20+
implicitly[Ordering[Double]]
21+
22+
implicit class ProbabilityOps(p1: Probability) extends AnyVal {
23+
def unary_~ : Probability = Certain - p1
24+
def &(p2: Probability): Probability = p1 * p2
25+
def |(p2: Probability): Probability = p1 + p2 - (p1 * p2)
26+
27+
def isImpossible: Boolean = p1 == Never
28+
def isCertain: Boolean = p1 == Certain
29+
30+
import scala.util.Random
31+
32+
def sample(r: Random = Random): Boolean = r.nextDouble <= p1
33+
def toDouble: Double = p1
34+
}
35+
36+
val caughtTrain = Probability.unsafe(0.3)
37+
val missedTrain = ~caughtTrain
38+
val caughtCab = Probability.CoinToss
39+
val arrived = caughtTrain | (missedTrain & caughtCab)
40+
41+
println((1 to 5).map(_ => arrived.sample()).toList)
42+
}
43+
}

0 commit comments

Comments
 (0)