1
+ package dotty .tools
2
+ package dotc
3
+ package core
4
+
5
+ import Types .* , Contexts .* , Symbols .* , Constants .* , Decorators .*
6
+ import config .Printers .typr
7
+ import reporting .trace
8
+ import StdNames .tpnme
9
+
10
+ object TypeEval :
11
+
12
+ def tryCompiletimeConstantFold (tp : AppliedType )(using Context ): Type = tp.tycon match
13
+ case tycon : TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
14
+ extension (tp : Type ) def fixForEvaluation : Type =
15
+ tp.normalized.dealias match
16
+ // enable operations for constant singleton terms. E.g.:
17
+ // ```
18
+ // final val one = 1
19
+ // type Two = one.type + one.type
20
+ // ```
21
+ case tp : TypeProxy if tp.underlying.isStable => tp.underlying.fixForEvaluation
22
+ case tp => tp
23
+
24
+ def constValue (tp : Type ): Option [Any ] = tp.fixForEvaluation match
25
+ case ConstantType (Constant (n)) => Some (n)
26
+ case _ => None
27
+
28
+ def boolValue (tp : Type ): Option [Boolean ] = tp.fixForEvaluation match
29
+ case ConstantType (Constant (n : Boolean )) => Some (n)
30
+ case _ => None
31
+
32
+ def intValue (tp : Type ): Option [Int ] = tp.fixForEvaluation match
33
+ case ConstantType (Constant (n : Int )) => Some (n)
34
+ case _ => None
35
+
36
+ def longValue (tp : Type ): Option [Long ] = tp.fixForEvaluation match
37
+ case ConstantType (Constant (n : Long )) => Some (n)
38
+ case _ => None
39
+
40
+ def floatValue (tp : Type ): Option [Float ] = tp.fixForEvaluation match
41
+ case ConstantType (Constant (n : Float )) => Some (n)
42
+ case _ => None
43
+
44
+ def doubleValue (tp : Type ): Option [Double ] = tp.fixForEvaluation match
45
+ case ConstantType (Constant (n : Double )) => Some (n)
46
+ case _ => None
47
+
48
+ def stringValue (tp : Type ): Option [String ] = tp.fixForEvaluation match
49
+ case ConstantType (Constant (n : String )) => Some (n)
50
+ case _ => None
51
+
52
+ // Returns Some(true) if the type is a constant.
53
+ // Returns Some(false) if the type is not a constant.
54
+ // Returns None if there is not enough information to determine if the type is a constant.
55
+ // The type is a constant if it is a constant type or a type operation composition of constant types.
56
+ // If we get a type reference for an argument, then the result is not yet known.
57
+ def isConst (tp : Type ): Option [Boolean ] = tp.dealias match
58
+ // known to be constant
59
+ case ConstantType (_) => Some (true )
60
+ // currently not a concrete known type
61
+ case TypeRef (NoPrefix ,_) => None
62
+ // currently not a concrete known type
63
+ case _ : TypeParamRef => None
64
+ // constant if the term is constant
65
+ case t : TermRef => isConst(t.underlying)
66
+ // an operation type => recursively check all argument compositions
67
+ case applied : AppliedType if defn.isCompiletimeAppliedType(applied.typeSymbol) =>
68
+ val argsConst = applied.args.map(isConst)
69
+ if (argsConst.exists(_.isEmpty)) None
70
+ else Some (argsConst.forall(_.get))
71
+ // all other types are considered not to be constant
72
+ case _ => Some (false )
73
+
74
+ def expectArgsNum (expectedNum : Int ): Unit =
75
+ // We can use assert instead of a compiler type error because this error should not
76
+ // occur since the type signature of the operation enforces the proper number of args.
77
+ assert(tp.args.length == expectedNum, s " Type operation expects $expectedNum arguments but found ${tp.args.length}" )
78
+
79
+ def natValue (tp : Type ): Option [Int ] = intValue(tp).filter(n => n >= 0 && n < Int .MaxValue )
80
+
81
+ // Runs the op and returns the result as a constant type.
82
+ // If the op throws an exception, then this exception is converted into a type error.
83
+ def runConstantOp (op : => Any ): Type =
84
+ val result =
85
+ try op
86
+ catch case e : Throwable =>
87
+ throw new TypeError (e.getMessage.nn)
88
+ ConstantType (Constant (result))
89
+
90
+ def constantFold1 [T ](extractor : Type => Option [T ], op : T => Any ): Option [Type ] =
91
+ expectArgsNum(1 )
92
+ extractor(tp.args.head).map(a => runConstantOp(op(a)))
93
+
94
+ def constantFold2 [T ](extractor : Type => Option [T ], op : (T , T ) => Any ): Option [Type ] =
95
+ constantFold2AB(extractor, extractor, op)
96
+
97
+ def constantFold2AB [TA , TB ](extractorA : Type => Option [TA ], extractorB : Type => Option [TB ], op : (TA , TB ) => Any ): Option [Type ] =
98
+ expectArgsNum(2 )
99
+ for
100
+ a <- extractorA(tp.args(0 ))
101
+ b <- extractorB(tp.args(1 ))
102
+ yield runConstantOp(op(a, b))
103
+
104
+ def constantFold3 [TA , TB , TC ](
105
+ extractorA : Type => Option [TA ],
106
+ extractorB : Type => Option [TB ],
107
+ extractorC : Type => Option [TC ],
108
+ op : (TA , TB , TC ) => Any
109
+ ): Option [Type ] =
110
+ expectArgsNum(3 )
111
+ for
112
+ a <- extractorA(tp.args(0 ))
113
+ b <- extractorB(tp.args(1 ))
114
+ c <- extractorC(tp.args(2 ))
115
+ yield runConstantOp(op(a, b, c))
116
+
117
+ trace(i " compiletime constant fold $tp" , typr, show = true ) {
118
+ val name = tycon.symbol.name
119
+ val owner = tycon.symbol.owner
120
+ val constantType =
121
+ if defn.isCompiletime_S(tycon.symbol) then
122
+ constantFold1(natValue, _ + 1 )
123
+ else if owner == defn.CompiletimeOpsAnyModuleClass then name match
124
+ case tpnme.Equals => constantFold2(constValue, _ == _)
125
+ case tpnme.NotEquals => constantFold2(constValue, _ != _)
126
+ case tpnme.ToString => constantFold1(constValue, _.toString)
127
+ case tpnme.IsConst => isConst(tp.args.head).map(b => ConstantType (Constant (b)))
128
+ case _ => None
129
+ else if owner == defn.CompiletimeOpsIntModuleClass then name match
130
+ case tpnme.Abs => constantFold1(intValue, _.abs)
131
+ case tpnme.Negate => constantFold1(intValue, x => - x)
132
+ // ToString is deprecated for ops.int, and moved to ops.any
133
+ case tpnme.ToString => constantFold1(intValue, _.toString)
134
+ case tpnme.Plus => constantFold2(intValue, _ + _)
135
+ case tpnme.Minus => constantFold2(intValue, _ - _)
136
+ case tpnme.Times => constantFold2(intValue, _ * _)
137
+ case tpnme.Div => constantFold2(intValue, _ / _)
138
+ case tpnme.Mod => constantFold2(intValue, _ % _)
139
+ case tpnme.Lt => constantFold2(intValue, _ < _)
140
+ case tpnme.Gt => constantFold2(intValue, _ > _)
141
+ case tpnme.Ge => constantFold2(intValue, _ >= _)
142
+ case tpnme.Le => constantFold2(intValue, _ <= _)
143
+ case tpnme.Xor => constantFold2(intValue, _ ^ _)
144
+ case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
145
+ case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
146
+ case tpnme.ASR => constantFold2(intValue, _ >> _)
147
+ case tpnme.LSL => constantFold2(intValue, _ << _)
148
+ case tpnme.LSR => constantFold2(intValue, _ >>> _)
149
+ case tpnme.Min => constantFold2(intValue, _ min _)
150
+ case tpnme.Max => constantFold2(intValue, _ max _)
151
+ case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer .numberOfLeadingZeros(_))
152
+ case tpnme.ToLong => constantFold1(intValue, _.toLong)
153
+ case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
154
+ case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
155
+ case _ => None
156
+ else if owner == defn.CompiletimeOpsLongModuleClass then name match
157
+ case tpnme.Abs => constantFold1(longValue, _.abs)
158
+ case tpnme.Negate => constantFold1(longValue, x => - x)
159
+ case tpnme.Plus => constantFold2(longValue, _ + _)
160
+ case tpnme.Minus => constantFold2(longValue, _ - _)
161
+ case tpnme.Times => constantFold2(longValue, _ * _)
162
+ case tpnme.Div => constantFold2(longValue, _ / _)
163
+ case tpnme.Mod => constantFold2(longValue, _ % _)
164
+ case tpnme.Lt => constantFold2(longValue, _ < _)
165
+ case tpnme.Gt => constantFold2(longValue, _ > _)
166
+ case tpnme.Ge => constantFold2(longValue, _ >= _)
167
+ case tpnme.Le => constantFold2(longValue, _ <= _)
168
+ case tpnme.Xor => constantFold2(longValue, _ ^ _)
169
+ case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
170
+ case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
171
+ case tpnme.ASR => constantFold2(longValue, _ >> _)
172
+ case tpnme.LSL => constantFold2(longValue, _ << _)
173
+ case tpnme.LSR => constantFold2(longValue, _ >>> _)
174
+ case tpnme.Min => constantFold2(longValue, _ min _)
175
+ case tpnme.Max => constantFold2(longValue, _ max _)
176
+ case tpnme.NumberOfLeadingZeros =>
177
+ constantFold1(longValue, java.lang.Long .numberOfLeadingZeros(_))
178
+ case tpnme.ToInt => constantFold1(longValue, _.toInt)
179
+ case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
180
+ case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
181
+ case _ => None
182
+ else if owner == defn.CompiletimeOpsFloatModuleClass then name match
183
+ case tpnme.Abs => constantFold1(floatValue, _.abs)
184
+ case tpnme.Negate => constantFold1(floatValue, x => - x)
185
+ case tpnme.Plus => constantFold2(floatValue, _ + _)
186
+ case tpnme.Minus => constantFold2(floatValue, _ - _)
187
+ case tpnme.Times => constantFold2(floatValue, _ * _)
188
+ case tpnme.Div => constantFold2(floatValue, _ / _)
189
+ case tpnme.Mod => constantFold2(floatValue, _ % _)
190
+ case tpnme.Lt => constantFold2(floatValue, _ < _)
191
+ case tpnme.Gt => constantFold2(floatValue, _ > _)
192
+ case tpnme.Ge => constantFold2(floatValue, _ >= _)
193
+ case tpnme.Le => constantFold2(floatValue, _ <= _)
194
+ case tpnme.Min => constantFold2(floatValue, _ min _)
195
+ case tpnme.Max => constantFold2(floatValue, _ max _)
196
+ case tpnme.ToInt => constantFold1(floatValue, _.toInt)
197
+ case tpnme.ToLong => constantFold1(floatValue, _.toLong)
198
+ case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
199
+ case _ => None
200
+ else if owner == defn.CompiletimeOpsDoubleModuleClass then name match
201
+ case tpnme.Abs => constantFold1(doubleValue, _.abs)
202
+ case tpnme.Negate => constantFold1(doubleValue, x => - x)
203
+ case tpnme.Plus => constantFold2(doubleValue, _ + _)
204
+ case tpnme.Minus => constantFold2(doubleValue, _ - _)
205
+ case tpnme.Times => constantFold2(doubleValue, _ * _)
206
+ case tpnme.Div => constantFold2(doubleValue, _ / _)
207
+ case tpnme.Mod => constantFold2(doubleValue, _ % _)
208
+ case tpnme.Lt => constantFold2(doubleValue, _ < _)
209
+ case tpnme.Gt => constantFold2(doubleValue, _ > _)
210
+ case tpnme.Ge => constantFold2(doubleValue, _ >= _)
211
+ case tpnme.Le => constantFold2(doubleValue, _ <= _)
212
+ case tpnme.Min => constantFold2(doubleValue, _ min _)
213
+ case tpnme.Max => constantFold2(doubleValue, _ max _)
214
+ case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
215
+ case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
216
+ case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
217
+ case _ => None
218
+ else if owner == defn.CompiletimeOpsStringModuleClass then name match
219
+ case tpnme.Plus => constantFold2(stringValue, _ + _)
220
+ case tpnme.Length => constantFold1(stringValue, _.length)
221
+ case tpnme.Matches => constantFold2(stringValue, _ matches _)
222
+ case tpnme.Substring =>
223
+ constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
224
+ case tpnme.CharAt =>
225
+ constantFold2AB(stringValue, intValue, _ charAt _)
226
+ case _ => None
227
+ else if owner == defn.CompiletimeOpsBooleanModuleClass then name match
228
+ case tpnme.Not => constantFold1(boolValue, x => ! x)
229
+ case tpnme.And => constantFold2(boolValue, _ && _)
230
+ case tpnme.Or => constantFold2(boolValue, _ || _)
231
+ case tpnme.Xor => constantFold2(boolValue, _ ^ _)
232
+ case _ => None
233
+ else None
234
+
235
+ constantType.getOrElse(NoType )
236
+ }
237
+
238
+ case _ => NoType
239
+ end tryCompiletimeConstantFold
240
+ end TypeEval
0 commit comments