@@ -2,30 +2,40 @@ package dotty.tools.dotc
2
2
package ast
3
3
4
4
import core ._
5
- import Symbols ._ , Types ._ , Contexts ._ , Flags ._ , Constants ._
6
- import StdNames .nme
7
-
8
- /** Generate proxy classes for @main functions.
9
- * A function like
10
- *
11
- * @main def f(x: S, ys: T*) = ...
12
- *
13
- * would be translated to something like
14
- *
15
- * import CommandLineParser._
16
- * class f {
17
- * @static def main(args: Array[String]): Unit =
18
- * try
19
- * f(
20
- * parseArgument[S](args, 0),
21
- * parseRemainingArguments[T](args, 1): _*
22
- * )
23
- * catch case err: ParseError => showError(err)
24
- * }
25
- */
5
+ import Symbols ._ , Types ._ , Contexts ._ , Decorators ._ , util .Spans ._ , Flags ._ , Constants ._
6
+ import StdNames .{nme , tpnme }
7
+ import ast .Trees ._
8
+ import Names .Name
9
+ import Comments .Comment
10
+ import NameKinds .DefaultGetterName
11
+ import Annotations .Annotation
12
+
26
13
object MainProxies {
27
14
28
- def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
15
+ /** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */
16
+ def proxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
17
+ mainAnnotationProxies(stats) ++ mainProxies(stats)
18
+ }
19
+
20
+ /** Generate proxy classes for @main functions.
21
+ * A function like
22
+ *
23
+ * @main def f(x: S, ys: T*) = ...
24
+ *
25
+ * would be translated to something like
26
+ *
27
+ * import CommandLineParser._
28
+ * class f {
29
+ * @static def main(args: Array[String]): Unit =
30
+ * try
31
+ * f(
32
+ * parseArgument[S](args, 0),
33
+ * parseRemainingArguments[T](args, 1): _*
34
+ * )
35
+ * catch case err: ParseError => showError(err)
36
+ * }
37
+ */
38
+ private def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
29
39
import tpd ._
30
40
def mainMethods (stats : List [Tree ]): List [Symbol ] = stats.flatMap {
31
41
case stat : DefDef if stat.symbol.hasAnnotation(defn.MainAnnot ) =>
@@ -39,7 +49,7 @@ object MainProxies {
39
49
}
40
50
41
51
import untpd ._
42
- def mainProxy (mainFun : Symbol )(using Context ): List [TypeDef ] = {
52
+ private def mainProxy (mainFun : Symbol )(using Context ): List [TypeDef ] = {
43
53
val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot ).get.tree.span
44
54
def pos = mainFun.sourcePos
45
55
val argsRef = Ident (nme.args)
@@ -114,4 +124,298 @@ object MainProxies {
114
124
}
115
125
result
116
126
}
127
+
128
+ private type DefaultValueSymbols = Map [Int , Symbol ]
129
+ private type ParameterAnnotationss = Seq [Seq [Annotation ]]
130
+
131
+ /**
132
+ * Generate proxy classes for main functions.
133
+ * A function like
134
+ *
135
+ * /* *
136
+ * * Lorem ipsum dolor sit amet
137
+ * * consectetur adipiscing elit.
138
+ * *
139
+ * * @param x my param x
140
+ * * @param ys all my params y
141
+ * */
142
+ * @myMain(80) def f(
143
+ * @myMain.Alias("myX") x: S,
144
+ * ys: T*
145
+ * ) = ...
146
+ *
147
+ * would be translated to something like
148
+ *
149
+ * final class f {
150
+ * static def main(args: Array[String]): Unit = {
151
+ * val cmd = new myMain(80).command(
152
+ * info = new CommandInfo(
153
+ * name = "f",
154
+ * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
155
+ * parameters = Seq(
156
+ * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
157
+ * new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq())
158
+ * )
159
+ * )
160
+ * args = args
161
+ * )
162
+ *
163
+ * val args0: () => S = cmd.argGetter[S](0, None)
164
+ * val args1: () => Seq[T] = cmd.varargGetter[T]
165
+ *
166
+ * cmd.run(() => f(args0(), args1()*))
167
+ * }
168
+ * }
169
+ */
170
+ private def mainAnnotationProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
171
+ import tpd ._
172
+
173
+ /**
174
+ * Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this
175
+ * point of the compilation, they must be explicitly passed by [[mainProxy ]].
176
+ */
177
+ def defaultValueSymbols (scope : Tree , funSymbol : Symbol ): DefaultValueSymbols =
178
+ scope match {
179
+ case TypeDef (_, template : Template ) =>
180
+ template.body.flatMap((_ : Tree ) match {
181
+ case dd : DefDef if dd.name.is(DefaultGetterName ) && dd.name.firstPart == funSymbol.name =>
182
+ val DefaultGetterName .NumberedInfo (index) = dd.name.info
183
+ List (index -> dd.symbol)
184
+ case _ => Nil
185
+ }).toMap
186
+ case _ => Map .empty
187
+ }
188
+
189
+ /** Computes the list of main methods present in the code. */
190
+ def mainMethods (scope : Tree , stats : List [Tree ]): List [(Symbol , ParameterAnnotationss , DefaultValueSymbols , Option [Comment ])] = stats.flatMap {
191
+ case stat : DefDef =>
192
+ val sym = stat.symbol
193
+ sym.annotations.filter(_.matches(defn.MainAnnotationClass )) match {
194
+ case Nil =>
195
+ Nil
196
+ case _ :: Nil =>
197
+ val paramAnnotations = stat.paramss.flatMap(_.map(
198
+ valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation ))
199
+ ))
200
+ (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
201
+ case mainAnnot :: others =>
202
+ report.error(s " method cannot have multiple main annotations " , mainAnnot.tree)
203
+ Nil
204
+ }
205
+ case stat @ TypeDef (_, impl : Template ) if stat.symbol.is(Module ) =>
206
+ mainMethods(stat, impl.body)
207
+ case _ =>
208
+ Nil
209
+ }
210
+
211
+ // Assuming that the top-level object was already generated, all main methods will have a scope
212
+ mainMethods(EmptyTree , stats).flatMap(mainAnnotationProxy)
213
+ }
214
+
215
+ private def mainAnnotationProxy (mainFun : Symbol , paramAnnotations : ParameterAnnotationss , defaultValueSymbols : DefaultValueSymbols , docComment : Option [Comment ])(using Context ): Option [TypeDef ] = {
216
+ val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass ).get
217
+ def pos = mainFun.sourcePos
218
+
219
+ val documentation = new Documentation (docComment)
220
+
221
+ /** () => value */
222
+ def unitToValue (value : Tree ): Tree =
223
+ val defDef = DefDef (nme.ANON_FUN , List (Nil ), TypeTree (), value)
224
+ Block (defDef, Closure (Nil , Ident (nme.ANON_FUN ), EmptyTree ))
225
+
226
+ /** Generate a list of trees containing the ParamInfo instantiations.
227
+ *
228
+ * A ParamInfo has the following shape
229
+ * ```
230
+ * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
231
+ * ```
232
+ */
233
+ def parameterInfos (mt : MethodType ): List [Tree ] =
234
+ extension (tree : Tree ) def withProperty (sym : Symbol , args : List [Tree ]) =
235
+ Apply (Select (tree, sym.name), args)
236
+
237
+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
238
+ val param = paramName.toString
239
+ val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias
240
+ val paramType = paramType0.dealias
241
+
242
+ val paramTypeStr = formal.dealias.typeSymbol.owner.showFullName + " ." + paramType.show
243
+ val hasDefault = defaultValueSymbols.contains(idx)
244
+ val isRepeated = formal.isRepeatedParam
245
+ val paramDoc = documentation.argDocs.getOrElse(param, " " )
246
+ val paramAnnots =
247
+ val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList
248
+ Apply (ref(defn.SeqModule .termRef), annotationTrees)
249
+
250
+ val constructorArgs = List (param, paramTypeStr, hasDefault, isRepeated, paramDoc)
251
+ .map(value => Literal (Constant (value)))
252
+
253
+ New (TypeTree (defn.MainAnnotationParameterInfo .typeRef), List (constructorArgs :+ paramAnnots))
254
+
255
+ end parameterInfos
256
+
257
+ /**
258
+ * Creates a list of references and definitions of arguments.
259
+ * The goal is to create the
260
+ * `val args0: () => S = cmd.argGetter[S](0, None)`
261
+ * part of the code.
262
+ */
263
+ def argValDefs (mt : MethodType ): List [ValDef ] =
264
+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
265
+ val argName = nme.args ++ idx.toString
266
+ val isRepeated = formal.isRepeatedParam
267
+ val formalType = if isRepeated then formal.argTypes.head else formal
268
+ val getterName = if isRepeated then nme.varargGetter else nme.argGetter
269
+ val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
270
+ case None => ref(defn.NoneModule .termRef)
271
+ case Some (dvSym) =>
272
+ val value = unitToValue(ref(dvSym.termRef))
273
+ Apply (ref(defn.SomeClass .companionModule.termRef), value)
274
+ val argGetter0 = TypeApply (Select (Ident (nme.cmd), getterName), TypeTree (formalType) :: Nil )
275
+ val argGetter =
276
+ if isRepeated then argGetter0
277
+ else Apply (argGetter0, List (Literal (Constant (idx)), defaultValueGetterOpt))
278
+
279
+ ValDef (argName, TypeTree (), argGetter)
280
+ end argValDefs
281
+
282
+
283
+ /** Create a list of argument references that will be passed as argument to the main method.
284
+ * `args0`, ...`argn*`
285
+ */
286
+ def argRefs (mt : MethodType ): List [Tree ] =
287
+ for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
288
+ val argRef = Apply (Ident (nme.args ++ idx.toString), Nil )
289
+ if formal.isRepeatedParam then repeated(argRef) else argRef
290
+ end argRefs
291
+
292
+
293
+ /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
294
+ def instantiateAnnotation (annot : Annotation ): Tree =
295
+ val argss = {
296
+ def recurse (t : tpd.Tree , acc : List [List [Tree ]]): List [List [Tree ]] = t match {
297
+ case Apply (t, args : List [tpd.Tree ]) => recurse(t, extractArgs(args) :: acc)
298
+ case _ => acc
299
+ }
300
+
301
+ def extractArgs (args : List [tpd.Tree ]): List [Tree ] =
302
+ args.flatMap {
303
+ case Typed (SeqLiteral (varargs, _), _) => varargs.map(arg => TypedSplice (arg))
304
+ case arg : Select if arg.name.is(DefaultGetterName ) => Nil // Ignore default values, they will be added later by the compiler
305
+ case arg => List (TypedSplice (arg))
306
+ }
307
+
308
+ recurse(annot.tree, Nil )
309
+ }
310
+
311
+ New (TypeTree (annot.symbol.typeRef), argss)
312
+ end instantiateAnnotation
313
+
314
+ def generateMainClass (mainCall : Tree , args : List [Tree ], parameterInfos : List [Tree ]): TypeDef =
315
+ val cmdInfo =
316
+ val nameTree = Literal (Constant (mainFun.showName))
317
+ val docTree = Literal (Constant (documentation.mainDoc))
318
+ val paramInfos = Apply (ref(defn.SeqModule .termRef), parameterInfos)
319
+ New (TypeTree (defn.MainAnnotationCommandInfo .typeRef), List (List (nameTree, docTree, paramInfos)))
320
+
321
+ val cmd = ValDef (
322
+ nme.cmd,
323
+ TypeTree (),
324
+ Apply (
325
+ Select (instantiateAnnotation(mainAnnot), nme.command),
326
+ List (cmdInfo, Ident (nme.args))
327
+ )
328
+ )
329
+ val run = Apply (Select (Ident (nme.cmd), nme.run), mainCall)
330
+ val body = Block (cmdInfo :: cmd :: args, run)
331
+ val mainArg = ValDef (nme.args, TypeTree (defn.ArrayType .appliedTo(defn.StringType )), EmptyTree )
332
+ .withFlags(Param )
333
+ /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
334
+ * The annotations will be retype-checked in another scope that may not have the same imports.
335
+ */
336
+ def insertTypeSplices = new TreeMap {
337
+ override def transform (tree : Tree )(using Context ): Tree = tree match
338
+ case tree : tpd.Ident @ unchecked => TypedSplice (tree)
339
+ case tree => super .transform(tree)
340
+ }
341
+ val annots = mainFun.annotations
342
+ .filterNot(_.matches(defn.MainAnnotationClass ))
343
+ .map(annot => insertTypeSplices.transform(annot.tree))
344
+ val mainMeth = DefDef (nme.main, (mainArg :: Nil ) :: Nil , TypeTree (defn.UnitType ), body)
345
+ .withFlags(JavaStatic )
346
+ .withAnnotations(annots)
347
+ val mainTempl = Template (emptyConstructor, Nil , Nil , EmptyValDef , mainMeth :: Nil )
348
+ val mainCls = TypeDef (mainFun.name.toTypeName, mainTempl)
349
+ .withFlags(Final | Invisible )
350
+ mainCls.withSpan(mainAnnot.tree.span.toSynthetic)
351
+ end generateMainClass
352
+
353
+ if (! mainFun.owner.isStaticOwner)
354
+ report.error(s " main method is not statically accessible " , pos)
355
+ None
356
+ else mainFun.info match {
357
+ case _ : ExprType =>
358
+ Some (generateMainClass(unitToValue(ref(mainFun.termRef)), Nil , Nil ))
359
+ case mt : MethodType =>
360
+ if (mt.isImplicitMethod)
361
+ report.error(s " main method cannot have implicit parameters " , pos)
362
+ None
363
+ else mt.resType match
364
+ case restpe : MethodType =>
365
+ report.error(s " main method cannot be curried " , pos)
366
+ None
367
+ case _ =>
368
+ Some (generateMainClass(unitToValue(Apply (ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt)))
369
+ case _ : PolyType =>
370
+ report.error(s " main method cannot have type parameters " , pos)
371
+ None
372
+ case _ =>
373
+ report.error(s " main can only annotate a method " , pos)
374
+ None
375
+ }
376
+ }
377
+
378
+ /** A class responsible for extracting the docstrings of a method. */
379
+ private class Documentation (docComment : Option [Comment ]):
380
+ import util .CommentParsing ._
381
+
382
+ /** The main part of the documentation. */
383
+ lazy val mainDoc : String = _mainDoc
384
+ /** The parameters identified by @param. Maps from parameter name to its documentation. */
385
+ lazy val argDocs : Map [String , String ] = _argDocs
386
+
387
+ private var _mainDoc : String = " "
388
+ private var _argDocs : Map [String , String ] = Map ()
389
+
390
+ docComment match {
391
+ case Some (comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
392
+ case None =>
393
+ }
394
+
395
+ private def cleanComment (raw : String ): String =
396
+ var lines : Seq [String ] = raw.trim.nn.split('\n ' ).nn.toSeq
397
+ lines = lines.map(l => l.substring(skipLineLead(l, - 1 ), l.length).nn.trim.nn)
398
+ var s = lines.foldLeft(" " ) {
399
+ case (" " , s2) => s2
400
+ case (s1, " " ) if s1.last == '\n ' => s1 // Multiple newlines are kept as single newlines
401
+ case (s1, " " ) => s1 + '\n '
402
+ case (s1, s2) if s1.last == '\n ' => s1 + s2
403
+ case (s1, s2) => s1 + ' ' + s2
404
+ }
405
+ s.replaceAll(raw " \[\[ " , " " ).nn.replaceAll(raw " \]\] " , " " ).nn.trim.nn
406
+
407
+ private def parseDocComment (raw : String ): Unit =
408
+ // Positions of the sections (@) in the docstring
409
+ val tidx : List [(Int , Int )] = tagIndex(raw)
410
+
411
+ // Parse main comment
412
+ var mainComment : String = raw.substring(skipLineLead(raw, 0 ), startTag(raw, tidx)).nn
413
+ _mainDoc = cleanComment(mainComment)
414
+
415
+ // Parse arguments comments
416
+ val argsCommentsSpans : Map [String , (Int , Int )] = paramDocs(raw, " @param" , tidx)
417
+ val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
418
+ val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn })
419
+ _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
420
+ end Documentation
117
421
}
0 commit comments