Skip to content

Commit 2aeb5ae

Browse files
Merge pull request #7628 from MaximeKjaer/singleton-arithmetic
Add primitive compiletime operations on singleton types
2 parents 5dc232a + 7e0a1db commit 2aeb5ae

16 files changed

+558
-19
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ class Definitions {
236236
@tu lazy val CompiletimeTesting_ErrorKind: Symbol = ctx.requiredModule("scala.compiletime.testing.ErrorKind")
237237
@tu lazy val CompiletimeTesting_ErrorKind_Parser: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Parser")
238238
@tu lazy val CompiletimeTesting_ErrorKind_Typer: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Typer")
239+
@tu lazy val CompiletimeOpsPackageObject: Symbol = ctx.requiredModule("scala.compiletime.ops.package")
240+
@tu lazy val CompiletimeOpsPackageObjectAny: Symbol = ctx.requiredModule("scala.compiletime.ops.package.any")
241+
@tu lazy val CompiletimeOpsPackageObjectInt: Symbol = ctx.requiredModule("scala.compiletime.ops.package.int")
242+
@tu lazy val CompiletimeOpsPackageObjectString: Symbol = ctx.requiredModule("scala.compiletime.ops.package.string")
243+
@tu lazy val CompiletimeOpsPackageObjectBoolean: Symbol = ctx.requiredModule("scala.compiletime.ops.package.boolean")
239244

240245
/** The `scalaShadowing` package is used to safely modify classes and
241246
* objects in scala so that they can be used from dotty. They will
@@ -946,6 +951,26 @@ class Definitions {
946951
final def isCompiletime_S(sym: Symbol)(implicit ctx: Context): Boolean =
947952
sym.name == tpnme.S && sym.owner == CompiletimePackageObject.moduleClass
948953

954+
private val compiletimePackageAnyTypes: Set[Name] = Set(tpnme.Equals, tpnme.NotEquals)
955+
private val compiletimePackageIntTypes: Set[Name] = Set(
956+
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
957+
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
958+
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max, tpnme.ToString,
959+
)
960+
private val compiletimePackageBooleanTypes: Set[Name] = Set(tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or)
961+
private val compiletimePackageStringTypes: Set[Name] = Set(tpnme.Plus)
962+
963+
final def isCompiletimeAppliedType(sym: Symbol)(implicit ctx: Context): Boolean = {
964+
def isOpsPackageObjectAppliedType: Boolean =
965+
sym.owner == CompiletimeOpsPackageObjectAny.moduleClass && compiletimePackageAnyTypes.contains(sym.name) ||
966+
sym.owner == CompiletimeOpsPackageObjectInt.moduleClass && compiletimePackageIntTypes.contains(sym.name) ||
967+
sym.owner == CompiletimeOpsPackageObjectBoolean.moduleClass && compiletimePackageBooleanTypes.contains(sym.name) ||
968+
sym.owner == CompiletimeOpsPackageObjectString.moduleClass && compiletimePackageStringTypes.contains(sym.name)
969+
970+
sym.isType && (isCompiletime_S(sym) || isOpsPackageObjectAppliedType)
971+
}
972+
973+
949974
// ----- Symbol sets ---------------------------------------------------
950975

951976
@tu lazy val AbstractFunctionType: Array[TypeRef] = mkArityArray("scala.runtime.AbstractFunction", MaxImplementedFunctionArity, 0)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,34 @@ object StdNames {
202202
final val Product: N = "Product"
203203
final val PartialFunction: N = "PartialFunction"
204204
final val PrefixType: N = "PrefixType"
205-
final val S: N = "S"
206205
final val Serializable: N = "Serializable"
207206
final val Singleton: N = "Singleton"
208207
final val Throwable: N = "Throwable"
209208
final val IOOBException: N = "IndexOutOfBoundsException"
210209
final val FunctionXXL: N = "FunctionXXL"
211210

211+
final val Abs: N = "Abs"
212+
final val And: N = "&&"
213+
final val Div: N = "/"
214+
final val Equals: N = "=="
215+
final val Ge: N = ">="
216+
final val Gt: N = ">"
217+
final val Le: N = "<="
218+
final val Lt: N = "<"
219+
final val Max: N = "Max"
220+
final val Min: N = "Min"
221+
final val Minus: N = "-"
222+
final val Mod: N = "%"
223+
final val Negate: N = "Negate"
224+
final val Not: N = "!"
225+
final val NotEquals: N = "!="
226+
final val Or: N = "||"
227+
final val Plus: N = "+"
228+
final val S: N = "S"
229+
final val Times: N = "*"
230+
final val ToString: N = "ToString"
231+
final val Xor: N = "^"
232+
212233
final val ClassfileAnnotation: N = "ClassfileAnnotation"
213234
final val ClassManifest: N = "ClassManifest"
214235
final val Enum: N = "Enum"

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,12 +371,16 @@ class TypeApplications(val self: Type) extends AnyVal {
371371
// just eta-reduction (ignoring variance annotations).
372372
// See i2201*.scala for examples where more aggressive
373373
// reduction would break type inference.
374-
dealiased.paramRefs == dealiasedArgs
374+
dealiased.paramRefs == dealiasedArgs ||
375+
defn.isCompiletimeAppliedType(tyconBody.typeSymbol)
375376
case _ => false
376377
}
377378
}
378379
if ((dealiased eq stripped) || followAlias)
379-
try dealiased.instantiate(args)
380+
try {
381+
val instantiated = dealiased.instantiate(args)
382+
if (followAlias) instantiated.normalized else instantiated
383+
}
380384
catch { case ex: IndexOutOfBoundsException => AppliedType(self, args) }
381385
else AppliedType(self, args)
382386
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
971971
compareLower(bounds(param2), tyconIsTypeRef = false)
972972
case tycon2: TypeRef =>
973973
isMatchingApply(tp1) ||
974-
defn.isCompiletime_S(tycon2.symbol) && compareS(tp2, tp1, fromBelow = true) || {
974+
defn.isCompiletimeAppliedType(tycon2.symbol) && compareCompiletimeAppliedType(tp2, tp1, fromBelow = true) || {
975975
tycon2.info match {
976976
case info2: TypeBounds =>
977977
compareLower(info2, tyconIsTypeRef = true)
@@ -1011,7 +1011,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
10111011
case tycon1: TypeRef =>
10121012
val sym = tycon1.symbol
10131013
!sym.isClass && {
1014-
defn.isCompiletime_S(sym) && compareS(tp1, tp2, fromBelow = false) ||
1014+
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
10151015
recur(tp1.superType, tp2) ||
10161016
tryLiftedToThis1
10171017
}
@@ -1021,7 +1021,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
10211021
false
10221022
}
10231023

1024-
/** Compare `tp` of form `S[arg]` with `other`, via ">:>` if fromBelow is true, "<:<" otherwise.
1024+
/** Compare `tp` of form `S[arg]` with `other`, via ">:>" if fromBelow is true, "<:<" otherwise.
10251025
* If `arg` is a Nat constant `n`, proceed with comparing `n + 1` and `other`.
10261026
* Otherwise, if `other` is a Nat constant `n`, proceed with comparing `arg` and `n - 1`.
10271027
*/
@@ -1043,6 +1043,18 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
10431043
case _ => false
10441044
}
10451045

1046+
/** Compare `tp` of form `tycon[...args]`, where `tycon` is a scala.compiletime type,
1047+
* with `other` via ">:>" if fromBelow is true, "<:<" otherwise.
1048+
* Delegates to compareS if `tycon` is scala.compiletime.S. Otherwise, constant folds if possible.
1049+
*/
1050+
def compareCompiletimeAppliedType(tp: AppliedType, other: Type, fromBelow: Boolean): Boolean = {
1051+
if (defn.isCompiletime_S(tp.tycon.typeSymbol)) compareS(tp, other, fromBelow)
1052+
else {
1053+
val folded = tp.tryCompiletimeConstantFold
1054+
if (fromBelow) recur(other, folded) else recur(folded, other)
1055+
}
1056+
}
1057+
10461058
/** Like tp1 <:< tp2, but returns false immediately if we know that
10471059
* the case was covered previously during subtyping.
10481060
*/

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3640,19 +3640,97 @@ object Types {
36403640
case _ =>
36413641
NoType
36423642
}
3643-
if (defn.isCompiletime_S(tycon.symbol) && args.length == 1)
3644-
trace(i"normalize S $this", typr, show = true) {
3645-
args.head.normalized match {
3646-
case ConstantType(Constant(n: Int)) if n >= 0 && n < Int.MaxValue =>
3647-
ConstantType(Constant(n + 1))
3648-
case none => tryMatchAlias
3649-
}
3650-
}
3651-
else tryMatchAlias
3643+
3644+
tryCompiletimeConstantFold.orElse(tryMatchAlias)
3645+
36523646
case _ =>
36533647
NoType
36543648
}
36553649

3650+
def tryCompiletimeConstantFold(implicit ctx: Context): Type = tycon match {
3651+
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
3652+
def constValue(tp: Type): Option[Any] = tp match {
3653+
case ConstantType(Constant(n)) => Some(n)
3654+
case _ => None
3655+
}
3656+
3657+
def boolValue(tp: Type): Option[Boolean] = tp match {
3658+
case ConstantType(Constant(n: Boolean)) => Some(n)
3659+
case _ => None
3660+
}
3661+
3662+
def intValue(tp: Type): Option[Int] = tp match {
3663+
case ConstantType(Constant(n: Int)) => Some(n)
3664+
case _ => None
3665+
}
3666+
3667+
def stringValue(tp: Type): Option[String] = tp match {
3668+
case ConstantType(Constant(n: String)) => Some(n)
3669+
case _ => None
3670+
}
3671+
3672+
def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)
3673+
3674+
def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
3675+
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))
3676+
3677+
def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
3678+
for {
3679+
a <- extractor(args.head.normalized)
3680+
b <- extractor(args.tail.head.normalized)
3681+
} yield ConstantType(Constant(op(a, b)))
3682+
3683+
trace(i"compiletime constant fold $this", typr, show = true) {
3684+
val name = tycon.symbol.name
3685+
val owner = tycon.symbol.owner
3686+
val nArgs = args.length
3687+
val constantType =
3688+
if (owner == defn.CompiletimePackageObject.moduleClass) name match {
3689+
case tpnme.S if nArgs == 1 => constantFold1(natValue, _ + 1)
3690+
case _ => None
3691+
} else if (owner == defn.CompiletimeOpsPackageObjectAny.moduleClass) name match {
3692+
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
3693+
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
3694+
case _ => None
3695+
} else if (owner == defn.CompiletimeOpsPackageObjectInt.moduleClass) name match {
3696+
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
3697+
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x)
3698+
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
3699+
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
3700+
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
3701+
case tpnme.Times if nArgs == 2 => constantFold2(intValue, _ * _)
3702+
case tpnme.Div if nArgs == 2 => constantFold2(intValue, {
3703+
case (_, 0) => throw new TypeError("Division by 0")
3704+
case (a, b) => a / b
3705+
})
3706+
case tpnme.Mod if nArgs == 2 => constantFold2(intValue, {
3707+
case (_, 0) => throw new TypeError("Modulo by 0")
3708+
case (a, b) => a % b
3709+
})
3710+
case tpnme.Lt if nArgs == 2 => constantFold2(intValue, _ < _)
3711+
case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _)
3712+
case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _)
3713+
case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _)
3714+
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
3715+
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
3716+
case _ => None
3717+
} else if (owner == defn.CompiletimeOpsPackageObjectString.moduleClass) name match {
3718+
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
3719+
case _ => None
3720+
} else if (owner == defn.CompiletimeOpsPackageObjectBoolean.moduleClass) name match {
3721+
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x)
3722+
case tpnme.And if nArgs == 2 => constantFold2(boolValue, _ && _)
3723+
case tpnme.Or if nArgs == 2 => constantFold2(boolValue, _ || _)
3724+
case tpnme.Xor if nArgs == 2 => constantFold2(boolValue, _ ^ _)
3725+
case _ => None
3726+
} else None
3727+
3728+
constantType.getOrElse(NoType)
3729+
}
3730+
3731+
case _ => NoType
3732+
}
3733+
36563734
def lowerBound(implicit ctx: Context): Type = tycon.stripTypeVar match {
36573735
case tycon: TypeRef =>
36583736
tycon.info match {
@@ -4022,7 +4100,7 @@ object Types {
40224100
myReduced =
40234101
trace(i"reduce match type $this $hashCode", typr, show = true) {
40244102
try
4025-
typeComparer.matchCases(scrutinee, cases)(trackingCtx)
4103+
typeComparer.matchCases(scrutinee.normalized, cases)(trackingCtx)
40264104
catch {
40274105
case ex: Throwable =>
40284106
handleRecursive("reduce type ", i"$scrutinee match ...", ex)

docs/docs/reference/metaprogramming/inline.md

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ val intTwo: 2 = natTwo
295295

296296
The `scala.compiletime` package contains helper definitions that provide support for compile time operations over values. They are described in the following.
297297

298-
#### `constValue`, `constValueOpt`, and the `S` combinator
298+
### `constValue`, `constValueOpt`, and the `S` combinator
299299

300300
`constvalue` is a function that produces the constant value represented by a
301301
type.
@@ -317,7 +317,7 @@ enabling us to handle situations where a value is not present. Note that `S` is
317317
the type of the successor of some singleton type. For example the type `S[1]` is
318318
the singleton type `2`.
319319

320-
#### `erasedValue`
320+
### `erasedValue`
321321

322322
So far we have seen inline methods that take terms (tuples and integers) as
323323
parameters. What if we want to base case distinctions on types instead? For
@@ -381,7 +381,7 @@ final val two = toIntT[Succ[Succ[Zero.type]]]
381381
behavior. Since `toInt` performs static checks over the static type of `N` we
382382
can safely use it to scrutinize its return type (`S[S[Z]]` in this case).
383383

384-
#### `error`
384+
### `error`
385385

386386
The `error` method is used to produce user-defined compile errors during inline expansion.
387387
It has the following signature:
@@ -411,6 +411,54 @@ inline def fail(p1: => Any) = {
411411
fail(identity("foo")) // error: failed on: identity("foo")
412412
```
413413

414+
### The `scala.compiletime.ops` package
415+
416+
The `scala.compiletime.ops` package contains types that provide support for
417+
primitive operations on singleton types. For example,
418+
`scala.compiletime.ops.int.*` provides support for multiplying two singleton
419+
`Int` types, and `scala.compiletime.ops.boolean.&&` for the conjunction of two
420+
`Boolean` types. When all arguments to a type in `scala.compiletime.ops` are
421+
singleton types, the compiler can evaluate the result of the operation.
422+
423+
```scala
424+
import scala.compiletime.ops.int._
425+
import scala.compiletime.ops.boolean._
426+
427+
val conjunction: true && true = true
428+
val multiplication: 3 * 5 = 15
429+
```
430+
431+
Many of these singleton operation types are meant to be used infix (as in [SLS §
432+
3.2.8](https://www.scala-lang.org/files/archive/spec/2.12/03-types.html#infix-types)),
433+
and are annotated with [`@infix`](scala.annotation.infix) accordingly.
434+
435+
Since type aliases have the same precedence rules as their term-level
436+
equivalents, the operations compose with the expected precedence rules:
437+
438+
```scala
439+
import scala.compiletime.ops.int._
440+
val x: 1 + 2 * 3 = 7
441+
```
442+
443+
The operation types are located in packages named after the type of the
444+
left-hand side parameter: for instance, `scala.compiletime.int.+` represents
445+
addition of two numbers, while `scala.compiletime.string.+` represents string
446+
concatenation. To use both and distinguish the two types from each other, a
447+
match type can dispatch to the correct implementation:
448+
449+
```scala
450+
import scala.compiletime.ops._
451+
import scala.annotation.infix
452+
453+
@infix type +[X <: Int | String, Y <: Int | String] = (X, Y) match {
454+
case (Int, Int) => int.+[X, Y]
455+
case (String, String) => string.+[X, Y]
456+
}
457+
458+
val concat: "a" + "b" = "ab"
459+
val addition: 1 + 1 = 2
460+
```
461+
414462
## Summoning Implicits Selectively
415463

416464
It is foreseen that many areas of typelevel programming can be done with rewrite

0 commit comments

Comments
 (0)