Skip to content

Commit b03a648

Browse files
Intrinsify constValueTuple and summonAll (#18013)
The new implementation instantiates the TupleN/TupleXXL classes directly. This avoids the expensive unrolling of tuples using `*:` recursively. Fixes #15988
2 parents 1637282 + 9d031e2 commit b03a648

File tree

9 files changed

+119
-47
lines changed

9 files changed

+119
-47
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,6 +1520,25 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
15201520
}
15211521
}
15221522

1523+
/** Creates the tuple containing the given elements */
1524+
def tupleTree(elems: List[Tree])(using Context): Tree = {
1525+
val arity = elems.length
1526+
if arity == 0 then
1527+
ref(defn.EmptyTupleModule)
1528+
else if arity <= Definitions.MaxTupleArity then
1529+
// TupleN[elem1Tpe, ...](elem1, ...)
1530+
ref(defn.TupleType(arity).nn.typeSymbol.companionModule)
1531+
.select(nme.apply)
1532+
.appliedToTypes(elems.map(_.tpe.widenIfUnstable))
1533+
.appliedToArgs(elems)
1534+
else
1535+
// TupleXXL.apply(elems*) // TODO add and use Tuple.apply(elems*) ?
1536+
ref(defn.TupleXXLModule)
1537+
.select(nme.apply)
1538+
.appliedToVarargs(elems.map(_.asInstance(defn.ObjectType)), TypeTree(defn.ObjectType))
1539+
.asInstance(defn.tupleType(elems.map(elem => elem.tpe.widenIfUnstable)))
1540+
}
1541+
15231542
/** Creates the tuple type tree representation of the type trees in `ts` */
15241543
def tupleTypeTree(elems: List[Tree])(using Context): Tree = {
15251544
val arity = elems.length

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,10 @@ class Definitions {
243243
@tu lazy val Compiletime_requireConst : Symbol = CompiletimePackageClass.requiredMethod("requireConst")
244244
@tu lazy val Compiletime_constValue : Symbol = CompiletimePackageClass.requiredMethod("constValue")
245245
@tu lazy val Compiletime_constValueOpt: Symbol = CompiletimePackageClass.requiredMethod("constValueOpt")
246+
@tu lazy val Compiletime_constValueTuple: Symbol = CompiletimePackageClass.requiredMethod("constValueTuple")
246247
@tu lazy val Compiletime_summonFrom : Symbol = CompiletimePackageClass.requiredMethod("summonFrom")
247-
@tu lazy val Compiletime_summonInline : Symbol = CompiletimePackageClass.requiredMethod("summonInline")
248+
@tu lazy val Compiletime_summonInline : Symbol = CompiletimePackageClass.requiredMethod("summonInline")
249+
@tu lazy val Compiletime_summonAll : Symbol = CompiletimePackageClass.requiredMethod("summonAll")
248250
@tu lazy val CompiletimeTestingPackage: Symbol = requiredPackage("scala.compiletime.testing")
249251
@tu lazy val CompiletimeTesting_typeChecks: Symbol = CompiletimeTestingPackage.requiredMethod("typeChecks")
250252
@tu lazy val CompiletimeTesting_typeCheckErrors: Symbol = CompiletimeTestingPackage.requiredMethod("typeCheckErrors")
@@ -932,6 +934,8 @@ class Definitions {
932934
@tu lazy val TupleTypeRef: TypeRef = requiredClassRef("scala.Tuple")
933935
def TupleClass(using Context): ClassSymbol = TupleTypeRef.symbol.asClass
934936
@tu lazy val Tuple_cons: Symbol = TupleClass.requiredMethod("*:")
937+
@tu lazy val TupleModule: Symbol = requiredModule("scala.Tuple")
938+
@tu lazy val EmptyTupleClass: Symbol = requiredClass("scala.EmptyTuple")
935939
@tu lazy val EmptyTupleModule: Symbol = requiredModule("scala.EmptyTuple")
936940
@tu lazy val NonEmptyTupleTypeRef: TypeRef = requiredClassRef("scala.NonEmptyTuple")
937941
def NonEmptyTupleClass(using Context): ClassSymbol = NonEmptyTupleTypeRef.symbol.asClass

compiler/src/dotty/tools/dotc/inlines/Inliner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,8 @@ class Inliner(val call: tpd.Tree)(using Context):
497497
// assertAllPositioned(tree) // debug
498498
tree.changeOwner(originalOwner, ctx.owner)
499499

500-
def tryConstValue: Tree =
501-
TypeComparer.constValue(callTypeArgs.head.tpe) match {
500+
def tryConstValue(tpe: Type): Tree =
501+
TypeComparer.constValue(tpe) match {
502502
case Some(c) => Literal(c).withSpan(call.span)
503503
case _ => EmptyTree
504504
}

compiler/src/dotty/tools/dotc/inlines/Inlines.scala

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -413,36 +413,67 @@ object Inlines:
413413
return Intrinsics.codeOf(arg, call.srcPos)
414414
case _ =>
415415

416-
// Special handling of `constValue[T]`, `constValueOpt[T], and summonInline[T]`
416+
// Special handling of `constValue[T]`, `constValueOpt[T]`, `constValueTuple[T]`, `summonInline[T]` and `summonAll[T]`
417417
if callTypeArgs.length == 1 then
418-
if (inlinedMethod == defn.Compiletime_constValue) {
419-
val constVal = tryConstValue
418+
419+
def constValueOrError(tpe: Type): Tree =
420+
val constVal = tryConstValue(tpe)
420421
if constVal.isEmpty then
421-
val msg = NotConstant("cannot take constValue", callTypeArgs.head.tpe)
422-
return ref(defn.Predef_undefined).withSpan(call.span).withType(ErrorType(msg))
422+
val msg = NotConstant("cannot take constValue", tpe)
423+
ref(defn.Predef_undefined).withSpan(callTypeArgs.head.span).withType(ErrorType(msg))
423424
else
424-
return constVal
425+
constVal
426+
427+
def searchImplicitOrError(tpe: Type): Tree =
428+
val evTyper = new Typer(ctx.nestingLevel + 1)
429+
val evCtx = ctx.fresh.setTyper(evTyper)
430+
inContext(evCtx) {
431+
val evidence = evTyper.inferImplicitArg(tpe, callTypeArgs.head.span)
432+
evidence.tpe match
433+
case fail: Implicits.SearchFailureType =>
434+
errorTree(call, evTyper.missingArgMsg(evidence, tpe, ""))
435+
case _ =>
436+
evidence
437+
}
438+
439+
def unrollTupleTypes(tpe: Type): Option[List[Type]] = tpe.dealias match
440+
case AppliedType(tycon, args) if defn.isTupleClass(tycon.typeSymbol) =>
441+
Some(args)
442+
case AppliedType(tycon, head :: tail :: Nil) if tycon.isRef(defn.PairClass) =>
443+
unrollTupleTypes(tail).map(head :: _)
444+
case tpe: TermRef if tpe.symbol == defn.EmptyTupleModule =>
445+
Some(Nil)
446+
case _ =>
447+
None
448+
449+
if (inlinedMethod == defn.Compiletime_constValue) {
450+
return constValueOrError(callTypeArgs.head.tpe)
425451
}
426452
else if (inlinedMethod == defn.Compiletime_constValueOpt) {
427-
val constVal = tryConstValue
453+
val constVal = tryConstValue(callTypeArgs.head.tpe)
428454
return (
429455
if (constVal.isEmpty) ref(defn.NoneModule.termRef)
430456
else New(defn.SomeClass.typeRef.appliedTo(constVal.tpe), constVal :: Nil)
431457
)
432458
}
459+
else if (inlinedMethod == defn.Compiletime_constValueTuple) {
460+
unrollTupleTypes(callTypeArgs.head.tpe) match
461+
case Some(types) =>
462+
val constants = types.map(constValueOrError)
463+
return Typed(tpd.tupleTree(constants), TypeTree(callTypeArgs.head.tpe)).withSpan(call.span)
464+
case _ =>
465+
return errorTree(call, em"Tuple element types must be known at compile time")
466+
}
433467
else if (inlinedMethod == defn.Compiletime_summonInline) {
434-
def searchImplicit(tpt: Tree) =
435-
val evTyper = new Typer(ctx.nestingLevel + 1)
436-
val evCtx = ctx.fresh.setTyper(evTyper)
437-
inContext(evCtx) {
438-
val evidence = evTyper.inferImplicitArg(tpt.tpe, tpt.span)
439-
evidence.tpe match
440-
case fail: Implicits.SearchFailureType =>
441-
errorTree(call, evTyper.missingArgMsg(evidence, tpt.tpe, ""))
442-
case _ =>
443-
evidence
444-
}
445-
return searchImplicit(callTypeArgs.head)
468+
return searchImplicitOrError(callTypeArgs.head.tpe)
469+
}
470+
else if (inlinedMethod == defn.Compiletime_summonAll) {
471+
unrollTupleTypes(callTypeArgs.head.tpe) match
472+
case Some(types) =>
473+
val implicits = types.map(searchImplicitOrError)
474+
return Typed(tpd.tupleTree(implicits), TypeTree(callTypeArgs.head.tpe)).withSpan(call.span)
475+
case _ =>
476+
return errorTree(call, em"Tuple element types must be known at compile time")
446477
}
447478
end if
448479

library/src/scala/compiletime/package.scala

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,9 @@ transparent inline def constValue[T]: T =
117117
* `(constValue[X1], ..., constValue[Xn])`.
118118
*/
119119
inline def constValueTuple[T <: Tuple]: T =
120-
val res =
121-
inline erasedValue[T] match
122-
case _: EmptyTuple => EmptyTuple
123-
case _: (t *: ts) => constValue[t] *: constValueTuple[ts]
124-
end match
125-
res.asInstanceOf[T]
126-
end constValueTuple
120+
// implemented in dotty.tools.dotc.typer.Inliner
121+
error("Compiler bug: `constValueTuple` was not evaluated by the compiler")
122+
127123

128124
/** Summons first given matching one of the listed cases. E.g. in
129125
*
@@ -168,13 +164,8 @@ transparent inline def summonInline[T]: T =
168164
* @return the given values typed as elements of the tuple
169165
*/
170166
inline def summonAll[T <: Tuple]: T =
171-
val res =
172-
inline erasedValue[T] match
173-
case _: EmptyTuple => EmptyTuple
174-
case _: (t *: ts) => summonInline[t] *: summonAll[ts]
175-
end match
176-
res.asInstanceOf[T]
177-
end summonAll
167+
// implemented in dotty.tools.dotc.typer.Inliner
168+
error("Compiler bug: `summonAll` was not evaluated by the compiler")
178169

179170
/** Assertion that an argument is by-name. Used for nullability checking. */
180171
def byName[T](x: => T): T = x

tests/neg/17211.check

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
-- [E182] Type Error: tests/neg/17211.scala:14:12 ----------------------------------------------------------------------
1+
-- [E182] Type Error: tests/neg/17211.scala:14:13 ----------------------------------------------------------------------
22
14 | constValue[IsInt[Foo.Foo]] // error
3-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
4-
| IsInt[Foo.Foo] is not a constant type; cannot take constValue
3+
| ^^^^^^^^^^^^^^
4+
| IsInt[Foo.Foo] is not a constant type; cannot take constValue
55
|
6-
| Note: a match type could not be fully reduced:
6+
| Note: a match type could not be fully reduced:
77
|
8-
| trying to reduce IsInt[Foo.Foo]
9-
| failed since selector Foo.Foo
10-
| does not match case Int => (true : Boolean)
11-
| and cannot be shown to be disjoint from it either.
12-
| Therefore, reduction cannot advance to the remaining case
8+
| trying to reduce IsInt[Foo.Foo]
9+
| failed since selector Foo.Foo
10+
| does not match case Int => (true : Boolean)
11+
| and cannot be shown to be disjoint from it either.
12+
| Therefore, reduction cannot advance to the remaining case
1313
|
14-
| case _ => (false : Boolean)
14+
| case _ => (false : Boolean)

tests/neg/i14177a.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ import scala.compiletime.*
33
trait C[A]
44

55
inline given [Tup <: Tuple]: C[Tup] with
6-
val cs = summonAll[Tuple.Map[Tup, C]] // error cannot reduce inline match with
6+
val cs = summonAll[Tuple.Map[Tup, C]] // error: Tuple element types must be known at compile time

tests/run/i15988a.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import scala.compiletime.constValueTuple
2+
3+
@main def Test: Unit =
4+
assert(constValueTuple[EmptyTuple] == EmptyTuple)
5+
assert(constValueTuple[("foo", 5, 3.14, "bar", false)] == ("foo", 5, 3.14, "bar", false))
6+
assert(constValueTuple[(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)] == (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23))

tests/run/i15988b.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import scala.compiletime.summonAll
2+
3+
@main def Test: Unit =
4+
assert(summonAll[EmptyTuple] == EmptyTuple)
5+
assert(summonAll[(5, 5, 5)] == (5, 5, 5))
6+
assert(
7+
summonAll[(
8+
5, 5, 5, 5, 5,
9+
5, 5, 5, 5, 5,
10+
5, 5, 5, 5, 5,
11+
5, 5, 5, 5, 5,
12+
5, 5, 5, 5, 5,
13+
)] == (
14+
5, 5, 5, 5, 5,
15+
5, 5, 5, 5, 5,
16+
5, 5, 5, 5, 5,
17+
5, 5, 5, 5, 5,
18+
5, 5, 5, 5, 5,
19+
))
20+
21+
given 5 = 5

0 commit comments

Comments
 (0)