Skip to content

Keep track of method parameter symbols #8597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TreeTypeMap(
val (tmap1, tparams1) = transformDefs(ddef.tparams)
val (tmap2, vparamss1) = tmap1.transformVParamss(vparamss)
val res = cpy.DefDef(ddef)(name, tparams1, vparamss1, tmap2.transform(tpt), tmap2.transform(ddef.rhs))
res.symbol.setParamssFromDefs(tparams1, vparamss1)
res.symbol.transformAnnotations {
case ann: BodyAnnotation => ann.derivedAnnotation(transform(ann.tree))
case ann => ann
Expand Down
38 changes: 30 additions & 8 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def DefDef(sym: TermSymbol, tparams: List[TypeSymbol], vparamss: List[List[TermSymbol]],
resultType: Type, rhs: Tree)(implicit ctx: Context): DefDef =
sym.setParamss(tparams, vparamss)
ta.assignType(
untpd.DefDef(
sym.name,
Expand All @@ -223,15 +224,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def DefDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Tree)(implicit ctx: Context): DefDef =
polyDefDef(sym, Function.const(rhsFn))

/** A DefDef with given method symbol `sym`.
* @rhsFn A function from type parameter types and term parameter references
* to the method's right-hand side.
* Parameter symbols are taken from the `rawParamss` field of `sym`, or
* are freshly generated if `rawParamss` is empty.
*/
def polyDefDef(sym: TermSymbol, rhsFn: List[Type] => List[List[Tree]] => Tree)(implicit ctx: Context): DefDef = {
val (tparams, mtp) = sym.info match {

val (tparams, existingParamss, mtp) = sym.info match {
case tp: PolyType =>
val tparams = ctx.newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_))
(tparams, tp.instantiate(tparams map (_.typeRef)))
case tp => (Nil, tp)
val (tparams, existingParamss) = sym.rawParamss match
case tparams :: vparamss =>
assert(tparams.hasSameLengthAs(tp.paramNames) && tparams.head.isType)
(tparams.asInstanceOf[List[TypeSymbol]], vparamss)
case _ =>
(ctx.newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_)), Nil)
(tparams, existingParamss, tp.instantiate(tparams map (_.typeRef)))
case tp => (Nil, sym.rawParamss, tp)
}

def valueParamss(tp: Type): (List[List[TermSymbol]], Type) = tp match {
def valueParamss(tp: Type, existingParamss: List[List[Symbol]]): (List[List[TermSymbol]], Type) = tp match {
case tp: MethodType =>
val isParamDependent = tp.isParamDependent
val previousParamRefs = if (isParamDependent) mutable.ListBuffer[TermRef]() else null
Expand All @@ -254,14 +267,23 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
makeSym(origInfo)
}

val params = tp.paramNames.lazyZip(tp.paramInfos).map(valueParam)
val (paramss, rtp) = valueParamss(tp.instantiate(params map (_.termRef)))
val (params, existingParamss1) =
if tp.paramInfos.isEmpty then (Nil, existingParamss)
else existingParamss match
case vparams :: existingParamss1 =>
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
(vparams.asInstanceOf[List[TermSymbol]], existingParamss1)
case _ =>
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
val (paramss, rtp) =
valueParamss(tp.instantiate(params map (_.termRef)), existingParamss1)
(params :: paramss, rtp)
case tp => (Nil, tp.widenExpr)
}
val (vparamss, rtp) = valueParamss(mtp)
val (vparamss, rtp) = valueParamss(mtp, existingParamss)
val targs = tparams map (_.typeRef)
val argss = vparamss.nestedMap(vparam => Ident(vparam.termRef))
sym.setParamss(tparams, vparamss)
DefDef(sym, tparams, vparamss, rtp, rhsFn(targs)(argss))
}

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ object Flags {
val SyntheticGivenMethod: FlagSet = Synthetic | Given | Method
val SyntheticModule: FlagSet = Synthetic | Module
val SyntheticOpaque: FlagSet = Synthetic | Opaque
val SyntheticParam: FlagSet = Synthetic | Param
val SyntheticTermParam: FlagSet = Synthetic | TermParam
val SyntheticTypeParam: FlagSet = Synthetic | TypeParam
}
61 changes: 59 additions & 2 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Scopes.Scope
import dotty.tools.io.AbstractFile
import Decorators.SymbolIteratorDecorator
import ast._
import ast.Trees.{LambdaTypeTree, TypeBoundsTree}
import ast.Trees.{LambdaTypeTree, TypeBoundsTree, ValDef, TypeDef}
import Trees.Literal
import Variances.Variance
import annotation.tailrec
Expand Down Expand Up @@ -150,6 +150,7 @@ object SymDenotations {
private var myFlags: FlagSet = adaptFlags(initFlags)
private var myPrivateWithin: Symbol = initPrivateWithin
private var myAnnotations: List[Annotation] = Nil
private var myParamss: List[List[Symbol]] = Nil

/** The owner of the symbol; overridden in NoDenotation */
def owner: Symbol = maybeOwner
Expand Down Expand Up @@ -372,6 +373,58 @@ object SymDenotations {
case Nil => Nil
}

/** If this is a method, the parameter symbols, by section.
* Both type and value parameters are included. Empty sections are skipped.
*/
final def rawParamss: List[List[Symbol]] = myParamss
final def rawParamss_=(pss: List[List[Symbol]]): Unit =
myParamss = pss

final def setParamss(tparams: List[Symbol], vparamss: List[List[Symbol]])(using Context): Unit =
rawParamss = (if tparams.isEmpty then vparamss else tparams :: vparamss)
.filterConserve(!_.isEmpty)

final def setParamssFromDefs(tparams: List[TypeDef[?]], vparamss: List[List[ValDef[?]]])(using Context): Unit =
setParamss(tparams.map(_.symbol), vparamss.map(_.map(_.symbol)))

/** A pair consistsing of type paremeter symbols and value parameter symbol lists
* of this method definition, or (Nil, Nil) for other symbols.
* Makes use of `rawParamss` when present, or constructs fresh parameter symbols otherwise.
* This method can be allocation-heavy.
*/
final def paramSymss(using ctx: Context): (List[TypeSymbol], List[List[TermSymbol]]) =

def recurWithParamss(info: Type, paramss: List[List[Symbol]]): List[List[Symbol]] =
info match
case info: LambdaType =>
if info.paramNames.isEmpty then Nil :: recurWithParamss(info.resType, paramss)
else paramss.head :: recurWithParamss(info.resType, paramss.tail)
case _ =>
Nil

def recurWithoutParamss(info: Type): List[List[Symbol]] = info match
case info: LambdaType =>
val params = info.paramNames.lazyZip(info.paramInfos).map((pname, ptype) =>
ctx.newSymbol(symbol, pname, SyntheticParam, ptype))
val prefs = params.map(_.namedType)
for param <- params do
param.info = param.info.substParams(info, prefs)
params :: recurWithoutParamss(info.instantiate(prefs))
case _ =>
Nil

try
val allParamss =
if rawParamss.isEmpty then recurWithoutParamss(info)
else recurWithParamss(info, rawParamss)
info match
case info: PolyType => (allParamss.head, allParamss.tail).asInstanceOf
case _ => (Nil, allParamss).asInstanceOf
catch case NonFatal(ex) =>
println(i"paramSymss failure for $symbol, $info, $rawParamss")
throw ex
end paramSymss

/** The denotation is completed: info is not a lazy type and attributes have defined values */
final def isCompleted: Boolean = !myInfo.isInstanceOf[LazyType]

Expand Down Expand Up @@ -1450,16 +1503,20 @@ object SymDenotations {
initFlags: FlagSet = UndefinedFlags,
info: Type = null,
privateWithin: Symbol = null,
annotations: List[Annotation] = null)(implicit ctx: Context): SymDenotation = {
annotations: List[Annotation] = null,
rawParamss: List[List[Symbol]] = null)(
using ctx: Context): SymDenotation = {
// simulate default parameters, while also passing implicit context ctx to the default values
val initFlags1 = (if (initFlags != UndefinedFlags) initFlags else this.flags)
val info1 = if (info != null) info else this.info
if (ctx.isAfterTyper && changedClassParents(info, info1, completersMatter = false))
assert(ctx.phase.changesParents, i"undeclared parent change at ${ctx.phase} for $this, was: $info, now: $info1")
val privateWithin1 = if (privateWithin != null) privateWithin else this.privateWithin
val annotations1 = if (annotations != null) annotations else this.annotations
val rawParamss1 = if rawParamss != null then rawParamss else this.rawParamss
val d = ctx.SymDenotation(symbol, owner, name, initFlags1, info1, privateWithin1)
d.annotations = annotations1
d.rawParamss = rawParamss1
d.registeredCompanion = registeredCompanion
d
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,10 @@ class TreeUnpickler(reader: TastyReader,
ta.assignType(untpd.ValDef(sym.name.asTermName, tpt, readRhs(localCtx)), sym)

def DefDef(tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree) =
ta.assignType(
untpd.DefDef(sym.name.asTermName, tparams, vparamss, tpt, readRhs(localCtx)),
sym)
sym.setParamssFromDefs(tparams, vparamss)
ta.assignType(
untpd.DefDef(sym.name.asTermName, tparams, vparamss, tpt, readRhs(localCtx)),
sym)

def TypeDef(rhs: Tree) =
ta.assignType(untpd.TypeDef(sym.name.asTypeName, rhs), sym)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
/** A map from symbols to their associated `decls` scopes */
private val symScopes = mutable.AnyRefMap[Symbol, Scope]()

/** A dummy buffer to pass to `readType` when no `rawParamss` are collected */
private val throwAwayBuffer = new ListBuffer[List[Symbol]]

protected def errorBadSignature(msg: String, original: Option[RuntimeException] = None)(implicit ctx: Context): Nothing = {
val ex = new BadSignature(
i"""error reading Scala signature of $classRoot from $source:
Expand Down Expand Up @@ -574,7 +577,11 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
if (isSymbolRef(inforef)) inforef = readNat()

// println("reading type for " + denot) // !!! DEBUG
val tp = at(inforef, () => readType()(ctx))
val paramssBuf =
if denot.is(Method) then new ListBuffer[List[Symbol]]
else throwAwayBuffer
val tp = at(inforef, () => readType(paramssBuf)(ctx))
if denot.is(Method) then denot.rawParamss = paramssBuf.toList

denot match {
case denot: ClassDenotation if !isRefinementClass(denot.symbol) =>
Expand Down Expand Up @@ -721,7 +728,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
* the flag say that a type of kind * is expected, so that PolyType(tps, restpe) can be disambiguated to PolyType(tps, NullaryMethodType(restpe))
* (if restpe is not a ClassInfoType, a MethodType or a NullaryMethodType, which leaves TypeRef/SingletonType -- the latter would make the polytype a type constructor)
*/
protected def readType()(implicit ctx: Context): Type = {
protected def readType(paramssBuf: ListBuffer[List[Symbol]])(implicit ctx: Context): Type = {
val tag = readByte()
val end = readNat() + readIndex
(tag: @switch) match {
Expand Down Expand Up @@ -790,14 +797,18 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
case METHODtpe | IMPLICITMETHODtpe =>
val restpe = readTypeRef()
val params = until(end, () => readSymbolRef())
if params.nonEmpty then paramssBuf += params
val maker = MethodType.companion(
isImplicit = tag == IMPLICITMETHODtpe || params.nonEmpty && params.head.is(Implicit))
maker.fromSymbols(params, restpe)
case POLYtpe =>
val restpe = readTypeRef()
val typeParams = until(end, () => readSymbolRef())
if (typeParams.nonEmpty) TempPolyType(typeParams.asInstanceOf[List[TypeSymbol]], restpe.widenExpr)
else ExprType(restpe)
if typeParams.nonEmpty then
paramssBuf += typeParams
TempPolyType(typeParams.asInstanceOf[List[TypeSymbol]], restpe.widenExpr)
else
ExprType(restpe)
case EXISTENTIALtpe =>
val restpe = readTypeRef()
val boundSyms = until(end, () => readSymbolRef())
Expand Down Expand Up @@ -881,7 +892,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
at(readNat(), () => readDisambiguatedSymbol(p)())

protected def readNameRef()(implicit ctx: Context): Name = at(readNat(), () => readName())
protected def readTypeRef()(implicit ctx: Context): Type = at(readNat(), () => readType()) // after the NMT_TRANSITION period, we can leave off the () => ... ()
protected def readTypeRef()(implicit ctx: Context): Type = at(readNat(), () => readType(throwAwayBuffer)) // after the NMT_TRANSITION period, we can leave off the () => ... ()
protected def readConstantRef()(implicit ctx: Context): Constant = at(readNat(), () => readConstant())

protected def readTypeNameRef()(implicit ctx: Context): TypeName = readNameRef().toTypeName
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/reporting/Message.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ abstract class Message(val errorId: ErrorMessageID) { self =>
* they look weird and are normally follow-up errors to something that was
* diagnosed before.
*/
def isNonSensical: Boolean = { message; myIsNonSensical }
def isNonSensical: Boolean = { message; myIsNonSensical }

/** The implicit `Context` in messages is a large thing that we don't want
* persisted. This method gets around that by duplicating the message,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
override def transformDefDef(tree: DefDef)(implicit ctx: Context): DefDef = {
val sym = tree.symbol
if (sym.isConstructor && sym.owner.derivesFromJavaEnum)
cpy.DefDef(tree)(
val tree1 = cpy.DefDef(tree)(
vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param)))
sym.setParamssFromDefs(tree1.tparams, tree1.vparamss)
tree1
else if (sym.name == nme.DOLLAR_NEW && sym.owner.linkedClass.derivesFromJavaEnum) {
val Block((tdef @ TypeDef(tpnme.ANON_CLASS, templ: Template)) :: Nil, call) = tree.rhs
val args = tree.vparamss.last.takeRight(2).map(param => ref(param.symbol)).reverse
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,14 @@ class TreeChecker extends Phase with SymTransformer {
}

override def typedDefDef(ddef: untpd.DefDef, sym: Symbol)(implicit ctx: Context): Tree =
def defParamss =
(ddef.tparams :: ddef.vparamss).filter(!_.isEmpty).map(_.map(_.symbol))
def layout(symss: List[List[Symbol]]): String =
symss.map(syms => i"($syms%, %)").mkString
assert(ctx.erasedTypes || sym.rawParamss == defParamss,
i"""param mismatch for ${sym.showLocated}:
|defined in tree = ${layout(defParamss)}
|stored in symbol = ${layout(sym.rawParamss)}""")
withDefinedSyms(ddef.tparams) {
withDefinedSyms(ddef.vparamss.flatten) {
if (!sym.isClassConstructor && !(sym.name eq nme.STATIC_CONSTRUCTOR))
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,7 @@ class Namer { typer: Typer =>
vparamss foreach completeParams
def typeParams = tparams map symbolOfTree
val termParamss = ctx.normalizeIfConstructor(vparamss.nestedMap(symbolOfTree), isConstructor)
sym.setParamss(typeParams, termParamss)
def wrapMethType(restpe: Type): Type = {
instantiateDependent(restpe, typeParams, termParamss)
ctx.methodType(tparams map symbolOfTree, termParamss, restpe, isJava = ddef.mods.is(JavaDefined))
Expand Down