Skip to content

Commit b85bd2d

Browse files
committed
Change syntax and Parser to allow extension instances
1 parent f802fa4 commit b85bd2d

File tree

5 files changed

+39
-18
lines changed

5 files changed

+39
-18
lines changed

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,11 @@ object Parsers {
931931
lookahead.nextToken()
932932
if lookahead.isIdent && !lookahead.isIdent(nme.on) then
933933
lookahead.nextToken()
934+
if lookahead.isNewLine then
935+
lookahead.nextToken()
934936
lookahead.isIdent(nme.on)
937+
|| lookahead.token == LBRACE
938+
|| lookahead.token == COLON
935939

936940
/* --------- OPERAND/OPERATOR STACK --------------------------------------- */
937941

@@ -3470,10 +3474,13 @@ object Parsers {
34703474
Template(constr, parents, Nil, EmptyValDef, Nil)
34713475
}
34723476

3473-
def checkExtensionMethod(tparams: List[Tree], stat: Tree): Unit = stat match {
3477+
def checkExtensionMethod(tparams: List[Tree],
3478+
vparamss: List[List[Tree]], stat: Tree): Unit = stat match {
34743479
case stat: DefDef =>
3475-
if stat.mods.is(Extension) then
3480+
if stat.mods.is(Extension) && vparamss.nonEmpty then
34763481
syntaxError(i"no extension method allowed here since leading parameter was already given", stat.span)
3482+
else if !stat.mods.is(Extension) && vparamss.isEmpty then
3483+
syntaxError(i"an extension method is required here", stat.span)
34773484
else if tparams.nonEmpty && stat.tparams.nonEmpty then
34783485
syntaxError(i"extension method cannot have type parameters since some were already given previously",
34793486
stat.tparams.head.span)
@@ -3527,21 +3534,25 @@ object Parsers {
35273534
finalizeDef(gdef, mods1, start)
35283535
}
35293536

3530-
/** ExtensionDef ::= [id] ‘on’ ExtParamClause {UsingParamClause} ExtMethods
3537+
/** ExtensionDef ::= [id] [‘on’ ExtParamClause {UsingParamClause}] TemplateBody
35313538
*/
35323539
def extensionDef(start: Offset, mods: Modifiers): ModuleDef =
35333540
in.nextToken()
35343541
val name = if isIdent && !isIdent(nme.on) then ident() else EmptyTermName
35353542
in.endMarkerScope(if name.isEmpty then nme.extension else name) {
3536-
if !isIdent(nme.on) then syntaxErrorOrIncomplete("`on` expected")
3537-
if isIdent(nme.on) then in.nextToken()
3538-
val tparams = typeParamClauseOpt(ParamOwner.Def)
3539-
val extParams = paramClause(0, prefix = true)
3540-
val givenParamss = paramClauses(givenOnly = true)
3543+
val (tparams, vparamss) =
3544+
if isIdent(nme.on) then
3545+
in.nextToken()
3546+
val tparams = typeParamClauseOpt(ParamOwner.Def)
3547+
val extParams = paramClause(0, prefix = true)
3548+
val givenParamss = paramClauses(givenOnly = true)
3549+
(tparams, extParams :: givenParamss)
3550+
else
3551+
(Nil, Nil)
35413552
possibleTemplateStart()
35423553
if !in.isNestedStart then syntaxError("Extension without extension methods")
3543-
val templ = templateBodyOpt(makeConstructor(tparams, extParams :: givenParamss), Nil, Nil)
3544-
templ.body.foreach(checkExtensionMethod(tparams, _))
3554+
val templ = templateBodyOpt(makeConstructor(tparams, vparamss), Nil, Nil)
3555+
templ.body.foreach(checkExtensionMethod(tparams, vparamss, _))
35453556
val edef = ModuleDef(name, templ)
35463557
finalizeDef(edef, addFlag(mods, Given), start)
35473558
}

docs/docs/internals/syntax.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ EnumDef ::= id ClassConstr InheritClauses EnumBody
386386
GivenDef ::= [GivenSig] [‘_’ ‘<:’] Type ‘=’ Expr
387387
| [GivenSig] ConstrApps [TemplateBody]
388388
GivenSig ::= [id] [DefTypeParamClause] {UsingParamClause} ‘as’
389-
ExtensionDef ::= [id] ‘on’ ExtParamClause {WithParamsOrTypes} ExtMethods
389+
ExtensionDef ::= [id] [‘on’ ExtParamClause {UsingParamClause}]
390+
TemplateBody
390391
ExtMethods ::= [nl] ‘{’ ‘def’ DefDef {semi ‘def’ DefDef} ‘}’
391392
ExtParamClause ::= [DefTypeParamClause] ‘(’ DefParam ‘)’
392393
Template ::= InheritClauses [TemplateBody] Template(constr, parents, self, stats)

tests/pos/reference/extension-methods.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,25 @@ object ExtMethods:
5656
extension on [T](xs: List[T])(using Ordering[T]):
5757
def largest(n: Int) = xs.sorted.takeRight(n)
5858

59-
given stringOps1 as AnyRef {
59+
extension ops:
60+
def (xs: Seq[String]).longestStrings: Seq[String] =
61+
val maxLength = xs.map(_.length).max
62+
xs.filter(_.length == maxLength)
63+
def (xs: Seq[String]).longestString: String = xs.longestStrings.head
64+
def [T](xs: List[T]).second: T = xs.tail.head
65+
66+
extension:
67+
def [T](xs: List[T]) longest (using Ordering[T])(n: Int) =
68+
xs.sorted.takeRight(n)
69+
70+
given stringOps2 as AnyRef {
6071
def (xs: Seq[String]).longestStrings: Seq[String] = {
6172
val maxLength = xs.map(_.length).max
6273
xs.filter(_.length == maxLength)
6374
}
6475
}
6576

66-
given listOps1 as AnyRef {
77+
given listOps2 as AnyRef {
6778
def [T](xs: List[T]) second = xs.tail.head
6879
def [T](xs: List[T]) third: T = xs.tail.tail.head
6980
}

tests/run/extmethod-overload.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ object Test extends App {
2222
// Test with extension methods in given object
2323
object test1 {
2424

25-
given Foo as AnyRef {
25+
extension Foo:
2626
def (x: Int) |+| (y: Int) = x + y
2727
def (x: Int) |+| (y: String) = x + y.length
2828

2929
def [T](xs: List[T]) +++ (ys: List[T]): List[T] = xs ++ ys ++ ys
3030
def [T](xs: List[T]) +++ (ys: Iterator[T]): List[T] = xs ++ ys ++ ys
31-
}
31+
end Foo
3232

3333
assert((1 |+| 2) == 3)
3434
assert((1 |+| "2") == 2)

tests/run/instances.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ object Test extends App {
3131
extension listListOps on [T](xs: List[List[T]]):
3232
def flattened = xs.foldLeft[List[T]](Nil)(_ ++ _)
3333

34-
// A right associative op. Note: can't use given extension for this!
35-
given prepend as AnyRef {
34+
extension prepend:
3635
def [T](x: T) :: (xs: Seq[T]) = x +: xs
37-
}
3836

3937
val ss: Seq[Int] = List(1, 2, 3)
4038
val ss1 = 0 :: ss

0 commit comments

Comments
 (0)