Skip to content

Change default params representation #8637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package ast
import core._
import util.Spans._, Types._, Contexts._, Constants._, Names._, NameOps._, Flags._
import Symbols._, StdNames._, Trees._
import Decorators._, transform.SymUtils._
import Decorators.{given _}, transform.SymUtils._
import NameKinds.{UniqueName, EvidenceParamName, DefaultGetterName}
import typer.{FrontEnd, Namer}
import util.{Property, SourceFile, SourcePosition}
Expand Down Expand Up @@ -268,16 +268,18 @@ object desugar {
Nil
}

def normalizedVparamss = meth1.vparamss.map(_.map(vparam =>
cpy.ValDef(vparam)(rhs = EmptyTree)))
def normalizedVparamss = meth1.vparamss.nestedMapConserve(vparam =>
if vparam.rhs.isEmpty then vparam
else cpy.ValDef(vparam)(rhs = EmptyTree).withMods(vparam.mods | HasDefault)
)

def defaultGetters(vparamss: List[List[ValDef]], n: Int): List[DefDef] = vparamss match {
case (vparam :: vparams) :: vparamss1 =>
def defaultGetter: DefDef =
DefDef(
name = DefaultGetterName(methName, n),
tparams = meth.tparams.map(tparam => dropContextBounds(toDefParam(tparam, keepAnnotations = true))),
vparamss = takeUpTo(normalizedVparamss.nestedMap(toDefParam(_, keepAnnotations = true)), n),
vparamss = takeUpTo(normalizedVparamss.nestedMap(toDefParam(_, keepAnnotations = true, keepDefault = false)), n),
tpt = TypeTree(),
rhs = vparam.rhs
)
Expand All @@ -294,7 +296,6 @@ object desugar {
if (defGetters.isEmpty) meth1
else {
val meth2 = cpy.DefDef(meth1)(vparamss = normalizedVparamss)
.withMods(meth1.mods | DefaultParameterized)
Thicket(meth2 :: defGetters)
}
}
Expand Down Expand Up @@ -386,10 +387,11 @@ object desugar {
if (!keepAnnotations) mods = mods.withAnnotations(Nil)
tparam.withMods(mods & EmptyFlags | Param)
}
private def toDefParam(vparam: ValDef, keepAnnotations: Boolean): ValDef = {
private def toDefParam(vparam: ValDef, keepAnnotations: Boolean, keepDefault: Boolean): ValDef = {
var mods = vparam.rawMods
if (!keepAnnotations) mods = mods.withAnnotations(Nil)
vparam.withMods(mods & (GivenOrImplicit | Erased) | Param)
val hasDefault = if keepDefault then HasDefault else EmptyFlags
vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault) | Param)
}

/** The expansion of a class definition. See inline comments for what is involved */
Expand Down Expand Up @@ -463,7 +465,7 @@ object desugar {
ctx.error(CaseClassMissingNonImplicitParamList(cdef), namePos)
ListOfNil
}
else originalVparamss.nestedMap(toDefParam(_, keepAnnotations = false))
else originalVparamss.nestedMap(toDefParam(_, keepAnnotations = false, keepDefault = true))
val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss)

val (normalizedBody, enumCases, enumCompanionRef) = {
Expand All @@ -475,7 +477,7 @@ object desugar {
defDef(
addEvidenceParams(
cpy.DefDef(ddef)(tparams = constrTparams ++ ddef.tparams),
evidenceParams(constr1).map(toDefParam(_, keepAnnotations = false)))))
evidenceParams(constr1).map(toDefParam(_, keepAnnotations = false, keepDefault = false)))))
case stat =>
stat
}
Expand All @@ -501,16 +503,11 @@ object desugar {

// Annotations are dropped from the constructor parameters but should be
// preserved in all derived parameters.
val derivedTparams = {
val impliedTparamsIt = impliedTparams.iterator
constrTparams.map(tparam => derivedTypeParam(tparam)
.withAnnotations(impliedTparamsIt.next().mods.annotations))
}
val derivedVparamss = {
val constrVparamsIt = constrVparamss.iterator.flatten
constrVparamss.nestedMap(vparam => derivedTermParam(vparam)
.withAnnotations(constrVparamsIt.next().mods.annotations))
}
val derivedTparams =
constrTparams.zipWithConserve(impliedTparams)((tparam, impliedParam) =>
derivedTypeParam(tparam).withAnnotations(impliedParam.mods.annotations))
val derivedVparamss =
constrVparamss.nestedMap(vparam => derivedTermParam(vparam))

val arity = constrVparamss.head.length

Expand Down Expand Up @@ -712,13 +709,16 @@ object desugar {
val applyMeths =
if (mods.is(Abstract)) Nil
else {
val copiedFlagsMask = DefaultParameterized | (copiedAccessFlags & Private)
val copiedFlagsMask = copiedAccessFlags & Private
val appMods = {
val mods = Modifiers(Synthetic | constr1.mods.flags & copiedFlagsMask)
if (restrictedAccess) mods.withPrivateWithin(constr1.mods.privateWithin)
else mods
}
val app = DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, widenedCreatorExpr)
val appParamss =
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
app :: widenDefs
}
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TreeTypeMap(
val (tmap1, tparams1) = transformDefs(ddef.tparams)
val (tmap2, vparamss1) = tmap1.transformVParamss(vparamss)
val res = cpy.DefDef(ddef)(name, tparams1, vparamss1, tmap2.transform(tpt), tmap2.transform(ddef.rhs))
res.symbol.setParamssFromDefs(tparams1, vparamss1)
res.symbol.transformAnnotations {
case ann: BodyAnnotation => ann.derivedAnnotation(transform(ann.tree))
case ann => ann
Expand Down
40 changes: 31 additions & 9 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import transform.TypeUtils._
import core._
import util.Spans._, Types._, Contexts._, Constants._, Names._, Flags._, NameOps._
import Symbols._, StdNames._, Annotations._, Trees._, Symbols._
import Decorators._, DenotTransformers._
import Decorators.{given _}, DenotTransformers._
import collection.{immutable, mutable}
import util.{Property, SourceFile, NoSource}
import NameKinds.{TempResultName, OuterSelectName}
Expand Down Expand Up @@ -208,6 +208,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def DefDef(sym: TermSymbol, tparams: List[TypeSymbol], vparamss: List[List[TermSymbol]],
resultType: Type, rhs: Tree)(implicit ctx: Context): DefDef =
sym.setParamss(tparams, vparamss)
ta.assignType(
untpd.DefDef(
sym.name,
Expand All @@ -223,15 +224,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def DefDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Tree)(implicit ctx: Context): DefDef =
polyDefDef(sym, Function.const(rhsFn))

/** A DefDef with given method symbol `sym`.
* @rhsFn A function from type parameter types and term parameter references
* to the method's right-hand side.
* Parameter symbols are taken from the `rawParamss` field of `sym`, or
* are freshly generated if `rawParamss` is empty.
*/
def polyDefDef(sym: TermSymbol, rhsFn: List[Type] => List[List[Tree]] => Tree)(implicit ctx: Context): DefDef = {
val (tparams, mtp) = sym.info match {

val (tparams, existingParamss, mtp) = sym.info match {
case tp: PolyType =>
val tparams = ctx.newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_))
(tparams, tp.instantiate(tparams map (_.typeRef)))
case tp => (Nil, tp)
val (tparams, existingParamss) = sym.rawParamss match
case tparams :: vparamss =>
assert(tparams.hasSameLengthAs(tp.paramNames) && tparams.head.isType)
(tparams.asInstanceOf[List[TypeSymbol]], vparamss)
case _ =>
(ctx.newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_)), Nil)
(tparams, existingParamss, tp.instantiate(tparams map (_.typeRef)))
case tp => (Nil, sym.rawParamss, tp)
}

def valueParamss(tp: Type): (List[List[TermSymbol]], Type) = tp match {
def valueParamss(tp: Type, existingParamss: List[List[Symbol]]): (List[List[TermSymbol]], Type) = tp match {
case tp: MethodType =>
val isParamDependent = tp.isParamDependent
val previousParamRefs = if (isParamDependent) mutable.ListBuffer[TermRef]() else null
Expand All @@ -254,14 +267,23 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
makeSym(origInfo)
}

val params = tp.paramNames.lazyZip(tp.paramInfos).map(valueParam)
val (paramss, rtp) = valueParamss(tp.instantiate(params map (_.termRef)))
val (params, existingParamss1) =
if tp.paramInfos.isEmpty then (Nil, existingParamss)
else existingParamss match
case vparams :: existingParamss1 =>
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
(vparams.asInstanceOf[List[TermSymbol]], existingParamss1)
case _ =>
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
val (paramss, rtp) =
valueParamss(tp.instantiate(params map (_.termRef)), existingParamss1)
(params :: paramss, rtp)
case tp => (Nil, tp.widenExpr)
}
val (vparamss, rtp) = valueParamss(mtp)
val (vparamss, rtp) = valueParamss(mtp, existingParamss)
val targs = tparams map (_.typeRef)
val argss = vparamss.nestedMap(vparam => Ident(vparam.termRef))
sym.setParamss(tparams, vparamss)
DefDef(sym, tparams, vparamss, rtp, rhsFn(targs)(argss))
}

Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Decorators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,13 @@ object Decorators {
def & (ys: List[T]): List[T] = xs filter (ys contains _)
}

implicit class ListOfListDecorator[T](val xss: List[List[T]]) extends AnyVal {
def nestedMap[U](f: T => U): List[List[U]] = xss map (_ map f)
def nestedMapconserve[U](f: T => U): List[List[U]] = xss mapconserve (_ mapconserve f)
}
extension ListOfListDecorator on [T, U](xss: List[List[T]]):
def nestedMap(f: T => U): List[List[U]] =
xss.map(_.map(f))
def nestedMapConserve(f: T => U): List[List[U]] =
xss.mapconserve(_.mapconserve(f))
def nestedZipWithConserve(yss: List[List[U]])(f: (T, U) => T): List[List[T]] =
xss.zipWithConserve(yss)((xs, ys) => xs.zipWithConserve(ys)(f))

implicit class TextToString(val text: Text) extends AnyVal {
def show(implicit ctx: Context): String = text.mkString(ctx.settings.pageWidth.value, ctx.settings.printLines.value)
Expand Down
14 changes: 7 additions & 7 deletions compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ object Flags {
val (SuperParamAliasOrScala2x @ _, SuperParamAlias @ _, Scala2x @ _) = newFlags(26, "<super-param-alias>", "<scala-2.x>")

/** A method that has default params */
val (_, DefaultParameterized @ _, _) = newFlags(27, "<defaultparam>")
val (_, HasDefault @ _, _) = newFlags(27, "<hasdefault>")

/** An extension method, or a collective extension instance */
val (_, Extension @ _, _) = newFlags(28, "<extension>")
Expand Down Expand Up @@ -391,18 +391,18 @@ object Flags {

/** Translation of Scala2's EXPANDEDNAME flag. This flag is never stored in
* symbols, is only used locally when reading the flags of a Scala2 symbol.
* It's therefore safe to share the code with `InheritedDefaultParams` because
* It's therefore safe to share the code with `HasDefaultParams` because
* the latter is never present in Scala2 unpickle info.
* /
* A method that is known to have inherited default parameters
* A method that is known to have (defined or inherited) default parameters
*/
val (Scala2ExpandedName @ _, InheritedDefaultParams @ _, _) = newFlags(59, "<inherited-default-param>")
val (Scala2ExpandedName @ _, HasDefaultParams @ _, _) = newFlags(59, "<has-default-params>")

/** A method that is known to have no default parameters
* /
* A type symbol with provisional empty bounds
*/
val (_, NoDefaultParams @ _, Provisional @ _) = newFlags(60, "<no-default-param>", "<provisional>")
val (_, NoDefaultParams @ _, Provisional @ _) = newFlags(60, "<no-default-params>", "<provisional>")

/** A denotation that is valid in all run-ids */
val (Permanent @ _, _, _) = newFlags(61, "<permanent>")
Expand Down Expand Up @@ -525,8 +525,7 @@ object Flags {
val EnumCase: FlagSet = Case | Enum
val CovariantLocal: FlagSet = Covariant | Local // A covariant type parameter
val ContravariantLocal: FlagSet = Contravariant | Local // A contravariant type parameter
val HasDefaultParamsFlags: FlagSet = DefaultParameterized | InheritedDefaultParams // Has defined or inherited default parameters
val DefaultParameter: FlagSet = DefaultParameterized | Param // A Scala 2x default parameter
val DefaultParameter: FlagSet = HasDefault | Param // A Scala 2x default parameter
val DeferredOrLazy: FlagSet = Deferred | Lazy
val DeferredOrLazyOrMethod: FlagSet = Deferred | Lazy | Method
val DeferredOrTermParamOrAccessor: FlagSet = Deferred | ParamAccessor | TermParam // term symbols without right-hand sides
Expand Down Expand Up @@ -576,6 +575,7 @@ object Flags {
val SyntheticGivenMethod: FlagSet = Synthetic | Given | Method
val SyntheticModule: FlagSet = Synthetic | Module
val SyntheticOpaque: FlagSet = Synthetic | Opaque
val SyntheticParam: FlagSet = Synthetic | Param
val SyntheticTermParam: FlagSet = Synthetic | TermParam
val SyntheticTypeParam: FlagSet = Synthetic | TypeParam
}
80 changes: 71 additions & 9 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Scopes.Scope
import dotty.tools.io.AbstractFile
import Decorators.SymbolIteratorDecorator
import ast._
import ast.Trees.{LambdaTypeTree, TypeBoundsTree}
import ast.Trees.{LambdaTypeTree, TypeBoundsTree, ValDef, TypeDef}
import Trees.Literal
import Variances.Variance
import annotation.tailrec
Expand Down Expand Up @@ -150,6 +150,7 @@ object SymDenotations {
private var myFlags: FlagSet = adaptFlags(initFlags)
private var myPrivateWithin: Symbol = initPrivateWithin
private var myAnnotations: List[Annotation] = Nil
private var myParamss: List[List[Symbol]] = Nil

/** The owner of the symbol; overridden in NoDenotation */
def owner: Symbol = maybeOwner
Expand Down Expand Up @@ -372,6 +373,59 @@ object SymDenotations {
case Nil => Nil
}

/** If this is a method, the parameter symbols, by section.
* Both type and value parameters are included. Empty sections are skipped.
*/
final def rawParamss: List[List[Symbol]] = myParamss
final def rawParamss_=(pss: List[List[Symbol]]): Unit =
myParamss = pss

final def setParamss(tparams: List[Symbol], vparamss: List[List[Symbol]])(using Context): Unit =
rawParamss = (if tparams.isEmpty then vparamss else tparams :: vparamss)
.filterConserve(!_.isEmpty)

final def setParamssFromDefs(tparams: List[TypeDef[?]], vparamss: List[List[ValDef[?]]])(using Context): Unit =
setParamss(tparams.map(_.symbol), vparamss.map(_.map(_.symbol)))

/** A pair consistsing of type paremeter symbols and value parameter symbol lists
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nicolasstucki The paramSymss method should be exported through Tasty reflect.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added paramSymss to the TASTy reflect API

* of this method definition, or (Nil, Nil) for other symbols.
* Makes use of `rawParamss` when present, or constructs fresh parameter symbols otherwise.
* This method can be allocation-heavy.
*/
final def paramSymss(using ctx: Context): (List[TypeSymbol], List[List[TermSymbol]]) =

def recurWithParamss(info: Type, paramss: List[List[Symbol]]): List[List[Symbol]] =
info match
case info: LambdaType =>
if info.paramNames.isEmpty then Nil :: recurWithParamss(info.resType, paramss)
else paramss.head :: recurWithParamss(info.resType, paramss.tail)
case _ =>
Nil

def recurWithoutParamss(info: Type): List[List[Symbol]] = info match
case info: LambdaType =>
val params = info.paramNames.lazyZip(info.paramInfos).map((pname, ptype) =>
ctx.newSymbol(symbol, pname, SyntheticParam, ptype))
val prefs = params.map(_.namedType)
for param <- params do
param.info = param.info.substParams(info, prefs)
params :: recurWithoutParamss(info.instantiate(prefs))
case _ =>
Nil

try
val allParamss =
if rawParamss.isEmpty then recurWithoutParamss(info)
else recurWithParamss(info, rawParamss)
val result = info match
case info: PolyType => (allParamss.head, allParamss.tail)
case _ => (Nil, allParamss)
result.asInstanceOf[(List[TypeSymbol], List[List[TermSymbol]])]
catch case NonFatal(ex) =>
println(i"paramSymss failure for $symbol, $info, $rawParamss")
throw ex
end paramSymss

/** The denotation is completed: info is not a lazy type and attributes have defined values */
final def isCompleted: Boolean = !myInfo.isInstanceOf[LazyType]

Expand Down Expand Up @@ -916,15 +970,19 @@ object SymDenotations {
def isAsConcrete(that: Symbol)(implicit ctx: Context): Boolean =
!this.is(Deferred) || that.is(Deferred)

/** Does this symbol have defined or inherited default parameters? */
/** Does this symbol have defined or inherited default parameters?
* Default parameters are recognized until erasure.
*/
def hasDefaultParams(implicit ctx: Context): Boolean =
if (this.isOneOf(HasDefaultParamsFlags)) true
else if (this.is(NoDefaultParams)) false
else {
val result = allOverriddenSymbols exists (_.hasDefaultParams)
setFlag(if (result) InheritedDefaultParams else NoDefaultParams)
if ctx.erasedTypes then false
else if is(HasDefaultParams) then true
else if is(NoDefaultParams) then false
else
val result =
rawParamss.exists(_.exists(_.is(HasDefault)))
|| allOverriddenSymbols.exists(_.hasDefaultParams)
setFlag(if result then HasDefaultParams else NoDefaultParams)
result
}

/** Symbol is an owner that would be skipped by effectiveOwner. Skipped are
* - package objects
Expand Down Expand Up @@ -1450,16 +1508,20 @@ object SymDenotations {
initFlags: FlagSet = UndefinedFlags,
info: Type = null,
privateWithin: Symbol = null,
annotations: List[Annotation] = null)(implicit ctx: Context): SymDenotation = {
annotations: List[Annotation] = null,
rawParamss: List[List[Symbol]] = null)(
using ctx: Context): SymDenotation = {
// simulate default parameters, while also passing implicit context ctx to the default values
val initFlags1 = (if (initFlags != UndefinedFlags) initFlags else this.flags)
val info1 = if (info != null) info else this.info
if (ctx.isAfterTyper && changedClassParents(info, info1, completersMatter = false))
assert(ctx.phase.changesParents, i"undeclared parent change at ${ctx.phase} for $this, was: $info, now: $info1")
val privateWithin1 = if (privateWithin != null) privateWithin else this.privateWithin
val annotations1 = if (annotations != null) annotations else this.annotations
val rawParamss1 = if rawParamss != null then rawParamss else this.rawParamss
val d = ctx.SymDenotation(symbol, owner, name, initFlags1, info1, privateWithin1)
d.annotations = annotations1
d.rawParamss = rawParamss1
d.registeredCompanion = registeredCompanion
d
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ class TreePickler(pickler: TastyPickler) {
if (flags.is(Mutable)) writeModTag(MUTABLE)
if (flags.is(Accessor)) writeModTag(FIELDaccessor)
if (flags.is(CaseAccessor)) writeModTag(CASEaccessor)
if (flags.is(DefaultParameterized)) writeModTag(DEFAULTparameterized)
if (flags.is(HasDefault)) writeModTag(HASDEFAULT)
if (flags.is(StableRealizable)) writeModTag(STABLE)
if (flags.is(Extension)) writeModTag(EXTENSION)
if (flags.is(ParamAccessor)) writeModTag(PARAMsetter)
Expand Down
Loading