Skip to content

Commit da9bdfb

Browse files
Change structure of ParameterInfos
- Based on suggestion by Nicolas Stucki, here: https://gist.github.com/nicolasstucki/84ebcd5c2cfc9aa14abba96ae1a0e996 - Add ability to pass multiple ParameterAnnotations
1 parent a142ceb commit da9bdfb

File tree

6 files changed

+167
-95
lines changed

6 files changed

+167
-95
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import Annotations.Annotation
4747
*/
4848
object MainProxies {
4949
private type DefaultValueSymbols = Map[Int, Symbol]
50-
private type ParameterAnnotations = Vector[Option[Annotation]]
50+
private type ParameterAnnotationss = Seq[Seq[Annotation]]
5151

5252
def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
5353
import tpd._
@@ -69,19 +69,15 @@ object MainProxies {
6969
}
7070

7171
/** Computes the list of main methods present in the code. */
72-
def mainMethods(scope: Tree, stats: List[Tree]): List[(Symbol, ParameterAnnotations, DefaultValueSymbols, Option[Comment])] = stats.flatMap {
72+
def mainMethods(scope: Tree, stats: List[Tree]): List[(Symbol, ParameterAnnotationss, DefaultValueSymbols, Option[Comment])] = stats.flatMap {
7373
case stat: DefDef =>
7474
val sym = stat.symbol
7575
sym.annotations.filter(_.matches(defn.MainAnnot)) match {
7676
case Nil =>
7777
Nil
7878
case _ :: Nil =>
7979
val paramAnnotations = stat.paramss.flatMap(_.map(
80-
valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotParameterAnnotation)) match {
81-
case Nil => None
82-
case paramAnnot :: Nil => Some(paramAnnot)
83-
case paramAnnot :: others => report.error(s"parameters cannot have multiple annotations", paramAnnot.tree); None
84-
}
80+
valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotParameterAnnotation))
8581
))
8682
(sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
8783
case mainAnnot :: others =>
@@ -99,7 +95,7 @@ object MainProxies {
9995
}
10096

10197
import untpd._
102-
def mainProxy(mainFun: Symbol, paramAnnotations: ParameterAnnotations, defaultValueSymbols: DefaultValueSymbols, docComment: Option[Comment])(using Context): List[TypeDef] = {
98+
def mainProxy(mainFun: Symbol, paramAnnotations: ParameterAnnotationss, defaultValueSymbols: DefaultValueSymbols, docComment: Option[Comment])(using Context): List[TypeDef] = {
10399
val mainAnnot = mainFun.getAnnotation(defn.MainAnnot).get
104100
def pos = mainFun.sourcePos
105101
val mainArgsName: TermName = nme.args
@@ -111,6 +107,11 @@ object MainProxies {
111107

112108
inline def some(value: Tree): Tree = Apply(ref(defn.SomeClass.companionModule.termRef), value)
113109

110+
def unitToValue(value: Tree): Tree =
111+
val anonName = nme.ANON_FUN
112+
val defdef = DefDef(anonName, List(Nil), TypeTree(), value)
113+
Block(defdef, Closure(Nil, Ident(anonName), EmptyTree))
114+
114115
/**
115116
* Creates a list of references and definitions of arguments, the first referencing the second.
116117
* The goal is to create the
@@ -150,27 +151,24 @@ object MainProxies {
150151
/*
151152
* Assignations to be made after the creation of the ParameterInfos.
152153
* For example:
153-
* args0paramInfos.documentation = Some("my param x")
154+
* args0paramInfos.withDocumentation = Some("my param x")
154155
* is represented by the pair
155-
* ("documentation", some(lit("my param x")))
156+
* (defn.MainAnnotationParameterInfos_withDocumentation, some(lit("my param x")))
156157
*/
157-
var assignations: List[(String, Tree)] = Nil
158+
var assignations: List[(Symbol, List[Tree])] = Nil
158159
for (dvSym <- defaultValueSymbols.get(n))
159-
assignations = ("defaultValue" -> some(ref(dvSym.termRef))) :: assignations
160-
for (annot <- paramAnnotations(n))
161-
assignations = ("annotation" -> some(instanciateAnnotation(annot))) :: assignations
160+
assignations = (defn.MainAnnotationParameterInfos_withDefaultValue -> List(unitToValue(ref(dvSym.termRef)))) :: assignations
162161
for (doc <- documentation.argDocs.get(param))
163-
assignations = ("documentation" -> some(lit(doc))) :: assignations
162+
assignations = (defn.MainAnnotationParameterInfos_withDocumentation -> List(lit(doc))) :: assignations
164163

165-
val assignationsTrees = assignations.map{
166-
case (name, value) => Apply(Select(paramInfosIdent, defn.MainAnnotParameterInfos.requiredMethod(name + "_=").name), value)
167-
}
164+
val instanciatedAnnots = paramAnnotations(n).map(instanciateAnnotation).toList
165+
if instanciatedAnnots.nonEmpty then
166+
assignations = (defn.MainAnnotationParameterInfos_withAnnotations -> instanciatedAnnots) :: assignations
168167

169168
if assignations.isEmpty then
170169
paramInfosTree
171170
else
172-
val paramInfosInstance = ValDef(paramInfosName, TypeTree(), paramInfosTree)
173-
Block(paramInfosInstance :: assignationsTrees, paramInfosIdent)
171+
assignations.foldLeft[Tree](paramInfosTree){ case (tree, (setterSym, values)) => Apply(Select(tree, setterSym.name), values) }
174172
}
175173

176174
val argDef = ValDef(

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,9 @@ class Definitions {
859859
@tu lazy val MainAnnot: ClassSymbol = requiredClass("scala.annotation.MainAnnotation")
860860
@tu lazy val MainAnnot_command: Symbol = MainAnnot.requiredMethod("command")
861861
@tu lazy val MainAnnotParameterInfos: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterInfos")
862+
@tu lazy val MainAnnotationParameterInfos_withDefaultValue: Symbol = MainAnnotParameterInfos.requiredMethod("withDefaultValue")
863+
@tu lazy val MainAnnotationParameterInfos_withDocumentation: Symbol = MainAnnotParameterInfos.requiredMethod("withDocumentation")
864+
@tu lazy val MainAnnotationParameterInfos_withAnnotations: Symbol = MainAnnotParameterInfos.requiredMethod("withAnnotations")
862865
@tu lazy val MainAnnotParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation")
863866
@tu lazy val MainAnnotCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command")
864867
@tu lazy val MainAnnotCommand_argGetter: Symbol = MainAnnotCommand.requiredMethod("argGetter")

library/src/scala/annotation/MainAnnotation.scala

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,34 @@ trait MainAnnotation extends StaticAnnotation:
2525
end MainAnnotation
2626

2727
object MainAnnotation:
28-
/**
29-
* The information related to one of the parameters of the annotated method.
30-
* @param name the name of the parameter
31-
* @param typeName the name of the parameter's type
32-
* @tparam T the type of the parameter
33-
*/
34-
class ParameterInfos[T](var name: String, var typeName: String):
28+
// Inspired by https://github.com/scala-js/scala-js/blob/0708917912938714d52be1426364f78a3d1fd269/linker-interface/shared/src/main/scala/org/scalajs/linker/interface/StandardConfig.scala#L23-L218
29+
final class ParameterInfos[T] private (
30+
/** The name of the parameter */
31+
val name: String,
32+
/** The name of the parameter's type */
33+
val typeName: String,
3534
/** The docstring of the parameter. Defaults to None. */
36-
var documentation: Option[String] = None
35+
val documentation: Option[String],
3736
/** The default value that the parameter has. Defaults to None. */
38-
var defaultValue: Option[T] = None
39-
/** If there is one, the ParameterAnnotation associated with the parameter. Defaults to None. */
40-
var annotation: Option[ParameterAnnotation] = None
37+
val defaultValueOpt: Option[() => T],
38+
/** The ParameterAnnotations associated with the parameter. Defaults to Seq.empty. */
39+
val annotations: Seq[ParameterAnnotation],
40+
) {
41+
// Main public constructor
42+
def this(name: String, typeName: String) =
43+
this(name, typeName, None, None, Seq.empty)
44+
45+
def withDefaultValue(defaultValueGetter: () => T): ParameterInfos[T] =
46+
new ParameterInfos(name, typeName, documentation, Some(defaultValueGetter), annotations)
47+
48+
def withDocumentation(doc: String): ParameterInfos[T] =
49+
new ParameterInfos(name, typeName, Some(doc), defaultValueOpt, annotations)
50+
51+
def withAnnotations(annots: ParameterAnnotation*): ParameterInfos[T] =
52+
new ParameterInfos(name, typeName, documentation, defaultValueOpt, annots)
53+
54+
override def toString: String = s"$name: $typeName"
55+
}
4156

4257
/** A class representing a command to run */
4358
trait Command[ArgumentParser[_], MainResultType]:

library/src/scala/main.scala

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,17 +174,11 @@ final class main(maxLineLength: Int) extends MainAnnotation:
174174
error(s"more than one value for $argName: ${multValues.mkString(", ")}")
175175
}
176176

177-
private def getAnnotationData[T](paramInfos: ParameterInfos[_], extractor: Arg => T): Option[T] =
178-
paramInfos.annotation match {
179-
case Some(annot: Arg) => Some(extractor(annot))
180-
case _ => None
181-
}
182-
183177
private inline def getEffectiveName(paramInfos: ParameterInfos[_]): String =
184-
getAnnotationData(paramInfos, _.name).filter(_.length > 0).getOrElse(paramInfos.name)
178+
paramInfos.annotations.collectFirst{ case arg: Arg if arg.name.length > 0 => arg.name }.getOrElse(paramInfos.name)
185179

186180
private inline def getShortName(paramInfos: ParameterInfos[_]): Option[Char] =
187-
getAnnotationData(paramInfos, _.shortName).filterNot(_ == 0)
181+
paramInfos.annotations.collectFirst{ case arg: Arg if arg.shortName != 0 => arg.shortName }
188182

189183
private def registerArg(paramInfos: ParameterInfos[_], argKind: ArgumentKind): Unit =
190184
argNames += getEffectiveName(paramInfos)
@@ -198,8 +192,8 @@ final class main(maxLineLength: Int) extends MainAnnotation:
198192

199193
override def argGetter[T](paramInfos: ParameterInfos[T])(using p: ArgumentParser[T]): () => T =
200194
val name = getEffectiveName(paramInfos)
201-
val (defaultGetter, argumentKind) = paramInfos.defaultValue match {
202-
case Some(value) => (() => () => value, ArgumentKind.OptionalArgument)
195+
val (defaultGetter, argumentKind) = paramInfos.defaultValueOpt match {
196+
case Some(value) => (() => value, ArgumentKind.OptionalArgument)
203197
case None => (() => error(s"missing argument for $name"), ArgumentKind.SimpleArgument)
204198
}
205199
registerArg(paramInfos, argumentKind)

tests/run/main-annotation-homemade-annot-1.scala

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ object Test:
2121
end Test
2222

2323
class mainAwait(timeout: Int = 2) extends MainAnnotation:
24-
self =>
24+
import MainAnnotation._
25+
import main.{Arg}
2526

2627
private val maxLineLength = 120
2728

@@ -33,8 +34,12 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
3334
}
3435

3536
override def command(args: Array[String], commandName: String, docComment: String) =
36-
new MainAnnotation.Command[ArgumentParser, MainResultType]:
37+
new Command[ArgumentParser, MainResultType]:
38+
private val argMarker = "--"
39+
private val shortArgMarker = "-"
40+
3741
private var argNames = new mutable.ArrayBuffer[String]
42+
private var argShortNames = new mutable.ArrayBuffer[Option[Char]]
3843
private var argTypes = new mutable.ArrayBuffer[String]
3944
private var argDocs = new mutable.ArrayBuffer[String]
4045
private var argKinds = new mutable.ArrayBuffer[ArgumentKind]
@@ -53,23 +58,34 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
5358
private def argAt(idx: Int): Option[String] =
5459
if idx < args.length then Some(args(idx)) else None
5560

61+
private def isArgNameAt(idx: Int): Boolean =
62+
val arg = args(argIdx)
63+
val isFullName = arg.startsWith(argMarker)
64+
val isShortName = arg.startsWith(shortArgMarker) && arg.length == 2 && shortNameIsValid(arg(1))
65+
66+
isFullName || isShortName
67+
5668
private def nextPositionalArg(): Option[String] =
57-
while argIdx < args.length && args(argIdx).startsWith("--") do argIdx += 2
69+
while argIdx < args.length && isArgNameAt(argIdx) do argIdx += 2
5870
val result = argAt(argIdx)
5971
argIdx += 1
6072
result
6173

74+
private def shortNameIsValid(shortName: Char): Boolean =
75+
shortName == 0 || shortName.isLetter
76+
6277
private def convert[T](argName: String, arg: String, p: ArgumentParser[T]): () => T =
6378
p.fromStringOption(arg) match
6479
case Some(t) => () => t
6580
case None => error(s"invalid argument for $argName: $arg")
6681

6782
private def argUsage(pos: Int): String =
6883
val name = argNames(pos)
84+
val namePrint = argShortNames(pos).map(short => s"[$shortArgMarker$short | $argMarker$name]").getOrElse(s"[$argMarker$name]")
6985

7086
argKinds(pos) match {
71-
case ArgumentKind.SimpleArgument => s"[--$name] <${argTypes(pos)}>"
72-
case ArgumentKind.OptionalArgument => s"[[--$name] <${argTypes(pos)}>]"
87+
case ArgumentKind.SimpleArgument => s"$namePrint <${argTypes(pos)}>"
88+
case ArgumentKind.OptionalArgument => s"[$namePrint <${argTypes(pos)}>]"
7389
case ArgumentKind.VarArgument => s"[<${argTypes(pos)}> [<${argTypes(pos)}> [...]]]"
7490
}
7591

@@ -137,19 +153,18 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
137153
println(argDoc)
138154
}
139155

140-
private def indicesOfArg(argName: String): Seq[Int] =
141-
def allIndicesOf(s: String): Seq[Int] =
142-
def recurse(s: String, from: Int): Seq[Int] =
143-
val i = args.indexOf(s, from)
144-
if i < 0 then Seq() else i +: recurse(s, i + 1)
145-
146-
recurse(s, 0)
156+
private def indicesOfArg(argName: String, shortArgName: Option[Char]): Seq[Int] =
157+
def allIndicesOf(s: String, from: Int): Seq[Int] =
158+
val i = args.indexOf(s, from)
159+
if i < 0 then Seq() else i +: allIndicesOf(s, i + 1)
147160

148-
val indices = allIndicesOf(s"--$argName")
149-
indices.filter(_ >= 0)
161+
val indices = allIndicesOf(s"$argMarker$argName", 0)
162+
val indicesShort = shortArgName.map(shortName => allIndicesOf(s"$shortArgMarker$shortName", 0)).getOrElse(Seq())
163+
(indices ++: indicesShort).filter(_ >= 0)
150164

151-
private def getArgGetter[T](argName: String, getDefaultGetter: () => () => T)(using p: ArgumentParser[T]): () => T =
152-
indicesOfArg(argName) match {
165+
private def getArgGetter[T](paramInfos: ParameterInfos[_], getDefaultGetter: () => () => T)(using p: ArgumentParser[T]): () => T =
166+
val argName = getEffectiveName(paramInfos)
167+
indicesOfArg(argName, getShortName(paramInfos)) match {
153168
case s @ (Seq() | Seq(_)) =>
154169
val argOpt = s.headOption.map(idx => argAt(idx + 1)).getOrElse(nextPositionalArg())
155170
argOpt match {
@@ -161,48 +176,64 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
161176
error(s"more than one value for $argName: ${multValues.mkString(", ")}")
162177
}
163178

164-
private def registerArg(paramInfos: MainAnnotation.ParameterInfos[_], argKind: ArgumentKind): Unit =
165-
argNames += paramInfos.name
179+
private inline def getEffectiveName(paramInfos: ParameterInfos[_]): String =
180+
paramInfos.annotations.collectFirst{ case arg: Arg if arg.name.length > 0 => arg.name }.getOrElse(paramInfos.name)
181+
182+
private inline def getShortName(paramInfos: ParameterInfos[_]): Option[Char] =
183+
paramInfos.annotations.collectFirst{ case arg: Arg if arg.shortName != 0 => arg.shortName }
184+
185+
private def registerArg(paramInfos: ParameterInfos[_], argKind: ArgumentKind): Unit =
186+
argNames += getEffectiveName(paramInfos)
166187
argTypes += paramInfos.typeName
167188
argDocs += paramInfos.documentation.getOrElse("")
168189
argKinds += argKind
169190

170-
override def argGetter[T](paramInfos: MainAnnotation.ParameterInfos[T])(using p: ArgumentParser[T]): () => T =
171-
val name = paramInfos.name
172-
val (defaultGetter, argumentKind) = paramInfos.defaultValue match {
173-
case Some(value) => (() => () => value, ArgumentKind.OptionalArgument)
191+
val shortName = getShortName(paramInfos)
192+
shortName.foreach(c => if !shortNameIsValid(c) then throw IllegalArgumentException(s"Invalid short name: $shortArgMarker$c"))
193+
argShortNames += shortName
194+
195+
override def argGetter[T](paramInfos: ParameterInfos[T])(using p: ArgumentParser[T]): () => T =
196+
val name = getEffectiveName(paramInfos)
197+
val (defaultGetter, argumentKind) = paramInfos.defaultValueOpt match {
198+
case Some(value) => (() => value, ArgumentKind.OptionalArgument)
174199
case None => (() => error(s"missing argument for $name"), ArgumentKind.SimpleArgument)
175200
}
176201
registerArg(paramInfos, argumentKind)
177-
getArgGetter(name, defaultGetter)
202+
getArgGetter(paramInfos, defaultGetter)
178203

179-
override def varargGetter[T](paramInfos: MainAnnotation.ParameterInfos[T])(using p: ArgumentParser[T]): () => Seq[T] =
204+
override def varargGetter[T](paramInfos: ParameterInfos[T])(using p: ArgumentParser[T]): () => Seq[T] =
180205
registerArg(paramInfos, ArgumentKind.VarArgument)
181206
def remainingArgGetters(): List[() => T] = nextPositionalArg() match
182-
case Some(arg) => convert(paramInfos.name, arg, p) :: remainingArgGetters()
207+
case Some(arg) => convert(getEffectiveName(paramInfos), arg, p) :: remainingArgGetters()
183208
case None => Nil
184209
val getters = remainingArgGetters()
185210
() => getters.map(_())
186211

187212
override def run(f: => MainResultType): Unit =
213+
def checkShortNamesUnique(): Unit =
214+
val shortNameToIndices = argShortNames.collect{ case Some(short) => short }.zipWithIndex.groupBy(_._1).view.mapValues(_.map(_._2))
215+
for ((shortName, indices) <- shortNameToIndices if indices.length > 1)
216+
error(s"$shortName is used as short name for multiple parameters: ${indices.map(idx => argNames(idx)).mkString(", ")}")
217+
188218
def flagUnused(): Unit = nextPositionalArg() match
189219
case Some(arg) =>
190220
error(s"unused argument: $arg")
191221
flagUnused()
192222
case None =>
193223
for
194224
arg <- args
195-
if arg.startsWith("--") && !argNames.contains(arg.drop(2))
225+
if arg.startsWith(argMarker) && !argNames.contains(arg.drop(2))
196226
do
197227
error(s"unknown argument name: $arg")
198228
end flagUnused
199229

200-
if args.contains("--help") then
230+
if args.contains(s"${argMarker}help") then
201231
usage()
202232
println()
203233
explain()
204234
else
205235
flagUnused()
236+
checkShortNamesUnique()
206237
if errors.nonEmpty then
207238
for msg <- errors do println(s"Error: $msg")
208239
usage()

0 commit comments

Comments
 (0)