Skip to content

Commit fdcb455

Browse files
author
EnzeXing
committed
Adding semantics for calling on SafeValue and evaluating SeqLiteral
1 parent f51f91b commit fdcb455

File tree

2 files changed

+82
-30
lines changed

2 files changed

+82
-30
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,19 @@ class Objects(using Context @constructorOnly):
227227
case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data) extends ValueElement:
228228
def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")"
229229

230-
/** Represents common base values like Int, String, etc.
230+
/**
231+
* Represents common base values like Int, String, etc.
232+
* Assumption: all methods calls on such values should be pure (no side effects)
231233
*/
232-
case object SafeValue extends ValueElement:
233-
val safeTypes = defn.ScalaNumericValueTypeList ++ List(defn.UnitType, defn.BooleanType, defn.StringType)
234-
def show(using Context): String = "SafeValue"
234+
case class SafeValue(tpe: Type) extends ValueElement:
235+
// tpe could be a AppliedType(java.lang.Class, T)
236+
val baseType = if tpe.isInstanceOf[AppliedType] then tpe.asInstanceOf[AppliedType].underlying else tpe
237+
assert(baseType.isInstanceOf[TypeRef] && SafeValue.safeTypes.contains(baseType), "Invalid creation of SafeValue! Type = " + tpe)
238+
val typeref = baseType.asInstanceOf[TypeRef]
239+
def show(using Context): String = "SafeValue of type " + tpe
240+
241+
object SafeValue:
242+
val safeTypes = defn.ScalaNumericValueTypeList ++ List(defn.UnitType, defn.BooleanType, defn.StringType, defn.NullType, defn.ClassClass.typeRef)
235243

236244
/**
237245
* Represents a set of values
@@ -704,7 +712,7 @@ class Objects(using Context @constructorOnly):
704712
a match
705713
case UnknownValue => UnknownValue
706714
case Package(_) => a
707-
case SafeValue => SafeValue
715+
case SafeValue(_) => a
708716
case ref: Ref => if ref.klass.isSubClass(klass) then ref else Bottom
709717
case ValueSet(values) => values.map(v => v.filterClass(klass)).join
710718
case arr: OfArray => if defn.ArrayClass.isSubClass(klass) then arr else Bottom
@@ -733,7 +741,7 @@ class Objects(using Context @constructorOnly):
733741
* @param superType The type of the super in a super call. NoType for non-super calls.
734742
* @param needResolve Whether the target of the call needs resolution?
735743
*/
736-
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.tree.show), printer, (_: Value).show) {
744+
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
737745
value.filterClass(meth.owner) match
738746
case UnknownValue =>
739747
if reportUnknown then
@@ -743,11 +751,33 @@ class Objects(using Context @constructorOnly):
743751
UnknownValue
744752

745753
case Package(packageSym) =>
746-
report.warning("[Internal error] Unexpected call on package = " + value.show + ", meth = " + meth.show + Trace.show, Trace.position)
747-
Bottom
748-
749-
case SafeValue =>
750-
SafeValue // Check return type, if not safe, try to analyze body, 1.until(2).map(i => UninitializedObject)
754+
// calls on packages are unexpected. However the typer might mistakenly
755+
// set the receiver to be a package instead of package object.
756+
// See packageObjectStringInterpolator.scala
757+
if !meth.owner.denot.isPackageObject then
758+
report.warning("[Internal error] Unexpected call on package = " + value.show + ", meth = " + meth.show + Trace.show, Trace.position)
759+
Bottom
760+
else
761+
// Method call on package object instead
762+
val packageObj = accessObject(meth.owner.moduleClass.asClass)
763+
call(packageObj, meth, args, receiver, superType, needResolve)
764+
765+
case v @ SafeValue(tpe) =>
766+
// Assume such method is pure. Check return type, only try to analyze body if return type is not safe
767+
val target = resolve(v.typeref.symbol.asClass, meth)
768+
if !target.hasSource then
769+
UnknownValue
770+
else
771+
val ddef = target.defTree.asInstanceOf[DefDef]
772+
val returnType = ddef.tpt.tpe
773+
if SafeValue.safeTypes.contains(returnType) then
774+
// since method is pure and return type is safe, no need to analyze method body
775+
SafeValue(returnType)
776+
else
777+
val cls = target.owner.enclosingClass.asClass
778+
// convert SafeType to an OfClass before analyzing method body
779+
val ref = OfClass(cls, Bottom, NoSymbol, Nil, Env.NoEnv)
780+
call(ref, meth, args, receiver, superType, needResolve)
751781

752782
case Bottom =>
753783
Bottom
@@ -774,7 +804,7 @@ class Objects(using Context @constructorOnly):
774804
Bottom
775805
else
776806
// Array.length is OK
777-
SafeValue
807+
SafeValue(defn.IntType)
778808

779809
case ref: Ref =>
780810
val isLocal = !meth.owner.isClass
@@ -795,10 +825,10 @@ class Objects(using Context @constructorOnly):
795825
arr
796826
else if target.equals(defn.Predef_classOf) then
797827
// Predef.classOf is a stub method in tasty and is replaced in backend
798-
SafeValue
828+
UnknownValue
799829
else if target.equals(defn.ClassTagModule_apply) then
800-
// ClassTag and other reflection related values are considered safe
801-
SafeValue
830+
// ClassTag and other reflection related values are not analyzed
831+
UnknownValue
802832
else if target.hasSource then
803833
val cls = target.owner.enclosingClass.asClass
804834
val ddef = target.defTree.asInstanceOf[DefDef]
@@ -886,6 +916,7 @@ class Objects(using Context @constructorOnly):
886916
Returns.installHandler(ctor)
887917
eval(ddef.rhs, ref, cls, cacheResult = true)
888918
Returns.popHandler(ctor)
919+
value
889920
}
890921
else
891922
// no source code available
@@ -912,8 +943,9 @@ class Objects(using Context @constructorOnly):
912943
else
913944
UnknownValue
914945

915-
case SafeValue =>
916-
SafeValue
946+
case v @ SafeValue(_) =>
947+
report.warning("[Internal error] Unexpected selection on safe value " + v.show + ", field = " + field.show + Trace.show, Trace.position)
948+
Bottom
917949

918950
case Package(packageSym) =>
919951
if field.isStaticObject then
@@ -997,7 +1029,7 @@ class Objects(using Context @constructorOnly):
9971029
case arr: OfArray =>
9981030
report.warning("[Internal error] unexpected tree in assignment, array = " + arr.show + " field = " + field + Trace.show, Trace.position)
9991031

1000-
case SafeValue | UnknownValue =>
1032+
case SafeValue(_) | UnknownValue =>
10011033
report.warning("Assigning to base or unknown value is forbidden. " + Trace.show, Trace.position)
10021034

10031035
case ValueSet(values) =>
@@ -1029,7 +1061,7 @@ class Objects(using Context @constructorOnly):
10291061
*/
10301062
def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
10311063
outer.filterClass(klass.owner) match
1032-
case _ : Fun | _: OfArray | SafeValue =>
1064+
case _ : Fun | _: OfArray | SafeValue(_) =>
10331065
report.warning("[Internal error] unexpected outer in instantiating a class, outer = " + outer.show + ", class = " + klass.show + ", " + Trace.show, Trace.position)
10341066
Bottom
10351067

@@ -1126,7 +1158,7 @@ class Objects(using Context @constructorOnly):
11261158
case UnknownValue =>
11271159
report.warning("Calling on unknown value. " + Trace.show, Trace.position)
11281160
Bottom
1129-
case _: ValueSet | _: Ref | _: OfArray | _: Package | SafeValue =>
1161+
case _: ValueSet | _: Ref | _: OfArray | _: Package | SafeValue(_) =>
11301162
report.warning("[Internal error] Unexpected by-name value " + value.show + ". " + Trace.show, Trace.position)
11311163
Bottom
11321164
else
@@ -1314,8 +1346,8 @@ class Objects(using Context @constructorOnly):
13141346
case _: This =>
13151347
evalType(expr.tpe, thisV, klass)
13161348

1317-
case Literal(_) =>
1318-
SafeValue
1349+
case Literal(const) =>
1350+
SafeValue(const.tpe)
13191351

13201352
case Typed(expr, tpt) =>
13211353
if tpt.tpe.hasAnnotation(defn.UncheckedAnnot) then
@@ -1390,7 +1422,12 @@ class Objects(using Context @constructorOnly):
13901422
res
13911423

13921424
case SeqLiteral(elems, elemtpt) =>
1393-
evalExprs(elems, thisV, klass).join
1425+
// Obtain the output Seq from SeqLiteral tree by calling respective wrapArrayMethod
1426+
val wrapArrayMethodName = ast.tpd.wrapArrayMethodName(elemtpt.tpe)
1427+
val meth = defn.getWrapVarargsArrayModule.requiredMethod(wrapArrayMethodName)
1428+
val module = defn.getWrapVarargsArrayModule.moduleClass.asClass
1429+
val args = evalArgs(elems.map(Arg.apply), thisV, klass)
1430+
call(ObjectRef(module), meth, args, module.typeRef, NoType)
13941431

13951432
case Inlined(call, bindings, expansion) =>
13961433
evalExprs(bindings, thisV, klass)
@@ -1601,7 +1638,7 @@ class Objects(using Context @constructorOnly):
16011638

16021639
// call .apply
16031640
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
1604-
val applyRes = call(scrutinee, applyDenot.symbol, ArgInfo(SafeValue, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1641+
val applyRes = call(scrutinee, applyDenot.symbol, ArgInfo(SafeValue(defn.IntType), summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
16051642

16061643
if isWildcardStarArgList(pats) then
16071644
if pats.size == 1 then
@@ -1612,7 +1649,7 @@ class Objects(using Context @constructorOnly):
16121649
else
16131650
// call .drop
16141651
val dropDenot = getMemberMethod(scrutineeType, nme.drop, dropType(elemType))
1615-
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(SafeValue, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1652+
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(SafeValue(defn.IntType), summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
16161653
for pat <- pats.init do evalPattern(applyRes, pat)
16171654
evalPattern(dropRes, pats.last)
16181655
end if
@@ -1623,8 +1660,7 @@ class Objects(using Context @constructorOnly):
16231660
end evalSeqPatterns
16241661

16251662
def canSkipCase(remainingScrutinee: Value, catchValue: Value) =
1626-
(remainingScrutinee == Bottom && scrutinee != Bottom) ||
1627-
(catchValue == Bottom && remainingScrutinee != Bottom)
1663+
remainingScrutinee == Bottom || catchValue == Bottom
16281664

16291665
var remainingScrutinee = scrutinee
16301666
val caseResults: mutable.ArrayBuffer[Value] = mutable.ArrayBuffer()
@@ -1653,8 +1689,8 @@ class Objects(using Context @constructorOnly):
16531689
*/
16541690
def evalType(tp: Type, thisV: ThisValue, klass: ClassSymbol, elideObjectAccess: Boolean = false): Contextual[Value] = log("evaluating " + tp.show, printer, (_: Value).show) {
16551691
tp match
1656-
case _: ConstantType =>
1657-
SafeValue
1692+
case consttpe: ConstantType =>
1693+
SafeValue(consttpe.underlying)
16581694

16591695
case tmref: TermRef if tmref.prefix == NoPrefix =>
16601696
val sym = tmref.symbol
@@ -1904,7 +1940,7 @@ class Objects(using Context @constructorOnly):
19041940
resolveThis(target, ref.outerValue(klass), outerCls)
19051941
case ValueSet(values) =>
19061942
values.map(ref => resolveThis(target, ref, klass)).join
1907-
case _: Fun | _ : OfArray | _: Package | SafeValue =>
1943+
case _: Fun | _ : OfArray | _: Package | SafeValue(_) =>
19081944
report.warning("[Internal error] unexpected thisV = " + thisV + ", target = " + target.show + ", klass = " + klass.show + Trace.show, Trace.position)
19091945
Bottom
19101946
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package p
2+
package object a {
3+
val b = 10
4+
implicit class CI(s: StringContext) {
5+
def ci(args: Any*) = 10
6+
}
7+
}
8+
9+
import p.a._
10+
11+
object A:
12+
val f = b // p.a(ObjectRef(p.a)).b
13+
def foo(s: String): String = s
14+
val f1 = ci"a" // => p.a(Package(p).select(a)).CI(StringContext"a").ci()
15+
16+

0 commit comments

Comments
 (0)