Skip to content

Commit 4356d70

Browse files
authored
Merge pull request #8669 from dotty-staging/repeated-expected
Take expected type into account when typing a sequence argument
2 parents f26808e + 03d5d9b commit 4356d70

File tree

14 files changed

+81
-34
lines changed

14 files changed

+81
-34
lines changed

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ class TypeApplications(val self: Type) extends AnyVal {
378378
self.derivedExprType(tp.translateParameterized(from, to))
379379
case _ =>
380380
if (self.derivesFrom(from)) {
381-
def elemType(tp: Type): Type = tp match
381+
def elemType(tp: Type): Type = tp.widenDealias match
382382
case tp: AndOrType => tp.derivedAndOrType(elemType(tp.tp1), elemType(tp.tp2))
383383
case _ => tp.baseType(from).argInfos.head
384384
val arg = elemType(self)
@@ -388,18 +388,26 @@ class TypeApplications(val self: Type) extends AnyVal {
388388
else self
389389
}
390390

391-
/** If this is repeated parameter type, its underlying Seq type,
392-
* or, if isJava is true, Array type, else the type itself.
391+
/** If this is a repeated parameter `*T`, translate it to either `Seq[T]` or
392+
* `Array[? <: T]` depending on the value of `toArray`.
393+
* Additionally, if `translateWildcard` is true, a wildcard type
394+
* will be translated to `*<?>`.
395+
* Other types are kept as-is.
393396
*/
394-
def underlyingIfRepeated(isJava: Boolean)(implicit ctx: Context): Type =
395-
if (self.isRepeatedParam) {
396-
val seqClass = if (isJava) defn.ArrayClass else defn.SeqClass
397-
// If `isJava` is set, then we want to turn `RepeatedParam[T]` into `Array[? <: T]`,
398-
// since arrays aren't covariant until after erasure. See `tests/pos/i5140`.
399-
translateParameterized(defn.RepeatedParamClass, seqClass, wildcardArg = isJava)
400-
}
397+
def translateFromRepeated(toArray: Boolean, translateWildcard: Boolean = false)(using Context): Type =
398+
val seqClass = if (toArray) defn.ArrayClass else defn.SeqClass
399+
if translateWildcard && self.isInstanceOf[WildcardType] then
400+
seqClass.typeRef.appliedTo(WildcardType)
401+
else if self.isRepeatedParam then
402+
// We want `Array[? <: T]` because arrays aren't covariant until after
403+
// erasure. See `tests/pos/i5140`.
404+
translateParameterized(defn.RepeatedParamClass, seqClass, wildcardArg = toArray)
401405
else self
402406

407+
/** Translate a `From[T]` into a `*T`. */
408+
def translateToRepeated(from: ClassSymbol)(using Context): Type =
409+
translateParameterized(from, defn.RepeatedParamClass)
410+
403411
/** If this is an encoding of a (partially) applied type, return its arguments,
404412
* otherwise return Nil.
405413
* Existential types in arguments are returned as TypeBounds instances.

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ object TypeErasure {
155155
case etp => etp
156156

157157
def sigName(tp: Type, isJava: Boolean)(implicit ctx: Context): TypeName = {
158-
val normTp = tp.underlyingIfRepeated(isJava)
158+
val normTp = tp.translateFromRepeated(toArray = isJava)
159159
val erase = erasureFn(isJava, semiEraseVCs = false, isConstructor = false, wildcardOK = true)
160160
erase.sigName(normTp)(preErasureCtx)
161161
}
@@ -448,7 +448,7 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean
448448
val tycon = tp.tycon
449449
if (tycon.isRef(defn.ArrayClass)) eraseArray(tp)
450450
else if (tycon.isRef(defn.PairClass)) erasePair(tp)
451-
else if (tp.isRepeatedParam) apply(tp.underlyingIfRepeated(isJava))
451+
else if (tp.isRepeatedParam) apply(tp.translateFromRepeated(toArray = isJava))
452452
else apply(tp.translucentSuperType)
453453
case _: TermRef | _: ThisType =>
454454
this(tp.widen)
@@ -540,7 +540,7 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean
540540
// See doc comment for ElimByName for speculation how we could improve this.
541541
else
542542
MethodType(Nil, Nil,
543-
eraseResult(sym.info.finalResultType.underlyingIfRepeated(isJava)))
543+
eraseResult(sym.info.finalResultType.translateFromRepeated(toArray = isJava)))
544544
case tp1: PolyType =>
545545
eraseResult(tp1.resultType) match
546546
case rt: MethodType => rt

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,7 @@ object Types {
16071607
case res => res
16081608
}
16091609
val funType = defn.FunctionOf(
1610-
formals1 mapConserve (_.underlyingIfRepeated(mt.isJavaMethod)),
1610+
formals1 mapConserve (_.translateFromRepeated(toArray = mt.isJavaMethod)),
16111611
result1, isContextual, isErased)
16121612
if (mt.isResultDependent) RefinedType(funType, nme.apply, mt)
16131613
else funType

compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
593593
if (tag == ALIASsym) TypeAlias(tp1)
594594
else if (denot.isType) checkNonCyclic(denot.symbol, tp1, reportErrors = false)
595595
// we need the checkNonCyclic call to insert LazyRefs for F-bounded cycles
596-
else if (!denot.is(Param)) tp1.underlyingIfRepeated(isJava = false)
596+
else if (!denot.is(Param)) tp1.translateFromRepeated(toArray = false)
597597
else tp1
598598
if (denot.isConstructor) addConstructorTypeParams(denot)
599599
if (atEnd)

compiler/src/dotty/tools/dotc/transform/ElimRepeated.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
5252
val resultType1 = elimRepeated(resultType)
5353
val paramTypes1 =
5454
if (paramTypes.nonEmpty && paramTypes.last.isRepeatedParam) {
55-
val last = paramTypes.last.underlyingIfRepeated(tp.isJavaMethod)
55+
val last = paramTypes.last.translateFromRepeated(toArray = tp.isJavaMethod)
5656
paramTypes.init :+ last
5757
}
5858
else paramTypes
@@ -159,7 +159,7 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
159159
tp.derivedLambdaType(tp.paramNames, tp.paramInfos, toJavaVarArgs(tp.resultType))
160160
case tp: MethodType =>
161161
val inits :+ last = tp.paramInfos
162-
val last1 = last.underlyingIfRepeated(isJava = true)
162+
val last1 = last.translateFromRepeated(toArray = true)
163163
tp.derivedLambdaType(tp.paramNames, inits :+ last1, tp.resultType)
164164
}
165165
}

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
420420
for ((formal, idx) <- methTpe.paramInfos.zipWithIndex) yield {
421421
val elem =
422422
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
423-
.ensureConforms(formal.underlyingIfRepeated(isJava = false))
423+
.ensureConforms(formal.translateFromRepeated(toArray = false))
424424
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
425425
}
426426
New(classRef, elems)

compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ trait QuotesAndSplices {
207207
try ref(defn.InternalQuoted_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span)
208208
finally {
209209
val patType = pat.tpe.widen
210-
val patType1 = patType.underlyingIfRepeated(isJava = false)
210+
val patType1 = patType.translateFromRepeated(toArray = false)
211211
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
212212
patBuf += pat1
213213
}

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ trait TypeAssigner {
166166
else sym.info
167167

168168
private def toRepeated(tree: Tree, from: ClassSymbol)(using Context): Tree =
169-
Typed(tree, TypeTree(tree.tpe.widen.translateParameterized(from, defn.RepeatedParamClass)))
169+
Typed(tree, TypeTree(tree.tpe.widen.translateToRepeated(from)))
170170

171171
def seqToRepeated(tree: Tree)(using Context): Tree = toRepeated(tree, defn.SeqClass)
172172

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -728,20 +728,19 @@ class Typer extends Namer
728728

729729
if (untpd.isWildcardStarArg(tree)) {
730730
def typedWildcardStarArgExpr = {
731+
// A sequence argument `xs: _*` can be either a `Seq[T]` or an `Array[_ <: T]`,
732+
// irrespective of whether the method we're calling is a Java or Scala method,
733+
// so the expected type is the union `Seq[T] | Array[_ <: T]`.
731734
val ptArg =
732-
if (ctx.mode.is(Mode.QuotedPattern)) pt.underlyingIfRepeated(isJava = false)
733-
else WildcardType
735+
// FIXME(#8680): Quoted patterns do not support Array repeated arguments
736+
if (ctx.mode.is(Mode.QuotedPattern)) pt.translateFromRepeated(toArray = false, translateWildcard = true)
737+
else pt.translateFromRepeated(toArray = false, translateWildcard = true) |
738+
pt.translateFromRepeated(toArray = true, translateWildcard = true)
734739
val tpdExpr = typedExpr(tree.expr, ptArg)
735-
tpdExpr.tpe.widenDealias match {
736-
case defn.ArrayOf(_) =>
737-
val starType = defn.ArrayType.appliedTo(WildcardType)
738-
val exprAdapted = adapt(tpdExpr, starType)
739-
arrayToRepeated(exprAdapted)
740-
case _ =>
741-
val starType = defn.SeqType.appliedTo(defn.AnyType)
742-
val exprAdapted = adapt(tpdExpr, starType)
743-
seqToRepeated(exprAdapted)
744-
}
740+
val expr1 = typedExpr(tree.expr, ptArg)
741+
val fromCls = if expr1.tpe.derivesFrom(defn.ArrayClass) then defn.ArrayClass else defn.SeqClass
742+
val tpt1 = TypeTree(expr1.tpe.widen.translateToRepeated(fromCls)).withSpan(tree.tpt.span)
743+
assignType(cpy.Typed(tree)(expr1, tpt1), tpt1)
745744
}
746745
cases(
747746
ifPat = ascription(TypeTree(defn.RepeatedParamType.appliedTo(pt)), isWildcard = true),
@@ -1158,7 +1157,7 @@ class Typer extends Namer
11581157
if (!param.tpt.isEmpty) param
11591158
else cpy.ValDef(param)(
11601159
tpt = untpd.TypeTree(
1161-
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
1160+
inferredParamType(param, protoFormal(i)).translateFromRepeated(toArray = false)))
11621161
desugar.makeClosure(inferredParams, fnBody, resultTpt, isContextual)
11631162
}
11641163
typed(desugared, pt)

tests/pos/case-signature.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// If `translateFromRepeated` translated wildcards by default, the following
2+
// would break because of the use of wildcards in signatures.
3+
case class Benchmark[A](params: List[A],
4+
sqlInsert: (benchId: Long, params: A, session: Int) => Unit,
5+
fun: List[A])

tests/pos/sequence-argument/B_2.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import scala.reflect.ClassTag
2+
import scala.language.implicitConversions
3+
4+
object B {
5+
def doubleSeq[T](x: T): Seq[T] = Seq(x, x)
6+
def doubleArray[T: ClassTag](x: T): Array[T] = Array(x, x)
7+
8+
def box(args: Integer*): Unit = {}
9+
def widen(args: Long*): Unit = {}
10+
def conv(args: Y*): Unit = {}
11+
12+
box(doubleSeq(1): _*)
13+
box(doubleArray(1): _*)
14+
Java_2.box(doubleSeq(1): _*)
15+
Java_2.box(doubleArray(1): _*)
16+
17+
widen(doubleSeq(1): _*)
18+
widen(doubleArray(1): _*)
19+
Java_2.widen(doubleSeq(1): _*)
20+
Java_2.widen(doubleArray(1): _*)
21+
22+
implicit def xToY(x: X): Y = new Y
23+
val x: X = new X
24+
conv(doubleSeq(x): _*)
25+
conv(doubleArray(x): _*)
26+
Java_2.conv(doubleSeq(x): _*)
27+
Java_2.conv(doubleArray(x): _*)
28+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
public class Java_2 {
2+
public static void box(Integer ...args) {}
3+
public static void widen(Long... args) {}
4+
public static void conv(Y... args) {}
5+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class X
2+
class Y

0 commit comments

Comments
 (0)