@@ -2,30 +2,40 @@ package dotty.tools.dotc
2
2
package ast
3
3
4
4
import core ._
5
- import Symbols ._ , Types ._ , Contexts ._ , Flags ._ , Constants ._
6
- import StdNames .nme
7
-
8
- /** Generate proxy classes for @main functions.
9
- * A function like
10
- *
11
- * @main def f(x: S, ys: T*) = ...
12
- *
13
- * would be translated to something like
14
- *
15
- * import CommandLineParser._
16
- * class f {
17
- * @static def main(args: Array[String]): Unit =
18
- * try
19
- * f(
20
- * parseArgument[S](args, 0),
21
- * parseRemainingArguments[T](args, 1): _*
22
- * )
23
- * catch case err: ParseError => showError(err)
24
- * }
25
- */
5
+ import Symbols ._ , Types ._ , Contexts ._ , Decorators ._ , util .Spans ._ , Flags ._ , Constants ._
6
+ import StdNames .{nme , tpnme }
7
+ import ast .Trees ._
8
+ import Names .Name
9
+ import Comments .Comment
10
+ import NameKinds .DefaultGetterName
11
+ import Annotations .Annotation
12
+
26
13
object MainProxies {
27
14
28
- def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
15
+ /** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */
16
+ def proxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
17
+ mainAnnotationProxies(stats) ++ mainProxies(stats)
18
+ }
19
+
20
+ /** Generate proxy classes for @main functions.
21
+ * A function like
22
+ *
23
+ * @main def f(x: S, ys: T*) = ...
24
+ *
25
+ * would be translated to something like
26
+ *
27
+ * import CommandLineParser._
28
+ * class f {
29
+ * @static def main(args: Array[String]): Unit =
30
+ * try
31
+ * f(
32
+ * parseArgument[S](args, 0),
33
+ * parseRemainingArguments[T](args, 1): _*
34
+ * )
35
+ * catch case err: ParseError => showError(err)
36
+ * }
37
+ */
38
+ private def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
29
39
import tpd ._
30
40
def mainMethods (stats : List [Tree ]): List [Symbol ] = stats.flatMap {
31
41
case stat : DefDef if stat.symbol.hasAnnotation(defn.MainAnnot ) =>
@@ -39,7 +49,7 @@ object MainProxies {
39
49
}
40
50
41
51
import untpd ._
42
- def mainProxy (mainFun : Symbol )(using Context ): List [TypeDef ] = {
52
+ private def mainProxy (mainFun : Symbol )(using Context ): List [TypeDef ] = {
43
53
val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot ).get.tree.span
44
54
def pos = mainFun.sourcePos
45
55
val argsRef = Ident (nme.args)
@@ -116,4 +126,298 @@ object MainProxies {
116
126
}
117
127
result
118
128
}
129
+
130
+ private type DefaultValueSymbols = Map [Int , Symbol ]
131
+ private type ParameterAnnotationss = Seq [Seq [Annotation ]]
132
+
133
+ /**
134
+ * Generate proxy classes for main functions.
135
+ * A function like
136
+ *
137
+ * /* *
138
+ * * Lorem ipsum dolor sit amet
139
+ * * consectetur adipiscing elit.
140
+ * *
141
+ * * @param x my param x
142
+ * * @param ys all my params y
143
+ * */
144
+ * @myMain(80) def f(
145
+ * @myMain.Alias("myX") x: S,
146
+ * ys: T*
147
+ * ) = ...
148
+ *
149
+ * would be translated to something like
150
+ *
151
+ * final class f {
152
+ * static def main(args: Array[String]): Unit = {
153
+ * val cmd = new myMain(80).command(
154
+ * info = new CommandInfo(
155
+ * name = "f",
156
+ * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
157
+ * parameters = Seq(
158
+ * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
159
+ * new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq())
160
+ * )
161
+ * )
162
+ * args = args
163
+ * )
164
+ *
165
+ * val args0: () => S = cmd.argGetter[S](0, None)
166
+ * val args1: () => Seq[T] = cmd.varargGetter[T]
167
+ *
168
+ * cmd.run(() => f(args0(), args1()*))
169
+ * }
170
+ * }
171
+ */
172
+ private def mainAnnotationProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
173
+ import tpd ._
174
+
175
+ /**
176
+ * Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this
177
+ * point of the compilation, they must be explicitly passed by [[mainProxy ]].
178
+ */
179
+ def defaultValueSymbols (scope : Tree , funSymbol : Symbol ): DefaultValueSymbols =
180
+ scope match {
181
+ case TypeDef (_, template : Template ) =>
182
+ template.body.flatMap((_ : Tree ) match {
183
+ case dd : DefDef if dd.name.is(DefaultGetterName ) && dd.name.firstPart == funSymbol.name =>
184
+ val DefaultGetterName .NumberedInfo (index) = dd.name.info
185
+ List (index -> dd.symbol)
186
+ case _ => Nil
187
+ }).toMap
188
+ case _ => Map .empty
189
+ }
190
+
191
+ /** Computes the list of main methods present in the code. */
192
+ def mainMethods (scope : Tree , stats : List [Tree ]): List [(Symbol , ParameterAnnotationss , DefaultValueSymbols , Option [Comment ])] = stats.flatMap {
193
+ case stat : DefDef =>
194
+ val sym = stat.symbol
195
+ sym.annotations.filter(_.matches(defn.MainAnnotationClass )) match {
196
+ case Nil =>
197
+ Nil
198
+ case _ :: Nil =>
199
+ val paramAnnotations = stat.paramss.flatMap(_.map(
200
+ valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation ))
201
+ ))
202
+ (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
203
+ case mainAnnot :: others =>
204
+ report.error(s " method cannot have multiple main annotations " , mainAnnot.tree)
205
+ Nil
206
+ }
207
+ case stat @ TypeDef (_, impl : Template ) if stat.symbol.is(Module ) =>
208
+ mainMethods(stat, impl.body)
209
+ case _ =>
210
+ Nil
211
+ }
212
+
213
+ // Assuming that the top-level object was already generated, all main methods will have a scope
214
+ mainMethods(EmptyTree , stats).flatMap(mainAnnotationProxy)
215
+ }
216
+
217
+ private def mainAnnotationProxy (mainFun : Symbol , paramAnnotations : ParameterAnnotationss , defaultValueSymbols : DefaultValueSymbols , docComment : Option [Comment ])(using Context ): Option [TypeDef ] = {
218
+ val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass ).get
219
+ def pos = mainFun.sourcePos
220
+
221
+ val documentation = new Documentation (docComment)
222
+
223
+ /** () => value */
224
+ def unitToValue (value : Tree ): Tree =
225
+ val defDef = DefDef (nme.ANON_FUN , List (Nil ), TypeTree (), value)
226
+ Block (defDef, Closure (Nil , Ident (nme.ANON_FUN ), EmptyTree ))
227
+
228
+ /** Generate a list of trees containing the ParamInfo instantiations.
229
+ *
230
+ * A ParamInfo has the following shape
231
+ * ```
232
+ * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
233
+ * ```
234
+ */
235
+ def parameterInfos (mt : MethodType ): List [Tree ] =
236
+ extension (tree : Tree ) def withProperty (sym : Symbol , args : List [Tree ]) =
237
+ Apply (Select (tree, sym.name), args)
238
+
239
+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
240
+ val param = paramName.toString
241
+ val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias
242
+ val paramType = paramType0.dealias
243
+
244
+ val paramTypeStr = formal.dealias.typeSymbol.owner.showFullName + " ." + paramType.show
245
+ val hasDefault = defaultValueSymbols.contains(idx)
246
+ val isRepeated = formal.isRepeatedParam
247
+ val paramDoc = documentation.argDocs.getOrElse(param, " " )
248
+ val paramAnnots =
249
+ val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList
250
+ Apply (ref(defn.SeqModule .termRef), annotationTrees)
251
+
252
+ val constructorArgs = List (param, paramTypeStr, hasDefault, isRepeated, paramDoc)
253
+ .map(value => Literal (Constant (value)))
254
+
255
+ New (TypeTree (defn.MainAnnotationParameterInfo .typeRef), List (constructorArgs :+ paramAnnots))
256
+
257
+ end parameterInfos
258
+
259
+ /**
260
+ * Creates a list of references and definitions of arguments.
261
+ * The goal is to create the
262
+ * `val args0: () => S = cmd.argGetter[S](0, None)`
263
+ * part of the code.
264
+ */
265
+ def argValDefs (mt : MethodType ): List [ValDef ] =
266
+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
267
+ val argName = nme.args ++ idx.toString
268
+ val isRepeated = formal.isRepeatedParam
269
+ val formalType = if isRepeated then formal.argTypes.head else formal
270
+ val getterName = if isRepeated then nme.varargGetter else nme.argGetter
271
+ val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
272
+ case None => ref(defn.NoneModule .termRef)
273
+ case Some (dvSym) =>
274
+ val value = unitToValue(ref(dvSym.termRef))
275
+ Apply (ref(defn.SomeClass .companionModule.termRef), value)
276
+ val argGetter0 = TypeApply (Select (Ident (nme.cmd), getterName), TypeTree (formalType) :: Nil )
277
+ val argGetter =
278
+ if isRepeated then argGetter0
279
+ else Apply (argGetter0, List (Literal (Constant (idx)), defaultValueGetterOpt))
280
+
281
+ ValDef (argName, TypeTree (), argGetter)
282
+ end argValDefs
283
+
284
+
285
+ /** Create a list of argument references that will be passed as argument to the main method.
286
+ * `args0`, ...`argn*`
287
+ */
288
+ def argRefs (mt : MethodType ): List [Tree ] =
289
+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
290
+ val argRef = Apply (Ident (nme.args ++ idx.toString), Nil )
291
+ if formal.isRepeatedParam then repeated(argRef) else argRef
292
+ end argRefs
293
+
294
+
295
+ /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
296
+ def instantiateAnnotation (annot : Annotation ): Tree =
297
+ val argss = {
298
+ def recurse (t : tpd.Tree , acc : List [List [Tree ]]): List [List [Tree ]] = t match {
299
+ case Apply (t, args : List [tpd.Tree ]) => recurse(t, extractArgs(args) :: acc)
300
+ case _ => acc
301
+ }
302
+
303
+ def extractArgs (args : List [tpd.Tree ]): List [Tree ] =
304
+ args.flatMap {
305
+ case Typed (SeqLiteral (varargs, _), _) => varargs.map(arg => TypedSplice (arg))
306
+ case arg : Select if arg.name.is(DefaultGetterName ) => Nil // Ignore default values, they will be added later by the compiler
307
+ case arg => List (TypedSplice (arg))
308
+ }
309
+
310
+ recurse(annot.tree, Nil )
311
+ }
312
+
313
+ New (TypeTree (annot.symbol.typeRef), argss)
314
+ end instantiateAnnotation
315
+
316
+ def generateMainClass (mainCall : Tree , args : List [Tree ], parameterInfos : List [Tree ]): TypeDef =
317
+ val cmdInfo =
318
+ val nameTree = Literal (Constant (mainFun.showName))
319
+ val docTree = Literal (Constant (documentation.mainDoc))
320
+ val paramInfos = Apply (ref(defn.SeqModule .termRef), parameterInfos)
321
+ New (TypeTree (defn.MainAnnotationCommandInfo .typeRef), List (List (nameTree, docTree, paramInfos)))
322
+
323
+ val cmd = ValDef (
324
+ nme.cmd,
325
+ TypeTree (),
326
+ Apply (
327
+ Select (instantiateAnnotation(mainAnnot), nme.command),
328
+ List (cmdInfo, Ident (nme.args))
329
+ )
330
+ )
331
+ val run = Apply (Select (Ident (nme.cmd), nme.run), mainCall)
332
+ val body = Block (cmdInfo :: cmd :: args, run)
333
+ val mainArg = ValDef (nme.args, TypeTree (defn.ArrayType .appliedTo(defn.StringType )), EmptyTree )
334
+ .withFlags(Param )
335
+ /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
336
+ * The annotations will be retype-checked in another scope that may not have the same imports.
337
+ */
338
+ def insertTypeSplices = new TreeMap {
339
+ override def transform (tree : Tree )(using Context ): Tree = tree match
340
+ case tree : tpd.Ident @ unchecked => TypedSplice (tree)
341
+ case tree => super .transform(tree)
342
+ }
343
+ val annots = mainFun.annotations
344
+ .filterNot(_.matches(defn.MainAnnotationClass ))
345
+ .map(annot => insertTypeSplices.transform(annot.tree))
346
+ val mainMeth = DefDef (nme.main, (mainArg :: Nil ) :: Nil , TypeTree (defn.UnitType ), body)
347
+ .withFlags(JavaStatic )
348
+ .withAnnotations(annots)
349
+ val mainTempl = Template (emptyConstructor, Nil , Nil , EmptyValDef , mainMeth :: Nil )
350
+ val mainCls = TypeDef (mainFun.name.toTypeName, mainTempl)
351
+ .withFlags(Final | Invisible )
352
+ mainCls.withSpan(mainAnnot.tree.span.toSynthetic)
353
+ end generateMainClass
354
+
355
+ if (! mainFun.owner.isStaticOwner)
356
+ report.error(s " main method is not statically accessible " , pos)
357
+ None
358
+ else mainFun.info match {
359
+ case _ : ExprType =>
360
+ Some (generateMainClass(unitToValue(ref(mainFun.termRef)), Nil , Nil ))
361
+ case mt : MethodType =>
362
+ if (mt.isImplicitMethod)
363
+ report.error(s " main method cannot have implicit parameters " , pos)
364
+ None
365
+ else mt.resType match
366
+ case restpe : MethodType =>
367
+ report.error(s " main method cannot be curried " , pos)
368
+ None
369
+ case _ =>
370
+ Some (generateMainClass(unitToValue(Apply (ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt)))
371
+ case _ : PolyType =>
372
+ report.error(s " main method cannot have type parameters " , pos)
373
+ None
374
+ case _ =>
375
+ report.error(s " main can only annotate a method " , pos)
376
+ None
377
+ }
378
+ }
379
+
380
+ /** A class responsible for extracting the docstrings of a method. */
381
+ private class Documentation (docComment : Option [Comment ]):
382
+ import util .CommentParsing ._
383
+
384
+ /** The main part of the documentation. */
385
+ lazy val mainDoc : String = _mainDoc
386
+ /** The parameters identified by @param. Maps from parameter name to its documentation. */
387
+ lazy val argDocs : Map [String , String ] = _argDocs
388
+
389
+ private var _mainDoc : String = " "
390
+ private var _argDocs : Map [String , String ] = Map ()
391
+
392
+ docComment match {
393
+ case Some (comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
394
+ case None =>
395
+ }
396
+
397
+ private def cleanComment (raw : String ): String =
398
+ var lines : Seq [String ] = raw.trim.nn.split('\n ' ).nn.toSeq
399
+ lines = lines.map(l => l.substring(skipLineLead(l, - 1 ), l.length).nn.trim.nn)
400
+ var s = lines.foldLeft(" " ) {
401
+ case (" " , s2) => s2
402
+ case (s1, " " ) if s1.last == '\n ' => s1 // Multiple newlines are kept as single newlines
403
+ case (s1, " " ) => s1 + '\n '
404
+ case (s1, s2) if s1.last == '\n ' => s1 + s2
405
+ case (s1, s2) => s1 + ' ' + s2
406
+ }
407
+ s.replaceAll(raw " \[\[ " , " " ).nn.replaceAll(raw " \]\] " , " " ).nn.trim.nn
408
+
409
+ private def parseDocComment (raw : String ): Unit =
410
+ // Positions of the sections (@) in the docstring
411
+ val tidx : List [(Int , Int )] = tagIndex(raw)
412
+
413
+ // Parse main comment
414
+ var mainComment : String = raw.substring(skipLineLead(raw, 0 ), startTag(raw, tidx)).nn
415
+ _mainDoc = cleanComment(mainComment)
416
+
417
+ // Parse arguments comments
418
+ val argsCommentsSpans : Map [String , (Int , Int )] = paramDocs(raw, " @param" , tidx)
419
+ val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
420
+ val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn })
421
+ _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
422
+ end Documentation
119
423
}
0 commit comments