Skip to content

Commit a66cb8d

Browse files
Merge pull request #12027 from dotty-staging/fix-#11996
Add missing Erased flag to inline bindings
2 parents 104f437 + 280109e commit a66cb8d

File tree

7 files changed

+41
-3
lines changed

7 files changed

+41
-3
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,7 @@ class Definitions {
890890
@tu lazy val ImplicitAmbiguousAnnot: ClassSymbol = requiredClass("scala.annotation.implicitAmbiguous")
891891
@tu lazy val ImplicitNotFoundAnnot: ClassSymbol = requiredClass("scala.annotation.implicitNotFound")
892892
@tu lazy val InlineParamAnnot: ClassSymbol = requiredClass("scala.annotation.internal.InlineParam")
893+
@tu lazy val ErasedParamAnnot: ClassSymbol = requiredClass("scala.annotation.internal.ErasedParam")
893894
@tu lazy val InvariantBetweenAnnot: ClassSymbol = requiredClass("scala.annotation.internal.InvariantBetween")
894895
@tu lazy val MainAnnot: ClassSymbol = requiredClass("scala.main")
895896
@tu lazy val MigrationAnnot: ClassSymbol = requiredClass("scala.annotation.migration")

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3640,9 +3640,15 @@ object Types {
36403640
case ExprType(resType) => ExprType(AnnotatedType(resType, Annotation(defn.InlineParamAnnot)))
36413641
case _ => AnnotatedType(tp, Annotation(defn.InlineParamAnnot))
36423642
}
3643+
def translateErased(tp: Type): Type = tp match {
3644+
case ExprType(resType) => ExprType(AnnotatedType(resType, Annotation(defn.ErasedParamAnnot)))
3645+
case _ => AnnotatedType(tp, Annotation(defn.ErasedParamAnnot))
3646+
}
36433647
def paramInfo(param: Symbol) = {
3644-
val paramType = param.info.annotatedToRepeated
3645-
if (param.is(Inline)) translateInline(paramType) else paramType
3648+
var paramType = param.info.annotatedToRepeated
3649+
if (param.is(Inline)) paramType = translateInline(paramType)
3650+
if (param.is(Erased)) paramType = translateErased(paramType)
3651+
paramType
36463652
}
36473653

36483654
apply(params.map(_.name.asTermName))(

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
220220
toTextGlobal(tp.resultType)
221221
}
222222
case AnnotatedType(tpe, annot) =>
223-
if annot.symbol == defn.InlineParamAnnot then toText(tpe)
223+
if annot.symbol == defn.InlineParamAnnot || annot.symbol == defn.ErasedParamAnnot then toText(tpe)
224224
else toTextLocal(tpe) ~ " " ~ toText(annot)
225225
case tp: TypeVar =>
226226
if (tp.isInstantiated)

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,8 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
467467
var bindingFlags: FlagSet = InlineProxy
468468
if formal.widenExpr.hasAnnotation(defn.InlineParamAnnot) then
469469
bindingFlags |= Inline
470+
if formal.widenExpr.hasAnnotation(defn.ErasedParamAnnot) then
471+
bindingFlags |= Erased
470472
if isByName then
471473
bindingFlags |= Method
472474
val boundSym = newSym(InlineBinderName.fresh(name.asTermName), bindingFlags, bindingType).asTerm
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package scala.annotation.internal
2+
3+
import scala.annotation.Annotation
4+
5+
/** An annotation produced by Namer to indicate an erased parameter */
6+
final class ErasedParam() extends Annotation
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
final class UnivEq[A]
2+
3+
object UnivEq:
4+
erased def force[A]: UnivEq[A] =
5+
compiletime.erasedValue
6+
7+
extension [A](a: A)
8+
inline def ==*[B >: A](b: B)(using erased UnivEq[B]): Boolean = a == b
9+
inline def !=*[B >: A](b: B)(using erased UnivEq[B]): Boolean = a != b
10+
11+
case class I(i: Int)
12+
13+
@main def Test = {
14+
def test[A](a: A, b: A): Unit = {
15+
erased given UnivEq[A] = UnivEq.force[A]
16+
println(a ==* a)
17+
println(a !=* b)
18+
}
19+
println("Test starting...")
20+
test(I(1), I(2)) // error
21+
test(1, 2)
22+
test(true, false)
23+
}

0 commit comments

Comments
 (0)