Skip to content

Commit 266a1e4

Browse files
author
Rodrigo Raya
authored
Merge pull request #8863 from dotty-staging/depscala-2
Experimental Support for Dependent Type Arguments - Step 1: Parsing I have been working with these changes in the typing phase. They look fine to me.
2 parents 30bebc3 + 8385c17 commit 266a1e4

File tree

13 files changed

+170
-44
lines changed

13 files changed

+170
-44
lines changed

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,11 @@ object Trees {
669669
type ThisTree[-T >: Untyped] = LambdaTypeTree[T]
670670
}
671671

672+
case class TermLambdaTypeTree[-T >: Untyped] private[ast] (params: List[ValDef[T]], body: Tree[T])(implicit @constructorOnly src: SourceFile)
673+
extends TypTree[T] {
674+
type ThisTree[-T >: Untyped] = TermLambdaTypeTree[T]
675+
}
676+
672677
/** [bound] selector match { cases } */
673678
case class MatchTypeTree[-T >: Untyped] private[ast] (bound: Tree[T], selector: Tree[T], cases: List[CaseDef[T]])(implicit @constructorOnly src: SourceFile)
674679
extends TypTree[T] {
@@ -964,6 +969,7 @@ object Trees {
964969
type RefinedTypeTree = Trees.RefinedTypeTree[T]
965970
type AppliedTypeTree = Trees.AppliedTypeTree[T]
966971
type LambdaTypeTree = Trees.LambdaTypeTree[T]
972+
type TermLambdaTypeTree = Trees.TermLambdaTypeTree[T]
967973
type MatchTypeTree = Trees.MatchTypeTree[T]
968974
type ByNameTypeTree = Trees.ByNameTypeTree[T]
969975
type TypeBoundsTree = Trees.TypeBoundsTree[T]
@@ -1135,6 +1141,10 @@ object Trees {
11351141
case tree: LambdaTypeTree if (tparams eq tree.tparams) && (body eq tree.body) => tree
11361142
case _ => finalize(tree, untpd.LambdaTypeTree(tparams, body)(sourceFile(tree)))
11371143
}
1144+
def TermLambdaTypeTree(tree: Tree)(params: List[ValDef], body: Tree)(implicit ctx: Context): TermLambdaTypeTree = tree match {
1145+
case tree: TermLambdaTypeTree if (params eq tree.params) && (body eq tree.body) => tree
1146+
case _ => finalize(tree, untpd.TermLambdaTypeTree(params, body)(sourceFile(tree)))
1147+
}
11381148
def MatchTypeTree(tree: Tree)(bound: Tree, selector: Tree, cases: List[CaseDef])(implicit ctx: Context): MatchTypeTree = tree match {
11391149
case tree: MatchTypeTree if (bound eq tree.bound) && (selector eq tree.selector) && (cases eq tree.cases) => tree
11401150
case _ => finalize(tree, untpd.MatchTypeTree(bound, selector, cases)(sourceFile(tree)))
@@ -1293,6 +1303,10 @@ object Trees {
12931303
inContext(localCtx) {
12941304
cpy.LambdaTypeTree(tree)(transformSub(tparams), transform(body))
12951305
}
1306+
case TermLambdaTypeTree(params, body) =>
1307+
inContext(localCtx) {
1308+
cpy.TermLambdaTypeTree(tree)(transformSub(params), transform(body))
1309+
}
12961310
case MatchTypeTree(bound, selector, cases) =>
12971311
cpy.MatchTypeTree(tree)(transform(bound), transform(selector), transformSub(cases))
12981312
case ByNameTypeTree(result) =>
@@ -1422,6 +1436,10 @@ object Trees {
14221436
inContext(localCtx) {
14231437
this(this(x, tparams), body)
14241438
}
1439+
case TermLambdaTypeTree(params, body) =>
1440+
inContext(localCtx) {
1441+
this(this(x, params), body)
1442+
}
14251443
case MatchTypeTree(bound, selector, cases) =>
14261444
this(this(this(x, bound), selector), cases)
14271445
case ByNameTypeTree(result) =>

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
378378
def RefinedTypeTree(tpt: Tree, refinements: List[Tree])(implicit src: SourceFile): RefinedTypeTree = new RefinedTypeTree(tpt, refinements)
379379
def AppliedTypeTree(tpt: Tree, args: List[Tree])(implicit src: SourceFile): AppliedTypeTree = new AppliedTypeTree(tpt, args)
380380
def LambdaTypeTree(tparams: List[TypeDef], body: Tree)(implicit src: SourceFile): LambdaTypeTree = new LambdaTypeTree(tparams, body)
381+
def TermLambdaTypeTree(params: List[ValDef], body: Tree)(implicit src: SourceFile): TermLambdaTypeTree = new TermLambdaTypeTree(params, body)
381382
def MatchTypeTree(bound: Tree, selector: Tree, cases: List[CaseDef])(implicit src: SourceFile): MatchTypeTree = new MatchTypeTree(bound, selector, cases)
382383
def ByNameTypeTree(result: Tree)(implicit src: SourceFile): ByNameTypeTree = new ByNameTypeTree(result)
383384
def TypeBoundsTree(lo: Tree, hi: Tree, alias: Tree = EmptyTree)(implicit src: SourceFile): TypeBoundsTree = new TypeBoundsTree(lo, hi, alias)
@@ -487,8 +488,14 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
487488
ValDef(nme.syntheticParamName(n), if (tpt == null) TypeTree() else tpt, EmptyTree)
488489
.withFlags(flags)
489490

490-
def lambdaAbstract(tparams: List[TypeDef], tpt: Tree)(implicit ctx: Context): Tree =
491-
if (tparams.isEmpty) tpt else LambdaTypeTree(tparams, tpt)
491+
def lambdaAbstract(params: List[ValDef] | List[TypeDef], tpt: Tree)(using Context): Tree =
492+
params match
493+
case Nil => tpt
494+
case (vd: ValDef) :: _ => TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], tpt)
495+
case _ => LambdaTypeTree(params.asInstanceOf[List[TypeDef]], tpt)
496+
497+
def lambdaAbstractAll(paramss: List[List[ValDef] | List[TypeDef]], tpt: Tree)(using Context): Tree =
498+
paramss.foldRight(tpt)(lambdaAbstract)
492499

493500
/** A reference to given definition. If definition is a repeated
494501
* parameter, the reference will be a repeated argument.

compiler/src/dotty/tools/dotc/config/Feature.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package dotc
33
package config
44

55
import core._
6-
import Contexts._, Symbols._, Names._
6+
import Contexts._, Symbols._, Names._, NameOps._
77
import StdNames.nme
88
import Decorators.{given _}
99
import util.SourcePosition
@@ -22,7 +22,7 @@ object Feature:
2222
def enabledBySetting(feature: TermName, owner: Symbol = NoSymbol)(using Context): Boolean =
2323
def toPrefix(sym: Symbol): String =
2424
if !sym.exists || sym == defn.LanguageModule.moduleClass then ""
25-
else toPrefix(sym.owner) + sym.name + "."
25+
else toPrefix(sym.owner) + sym.name.stripModuleClassSuffix + "."
2626
val prefix = if owner.exists then toPrefix(owner) else ""
2727
ctx.base.settings.language.value.contains(prefix + feature)
2828

@@ -38,7 +38,7 @@ object Feature:
3838
def enabledByImport(feature: TermName, owner: Symbol = NoSymbol)(using Context): Boolean =
3939
ctx.atPhase(ctx.typerPhase) {
4040
ctx.importInfo != null
41-
&& ctx.importInfo.featureImported(feature.toTermName,
41+
&& ctx.importInfo.featureImported(feature,
4242
if owner.exists then owner else defn.LanguageModule.moduleClass)
4343
}
4444

@@ -57,6 +57,9 @@ object Feature:
5757
def dynamicsEnabled(using Context): Boolean =
5858
enabled(nme.dynamics)
5959

60+
def dependentEnabled(using Context) =
61+
enabled(nme.dependent, defn.LanguageExperimentalModule.moduleClass)
62+
6063
def sourceVersionSetting(using Context): SourceVersion =
6164
SourceVersion.valueOf(ctx.settings.source.value)
6265

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ class Definitions {
663663
@tu lazy val Mirror_SingletonProxyClass: ClassSymbol = ctx.requiredClass("scala.deriving.Mirror.SingletonProxy")
664664

665665
@tu lazy val LanguageModule: Symbol = ctx.requiredModule("scala.language")
666+
@tu lazy val LanguageExperimentalModule: Symbol = ctx.requiredModule("scala.language.experimental")
666667
@tu lazy val NonLocalReturnControlClass: ClassSymbol = ctx.requiredClass("scala.runtime.NonLocalReturnControl")
667668
@tu lazy val SelectableClass: ClassSymbol = ctx.requiredClass("scala.Selectable")
668669

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ object StdNames {
440440
val definitions: N = "definitions"
441441
val delayedInit: N = "delayedInit"
442442
val delayedInitArg: N = "delayedInit$body"
443+
val dependent: N = "dependent"
443444
val derived: N = "derived"
444445
val derives: N = "derives"
445446
val doubleHash: N = "doubleHash"

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ object Parsers {
192192

193193
def isIdent = in.isIdent
194194
def isIdent(name: Name) = in.isIdent(name)
195-
def isSimpleLiteral = simpleLiteralTokens contains in.token
195+
def isSimpleLiteral =
196+
simpleLiteralTokens.contains(in.token)
197+
|| isIdent(nme.raw.MINUS) && numericLitTokens.contains(in.lookahead.token)
196198
def isLiteral = literalTokens contains in.token
197199
def isNumericLit = numericLitTokens contains in.token
198200
def isTemplateIntro = templateIntroTokens contains in.token
@@ -1101,18 +1103,42 @@ object Parsers {
11011103
*/
11021104
def qualId(): Tree = dotSelectors(termIdent())
11031105

1104-
/** SimpleExpr ::= literal
1105-
* | 'id | 'this | 'true | 'false | 'null
1106-
* | null
1106+
/** Singleton ::= SimpleRef
1107+
* | SimpleLiteral
1108+
* | Singleton ‘.’ id
1109+
*/
1110+
def singleton(): Tree =
1111+
if isSimpleLiteral then simpleLiteral()
1112+
else dotSelectors(simpleRef())
1113+
1114+
/** SimpleLiteral ::= [‘-’] integerLiteral
1115+
* | [‘-’] floatingPointLiteral
1116+
* | booleanLiteral
1117+
* | characterLiteral
1118+
* | stringLiteral
1119+
*/
1120+
def simpleLiteral(): Tree =
1121+
if isIdent(nme.raw.MINUS) then
1122+
val start = in.offset
1123+
in.nextToken()
1124+
literal(negOffset = start, inTypeOrSingleton = true)
1125+
else
1126+
literal(inTypeOrSingleton = true)
1127+
1128+
/** Literal ::= SimpleLiteral
1129+
* | processedStringLiteral
1130+
* | symbolLiteral
1131+
* | ‘null’
1132+
*
11071133
* @param negOffset The offset of a preceding `-' sign, if any.
1108-
* If the literal is not negated, negOffset = in.offset.
1134+
* If the literal is not negated, negOffset == in.offset.
11091135
*/
1110-
def literal(negOffset: Int = in.offset, inPattern: Boolean = false, inType: Boolean = false, inStringInterpolation: Boolean = false): Tree = {
1136+
def literal(negOffset: Int = in.offset, inPattern: Boolean = false, inTypeOrSingleton: Boolean = false, inStringInterpolation: Boolean = false): Tree = {
11111137
def literalOf(token: Token): Tree = {
11121138
val isNegated = negOffset < in.offset
11131139
def digits0 = in.removeNumberSeparators(in.strVal)
11141140
def digits = if (isNegated) "-" + digits0 else digits0
1115-
if (!inType)
1141+
if !inTypeOrSingleton then
11161142
token match {
11171143
case INTLIT => return Number(digits, NumberKind.Whole(in.base))
11181144
case DECILIT => return Number(digits, NumberKind.Decimal)
@@ -1333,22 +1359,33 @@ object Parsers {
13331359
case _ => false
13341360
}
13351361

1336-
/** Type ::= FunType
1337-
* | HkTypeParamClause ‘=>>’ Type
1338-
* | MatchType
1339-
* | InfixType
1340-
* FunType ::= (MonoFunType | PolyFunType)
1341-
* MonoFunType ::= FunArgTypes (‘=>’ | ‘?=>’) Type
1342-
* PolyFunType ::= HKTypeParamClause '=>' Type
1343-
* FunArgTypes ::= InfixType
1344-
* | `(' [ [ ‘[using]’ ‘['erased'] FunArgType {`,' FunArgType } ] `)'
1345-
* | '(' [ ‘[using]’ ‘['erased'] TypedFunParam {',' TypedFunParam } ')'
1362+
/** Type ::= FunType
1363+
* | HkTypeParamClause ‘=>>’ Type
1364+
* | FunParamClause ‘=>>’ Type
1365+
* | MatchType
1366+
* | InfixType
1367+
* FunType ::= (MonoFunType | PolyFunType)
1368+
* MonoFunType ::= FunArgTypes (‘=>’ | ‘?=>’) Type
1369+
* PolyFunType ::= HKTypeParamClause '=>' Type
1370+
* FunArgTypes ::= InfixType
1371+
* | `(' [ [ ‘[using]’ ‘['erased'] FunArgType {`,' FunArgType } ] `)'
1372+
* | '(' [ ‘[using]’ ‘['erased'] TypedFunParam {',' TypedFunParam } ')'
13461373
*/
13471374
def typ(): Tree = {
13481375
val start = in.offset
13491376
var imods = Modifiers()
13501377
def functionRest(params: List[Tree]): Tree =
13511378
atSpan(start, in.offset) {
1379+
if in.token == TLARROW then
1380+
if !imods.flags.isEmpty || params.isEmpty then
1381+
syntaxError(em"illegal parameter list for type lambda", start)
1382+
in.token = ARROW
1383+
else
1384+
for case ValDef(_, tpt: ByNameTypeTree, _) <- params do
1385+
syntaxError(em"parameter of type lambda may not be call-by-name", tpt.span)
1386+
in.nextToken()
1387+
return TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], typ())
1388+
13521389
if in.token == CTXARROW then
13531390
in.nextToken()
13541391
imods |= Given
@@ -1475,10 +1512,19 @@ object Parsers {
14751512
Span(start, start + nme.IMPLICITkw.asSimpleName.length)
14761513

14771514
/** TypedFunParam ::= id ':' Type */
1478-
def typedFunParam(start: Offset, name: TermName, mods: Modifiers = EmptyModifiers): Tree = atSpan(start) {
1479-
accept(COLON)
1480-
makeParameter(name, typ(), mods | Param)
1481-
}
1515+
def typedFunParam(start: Offset, name: TermName, mods: Modifiers = EmptyModifiers): ValDef =
1516+
atSpan(start) {
1517+
accept(COLON)
1518+
makeParameter(name, typ(), mods | Param)
1519+
}
1520+
1521+
/** FunParamClause ::= ‘(’ TypedFunParam {‘,’ TypedFunParam } ‘)’
1522+
*/
1523+
def funParamClause(): List[ValDef] =
1524+
inParens(commaSeparated(() => typedFunParam(in.offset, ident())))
1525+
1526+
def funParamClauses(): List[List[ValDef]] =
1527+
if in.token == LPAREN then funParamClause() :: funParamClauses() else Nil
14821528

14831529
/** InfixType ::= RefinedType {id [nl] RefinedType}
14841530
*/
@@ -1556,14 +1602,12 @@ object Parsers {
15561602
/** SimpleType ::= SimpleLiteral
15571603
* | ‘?’ SubtypeBounds
15581604
* | SimpleType1
1605+
* | SimpeType ‘(’ Singletons ‘)’ -- under language.experimental.dependent, checked in Typer
1606+
* Singletons ::= Singleton {‘,’ Singleton}
15591607
*/
15601608
def simpleType(): Tree =
15611609
if isSimpleLiteral then
1562-
SingletonTypeTree(literal(inType = true))
1563-
else if isIdent(nme.raw.MINUS) && numericLitTokens.contains(in.lookahead.token) then
1564-
val start = in.offset
1565-
in.nextToken()
1566-
SingletonTypeTree(literal(negOffset = start, inType = true))
1610+
SingletonTypeTree(simpleLiteral())
15671611
else if in.token == USCORE then
15681612
if sourceVersion.isAtLeast(`3.1`) then
15691613
deprecationWarning(em"`_` is deprecated for wildcard arguments of types: use `?` instead")
@@ -1576,7 +1620,11 @@ object Parsers {
15761620
else if isIdent(nme.*) && ctx.settings.YkindProjector.value then
15771621
typeIdent()
15781622
else
1579-
simpleType1()
1623+
def singletonArgs(t: Tree): Tree =
1624+
if in.token == LPAREN
1625+
then singletonArgs(AppliedTypeTree(t, inParens(commaSeparated(singleton))))
1626+
else t
1627+
singletonArgs(simpleType1())
15801628

15811629
/** SimpleType1 ::= id
15821630
* | Singleton `.' id
@@ -2811,11 +2859,11 @@ object Parsers {
28112859
else tree1
28122860
}
28132861

2814-
/** Annotation ::= `@' SimpleType {ParArgumentExprs}
2862+
/** Annotation ::= `@' SimpleType1 {ParArgumentExprs}
28152863
*/
28162864
def annot(): Tree =
28172865
adjustStart(accept(AT)) {
2818-
ensureApplied(parArgumentExprss(wrapNew(simpleType())))
2866+
ensureApplied(parArgumentExprss(wrapNew(simpleType1())))
28192867
}
28202868

28212869
def annotations(skipNewLines: Boolean = false): List[Tree] = {
@@ -3348,15 +3396,16 @@ object Parsers {
33483396
argumentExprss(mkApply(Ident(nme.CONSTRUCTOR), argumentExprs()))
33493397
}
33503398

3351-
/** TypeDcl ::= id [TypeParamClause] TypeBounds [‘=’ Type]
3399+
/** TypeDcl ::= id [TypeParamClause] {FunParamClause} TypeBounds [‘=’ Type]
33523400
*/
33533401
def typeDefOrDcl(start: Offset, mods: Modifiers): Tree = {
33543402
newLinesOpt()
33553403
atSpan(start, nameStart) {
33563404
val nameIdent = typeIdent()
33573405
val tparams = typeParamClauseOpt(ParamOwner.Type)
3406+
val vparamss = funParamClauses()
33583407
def makeTypeDef(rhs: Tree): Tree = {
3359-
val rhs1 = lambdaAbstract(tparams, rhs)
3408+
val rhs1 = lambdaAbstractAll(tparams :: vparamss, rhs)
33603409
val tdef = TypeDef(nameIdent.name.toTypeName, rhs1)
33613410
if (nameIdent.isBackquoted)
33623411
tdef.pushAttachment(Backquoted, ())

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,12 +477,19 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
477477
changePrec(OrTypePrec) { toText(args(0)) ~ " | " ~ atPrec(OrTypePrec + 1) { toText(args(1)) } }
478478
else if (tpt.symbol == defn.andType && args.length == 2)
479479
changePrec(AndTypePrec) { toText(args(0)) ~ " & " ~ atPrec(AndTypePrec + 1) { toText(args(1)) } }
480-
else
481-
toTextLocal(tpt) ~ "[" ~ Text(args map argText, ", ") ~ "]"
480+
else args match
481+
case arg :: _ if arg.isTerm =>
482+
toTextLocal(tpt) ~ "(" ~ Text(args.map(argText), ", ") ~ ")"
483+
case _ =>
484+
toTextLocal(tpt) ~ "[" ~ Text(args.map(argText), ", ") ~ "]"
482485
case LambdaTypeTree(tparams, body) =>
483486
changePrec(GlobalPrec) {
484487
tparamsText(tparams) ~ " =>> " ~ toText(body)
485488
}
489+
case TermLambdaTypeTree(params, body) =>
490+
changePrec(GlobalPrec) {
491+
paramsText(params) ~ " =>> " ~ toText(body)
492+
}
486493
case MatchTypeTree(bound, sel, cases) =>
487494
changePrec(GlobalPrec) {
488495
toText(sel) ~ keywordStr(" match ") ~ blockText(cases) ~

compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,9 @@ object ErrorReporting {
146146
else ""
147147
}
148148

149+
def dependentStr =
150+
"""Term-dependent types are experimental,
151+
|they must be enabled with a `experimental.dependent` language import or setting""".stripMargin
152+
149153
def err(using Context): Errors = new Errors
150154
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,14 @@ class Typer extends Namer
15791579
}
15801580

15811581
def typedAppliedTypeTree(tree: untpd.AppliedTypeTree)(using Context): Tree = {
1582+
tree.args match
1583+
case arg :: _ if arg.isTerm =>
1584+
if dependentEnabled then
1585+
return errorTree(tree, i"Not yet implemented: T(...)")
1586+
else
1587+
return errorTree(tree, dependentStr)
1588+
case _ =>
1589+
15821590
val tpt1 = typed(tree.tpt, AnyTypeConstructorProto)(using ctx.retractMode(Mode.Pattern))
15831591
val tparams = tpt1.tpe.typeParams
15841592
if (tparams.isEmpty) {
@@ -1664,6 +1672,12 @@ class Typer extends Namer
16641672
assignType(cpy.LambdaTypeTree(tree)(tparams1, body1), tparams1, body1)
16651673
}
16661674

1675+
def typedTermLambdaTypeTree(tree: untpd.TermLambdaTypeTree)(using Context): Tree =
1676+
if dependentEnabled then
1677+
errorTree(tree, i"Not yet implemented: (...) =>> ...")
1678+
else
1679+
errorTree(tree, dependentStr)
1680+
16671681
def typedMatchTypeTree(tree: untpd.MatchTypeTree, pt: Type)(using Context): Tree = {
16681682
val bound1 =
16691683
if (tree.bound.isEmpty && isFullyDefined(pt, ForceDegree.none)) TypeTree(pt)
@@ -2381,6 +2395,7 @@ class Typer extends Namer
23812395
case tree: untpd.RefinedTypeTree => typedRefinedTypeTree(tree)
23822396
case tree: untpd.AppliedTypeTree => typedAppliedTypeTree(tree)
23832397
case tree: untpd.LambdaTypeTree => typedLambdaTypeTree(tree)(using ctx.localContext(tree, NoSymbol).setNewScope)
2398+
case tree: untpd.TermLambdaTypeTree => typedTermLambdaTypeTree(tree)(using ctx.localContext(tree, NoSymbol).setNewScope)
23842399
case tree: untpd.MatchTypeTree => typedMatchTypeTree(tree, pt)
23852400
case tree: untpd.ByNameTypeTree => typedByNameTypeTree(tree)
23862401
case tree: untpd.TypeBoundsTree => typedTypeBoundsTree(tree, pt)

0 commit comments

Comments
 (0)