Skip to content

Add primitive compiletime operations on singleton types #7628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class Definitions {
@tu lazy val CompiletimeTesting_ErrorKind: Symbol = ctx.requiredModule("scala.compiletime.testing.ErrorKind")
@tu lazy val CompiletimeTesting_ErrorKind_Parser: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Parser")
@tu lazy val CompiletimeTesting_ErrorKind_Typer: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Typer")
@tu lazy val CompiletimeOpsPackageObject: Symbol = ctx.requiredModule("scala.compiletime.ops.package")

/** The `scalaShadowing` package is used to safely modify classes and
* objects in scala so that they can be used from dotty. They will
Expand Down Expand Up @@ -898,6 +899,20 @@ class Definitions {
final def isCompiletime_S(sym: Symbol)(implicit ctx: Context): Boolean =
sym.name == tpnme.S && sym.owner == CompiletimePackageObject.moduleClass

final def isCompiletimeAppliedType(sym: Symbol)(implicit ctx: Context): Boolean = {
def isOpsPackageObjectAppliedType: Boolean =
sym.owner == CompiletimeOpsPackageObject.moduleClass && Set(
tpnme.Equals, tpnme.NotEquals,
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max, tpnme.ToString,
tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or
).contains(sym.name)

sym.isType && (isCompiletime_S(sym) || isOpsPackageObjectAppliedType)
}


// ----- Symbol sets ---------------------------------------------------

@tu lazy val AbstractFunctionType: Array[TypeRef] = mkArityArray("scala.runtime.AbstractFunction", MaxImplementedFunctionArity, 0)
Expand Down
23 changes: 22 additions & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,34 @@ object StdNames {
final val Product: N = "Product"
final val PartialFunction: N = "PartialFunction"
final val PrefixType: N = "PrefixType"
final val S: N = "S"
final val Serializable: N = "Serializable"
final val Singleton: N = "Singleton"
final val Throwable: N = "Throwable"
final val IOOBException: N = "IndexOutOfBoundsException"
final val FunctionXXL: N = "FunctionXXL"

final val Abs: N = "Abs"
final val And: N = "&&"
final val Div: N = "/"
final val Equals: N = "=="
final val Ge: N = ">="
final val Gt: N = ">"
final val Le: N = "<="
final val Lt: N = "<"
final val Max: N = "Max"
final val Min: N = "Min"
final val Minus: N = "-"
final val Mod: N = "%"
final val Negate: N = "Negate"
final val Not: N = "!"
final val NotEquals: N = "!="
final val Or: N = "||"
final val Plus: N = "+"
final val S: N = "S"
final val Times: N = "*"
final val ToString: N = "ToString"
final val Xor: N = "^"

final val ClassfileAnnotation: N = "ClassfileAnnotation"
final val ClassManifest: N = "ClassManifest"
final val Enum: N = "Enum"
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,16 @@ class TypeApplications(val self: Type) extends AnyVal {
// just eta-reduction (ignoring variance annotations).
// See i2201*.scala for examples where more aggressive
// reduction would break type inference.
dealiased.paramRefs == dealiasedArgs
dealiased.paramRefs == dealiasedArgs ||
defn.isCompiletimeAppliedType(tyconBody.typeSymbol)
case _ => false
}
}
if ((dealiased eq stripped) || followAlias)
try dealiased.instantiate(args)
try {
val instantiated = dealiased.instantiate(args)
if (followAlias) instantiated.normalized else instantiated
}
catch { case ex: IndexOutOfBoundsException => AppliedType(self, args) }
else AppliedType(self, args)
}
Expand Down
15 changes: 12 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
compareLower(bounds(param2), tyconIsTypeRef = false)
case tycon2: TypeRef =>
isMatchingApply(tp1) ||
defn.isCompiletime_S(tycon2.symbol) && compareS(tp2, tp1, fromBelow = true) || {
defn.isCompiletimeAppliedType(tycon2.symbol) && compareCompiletimeAppliedType(tp2, tp1, fromBelow = true) || {
tycon2.info match {
case info2: TypeBounds =>
compareLower(info2, tyconIsTypeRef = true)
Expand Down Expand Up @@ -1005,7 +1005,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
case tycon1: TypeRef =>
val sym = tycon1.symbol
!sym.isClass && {
defn.isCompiletime_S(sym) && compareS(tp1, tp2, fromBelow = false) ||
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
recur(tp1.superType, tp2) ||
tryLiftedToThis1
}
Expand All @@ -1015,7 +1015,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
false
}

/** Compare `tp` of form `S[arg]` with `other`, via ">:>` if fromBelow is true, "<:<" otherwise.
/** Compare `tp` of form `S[arg]` with `other`, via ">:>" if fromBelow is true, "<:<" otherwise.
* If `arg` is a Nat constant `n`, proceed with comparing `n + 1` and `other`.
* Otherwise, if `other` is a Nat constant `n`, proceed with comparing `arg` and `n - 1`.
*/
Expand All @@ -1037,6 +1037,15 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
case _ => false
}

/** Compare `tp` of form `tycon[...args]`, where `tycon` is a scala.compiletime type,
* with `other` via ">:>" if fromBelow is true, "<:<" otherwise.
* Delegates to compareS if `tycon` is scala.compiletime.S. Otherwise, constant folds if possible.
*/
def compareCompiletimeAppliedType(tp: AppliedType, other: Type, fromBelow: Boolean): Boolean = {
if (defn.isCompiletime_S(tp.tycon.typeSymbol)) compareS(tp, other, fromBelow)
else tp.tryCompiletimeConstantFold.exists(folded => if (fromBelow) recur(other, folded) else recur(folded, other))
}

/** Like tp1 <:< tp2, but returns false immediately if we know that
* the case was covered previously during subtyping.
*/
Expand Down
80 changes: 70 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3595,19 +3595,79 @@ object Types {
case _ =>
NoType
}
if (defn.isCompiletime_S(tycon.symbol) && args.length == 1)
trace(i"normalize S $this", typr, show = true) {
args.head.normalized match {
case ConstantType(Constant(n: Int)) if n >= 0 && n < Int.MaxValue =>
ConstantType(Constant(n + 1))
case none => tryMatchAlias
}
}
else tryMatchAlias

tryCompiletimeConstantFold.getOrElse(tryMatchAlias)

case _ =>
NoType
}

def tryCompiletimeConstantFold(implicit ctx: Context): Option[Type] = tycon match {
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
def constValue(tp: Type): Option[Any] = tp match {
case ConstantType(Constant(n)) => Some(n)
case _ => None
}

def boolValue(tp: Type): Option[Boolean] = tp match {
case ConstantType(Constant(n: Boolean)) => Some(n)
case _ => None
}

def intValue(tp: Type): Option[Int] = tp match {
case ConstantType(Constant(n: Int)) => Some(n)
case _ => None
}

def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))

def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
for {
a <- extractor(args.head.normalized)
b <- extractor(args.tail.head.normalized)
} yield ConstantType(Constant(op(a, b)))

trace(i"compiletime constant fold $this", typr, show = true) {
if (args.length == 1) tycon.symbol.name match {
case tpnme.S => constantFold1(natValue, _ + 1)
case tpnme.Abs => constantFold1(intValue, _.abs)
case tpnme.Negate => constantFold1(intValue, x => -x)
case tpnme.Not => constantFold1(boolValue, x => !x)
case tpnme.ToString => constantFold1(intValue, _.toString)
case _ => None
} else if (args.length == 2) tycon.symbol.name match {
case tpnme.Equals => constantFold2(constValue, _ == _)
case tpnme.NotEquals => constantFold2(constValue, _ != _)
case tpnme.Plus => constantFold2(intValue, _ + _)
case tpnme.Minus => constantFold2(intValue, _ - _)
case tpnme.Times => constantFold2(intValue, _ * _)
case tpnme.Div => constantFold2(intValue, {
case (_, 0) => throw new TypeError("Division by 0")
case (a, b) => a / b
})
case tpnme.Mod => constantFold2(intValue, {
case (_, 0) => throw new TypeError("Modulo by 0")
case (a, b) => a % b
})
case tpnme.Lt => constantFold2(intValue, _ < _)
case tpnme.Gt => constantFold2(intValue, _ > _)
case tpnme.Ge => constantFold2(intValue, _ >= _)
case tpnme.Le => constantFold2(intValue, _ <= _)
case tpnme.Min => constantFold2(intValue, _ min _)
case tpnme.Max => constantFold2(intValue, _ max _)
case tpnme.And => constantFold2(boolValue, _ && _)
case tpnme.Or => constantFold2(boolValue, _ || _)
case tpnme.Xor => constantFold2(boolValue, _ ^ _)
case _ => None
} else None
}

case _ => None
}

def lowerBound(implicit ctx: Context): Type = tycon.stripTypeVar match {
case tycon: TypeRef =>
tycon.info match {
Expand Down Expand Up @@ -3974,7 +4034,7 @@ object Types {
myReduced =
trace(i"reduce match type $this $hashCode", typr, show = true) {
try
typeComparer.matchCases(scrutinee, cases)(trackingCtx)
typeComparer.matchCases(scrutinee.normalized, cases)(trackingCtx)
catch {
case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
Expand Down
30 changes: 30 additions & 0 deletions library/src/scala/compiletime/ops/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package scala.compiletime

import scala.annotation.infix

package object ops {
@infix type ==[X <: AnyVal, Y <: AnyVal] <: Boolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe == and != should be split as well. Either everything is split according to the supertype or non of it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved them into scala.compiletime.ops.any in 8c117c1.

The alternative was to duplicate them into each subpackage for each supported type, and then duplicate the constant folding code... which felt like a lot of duplication. Seeing that == and != are defined on Any, I think this solution makes the most sense. It also emphasizes that equality is between to Any values, and that 1 == "1" will return false.

@infix type !=[X <: AnyVal, Y <: AnyVal] <: Boolean

@infix type +[X <: Int, Y <: Int] <: Int
@infix type -[X <: Int, Y <: Int] <: Int
@infix type *[X <: Int, Y <: Int] <: Int
@infix type /[X <: Int, Y <: Int] <: Int
@infix type %[X <: Int, Y <: Int] <: Int

@infix type <[X <: Int, Y <: Int] <: Boolean
@infix type >[X <: Int, Y <: Int] <: Boolean
@infix type >=[X <: Int, Y <: Int] <: Boolean
@infix type <=[X <: Int, Y <: Int] <: Boolean

type Abs[X <: Int] <: Int
type Negate[X <: Int] <: Int
type Min[X <: Int, Y <: Int] <: Int
type Max[X <: Int, Y <: Int] <: Int
type ToString[X <: Int] <: String

type ![X <: Boolean] <: Boolean
@infix type ^[X <: Boolean, Y <: Boolean] <: Boolean
@infix type &&[X <: Boolean, Y <: Boolean] <: Boolean
@infix type ||[X <: Boolean, Y <: Boolean] <: Boolean
}
12 changes: 12 additions & 0 deletions tests/neg/singleton-ops-match-type-scrutinee.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import scala.compiletime.ops._

object Test {
type Max2[A <: Int, B <: Int] <: Int = (A < B) match {
case true => B
case false => A
}
val t0: Max2[-1, 10] = 10
val t1: Max2[4, 2] = 4
val t2: Max2[2, 2] = 1 // error
val t3: Max2[-1, -1] = 0 // error
}
12 changes: 12 additions & 0 deletions tests/neg/singleton-ops-recursive-match-type.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import scala.compiletime.ops._

object Test {
type GCD[A <: Int, B <: Int] <: Int = B match {
case 0 => A
case _ => GCD[B, A % B]
}
val t0: GCD[10, 0] = 10
val t1: GCD[252, 105] = 21
val t3: GCD[105, 147] = 10 // error
val t4: GCD[1, 1] = -1 // error
}
9 changes: 9 additions & 0 deletions tests/neg/singleton-ops-type-alias.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.compiletime.ops._

object Test {
type Xor[A <: Boolean, B <: Boolean] = (A && ![B]) || (![A] && B)
val t0: Xor[true, true] = false
val t1: Xor[false, true] = true
val t2: Xor[true, false] = false // error
val t3: Xor[false, false] = true // error
}
100 changes: 100 additions & 0 deletions tests/neg/singleton-ops.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import scala.compiletime.ops._

object Test {
summon[2 + 3 =:= 6 - 1]
summon[1763 =:= 41 * 43]
summon[2 + 2 =:= 3] // error
summon[29 * 31 =:= 900] // error
summon[Int <:< Int + 1] // error
summon[1 + Int <:< Int]

val t0: 2 + 3 = 5
val t1: 2 + 2 = 5 // error
val t2: -1 + 1 = 0
val t3: -5 + -5 = -11 // error

val t4: 10 * 20 = 200
val t5: 30 * 10 = 400 // error
val t6: -10 * 2 = -20
val t7: -2 * -2 = 4

val t8: 10 / 2 = 5
val t9: 11 / -2 = -5 // Integer division
val t10: 2 / 4 = 2 // error
val t11: -1 / 0 = 1 // error

val t12: 10 % 3 = 1
val t13: 12 % 2 = 1 // error
val t14: 1 % -3 = 1
val t15: -3 % 0 = 0 // error

val t16: 1 < 0 = false
val t17: 0 < 1 = true
val t18: 10 < 5 = true // error
val t19: 5 < 10 = false // error

val t20: 1 <= 0 = false
val t21: 1 <= 1 = true
val t22: 10 <= 5 = true // error
val t23: 5 <= 10 = false // error

val t24: 1 > 0 = true
val t25: 0 > 1 = false
val t26: 10 > 5 = false // error
val t27: 5 > 10 = true // error

val t28: 1 >= 1 = true
val t29: 0 >= 1 = false
val t30: 10 >= 5 = false // error
val t31: 5 >= 10 = true // error

val t32: 1 == 1 = true
val t33: 0 == 1 = false
val t34: 10 == 5 = true // error
val t35: 10 == 10 = false // error

val t36: 1 != 1 = false
val t37: 0 != 1 = true
val t38: 10 != 5 = false // error
val t39: 10 != 10 = true // error

val t40: Abs[0] = 0
val t41: Abs[-1] = 1
val t42: Abs[-1] = -1 // error
val t43: Abs[1] = -1 // error

val t44: Negate[-10] = 10
val t45: Negate[10] = -10
val t46: Negate[1] = 1 // error
val t47: Negate[-1] = -1 // error

val t48: Max[-1, 10] = 10
val t49: Max[4, 2] = 4
val t50: Max[2, 2] = 1 // error
val t51: Max[-1, -1] = 0 // error

val t52: Min[-1, 10] = -1
val t53: Min[4, 2] = 2
val t54: Min[2, 2] = 1 // error
val t55: Min[-1, -1] = 0 // error

val t56: true && true = true
val t57: true && false = false
val t58: false && true = true // error
val t59: false && false = true // error

val t60: true || true = true
val t61: true || false = true
val t62: false || true = false // error
val t63: false || false = true // error

val t64: ![true] = false
val t65: ![false] = true
val t66: ![true] = true // error
val t67: ![false] = false // error

val t68: ToString[213] = "213"
val t69: ToString[-1] = "-1"
val t70: ToString[0] = "-0" // error
val t71: ToString[200] = "100" // error
}
Loading