Skip to content

Implement AppliedTermRef (singleton types for term-level applications) #3887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,15 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
false
}
compareTypeBounds
case tp2: AppliedTermRef =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to have this in thirdTry and not firstTry? Also I think it'd be good to have more testcases that stress-test the subtype checks for AppliedTermRef. For example, what happens when an AppliedTermRef is hidden in a type alias?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have a test that relies on a type application:

val t: Id[{a + 1}] = a + 1

I'll also add one that relies on a simple alias.

// TODO(gsps): Check whether rule or position thereof should change
def compareAppliedTerm = tp1 match {
case tp1: AppliedTermRef =>
sameLength(tp1.args, tp2.args) && isSubType(tp1.fn, tp2.fn) &&
tp1.args.zip(tp2.args).forall((arg1, arg2) => isSubType(arg1, arg2))
case _ => fourthTry
}
compareAppliedTerm
case tp2: AnnotatedType if tp2.isRefining =>
(tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || defn.isBottomType(tp1)) &&
recur(tp1, tp2.parent)
Expand Down
99 changes: 98 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1044,9 +1044,11 @@ object Types {
case _ => this
}

/** Strip PolyType prefixes */
/** Strip PolyType and AppliedTermRef prefixes */
// TODO(gsps): Rename this to also reflect the removal of AppliedTermRefs
def stripPoly(implicit ctx: Context): Type = this match {
case tp: PolyType => tp.resType.stripPoly
case tp: AppliedTermRef => tp.resType.stripPoly
case _ => this
}

Expand Down Expand Up @@ -2608,6 +2610,73 @@ object Types {
override def hashCode: Int = System.identityHashCode(this)
}

// --- AppliedTermRef -------------------------------------------------------

/** A precise representation of a term-level application `fn(... args)`. **/
abstract case class AppliedTermRef(fn: /*TermRef | AppliedTermRef*/ SingletonType, args: List[Type])
extends CachedProxyType with SingletonType
{
private[this] var myResType: Type = _
def resType(implicit ctx: Context): Type = {
if (myResType == null)
fn.widen match {
case methTpe: MethodType => myResType = ctx.typer.applicationResultType(methTpe, args)
}
myResType
}

def underlying(implicit ctx: Context): Type = resType

/** Compute the derived AppliedTermRef, widening to the result type if any of its
* components is unstable.
*/
def derivedAppliedTermRef(fn: Type, args: List[Type])(implicit ctx: Context): Type =
if ((this.fn eq fn) && (this.args eq args)) this
else AppliedTermRef.make(fn, args)

override def computeHash(bs: Binders) = doHash(bs, fn, args)
override def hashIsStable: Boolean = fn.hashIsStable && args.forall(_.hashIsStable)

override def eql(that: Type) = that match {
case that: AppliedTermRef => (this.fn eq that.fn) && this.args.eqElements(that.args)
case _ => false
}

// TODO(gsps): AppliedTermRef#iso?
}

final class CachedAppliedTermRef(fn: SingletonType, args: List[Type]) extends AppliedTermRef(fn, args)

object AppliedTermRef {
def apply(fn: SingletonType, args: List[Type])(implicit ctx: Context): AppliedTermRef = {
assertUnerased()
assert(fn.isStable, args.forall(_.isStable))
unique(new CachedAppliedTermRef(fn, args))
}

def make(fn: Type, args: List[Type])(implicit ctx: Context): Type = {
def fallbackToResult(): Type =
fn.widenDealias match {
case methTpe: MethodType => ctx.typer.applicationResultType(methTpe, args)
case _: WildcardType => WildcardType
case tp => throw new AssertionError(i"Don't know how to apply $tp.")
}
def complete(fn: SingletonType): Type =
if (!ctx.erasedTypes && fn.isStable && args.forall(_.isStable))
AppliedTermRef(fn, args)
else
fallbackToResult()
fn.dealias match {
case fn: TermRef =>
complete(fn)
case fn: AppliedTermRef =>
complete(fn)
case _ =>
fallbackToResult()
}
}
}

// --- Refined Type and RecType ------------------------------------------------

abstract class RefinedOrRecType extends CachedProxyType with ValueType {
Expand Down Expand Up @@ -4800,6 +4869,8 @@ object Types {
tp.derivedSuperType(thistp, supertp)
protected def derivedAppliedType(tp: AppliedType, tycon: Type, args: List[Type]): Type =
tp.derivedAppliedType(tycon, args)
protected def derivedAppliedTermRef(tp: AppliedTermRef, fn: Type, args: List[Type]): Type =
tp.derivedAppliedTermRef(fn, args)
protected def derivedAndType(tp: AndType, tp1: Type, tp2: Type): Type =
tp.derivedAndType(tp1, tp2)
protected def derivedOrType(tp: OrType, tp1: Type, tp2: Type): Type =
Expand Down Expand Up @@ -4859,6 +4930,9 @@ object Types {
}
derivedAppliedType(tp, this(tp.tycon), mapArgs(tp.args, tp.tyconTypeParams))

case tp: AppliedTermRef =>
derivedAppliedTermRef(tp, this(tp.fn), tp.args.mapConserve(this))

case tp: RefinedType =>
derivedRefinedType(tp, this(tp.parent), this(tp.refinedInfo))

Expand Down Expand Up @@ -5173,6 +5247,26 @@ object Types {
else tp.derivedAppliedType(tycon, args)
}

// TODO(gsps): Double-check for changes in similar derivations
override protected def derivedAppliedTermRef(tp: AppliedTermRef, fn: Type, args: List[Type]): Type =
fn match {
case Range(fnLo, fnHi) =>
range(derivedAppliedTermRef(tp, fnLo, args), derivedAppliedTermRef(tp, fnHi, args))
case _ =>
if (fn.isBottomType) {
fn
} else if (args.exists(isRange)) {
val loBuf, hiBuf = new mutable.ListBuffer[Type]
args foreach {
case Range(lo, hi) => loBuf += lo; hiBuf += hi
case arg => loBuf += arg; hiBuf += arg
}
range(tp.derivedAppliedTermRef(fn, loBuf.toList), tp.derivedAppliedTermRef(fn, hiBuf.toList))
} else {
tp.derivedAppliedTermRef(fn, args)
}
}

override protected def derivedAndType(tp: AndType, tp1: Type, tp2: Type): Type =
if (isRange(tp1) || isRange(tp2)) range(lower(tp1) & lower(tp2), upper(tp1) & upper(tp2))
else tp.derivedAndType(tp1, tp2)
Expand Down Expand Up @@ -5275,6 +5369,9 @@ object Types {
}
foldArgs(this(x, tycon), tp.tyconTypeParams, args)

case tp: AppliedTermRef =>
foldOver(this(x, tp.fn), tp.args)

case _: BoundType | _: ThisType => x

case tp: LambdaType =>
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TastyPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class TastyPrinter(bytes: Array[Byte])(implicit ctx: Context) {
POLYtype | TYPELAMBDAtype =>
printTree()
until(end) { printName(); printTree() }
case APPLIEDTERMREF =>
printTree()
until(end) { printTree() }
case PARAMtype =>
printNat(); printNat()
case _ =>
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ class TreePickler(pickler: TastyPickler) {
case AppliedType(tycon, args) =>
writeByte(APPLIEDtype)
withLength { pickleType(tycon); args.foreach(pickleType(_)) }
case AppliedTermRef(fn, args) =>
writeByte(APPLIEDTERMREF)
withLength { pickleType(fn); args.foreach(pickleType(_)) }
case ConstantType(value) =>
pickleConstant(value)
case tpe: NamedType =>
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ class TreeUnpickler(reader: TastyReader,
// Eta expansion of the latter puts readType() out of the expression.
case APPLIEDtype =>
readType().appliedTo(until(end)(readType()))
case APPLIEDTERMREF =>
AppliedTermRef(readType().asInstanceOf[SingletonType], until(end)(readType()))
case TYPEBOUNDS =>
val lo = readType()
if nothingButMods(end) then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
}
}

// TODO(gsps): This is a hack to mark certain primitive methods as stable.
// In the future such primitive methods should be pickled with the StableRealizable flag set.
def markPrimitiveStable(owner: Symbol, name: Name, flags: FlagSet): FlagSet =
if (tpnme.ScalaValueNames.contains(defn.scalaClassName(owner))) flags | StableRealizable else flags

tag match {
case NONEsym => return NoSymbol
case EXTref | EXTMODCLASSref => return readExtSymbol()
Expand All @@ -457,10 +462,12 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
}

name = name.adjustIfModuleClass(flags)
if (flags.is(Method))
if (flags.is(Method)) {
name =
if (name == nme.TRAIT_CONSTRUCTOR) nme.CONSTRUCTOR
else name.asTermName.unmangle(Scala2MethodNameKinds)
flags = markPrimitiveStable(owner, name, flags)
}
if ((flags.is(Scala2ExpandedName))) {
name = name.unmangle(ExpandedName)
flags = flags &~ Scala2ExpandedName
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,7 @@ object Parsers {
* | `(' ArgTypes `)'
* | `_' TypeBounds
* | Refinement
* | `{` PostfixExpr `}`
* | Literal
* | ‘$’ ‘{’ Block ‘}’
*/
Expand All @@ -1545,7 +1546,7 @@ object Parsers {
makeTupleOrParens(inParens(argTypes(namedOK = false, wildOK = true)))
}
else if (in.token == LBRACE)
atSpan(in.offset) { RefinedTypeTree(EmptyTree, refinement()) }
atSpan(in.offset) { inBraces(emptyRefinementOrSingletonExpr()) }
else if (isSimpleLiteral) { SingletonTypeTree(literal(inType = true)) }
else if (isIdent(nme.raw.MINUS) && in.lookaheadIn(numericLitTokens)) {
val start = in.offset
Expand Down Expand Up @@ -1575,6 +1576,11 @@ object Parsers {
}
}

def emptyRefinementOrSingletonExpr(): Tree = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this refinementOfEmptyOrSingletonExpr().

if (!isStatSeqEnd && !isDclIntro) SingletonTypeTree(postfixExpr())
else RefinedTypeTree(EmptyTree, refineStatSeq())
}

val handleSingletonType: Tree => Tree = t =>
if (in.token == TYPE) {
in.nextToken()
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class PlainPrinter(_ctx: Context) extends Printer {
case tp @ AppliedType(tycon, args) =>
if (defn.isCompiletimeAppliedType(tycon.typeSymbol)) tp.tryCompiletimeConstantFold
else tycon.dealias.appliedTo(args)
case tp @ AppliedTermRef(fn, args) =>
tp.derivedAppliedTermRef(homogenize(fn), args.mapConserve(homogenize))
case _ =>
tp
}
Expand Down Expand Up @@ -140,6 +142,8 @@ class PlainPrinter(_ctx: Context) extends Printer {
toTextRef(tp) ~ ".type"
case tp: TermRef if tp.denot.isOverloaded =>
"<overloaded " ~ toTextRef(tp) ~ ">"
case tp: AppliedTermRef =>
"{" ~ toTextRef(tp) ~ "}"
case tp: TypeRef =>
if (printWithoutPrefix.contains(tp.symbol))
toText(tp.name)
Expand Down Expand Up @@ -288,6 +292,12 @@ class PlainPrinter(_ctx: Context) extends Printer {
tp match {
case tp: TermRef =>
toTextPrefix(tp.prefix) ~ selectionString(tp)
case AppliedTermRef(fn, args) =>
val argTexts = args.map {
case arg: SingletonType => toTextRef(arg)
case arg => argText(arg)
}
(toTextRef(fn) ~ "(" ~ Text(argTexts, ", ") ~ ")").close
case tp: ThisType =>
nameString(tp.cls) + ".this"
case SuperType(thistpe: SingletonType, _) =>
Expand Down
15 changes: 8 additions & 7 deletions compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,14 @@ object Erasure {
ref(meth).appliedToArgs(args.toList ++ followingArgs)
}

private def protoArgs(pt: Type, methTp: Type): List[untpd.Tree] = (pt, methTp) match {
case (pt: FunProto, methTp: MethodType) if methTp.isErasedMethod =>
protoArgs(pt.resType, methTp.resType)
case (pt: FunProto, methTp: MethodType) =>
pt.args ++ protoArgs(pt.resType, methTp.resType)
case _ => Nil
}
private def protoArgs(pt: Type, methTp: Type)(implicit ctx: Context): List[untpd.Tree] =
(pt, methTp.stripPoly) match {
case (pt: FunProto, methTp: MethodType) if methTp.isErasedMethod =>
protoArgs(pt.resType, methTp.resType)
case (pt: FunProto, methTp: MethodType) =>
pt.args ++ protoArgs(pt.resType, methTp.resType)
case _ => Nil
}

override def typedTypeApply(tree: untpd.TypeApply, pt: Type)(implicit ctx: Context): Tree = {
val ntree = interceptTypeApply(tree.asInstanceOf[TypeApply])(ctx.withPhase(ctx.erasurePhase)).withSpan(tree.span)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ trait FullParameterization {
else {
// this type could have changed on forwarding. Need to insert a cast.
originalDef.vparamss.foldLeft(fun)((acc, vparams) => {
val meth = acc.tpe.asInstanceOf[MethodType]
val meth = acc.tpe.stripPoly.asInstanceOf[MethodType]
val paramTypes = meth.instantiateParamInfos(vparams.map(_.tpe))
acc.appliedToArgs(
vparams.lazyZip(paramTypes).map((vparam, paramType) => {
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ trait Implicits { self: Typer =>
success(Literal(Constant(())))
case n: TermRef =>
success(ref(n))
// TODO(gsps): Handle AppliedTermRef
case tp =>
EmptyTree
}
Expand Down
16 changes: 10 additions & 6 deletions compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,18 @@ trait TypeAssigner {
tp
}

def applicationResultType(methTp: MethodType, args: List[Type])(implicit ctx: Context): Type =
if (methTp.isResultDependent) safeSubstParams(methTp.resultType, methTp.paramRefs, args)
else methTp.resultType

def assignType(tree: untpd.Apply, fn: Tree, args: List[Tree])(implicit ctx: Context): Apply = {
val ownType = fn.tpe.widen match {
case fntpe: MethodType =>
if (sameLength(fntpe.paramInfos, args) || ctx.phase.prev.relaxedTyping)
if (fntpe.isResultDependent) safeSubstParams(fntpe.resultType, fntpe.paramRefs, args.tpes)
else fntpe.resultType
val fnTpe = fn.tpe
val ownType = fnTpe.widen match {
case methTp: MethodType =>
if (sameLength(methTp.paramInfos, args) || ctx.phase.prev.relaxedTyping)
AppliedTermRef.make(fnTpe, args.tpes)
else
errorType(i"wrong number of arguments at ${ctx.phase.prev} for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.sourcePos)
errorType(i"wrong number of arguments at ${ctx.phase.prev} for $methTp: $fnTpe, expected: ${methTp.paramInfos.length}, found: ${args.length}", tree.sourcePos)
case t =>
if (ctx.settings.Ydebug.value) new FatalError("").printStackTrace()
errorType(err.takesNoParamsStr(fn, ""), tree.sourcePos)
Expand Down
1 change: 1 addition & 0 deletions docs/docs/internals/syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ SimpleType ::= SimpleType TypeArgs
| ‘(’ ArgTypes ‘)’ Tuple(ts)
| ‘?’ SubtypeBounds
| Refinement RefinedTypeTree(EmptyTree, refinement)
| `{` PostfixExpr `}` SingletonTypeTree(expr)
| SimpleLiteral SingletonTypeTree(l)
| ‘$’ ‘{’ Block ‘}’
ArgTypes ::= Type {‘,’ Type}
Expand Down
5 changes: 4 additions & 1 deletion tasty/src/dotty/tools/tasty/TastyFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ Standard-Section: "ASTs" TopLevelStat*
THIS clsRef_Type -- cls.this
RECthis recType_ASTRef -- The `this` in a recursive refined type `recType`.
SHAREDtype path_ASTRef -- link to previously serialized path
APPLIEDTERMREF Length fn_Type arg_Type* -- The stable result of `fn` applied to `arg`s

Constant = UNITconst -- ()
FALSEconst -- false
Expand Down Expand Up @@ -254,7 +255,7 @@ object TastyFormat {

final val header: Array[Int] = Array(0x5C, 0xA1, 0xAB, 0x1F)
val MajorVersion: Int = 20
val MinorVersion: Int = 0
val MinorVersion: Int = 1

/** Tags used to serialize names, should update [[nameTagToString]] if a new constant is added */
class NameTags {
Expand Down Expand Up @@ -453,6 +454,7 @@ object TastyFormat {
final val ANNOTATION = 173
final val TERMREFin = 174
final val TYPEREFin = 175
final val APPLIEDTERMREF = 176

final val METHODtype = 180
final val ERASEDMETHODtype = 181
Expand Down Expand Up @@ -660,6 +662,7 @@ object TastyFormat {
case SUPERtype => "SUPERtype"
case TERMREFin => "TERMREFin"
case TYPEREFin => "TYPEREFin"
case APPLIEDTERMREF => "APPLIEDTERMREF"

case REFINEDtype => "REFINEDtype"
case REFINEDtpt => "REFINEDtpt"
Expand Down
2 changes: 1 addition & 1 deletion tests/init/neg/private.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
class A(a: Int) {
a + 3
val _ = a + 3
def foo() = a * 2
}

Expand Down
Loading