From a93c5f3f6cbcb8131e64b0dcdd23582d646b050c Mon Sep 17 00:00:00 2001 From: Fengyun Liu Date: Mon, 4 Dec 2023 21:49:30 +0100 Subject: [PATCH 1/2] Treat new Array(0) as immutable An array of size 0 is immutable, thus we can safely abstract them with the bottom value. For the rules to be simple and understandable, we usually want to avoid such fine-tuning. However, given that we expect such code patterns to be rare and we want to avoid changes in the standard library, we fine-tune the analysis as a compromise. --- .../tools/dotc/transform/init/Objects.scala | 31 +++++++++++-------- tests/init-global/pos/array-size-zero.scala | 10 ++++++ 2 files changed, 28 insertions(+), 13 deletions(-) create mode 100644 tests/init-global/pos/array-size-zero.scala diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index ed8d07e90ce5..ac29a7c797a5 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -565,7 +565,7 @@ object Objects: // --------------------------- domain operations ----------------------------- - type ArgInfo = TraceValue[Value] + case class ArgInfo(value: Value, trace: Trace, tree: Tree) extension (a: Value) def join(b: Value): Value = @@ -875,7 +875,7 @@ object Objects: * @param ctor The symbol of the target constructor. * @param args The arguments passsed to the constructor. */ - def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) { + def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo], inKlass: ClassSymbol): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) { outer match case _ : Fun | _: OfArray => @@ -884,9 +884,14 @@ object Objects: case outer: (Ref | Cold.type | Bottom.type) => if klass == defn.ArrayClass then - val arr = OfArray(State.currentObject, summon[Regions.Data]) - Heap.writeJoin(arr.addr, Bottom) - arr + args.head.tree.tpe match + case ConstantType(Constants.Constant(0)) => + // new Array(0) + Bottom + case _ => + val arr = OfArray(State.currentObject, summon[Regions.Data]) + Heap.writeJoin(arr.addr, Bottom) + arr else // Widen the outer to finitize the domain. Arguments already widened in `evalArgs`. val (outerWidened, envWidened) = @@ -909,7 +914,7 @@ object Objects: instance case ValueSet(values) => - values.map(ref => instantiate(ref, klass, ctor, args)).join + values.map(ref => instantiate(ref, klass, ctor, args, inKlass)).join } /** Handle local variable definition, `val x = e` or `var x = e`. @@ -1083,7 +1088,7 @@ object Objects: val cls = tref.classSymbol.asClass withTrace(trace2) { val outer = outerValue(tref, thisV, klass) - instantiate(outer, cls, ctor, args) + instantiate(outer, cls, ctor, args, klass) } case Apply(ref, arg :: Nil) if ref.symbol == defn.InitRegionMethod => @@ -1328,7 +1333,7 @@ object Objects: case _ => List() val implicitArgsAfterScrutinee = evalArgs(implicits.map(Arg.apply), thisV, klass) - val args = implicitArgsBeforeScrutinee(fun) ++ (TraceValue(scrutinee, summon[Trace]) :: implicitArgsAfterScrutinee) + val args = implicitArgsBeforeScrutinee(fun) ++ (ArgInfo(scrutinee, summon[Trace], EmptyTree) :: implicitArgsAfterScrutinee) val unapplyRes = call(receiver, funRef.symbol, args, funRef.prefix, superType = NoType, needResolve = true) if fun.symbol.name == nme.unapplySeq then @@ -1425,7 +1430,7 @@ object Objects: // call .lengthCompare or .length val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType) if lengthCompareDenot.exists then - call(scrutinee, lengthCompareDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true) + call(scrutinee, lengthCompareDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true) else val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType) call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true) @@ -1433,7 +1438,7 @@ object Objects: // call .apply val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType)) - val applyRes = call(scrutinee, applyDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true) + val applyRes = call(scrutinee, applyDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true) if isWildcardStarArgList(pats) then if pats.size == 1 then @@ -1444,7 +1449,7 @@ object Objects: else // call .drop val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType)) - val dropRes = call(scrutinee, dropDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true) + val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true) for pat <- pats.init do evalPattern(applyRes, pat) evalPattern(dropRes, pats.last) end if @@ -1546,7 +1551,7 @@ object Objects: case _ => res.widen(1) - argInfos += TraceValue(widened, trace.add(arg.tree)) + argInfos += ArgInfo(widened, trace.add(arg.tree), arg.tree) } argInfos.toList @@ -1644,7 +1649,7 @@ object Objects: // The parameter check of traits comes late in the mixin phase. // To avoid crash we supply hot values for erroneous parent calls. // See tests/neg/i16438.scala. - val args: List[ArgInfo] = ctor.info.paramInfoss.flatten.map(_ => new ArgInfo(Bottom, Trace.empty)) + val args: List[ArgInfo] = ctor.info.paramInfoss.flatten.map(_ => new ArgInfo(Bottom, Trace.empty, EmptyTree)) extendTrace(superParent) { superCall(tref, ctor, args, tasks) } diff --git a/tests/init-global/pos/array-size-zero.scala b/tests/init-global/pos/array-size-zero.scala new file mode 100644 index 000000000000..a1a2fc578ad7 --- /dev/null +++ b/tests/init-global/pos/array-size-zero.scala @@ -0,0 +1,10 @@ +object A: + val emptyArray = new Array(0) + +object B: + def build(data: Int*) = + if data.size == 0 then A.emptyArray else Array(data) + + val arr = build(5, 6) + val first = arr(0) + From 81db8cd60a98f1781f67e751492031dfd6e10eaa Mon Sep 17 00:00:00 2001 From: Fengyun Liu Date: Thu, 14 Dec 2023 22:42:16 +0100 Subject: [PATCH 2/2] Address review: Remove unused parameter --- compiler/src/dotty/tools/dotc/transform/init/Objects.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index ac29a7c797a5..763b71619de8 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -875,7 +875,7 @@ object Objects: * @param ctor The symbol of the target constructor. * @param args The arguments passsed to the constructor. */ - def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo], inKlass: ClassSymbol): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) { + def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) { outer match case _ : Fun | _: OfArray => @@ -914,7 +914,7 @@ object Objects: instance case ValueSet(values) => - values.map(ref => instantiate(ref, klass, ctor, args, inKlass)).join + values.map(ref => instantiate(ref, klass, ctor, args)).join } /** Handle local variable definition, `val x = e` or `var x = e`. @@ -1088,7 +1088,7 @@ object Objects: val cls = tref.classSymbol.asClass withTrace(trace2) { val outer = outerValue(tref, thisV, klass) - instantiate(outer, cls, ctor, args, klass) + instantiate(outer, cls, ctor, args) } case Apply(ref, arg :: Nil) if ref.symbol == defn.InitRegionMethod =>