Skip to content

Commit a0aeea7

Browse files
committed
Add type refinement for abstract type bindings
1 parent 95ec081 commit a0aeea7

10 files changed

+324
-4
lines changed

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ object StdNames {
216216
final val Type : N = "Type"
217217
final val TypeTree: N = "TypeTree"
218218

219+
final val UncheckedRefinedArgument: N = "UncheckedRefinedArgument"
220+
219221
// Annotation simple names, used in Namer
220222
final val BeanPropertyAnnot: N = "BeanProperty"
221223
final val BooleanBeanPropertyAnnot: N = "BooleanBeanProperty"

compiler/src/dotty/tools/dotc/transform/FirstTransform.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,24 @@ class FirstTransform extends MiniPhase with InfoTransformer { thisPhase =>
108108
cpy.Template(impl)(self = EmptyValDef)
109109
}
110110

111-
override def transformDefDef(ddef: DefDef)(implicit ctx: Context) = {
111+
override def transformDefDef(ddef: DefDef)(implicit ctx: Context): Tree = {
112+
if (ddef.name == nme.unapply && !ddef.symbol.is(Synthetic)) {
113+
val resultType = ddef.tpt.tpe.finalResultType
114+
val refined = resultType.select(tpnme.UncheckedRefinedArgument)
115+
if (refined.typeSymbol.exists) {
116+
val arg = ddef.vparamss.head.head
117+
if (!(refined <:< arg.tpt.tpe)) {
118+
ctx.error(i"Abstract type refinement $refined (${refined.widenDealias}) is not a subtype of the unapply argument (${arg.tpt.tpe.widenDealias})", ddef.pos)
119+
} else if (!(refined <:< arg.tpe)) {
120+
ctx.warning(
121+
i"""Abstract type refinement $refined (${refined.widenDealias}) should be a subtype of the unapply argument singleton type(${arg.name}.type).
122+
|
123+
|This type constraint can be added as follows:
124+
| def unapply(x: T): U { type ${tpnme.UncheckedRefinedArgument} <: x.type } = ...
125+
""".stripMargin, ddef.pos)
126+
}
127+
}
128+
}
112129
if (ddef.symbol.hasAnnotation(defn.NativeAnnot)) {
113130
ddef.symbol.resetFlag(Deferred)
114131
DefDef(ddef.symbol.asTerm,

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
896896
* whereas overloaded variants need to have a conforming variant.
897897
*/
898898
def trySelectUnapply(qual: untpd.Tree)(fallBack: Tree => Tree): Tree = {
899-
// try first for non-overloaded, then for overloaded ocurrences
899+
// try first for non-overloaded, then for overloaded occurrences
900900
def tryWithName(name: TermName)(fallBack: Tree => Tree)(implicit ctx: Context): Tree = {
901901
def tryWithProto(pt: Type)(implicit ctx: Context) = {
902902
val result = typedExpr(untpd.Select(qual, name), new UnapplyFunProto(pt, this))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,10 +1322,23 @@ class Typer extends Namer
13221322
else {
13231323
// for a singleton pattern like `x @ Nil`, `x` should get the type from the scrutinee
13241324
// see tests/neg/i3200b.scala and SI-1503
1325-
val symTp =
1325+
val symTp0 =
13261326
if (body1.tpe.isInstanceOf[TermRef]) pt1
13271327
else body1.tpe.underlyingIfRepeated(isJava = false)
1328-
val sym = ctx.newPatternBoundSymbol(tree.name, symTp, tree.pos)
1328+
1329+
// If it is name based pattern matching, the type of the argument of the unapply is abstract and
1330+
// the return type has a type member `Refined`, then refine the type of the binding with the type of `Refined`.
1331+
val symTp1 = body1 match {
1332+
case Trees.UnApply(fun, _, _) if symTp0.typeSymbol.is(Deferred) =>
1333+
// TODO check that it is name based pattern matching
1334+
val resultType = fun.tpe.widen.finalResultType
1335+
val refined = resultType.select(tpnme.UncheckedRefinedArgument)
1336+
if (refined.typeSymbol.exists) refined & symTp0
1337+
else symTp0
1338+
case _ => symTp0
1339+
}
1340+
1341+
val sym = ctx.newPatternBoundSymbol(tree.name, symTp1, tree.pos)
13291342
if (ctx.mode.is(Mode.InPatternAlternative))
13301343
ctx.error(i"Illegal variable ${sym.name} in pattern alternative", tree.pos)
13311344
assignType(cpy.Bind(tree)(tree.name, body1), sym)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Some((SuccClass(SuccClass(ZeroObject)),SuccClass(ZeroObject)))
2+
Some((ZeroObject,SuccClass(SuccClass(ZeroObject))))
3+
None
4+
Some((2,1))
5+
Some((0,2))
6+
None
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
2+
// FIXME
3+
4+
//object Test {
5+
// def main(args: Array[String]): Unit = {
6+
// app(ClassNums)
7+
// app(IntNums)
8+
// }
9+
//
10+
// def app(peano: Peano): Unit = {
11+
// import peano._
12+
// def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
13+
// n match {
14+
// case Zero => None
15+
// case s @ Succ(_) => Some(safeDiv(m, s))
16+
// }
17+
// }
18+
// val two = Succ(Succ(Zero))
19+
// val five = Succ(Succ(Succ(two)))
20+
// println(divOpt(five, two))
21+
// println(divOpt(two, five))
22+
// println(divOpt(two, Zero))
23+
// }
24+
//}
25+
26+
trait Peano {
27+
type Nat
28+
type Zero <: Nat
29+
type Succ <: Nat
30+
31+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
32+
33+
implicit def succDeco(succ: Succ): SuccAPI
34+
trait SuccAPI {
35+
def pred: Nat
36+
}
37+
38+
val Zero: Zero
39+
40+
val Succ: SuccExtractor
41+
trait SuccExtractor {
42+
def apply(nat: Nat): Succ
43+
def unapply(nat: Nat): SuccOpt { type UncheckedRefinedArgument <: nat.type }
44+
}
45+
trait SuccOpt {
46+
type UncheckedRefinedArgument <: Succ
47+
def isEmpty: Boolean
48+
def get: Nat
49+
}
50+
}
51+
52+
object IntNums extends Peano {
53+
type Nat = Int
54+
type Zero = Int
55+
type Succ = Int
56+
57+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = (m / n, m % n)
58+
59+
val Zero: Zero = 0
60+
61+
object Succ extends SuccExtractor {
62+
def apply(nat: Nat): Succ = nat + 1
63+
def unapply(nat: Nat) = new SuccOpt {
64+
type UncheckedRefinedArgument = nat.type
65+
def isEmpty: Boolean = nat == 0
66+
def get: Int = nat - 1
67+
}
68+
}
69+
70+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
71+
def pred: Nat = succ - 1
72+
}
73+
}
74+
75+
object ClassNums extends Peano {
76+
trait NatTrait
77+
object ZeroObject extends NatTrait {
78+
override def toString: String = "ZeroObject"
79+
}
80+
case class SuccClass(predecessor: NatTrait) extends NatTrait with SuccOpt {
81+
type UncheckedRefinedArgument = this.type
82+
def isEmpty: Boolean = false
83+
def get: NatTrait = this
84+
}
85+
86+
object SuccNoMatch extends SuccOpt {
87+
type UncheckedRefinedArgument = Nothing
88+
def isEmpty: Boolean = true
89+
def get: NatTrait = throw new NoSuchElementException
90+
}
91+
92+
type Nat = NatTrait
93+
type Zero = ZeroObject.type
94+
type Succ = SuccClass
95+
96+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = {
97+
def intValue(x: Nat, acc: Int): Int = x match {
98+
case nat: SuccClass => intValue(nat.predecessor, acc + 1)
99+
case _ => acc
100+
}
101+
def natValue(x: Int): Nat =
102+
if (x == 0) ZeroObject
103+
else new SuccClass(natValue(x - 1))
104+
val i = intValue(m, 0)
105+
val j = intValue(n, 0)
106+
(natValue(i / j), natValue(i % j))
107+
}
108+
109+
val Zero: Zero = ZeroObject
110+
111+
object Succ extends SuccExtractor {
112+
def apply(nat: Nat): Succ = new SuccClass(nat)
113+
def unapply(nat: Nat) = nat match {
114+
case nat: (SuccClass & nat.type) => nat
115+
case _ => SuccNoMatch
116+
}
117+
}
118+
119+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
120+
def pred: Nat = succ.predecessor
121+
}
122+
123+
}

tests/run/refined-binding-nat.check

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Some((SuccClass(SuccClass(ZeroObject)),SuccClass(ZeroObject)))
2+
Some((ZeroObject,SuccClass(SuccClass(ZeroObject))))
3+
None
4+
Some((2,1))
5+
Some((0,2))
6+
None

tests/run/refined-binding-nat.scala

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
app(ClassNums)
5+
app(IntNums)
6+
}
7+
8+
def app(peano: Peano): Unit = {
9+
import peano._
10+
def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
11+
n match {
12+
case Zero => None
13+
case s @ Succ(_) => Some(safeDiv(m, s))
14+
}
15+
}
16+
val two = Succ(Succ(Zero))
17+
val five = Succ(Succ(Succ(two)))
18+
println(divOpt(five, two))
19+
println(divOpt(two, five))
20+
println(divOpt(two, Zero))
21+
}
22+
}
23+
24+
trait Peano {
25+
type Nat
26+
type Zero <: Nat
27+
type Succ <: Nat
28+
29+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
30+
31+
implicit def succDeco(succ: Succ): SuccAPI
32+
trait SuccAPI {
33+
def pred: Nat
34+
}
35+
36+
val Zero: Zero
37+
38+
val Succ: SuccExtractor
39+
trait SuccExtractor {
40+
def apply(nat: Nat): Succ
41+
def unapply(nat: Nat): SuccOpt
42+
}
43+
trait SuccOpt {
44+
type UncheckedRefinedArgument = Succ
45+
def isEmpty: Boolean
46+
def get: Nat
47+
}
48+
}
49+
50+
object IntNums extends Peano {
51+
type Nat = Int
52+
type Zero = Int
53+
type Succ = Int
54+
55+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = (m / n, m % n)
56+
57+
val Zero: Zero = 0
58+
59+
object Succ extends SuccExtractor {
60+
def apply(nat: Nat): Succ = nat + 1
61+
def unapply(nat: Nat): SuccOpt = new SuccOpt {
62+
def isEmpty: Boolean = nat == 0
63+
def get: Int = nat - 1
64+
}
65+
}
66+
67+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
68+
def pred: Nat = succ - 1
69+
}
70+
}
71+
72+
object ClassNums extends Peano {
73+
trait NatTrait
74+
object ZeroObject extends NatTrait {
75+
override def toString: String = "ZeroObject"
76+
}
77+
case class SuccClass(predecessor: NatTrait) extends NatTrait with SuccOpt {
78+
def isEmpty: Boolean = false
79+
def get: NatTrait = this
80+
}
81+
82+
object SuccNoMatch extends SuccOpt {
83+
def isEmpty: Boolean = true
84+
def get: NatTrait = throw new NoSuchElementException
85+
}
86+
87+
type Nat = NatTrait
88+
type Zero = ZeroObject.type
89+
type Succ = SuccClass
90+
91+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = {
92+
def intValue(x: Nat, acc: Int): Int = x match {
93+
case nat: SuccClass => intValue(nat.predecessor, acc + 1)
94+
case _ => acc
95+
}
96+
def natValue(x: Int): Nat =
97+
if (x == 0) ZeroObject
98+
else new SuccClass(natValue(x - 1))
99+
val i = intValue(m, 0)
100+
val j = intValue(n, 0)
101+
(natValue(i / j), natValue(i % j))
102+
}
103+
104+
val Zero: Zero = ZeroObject
105+
106+
object Succ extends SuccExtractor {
107+
def apply(nat: Nat): Succ = new SuccClass(nat)
108+
def unapply(nat: Nat): SuccOpt = nat match {
109+
case nat: SuccClass => nat
110+
case _ => SuccNoMatch
111+
}
112+
}
113+
114+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
115+
def pred: Nat = succ.predecessor
116+
}
117+
118+
}

tests/run/refined-binding.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ok
2+
9

tests/run/refined-binding.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
trait Foo {
3+
4+
type X
5+
type Y <: X
6+
7+
def x: X
8+
9+
def f(y: Y) = println("ok")
10+
object Z {
11+
def unapply(arg: X): Opt = new Opt
12+
}
13+
14+
class Opt {
15+
type UncheckedRefinedArgument = Y
16+
def get: Int = 9
17+
def isEmpty: Boolean = false
18+
}
19+
}
20+
21+
object Test {
22+
def main(args: Array[String]): Unit = {
23+
test(new Foo { type X = Int; type Y = Int; def x: X = 1 })
24+
}
25+
26+
def test(foo: Foo): Unit = {
27+
foo.x match {
28+
case x @ foo.Z(i) => // `x` is refined to type `Y`
29+
foo.f(x)
30+
println(i)
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)