Skip to content

Commit 7e2a1ad

Browse files
committed
Add type refinement for abstract type bindings
1 parent 8b2500e commit 7e2a1ad

File tree

6 files changed

+175
-3
lines changed

6 files changed

+175
-3
lines changed

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
896896
* whereas overloaded variants need to have a conforming variant.
897897
*/
898898
def trySelectUnapply(qual: untpd.Tree)(fallBack: Tree => Tree): Tree = {
899-
// try first for non-overloaded, then for overloaded ocurrences
899+
// try first for non-overloaded, then for overloaded occurrences
900900
def tryWithName(name: TermName)(fallBack: Tree => Tree)(implicit ctx: Context): Tree = {
901901
def tryWithProto(pt: Type)(implicit ctx: Context) = {
902902
val result = typedExpr(untpd.Select(qual, name), new UnapplyFunProto(pt, this))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,10 +1322,23 @@ class Typer extends Namer
13221322
else {
13231323
// for a singleton pattern like `x @ Nil`, `x` should get the type from the scrutinee
13241324
// see tests/neg/i3200b.scala and SI-1503
1325-
val symTp =
1325+
val symTp0 =
13261326
if (body1.tpe.isInstanceOf[TermRef]) pt1
13271327
else body1.tpe.underlyingIfRepeated(isJava = false)
1328-
val sym = ctx.newPatternBoundSymbol(tree.name, symTp, tree.pos)
1328+
1329+
// If it is name based pattern matching, the type of the argument of the unapply is abstract and
1330+
// the return type has a type member `Refined`, then refine the type of the binding with the type of `Refined`.
1331+
val symTp1 = body1 match {
1332+
case Trees.UnApply(fun, _, _) if symTp0.typeSymbol.is(Deferred) =>
1333+
// TODO check that it is name based pattern matching
1334+
val resultType = fun.tpe.widen.finalResultType
1335+
val refined = resultType.select("Refined".toTypeName)
1336+
if (refined.exists) refined & symTp0
1337+
else symTp0
1338+
case _ => symTp0
1339+
}
1340+
1341+
val sym = ctx.newPatternBoundSymbol(tree.name, symTp1, tree.pos)
13291342
if (ctx.mode.is(Mode.InPatternAlternative))
13301343
ctx.error(i"Illegal variable ${sym.name} in pattern alternative", tree.pos)
13311344
assignType(cpy.Bind(tree)(tree.name, body1), sym)

tests/run/refined-binding-nat.check

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Some((SuccClass(SuccClass(ZeroObject)),SuccClass(ZeroObject)))
2+
Some((ZeroObject,SuccClass(SuccClass(ZeroObject))))
3+
None
4+
Some((2,1))
5+
Some((0,2))
6+
None

tests/run/refined-binding-nat.scala

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
app(ClassNums)
5+
app(IntNums)
6+
}
7+
8+
def app(peano: Peano): Unit = {
9+
import peano._
10+
def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
11+
n match {
12+
case Zero => None
13+
case s @ Succ(_) => Some(safeDiv(m, s))
14+
}
15+
}
16+
val two = Succ(Succ(Zero))
17+
val five = Succ(Succ(Succ(two)))
18+
println(divOpt(five, two))
19+
println(divOpt(two, five))
20+
println(divOpt(two, Zero))
21+
}
22+
}
23+
24+
trait Peano {
25+
type Nat
26+
type Zero <: Nat
27+
type Succ <: Nat
28+
29+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
30+
31+
implicit def succDeco(succ: Succ): SuccAPI
32+
trait SuccAPI {
33+
def pred: Nat
34+
}
35+
36+
val Zero: Zero
37+
38+
val Succ: SuccExtractor
39+
trait SuccExtractor {
40+
def apply(nat: Nat): Succ
41+
def unapply(nat: Nat): SuccOpt
42+
}
43+
trait SuccOpt {
44+
type Refined = Succ
45+
def isEmpty: Boolean
46+
def get: Nat
47+
}
48+
}
49+
50+
object IntNums extends Peano {
51+
type Nat = Int
52+
type Zero = Int
53+
type Succ = Int
54+
55+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = (m / n, m % n)
56+
57+
val Zero: Zero = 0
58+
59+
object Succ extends SuccExtractor {
60+
def apply(nat: Nat): Succ = nat + 1
61+
def unapply(nat: Nat): SuccOpt = new SuccOpt {
62+
def isEmpty: Boolean = nat == 0
63+
def get: Int = nat - 1
64+
}
65+
}
66+
67+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
68+
def pred: Nat = succ - 1
69+
}
70+
}
71+
72+
object ClassNums extends Peano {
73+
trait NatTrait
74+
object ZeroObject extends NatTrait {
75+
override def toString: String = "ZeroObject"
76+
}
77+
case class SuccClass(predecessor: NatTrait) extends NatTrait with SuccOpt {
78+
def isEmpty: Boolean = false
79+
def get: NatTrait = this
80+
}
81+
82+
object SuccNoMatch extends SuccOpt {
83+
def isEmpty: Boolean = true
84+
def get: NatTrait = throw new NoSuchElementException
85+
}
86+
87+
type Nat = NatTrait
88+
type Zero = ZeroObject.type
89+
type Succ = SuccClass
90+
91+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = {
92+
def intValue(x: Nat, acc: Int): Int = x match {
93+
case nat: SuccClass => intValue(nat.predecessor, acc + 1)
94+
case _ => acc
95+
}
96+
def natValue(x: Int): Nat =
97+
if (x == 0) ZeroObject
98+
else new SuccClass(natValue(x - 1))
99+
val i = intValue(m, 0)
100+
val j = intValue(n, 0)
101+
(natValue(i / j), natValue(i % j))
102+
}
103+
104+
val Zero: Zero = ZeroObject
105+
106+
object Succ extends SuccExtractor {
107+
def apply(nat: Nat): Succ = new SuccClass(nat)
108+
def unapply(nat: Nat): SuccOpt = nat match {
109+
case nat: SuccClass => nat
110+
case _ => SuccNoMatch
111+
}
112+
}
113+
114+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
115+
def pred: Nat = succ.predecessor
116+
}
117+
118+
}

tests/run/refined-binding.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ok
2+
9

tests/run/refined-binding.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
trait Foo {
3+
4+
type X
5+
type Y <: X
6+
7+
def x: X
8+
9+
def f(y: Y) = println("ok")
10+
object Z {
11+
def unapply(arg: X): Opt = new Opt
12+
}
13+
14+
class Opt {
15+
type Refined = Y
16+
def get: Int = 9
17+
def isEmpty: Boolean = false
18+
}
19+
}
20+
21+
object Test {
22+
def main(args: Array[String]): Unit = {
23+
test(new Foo { type X = Int; type Y = Int; def x: X = 1 })
24+
}
25+
26+
def test(foo: Foo): Unit = {
27+
foo.x match {
28+
case x @ foo.Z(i) => // `x` is refined to type `Y`
29+
foo.f(x)
30+
println(i)
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)