Skip to content

Commit 652553d

Browse files
committed
Intercepted methods transformer
Replace member references for: methods inside Any( == and !=) ## on primitives .getClass on primitives
1 parent ea5acb5 commit 652553d

File tree

6 files changed

+159
-8
lines changed

6 files changed

+159
-8
lines changed

src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Compiler {
2121
List(
2222
List(new FrontEnd),
2323
List(new LazyValsCreateCompanionObjects, new PatternMatcher), //force separataion between lazyVals and LVCreateCO
24-
List(new LazyValTranformContext().transformer, new Splitter, new TypeTestsCasts),
24+
List(new LazyValTranformContext().transformer, new Splitter, new TypeTestsCasts, new InterceptedMethods),
2525
List(new Erasure),
2626
List(new UncurryTreeTransform)
2727
)

src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ object Contexts {
205205

206206
final def withPhase(phase: Phase): Context =
207207
withPhase(phase.id)
208-
/** If -Ydebug is on, the top of the stack trace where this context
208+
209+
/** If -Ydebug is on, the top of the stack trace where this context
209210
* was created, otherwise `null`.
210211
*/
211212
private var creationTrace: Array[StackTraceElement] = _
@@ -298,6 +299,7 @@ object Contexts {
298299
setCreationTrace()
299300
this
300301
}
302+
301303
/** A fresh clone of this context. */
302304
def fresh: FreshContext = clone.asInstanceOf[FreshContext].init(this)
303305

src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class Definitions {
126126

127127
lazy val AnyValClass: ClassSymbol = ctx.requiredClass("scala.AnyVal")
128128

129+
lazy val AnyVal_getClass = AnyValClass.requiredMethod(nme.getClass_)
129130
lazy val Any_== = newMethod(AnyClass, nme.EQ, methOfAny(BooleanType), Final)
130131
lazy val Any_!= = newMethod(AnyClass, nme.NE, methOfAny(BooleanType), Final)
131132
lazy val Any_equals = newMethod(AnyClass, nme.equals_, methOfAny(BooleanType))
@@ -154,6 +155,7 @@ class Definitions {
154155
ScalaPackageClass, tpnme.Null, AbstractFinal, List(ObjectClass.typeRef))
155156

156157
lazy val ScalaPredefModule = ctx.requiredModule("scala.Predef")
158+
lazy val ScalaRuntimeModule = ctx.requiredModule("scala.runtime.ScalaRunTime")
157159
lazy val DottyPredefModule = ctx.requiredModule("dotty.DottyPredef")
158160
lazy val NilModule = ctx.requiredModule("scala.collection.immutable.Nil")
159161

@@ -170,14 +172,20 @@ class Definitions {
170172

171173
lazy val UnitClass = valueClassSymbol("scala.Unit", BoxedUnitClass, java.lang.Void.TYPE, UnitEnc)
172174
lazy val BooleanClass = valueClassSymbol("scala.Boolean", BoxedBooleanClass, java.lang.Boolean.TYPE, BooleanEnc)
173-
175+
lazy val Boolean_! = BooleanClass.requiredMethod(nme.UNARY_!)
174176
lazy val Boolean_and = BooleanClass.requiredMethod(nme.ZAND)
175-
177+
176178
lazy val ByteClass = valueClassSymbol("scala.Byte", BoxedByteClass, java.lang.Byte.TYPE, ByteEnc)
177179
lazy val ShortClass = valueClassSymbol("scala.Short", BoxedShortClass, java.lang.Short.TYPE, ShortEnc)
178180
lazy val CharClass = valueClassSymbol("scala.Char", BoxedCharClass, java.lang.Character.TYPE, CharEnc)
179181
lazy val IntClass = valueClassSymbol("scala.Int", BoxedIntClass, java.lang.Integer.TYPE, IntEnc)
180182
lazy val LongClass = valueClassSymbol("scala.Long", BoxedLongClass, java.lang.Long.TYPE, LongEnc)
183+
lazy val Long_XOR_Long = LongClass.info.member(nme.XOR).requiredSymbol(
184+
x => (x is Method) && (x.info.firstParamTypes.head isRef defn.LongClass)
185+
)
186+
lazy val Long_LSR_Int = LongClass.info.member(nme.LSR).requiredSymbol(
187+
x => (x is Method) && (x.info.firstParamTypes.head isRef defn.IntClass)
188+
)
181189
lazy val FloatClass = valueClassSymbol("scala.Float", BoxedFloatClass, java.lang.Float.TYPE, FloatEnc)
182190
lazy val DoubleClass = valueClassSymbol("scala.Double", BoxedDoubleClass, java.lang.Double.TYPE, DoubleEnc)
183191

@@ -421,16 +429,16 @@ class Definitions {
421429

422430
// ----- primitive value class machinery ------------------------------------------
423431

424-
lazy val ScalaValueClasses: collection.Set[Symbol] = Set(
425-
UnitClass,
426-
BooleanClass,
432+
lazy val ScalaNumericValueClasses: collection.Set[Symbol] = Set(
427433
ByteClass,
428434
ShortClass,
429435
CharClass,
430436
IntClass,
431437
LongClass,
432438
FloatClass,
433439
DoubleClass)
440+
441+
lazy val ScalaValueClasses: collection.Set[Symbol] = ScalaNumericValueClasses + UnitClass + BooleanClass
434442

435443
lazy val ScalaBoxedClasses = ScalaValueClasses map boxedClass
436444

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import TreeTransforms._
5+
import core.DenotTransformers._
6+
import core.Denotations._
7+
import core.SymDenotations._
8+
import core.Contexts._
9+
import core.Types._
10+
import ast.Trees._
11+
import ast.tpd.{Apply, Tree, cpy}
12+
import dotty.tools.dotc.ast.tpd
13+
import scala.collection.mutable
14+
import dotty.tools.dotc._
15+
import core._
16+
import Contexts._
17+
import Symbols._
18+
import Decorators._
19+
import NameOps._
20+
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransformer, TreeTransform}
21+
import dotty.tools.dotc.ast.Trees._
22+
import dotty.tools.dotc.ast.{untpd, tpd}
23+
import dotty.tools.dotc.core.Constants.Constant
24+
import dotty.tools.dotc.core.Types.MethodType
25+
import dotty.tools.dotc.core.Names.Name
26+
import dotty.runtime.LazyVals
27+
import scala.collection.mutable.ListBuffer
28+
import dotty.tools.dotc.core.Denotations.SingleDenotation
29+
import dotty.tools.dotc.core.SymDenotations.SymDenotation
30+
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
31+
import StdNames._
32+
33+
/** Replace member references as follows:
34+
*
35+
* - `x == y` for == in class Any becomes `x equals y` with equals in class Object.
36+
* - `x != y` for != in class Any becomes `!(x equals y)` with equals in class Object.
37+
* - `x.##` for ## in other classes becomes calls to ScalaRunTime.hash,
38+
* using the most precise overload available
39+
* - `x.getClass` for getClass in primitives becomes `x.getClass` with getClass in class Object.
40+
* - `x.isInstanceOf[O]` if O is object becomes `x eq O` (reference equality)
41+
*/
42+
class InterceptedMethods extends TreeTransform {
43+
44+
import tpd._
45+
46+
override def name: String = "intercepted"
47+
48+
private var getClassMethods: Set[Symbol] = _
49+
private var poundPoundMethods: Set[Symbol] = _
50+
private var Any_comparisons: Set[Symbol] = _
51+
private var interceptedMethods: Set[Symbol] = _
52+
private var primitiveGetClassMethods: Set[Symbol] = _
53+
54+
/** perform context-dependant initialization */
55+
override def init(implicit ctx: Context, info: TransformerInfo): Unit = {
56+
getClassMethods = Set(defn.Any_getClass, defn.AnyVal_getClass)
57+
poundPoundMethods = Set(defn.Any_##, defn.Object_##)
58+
Any_comparisons = Set(defn.Any_==, defn.Any_!=)
59+
interceptedMethods = getClassMethods ++ poundPoundMethods ++ Any_comparisons
60+
primitiveGetClassMethods = Set[Symbol](defn.Any_getClass, defn.AnyVal_getClass) ++ defn.ScalaValueClasses.map(x => x.requiredMethod(nme.getClass_))
61+
}
62+
63+
64+
// this should be removed if we have guarantee that ## will get Apply node
65+
override def transformSelect(tree: tpd.Select)(implicit ctx: Context, info: TransformerInfo): Tree = {
66+
if (tree.symbol.isTerm && poundPoundMethods.contains(tree.symbol.asTerm)) {
67+
val rewrite = PoundPoundValue(tree.qualifier)
68+
ctx.log(s"$name rewrote $tree to $rewrite")
69+
rewrite
70+
}
71+
else tree
72+
}
73+
74+
private def PoundPoundValue(tree: Tree)(implicit ctx: Context) = {
75+
val s = tree.tpe.widen.typeSymbol
76+
if (s == defn.NullClass) Literal(Constant(0))
77+
else {
78+
// Since we are past typer, we need to avoid creating trees carrying
79+
// overloaded types. This logic is custom (and technically incomplete,
80+
// although serviceable) for def hash. What is really needed is for
81+
// the overloading logic presently hidden away in a few different
82+
// places to be properly exposed so we can just call "resolveOverload"
83+
// after typer. Until then:
84+
def alts = defn.ScalaRuntimeModule.info.member(nme.hash_).alternatives
85+
def alt1 = alts find (_.info.firstParamTypes.head =:= tree.tpe)
86+
def alt2 = defn.ScalaRuntimeModule.info.member(nme.hash_)
87+
.suchThat (_.info.firstParamTypes.head.typeSymbol == defn.AnyClass)
88+
89+
tpd.Apply(Ident(alt1.getOrElse(alt2).termRef), List(tree))
90+
}
91+
}
92+
93+
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo): Tree = {
94+
def unknown = {
95+
ctx.debugwarn(s"The symbol '${tree.fun.symbol}' was interecepted but didn't match any cases, " +
96+
s"that means the intercepted methods set doesn't match the code")
97+
tree
98+
}
99+
if (tree.fun.symbol.isTerm && tree.args.isEmpty &&
100+
(interceptedMethods contains tree.fun.symbol.asTerm)) {
101+
val rewrite: Tree = tree.fun match {
102+
case Select(qual, name) =>
103+
if (poundPoundMethods contains tree.fun.symbol.asTerm) {
104+
PoundPoundValue(qual)
105+
} else if (Any_comparisons contains tree.fun.symbol.asTerm) {
106+
if (tree.fun.symbol eq defn.Any_==) {
107+
Apply(Select(qual, defn.Object_equals.termRef), tree.args)
108+
} else if (tree.fun.symbol eq defn.Any_!=) {
109+
Select(Apply(Select(qual, defn.Object_equals.termRef), tree.args), defn.Boolean_!.termRef)
110+
} else unknown
111+
} /* else if (isPrimitiveValueClass(qual.tpe.typeSymbol)) {
112+
// todo: this is needed to support value classes
113+
// Rewrite 5.getClass to ScalaRunTime.anyValClass(5)
114+
global.typer.typed(gen.mkRuntimeCall(nme.anyValClass,
115+
List(qual, typer.resolveClassTag(tree.pos, qual.tpe.widen))))
116+
}*/
117+
else if (primitiveGetClassMethods.contains(tree.fun.symbol)) {
118+
// if we got here then we're trying to send a primitive getClass method to either
119+
// a) an Any, in which cage Object_getClass works because Any erases to object. Or
120+
//
121+
// b) a non-primitive, e.g. because the qualifier's type is a refinement type where one parent
122+
// of the refinement is a primitive and another is AnyRef. In that case
123+
// we get a primitive form of _getClass trying to target a boxed value
124+
// so we need replace that method name with Object_getClass to get correct behavior.
125+
// See SI-5568.
126+
Apply(Select(qual, defn.Object_getClass.termRef), Nil)
127+
} else {
128+
unknown
129+
}
130+
case _ =>
131+
unknown
132+
}
133+
ctx.log(s"$name rewrote $tree to $rewrite")
134+
rewrite
135+
}
136+
else tree
137+
}
138+
}

src/dotty/tools/dotc/transform/TypeTestsCasts.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ import typer.ErrorReporting._
1515
import ast.Trees._
1616
import Erasure.Boxing.box
1717

18-
/** This transform normalizes type tests and type casts.
18+
/** This transform normalizes type tests and type casts,
19+
* also replacing type tests with singleton argument type with refference equality check
1920
* Any remaining type tests
2021
* - use the object methods $isInstanceOf and $asInstanceOf
2122
* - have a reference type as receiver

tests/untried/pos/hashhash-overloads.scala renamed to tests/pos/hashhash-overloads.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ object Test {
33
def g = 5f.##
44
def h = ({ 5 ; println("abc") }).##
55
def f2 = null.##
6+
def l = 3L.##
7+
def b(arg: Boolean) = arg.##
68
}

0 commit comments

Comments
 (0)