Skip to content

Intrinsify constValueTuple and summonAll #18013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,25 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

/** Creates the tuple containing the given elements */
def tupleTree(elems: List[Tree])(using Context): Tree = {
val arity = elems.length
if arity == 0 then
ref(defn.EmptyTupleModule)
else if arity <= Definitions.MaxTupleArity then
// TupleN[elem1Tpe, ...](elem1, ...)
ref(defn.TupleType(arity).nn.typeSymbol.companionModule)
.select(nme.apply)
.appliedToTypes(elems.map(_.tpe.widenIfUnstable))
.appliedToArgs(elems)
else
// TupleXXL.apply(elems*) // TODO add and use Tuple.apply(elems*) ?
ref(defn.TupleXXLModule)
.select(nme.apply)
.appliedToVarargs(elems.map(_.asInstance(defn.ObjectType)), TypeTree(defn.ObjectType))
.asInstance(defn.tupleType(elems.map(elem => elem.tpe.widenIfUnstable)))
}

/** Creates the tuple type tree representation of the type trees in `ts` */
def tupleTypeTree(elems: List[Tree])(using Context): Tree = {
val arity = elems.length
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ class Definitions {
@tu lazy val Compiletime_requireConst : Symbol = CompiletimePackageClass.requiredMethod("requireConst")
@tu lazy val Compiletime_constValue : Symbol = CompiletimePackageClass.requiredMethod("constValue")
@tu lazy val Compiletime_constValueOpt: Symbol = CompiletimePackageClass.requiredMethod("constValueOpt")
@tu lazy val Compiletime_constValueTuple: Symbol = CompiletimePackageClass.requiredMethod("constValueTuple")
@tu lazy val Compiletime_summonFrom : Symbol = CompiletimePackageClass.requiredMethod("summonFrom")
@tu lazy val Compiletime_summonInline : Symbol = CompiletimePackageClass.requiredMethod("summonInline")
@tu lazy val Compiletime_summonInline : Symbol = CompiletimePackageClass.requiredMethod("summonInline")
@tu lazy val Compiletime_summonAll : Symbol = CompiletimePackageClass.requiredMethod("summonAll")
@tu lazy val CompiletimeTestingPackage: Symbol = requiredPackage("scala.compiletime.testing")
@tu lazy val CompiletimeTesting_typeChecks: Symbol = CompiletimeTestingPackage.requiredMethod("typeChecks")
@tu lazy val CompiletimeTesting_typeCheckErrors: Symbol = CompiletimeTestingPackage.requiredMethod("typeCheckErrors")
Expand Down Expand Up @@ -932,6 +934,8 @@ class Definitions {
@tu lazy val TupleTypeRef: TypeRef = requiredClassRef("scala.Tuple")
def TupleClass(using Context): ClassSymbol = TupleTypeRef.symbol.asClass
@tu lazy val Tuple_cons: Symbol = TupleClass.requiredMethod("*:")
@tu lazy val TupleModule: Symbol = requiredModule("scala.Tuple")
@tu lazy val EmptyTupleClass: Symbol = requiredClass("scala.EmptyTuple")
@tu lazy val EmptyTupleModule: Symbol = requiredModule("scala.EmptyTuple")
@tu lazy val NonEmptyTupleTypeRef: TypeRef = requiredClassRef("scala.NonEmptyTuple")
def NonEmptyTupleClass(using Context): ClassSymbol = NonEmptyTupleTypeRef.symbol.asClass
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/inlines/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ class Inliner(val call: tpd.Tree)(using Context):
// assertAllPositioned(tree) // debug
tree.changeOwner(originalOwner, ctx.owner)

def tryConstValue: Tree =
TypeComparer.constValue(callTypeArgs.head.tpe) match {
def tryConstValue(tpe: Type): Tree =
TypeComparer.constValue(tpe) match {
case Some(c) => Literal(c).withSpan(call.span)
case _ => EmptyTree
}
Expand Down
69 changes: 50 additions & 19 deletions compiler/src/dotty/tools/dotc/inlines/Inlines.scala
Original file line number Diff line number Diff line change
Expand Up @@ -413,36 +413,67 @@ object Inlines:
return Intrinsics.codeOf(arg, call.srcPos)
case _ =>

// Special handling of `constValue[T]`, `constValueOpt[T], and summonInline[T]`
// Special handling of `constValue[T]`, `constValueOpt[T]`, `constValueTuple[T]`, `summonInline[T]` and `summonAll[T]`
if callTypeArgs.length == 1 then
if (inlinedMethod == defn.Compiletime_constValue) {
val constVal = tryConstValue

def constValueOrError(tpe: Type): Tree =
val constVal = tryConstValue(tpe)
if constVal.isEmpty then
val msg = NotConstant("cannot take constValue", callTypeArgs.head.tpe)
return ref(defn.Predef_undefined).withSpan(call.span).withType(ErrorType(msg))
val msg = NotConstant("cannot take constValue", tpe)
ref(defn.Predef_undefined).withSpan(callTypeArgs.head.span).withType(ErrorType(msg))
else
return constVal
constVal

def searchImplicitOrError(tpe: Type): Tree =
val evTyper = new Typer(ctx.nestingLevel + 1)
val evCtx = ctx.fresh.setTyper(evTyper)
inContext(evCtx) {
val evidence = evTyper.inferImplicitArg(tpe, callTypeArgs.head.span)
evidence.tpe match
case fail: Implicits.SearchFailureType =>
errorTree(call, evTyper.missingArgMsg(evidence, tpe, ""))
case _ =>
evidence
}

def unrollTupleTypes(tpe: Type): Option[List[Type]] = tpe.dealias match
case AppliedType(tycon, args) if defn.isTupleClass(tycon.typeSymbol) =>
Some(args)
case AppliedType(tycon, head :: tail :: Nil) if tycon.isRef(defn.PairClass) =>
unrollTupleTypes(tail).map(head :: _)
case tpe: TermRef if tpe.symbol == defn.EmptyTupleModule =>
Some(Nil)
case _ =>
None

if (inlinedMethod == defn.Compiletime_constValue) {
return constValueOrError(callTypeArgs.head.tpe)
}
else if (inlinedMethod == defn.Compiletime_constValueOpt) {
val constVal = tryConstValue
val constVal = tryConstValue(callTypeArgs.head.tpe)
return (
if (constVal.isEmpty) ref(defn.NoneModule.termRef)
else New(defn.SomeClass.typeRef.appliedTo(constVal.tpe), constVal :: Nil)
)
}
else if (inlinedMethod == defn.Compiletime_constValueTuple) {
unrollTupleTypes(callTypeArgs.head.tpe) match
case Some(types) =>
val constants = types.map(constValueOrError)
return Typed(tpd.tupleTree(constants), TypeTree(callTypeArgs.head.tpe)).withSpan(call.span)
case _ =>
return errorTree(call, em"Tuple element types must be known at compile time")
}
else if (inlinedMethod == defn.Compiletime_summonInline) {
def searchImplicit(tpt: Tree) =
val evTyper = new Typer(ctx.nestingLevel + 1)
val evCtx = ctx.fresh.setTyper(evTyper)
inContext(evCtx) {
val evidence = evTyper.inferImplicitArg(tpt.tpe, tpt.span)
evidence.tpe match
case fail: Implicits.SearchFailureType =>
errorTree(call, evTyper.missingArgMsg(evidence, tpt.tpe, ""))
case _ =>
evidence
}
return searchImplicit(callTypeArgs.head)
return searchImplicitOrError(callTypeArgs.head.tpe)
}
else if (inlinedMethod == defn.Compiletime_summonAll) {
unrollTupleTypes(callTypeArgs.head.tpe) match
case Some(types) =>
val implicits = types.map(searchImplicitOrError)
return Typed(tpd.tupleTree(implicits), TypeTree(callTypeArgs.head.tpe)).withSpan(call.span)
case _ =>
return errorTree(call, em"Tuple element types must be known at compile time")
}
end if

Expand Down
19 changes: 5 additions & 14 deletions library/src/scala/compiletime/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,9 @@ transparent inline def constValue[T]: T =
* `(constValue[X1], ..., constValue[Xn])`.
*/
inline def constValueTuple[T <: Tuple]: T =
val res =
inline erasedValue[T] match
case _: EmptyTuple => EmptyTuple
case _: (t *: ts) => constValue[t] *: constValueTuple[ts]
end match
res.asInstanceOf[T]
end constValueTuple
// implemented in dotty.tools.dotc.typer.Inliner
error("Compiler bug: `constValueTuple` was not evaluated by the compiler")


/** Summons first given matching one of the listed cases. E.g. in
*
Expand Down Expand Up @@ -168,13 +164,8 @@ transparent inline def summonInline[T]: T =
* @return the given values typed as elements of the tuple
*/
inline def summonAll[T <: Tuple]: T =
val res =
inline erasedValue[T] match
case _: EmptyTuple => EmptyTuple
case _: (t *: ts) => summonInline[t] *: summonAll[ts]
end match
res.asInstanceOf[T]
end summonAll
// implemented in dotty.tools.dotc.typer.Inliner
error("Compiler bug: `summonAll` was not evaluated by the compiler")

/** Assertion that an argument is by-name. Used for nullability checking. */
def byName[T](x: => T): T = x
Expand Down
20 changes: 10 additions & 10 deletions tests/neg/17211.check
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
-- [E182] Type Error: tests/neg/17211.scala:14:12 ----------------------------------------------------------------------
-- [E182] Type Error: tests/neg/17211.scala:14:13 ----------------------------------------------------------------------
14 | constValue[IsInt[Foo.Foo]] // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
| IsInt[Foo.Foo] is not a constant type; cannot take constValue
| ^^^^^^^^^^^^^^
| IsInt[Foo.Foo] is not a constant type; cannot take constValue
|
| Note: a match type could not be fully reduced:
| Note: a match type could not be fully reduced:
|
| trying to reduce IsInt[Foo.Foo]
| failed since selector Foo.Foo
| does not match case Int => (true : Boolean)
| and cannot be shown to be disjoint from it either.
| Therefore, reduction cannot advance to the remaining case
| trying to reduce IsInt[Foo.Foo]
| failed since selector Foo.Foo
| does not match case Int => (true : Boolean)
| and cannot be shown to be disjoint from it either.
| Therefore, reduction cannot advance to the remaining case
|
| case _ => (false : Boolean)
| case _ => (false : Boolean)
2 changes: 1 addition & 1 deletion tests/neg/i14177a.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ import scala.compiletime.*
trait C[A]

inline given [Tup <: Tuple]: C[Tup] with
val cs = summonAll[Tuple.Map[Tup, C]] // error cannot reduce inline match with
val cs = summonAll[Tuple.Map[Tup, C]] // error: Tuple element types must be known at compile time
6 changes: 6 additions & 0 deletions tests/run/i15988a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import scala.compiletime.constValueTuple

@main def Test: Unit =
assert(constValueTuple[EmptyTuple] == EmptyTuple)
assert(constValueTuple[("foo", 5, 3.14, "bar", false)] == ("foo", 5, 3.14, "bar", false))
assert(constValueTuple[(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)] == (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23))
21 changes: 21 additions & 0 deletions tests/run/i15988b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import scala.compiletime.summonAll

@main def Test: Unit =
assert(summonAll[EmptyTuple] == EmptyTuple)
assert(summonAll[(5, 5, 5)] == (5, 5, 5))
assert(
summonAll[(
5, 5, 5, 5, 5,
5, 5, 5, 5, 5,
5, 5, 5, 5, 5,
5, 5, 5, 5, 5,
5, 5, 5, 5, 5,
)] == (
5, 5, 5, 5, 5,
5, 5, 5, 5, 5,
5, 5, 5, 5, 5,
5, 5, 5, 5, 5,
5, 5, 5, 5, 5,
))

given 5 = 5