Skip to content

Commit d91bad7

Browse files
timotheeandresnicolasstucki
authored andcommitted
Add scala.annotation.MainAnnotation
See `docs/_docs/reference/experimental/main-annotation.md`
1 parent 220b753 commit d91bad7

26 files changed

+1220
-27
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

Lines changed: 327 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,40 @@ package dotty.tools.dotc
22
package ast
33

44
import core._
5-
import Symbols._, Types._, Contexts._, Flags._, Constants._
6-
import StdNames.nme
7-
8-
/** Generate proxy classes for @main functions.
9-
* A function like
10-
*
11-
* @main def f(x: S, ys: T*) = ...
12-
*
13-
* would be translated to something like
14-
*
15-
* import CommandLineParser._
16-
* class f {
17-
* @static def main(args: Array[String]): Unit =
18-
* try
19-
* f(
20-
* parseArgument[S](args, 0),
21-
* parseRemainingArguments[T](args, 1): _*
22-
* )
23-
* catch case err: ParseError => showError(err)
24-
* }
25-
*/
5+
import Symbols._, Types._, Contexts._, Decorators._, util.Spans._, Flags._, Constants._
6+
import StdNames.{nme, tpnme}
7+
import ast.Trees._
8+
import Names.Name
9+
import Comments.Comment
10+
import NameKinds.DefaultGetterName
11+
import Annotations.Annotation
12+
2613
object MainProxies {
2714

28-
def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
15+
/** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */
16+
def proxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
17+
mainAnnotationProxies(stats) ++ mainProxies(stats)
18+
}
19+
20+
/** Generate proxy classes for @main functions.
21+
* A function like
22+
*
23+
* @main def f(x: S, ys: T*) = ...
24+
*
25+
* would be translated to something like
26+
*
27+
* import CommandLineParser._
28+
* class f {
29+
* @static def main(args: Array[String]): Unit =
30+
* try
31+
* f(
32+
* parseArgument[S](args, 0),
33+
* parseRemainingArguments[T](args, 1): _*
34+
* )
35+
* catch case err: ParseError => showError(err)
36+
* }
37+
*/
38+
private def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
2939
import tpd._
3040
def mainMethods(stats: List[Tree]): List[Symbol] = stats.flatMap {
3141
case stat: DefDef if stat.symbol.hasAnnotation(defn.MainAnnot) =>
@@ -39,7 +49,7 @@ object MainProxies {
3949
}
4050

4151
import untpd._
42-
def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = {
52+
private def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = {
4353
val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot).get.tree.span
4454
def pos = mainFun.sourcePos
4555
val argsRef = Ident(nme.args)
@@ -114,4 +124,298 @@ object MainProxies {
114124
}
115125
result
116126
}
127+
128+
private type DefaultValueSymbols = Map[Int, Symbol]
129+
private type ParameterAnnotationss = Seq[Seq[Annotation]]
130+
131+
/**
132+
* Generate proxy classes for main functions.
133+
* A function like
134+
*
135+
* /**
136+
* * Lorem ipsum dolor sit amet
137+
* * consectetur adipiscing elit.
138+
* *
139+
* * @param x my param x
140+
* * @param ys all my params y
141+
* */
142+
* @myMain(80) def f(
143+
* @myMain.Alias("myX") x: S,
144+
* ys: T*
145+
* ) = ...
146+
*
147+
* would be translated to something like
148+
*
149+
* final class f {
150+
* static def main(args: Array[String]): Unit = {
151+
* val cmd = new myMain(80).command(
152+
* info = new CommandInfo(
153+
* name = "f",
154+
* documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
155+
* parameters = Seq(
156+
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
157+
* new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq())
158+
* )
159+
* )
160+
* args = args
161+
* )
162+
*
163+
* val args0: () => S = cmd.argGetter[S](0, None)
164+
* val args1: () => Seq[T] = cmd.varargGetter[T]
165+
*
166+
* cmd.run(() => f(args0(), args1()*))
167+
* }
168+
* }
169+
*/
170+
private def mainAnnotationProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = {
171+
import tpd._
172+
173+
/**
174+
* Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this
175+
* point of the compilation, they must be explicitly passed by [[mainProxy]].
176+
*/
177+
def defaultValueSymbols(scope: Tree, funSymbol: Symbol): DefaultValueSymbols =
178+
scope match {
179+
case TypeDef(_, template: Template) =>
180+
template.body.flatMap((_: Tree) match {
181+
case dd: DefDef if dd.name.is(DefaultGetterName) && dd.name.firstPart == funSymbol.name =>
182+
val DefaultGetterName.NumberedInfo(index) = dd.name.info
183+
List(index -> dd.symbol)
184+
case _ => Nil
185+
}).toMap
186+
case _ => Map.empty
187+
}
188+
189+
/** Computes the list of main methods present in the code. */
190+
def mainMethods(scope: Tree, stats: List[Tree]): List[(Symbol, ParameterAnnotationss, DefaultValueSymbols, Option[Comment])] = stats.flatMap {
191+
case stat: DefDef =>
192+
val sym = stat.symbol
193+
sym.annotations.filter(_.matches(defn.MainAnnotationClass)) match {
194+
case Nil =>
195+
Nil
196+
case _ :: Nil =>
197+
val paramAnnotations = stat.paramss.flatMap(_.map(
198+
valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation))
199+
))
200+
(sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
201+
case mainAnnot :: others =>
202+
report.error(s"method cannot have multiple main annotations", mainAnnot.tree)
203+
Nil
204+
}
205+
case stat @ TypeDef(_, impl: Template) if stat.symbol.is(Module) =>
206+
mainMethods(stat, impl.body)
207+
case _ =>
208+
Nil
209+
}
210+
211+
// Assuming that the top-level object was already generated, all main methods will have a scope
212+
mainMethods(EmptyTree, stats).flatMap(mainAnnotationProxy)
213+
}
214+
215+
private def mainAnnotationProxy(mainFun: Symbol, paramAnnotations: ParameterAnnotationss, defaultValueSymbols: DefaultValueSymbols, docComment: Option[Comment])(using Context): Option[TypeDef] = {
216+
val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass).get
217+
def pos = mainFun.sourcePos
218+
219+
val documentation = new Documentation(docComment)
220+
221+
/** () => value */
222+
def unitToValue(value: Tree): Tree =
223+
val defDef = DefDef(nme.ANON_FUN, List(Nil), TypeTree(), value)
224+
Block(defDef, Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
225+
226+
/** Generate a list of trees containing the ParamInfo instantiations.
227+
*
228+
* A ParamInfo has the following shape
229+
* ```
230+
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
231+
* ```
232+
*/
233+
def parameterInfos(mt: MethodType): List[Tree] =
234+
extension (tree: Tree) def withProperty(sym: Symbol, args: List[Tree]) =
235+
Apply(Select(tree, sym.name), args)
236+
237+
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
238+
val param = paramName.toString
239+
val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias
240+
val paramType = paramType0.dealias
241+
242+
val paramTypeStr = formal.dealias.typeSymbol.owner.showFullName + "." + paramType.show
243+
val hasDefault = defaultValueSymbols.contains(idx)
244+
val isRepeated = formal.isRepeatedParam
245+
val paramDoc = documentation.argDocs.getOrElse(param, "")
246+
val paramAnnots =
247+
val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList
248+
Apply(ref(defn.SeqModule.termRef), annotationTrees)
249+
250+
val constructorArgs = List(param, paramTypeStr, hasDefault, isRepeated, paramDoc)
251+
.map(value => Literal(Constant(value)))
252+
253+
New(TypeTree(defn.MainAnnotationParameterInfo.typeRef), List(constructorArgs :+ paramAnnots))
254+
255+
end parameterInfos
256+
257+
/**
258+
* Creates a list of references and definitions of arguments.
259+
* The goal is to create the
260+
* `val args0: () => S = cmd.argGetter[S](0, None)`
261+
* part of the code.
262+
*/
263+
def argValDefs(mt: MethodType): List[ValDef] =
264+
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
265+
val argName = nme.args ++ idx.toString
266+
val isRepeated = formal.isRepeatedParam
267+
val formalType = if isRepeated then formal.argTypes.head else formal
268+
val getterName = if isRepeated then nme.varargGetter else nme.argGetter
269+
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
270+
case None => ref(defn.NoneModule.termRef)
271+
case Some(dvSym) =>
272+
val value = unitToValue(ref(dvSym.termRef))
273+
Apply(ref(defn.SomeClass.companionModule.termRef), value)
274+
val argGetter0 = TypeApply(Select(Ident(nme.cmd), getterName), TypeTree(formalType) :: Nil)
275+
val argGetter =
276+
if isRepeated then argGetter0
277+
else Apply(argGetter0, List(Literal(Constant(idx)), defaultValueGetterOpt))
278+
279+
ValDef(argName, TypeTree(), argGetter)
280+
end argValDefs
281+
282+
283+
/** Create a list of argument references that will be passed as argument to the main method.
284+
* `args0`, ...`argn*`
285+
*/
286+
def argRefs(mt: MethodType): List[Tree] =
287+
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
288+
val argRef = Apply(Ident(nme.args ++ idx.toString), Nil)
289+
if formal.isRepeatedParam then repeated(argRef) else argRef
290+
end argRefs
291+
292+
293+
/** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
294+
def instantiateAnnotation(annot: Annotation): Tree =
295+
val argss = {
296+
def recurse(t: tpd.Tree, acc: List[List[Tree]]): List[List[Tree]] = t match {
297+
case Apply(t, args: List[tpd.Tree]) => recurse(t, extractArgs(args) :: acc)
298+
case _ => acc
299+
}
300+
301+
def extractArgs(args: List[tpd.Tree]): List[Tree] =
302+
args.flatMap {
303+
case Typed(SeqLiteral(varargs, _), _) => varargs.map(arg => TypedSplice(arg))
304+
case arg: Select if arg.name.is(DefaultGetterName) => Nil // Ignore default values, they will be added later by the compiler
305+
case arg => List(TypedSplice(arg))
306+
}
307+
308+
recurse(annot.tree, Nil)
309+
}
310+
311+
New(TypeTree(annot.symbol.typeRef), argss)
312+
end instantiateAnnotation
313+
314+
def generateMainClass(mainCall: Tree, args: List[Tree], parameterInfos: List[Tree]): TypeDef =
315+
val cmdInfo =
316+
val nameTree = Literal(Constant(mainFun.showName))
317+
val docTree = Literal(Constant(documentation.mainDoc))
318+
val paramInfos = Apply(ref(defn.SeqModule.termRef), parameterInfos)
319+
New(TypeTree(defn.MainAnnotationCommandInfo.typeRef), List(List(nameTree, docTree, paramInfos)))
320+
321+
val cmd = ValDef(
322+
nme.cmd,
323+
TypeTree(),
324+
Apply(
325+
Select(instantiateAnnotation(mainAnnot), nme.command),
326+
List(cmdInfo, Ident(nme.args))
327+
)
328+
)
329+
val run = Apply(Select(Ident(nme.cmd), nme.run), mainCall)
330+
val body = Block(cmdInfo :: cmd :: args, run)
331+
val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree)
332+
.withFlags(Param)
333+
/** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
334+
* The annotations will be retype-checked in another scope that may not have the same imports.
335+
*/
336+
def insertTypeSplices = new TreeMap {
337+
override def transform(tree: Tree)(using Context): Tree = tree match
338+
case tree: tpd.Ident @unchecked => TypedSplice(tree)
339+
case tree => super.transform(tree)
340+
}
341+
val annots = mainFun.annotations
342+
.filterNot(_.matches(defn.MainAnnotationClass))
343+
.map(annot => insertTypeSplices.transform(annot.tree))
344+
val mainMeth = DefDef(nme.main, (mainArg :: Nil) :: Nil, TypeTree(defn.UnitType), body)
345+
.withFlags(JavaStatic)
346+
.withAnnotations(annots)
347+
val mainTempl = Template(emptyConstructor, Nil, Nil, EmptyValDef, mainMeth :: Nil)
348+
val mainCls = TypeDef(mainFun.name.toTypeName, mainTempl)
349+
.withFlags(Final | Invisible)
350+
mainCls.withSpan(mainAnnot.tree.span.toSynthetic)
351+
end generateMainClass
352+
353+
if (!mainFun.owner.isStaticOwner)
354+
report.error(s"main method is not statically accessible", pos)
355+
None
356+
else mainFun.info match {
357+
case _: ExprType =>
358+
Some(generateMainClass(unitToValue(ref(mainFun.termRef)), Nil, Nil))
359+
case mt: MethodType =>
360+
if (mt.isImplicitMethod)
361+
report.error(s"main method cannot have implicit parameters", pos)
362+
None
363+
else mt.resType match
364+
case restpe: MethodType =>
365+
report.error(s"main method cannot be curried", pos)
366+
None
367+
case _ =>
368+
Some(generateMainClass(unitToValue(Apply(ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt)))
369+
case _: PolyType =>
370+
report.error(s"main method cannot have type parameters", pos)
371+
None
372+
case _ =>
373+
report.error(s"main can only annotate a method", pos)
374+
None
375+
}
376+
}
377+
378+
/** A class responsible for extracting the docstrings of a method. */
379+
private class Documentation(docComment: Option[Comment]):
380+
import util.CommentParsing._
381+
382+
/** The main part of the documentation. */
383+
lazy val mainDoc: String = _mainDoc
384+
/** The parameters identified by @param. Maps from parameter name to its documentation. */
385+
lazy val argDocs: Map[String, String] = _argDocs
386+
387+
private var _mainDoc: String = ""
388+
private var _argDocs: Map[String, String] = Map()
389+
390+
docComment match {
391+
case Some(comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
392+
case None =>
393+
}
394+
395+
private def cleanComment(raw: String): String =
396+
var lines: Seq[String] = raw.trim.nn.split('\n').nn.toSeq
397+
lines = lines.map(l => l.substring(skipLineLead(l, -1), l.length).nn.trim.nn)
398+
var s = lines.foldLeft("") {
399+
case ("", s2) => s2
400+
case (s1, "") if s1.last == '\n' => s1 // Multiple newlines are kept as single newlines
401+
case (s1, "") => s1 + '\n'
402+
case (s1, s2) if s1.last == '\n' => s1 + s2
403+
case (s1, s2) => s1 + ' ' + s2
404+
}
405+
s.replaceAll(raw"\[\[", "").nn.replaceAll(raw"\]\]", "").nn.trim.nn
406+
407+
private def parseDocComment(raw: String): Unit =
408+
// Positions of the sections (@) in the docstring
409+
val tidx: List[(Int, Int)] = tagIndex(raw)
410+
411+
// Parse main comment
412+
var mainComment: String = raw.substring(skipLineLead(raw, 0), startTag(raw, tidx)).nn
413+
_mainDoc = cleanComment(mainComment)
414+
415+
// Parse arguments comments
416+
val argsCommentsSpans: Map[String, (Int, Int)] = paramDocs(raw, "@param", tidx)
417+
val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
418+
val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn })
419+
_argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
420+
end Documentation
117421
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,8 @@ class Definitions {
524524
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
525525
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
526526
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
527+
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
528+
527529

528530
@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
529531
@tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format)
@@ -849,6 +851,12 @@ class Definitions {
849851

850852
@tu lazy val XMLTopScopeModule: Symbol = requiredModule("scala.xml.TopScope")
851853

854+
@tu lazy val MainAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MainAnnotation")
855+
@tu lazy val MainAnnotationCommandInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.CommandInfo")
856+
@tu lazy val MainAnnotationParameterInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterInfo")
857+
@tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation")
858+
@tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command")
859+
852860
@tu lazy val CommandLineParserModule: Symbol = requiredModule("scala.util.CommandLineParser")
853861
@tu lazy val CLP_ParseError: ClassSymbol = CommandLineParserModule.requiredClass("ParseError").typeRef.symbol.asClass
854862
@tu lazy val CLP_parseArgument: Symbol = CommandLineParserModule.requiredMethod("parseArgument")

0 commit comments

Comments
 (0)