Skip to content

Commit 13d9359

Browse files
authored
Merge pull request #15453 from dotty-staging/fix-superType
Don't normalize in AppliedType#superType
2 parents 140693d + a000229 commit 13d9359

File tree

3 files changed

+279
-261
lines changed

3 files changed

+279
-261
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
11121112
if tycon1sym == tycon2sym && tycon1sym.isAliasType then
11131113
val preConstraint = constraint
11141114
isSubArgs(args1, args2, tp1, tparams)
1115-
&& tryAlso(preConstraint, recur(tp1.superType, tp2.superType))
1115+
&& tryAlso(preConstraint, recur(tp1.superTypeNormalized, tp2.superTypeNormalized))
11161116
else
11171117
isSubArgs(args1, args2, tp1, tparams)
11181118
}
@@ -1177,7 +1177,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
11771177
*/
11781178
def compareLower(tycon2bounds: TypeBounds, tyconIsTypeRef: Boolean): Boolean =
11791179
if ((tycon2bounds.lo `eq` tycon2bounds.hi) && !tycon2bounds.isInstanceOf[MatchAlias])
1180-
if (tyconIsTypeRef) recur(tp1, tp2.superType)
1180+
if (tyconIsTypeRef) recur(tp1, tp2.superTypeNormalized)
11811181
else isSubApproxHi(tp1, tycon2bounds.lo.applyIfParameterized(args2))
11821182
else
11831183
fallback(tycon2bounds.lo)
@@ -1249,11 +1249,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
12491249

12501250
!sym.isClass && {
12511251
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
1252-
recur(tp1.superType, tp2) ||
1252+
recur(tp1.superTypeNormalized, tp2) ||
12531253
tryLiftedToThis1
12541254
}|| byGadtBounds
12551255
case tycon1: TypeProxy =>
1256-
recur(tp1.superType, tp2)
1256+
recur(tp1.superTypeNormalized, tp2)
12571257
case _ =>
12581258
false
12591259
}
@@ -2645,9 +2645,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
26452645
!(tp2 <:< tp1)
26462646
&& (provablyDisjoint(tp1, tp2.tp2) || provablyDisjoint(tp1, tp2.tp1))
26472647
case (tp1: NamedType, _) if gadtBounds(tp1.symbol) != null =>
2648-
provablyDisjoint(gadtBounds(tp1.symbol).uncheckedNN.hi, tp2) || provablyDisjoint(tp1.superType, tp2)
2648+
provablyDisjoint(gadtBounds(tp1.symbol).uncheckedNN.hi, tp2)
2649+
|| provablyDisjoint(tp1.superTypeNormalized, tp2)
26492650
case (_, tp2: NamedType) if gadtBounds(tp2.symbol) != null =>
2650-
provablyDisjoint(tp1, gadtBounds(tp2.symbol).uncheckedNN.hi) || provablyDisjoint(tp1, tp2.superType)
2651+
provablyDisjoint(tp1, gadtBounds(tp2.symbol).uncheckedNN.hi)
2652+
|| provablyDisjoint(tp1, tp2.superTypeNormalized)
26512653
case (tp1: TermRef, tp2: TermRef) if isEnumValueOrModule(tp1) && isEnumValueOrModule(tp2) =>
26522654
tp1.termSymbol != tp2.termSymbol
26532655
case (tp1: TermRef, tp2: TypeRef) if isEnumValue(tp1) =>
@@ -2663,11 +2665,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
26632665
case (tp1: Type, tp2: Type) if defn.isTupleNType(tp2) =>
26642666
provablyDisjoint(tp1, tp2.toNestedPairs)
26652667
case (tp1: TypeProxy, tp2: TypeProxy) =>
2666-
provablyDisjoint(tp1.superType, tp2) || provablyDisjoint(tp1, tp2.superType)
2668+
provablyDisjoint(tp1.superTypeNormalized, tp2) || provablyDisjoint(tp1, tp2.superTypeNormalized)
26672669
case (tp1: TypeProxy, _) =>
2668-
provablyDisjoint(tp1.superType, tp2)
2670+
provablyDisjoint(tp1.superTypeNormalized, tp2)
26692671
case (_, tp2: TypeProxy) =>
2670-
provablyDisjoint(tp1, tp2.superType)
2672+
provablyDisjoint(tp1, tp2.superTypeNormalized)
26712673
case _ =>
26722674
false
26732675
}
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
package dotty.tools
2+
package dotc
3+
package core
4+
5+
import Types.*, Contexts.*, Symbols.*, Constants.*, Decorators.*
6+
import config.Printers.typr
7+
import reporting.trace
8+
import StdNames.tpnme
9+
10+
object TypeEval:
11+
12+
def tryCompiletimeConstantFold(tp: AppliedType)(using Context): Type = tp.tycon match
13+
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
14+
extension (tp: Type) def fixForEvaluation: Type =
15+
tp.normalized.dealias match
16+
// enable operations for constant singleton terms. E.g.:
17+
// ```
18+
// final val one = 1
19+
// type Two = one.type + one.type
20+
// ```
21+
case tp: TypeProxy if tp.underlying.isStable => tp.underlying.fixForEvaluation
22+
case tp => tp
23+
24+
def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match
25+
case ConstantType(Constant(n)) => Some(n)
26+
case _ => None
27+
28+
def boolValue(tp: Type): Option[Boolean] = tp.fixForEvaluation match
29+
case ConstantType(Constant(n: Boolean)) => Some(n)
30+
case _ => None
31+
32+
def intValue(tp: Type): Option[Int] = tp.fixForEvaluation match
33+
case ConstantType(Constant(n: Int)) => Some(n)
34+
case _ => None
35+
36+
def longValue(tp: Type): Option[Long] = tp.fixForEvaluation match
37+
case ConstantType(Constant(n: Long)) => Some(n)
38+
case _ => None
39+
40+
def floatValue(tp: Type): Option[Float] = tp.fixForEvaluation match
41+
case ConstantType(Constant(n: Float)) => Some(n)
42+
case _ => None
43+
44+
def doubleValue(tp: Type): Option[Double] = tp.fixForEvaluation match
45+
case ConstantType(Constant(n: Double)) => Some(n)
46+
case _ => None
47+
48+
def stringValue(tp: Type): Option[String] = tp.fixForEvaluation match
49+
case ConstantType(Constant(n: String)) => Some(n)
50+
case _ => None
51+
52+
// Returns Some(true) if the type is a constant.
53+
// Returns Some(false) if the type is not a constant.
54+
// Returns None if there is not enough information to determine if the type is a constant.
55+
// The type is a constant if it is a constant type or a type operation composition of constant types.
56+
// If we get a type reference for an argument, then the result is not yet known.
57+
def isConst(tp: Type): Option[Boolean] = tp.dealias match
58+
// known to be constant
59+
case ConstantType(_) => Some(true)
60+
// currently not a concrete known type
61+
case TypeRef(NoPrefix,_) => None
62+
// currently not a concrete known type
63+
case _: TypeParamRef => None
64+
// constant if the term is constant
65+
case t: TermRef => isConst(t.underlying)
66+
// an operation type => recursively check all argument compositions
67+
case applied: AppliedType if defn.isCompiletimeAppliedType(applied.typeSymbol) =>
68+
val argsConst = applied.args.map(isConst)
69+
if (argsConst.exists(_.isEmpty)) None
70+
else Some(argsConst.forall(_.get))
71+
// all other types are considered not to be constant
72+
case _ => Some(false)
73+
74+
def expectArgsNum(expectedNum: Int): Unit =
75+
// We can use assert instead of a compiler type error because this error should not
76+
// occur since the type signature of the operation enforces the proper number of args.
77+
assert(tp.args.length == expectedNum, s"Type operation expects $expectedNum arguments but found ${tp.args.length}")
78+
79+
def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)
80+
81+
// Runs the op and returns the result as a constant type.
82+
// If the op throws an exception, then this exception is converted into a type error.
83+
def runConstantOp(op: => Any): Type =
84+
val result =
85+
try op
86+
catch case e: Throwable =>
87+
throw new TypeError(e.getMessage.nn)
88+
ConstantType(Constant(result))
89+
90+
def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
91+
expectArgsNum(1)
92+
extractor(tp.args.head).map(a => runConstantOp(op(a)))
93+
94+
def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
95+
constantFold2AB(extractor, extractor, op)
96+
97+
def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] =
98+
expectArgsNum(2)
99+
for
100+
a <- extractorA(tp.args(0))
101+
b <- extractorB(tp.args(1))
102+
yield runConstantOp(op(a, b))
103+
104+
def constantFold3[TA, TB, TC](
105+
extractorA: Type => Option[TA],
106+
extractorB: Type => Option[TB],
107+
extractorC: Type => Option[TC],
108+
op: (TA, TB, TC) => Any
109+
): Option[Type] =
110+
expectArgsNum(3)
111+
for
112+
a <- extractorA(tp.args(0))
113+
b <- extractorB(tp.args(1))
114+
c <- extractorC(tp.args(2))
115+
yield runConstantOp(op(a, b, c))
116+
117+
trace(i"compiletime constant fold $tp", typr, show = true) {
118+
val name = tycon.symbol.name
119+
val owner = tycon.symbol.owner
120+
val constantType =
121+
if defn.isCompiletime_S(tycon.symbol) then
122+
constantFold1(natValue, _ + 1)
123+
else if owner == defn.CompiletimeOpsAnyModuleClass then name match
124+
case tpnme.Equals => constantFold2(constValue, _ == _)
125+
case tpnme.NotEquals => constantFold2(constValue, _ != _)
126+
case tpnme.ToString => constantFold1(constValue, _.toString)
127+
case tpnme.IsConst => isConst(tp.args.head).map(b => ConstantType(Constant(b)))
128+
case _ => None
129+
else if owner == defn.CompiletimeOpsIntModuleClass then name match
130+
case tpnme.Abs => constantFold1(intValue, _.abs)
131+
case tpnme.Negate => constantFold1(intValue, x => -x)
132+
// ToString is deprecated for ops.int, and moved to ops.any
133+
case tpnme.ToString => constantFold1(intValue, _.toString)
134+
case tpnme.Plus => constantFold2(intValue, _ + _)
135+
case tpnme.Minus => constantFold2(intValue, _ - _)
136+
case tpnme.Times => constantFold2(intValue, _ * _)
137+
case tpnme.Div => constantFold2(intValue, _ / _)
138+
case tpnme.Mod => constantFold2(intValue, _ % _)
139+
case tpnme.Lt => constantFold2(intValue, _ < _)
140+
case tpnme.Gt => constantFold2(intValue, _ > _)
141+
case tpnme.Ge => constantFold2(intValue, _ >= _)
142+
case tpnme.Le => constantFold2(intValue, _ <= _)
143+
case tpnme.Xor => constantFold2(intValue, _ ^ _)
144+
case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
145+
case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
146+
case tpnme.ASR => constantFold2(intValue, _ >> _)
147+
case tpnme.LSL => constantFold2(intValue, _ << _)
148+
case tpnme.LSR => constantFold2(intValue, _ >>> _)
149+
case tpnme.Min => constantFold2(intValue, _ min _)
150+
case tpnme.Max => constantFold2(intValue, _ max _)
151+
case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
152+
case tpnme.ToLong => constantFold1(intValue, _.toLong)
153+
case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
154+
case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
155+
case _ => None
156+
else if owner == defn.CompiletimeOpsLongModuleClass then name match
157+
case tpnme.Abs => constantFold1(longValue, _.abs)
158+
case tpnme.Negate => constantFold1(longValue, x => -x)
159+
case tpnme.Plus => constantFold2(longValue, _ + _)
160+
case tpnme.Minus => constantFold2(longValue, _ - _)
161+
case tpnme.Times => constantFold2(longValue, _ * _)
162+
case tpnme.Div => constantFold2(longValue, _ / _)
163+
case tpnme.Mod => constantFold2(longValue, _ % _)
164+
case tpnme.Lt => constantFold2(longValue, _ < _)
165+
case tpnme.Gt => constantFold2(longValue, _ > _)
166+
case tpnme.Ge => constantFold2(longValue, _ >= _)
167+
case tpnme.Le => constantFold2(longValue, _ <= _)
168+
case tpnme.Xor => constantFold2(longValue, _ ^ _)
169+
case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
170+
case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
171+
case tpnme.ASR => constantFold2(longValue, _ >> _)
172+
case tpnme.LSL => constantFold2(longValue, _ << _)
173+
case tpnme.LSR => constantFold2(longValue, _ >>> _)
174+
case tpnme.Min => constantFold2(longValue, _ min _)
175+
case tpnme.Max => constantFold2(longValue, _ max _)
176+
case tpnme.NumberOfLeadingZeros =>
177+
constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_))
178+
case tpnme.ToInt => constantFold1(longValue, _.toInt)
179+
case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
180+
case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
181+
case _ => None
182+
else if owner == defn.CompiletimeOpsFloatModuleClass then name match
183+
case tpnme.Abs => constantFold1(floatValue, _.abs)
184+
case tpnme.Negate => constantFold1(floatValue, x => -x)
185+
case tpnme.Plus => constantFold2(floatValue, _ + _)
186+
case tpnme.Minus => constantFold2(floatValue, _ - _)
187+
case tpnme.Times => constantFold2(floatValue, _ * _)
188+
case tpnme.Div => constantFold2(floatValue, _ / _)
189+
case tpnme.Mod => constantFold2(floatValue, _ % _)
190+
case tpnme.Lt => constantFold2(floatValue, _ < _)
191+
case tpnme.Gt => constantFold2(floatValue, _ > _)
192+
case tpnme.Ge => constantFold2(floatValue, _ >= _)
193+
case tpnme.Le => constantFold2(floatValue, _ <= _)
194+
case tpnme.Min => constantFold2(floatValue, _ min _)
195+
case tpnme.Max => constantFold2(floatValue, _ max _)
196+
case tpnme.ToInt => constantFold1(floatValue, _.toInt)
197+
case tpnme.ToLong => constantFold1(floatValue, _.toLong)
198+
case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
199+
case _ => None
200+
else if owner == defn.CompiletimeOpsDoubleModuleClass then name match
201+
case tpnme.Abs => constantFold1(doubleValue, _.abs)
202+
case tpnme.Negate => constantFold1(doubleValue, x => -x)
203+
case tpnme.Plus => constantFold2(doubleValue, _ + _)
204+
case tpnme.Minus => constantFold2(doubleValue, _ - _)
205+
case tpnme.Times => constantFold2(doubleValue, _ * _)
206+
case tpnme.Div => constantFold2(doubleValue, _ / _)
207+
case tpnme.Mod => constantFold2(doubleValue, _ % _)
208+
case tpnme.Lt => constantFold2(doubleValue, _ < _)
209+
case tpnme.Gt => constantFold2(doubleValue, _ > _)
210+
case tpnme.Ge => constantFold2(doubleValue, _ >= _)
211+
case tpnme.Le => constantFold2(doubleValue, _ <= _)
212+
case tpnme.Min => constantFold2(doubleValue, _ min _)
213+
case tpnme.Max => constantFold2(doubleValue, _ max _)
214+
case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
215+
case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
216+
case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
217+
case _ => None
218+
else if owner == defn.CompiletimeOpsStringModuleClass then name match
219+
case tpnme.Plus => constantFold2(stringValue, _ + _)
220+
case tpnme.Length => constantFold1(stringValue, _.length)
221+
case tpnme.Matches => constantFold2(stringValue, _ matches _)
222+
case tpnme.Substring =>
223+
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
224+
case tpnme.CharAt =>
225+
constantFold2AB(stringValue, intValue, _ charAt _)
226+
case _ => None
227+
else if owner == defn.CompiletimeOpsBooleanModuleClass then name match
228+
case tpnme.Not => constantFold1(boolValue, x => !x)
229+
case tpnme.And => constantFold2(boolValue, _ && _)
230+
case tpnme.Or => constantFold2(boolValue, _ || _)
231+
case tpnme.Xor => constantFold2(boolValue, _ ^ _)
232+
case _ => None
233+
else None
234+
235+
constantType.getOrElse(NoType)
236+
}
237+
238+
case _ => NoType
239+
end tryCompiletimeConstantFold
240+
end TypeEval

0 commit comments

Comments
 (0)