Skip to content

Fix #7669: Implement extended with syntax for extension methods. #7670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 4, 2019
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ object StdNames {
val eval: N = "eval"
val eqlAny: N = "eqlAny"
val ex: N = "ex"
val extended: N = "extended"
val extension: N = "extension"
val experimental: N = "experimental"
val f: N = "f"
Expand Down
127 changes: 57 additions & 70 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3362,16 +3362,6 @@ object Parsers {
Template(constr, parents, Nil, EmptyValDef, Nil)
}

/** Check that `vparamss` represents a legal collective parameter list for a given extension
*/
def checkExtensionParams(start: Offset, vparamss: List[List[ValDef]]): Unit = vparamss match
case (vparam :: Nil) :: vparamss1 if !vparam.mods.is(Given) =>
vparamss1.foreach(_.foreach(vparam =>
if !vparam.mods.is(Given) then
syntaxError(em"follow-on parameter in extension clause must be `given`", vparam.span)))
case _ =>
syntaxError(em"extension clause must start with a single regular parameter", start)

def checkExtensionMethod(tparams: List[Tree], stat: Tree): Unit = stat match {
case stat: DefDef =>
if stat.mods.is(Extension) then
Expand All @@ -3385,22 +3375,23 @@ object Parsers {

/** GivenDef ::= [GivenSig (‘:’ | <:)] Type ‘=’ Expr
* | [GivenSig ‘:’] ConstrApps [[‘with’] TemplateBody]
* | [id ‘:’] ‘extension’ ExtParamClause {GivenParamClause} ExtMethods
* | [id ‘:’] ExtParamClause {GivenParamClause} ‘extended’ ‘with’ ExtMethods
* GivenSig ::= [id] [DefTypeParamClause] {GivenParamClause}
* ExtParamClause ::= [DefTypeParamClause] DefParamClause {GivenParamClause}
* ExtParamClause ::= [DefTypeParamClause] DefParamClause
* ExtMethods ::= [nl] ‘{’ ‘def’ DefDef {semi ‘def’ DefDef} ‘}’
*/
def givenDef(start: Offset, mods: Modifiers, instanceMod: Mod) = atSpan(start, nameStart) {
var mods1 = addMod(mods, instanceMod)
val hasGivenSig = followingIsGivenSig()
val (name, isExtension) =
val nameStart = in.offset
val (name, isOldExtension) =
if isIdent && hasGivenSig then
(ident(), in.token == COLON && in.lookaheadIn(nme.extension))
else
(EmptyTermName, isIdent(nme.extension))

val gdef = in.endMarkerScope(if name.isEmpty then GIVEN else name) {
if isExtension then
if isOldExtension then
if (in.token == COLON) in.nextToken()
assert(ident() == nme.extension)
val tparams = typeParamClauseOpt(ParamOwner.Def)
Expand All @@ -3412,65 +3403,61 @@ object Parsers {
templ.body.foreach(checkExtensionMethod(tparams, _))
ModuleDef(name, templ)
else
var tparams: List[TypeDef] = Nil
var vparamss: List[List[ValDef]] = Nil
var hasExtensionParams = false

def parseParams(isExtension: Boolean): Unit =
if isExtension && (in.token == LBRACKET || in.token == LPAREN) then
hasExtensionParams = true
if tparams.nonEmpty || vparamss.nonEmpty then
syntaxError(i"cannot have parameters before and after `:` in extension")
if in.token == LBRACKET then
tparams = typeParamClause(ParamOwner.Def)
if in.token == LPAREN && followingIsParamOrGivenType() then
val paramsStart = in.offset
vparamss = paramClauses(givenOnly = !isExtension)
if isExtension then
checkExtensionParams(paramsStart, vparamss)

parseParams(isExtension = false)
val parents =
if in.token == COLON then
in.nextToken()
if in.token == LBRACKET
|| in.token == LPAREN && followingIsParamOrGivenType()
then
parseParams(isExtension = true)
Nil
else
constrApps(commaOK = true, templateCanFollow = true)
else if in.token == SUBTYPE then
if !mods.is(Inline) then
syntaxError("`<:' is only allowed for given with `inline' modifier")
in.nextToken()
TypeBoundsTree(EmptyTree, toplevelTyp()) :: Nil
else if name.isEmpty && !hasExtensionParams then
constrApps(commaOK = true, templateCanFollow = true)
val hasLabel = !name.isEmpty && in.token == COLON
if hasLabel then in.nextToken()
val tparams = typeParamClauseOpt(ParamOwner.Def)
val paramsStart = in.offset
val vparamss =
if in.token == LPAREN && followingIsParamOrGivenType()
then paramClauses()
else Nil

if in.token == EQUALS && parents.length == 1 && parents.head.isType then
val isExtension = isIdent(nme.extended)
def checkAllGivens(vparamss: List[List[ValDef]], what: String) =
vparamss.foreach(_.foreach(vparam =>
if !vparam.mods.is(Given) then syntaxError(em"$what must be `given`", vparam.span)))
if isExtension then
if !name.isEmpty && !hasLabel then
syntaxError(em"name $name of extension clause must be followed by `:`", nameStart)
vparamss match
case (vparam :: Nil) :: vparamss1 if !vparam.mods.is(Given) =>
checkAllGivens(vparamss1, "follow-on parameter in extension clause")
case _ =>
syntaxError("extension clause must start with a single regular parameter", paramsStart)
in.nextToken()
mods1 |= Final
DefDef(name, tparams, vparamss, parents.head, subExpr())
accept(WITH)
val (self, stats) = templateBody()
stats.foreach(checkExtensionMethod(tparams, _))
ModuleDef(name, Template(makeConstructor(tparams, vparamss), Nil, Nil, self, stats))
else
parents match
case TypeBoundsTree(_, _) :: _ => syntaxError("`=' expected")
case _ =>
possibleTemplateStart()
if hasExtensionParams then
in.observeIndented()
checkAllGivens(vparamss, "parameter of given instance")
val parents =
if hasLabel then
constrApps(commaOK = true, templateCanFollow = true)
else if in.token == SUBTYPE then
if !mods.is(Inline) then
syntaxError("`<:' is only allowed for given with `inline' modifier")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use `<:` instead of `<:' (same with inline) ? Both conventions seem to be used in the parser, but markdown-like syntax is probably more user-friendly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss which one to use and then standardize. I'll open an issue.

in.nextToken()
TypeBoundsTree(EmptyTree, toplevelTyp()) :: Nil
else
if !(name.isEmpty && tparams.isEmpty && vparamss.isEmpty) then
accept(COLON)
constrApps(commaOK = true, templateCanFollow = true)
if in.token == EQUALS && parents.length == 1 && parents.head.isType then
in.nextToken()
mods1 |= Final
DefDef(name, tparams, vparamss, parents.head, subExpr())
else
tparams = tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal))
vparamss = vparamss.map(_.map(vparam =>
parents match
case TypeBoundsTree(_, _) :: _ => syntaxError("`=' expected")
case _ =>
possibleTemplateStart()
val tparams1 = tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal))
val vparamss1 = vparamss.map(_.map(vparam =>
vparam.withMods(vparam.mods &~ Param | ParamAccessor | PrivateLocal)))
val templ = templateBodyOpt(makeConstructor(tparams, vparamss), parents, Nil)
if hasExtensionParams then
templ.body.foreach(checkExtensionMethod(tparams, _))
ModuleDef(name, templ)
else if tparams.isEmpty && vparamss.isEmpty then ModuleDef(name, templ)
else TypeDef(name.toTypeName, templ)
}
val templ = templateBodyOpt(makeConstructor(tparams1, vparamss1), parents, Nil)
if tparams.isEmpty && vparamss.isEmpty then ModuleDef(name, templ)
else TypeDef(name.toTypeName, templ)
}
finalizeDef(gdef, mods1, start)
}

Expand Down Expand Up @@ -3547,8 +3534,8 @@ object Parsers {
checkNextNotIndented()
Template(constr, Nil, Nil, EmptyValDef, Nil)

/** TemplateBody ::= [nl | `with'] `{' TemplateStatSeq `}'
* EnumBody ::= [nl | ‘with’] ‘{’ [SelfType] EnumStat {semi EnumStat} ‘}’
/** TemplateBody ::= [nl] `{' TemplateStatSeq `}'
* EnumBody ::= [nl] ‘{’ [SelfType] EnumStat {semi EnumStat} ‘}’
*/
def templateBodyOpt(constr: DefDef, parents: List[Tree], derived: List[Tree]): Template =
val (self, stats) =
Expand Down
4 changes: 2 additions & 2 deletions docs/docs/internals/syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ ObjectDef ::= id [Template]
EnumDef ::= id ClassConstr InheritClauses [‘with’] EnumBody EnumDef(mods, name, tparams, template)
GivenDef ::= [GivenSig (‘:’ | <:)] Type ‘=’ Expr
| [GivenSig ‘:’] ConstrApps [[‘with’] TemplateBody]
| [[id ‘:’] ‘extension’ ExtParamClause {GivenParamClause}
ExtMethods
| [id ‘:’] ExtParamClause {GivenParamClause}
‘extended’ ‘with’ ExtMethods
GivenSig ::= [id] [DefTypeParamClause] {GivenParamClause}
ExtParamClause ::= [DefTypeParamClause] ‘(’ DefParam ‘)’
ExtMethods ::= [nl] ‘{’ ‘def’ DefDef {semi ‘def’ DefDef} ‘}’
Expand Down
6 changes: 3 additions & 3 deletions docs/docs/reference/contextual/extension-methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,19 @@ List(1, 2, 3).second[Int]
`given` extensions are given instances that define extension methods and nothing else. Examples:

```scala
given stringOps: extension (xs: Seq[String]) {
given stringOps: (xs: Seq[String]) extended with {
def longestStrings: Seq[String] = {
val maxLength = xs.map(_.length).max
xs.filter(_.length == maxLength)
}
}

given listOps: extension [T](xs: List[T]) {
given listOps: [T](xs: List[T]) extended with {
def second = xs.tail.head
def third: T = xs.tail.tail.head
}

given extension [T](xs: List[T])(given Ordering[T]) {
given [T](xs: List[T])(given Ordering[T]) extended with {
def largest(n: Int) = xs.sorted.takeRight(n)
}
```
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/reference/other-new-features/opaques.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object Logarithms {
}

// Extension methods define opaque types' public APIs
given logarithmOps: extension (x: Logarithm) {
given logarithmOps: (x: Logarithm) extended with {
def toDouble: Double = math.exp(x)
def + (y: Logarithm): Logarithm = Logarithm(math.exp(x) + math.exp(y))
def * (y: Logarithm): Logarithm = Logarithm(x + y)
Expand Down
4 changes: 2 additions & 2 deletions tests/neg/extmethod-overload.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
object Test {
given a: extension (x: Int)
given a: (x: Int) extended with
def |+| (y: Int) = x + y

given b: extension (x: Int) {
given b: (x: Int) extended with {
def |+| (y: String) = x + y.length
}
assert((1 |+| 2) == 3) // error ambiguous
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/i6801.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
given MyNumericOps: extension [T](x: T) {
given myNumericOps: [T](x: T) extended with {
def + (y: T)(given n: Numeric[T]): T = n.plus(x,y)
}
def foo[T: Numeric](x: T) = 1f + x // error: no implicit argument of type Numeric[Any]
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/i7529.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
given fooOps: extension [A](a: A) with
given fooOps: [A](a: A) extended with

@nonsense // error: not found: nonsense
def foo = ???
4 changes: 2 additions & 2 deletions tests/pos/reference/delegates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ object Instances extends Common with
if (fst != 0) fst else xs1.compareTo(ys1)
end listOrd

given stringOps: extension (xs: Seq[String]) with
given stringOps: (xs: Seq[String]) extended with
def longestStrings: Seq[String] =
val maxLength = xs.map(_.length).max
xs.filter(_.length == maxLength)

given extension [T](xs: List[T])
given [T](xs: List[T]) extended with
def second = xs.tail.head
def third = xs.tail.tail.head

Expand Down
6 changes: 3 additions & 3 deletions tests/pos/reference/extension-methods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ object ExtMethods with

List(1, 2, 3).second[Int]

given stringOps: extension (xs: Seq[String]) {
given stringOps: (xs: Seq[String]) extended with {
def longestStrings: Seq[String] = {
val maxLength = xs.map(_.length).max
xs.filter(_.length == maxLength)
}
}

given listOps: extension [T](xs: List[T]) with
given listOps: [T](xs: List[T]) extended with
def second = xs.tail.head
def third: T = xs.tail.tail.head


given extension [T](xs: List[T])(given Ordering[T]) with
given [T](xs: List[T])(given Ordering[T]) extended with
def largest(n: Int) = xs.sorted.takeRight(n)

given stringOps1: AnyRef {
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/tasty-reflect-opaque-api-proto.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Reflect(val internal: CompilerInterface) {
opaque type Term <: Tree = internal.Term

object Tree {
given Ops: extension (tree: Tree) {
given ops: (tree: Tree) extended with {
def show: String = ???
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/run/extension-specificity.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
class A
class B extends A

given a: extension (x: A) with
given a: (x: A) extended with
def foo: Int = 1

given b: extension (x: B) with
given b: (x: B) extended with
def foo: Int = 2

@main def Test =
Expand Down
4 changes: 2 additions & 2 deletions tests/run/extmethods2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ object Test extends App {
test(given TC())

object A {
given listOps: extension [T](xs: List[T]) {
given listOps: [T](xs: List[T]) extended with {
def second: T = xs.tail.head
def third: T = xs.tail.tail.head
def concat(ys: List[T]) = xs ++ ys
}
given polyListOps: extension [T, U](xs: List[T]) {
given polyListOps: [T, U](xs: List[T]) extended with {
def zipp(ys: List[U]): List[(T, U)] = xs.zip(ys)
}
given extension (xs: List[Int]) {
Expand Down
6 changes: 3 additions & 3 deletions tests/run/instances.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ object Test extends App {

case class Circle(x: Double, y: Double, radius: Double)

given circleOps: extension (c: Circle) with
given circleOps: (c: Circle) extended with
def circumference: Double = c.radius * math.Pi * 2

val circle = new Circle(1, 1, 2.0)

assert(circle.circumference == circleOps.circumference(circle))

given stringOps: extension (xs: Seq[String]) with
given stringOps: (xs: Seq[String]) extended with
def longestStrings: Seq[String] =
val maxLength = xs.map(_.length).max
xs.filter(_.length == maxLength)
Expand All @@ -28,7 +28,7 @@ object Test extends App {

assert(names.longestStrings.second == "world")

given listListOps: extension [T](xs: List[List[T]]) with
given listListOps: [T](xs: List[List[T]]) extended with
def flattened = xs.foldLeft[List[T]](Nil)(_ ++ _)

// A right associative op. Note: can't use given extension for this!
Expand Down