Skip to content

Commit c174d60

Browse files
Pass ParameterInfos in command instead of getters
- Pass ParameterInfos for parameters in the command method instead of the getter functions. This way, we know about all of them beforehand, and parsing can be done more efficiently.
1 parent ddbb6d0 commit c174d60

8 files changed

+222
-229
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

Lines changed: 96 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -10,43 +10,25 @@ import Comments.Comment
1010
import NameKinds.DefaultGetterName
1111
import Annotations.Annotation
1212

13-
/** Generate proxy classes for main functions.
14-
* A function like
15-
*
16-
* /**
17-
* * Lorem ipsum dolor sit amet
18-
* * consectetur adipiscing elit.
19-
* *
20-
* * @param x my param x
21-
* * @param ys all my params y
22-
* */
23-
* @main(80) def f(
24-
* @main.ShortName('x') @main.Name("myX") x: S,
25-
* ys: T*
26-
* ) = ...
27-
*
28-
* would be translated to something like
29-
*
30-
* final class f {
31-
* static def main(args: Array[String]): Unit = {
32-
* val cmd = new main(80).command(args, "f", "Lorem ipsum dolor sit amet consectetur adipiscing elit.")
33-
*
34-
* val args0: () => S = cmd.argGetter[S](
35-
* new scala.annotation.MainAnnotation.ParameterInfos[S]("x", "S")
36-
* .withDocumentation("my param x")
37-
* .withAnnotations(new scala.main.ShortName('x'), new scala.main.Name("myX"))
38-
* )
39-
*
40-
* val args1: () => Seq[T] = cmd.varargGetter[T](
41-
* new scala.annotation.MainAnnotation.ParameterInfos[T]("ys", "T")
42-
* .withDocumentation("all my params y")
43-
* )
44-
*
45-
* cmd.run(f(args0.apply(), args1.apply()*))
46-
* }
47-
* }
48-
*/
4913
object MainProxies {
14+
/** Generate proxy classes for @main functions.
15+
* A function like
16+
*
17+
* @main def f(x: S, ys: T*) = ...
18+
*
19+
* would be translated to something like
20+
*
21+
* import CommandLineParser._
22+
* class f {
23+
* @static def main(args: Array[String]): Unit =
24+
* try
25+
* f(
26+
* parseArgument[S](args, 0),
27+
* parseRemainingArguments[T](args, 1): _*
28+
* )
29+
* catch case err: ParseError => showError(err)
30+
* }
31+
*/
5032
def mainProxiesOld(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
5133
import tpd._
5234
def mainMethods(stats: List[Tree]): List[Symbol] = stats.flatMap {
@@ -140,6 +122,44 @@ object MainProxies {
140122
private type DefaultValueSymbols = Map[Int, Symbol]
141123
private type ParameterAnnotationss = Seq[Seq[Annotation]]
142124

125+
/**
126+
* Generate proxy classes for main functions.
127+
* A function like
128+
*
129+
* /**
130+
* * Lorem ipsum dolor sit amet
131+
* * consectetur adipiscing elit.
132+
* *
133+
* * @param x my param x
134+
* * @param ys all my params y
135+
* */
136+
* @main(80) def f(
137+
* @main.ShortName('x') @main.Name("myX") x: S,
138+
* ys: T*
139+
* ) = ...
140+
*
141+
* would be translated to something like
142+
*
143+
* final class f {
144+
* static def main(args: Array[String]): Unit = {
145+
* val cmd = new main(80).command(
146+
* args,
147+
* "f",
148+
* "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
149+
* new scala.annotation.MainAnnotation.ParameterInfos("x", "S")
150+
* .withDocumentation("my param x")
151+
* .withAnnotations(new scala.main.ShortName('x'), new scala.main.Name("myX")),
152+
* new scala.annotation.MainAnnotation.ParameterInfos("ys", "T")
153+
* .withDocumentation("all my params y")
154+
* )
155+
*
156+
* val args0: () => S = cmd.argGetter[S]("x", None)
157+
* val args1: () => Seq[T] = cmd.varargGetter[T]("ys")
158+
*
159+
* cmd.run(f(args0(), args1()*))
160+
* }
161+
* }
162+
*/
143163
def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
144164
import tpd._
145165

@@ -195,6 +215,12 @@ object MainProxies {
195215
/** A literal value (Boolean, Int, String, etc.) */
196216
inline def lit(any: Any): Literal = Literal(Constant(any))
197217

218+
/** None */
219+
inline def none: Tree = ref(defn.NoneModule.termRef)
220+
221+
/** Some(value) */
222+
inline def some(value: Tree): Tree = Apply(ref(defn.SomeClass.companionModule.termRef), value)
223+
198224
/** () => value */
199225
def unitToValue(value: Tree): Tree =
200226
val anonName = nme.ANON_FUN
@@ -204,10 +230,12 @@ object MainProxies {
204230
/**
205231
* Creates a list of references and definitions of arguments, the first referencing the second.
206232
* The goal is to create the
207-
* `val arg0: () => S = ...`
208-
* part of the code. The first element of a tuple is a ref to `arg0`, the second is the whole definition.
233+
* `val args0: () => S = cmd.argGetter[S]("x", None)`
234+
* part of the code.
235+
* For each tuple, the first element is a ref to `args0`, the second is the whole definition, the third
236+
* is the ParameterInfos definition associated to this argument.
209237
*/
210-
def createArgs(mt: MethodType, cmdName: TermName): List[(Tree, ValDef)] =
238+
def createArgs(mt: MethodType, cmdName: TermName): List[(Tree, ValDef, Tree)] =
211239
mt.paramInfos.zip(mt.paramNames).zipWithIndex.map {
212240
case ((formal, paramName), n) =>
213241
val argName = nme.args ++ n.toString
@@ -220,11 +248,11 @@ object MainProxies {
220248
else (argRef0, formal, defn.MainAnnotCommand_argGetter)
221249
}
222250

223-
// The ParameterInfos to be passed to the arg getter
251+
// The ParameterInfos
224252
val parameterInfos = {
225253
val param = paramName.toString
226254
val paramInfosTree = New(
227-
AppliedTypeTree(TypeTree(defn.MainAnnotParameterInfos.typeRef), List(TypeTree(formalType))),
255+
TypeTree(defn.MainAnnotParameterInfos.typeRef),
228256
// Arguments to be passed to ParameterInfos' constructor
229257
List(List(lit(param), lit(formalType.show)))
230258
)
@@ -234,31 +262,38 @@ object MainProxies {
234262
* For example:
235263
* args0paramInfos.withDocumentation("my param x")
236264
* is represented by the pair
237-
* (defn.MainAnnotationParameterInfos_withDocumentation, List(lit("my param x")))
265+
* defn.MainAnnotationParameterInfos_withDocumentation -> List(lit("my param x"))
238266
*/
239267
var assignations: List[(Symbol, List[Tree])] = Nil
240-
for (dvSym <- defaultValueSymbols.get(n))
241-
assignations = (defn.MainAnnotationParameterInfos_withDefaultValue -> List(unitToValue(ref(dvSym.termRef)))) :: assignations
242268
for (doc <- documentation.argDocs.get(param))
243269
assignations = (defn.MainAnnotationParameterInfos_withDocumentation -> List(lit(doc))) :: assignations
244270

245271
val instanciatedAnnots = paramAnnotations(n).map(instanciateAnnotation).toList
246272
if instanciatedAnnots.nonEmpty then
247273
assignations = (defn.MainAnnotationParameterInfos_withAnnotations -> instanciatedAnnots) :: assignations
248274

249-
if assignations.isEmpty then
250-
paramInfosTree
251-
else
252-
assignations.foldLeft[Tree](paramInfosTree){ case (tree, (setterSym, values)) => Apply(Select(tree, setterSym.name), values) }
275+
assignations.foldLeft[Tree](paramInfosTree){ case (tree, (setterSym, values)) => Apply(Select(tree, setterSym.name), values) }
253276
}
254277

278+
val argParams =
279+
if formal.isRepeatedParam then
280+
List(lit(paramName.toString))
281+
else
282+
val defaultValueGetterOpt = defaultValueSymbols.get(n) match {
283+
case None =>
284+
none
285+
case Some(dvSym) =>
286+
some(unitToValue(ref(dvSym.termRef)))
287+
}
288+
List(lit(paramName.toString), defaultValueGetterOpt)
289+
255290
val argDef = ValDef(
256291
argName,
257292
TypeTree(),
258-
Apply(TypeApply(Select(Ident(cmdName), getterSym.name), TypeTree(formalType) :: Nil), parameterInfos),
293+
Apply(TypeApply(Select(Ident(cmdName), getterSym.name), TypeTree(formalType) :: Nil), argParams),
259294
)
260295

261-
(argRef, argDef)
296+
(argRef, argDef, parameterInfos)
262297
}
263298
end createArgs
264299

@@ -287,16 +322,9 @@ object MainProxies {
287322
if (!mainFun.owner.isStaticOwner)
288323
report.error(s"main method is not statically accessible", pos)
289324
else {
290-
val cmd = ValDef(
291-
cmdName,
292-
TypeTree(),
293-
Apply(
294-
Select(instanciateAnnotation(mainAnnot), defn.MainAnnot_command.name),
295-
Ident(nme.args) :: lit(mainFun.showName) :: lit(documentation.mainDoc) :: Nil
296-
)
297-
)
298325
var args: List[ValDef] = Nil
299326
var mainCall: Tree = ref(mainFun.termRef)
327+
var parameterInfoss: List[Tree] = Nil
300328

301329
mainFun.info match {
302330
case _: ExprType =>
@@ -309,16 +337,25 @@ object MainProxies {
309337
report.error(s"main method cannot be curried", pos)
310338
Nil
311339
case _ =>
312-
val (argRefs, argVals) = createArgs(mt, cmdName).unzip
340+
val (argRefs, argVals, paramInfoss) = createArgs(mt, cmdName).unzip3
313341
args = argVals
314342
mainCall = Apply(mainCall, argRefs)
343+
parameterInfoss = paramInfoss
315344
}
316345
case _: PolyType =>
317346
report.error(s"main method cannot have type parameters", pos)
318347
case _ =>
319348
report.error(s"main can only annotate a method", pos)
320349
}
321350

351+
val cmd = ValDef(
352+
cmdName,
353+
TypeTree(),
354+
Apply(
355+
Select(instanciateAnnotation(mainAnnot), defn.MainAnnot_command.name),
356+
Ident(nme.args) :: lit(mainFun.showName) :: lit(documentation.mainDoc) :: parameterInfoss
357+
)
358+
)
322359
val run = Apply(Select(Ident(cmdName), defn.MainAnnotCommand_run.name), mainCall)
323360
val body = Block(cmd :: args, run)
324361
val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree)

library/src/scala/annotation/MainAnnotation.scala

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,30 @@ trait MainAnnotation extends StaticAnnotation:
2121
type MainResultType
2222

2323
/** A new command with arguments from `args` */
24-
def command(args: Array[String], commandName: String, documentation: String): MainAnnotation.Command[ArgumentParser, MainResultType]
24+
def command(args: Array[String], commandName: String, documentation: String, parameterInfoss: MainAnnotation.ParameterInfos*): MainAnnotation.Command[ArgumentParser, MainResultType]
2525
end MainAnnotation
2626

2727
object MainAnnotation:
2828
// Inspired by https://github.com/scala-js/scala-js/blob/0708917912938714d52be1426364f78a3d1fd269/linker-interface/shared/src/main/scala/org/scalajs/linker/interface/StandardConfig.scala#L23-L218
29-
final class ParameterInfos[T] private (
29+
final class ParameterInfos private (
3030
/** The name of the parameter */
3131
val name: String,
3232
/** The name of the parameter's type */
3333
val typeName: String,
3434
/** The docstring of the parameter. Defaults to None. */
3535
val documentation: Option[String],
36-
/** The default value that the parameter has. Defaults to None. */
37-
val defaultValueGetterOpt: Option[() => T],
3836
/** The ParameterAnnotations associated with the parameter. Defaults to Seq.empty. */
3937
val annotations: Seq[ParameterAnnotation],
4038
) {
4139
// Main public constructor
4240
def this(name: String, typeName: String) =
43-
this(name, typeName, None, None, Seq.empty)
41+
this(name, typeName, None, Seq.empty)
4442

45-
def withDefaultValue(defaultValueGetter: () => T): ParameterInfos[T] =
46-
new ParameterInfos(name, typeName, documentation, Some(defaultValueGetter), annotations)
43+
def withDocumentation(doc: String): ParameterInfos =
44+
new ParameterInfos(name, typeName, Some(doc), annotations)
4745

48-
def withDocumentation(doc: String): ParameterInfos[T] =
49-
new ParameterInfos(name, typeName, Some(doc), defaultValueGetterOpt, annotations)
50-
51-
def withAnnotations(annots: ParameterAnnotation*): ParameterInfos[T] =
52-
new ParameterInfos(name, typeName, documentation, defaultValueGetterOpt, annots)
46+
def withAnnotations(annots: ParameterAnnotation*): ParameterInfos =
47+
new ParameterInfos(name, typeName, documentation, annots)
5348

5449
override def toString: String = s"$name: $typeName"
5550
}
@@ -58,10 +53,10 @@ object MainAnnotation:
5853
trait Command[ArgumentParser[_], MainResultType]:
5954

6055
/** The getter for the next argument of type `T` */
61-
def argGetter[T](paramInfos: ParameterInfos[T])(using fromString: ArgumentParser[T]): () => T
56+
def argGetter[T](name: String, optDefaultGetter: Option[() => T])(using fromString: ArgumentParser[T]): () => T
6257

6358
/** The getter for a final varargs argument of type `T*` */
64-
def varargGetter[T](paramInfos: ParameterInfos[T])(using fromString: ArgumentParser[T]): () => Seq[T]
59+
def varargGetter[T](name: String)(using fromString: ArgumentParser[T]): () => Seq[T]
6560

6661
/** Run `program` if all arguments are valid,
6762
* or print usage information and/or error messages.

0 commit comments

Comments
 (0)