Skip to content

Commit ecb4207

Browse files
committed
Add some very basic GADT constraints from type cases
1 parent f304a07 commit ecb4207

File tree

8 files changed

+287
-4
lines changed

8 files changed

+287
-4
lines changed

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,11 +489,14 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
489489
case m @ MatchTypeTree(bounds, selector, cases) =>
490490
// Analog to the case above for match types
491491
def transformCase(x: CaseDef): CaseDef =
492-
cpy.CaseDef(tree)(
492+
val gadtCtx = x.pat.removeAttachment(typer.Typer.InferredGadtConstraints) match
493+
case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt))
494+
case None => ctx
495+
inContext(gadtCtx)(cpy.CaseDef(tree)(
493496
withMode(Mode.Pattern)(transform(x.pat)),
494497
transform(x.guard),
495498
transform(x.body),
496-
)
499+
))
497500
cpy.MatchTypeTree(tree)(
498501
super.transform(bounds),
499502
super.transform(selector),

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,10 @@ class Namer { typer: Typer =>
10031003

10041004
override final def typeSig(sym: Symbol): Type =
10051005
val tparamSyms = completerTypeParams(sym)(using ictx)
1006-
given ctx: Context = nestedCtx.nn
1006+
given ctx: Context = if tparamSyms.isEmpty then nestedCtx.nn else
1007+
given ctx: Context = nestedCtx.nn.fresh.setFreshGADTBounds
1008+
ctx.gadtState.addToConstraint(tparamSyms)
1009+
ctx
10071010

10081011
def abstracted(tp: TypeBounds): TypeBounds =
10091012
HKTypeLambda.boundsFromParams(tparamSyms, tp)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2005,13 +2005,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
20052005
}
20062006
if !ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef then
20072007
withMode(Mode.GadtConstraintInference) {
2008-
TypeComparer.constrainPatternType(pat1.tpe, selType)
2008+
selType match
2009+
case scr: TypeRef if ctx.gadt.contains(scr.symbol) => pat1.tpe match
2010+
case pat: TypeRef => scr <:< pat
2011+
case _ => TypeComparer.constrainPatternType(pat1.tpe, selType)
2012+
case _ => TypeComparer.constrainPatternType(pat1.tpe, selType)
20092013
}
20102014
val pat2 = indexPattern(cdef).transform(pat1)
20112015
var body1 = typedType(cdef.body, pt)
20122016
if !body1.isType then
20132017
assert(ctx.reporter.errorsReported)
20142018
body1 = TypeTree(errorType(em"<error: not a type>", cdef.srcPos))
2019+
else if ctx.gadt.isNarrowing then
2020+
pat2.putAttachment(InferredGadtConstraints, ctx.gadt)
20152021
assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1)
20162022
}
20172023
caseRest(using ctx.fresh.setFreshGADTBounds.setNewScope)

tests/pos/mini-onnx/Indices.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import scala.compiletime.ops.string.+
2+
import scala.compiletime.ops.any
3+
4+
type Index = Int & Singleton
5+
6+
sealed trait Indices
7+
8+
final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices:
9+
override def toString = s"$head ::: $tail"
10+
11+
sealed trait INil extends Indices
12+
case object INil extends INil
13+
14+
object Indices:
15+
type ToString[X <: Indices] <: String = X match
16+
case INil => "INil"
17+
case head ::: tail => any.ToString[head] + " ::: " + ToString[tail]
18+
19+
type Contains[Haystack <: Indices, Needle <: Index] <: Boolean = Haystack match
20+
case head ::: tail => head match
21+
case Needle => true
22+
case _ => Contains[tail, Needle]
23+
case INil => false
24+
25+
type RemoveValue[RemoveFrom <: Indices, Value <: Index] <: Indices = RemoveFrom match
26+
case INil => INil
27+
case head ::: tail => head match
28+
case Value => RemoveValue[tail, Value]
29+
case _ => head ::: RemoveValue[tail, Value]

tests/pos/mini-onnx/Shape.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import scala.compiletime.ops.int.{S, +, <, <=, *}
2+
import scala.compiletime.ops.boolean.&&
3+
4+
type Dimension = Int & Singleton
5+
6+
sealed trait Shape extends Product with Serializable
7+
8+
final case class #:[+H <: Dimension, +T <: Shape](head: H, tail: T) extends Shape:
9+
override def toString = (head: Any) match
10+
case _ #: _ => s"($head) #: $tail"
11+
case _ => s"$head #: $tail"
12+
13+
sealed trait SNil extends Shape
14+
case object SNil extends SNil
15+
16+
object Shape:
17+
def scalar: SNil = SNil
18+
19+
type Concat[X <: Shape, Y <: Shape] <: Shape = X match
20+
case SNil => Y
21+
case head #: tail => head #: Concat[tail, Y]
22+
23+
type Reverse[X <: Shape] <: Shape = X match
24+
case SNil => SNil
25+
case head #: tail => Concat[Reverse[tail], head #: SNil]
26+
27+
type NumElements[X <: Shape] <: Int = X match
28+
case SNil => 1
29+
case head #: tail => head * NumElements[tail]
30+
31+
type Rank[X <: Shape] <: Int = X match
32+
case SNil => 0
33+
case head #: tail => Rank[tail] + 1
34+
35+
type IsEmpty[X <: Shape] <: Boolean = X match
36+
case SNil => true
37+
case _ #: _ => false
38+
39+
type Head[X <: Shape] <: Dimension = X match { case head #: _ => head }
40+
type Tail[X <: Shape] <: Shape = X match { case _ #: tail => tail }
41+
42+
type Reduce[S <: Shape, Axes <: None.type | Indices] <: Shape = Axes match
43+
case None.type => SNil
44+
case Indices => ReduceLoop[S, Axes, 0]
45+
46+
protected type ReduceLoop[RemoveFrom <: Shape, ToRemove <: Indices, I <: Index] <: Shape = RemoveFrom match
47+
case head #: tail => Indices.Contains[ToRemove, I] match
48+
case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]]
49+
case false => head #: ReduceLoop[tail, ToRemove, S[I]]
50+
case SNil => ToRemove match { case INil => SNil }
51+
52+
type WithinBounds[I <: Index, S <: Shape] = (0 <= I && I < Rank[S])
53+
54+
type RemoveIndex[RemoveFrom <: Shape, I <: Index] <: Shape = WithinBounds[I, RemoveFrom] match
55+
case true => RemoveIndexLoop[RemoveFrom, I, 0]
56+
57+
protected type RemoveIndexLoop[RemoveFrom <: Shape, I <: Index, Current <: Index] <: Shape = RemoveFrom match
58+
case head #: tail => Current match
59+
case I => tail
60+
case _ => head #: RemoveIndexLoop[tail, I, S[Current]]
61+
62+
type Map[X <: Shape, F[_ <: Dimension] <: Dimension] <: Shape = X match
63+
case SNil => SNil
64+
case head #: tail => F[head] #: Map[tail, F]
65+
66+
type FoldLeft[B, X <: Shape, Z <: B, F[_ <: B, _ <: Int] <: B] <: B = X match
67+
case SNil => Z
68+
case head #: tail => FoldLeft[B, tail, F[Z, head], F]
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import scala.compiletime.ops.int.S
2+
3+
type DimensionDenotation = String & Singleton
4+
5+
sealed trait TensorShapeDenotation extends Product with Serializable
6+
7+
final case class ##:[+H <: DimensionDenotation, +T <: TensorShapeDenotation](head: H, tail: T) extends TensorShapeDenotation:
8+
override def toString = (head: Any) match
9+
case _ ##: _ => s"($head) ##: $tail"
10+
case _ => s"$head ##: $tail"
11+
12+
sealed trait TSNil extends TensorShapeDenotation
13+
case object TSNil extends TSNil
14+
15+
object TensorShapeDenotation:
16+
type Reduce[S <: TensorShapeDenotation, Axes <: None.type | Indices] <: TensorShapeDenotation = Axes match
17+
case None.type => TSNil
18+
case Indices => ReduceLoop[S, Axes, 0]
19+
20+
protected type ReduceLoop[RemoveFrom <: TensorShapeDenotation, ToRemove <: Indices, I <: Index] <: TensorShapeDenotation = RemoveFrom match
21+
case head ##: tail => Indices.Contains[ToRemove, I] match
22+
case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]]
23+
case false => head ##: ReduceLoop[tail, ToRemove, S[I]]
24+
case TSNil => ToRemove match { case INil => TSNil }

tests/pos/mini-onnx/Tensors.scala

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import scala.compiletime.ops.int.*
2+
3+
object Tensors:
4+
import Shape.Reverse
5+
6+
type Supported = Int | Long | Float | Double | Byte | Short | Boolean | String
7+
8+
type TensorTypeDenotation = String & Singleton
9+
10+
type Axes = Tuple3[TensorTypeDenotation, TensorShapeDenotation, Shape]
11+
12+
opaque type Tensor[T <: Supported, +Ax <: Axes] = Tuple2[Array[T], Ax]
13+
14+
type SparseTensor[T <: Supported, A <: Axes] = Tensor[T, A]
15+
16+
type KeepOrReduceDims[S <: Shape, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: Shape = KeepDims match
17+
case true => ReduceKeepDims[S, AxisIndices]
18+
case false => Shape.Reduce[S, AxisIndices]
19+
20+
type KeepOrReduceDimDenotations[Td <: TensorShapeDenotation, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: TensorShapeDenotation = KeepDims match
21+
case true => Td
22+
case false => TensorShapeDenotation.Reduce[Td, AxisIndices]
23+
24+
type ReduceKeepDims[S <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match
25+
case None.type => SNil
26+
case Indices => ReduceKeepDimsLoop[S, AxisIndices, 0]
27+
28+
protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match
29+
case head #: tail => Indices.Contains[ToReplace, I] match
30+
case true => 1 #: ReduceKeepDimsLoop[tail, Indices.RemoveValue[ToReplace, I], S[I]]
31+
case false => head #: ReduceKeepDimsLoop[tail, ToReplace, S[I]]
32+
case SNil => ToReplace match { case INil => SNil }
33+
34+
type AddGivenAxisSize[S <: Shape, S1 <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match
35+
case None.type => SNil
36+
case Indices => AddGivenAxisSizeLoop[S, S1, AxisIndices, 0]
37+
38+
protected type AddGivenAxisSizeLoop[First <: Shape, Second <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = First match
39+
case head #: tail => Indices.Contains[AxisIndex, I] match
40+
case true => Second match
41+
case secondHead #: secondTail => (head + secondHead) #: AddGivenAxisSizeLoop[tail, secondTail, Indices.RemoveValue[AxisIndex, I], S[I]]
42+
case SNil => AxisIndex match { case INil => SNil }
43+
case false => Second match
44+
case secondHead #: secondTail => (head) #: AddGivenAxisSizeLoop[tail, secondTail, AxisIndex, S[I]]
45+
case SNil => AxisIndex match { case INil => SNil }
46+
47+
type UnsqueezeShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match
48+
case None.type => SNil
49+
case Indices => UnsqueezeShapeLoop[S, AxisIndex, 0]
50+
51+
protected type UnsqueezeShapeLoop[ToUnsqueeze <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = ToUnsqueeze match
52+
case head #: tail => Indices.Contains[AxisIndex, I] match
53+
case true => 1 #: head #: UnsqueezeShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I]]
54+
case false => head #: UnsqueezeShapeLoop[tail, AxisIndex, S[I]]
55+
case SNil => AxisIndex match { case INil => SNil }
56+
57+
type GatheredShape[S <: Shape, AxisIndex <: None.type | Indices, AxisIndices <: Indices] <: Shape = AxisIndex match
58+
case None.type => SNil
59+
case Indices => GatheredShapeLoop[S, AxisIndex, 0, AxisIndices]
60+
61+
protected type GatheredShapeLoop[ToGather <: Shape, AxisIndex <: Indices, I <: Index, AxisIndices <: Indices] <: Shape = ToGather match
62+
case head #: tail => Indices.Contains[AxisIndex, I] match
63+
case true => IndicesSize[AxisIndices] #: GatheredShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], AxisIndices]
64+
case false => head #: GatheredShapeLoop[tail, AxisIndex, S[I], AxisIndices]
65+
case SNil => AxisIndex match { case INil => SNil }
66+
67+
type IndicesSize[AxisIndices <: Indices] = IndicesSizeLoop[AxisIndices, 0]
68+
69+
type IndicesSizeLoop[AxisIndices <: Indices, Acc <: Dimension] <: Dimension = AxisIndices match
70+
case head ::: tail => IndicesSizeLoop[tail, S[Acc]]
71+
case INil => Acc
72+
73+
type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match
74+
case None.type => SNil
75+
case Indices => FlattenedShapeLoop[S, AxisIndex, 0, 1]
76+
77+
protected type FlattenedShapeLoop[ToFlatten <: Shape, AxisIndex <: Indices, I <: Index, Acc <: Index] <: Shape = ToFlatten match
78+
case head #: tail => Indices.Contains[AxisIndex, I] match
79+
case true => Acc #: FlattenedShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], head]
80+
case false => FlattenedShapeLoop[tail, AxisIndex, S[I], head * Acc]
81+
case SNil => AxisIndex match { case INil => Acc #: SNil }
82+
83+
type SlicedShape[AxisIndicesStarts <: None.type | Indices, AxisIndicesEnds <: None.type | Indices] <: Shape = AxisIndicesStarts match
84+
case None.type => SNil
85+
case Indices => AxisIndicesEnds match
86+
case None.type => SNil
87+
case Indices => SlicedShapeLoop[AxisIndicesStarts, AxisIndicesEnds]
88+
89+
protected type SlicedShapeLoop[Starts <: Indices, Ends <: Indices] <: Shape = Starts match
90+
case head ::: tail => Ends match
91+
case endsHead ::: endsTail => (endsHead - head) #: SlicedShapeLoop[tail, endsTail]
92+
case INil => SNil
93+
case INil => Ends match { case INil => SNil }
94+
95+
type PaddedShape[PadFrom <: Shape, AxisBefore <: None.type | Shape, AxisAfter <: None.type | Shape] <: Shape = AxisBefore match
96+
case None.type => PadFrom
97+
case Shape => AxisAfter match
98+
case None.type => PadFrom
99+
case Shape => Reverse[PaddedShapeLoop[Reverse[PadFrom], Reverse[AxisBefore], Reverse[AxisAfter]]]
100+
101+
protected type PaddedShapeLoop[PadFrom <: Shape, Before <: Shape, After <: Shape] <: Shape = Before match
102+
case head #: tail => After match
103+
case afterHead #: afterTail => PadFrom match
104+
case padFromHead #: padFromTail => (head + padFromHead + afterHead) #: PaddedShapeLoop[padFromTail, tail, afterTail]
105+
case SNil => SNil
106+
case SNil => SNil
107+
case SNil => After match
108+
case SNil => PadFrom match
109+
case padFromHead #: padFromTail => padFromHead #: PaddedShapeLoop[padFromTail, SNil, SNil]
110+
case SNil => SNil
111+
112+
type TiledShape[TileFrom <: Shape, AxisRepeats <: None.type | Indices] <: Shape = AxisRepeats match
113+
case None.type => SNil
114+
case Indices => TiledShapeLoop[TileFrom, AxisRepeats]
115+
116+
protected type TiledShapeLoop[TileFrom <: Shape, Repeats <: Indices] <: Shape = Repeats match
117+
case head ::: tail => TileFrom match
118+
case tileFromHead #: tileFromTail => (head * tileFromHead) #: TiledShapeLoop[tileFromTail, tail]
119+
case SNil => SNil
120+
case INil => SNil
121+
122+
type PoolShape[From <: Shape, KernelShape <: None.type | Shape] <: Shape = KernelShape match
123+
case None.type => SNil
124+
case Shape => Reverse[PoolShapeLoop[Reverse[From], Reverse[KernelShape]]]
125+
126+
protected type PoolShapeLoop[From <: Shape, KernelShape <: Shape] <: Shape = KernelShape match
127+
case head #: tail => From match
128+
case fromHead #: fromTail => (fromHead - head + 1) #: PoolShapeLoop[fromTail, tail]
129+
case SNil => SNil
130+
case SNil => From

tests/pos/nano-onnx.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import scala.compiletime.ops.int.*
2+
3+
type Index = Int & Singleton
4+
type Dimension = Int & Singleton
5+
6+
sealed trait Indices extends Product with Serializable
7+
sealed trait Shape extends Product with Serializable
8+
final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices
9+
final case class #:[+H <: Dimension, +T <: Shape ](head: H, tail: T) extends Shape
10+
sealed trait INil extends Indices; case object INil extends INil
11+
sealed trait SNil extends Shape; case object SNil extends SNil
12+
13+
object Ts:
14+
type ReduceKeepDims[S <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match
15+
case None.type => SNil
16+
case Indices => ReduceKeepDimsLoop[S, AxisIndices, 0]
17+
18+
protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match
19+
case head #: tail => ReduceKeepDimsLoop[tail, ToReplace, S[I]]
20+
case SNil => ToReplace match { case INil => SNil }

0 commit comments

Comments
 (0)