Skip to content

Commit 1582173

Browse files
committed
Add type refinement for abstract type bindings
1 parent 6240daa commit 1582173

9 files changed

+264
-4
lines changed

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ object StdNames {
488488
val productPrefix: N = "productPrefix"
489489
val raw_ : N = "raw"
490490
val readResolve: N = "readResolve"
491+
val refinedScrutinee: N = "refinedScrutinee"
491492
val reflect : N = "reflect"
492493
val reflectiveSelectable: N = "reflectiveSelectable"
493494
val reify : N = "reify"

compiler/src/dotty/tools/dotc/transform/FirstTransform.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,25 @@ class FirstTransform extends MiniPhase with InfoTransformer { thisPhase =>
108108
cpy.Template(impl)(self = EmptyValDef)
109109
}
110110

111-
override def transformDefDef(ddef: DefDef)(implicit ctx: Context) = {
111+
override def transformDefDef(ddef: DefDef)(implicit ctx: Context): Tree = {
112+
if (ddef.name == nme.unapply && !ddef.symbol.is(Synthetic)) {
113+
ddef.tpe.widen match {
114+
case mt: MethodType if !mt.resType.widen.isInstanceOf[MethodicType] =>
115+
val resultType = mt.resType.substParam(mt.paramRefs.head, mt.paramRefs.head)
116+
val refinedType = resultType.select(nme.refinedScrutinee).widen.resultType
117+
if (refinedType.exists && !(refinedType <:< mt.paramRefs.head)) {
118+
val paramName = mt.paramNames.head
119+
val paramTpe = mt.paramRefs.head
120+
val paramInfo = mt.paramInfos.head
121+
ctx.error(
122+
i"""Extractor with ${nme.refinedScrutinee} should refine the result type of that member.
123+
|The result type of ${nme.refinedScrutinee} should be a subtype of $paramTpe:
124+
| def unapply($paramName: $paramInfo): ${resultType.widenDealias.classSymbol.name} { def ${nme.refinedScrutinee}: $refinedType & $paramTpe }
125+
""".stripMargin, ddef.tpt.pos)
126+
}
127+
case _ =>
128+
}
129+
}
112130
if (ddef.symbol.hasAnnotation(defn.NativeAnnot)) {
113131
ddef.symbol.resetFlag(Deferred)
114132
DefDef(ddef.symbol.asTerm,

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
896896
* whereas overloaded variants need to have a conforming variant.
897897
*/
898898
def trySelectUnapply(qual: untpd.Tree)(fallBack: Tree => Tree): Tree = {
899-
// try first for non-overloaded, then for overloaded ocurrences
899+
// try first for non-overloaded, then for overloaded occurrences
900900
def tryWithName(name: TermName)(fallBack: Tree => Tree)(implicit ctx: Context): Tree = {
901901
def tryWithProto(pt: Type)(implicit ctx: Context) = {
902902
val result = typedExpr(untpd.Select(qual, name), new UnapplyFunProto(pt, this))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,10 +1323,29 @@ class Typer extends Namer
13231323
else {
13241324
// for a singleton pattern like `x @ Nil`, `x` should get the type from the scrutinee
13251325
// see tests/neg/i3200b.scala and SI-1503
1326-
val symTp =
1326+
val symTp0 =
13271327
if (body1.tpe.isInstanceOf[TermRef]) pt1
13281328
else body1.tpe.underlyingIfRepeated(isJava = false)
1329-
val sym = ctx.newPatternBoundSymbol(tree.name, symTp, tree.pos)
1329+
1330+
// If it is name based pattern matching, the type of the argument of the unapply is abstract and
1331+
// the return type has a type member `Refined`, then refine the type of the binding with the type of `Refined`.
1332+
val symTp1 = body1 match {
1333+
case Trees.UnApply(fun, _, _) if symTp0.typeSymbol.is(Deferred) =>
1334+
// TODO check that it is name based pattern matching
1335+
fun.tpe.widen match {
1336+
case mt: MethodType if !mt.resType.isInstanceOf[MethodType] =>
1337+
val resultType = mt.resType.substParam(mt.paramRefs.head, symTp0)
1338+
val refinedType = resultType.select(nme.refinedScrutinee).widen.resultType
1339+
if (refinedType.exists) refinedType
1340+
else symTp0
1341+
case _ =>
1342+
symTp0
1343+
}
1344+
1345+
case _ => symTp0
1346+
}
1347+
1348+
val sym = ctx.newPatternBoundSymbol(tree.name, symTp1, tree.pos)
13301349
if (ctx.mode.is(Mode.InPatternAlternative))
13311350
ctx.error(i"Illegal variable ${sym.name} in pattern alternative", tree.pos)
13321351
assignType(cpy.Bind(tree)(tree.name, body1), sym)

tests/neg/refined-binding-nat.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
trait Peano {
3+
type Nat
4+
type Zero <: Nat
5+
type Succ <: Nat
6+
7+
val Zero: Zero
8+
9+
val Succ: SuccExtractor
10+
trait SuccExtractor {
11+
def apply(nat: Nat): Succ
12+
def unapply(nat: Nat): SuccOpt // error: missing { def refinedScrutinee: Succ & nat.type }
13+
}
14+
trait SuccOpt {
15+
def isEmpty: Boolean
16+
def refinedScrutinee: Succ
17+
def get: Nat
18+
}
19+
}
20+
21+
object IntNums extends Peano {
22+
type Nat = Int
23+
type Zero = Int
24+
type Succ = Int
25+
26+
val Zero: Zero = 0
27+
28+
object Succ extends SuccExtractor {
29+
def apply(nat: Nat): Succ = nat + 1
30+
def unapply(nat: Nat) = new SuccOpt { // error: missing { def refinedScrutinee: Succ & nat.type }
31+
def isEmpty: Boolean = nat == 0
32+
def refinedScrutinee: Succ & nat.type = nat
33+
def get: Int = nat - 1
34+
}
35+
}
36+
37+
}
38+
39+
object IntNums2 extends Peano {
40+
type Nat = Int
41+
type Zero = Int
42+
type Succ = Int
43+
44+
val Zero: Zero = 0
45+
46+
object Succ extends SuccExtractor {
47+
def apply(nat: Nat): Succ = nat + 1
48+
def unapply(nat: Nat): SuccOpt { def refinedScrutinee: Succ & nat.type } = new SuccOpt {
49+
def isEmpty: Boolean = nat == 0
50+
def refinedScrutinee: Succ & nat.type = nat
51+
def get: Int = nat - 1
52+
}
53+
}
54+
55+
}

tests/run/refined-binding-nat.check

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Some((SuccClass(SuccClass(ZeroObject)),SuccClass(ZeroObject)))
2+
Some((ZeroObject,SuccClass(SuccClass(ZeroObject))))
3+
None
4+
Some((2,1))
5+
Some((0,2))
6+
None

tests/run/refined-binding-nat.scala

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
app(ClassNums)
5+
app(IntNums)
6+
}
7+
8+
def app(peano: Peano): Unit = {
9+
import peano._
10+
def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
11+
n match {
12+
case Zero => None
13+
case s @ Succ(_) => Some(safeDiv(m, s))
14+
}
15+
}
16+
val two = Succ(Succ(Zero))
17+
val five = Succ(Succ(Succ(two)))
18+
println(divOpt(five, two))
19+
println(divOpt(two, five))
20+
println(divOpt(two, Zero))
21+
}
22+
}
23+
24+
trait Peano {
25+
type Nat
26+
type Zero <: Nat
27+
type Succ <: Nat
28+
29+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
30+
31+
implicit def succDeco(succ: Succ): SuccAPI
32+
trait SuccAPI {
33+
def pred: Nat
34+
}
35+
36+
val Zero: Zero
37+
38+
val Succ: SuccExtractor
39+
trait SuccExtractor {
40+
def apply(nat: Nat): Succ
41+
def unapply(nat: Nat): SuccOpt { def refinedScrutinee: Succ & nat.type }
42+
}
43+
trait SuccOpt {
44+
def isEmpty: Boolean
45+
def refinedScrutinee: Succ
46+
def get: Nat
47+
}
48+
}
49+
50+
object IntNums extends Peano {
51+
type Nat = Int
52+
type Zero = Int
53+
type Succ = Int
54+
55+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = (m / n, m % n)
56+
57+
val Zero: Zero = 0
58+
59+
object Succ extends SuccExtractor {
60+
def apply(nat: Nat): Succ = nat + 1
61+
def unapply(nat: Nat) = new SuccOpt {
62+
def isEmpty: Boolean = nat == 0
63+
def refinedScrutinee: Succ & nat.type = nat
64+
def get: Int = nat - 1
65+
}
66+
}
67+
68+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
69+
def pred: Nat = succ - 1
70+
}
71+
}
72+
73+
object ClassNums extends Peano {
74+
trait NatTrait
75+
object ZeroObject extends NatTrait {
76+
override def toString: String = "ZeroObject"
77+
}
78+
case class SuccClass(predecessor: NatTrait) extends NatTrait with SuccOpt {
79+
def isEmpty: Boolean = false
80+
def refinedScrutinee: this.type = this
81+
def get: NatTrait = this
82+
}
83+
84+
object SuccNoMatch extends SuccOpt {
85+
def isEmpty: Boolean = true
86+
def refinedScrutinee: Nothing = throw new NoSuchElementException
87+
def get: NatTrait = throw new NoSuchElementException
88+
}
89+
90+
type Nat = NatTrait
91+
type Zero = ZeroObject.type
92+
type Succ = SuccClass
93+
94+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = {
95+
def intValue(x: Nat, acc: Int): Int = x match {
96+
case nat: SuccClass => intValue(nat.predecessor, acc + 1)
97+
case _ => acc
98+
}
99+
def natValue(x: Int): Nat =
100+
if (x == 0) ZeroObject
101+
else new SuccClass(natValue(x - 1))
102+
val i = intValue(m, 0)
103+
val j = intValue(n, 0)
104+
(natValue(i / j), natValue(i % j))
105+
}
106+
107+
val Zero: Zero = ZeroObject
108+
109+
object Succ extends SuccExtractor {
110+
def apply(nat: Nat): Succ = new SuccClass(nat)
111+
def unapply(nat: Nat) = nat match {
112+
case nat: (SuccClass & nat.type) => nat
113+
case _ => SuccNoMatch
114+
}
115+
}
116+
117+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
118+
def pred: Nat = succ.predecessor
119+
}
120+
121+
}

tests/run/refined-binding.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ok
2+
9

tests/run/refined-binding.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
sealed trait Foo {
3+
4+
type X
5+
type Y <: X
6+
7+
def x: X
8+
9+
def f(y: Y) = println("ok")
10+
11+
object Z {
12+
def unapply(arg: X) = new Opt {
13+
type Scrutinee = arg.type
14+
def refinedScrutinee: Y & Scrutinee = arg.asInstanceOf[Y & Scrutinee]
15+
}
16+
}
17+
18+
abstract class Opt {
19+
type Scrutinee <: Singleton
20+
def refinedScrutinee: Y & Scrutinee
21+
def get: Int = 9
22+
def isEmpty: Boolean = false
23+
}
24+
}
25+
26+
object Test {
27+
def main(args: Array[String]): Unit = {
28+
test(new Foo { type X = Int; type Y = Int; def x: X = 1 })
29+
}
30+
31+
def test(foo: Foo): Unit = {
32+
foo.x match {
33+
case x @ foo.Z(i) => // `x` is refined to type `Y`
34+
foo.f(x)
35+
println(i)
36+
}
37+
}
38+
}

0 commit comments

Comments
 (0)