Skip to content

Commit e852aa7

Browse files
timotheeandresnicolasstucki
authored andcommitted
Add scala.annotation.MainAnnotation
See `docs/_docs/reference/experimental/main-annotation.md`
1 parent bdddd41 commit e852aa7

26 files changed

+1217
-29
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

Lines changed: 327 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,40 @@ package dotty.tools.dotc
22
package ast
33

44
import core._
5-
import Symbols._, Types._, Contexts._, Flags._, Constants._
6-
import StdNames.nme
7-
8-
/** Generate proxy classes for @main functions.
9-
* A function like
10-
*
11-
* @main def f(x: S, ys: T*) = ...
12-
*
13-
* would be translated to something like
14-
*
15-
* import CommandLineParser._
16-
* class f {
17-
* @static def main(args: Array[String]): Unit =
18-
* try
19-
* f(
20-
* parseArgument[S](args, 0),
21-
* parseRemainingArguments[T](args, 1): _*
22-
* )
23-
* catch case err: ParseError => showError(err)
24-
* }
25-
*/
5+
import Symbols._, Types._, Contexts._, Decorators._, util.Spans._, Flags._, Constants._
6+
import StdNames.{nme, tpnme}
7+
import ast.Trees._
8+
import Names.Name
9+
import Comments.Comment
10+
import NameKinds.DefaultGetterName
11+
import Annotations.Annotation
12+
2613
object MainProxies {
2714

28-
def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
15+
/** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */
16+
def proxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
17+
mainAnnotationProxies(stats) ++ mainProxies(stats)
18+
}
19+
20+
/** Generate proxy classes for @main functions.
21+
* A function like
22+
*
23+
* @main def f(x: S, ys: T*) = ...
24+
*
25+
* would be translated to something like
26+
*
27+
* import CommandLineParser._
28+
* class f {
29+
* @static def main(args: Array[String]): Unit =
30+
* try
31+
* f(
32+
* parseArgument[S](args, 0),
33+
* parseRemainingArguments[T](args, 1): _*
34+
* )
35+
* catch case err: ParseError => showError(err)
36+
* }
37+
*/
38+
private def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
2939
import tpd._
3040
def mainMethods(stats: List[Tree]): List[Symbol] = stats.flatMap {
3141
case stat: DefDef if stat.symbol.hasAnnotation(defn.MainAnnot) =>
@@ -39,7 +49,7 @@ object MainProxies {
3949
}
4050

4151
import untpd._
42-
def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = {
52+
private def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = {
4353
val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot).get.tree.span
4454
def pos = mainFun.sourcePos
4555
val argsRef = Ident(nme.args)
@@ -116,4 +126,298 @@ object MainProxies {
116126
}
117127
result
118128
}
129+
130+
private type DefaultValueSymbols = Map[Int, Symbol]
131+
private type ParameterAnnotationss = Seq[Seq[Annotation]]
132+
133+
/**
134+
* Generate proxy classes for main functions.
135+
* A function like
136+
*
137+
* /**
138+
* * Lorem ipsum dolor sit amet
139+
* * consectetur adipiscing elit.
140+
* *
141+
* * @param x my param x
142+
* * @param ys all my params y
143+
* */
144+
* @myMain(80) def f(
145+
* @myMain.Alias("myX") x: S,
146+
* ys: T*
147+
* ) = ...
148+
*
149+
* would be translated to something like
150+
*
151+
* final class f {
152+
* static def main(args: Array[String]): Unit = {
153+
* val cmd = new myMain(80).command(
154+
* info = new CommandInfo(
155+
* name = "f",
156+
* documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
157+
* parameters = Seq(
158+
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
159+
* new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq())
160+
* )
161+
* )
162+
* args = args
163+
* )
164+
*
165+
* val args0: () => S = cmd.argGetter[S](0, None)
166+
* val args1: () => Seq[T] = cmd.varargGetter[T]
167+
*
168+
* cmd.run(() => f(args0(), args1()*))
169+
* }
170+
* }
171+
*/
172+
private def mainAnnotationProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
173+
import tpd._
174+
175+
/**
176+
* Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this
177+
* point of the compilation, they must be explicitly passed by [[mainProxy]].
178+
*/
179+
def defaultValueSymbols(scope: Tree, funSymbol: Symbol): DefaultValueSymbols =
180+
scope match {
181+
case TypeDef(_, template: Template) =>
182+
template.body.flatMap((_: Tree) match {
183+
case dd: DefDef if dd.name.is(DefaultGetterName) && dd.name.firstPart == funSymbol.name =>
184+
val DefaultGetterName.NumberedInfo(index) = dd.name.info
185+
List(index -> dd.symbol)
186+
case _ => Nil
187+
}).toMap
188+
case _ => Map.empty
189+
}
190+
191+
/** Computes the list of main methods present in the code. */
192+
def mainMethods(scope: Tree, stats: List[Tree]): List[(Symbol, ParameterAnnotationss, DefaultValueSymbols, Option[Comment])] = stats.flatMap {
193+
case stat: DefDef =>
194+
val sym = stat.symbol
195+
sym.annotations.filter(_.matches(defn.MainAnnotationClass)) match {
196+
case Nil =>
197+
Nil
198+
case _ :: Nil =>
199+
val paramAnnotations = stat.paramss.flatMap(_.map(
200+
valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation))
201+
))
202+
(sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
203+
case mainAnnot :: others =>
204+
report.error(s"method cannot have multiple main annotations", mainAnnot.tree)
205+
Nil
206+
}
207+
case stat @ TypeDef(_, impl: Template) if stat.symbol.is(Module) =>
208+
mainMethods(stat, impl.body)
209+
case _ =>
210+
Nil
211+
}
212+
213+
// Assuming that the top-level object was already generated, all main methods will have a scope
214+
mainMethods(EmptyTree, stats).flatMap(mainAnnotationProxy)
215+
}
216+
217+
private def mainAnnotationProxy(mainFun: Symbol, paramAnnotations: ParameterAnnotationss, defaultValueSymbols: DefaultValueSymbols, docComment: Option[Comment])(using Context): Option[TypeDef] = {
218+
val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass).get
219+
def pos = mainFun.sourcePos
220+
221+
val documentation = new Documentation(docComment)
222+
223+
/** () => value */
224+
def unitToValue(value: Tree): Tree =
225+
val defDef = DefDef(nme.ANON_FUN, List(Nil), TypeTree(), value)
226+
Block(defDef, Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
227+
228+
/** Generate a list of trees containing the ParamInfo instantiations.
229+
*
230+
* A ParamInfo has the following shape
231+
* ```
232+
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
233+
* ```
234+
*/
235+
def parameterInfos(mt: MethodType): List[Tree] =
236+
extension (tree: Tree) def withProperty(sym: Symbol, args: List[Tree]) =
237+
Apply(Select(tree, sym.name), args)
238+
239+
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
240+
val param = paramName.toString
241+
val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias
242+
val paramType = paramType0.dealias
243+
244+
val paramTypeStr = formal.dealias.typeSymbol.owner.showFullName + "." + paramType.show
245+
val hasDefault = defaultValueSymbols.contains(idx)
246+
val isRepeated = formal.isRepeatedParam
247+
val paramDoc = documentation.argDocs.getOrElse(param, "")
248+
val paramAnnots =
249+
val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList
250+
Apply(ref(defn.SeqModule.termRef), annotationTrees)
251+
252+
val constructorArgs = List(param, paramTypeStr, hasDefault, isRepeated, paramDoc)
253+
.map(value => Literal(Constant(value)))
254+
255+
New(TypeTree(defn.MainAnnotationParameterInfo.typeRef), List(constructorArgs :+ paramAnnots))
256+
257+
end parameterInfos
258+
259+
/**
260+
* Creates a list of references and definitions of arguments.
261+
* The goal is to create the
262+
* `val args0: () => S = cmd.argGetter[S](0, None)`
263+
* part of the code.
264+
*/
265+
def argValDefs(mt: MethodType): List[ValDef] =
266+
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
267+
val argName = nme.args ++ idx.toString
268+
val isRepeated = formal.isRepeatedParam
269+
val formalType = if isRepeated then formal.argTypes.head else formal
270+
val getterName = if isRepeated then nme.varargGetter else nme.argGetter
271+
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
272+
case None => ref(defn.NoneModule.termRef)
273+
case Some(dvSym) =>
274+
val value = unitToValue(ref(dvSym.termRef))
275+
Apply(ref(defn.SomeClass.companionModule.termRef), value)
276+
val argGetter0 = TypeApply(Select(Ident(nme.cmd), getterName), TypeTree(formalType) :: Nil)
277+
val argGetter =
278+
if isRepeated then argGetter0
279+
else Apply(argGetter0, List(Literal(Constant(idx)), defaultValueGetterOpt))
280+
281+
ValDef(argName, TypeTree(), argGetter)
282+
end argValDefs
283+
284+
285+
/** Create a list of argument references that will be passed as argument to the main method.
286+
* `args0`, ...`argn*`
287+
*/
288+
def argRefs(mt: MethodType): List[Tree] =
289+
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
290+
val argRef = Apply(Ident(nme.args ++ idx.toString), Nil)
291+
if formal.isRepeatedParam then repeated(argRef) else argRef
292+
end argRefs
293+
294+
295+
/** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
296+
def instantiateAnnotation(annot: Annotation): Tree =
297+
val argss = {
298+
def recurse(t: tpd.Tree, acc: List[List[Tree]]): List[List[Tree]] = t match {
299+
case Apply(t, args: List[tpd.Tree]) => recurse(t, extractArgs(args) :: acc)
300+
case _ => acc
301+
}
302+
303+
def extractArgs(args: List[tpd.Tree]): List[Tree] =
304+
args.flatMap {
305+
case Typed(SeqLiteral(varargs, _), _) => varargs.map(arg => TypedSplice(arg))
306+
case arg: Select if arg.name.is(DefaultGetterName) => Nil // Ignore default values, they will be added later by the compiler
307+
case arg => List(TypedSplice(arg))
308+
}
309+
310+
recurse(annot.tree, Nil)
311+
}
312+
313+
New(TypeTree(annot.symbol.typeRef), argss)
314+
end instantiateAnnotation
315+
316+
def generateMainClass(mainCall: Tree, args: List[Tree], parameterInfos: List[Tree]): TypeDef =
317+
val cmdInfo =
318+
val nameTree = Literal(Constant(mainFun.showName))
319+
val docTree = Literal(Constant(documentation.mainDoc))
320+
val paramInfos = Apply(ref(defn.SeqModule.termRef), parameterInfos)
321+
New(TypeTree(defn.MainAnnotationCommandInfo.typeRef), List(List(nameTree, docTree, paramInfos)))
322+
323+
val cmd = ValDef(
324+
nme.cmd,
325+
TypeTree(),
326+
Apply(
327+
Select(instantiateAnnotation(mainAnnot), nme.command),
328+
List(cmdInfo, Ident(nme.args))
329+
)
330+
)
331+
val run = Apply(Select(Ident(nme.cmd), nme.run), mainCall)
332+
val body = Block(cmdInfo :: cmd :: args, run)
333+
val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree)
334+
.withFlags(Param)
335+
/** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
336+
* The annotations will be retype-checked in another scope that may not have the same imports.
337+
*/
338+
def insertTypeSplices = new TreeMap {
339+
override def transform(tree: Tree)(using Context): Tree = tree match
340+
case tree: tpd.Ident @unchecked => TypedSplice(tree)
341+
case tree => super.transform(tree)
342+
}
343+
val annots = mainFun.annotations
344+
.filterNot(_.matches(defn.MainAnnotationClass))
345+
.map(annot => insertTypeSplices.transform(annot.tree))
346+
val mainMeth = DefDef(nme.main, (mainArg :: Nil) :: Nil, TypeTree(defn.UnitType), body)
347+
.withFlags(JavaStatic)
348+
.withAnnotations(annots)
349+
val mainTempl = Template(emptyConstructor, Nil, Nil, EmptyValDef, mainMeth :: Nil)
350+
val mainCls = TypeDef(mainFun.name.toTypeName, mainTempl)
351+
.withFlags(Final | Invisible)
352+
mainCls.withSpan(mainAnnot.tree.span.toSynthetic)
353+
end generateMainClass
354+
355+
if (!mainFun.owner.isStaticOwner)
356+
report.error(s"main method is not statically accessible", pos)
357+
None
358+
else mainFun.info match {
359+
case _: ExprType =>
360+
Some(generateMainClass(unitToValue(ref(mainFun.termRef)), Nil, Nil))
361+
case mt: MethodType =>
362+
if (mt.isImplicitMethod)
363+
report.error(s"main method cannot have implicit parameters", pos)
364+
None
365+
else mt.resType match
366+
case restpe: MethodType =>
367+
report.error(s"main method cannot be curried", pos)
368+
None
369+
case _ =>
370+
Some(generateMainClass(unitToValue(Apply(ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt)))
371+
case _: PolyType =>
372+
report.error(s"main method cannot have type parameters", pos)
373+
None
374+
case _ =>
375+
report.error(s"main can only annotate a method", pos)
376+
None
377+
}
378+
}
379+
380+
/** A class responsible for extracting the docstrings of a method. */
381+
private class Documentation(docComment: Option[Comment]):
382+
import util.CommentParsing._
383+
384+
/** The main part of the documentation. */
385+
lazy val mainDoc: String = _mainDoc
386+
/** The parameters identified by @param. Maps from parameter name to its documentation. */
387+
lazy val argDocs: Map[String, String] = _argDocs
388+
389+
private var _mainDoc: String = ""
390+
private var _argDocs: Map[String, String] = Map()
391+
392+
docComment match {
393+
case Some(comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
394+
case None =>
395+
}
396+
397+
private def cleanComment(raw: String): String =
398+
var lines: Seq[String] = raw.trim.nn.split('\n').nn.toSeq
399+
lines = lines.map(l => l.substring(skipLineLead(l, -1), l.length).nn.trim.nn)
400+
var s = lines.foldLeft("") {
401+
case ("", s2) => s2
402+
case (s1, "") if s1.last == '\n' => s1 // Multiple newlines are kept as single newlines
403+
case (s1, "") => s1 + '\n'
404+
case (s1, s2) if s1.last == '\n' => s1 + s2
405+
case (s1, s2) => s1 + ' ' + s2
406+
}
407+
s.replaceAll(raw"\[\[", "").nn.replaceAll(raw"\]\]", "").nn.trim.nn
408+
409+
private def parseDocComment(raw: String): Unit =
410+
// Positions of the sections (@) in the docstring
411+
val tidx: List[(Int, Int)] = tagIndex(raw)
412+
413+
// Parse main comment
414+
var mainComment: String = raw.substring(skipLineLead(raw, 0), startTag(raw, tidx)).nn
415+
_mainDoc = cleanComment(mainComment)
416+
417+
// Parse arguments comments
418+
val argsCommentsSpans: Map[String, (Int, Int)] = paramDocs(raw, "@param", tidx)
419+
val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
420+
val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn })
421+
_argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
422+
end Documentation
119423
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,8 @@ class Definitions {
528528
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
529529
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
530530
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
531+
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
532+
531533

532534
@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
533535
@tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format)
@@ -853,6 +855,12 @@ class Definitions {
853855

854856
@tu lazy val XMLTopScopeModule: Symbol = requiredModule("scala.xml.TopScope")
855857

858+
@tu lazy val MainAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MainAnnotation")
859+
@tu lazy val MainAnnotationCommandInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.CommandInfo")
860+
@tu lazy val MainAnnotationParameterInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterInfo")
861+
@tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation")
862+
@tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command")
863+
856864
@tu lazy val CommandLineParserModule: Symbol = requiredModule("scala.util.CommandLineParser")
857865
@tu lazy val CLP_ParseError: ClassSymbol = CommandLineParserModule.requiredClass("ParseError").typeRef.symbol.asClass
858866
@tu lazy val CLP_parseArgument: Symbol = CommandLineParserModule.requiredMethod("parseArgument")

0 commit comments

Comments
 (0)