Skip to content

Commit aad399d

Browse files
authored
Merge pull request #14497 from dotty-staging/add-exports-in-extmethods
Allow exports in extension clauses
2 parents 019caf8 + 3c55674 commit aad399d

19 files changed

+802
-478
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ object desugar {
393393
vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault) | Param)
394394
}
395395

396+
def mkApply(fn: Tree, paramss: List[ParamClause])(using Context): Tree =
397+
paramss.foldLeft(fn) { (fn, params) => params match
398+
case TypeDefs(params) =>
399+
TypeApply(fn, params.map(refOfDef))
400+
case (vparam: ValDef) :: _ if vparam.mods.is(Given) =>
401+
Apply(fn, params.map(refOfDef)).setApplyKind(ApplyKind.Using)
402+
case _ =>
403+
Apply(fn, params.map(refOfDef))
404+
}
405+
396406
/** The expansion of a class definition. See inline comments for what is involved */
397407
def classDef(cdef: TypeDef)(using Context): Tree = {
398408
val impl @ Template(constr0, _, self, _) = cdef.rhs
@@ -588,7 +598,7 @@ object desugar {
588598
}
589599

590600
// new C[Ts](paramss)
591-
lazy val creatorExpr = {
601+
lazy val creatorExpr =
592602
val vparamss = constrVparamss match
593603
case (vparam :: _) :: _ if vparam.mods.is(Implicit) => // add a leading () to match class parameters
594604
Nil :: constrVparamss
@@ -607,7 +617,6 @@ object desugar {
607617
}
608618
}
609619
ensureApplied(nu)
610-
}
611620

612621
val copiedAccessFlags = if migrateTo3 then EmptyFlags else AccessFlags
613622

@@ -892,48 +901,50 @@ object desugar {
892901
}
893902
}
894903

904+
def extMethod(mdef: DefDef, extParamss: List[ParamClause])(using Context): DefDef =
905+
cpy.DefDef(mdef)(
906+
name = normalizeName(mdef, mdef.tpt).asTermName,
907+
paramss =
908+
if mdef.name.isRightAssocOperatorName then
909+
val (typaramss, paramss) = mdef.paramss.span(isTypeParamClause) // first extract type parameters
910+
911+
paramss match
912+
case params :: paramss1 => // `params` must have a single parameter and without `given` flag
913+
914+
def badRightAssoc(problem: String) =
915+
report.error(i"right-associative extension method $problem", mdef.srcPos)
916+
extParamss ++ mdef.paramss
917+
918+
params match
919+
case ValDefs(vparam :: Nil) =>
920+
if !vparam.mods.is(Given) then
921+
// we merge the extension parameters with the method parameters,
922+
// swapping the operator arguments:
923+
// e.g.
924+
// extension [A](using B)(c: C)(using D)
925+
// def %:[E](f: F)(g: G)(using H): Res = ???
926+
// will be encoded as
927+
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
928+
val (leadingUsing, otherExtParamss) = extParamss.span(isUsingOrTypeParamClause)
929+
leadingUsing ::: typaramss ::: params :: otherExtParamss ::: paramss1
930+
else
931+
badRightAssoc("cannot start with using clause")
932+
case _ =>
933+
badRightAssoc("must start with a single parameter")
934+
case _ =>
935+
// no value parameters, so not an infix operator.
936+
extParamss ++ mdef.paramss
937+
else
938+
extParamss ++ mdef.paramss
939+
).withMods(mdef.mods | ExtensionMethod)
940+
895941
/** Transform extension construct to list of extension methods */
896942
def extMethods(ext: ExtMethods)(using Context): Tree = flatTree {
897-
for mdef <- ext.methods yield
898-
defDef(
899-
cpy.DefDef(mdef)(
900-
name = normalizeName(mdef, ext).asTermName,
901-
paramss =
902-
if mdef.name.isRightAssocOperatorName then
903-
val (typaramss, paramss) = mdef.paramss.span(isTypeParamClause) // first extract type parameters
904-
905-
paramss match
906-
case params :: paramss1 => // `params` must have a single parameter and without `given` flag
907-
908-
def badRightAssoc(problem: String) =
909-
report.error(i"right-associative extension method $problem", mdef.srcPos)
910-
ext.paramss ++ mdef.paramss
911-
912-
params match
913-
case ValDefs(vparam :: Nil) =>
914-
if !vparam.mods.is(Given) then
915-
// we merge the extension parameters with the method parameters,
916-
// swapping the operator arguments:
917-
// e.g.
918-
// extension [A](using B)(c: C)(using D)
919-
// def %:[E](f: F)(g: G)(using H): Res = ???
920-
// will be encoded as
921-
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
922-
val (leadingUsing, otherExtParamss) = ext.paramss.span(isUsingOrTypeParamClause)
923-
leadingUsing ::: typaramss ::: params :: otherExtParamss ::: paramss1
924-
else
925-
badRightAssoc("cannot start with using clause")
926-
case _ =>
927-
badRightAssoc("must start with a single parameter")
928-
case _ =>
929-
// no value parameters, so not an infix operator.
930-
ext.paramss ++ mdef.paramss
931-
else
932-
ext.paramss ++ mdef.paramss
933-
).withMods(mdef.mods | ExtensionMethod)
934-
)
943+
ext.methods map {
944+
case exp: Export => exp
945+
case mdef: DefDef => defDef(extMethod(mdef, ext.paramss))
946+
}
935947
}
936-
937948
/** Transforms
938949
*
939950
* <mods> type t >: Low <: Hi

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
117117
case class GenAlias(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
118118
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree])(implicit @constructorOnly src: SourceFile) extends TypTree
119119
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree)(implicit @constructorOnly src: SourceFile) extends DefTree
120-
case class ExtMethods(paramss: List[ParamClause], methods: List[DefDef])(implicit @constructorOnly src: SourceFile) extends Tree
120+
case class ExtMethods(paramss: List[ParamClause], methods: List[Tree])(implicit @constructorOnly src: SourceFile) extends Tree
121121
case class MacroTree(expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
122122

123123
case class ImportSelector(imported: Ident, renamed: Tree = EmptyTree, bound: Tree = EmptyTree)(implicit @constructorOnly src: SourceFile) extends Tree {
@@ -640,7 +640,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
640640
case tree: PatDef if (mods eq tree.mods) && (pats eq tree.pats) && (tpt eq tree.tpt) && (rhs eq tree.rhs) => tree
641641
case _ => finalize(tree, untpd.PatDef(mods, pats, tpt, rhs)(tree.source))
642642
}
643-
def ExtMethods(tree: Tree)(paramss: List[ParamClause], methods: List[DefDef])(using Context): Tree = tree match
643+
def ExtMethods(tree: Tree)(paramss: List[ParamClause], methods: List[Tree])(using Context): Tree = tree match
644644
case tree: ExtMethods if (paramss eq tree.paramss) && (methods == tree.methods) => tree
645645
case _ => finalize(tree, untpd.ExtMethods(paramss, methods)(tree.source))
646646
def ImportSelector(tree: Tree)(imported: Ident, renamed: Tree, bound: Tree)(using Context): Tree = tree match {

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,7 +3123,7 @@ object Parsers {
31233123
/** Import ::= `import' ImportExpr {‘,’ ImportExpr}
31243124
* Export ::= `export' ImportExpr {‘,’ ImportExpr}
31253125
*/
3126-
def importClause(leading: Token, mkTree: ImportConstr): List[Tree] = {
3126+
def importOrExportClause(leading: Token, mkTree: ImportConstr): List[Tree] = {
31273127
val offset = accept(leading)
31283128
commaSeparated(importExpr(mkTree)) match {
31293129
case t :: rest =>
@@ -3136,6 +3136,12 @@ object Parsers {
31363136
}
31373137
}
31383138

3139+
def exportClause() =
3140+
importOrExportClause(EXPORT, Export(_,_))
3141+
3142+
def importClause(outermost: Boolean = false) =
3143+
importOrExportClause(IMPORT, mkImport(outermost))
3144+
31393145
/** Create an import node and handle source version imports */
31403146
def mkImport(outermost: Boolean = false): ImportConstr = (tree, selectors) =>
31413147
val imp = Import(tree, selectors)
@@ -3685,8 +3691,10 @@ object Parsers {
36853691
if in.isColon() then
36863692
syntaxError("no `:` expected here")
36873693
in.nextToken()
3688-
val methods =
3689-
if isDefIntro(modifierTokens) then
3694+
val methods: List[Tree] =
3695+
if in.token == EXPORT then
3696+
exportClause()
3697+
else if isDefIntro(modifierTokens) then
36903698
extMethod(nparams) :: Nil
36913699
else
36923700
in.observeIndented()
@@ -3696,12 +3704,13 @@ object Parsers {
36963704
val result = atSpan(start)(ExtMethods(joinParams(tparams, leadParamss.toList), methods))
36973705
val comment = in.getDocComment(start)
36983706
if comment.isDefined then
3699-
for meth <- methods do
3707+
for case meth: DefDef <- methods do
37003708
if !meth.rawComment.isDefined then meth.setComment(comment)
37013709
result
37023710
end extension
37033711

37043712
/** ExtMethod ::= {Annotation [nl]} {Modifier} ‘def’ DefDef
3713+
* | Export
37053714
*/
37063715
def extMethod(numLeadParams: Int): DefDef =
37073716
val start = in.offset
@@ -3711,16 +3720,18 @@ object Parsers {
37113720

37123721
/** ExtMethods ::= ExtMethod | [nl] ‘{’ ExtMethod {semi ExtMethod ‘}’
37133722
*/
3714-
def extMethods(numLeadParams: Int): List[DefDef] = checkNoEscapingPlaceholders {
3715-
val meths = new ListBuffer[DefDef]
3723+
def extMethods(numLeadParams: Int): List[Tree] = checkNoEscapingPlaceholders {
3724+
val meths = new ListBuffer[Tree]
37163725
while
37173726
val start = in.offset
3718-
val mods = defAnnotsMods(modifierTokens)
3719-
in.token != EOF && {
3720-
accept(DEF)
3721-
meths += defDefOrDcl(start, mods, numLeadParams)
3722-
in.token != EOF && statSepOrEnd(meths, what = "extension method")
3723-
}
3727+
if in.token == EXPORT then
3728+
meths ++= exportClause()
3729+
else
3730+
val mods = defAnnotsMods(modifierTokens)
3731+
if in.token != EOF then
3732+
accept(DEF)
3733+
meths += defDefOrDcl(start, mods, numLeadParams)
3734+
in.token != EOF && statSepOrEnd(meths, what = "extension method")
37243735
do ()
37253736
if meths.isEmpty then syntaxErrorOrIncomplete("`def` expected")
37263737
meths.toList
@@ -3868,9 +3879,9 @@ object Parsers {
38683879
else stats += packaging(start)
38693880
}
38703881
else if (in.token == IMPORT)
3871-
stats ++= importClause(IMPORT, mkImport(outermost))
3882+
stats ++= importClause(outermost)
38723883
else if (in.token == EXPORT)
3873-
stats ++= importClause(EXPORT, Export(_,_))
3884+
stats ++= exportClause()
38743885
else if isIdent(nme.extension) && followingIsExtension() then
38753886
stats += extension()
38763887
else if isDefIntro(modifierTokens) then
@@ -3916,9 +3927,9 @@ object Parsers {
39163927
while
39173928
var empty = false
39183929
if (in.token == IMPORT)
3919-
stats ++= importClause(IMPORT, mkImport())
3930+
stats ++= importClause()
39203931
else if (in.token == EXPORT)
3921-
stats ++= importClause(EXPORT, Export(_,_))
3932+
stats ++= exportClause()
39223933
else if isIdent(nme.extension) && followingIsExtension() then
39233934
stats += extension()
39243935
else if (isDefIntro(modifierTokensOrCase))
@@ -3994,7 +4005,7 @@ object Parsers {
39944005
while
39954006
var empty = false
39964007
if (in.token == IMPORT)
3997-
stats ++= importClause(IMPORT, mkImport())
4008+
stats ++= importClause()
39984009
else if (isExprIntro)
39994010
stats += expr(Location.InBlock)
40004011
else if in.token == IMPLICIT && !in.inModifierPosition() then

0 commit comments

Comments
 (0)