Skip to content

Commit 006ee48

Browse files
committed
Add type refinement for abstract type bindings
1 parent 3e44d53 commit 006ee48

10 files changed

+340
-4
lines changed

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ object StdNames {
488488
val productPrefix: N = "productPrefix"
489489
val raw_ : N = "raw"
490490
val readResolve: N = "readResolve"
491+
val refinedScrutinee: N = "refinedScrutinee"
491492
val reflect : N = "reflect"
492493
val reflectiveSelectable: N = "reflectiveSelectable"
493494
val reify : N = "reify"

compiler/src/dotty/tools/dotc/transform/FirstTransform.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,29 @@ class FirstTransform extends MiniPhase with InfoTransformer { thisPhase =>
108108
cpy.Template(impl)(self = EmptyValDef)
109109
}
110110

111-
override def transformDefDef(ddef: DefDef)(implicit ctx: Context) = {
111+
override def transformDefDef(ddef: DefDef)(implicit ctx: Context): Tree = {
112+
if (ddef.name == nme.unapply && !ddef.symbol.is(Synthetic)) {
113+
val resultType = ddef.tpt.tpe.finalResultType
114+
// println("Checking: ")
115+
// println(ddef.show)
116+
// println()
117+
// println()
118+
// println()
119+
// val refined = resultType.select(tpnme.UncheckedRefinedArgument)
120+
// if (refined.typeSymbol.exists) {
121+
// val arg = ddef.vparamss.head.head
122+
// if (!(refined <:< arg.tpt.tpe)) {
123+
// ctx.error(i"Abstract type refinement $refined (${refined.widenDealias}) is not a subtype of the unapply argument (${arg.tpt.tpe.widenDealias})", ddef.pos)
124+
// } else if (!(refined <:< arg.tpe)) {
125+
// ctx.warning(
126+
// i"""Abstract type refinement $refined (${refined.widenDealias}) should be a subtype of the unapply argument singleton type(${arg.name}.type).
127+
// |
128+
// |This type constraint can be added as follows:
129+
// | def unapply(x: T): U { type ${tpnme.UncheckedRefinedArgument} <: x.type } = ...
130+
// """.stripMargin, ddef.pos)
131+
// }
132+
// }
133+
}
112134
if (ddef.symbol.hasAnnotation(defn.NativeAnnot)) {
113135
ddef.symbol.resetFlag(Deferred)
114136
DefDef(ddef.symbol.asTerm,

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
896896
* whereas overloaded variants need to have a conforming variant.
897897
*/
898898
def trySelectUnapply(qual: untpd.Tree)(fallBack: Tree => Tree): Tree = {
899-
// try first for non-overloaded, then for overloaded ocurrences
899+
// try first for non-overloaded, then for overloaded occurrences
900900
def tryWithName(name: TermName)(fallBack: Tree => Tree)(implicit ctx: Context): Tree = {
901901
def tryWithProto(pt: Type)(implicit ctx: Context) = {
902902
val result = typedExpr(untpd.Select(qual, name), new UnapplyFunProto(pt, this))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,10 +1323,29 @@ class Typer extends Namer
13231323
else {
13241324
// for a singleton pattern like `x @ Nil`, `x` should get the type from the scrutinee
13251325
// see tests/neg/i3200b.scala and SI-1503
1326-
val symTp =
1326+
val symTp0 =
13271327
if (body1.tpe.isInstanceOf[TermRef]) pt1
13281328
else body1.tpe.underlyingIfRepeated(isJava = false)
1329-
val sym = ctx.newPatternBoundSymbol(tree.name, symTp, tree.pos)
1329+
1330+
// If it is name based pattern matching, the type of the argument of the unapply is abstract and
1331+
// the return type has a type member `Refined`, then refine the type of the binding with the type of `Refined`.
1332+
val symTp1 = body1 match {
1333+
case Trees.UnApply(fun, _, _) if symTp0.typeSymbol.is(Deferred) =>
1334+
// TODO check that it is name based pattern matching
1335+
fun.tpe.widen match {
1336+
case mt: MethodType =>
1337+
val resultType = mt.resType.substParam(mt.paramRefs.head, symTp0)
1338+
val refinedType = resultType.select(nme.refinedScrutinee).widen.resultType
1339+
if (refinedType.exists) refinedType
1340+
else symTp0
1341+
case _ =>
1342+
symTp0
1343+
}
1344+
1345+
case _ => symTp0
1346+
}
1347+
1348+
val sym = ctx.newPatternBoundSymbol(tree.name, symTp1, tree.pos)
13301349
if (ctx.mode.is(Mode.InPatternAlternative))
13311350
ctx.error(i"Illegal variable ${sym.name} in pattern alternative", tree.pos)
13321351
assignType(cpy.Bind(tree)(tree.name, body1), sym)

tests/run/refined-binding-nat-safe-upper-bound.check

Whitespace-only changes.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
2+
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
app(ClassNums)
6+
app(IntNums)
7+
}
8+
9+
def app(peano: Peano): Unit = {
10+
import peano._
11+
// FIXME
12+
// def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
13+
// n match {
14+
// case Zero => None
15+
// case s @ Succ(_) => Some(safeDiv(m, s))
16+
// }
17+
// }
18+
// val two = Succ(Succ(Zero))
19+
// val five = Succ(Succ(Succ(two)))
20+
// println(divOpt(five, two))
21+
// println(divOpt(two, five))
22+
// println(divOpt(two, Zero))
23+
}
24+
}
25+
26+
trait Peano {
27+
type Nat
28+
type Zero <: Nat
29+
type Succ <: Nat
30+
31+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
32+
33+
implicit def succDeco(succ: Succ): SuccAPI
34+
trait SuccAPI {
35+
def pred: Nat
36+
}
37+
38+
val Zero: Zero
39+
40+
val Succ: SuccExtractor
41+
trait SuccExtractor {
42+
def apply(nat: Nat): Succ
43+
def unapply(nat: Nat): SuccOpt { type UncheckedRefinedArgument <: nat.type }
44+
}
45+
trait SuccOpt {
46+
type UncheckedRefinedArgument <: Succ
47+
def isEmpty: Boolean
48+
def get: Nat
49+
}
50+
}
51+
52+
object IntNums extends Peano {
53+
type Nat = Int
54+
type Zero = Int
55+
type Succ = Int
56+
57+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = (m / n, m % n)
58+
59+
val Zero: Zero = 0
60+
61+
object Succ extends SuccExtractor {
62+
def apply(nat: Nat): Succ = nat + 1
63+
def unapply(nat: Nat) = new SuccOpt {
64+
type UncheckedRefinedArgument = nat.type
65+
def isEmpty: Boolean = nat == 0
66+
def get: Int = nat - 1
67+
}
68+
}
69+
70+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
71+
def pred: Nat = succ - 1
72+
}
73+
}
74+
75+
object ClassNums extends Peano {
76+
trait NatTrait
77+
object ZeroObject extends NatTrait {
78+
override def toString: String = "ZeroObject"
79+
}
80+
case class SuccClass(predecessor: NatTrait) extends NatTrait with SuccOpt {
81+
type UncheckedRefinedArgument = this.type
82+
def isEmpty: Boolean = false
83+
def get: NatTrait = this
84+
}
85+
86+
object SuccNoMatch extends SuccOpt {
87+
type UncheckedRefinedArgument = Nothing
88+
def isEmpty: Boolean = true
89+
def get: NatTrait = throw new NoSuchElementException
90+
}
91+
92+
type Nat = NatTrait
93+
type Zero = ZeroObject.type
94+
type Succ = SuccClass
95+
96+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = {
97+
def intValue(x: Nat, acc: Int): Int = x match {
98+
case nat: SuccClass => intValue(nat.predecessor, acc + 1)
99+
case _ => acc
100+
}
101+
def natValue(x: Int): Nat =
102+
if (x == 0) ZeroObject
103+
else new SuccClass(natValue(x - 1))
104+
val i = intValue(m, 0)
105+
val j = intValue(n, 0)
106+
(natValue(i / j), natValue(i % j))
107+
}
108+
109+
val Zero: Zero = ZeroObject
110+
111+
object Succ extends SuccExtractor {
112+
def apply(nat: Nat): Succ = new SuccClass(nat)
113+
def unapply(nat: Nat) = nat match {
114+
case nat: (SuccClass & nat.type) => nat
115+
case _ => SuccNoMatch
116+
}
117+
}
118+
119+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
120+
def pred: Nat = succ.predecessor
121+
}
122+
123+
}

tests/run/refined-binding-nat.check

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Some((SuccClass(SuccClass(ZeroObject)),SuccClass(ZeroObject)))
2+
Some((ZeroObject,SuccClass(SuccClass(ZeroObject))))
3+
None
4+
Some((2,1))
5+
Some((0,2))
6+
None

tests/run/refined-binding-nat.scala

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
app(ClassNums)
5+
app(IntNums)
6+
}
7+
8+
def app(peano: Peano): Unit = {
9+
import peano._
10+
def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
11+
n match {
12+
case Zero => None
13+
case s @ Succ(_) => Some(safeDiv(m, s))
14+
}
15+
}
16+
val two = Succ(Succ(Zero))
17+
val five = Succ(Succ(Succ(two)))
18+
println(divOpt(five, two))
19+
println(divOpt(two, five))
20+
println(divOpt(two, Zero))
21+
}
22+
}
23+
24+
trait Peano {
25+
type Nat
26+
type Zero <: Nat
27+
type Succ <: Nat
28+
29+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
30+
31+
implicit def succDeco(succ: Succ): SuccAPI
32+
trait SuccAPI {
33+
def pred: Nat
34+
}
35+
36+
val Zero: Zero
37+
38+
val Succ: SuccExtractor
39+
trait SuccExtractor {
40+
def apply(nat: Nat): Succ
41+
def unapply(nat: Nat): SuccOpt
42+
}
43+
trait SuccOpt {
44+
type Scrutinee <: Singleton
45+
def isEmpty: Boolean
46+
def refinedScrutinee: Succ & Scrutinee
47+
def get: Nat
48+
}
49+
}
50+
51+
object IntNums extends Peano {
52+
type Nat = Int
53+
type Zero = Int
54+
type Succ = Int
55+
56+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = (m / n, m % n)
57+
58+
val Zero: Zero = 0
59+
60+
object Succ extends SuccExtractor {
61+
def apply(nat: Nat): Succ = nat + 1
62+
def unapply(nat: Nat): SuccOpt = new SuccOpt {
63+
type Scrutinee = nat.type
64+
def isEmpty: Boolean = nat == 0
65+
def refinedScrutinee: Succ & Scrutinee = nat
66+
def get: Int = nat - 1
67+
}
68+
}
69+
70+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
71+
def pred: Nat = succ - 1
72+
}
73+
}
74+
75+
object ClassNums extends Peano {
76+
trait NatTrait
77+
object ZeroObject extends NatTrait {
78+
override def toString: String = "ZeroObject"
79+
}
80+
case class SuccClass(predecessor: NatTrait) extends NatTrait with SuccOpt {
81+
type Scrutinee = this.type
82+
def isEmpty: Boolean = false
83+
def refinedScrutinee: this.type = this
84+
def get: NatTrait = this
85+
}
86+
87+
object SuccNoMatch extends SuccOpt {
88+
type Scrutinee = Nothing
89+
def isEmpty: Boolean = true
90+
def refinedScrutinee: Nothing = throw new NoSuchElementException
91+
def get: NatTrait = throw new NoSuchElementException
92+
}
93+
94+
type Nat = NatTrait
95+
type Zero = ZeroObject.type
96+
type Succ = SuccClass
97+
98+
def safeDiv(m: Nat, n: Succ): (Nat, Nat) = {
99+
def intValue(x: Nat, acc: Int): Int = x match {
100+
case nat: SuccClass => intValue(nat.predecessor, acc + 1)
101+
case _ => acc
102+
}
103+
def natValue(x: Int): Nat =
104+
if (x == 0) ZeroObject
105+
else new SuccClass(natValue(x - 1))
106+
val i = intValue(m, 0)
107+
val j = intValue(n, 0)
108+
(natValue(i / j), natValue(i % j))
109+
}
110+
111+
val Zero: Zero = ZeroObject
112+
113+
object Succ extends SuccExtractor {
114+
def apply(nat: Nat): Succ = new SuccClass(nat)
115+
def unapply(nat: Nat): SuccOpt = nat match {
116+
case nat: SuccClass => nat
117+
case _ => SuccNoMatch
118+
}
119+
}
120+
121+
def succDeco(succ: Succ): SuccAPI = new SuccAPI {
122+
def pred: Nat = succ.predecessor
123+
}
124+
125+
}

tests/run/refined-binding.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ok
2+
9

tests/run/refined-binding.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
sealed trait Foo {
3+
4+
type X
5+
type Y <: X
6+
7+
def x: X
8+
9+
def f(y: Y) = println("ok")
10+
11+
object Z {
12+
def unapply(arg: X) = new Opt {
13+
type Scrutinee = arg.type
14+
def refinedScrutinee: Y & Scrutinee = arg.asInstanceOf[Y & Scrutinee]
15+
}
16+
}
17+
18+
abstract class Opt {
19+
type Scrutinee <: Singleton
20+
def refinedScrutinee: Y & Scrutinee
21+
def get: Int = 9
22+
def isEmpty: Boolean = false
23+
}
24+
}
25+
26+
object Test {
27+
def main(args: Array[String]): Unit = {
28+
test(new Foo { type X = Int; type Y = Int; def x: X = 1 })
29+
}
30+
31+
def test(foo: Foo): Unit = {
32+
foo.x match {
33+
case x @ foo.Z(i) => // `x` is refined to type `Y`
34+
foo.f(x)
35+
println(i)
36+
}
37+
}
38+
}

0 commit comments

Comments
 (0)