Skip to content

Commit c143941

Browse files
committed
Fix #7669: Implement extended with syntax
Implement `extended with` syntax for extension methods.
1 parent da9e8ae commit c143941

File tree

14 files changed

+81
-93
lines changed

14 files changed

+81
-93
lines changed

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ object StdNames {
435435
val eval: N = "eval"
436436
val eqlAny: N = "eqlAny"
437437
val ex: N = "ex"
438+
val extended: N = "extended"
438439
val extension: N = "extension"
439440
val experimental: N = "experimental"
440441
val f: N = "f"

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 57 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3362,16 +3362,6 @@ object Parsers {
33623362
Template(constr, parents, Nil, EmptyValDef, Nil)
33633363
}
33643364

3365-
/** Check that `vparamss` represents a legal collective parameter list for a given extension
3366-
*/
3367-
def checkExtensionParams(start: Offset, vparamss: List[List[ValDef]]): Unit = vparamss match
3368-
case (vparam :: Nil) :: vparamss1 if !vparam.mods.is(Given) =>
3369-
vparamss1.foreach(_.foreach(vparam =>
3370-
if !vparam.mods.is(Given) then
3371-
syntaxError(em"follow-on parameter in extension clause must be `given`", vparam.span)))
3372-
case _ =>
3373-
syntaxError(em"extension clause must start with a single regular parameter", start)
3374-
33753365
def checkExtensionMethod(tparams: List[Tree], stat: Tree): Unit = stat match {
33763366
case stat: DefDef =>
33773367
if stat.mods.is(Extension) then
@@ -3385,22 +3375,23 @@ object Parsers {
33853375

33863376
/** GivenDef ::= [GivenSig (‘:’ | <:)] Type ‘=’ Expr
33873377
* | [GivenSig ‘:’] ConstrApps [[‘with’] TemplateBody]
3388-
* | [id ‘:’] ‘extension’ ExtParamClause {GivenParamClause} ExtMethods
3378+
* | [id ‘:’] ExtParamClause {GivenParamClause} ‘extended’ ‘with’ ExtMethods
33893379
* GivenSig ::= [id] [DefTypeParamClause] {GivenParamClause}
3390-
* ExtParamClause ::= [DefTypeParamClause] DefParamClause {GivenParamClause}
3380+
* ExtParamClause ::= [DefTypeParamClause] DefParamClause
33913381
* ExtMethods ::= [nl] ‘{’ ‘def’ DefDef {semi ‘def’ DefDef} ‘}’
33923382
*/
33933383
def givenDef(start: Offset, mods: Modifiers, instanceMod: Mod) = atSpan(start, nameStart) {
33943384
var mods1 = addMod(mods, instanceMod)
33953385
val hasGivenSig = followingIsGivenSig()
3396-
val (name, isExtension) =
3386+
val nameStart = in.offset
3387+
val (name, isOldExtension) =
33973388
if isIdent && hasGivenSig then
33983389
(ident(), in.token == COLON && in.lookaheadIn(nme.extension))
33993390
else
34003391
(EmptyTermName, isIdent(nme.extension))
34013392

34023393
val gdef = in.endMarkerScope(if name.isEmpty then GIVEN else name) {
3403-
if isExtension then
3394+
if isOldExtension then
34043395
if (in.token == COLON) in.nextToken()
34053396
assert(ident() == nme.extension)
34063397
val tparams = typeParamClauseOpt(ParamOwner.Def)
@@ -3412,65 +3403,61 @@ object Parsers {
34123403
templ.body.foreach(checkExtensionMethod(tparams, _))
34133404
ModuleDef(name, templ)
34143405
else
3415-
var tparams: List[TypeDef] = Nil
3416-
var vparamss: List[List[ValDef]] = Nil
3417-
var hasExtensionParams = false
3418-
3419-
def parseParams(isExtension: Boolean): Unit =
3420-
if isExtension && (in.token == LBRACKET || in.token == LPAREN) then
3421-
hasExtensionParams = true
3422-
if tparams.nonEmpty || vparamss.nonEmpty then
3423-
syntaxError(i"cannot have parameters before and after `:` in extension")
3424-
if in.token == LBRACKET then
3425-
tparams = typeParamClause(ParamOwner.Def)
3426-
if in.token == LPAREN && followingIsParamOrGivenType() then
3427-
val paramsStart = in.offset
3428-
vparamss = paramClauses(givenOnly = !isExtension)
3429-
if isExtension then
3430-
checkExtensionParams(paramsStart, vparamss)
3431-
3432-
parseParams(isExtension = false)
3433-
val parents =
3434-
if in.token == COLON then
3435-
in.nextToken()
3436-
if in.token == LBRACKET
3437-
|| in.token == LPAREN && followingIsParamOrGivenType()
3438-
then
3439-
parseParams(isExtension = true)
3440-
Nil
3441-
else
3442-
constrApps(commaOK = true, templateCanFollow = true)
3443-
else if in.token == SUBTYPE then
3444-
if !mods.is(Inline) then
3445-
syntaxError("`<:' is only allowed for given with `inline' modifier")
3446-
in.nextToken()
3447-
TypeBoundsTree(EmptyTree, toplevelTyp()) :: Nil
3448-
else if name.isEmpty && !hasExtensionParams then
3449-
constrApps(commaOK = true, templateCanFollow = true)
3406+
val hasLabel = !name.isEmpty && in.token == COLON
3407+
if hasLabel then in.nextToken()
3408+
val tparams = typeParamClauseOpt(ParamOwner.Def)
3409+
val paramsStart = in.offset
3410+
val vparamss =
3411+
if in.token == LPAREN && followingIsParamOrGivenType()
3412+
then paramClauses()
34503413
else Nil
3451-
3452-
if in.token == EQUALS && parents.length == 1 && parents.head.isType then
3414+
val isExtension = isIdent(nme.extended)
3415+
def checkAllGivens(vparamss: List[List[ValDef]], what: String) =
3416+
vparamss.foreach(_.foreach(vparam =>
3417+
if !vparam.mods.is(Given) then syntaxError(em"$what must be `given`", vparam.span)))
3418+
if isExtension then
3419+
if !name.isEmpty && !hasLabel then
3420+
syntaxError(em"name $name of extension clause must be followed by `:`", nameStart)
3421+
vparamss match
3422+
case (vparam :: Nil) :: vparamss1 if !vparam.mods.is(Given) =>
3423+
checkAllGivens(vparamss1, "follow-on parameter in extension clause")
3424+
case _ =>
3425+
syntaxError("extension clause must start with a single regular parameter", paramsStart)
34533426
in.nextToken()
3454-
mods1 |= Final
3455-
DefDef(name, tparams, vparamss, parents.head, subExpr())
3427+
accept(WITH)
3428+
val (self, stats) = templateBody()
3429+
stats.foreach(checkExtensionMethod(tparams, _))
3430+
ModuleDef(name, Template(makeConstructor(tparams, vparamss), Nil, Nil, self, stats))
34563431
else
3457-
parents match
3458-
case TypeBoundsTree(_, _) :: _ => syntaxError("`=' expected")
3459-
case _ =>
3460-
possibleTemplateStart()
3461-
if hasExtensionParams then
3462-
in.observeIndented()
3432+
checkAllGivens(vparamss, "parameter of given instance")
3433+
val parents =
3434+
if hasLabel then
3435+
constrApps(commaOK = true, templateCanFollow = true)
3436+
else if in.token == SUBTYPE then
3437+
if !mods.is(Inline) then
3438+
syntaxError("`<:' is only allowed for given with `inline' modifier")
3439+
in.nextToken()
3440+
TypeBoundsTree(EmptyTree, toplevelTyp()) :: Nil
3441+
else
3442+
if !(name.isEmpty && tparams.isEmpty && vparamss.isEmpty) then
3443+
accept(COLON)
3444+
constrApps(commaOK = true, templateCanFollow = true)
3445+
if in.token == EQUALS && parents.length == 1 && parents.head.isType then
3446+
in.nextToken()
3447+
mods1 |= Final
3448+
DefDef(name, tparams, vparamss, parents.head, subExpr())
34633449
else
3464-
tparams = tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal))
3465-
vparamss = vparamss.map(_.map(vparam =>
3450+
parents match
3451+
case TypeBoundsTree(_, _) :: _ => syntaxError("`=' expected")
3452+
case _ =>
3453+
possibleTemplateStart()
3454+
val tparams1 = tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal))
3455+
val vparamss1 = vparamss.map(_.map(vparam =>
34663456
vparam.withMods(vparam.mods &~ Param | ParamAccessor | PrivateLocal)))
3467-
val templ = templateBodyOpt(makeConstructor(tparams, vparamss), parents, Nil)
3468-
if hasExtensionParams then
3469-
templ.body.foreach(checkExtensionMethod(tparams, _))
3470-
ModuleDef(name, templ)
3471-
else if tparams.isEmpty && vparamss.isEmpty then ModuleDef(name, templ)
3472-
else TypeDef(name.toTypeName, templ)
3473-
}
3457+
val templ = templateBodyOpt(makeConstructor(tparams1, vparamss1), parents, Nil)
3458+
if tparams.isEmpty && vparamss.isEmpty then ModuleDef(name, templ)
3459+
else TypeDef(name.toTypeName, templ)
3460+
}
34743461
finalizeDef(gdef, mods1, start)
34753462
}
34763463

@@ -3547,8 +3534,8 @@ object Parsers {
35473534
checkNextNotIndented()
35483535
Template(constr, Nil, Nil, EmptyValDef, Nil)
35493536

3550-
/** TemplateBody ::= [nl | `with'] `{' TemplateStatSeq `}'
3551-
* EnumBody ::= [nl | ‘with’] ‘{’ [SelfType] EnumStat {semi EnumStat} ‘}’
3537+
/** TemplateBody ::= [nl] `{' TemplateStatSeq `}'
3538+
* EnumBody ::= [nl] ‘{’ [SelfType] EnumStat {semi EnumStat} ‘}’
35523539
*/
35533540
def templateBodyOpt(constr: DefDef, parents: List[Tree], derived: List[Tree]): Template =
35543541
val (self, stats) =

docs/docs/internals/syntax.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ ObjectDef ::= id [Template]
386386
EnumDef ::= id ClassConstr InheritClauses [‘with’] EnumBody EnumDef(mods, name, tparams, template)
387387
GivenDef ::= [GivenSig (‘:’ | <:)] Type ‘=’ Expr
388388
| [GivenSig ‘:’] ConstrApps [[‘with’] TemplateBody]
389-
| [[id ‘:’] ‘extension’ ExtParamClause {GivenParamClause}
390-
ExtMethods
389+
| [id ‘:’] ExtParamClause {GivenParamClause}
390+
‘extended’ ‘with’ ExtMethods
391391
GivenSig ::= [id] [DefTypeParamClause] {GivenParamClause}
392392
ExtParamClause ::= [DefTypeParamClause] ‘(’ DefParam ‘)’
393393
ExtMethods ::= [nl] ‘{’ ‘def’ DefDef {semi ‘def’ DefDef} ‘}’

docs/docs/reference/contextual/extension-methods.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,19 @@ List(1, 2, 3).second[Int]
126126
`given` extensions are given instances that define extension methods and nothing else. Examples:
127127

128128
```scala
129-
given stringOps: extension (xs: Seq[String]) {
129+
given stringOps: (xs: Seq[String]) extended with {
130130
def longestStrings: Seq[String] = {
131131
val maxLength = xs.map(_.length).max
132132
xs.filter(_.length == maxLength)
133133
}
134134
}
135135

136-
given listOps: extension [T](xs: List[T]) {
136+
given listOps: [T](xs: List[T]) extended with {
137137
def second = xs.tail.head
138138
def third: T = xs.tail.tail.head
139139
}
140140

141-
given extension [T](xs: List[T])(given Ordering[T]) {
141+
given [T](xs: List[T])(given Ordering[T]) extended with {
142142
def largest(n: Int) = xs.sorted.takeRight(n)
143143
}
144144
```

docs/docs/reference/other-new-features/opaques.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ object Logarithms {
2020
}
2121

2222
// Extension methods define opaque types' public APIs
23-
given logarithmOps: extension (x: Logarithm) {
23+
given logarithmOps: (x: Logarithm) extended with {
2424
def toDouble: Double = math.exp(x)
2525
def + (y: Logarithm): Logarithm = Logarithm(math.exp(x) + math.exp(y))
2626
def * (y: Logarithm): Logarithm = Logarithm(x + y)

tests/neg/extmethod-overload.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
object Test {
2-
given a: extension (x: Int)
2+
given a: (x: Int) extended with
33
def |+| (y: Int) = x + y
44

5-
given b: extension (x: Int) {
5+
given b: (x: Int) extended with {
66
def |+| (y: String) = x + y.length
77
}
88
assert((1 |+| 2) == 3) // error ambiguous

tests/neg/i6801.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
given MyNumericOps: extension [T](x: T) {
1+
given myNumericOps: [T](x: T) extended with {
22
def + (y: T)(given n: Numeric[T]): T = n.plus(x,y)
33
}
44
def foo[T: Numeric](x: T) = 1f + x // error: no implicit argument of type Numeric[Any]

tests/neg/i7529.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
given fooOps: extension [A](a: A) with
1+
given fooOps: [A](a: A) extended with
22

33
@nonsense // error: not found: nonsense
44
def foo = ???

tests/pos/reference/delegates.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ object Instances extends Common with
3939
if (fst != 0) fst else xs1.compareTo(ys1)
4040
end listOrd
4141

42-
given stringOps: extension (xs: Seq[String]) with
42+
given stringOps: (xs: Seq[String]) extended with
4343
def longestStrings: Seq[String] =
4444
val maxLength = xs.map(_.length).max
4545
xs.filter(_.length == maxLength)
4646

47-
given extension [T](xs: List[T])
47+
given [T](xs: List[T]) extended with
4848
def second = xs.tail.head
4949
def third = xs.tail.tail.head
5050

tests/pos/reference/extension-methods.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,19 @@ object ExtMethods with
4141

4242
List(1, 2, 3).second[Int]
4343

44-
given stringOps: extension (xs: Seq[String]) {
44+
given stringOps: (xs: Seq[String]) extended with {
4545
def longestStrings: Seq[String] = {
4646
val maxLength = xs.map(_.length).max
4747
xs.filter(_.length == maxLength)
4848
}
4949
}
5050

51-
given listOps: extension [T](xs: List[T]) with
51+
given listOps: [T](xs: List[T]) extended with
5252
def second = xs.tail.head
5353
def third: T = xs.tail.tail.head
5454

5555

56-
given extension [T](xs: List[T])(given Ordering[T]) with
56+
given [T](xs: List[T])(given Ordering[T]) extended with
5757
def largest(n: Int) = xs.sorted.takeRight(n)
5858

5959
given stringOps1: AnyRef {

tests/pos/tasty-reflect-opaque-api-proto.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class Reflect(val internal: CompilerInterface) {
1010
opaque type Term <: Tree = internal.Term
1111

1212
object Tree {
13-
given Ops: extension (tree: Tree) {
13+
given ops: (tree: Tree) extended with {
1414
def show: String = ???
1515
}
1616
}

tests/run/extension-specificity.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
class A
22
class B extends A
33

4-
given a: extension (x: A) with
4+
given a: (x: A) extended with
55
def foo: Int = 1
66

7-
given b: extension (x: B) with
7+
given b: (x: B) extended with
88
def foo: Int = 2
99

1010
@main def Test =

tests/run/extmethods2.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ object Test extends App {
1616
test(given TC())
1717

1818
object A {
19-
given listOps: extension [T](xs: List[T]) {
19+
given listOps: [T](xs: List[T]) extended with {
2020
def second: T = xs.tail.head
2121
def third: T = xs.tail.tail.head
2222
def concat(ys: List[T]) = xs ++ ys
2323
}
24-
given polyListOps: extension [T, U](xs: List[T]) {
24+
given polyListOps: [T, U](xs: List[T]) extended with {
2525
def zipp(ys: List[U]): List[(T, U)] = xs.zip(ys)
2626
}
2727
given extension (xs: List[Int]) {

tests/run/instances.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ object Test extends App {
88

99
case class Circle(x: Double, y: Double, radius: Double)
1010

11-
given circleOps: extension (c: Circle) with
11+
given circleOps: (c: Circle) extended with
1212
def circumference: Double = c.radius * math.Pi * 2
1313

1414
val circle = new Circle(1, 1, 2.0)
1515

1616
assert(circle.circumference == circleOps.circumference(circle))
1717

18-
given stringOps: extension (xs: Seq[String]) with
18+
given stringOps: (xs: Seq[String]) extended with
1919
def longestStrings: Seq[String] =
2020
val maxLength = xs.map(_.length).max
2121
xs.filter(_.length == maxLength)
@@ -28,7 +28,7 @@ object Test extends App {
2828

2929
assert(names.longestStrings.second == "world")
3030

31-
given listListOps: extension [T](xs: List[List[T]]) with
31+
given listListOps: [T](xs: List[List[T]]) extended with
3232
def flattened = xs.foldLeft[List[T]](Nil)(_ ++ _)
3333

3434
// A right associative op. Note: can't use given extension for this!

0 commit comments

Comments
 (0)