Skip to content

Add primitive compiletime operations on singleton types #7628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ class Definitions {
@tu lazy val CompiletimeTesting_ErrorKind: Symbol = ctx.requiredModule("scala.compiletime.testing.ErrorKind")
@tu lazy val CompiletimeTesting_ErrorKind_Parser: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Parser")
@tu lazy val CompiletimeTesting_ErrorKind_Typer: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Typer")
@tu lazy val CompiletimeIntPackageObject: Symbol = ctx.requiredModule("scala.compiletime.int.package")
@tu lazy val CompiletimeBooleanPackageObject: Symbol = ctx.requiredModule("scala.compiletime.boolean.package")

/** The `scalaShadowing` package is used to safely modify classes and
* objects in scala so that they can be used from dotty. They will
Expand Down Expand Up @@ -898,6 +900,28 @@ class Definitions {
final def isCompiletime_S(sym: Symbol)(implicit ctx: Context): Boolean =
sym.name == tpnme.S && sym.owner == CompiletimePackageObject.moduleClass

final def isCompiletimeAppliedType(sym: Symbol)(implicit ctx: Context): Boolean = {
def isPackageObjectAppliedType: Boolean =
sym.owner == CompiletimePackageObject.moduleClass && Set(
tpnme.S, tpnme.Equals, tpnme.NotEquals
).contains(sym.name)

def isIntAppliedType: Boolean =
sym.owner == CompiletimeIntPackageObject.moduleClass && Set(
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max
).contains(sym.name)

def isBooleanAppliedType: Boolean =
sym.owner == CompiletimeBooleanPackageObject.moduleClass && Set(
tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or
).contains(sym.name)

isPackageObjectAppliedType || isIntAppliedType || isBooleanAppliedType
}


// ----- Symbol sets ---------------------------------------------------

@tu lazy val AbstractFunctionType: Array[TypeRef] = mkArityArray("scala.runtime.AbstractFunction", MaxImplementedFunctionArity, 0)
Expand Down
22 changes: 21 additions & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,33 @@ object StdNames {
final val Product: N = "Product"
final val PartialFunction: N = "PartialFunction"
final val PrefixType: N = "PrefixType"
final val S: N = "S"
final val Serializable: N = "Serializable"
final val Singleton: N = "Singleton"
final val Throwable: N = "Throwable"
final val IOOBException: N = "IndexOutOfBoundsException"
final val FunctionXXL: N = "FunctionXXL"

final val Abs: N = "Abs"
final val And: N = "&&"
final val Div: N = "/"
final val Equals: N = "=="
final val Ge: N = ">="
final val Gt: N = ">"
final val Le: N = "<="
final val Lt: N = "<"
final val Max: N = "Max"
final val Min: N = "Min"
final val Minus: N = "-"
final val Mod: N = "%"
final val Negate: N = "Negate"
final val Not: N = "!"
final val NotEquals: N = "!="
final val Or: N = "||"
final val Plus: N = "+"
final val S: N = "S"
final val Times: N = "*"
final val Xor: N = "^"

final val ClassfileAnnotation: N = "ClassfileAnnotation"
final val ClassManifest: N = "ClassManifest"
final val Enum: N = "Enum"
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class TypeApplications(val self: Type) extends AnyVal {
}
}
if ((dealiased eq stripped) || followAlias)
try dealiased.instantiate(args)
try dealiased.instantiate(args).normalized
catch { case ex: IndexOutOfBoundsException => AppliedType(self, args) }
else AppliedType(self, args)
}
Expand Down
9 changes: 7 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
compareLower(bounds(param2), tyconIsTypeRef = false)
case tycon2: TypeRef =>
isMatchingApply(tp1) ||
defn.isCompiletime_S(tycon2.symbol) && compareS(tp2, tp1, fromBelow = true) || {
defn.isCompiletimeAppliedType(tycon2.symbol) && compareCompiletimeAppliedType(tp2, tp1, fromBelow = true) || {
tycon2.info match {
case info2: TypeBounds =>
compareLower(info2, tyconIsTypeRef = true)
Expand Down Expand Up @@ -1005,7 +1005,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
case tycon1: TypeRef =>
val sym = tycon1.symbol
!sym.isClass && {
defn.isCompiletime_S(sym) && compareS(tp1, tp2, fromBelow = false) ||
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
recur(tp1.superType, tp2) ||
tryLiftedToThis1
}
Expand Down Expand Up @@ -1037,6 +1037,11 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
case _ => false
}

def compareCompiletimeAppliedType(tp: AppliedType, other: Type, fromBelow: Boolean): Boolean = {
if (defn.isCompiletime_S(tp.tycon.typeSymbol)) compareS(tp, other, fromBelow)
else tp.tryCompiletimeConstantFold.exists(folded => recur(folded, other))
}

/** Like tp1 <:< tp2, but returns false immediately if we know that
* the case was covered previously during subtyping.
*/
Expand Down
73 changes: 63 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3595,19 +3595,72 @@ object Types {
case _ =>
NoType
}
if (defn.isCompiletime_S(tycon.symbol) && args.length == 1)
trace(i"normalize S $this", typr, show = true) {
args.head.normalized match {
case ConstantType(Constant(n: Int)) if n >= 0 && n < Int.MaxValue =>
ConstantType(Constant(n + 1))
case none => tryMatchAlias
}
}
else tryMatchAlias

tryCompiletimeConstantFold.getOrElse(tryMatchAlias)

case _ =>
NoType
}

def tryCompiletimeConstantFold(implicit ctx: Context): Option[Type] = tycon match {
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
def constValue(tp: Type): Option[Any] = tp match {
case ConstantType(Constant(n)) => Some(n)
case _ => None
}

def boolValue(tp: Type): Option[Boolean] = tp match {
case ConstantType(Constant(n: Boolean)) => Some(n)
case _ => None
}

def intValue(tp: Type): Option[Int] = tp match {
case ConstantType(Constant(n: Int)) => Some(n)
case _ => None
}

def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))

def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
for {
a <- extractor(args.head.normalized)
b <- extractor(args.tail.head.normalized)
} yield ConstantType(Constant(op(a, b)))

trace(i"compiletime constant fold $this", typr, show = true) {
if (args.length == 1) tycon.symbol.name match {
case tpnme.S => constantFold1(natValue, _ + 1)
case tpnme.Abs => constantFold1(intValue, _.abs)
case tpnme.Negate => constantFold1(intValue, x => -x)
case tpnme.Not => constantFold1(boolValue, x => !x)
case _ => None
} else if (args.length == 2) tycon.symbol.name match {
case tpnme.Equals => constantFold2(constValue, _ == _)
case tpnme.NotEquals => constantFold2(constValue, _ != _)
case tpnme.Plus => constantFold2(intValue, _ + _)
case tpnme.Minus => constantFold2(intValue, _ - _)
case tpnme.Times => constantFold2(intValue, _ * _)
case tpnme.Div => constantFold2(intValue, _ / _)
case tpnme.Mod => constantFold2(intValue, _ % _)
case tpnme.Lt => constantFold2(intValue, _ < _)
case tpnme.Gt => constantFold2(intValue, _ > _)
case tpnme.Ge => constantFold2(intValue, _ >= _)
case tpnme.Le => constantFold2(intValue, _ <= _)
case tpnme.Min => constantFold2(intValue, _ min _)
case tpnme.Max => constantFold2(intValue, _ max _)
case tpnme.And => constantFold2(boolValue, _ && _)
case tpnme.Or => constantFold2(boolValue, _ || _)
case tpnme.Xor => constantFold2(boolValue, _ ^ _)
case _ => None
} else None
}

case _ => None
}

def lowerBound(implicit ctx: Context): Type = tycon.stripTypeVar match {
case tycon: TypeRef =>
tycon.info match {
Expand Down Expand Up @@ -3974,7 +4027,7 @@ object Types {
myReduced =
trace(i"reduce match type $this $hashCode", typr, show = true) {
try
typeComparer.matchCases(scrutinee, cases)(trackingCtx)
typeComparer.matchCases(scrutinee.normalized, cases)(trackingCtx)
catch {
case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
Expand Down
8 changes: 8 additions & 0 deletions library/src/scala/compiletime/boolean/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package scala.compiletime

package object boolean {
type ![X <: Boolean] <: Boolean
type ^[X <: Boolean, Y <: Boolean] <: Boolean
type &&[X <: Boolean, Y <: Boolean] <: Boolean
type ||[X <: Boolean, Y <: Boolean] <: Boolean
}
19 changes: 19 additions & 0 deletions library/src/scala/compiletime/int/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package scala.compiletime

package object int {
type +[X <: Int, Y <: Int] <: Int
type -[X <: Int, Y <: Int] <: Int
type *[X <: Int, Y <: Int] <: Int
type /[X <: Int, Y <: Int] <: Int
type %[X <: Int, Y <: Int] <: Int

type <[X <: Int, Y <: Int] <: Boolean
type >[X <: Int, Y <: Int] <: Boolean
type >=[X <: Int, Y <: Int] <: Boolean
type <=[X <: Int, Y <: Int] <: Boolean

type Abs[X <: Int] <: Int
type Negate[X <: Int] <: Int
type Min[X <: Int, Y <: Int] <: Int
type Max[X <: Int, Y <: Int] <: Int
}
3 changes: 3 additions & 0 deletions library/src/scala/compiletime/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,7 @@ package object compiletime {
* }
*/
type S[N <: Int] <: Int

type ==[X <: AnyVal, Y <: AnyVal] <: Boolean
type !=[X <: AnyVal, Y <: AnyVal] <: Boolean
}
117 changes: 117 additions & 0 deletions tests/neg/compiletime-singleton-ops.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import scala.compiletime._
import scala.compiletime.int._
import scala.compiletime.boolean._

object Test {
val t0: 2 + 3 = 5
val t1: 2 + 2 = 5 // error
val t2: -1 + 1 = 0
val t3: -5 + -5 = -11 // error

val t4: 10 * 20 = 200
val t5: 30 * 10 = 400 // error
val t6: -10 * 2 = -20
val t7: -2 * -2 = 4

val t8: 10 / 2 = 5
val t9: 11 / 2 = 5 // Integer division
val t10: 2 / 4 = 2 // error
val t11: -1 / -1 = 1

val t12: 10 % 3 = 1
val t13: 12 % 2 = 1 // error
val t14: 1 % -3 = 1
val t15: -3 % -2 = 0 // error

val t16: 1 < 0 = false
val t17: 0 < 1 = true
val t18: 10 < 5 = true // error
val t19: 5 < 10 = false // error

val t20: 1 <= 0 = false
val t21: 1 <= 1 = true
val t22: 10 <= 5 = true // error
val t23: 5 <= 10 = false // error

val t24: 1 > 0 = true
val t25: 0 > 1 = false
val t26: 10 > 5 = false // error
val t27: 5 > 10 = true // error

val t28: 1 >= 1 = true
val t29: 0 >= 1 = false
val t30: 10 >= 5 = false // error
val t31: 5 >= 10 = true // error

val t32: 1 == 1 = true
val t33: 0 == 1 = false
val t34: 10 == 5 = true // error
val t35: 10 == 10 = false // error

val t36: 1 != 1 = false
val t37: 0 != 1 = true
val t38: 10 != 5 = false // error
val t39: 10 != 10 = true // error

val t40: Abs[0] = 0
val t41: Abs[-1] = 1
val t42: Abs[-1] = -1 // error
val t43: Abs[1] = -1 // error

val t44: Negate[-10] = 10
val t45: Negate[10] = -10
val t46: Negate[1] = 1 // error
val t47: Negate[-1] = -1 // error

val t48: Max[-1, 10] = 10
val t49: Max[4, 2] = 4
val t50: Max[2, 2] = 1 // error
val t51: Max[-1, -1] = 0 // error

val t52: Min[-1, 10] = -1
val t53: Min[4, 2] = 2
val t54: Min[2, 2] = 1 // error
val t55: Min[-1, -1] = 0 // error

val t56: true && true = true
val t57: true && false = false
val t58: false && true = true // error
val t59: false && false = true // error

val t60: true || true = true
val t61: true || false = true
val t62: false || true = false // error
val t63: false || false = true // error

val t64: ![true] = false
val t65: ![false] = true
val t66: ![true] = true // error
val t67: ![false] = false // error

// Test singleton ops in type alias:
type Xor[A <: Boolean, B <: Boolean] = (A && ![B]) || (![A] && B)
val t68: Xor[true, true] = false
val t69: Xor[false, true] = true
val t70: Xor[true, false] = false // error
val t71: Xor[false, false] = true // error

// Test singleton ops in recursive match types:
type GCD[A <: Int, B <: Int] <: Int = B match {
case 0 => A
case _ => GCD[B, A % B]
}
val t72: GCD[10, 0] = 10
val t73: GCD[252, 105] = 21
val t74: GCD[105, 147] = 10 // error
val t75: GCD[1, 1] = -1 // error

// Test singleton ops in match type scrutinee:
type Max2[A <: Int, B <: Int] <: Int = (A < B) match {
case true => B
case false => A
}
val t76: Max[-1, 10] = 10
val t77: Max[4, 2] = 4
val t78: Max[2, 2] = 1 // error
val t79: Max[-1, -1] = 0 // error
}