Skip to content

Take expected type into account when typing a sequence argument #8669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class TypeApplications(val self: Type) extends AnyVal {
self.derivedExprType(tp.translateParameterized(from, to))
case _ =>
if (self.derivesFrom(from)) {
def elemType(tp: Type): Type = tp match
def elemType(tp: Type): Type = tp.widenDealias match
case tp: AndOrType => tp.derivedAndOrType(elemType(tp.tp1), elemType(tp.tp2))
case _ => tp.baseType(from).argInfos.head
val arg = elemType(self)
Expand All @@ -388,18 +388,26 @@ class TypeApplications(val self: Type) extends AnyVal {
else self
}

/** If this is repeated parameter type, its underlying Seq type,
* or, if isJava is true, Array type, else the type itself.
/** If this is a repeated parameter `*T`, translate it to either `Seq[T]` or
* `Array[? <: T]` depending on the value of `toArray`.
* Additionally, if `translateWildcard` is true, a wildcard type
* will be translated to `*<?>`.
* Other types are kept as-is.
*/
def underlyingIfRepeated(isJava: Boolean)(implicit ctx: Context): Type =
if (self.isRepeatedParam) {
val seqClass = if (isJava) defn.ArrayClass else defn.SeqClass
// If `isJava` is set, then we want to turn `RepeatedParam[T]` into `Array[? <: T]`,
// since arrays aren't covariant until after erasure. See `tests/pos/i5140`.
translateParameterized(defn.RepeatedParamClass, seqClass, wildcardArg = isJava)
}
def translateFromRepeated(toArray: Boolean, translateWildcard: Boolean = false)(using Context): Type =
val seqClass = if (toArray) defn.ArrayClass else defn.SeqClass
if translateWildcard && self.isInstanceOf[WildcardType] then
seqClass.typeRef.appliedTo(WildcardType)
else if self.isRepeatedParam then
// We want `Array[? <: T]` because arrays aren't covariant until after
// erasure. See `tests/pos/i5140`.
translateParameterized(defn.RepeatedParamClass, seqClass, wildcardArg = toArray)
else self

/** Translate a `From[T]` into a `*T`. */
def translateToRepeated(from: ClassSymbol)(using Context): Type =
translateParameterized(from, defn.RepeatedParamClass)

/** If this is an encoding of a (partially) applied type, return its arguments,
* otherwise return Nil.
* Existential types in arguments are returned as TypeBounds instances.
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ object TypeErasure {
case etp => etp

def sigName(tp: Type, isJava: Boolean)(implicit ctx: Context): TypeName = {
val normTp = tp.underlyingIfRepeated(isJava)
val normTp = tp.translateFromRepeated(toArray = isJava)
val erase = erasureFn(isJava, semiEraseVCs = false, isConstructor = false, wildcardOK = true)
erase.sigName(normTp)(preErasureCtx)
}
Expand Down Expand Up @@ -448,7 +448,7 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean
val tycon = tp.tycon
if (tycon.isRef(defn.ArrayClass)) eraseArray(tp)
else if (tycon.isRef(defn.PairClass)) erasePair(tp)
else if (tp.isRepeatedParam) apply(tp.underlyingIfRepeated(isJava))
else if (tp.isRepeatedParam) apply(tp.translateFromRepeated(toArray = isJava))
else apply(tp.translucentSuperType)
case _: TermRef | _: ThisType =>
this(tp.widen)
Expand Down Expand Up @@ -540,7 +540,7 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean
// See doc comment for ElimByName for speculation how we could improve this.
else
MethodType(Nil, Nil,
eraseResult(sym.info.finalResultType.underlyingIfRepeated(isJava)))
eraseResult(sym.info.finalResultType.translateFromRepeated(toArray = isJava)))
case tp1: PolyType =>
eraseResult(tp1.resultType) match
case rt: MethodType => rt
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,7 @@ object Types {
case res => res
}
val funType = defn.FunctionOf(
formals1 mapConserve (_.underlyingIfRepeated(mt.isJavaMethod)),
formals1 mapConserve (_.translateFromRepeated(toArray = mt.isJavaMethod)),
result1, isContextual, isErased)
if (mt.isResultDependent) RefinedType(funType, nme.apply, mt)
else funType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
if (tag == ALIASsym) TypeAlias(tp1)
else if (denot.isType) checkNonCyclic(denot.symbol, tp1, reportErrors = false)
// we need the checkNonCyclic call to insert LazyRefs for F-bounded cycles
else if (!denot.is(Param)) tp1.underlyingIfRepeated(isJava = false)
else if (!denot.is(Param)) tp1.translateFromRepeated(toArray = false)
else tp1
if (denot.isConstructor) addConstructorTypeParams(denot)
if (atEnd)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/ElimRepeated.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
val resultType1 = elimRepeated(resultType)
val paramTypes1 =
if (paramTypes.nonEmpty && paramTypes.last.isRepeatedParam) {
val last = paramTypes.last.underlyingIfRepeated(tp.isJavaMethod)
val last = paramTypes.last.translateFromRepeated(toArray = tp.isJavaMethod)
paramTypes.init :+ last
}
else paramTypes
Expand Down Expand Up @@ -159,7 +159,7 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
tp.derivedLambdaType(tp.paramNames, tp.paramInfos, toJavaVarArgs(tp.resultType))
case tp: MethodType =>
val inits :+ last = tp.paramInfos
val last1 = last.underlyingIfRepeated(isJava = true)
val last1 = last.translateFromRepeated(toArray = true)
tp.derivedLambdaType(tp.paramNames, inits :+ last1, tp.resultType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
for ((formal, idx) <- methTpe.paramInfos.zipWithIndex) yield {
val elem =
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
.ensureConforms(formal.underlyingIfRepeated(isJava = false))
.ensureConforms(formal.translateFromRepeated(toArray = false))
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
}
New(classRef, elems)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ trait QuotesAndSplices {
try ref(defn.InternalQuoted_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span)
finally {
val patType = pat.tpe.widen
val patType1 = patType.underlyingIfRepeated(isJava = false)
val patType1 = patType.translateFromRepeated(toArray = false)
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
patBuf += pat1
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ trait TypeAssigner {
else sym.info

private def toRepeated(tree: Tree, from: ClassSymbol)(using Context): Tree =
Typed(tree, TypeTree(tree.tpe.widen.translateParameterized(from, defn.RepeatedParamClass)))
Typed(tree, TypeTree(tree.tpe.widen.translateToRepeated(from)))

def seqToRepeated(tree: Tree)(using Context): Tree = toRepeated(tree, defn.SeqClass)

Expand Down
25 changes: 12 additions & 13 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -728,20 +728,19 @@ class Typer extends Namer

if (untpd.isWildcardStarArg(tree)) {
def typedWildcardStarArgExpr = {
// A sequence argument `xs: _*` can be either a `Seq[T]` or an `Array[_ <: T]`,
// irrespective of whether the method we're calling is a Java or Scala method,
// so the expected type is the union `Seq[T] | Array[_ <: T]`.
val ptArg =
if (ctx.mode.is(Mode.QuotedPattern)) pt.underlyingIfRepeated(isJava = false)
else WildcardType
// FIXME(#8680): Quoted patterns do not support Array repeated arguments
if (ctx.mode.is(Mode.QuotedPattern)) pt.translateFromRepeated(toArray = false, translateWildcard = true)
else pt.translateFromRepeated(toArray = false, translateWildcard = true) |
pt.translateFromRepeated(toArray = true, translateWildcard = true)
val tpdExpr = typedExpr(tree.expr, ptArg)
tpdExpr.tpe.widenDealias match {
case defn.ArrayOf(_) =>
val starType = defn.ArrayType.appliedTo(WildcardType)
val exprAdapted = adapt(tpdExpr, starType)
arrayToRepeated(exprAdapted)
case _ =>
val starType = defn.SeqType.appliedTo(defn.AnyType)
val exprAdapted = adapt(tpdExpr, starType)
seqToRepeated(exprAdapted)
}
val expr1 = typedExpr(tree.expr, ptArg)
val fromCls = if expr1.tpe.derivesFrom(defn.ArrayClass) then defn.ArrayClass else defn.SeqClass
val tpt1 = TypeTree(expr1.tpe.widen.translateToRepeated(fromCls)).withSpan(tree.tpt.span)
assignType(cpy.Typed(tree)(expr1, tpt1), tpt1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

}
cases(
ifPat = ascription(TypeTree(defn.RepeatedParamType.appliedTo(pt)), isWildcard = true),
Expand Down Expand Up @@ -1158,7 +1157,7 @@ class Typer extends Namer
if (!param.tpt.isEmpty) param
else cpy.ValDef(param)(
tpt = untpd.TypeTree(
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
inferredParamType(param, protoFormal(i)).translateFromRepeated(toArray = false)))
desugar.makeClosure(inferredParams, fnBody, resultTpt, isContextual)
}
typed(desugared, pt)
Expand Down
5 changes: 5 additions & 0 deletions tests/pos/case-signature.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// If `translateFromRepeated` translated wildcards by default, the following
// would break because of the use of wildcards in signatures.
case class Benchmark[A](params: List[A],
sqlInsert: (benchId: Long, params: A, session: Int) => Unit,
fun: List[A])
28 changes: 28 additions & 0 deletions tests/pos/sequence-argument/B_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import scala.reflect.ClassTag
import scala.language.implicitConversions

object B {
def doubleSeq[T](x: T): Seq[T] = Seq(x, x)
def doubleArray[T: ClassTag](x: T): Array[T] = Array(x, x)

def box(args: Integer*): Unit = {}
def widen(args: Long*): Unit = {}
def conv(args: Y*): Unit = {}

box(doubleSeq(1): _*)
box(doubleArray(1): _*)
Java_2.box(doubleSeq(1): _*)
Java_2.box(doubleArray(1): _*)

widen(doubleSeq(1): _*)
widen(doubleArray(1): _*)
Java_2.widen(doubleSeq(1): _*)
Java_2.widen(doubleArray(1): _*)

implicit def xToY(x: X): Y = new Y
val x: X = new X
conv(doubleSeq(x): _*)
conv(doubleArray(x): _*)
Java_2.conv(doubleSeq(x): _*)
Java_2.conv(doubleArray(x): _*)
}
5 changes: 5 additions & 0 deletions tests/pos/sequence-argument/Java_2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
public class Java_2 {
public static void box(Integer ...args) {}
public static void widen(Long... args) {}
public static void conv(Y... args) {}
}
2 changes: 2 additions & 0 deletions tests/pos/sequence-argument/XY_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class X
class Y