diff --git a/compiler/src/dotty/tools/dotc/ast/MainProxies.scala b/compiler/src/dotty/tools/dotc/ast/MainProxies.scala index 183854f3aede..5e969c0c38c9 100644 --- a/compiler/src/dotty/tools/dotc/ast/MainProxies.scala +++ b/compiler/src/dotty/tools/dotc/ast/MainProxies.scala @@ -2,30 +2,40 @@ package dotty.tools.dotc package ast import core._ -import Symbols._, Types._, Contexts._, Flags._, Constants._ -import StdNames.nme - -/** Generate proxy classes for @main functions. - * A function like - * - * @main def f(x: S, ys: T*) = ... - * - * would be translated to something like - * - * import CommandLineParser._ - * class f { - * @static def main(args: Array[String]): Unit = - * try - * f( - * parseArgument[S](args, 0), - * parseRemainingArguments[T](args, 1): _* - * ) - * catch case err: ParseError => showError(err) - * } - */ +import Symbols._, Types._, Contexts._, Decorators._, util.Spans._, Flags._, Constants._ +import StdNames.{nme, tpnme} +import ast.Trees._ +import Names.Name +import Comments.Comment +import NameKinds.DefaultGetterName +import Annotations.Annotation + object MainProxies { - def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = { + /** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */ + def proxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = { + mainAnnotationProxies(stats) ++ mainProxies(stats) + } + + /** Generate proxy classes for @main functions. + * A function like + * + * @main def f(x: S, ys: T*) = ... + * + * would be translated to something like + * + * import CommandLineParser._ + * class f { + * @static def main(args: Array[String]): Unit = + * try + * f( + * parseArgument[S](args, 0), + * parseRemainingArguments[T](args, 1): _* + * ) + * catch case err: ParseError => showError(err) + * } + */ + private def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = { import tpd._ def mainMethods(stats: List[Tree]): List[Symbol] = stats.flatMap { case stat: DefDef if stat.symbol.hasAnnotation(defn.MainAnnot) => @@ -39,7 +49,7 @@ object MainProxies { } import untpd._ - def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = { + private def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = { val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot).get.tree.span def pos = mainFun.sourcePos val argsRef = Ident(nme.args) @@ -116,4 +126,322 @@ object MainProxies { } result } + + private type DefaultValueSymbols = Map[Int, Symbol] + private type ParameterAnnotationss = Seq[Seq[Annotation]] + + /** + * Generate proxy classes for main functions. + * A function like + * + * /** + * * Lorem ipsum dolor sit amet + * * consectetur adipiscing elit. + * * + * * @param x my param x + * * @param ys all my params y + * */ + * @myMain(80) def f( + * @myMain.Alias("myX") x: S, + * y: S, + * ys: T* + * ) = ... + * + * would be translated to something like + * + * final class f { + * static def main(args: Array[String]): Unit = { + * val annotation = new myMain(80) + * val info = new Info( + * name = "f", + * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.", + * parameters = Seq( + * new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))), + * new scala.annotation.MainAnnotation.Parameter("y", "S", true, false, "", Seq()), + * new scala.annotation.MainAnnotation.Parameter("ys", "T", false, true, "all my params y", Seq()) + * ) + * ), + * val command = annotation.command(info, args) + * if command.isDefined then + * val cmd = command.get + * val args0: () => S = annotation.argGetter[S](info.parameters(0), cmd(0), None) + * val args1: () => S = annotation.argGetter[S](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) + * val args2: () => Seq[T] = annotation.varargGetter[T](info.parameters(2), cmd.drop(2)) + * annotation.run(() => f(args0(), args1(), args2()*)) + * } + * } + */ + private def mainAnnotationProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = { + import tpd._ + + /** + * Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this + * point of the compilation, they must be explicitly passed by [[mainProxy]]. + */ + def defaultValueSymbols(scope: Tree, funSymbol: Symbol): DefaultValueSymbols = + scope match { + case TypeDef(_, template: Template) => + template.body.flatMap((_: Tree) match { + case dd: DefDef if dd.name.is(DefaultGetterName) && dd.name.firstPart == funSymbol.name => + val DefaultGetterName.NumberedInfo(index) = dd.name.info + List(index -> dd.symbol) + case _ => Nil + }).toMap + case _ => Map.empty + } + + /** Computes the list of main methods present in the code. */ + def mainMethods(scope: Tree, stats: List[Tree]): List[(Symbol, ParameterAnnotationss, DefaultValueSymbols, Option[Comment])] = stats.flatMap { + case stat: DefDef => + val sym = stat.symbol + sym.annotations.filter(_.matches(defn.MainAnnotationClass)) match { + case Nil => + Nil + case _ :: Nil => + val paramAnnotations = stat.paramss.flatMap(_.map( + valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation)) + )) + (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil + case mainAnnot :: others => + report.error(s"method cannot have multiple main annotations", mainAnnot.tree) + Nil + } + case stat @ TypeDef(_, impl: Template) if stat.symbol.is(Module) => + mainMethods(stat, impl.body) + case _ => + Nil + } + + // Assuming that the top-level object was already generated, all main methods will have a scope + mainMethods(EmptyTree, stats).flatMap(mainAnnotationProxy) + } + + private def mainAnnotationProxy(mainFun: Symbol, paramAnnotations: ParameterAnnotationss, defaultValueSymbols: DefaultValueSymbols, docComment: Option[Comment])(using Context): Option[TypeDef] = { + val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass).get + def pos = mainFun.sourcePos + + val documentation = new Documentation(docComment) + + /** () => value */ + def unitToValue(value: Tree): Tree = + val defDef = DefDef(nme.ANON_FUN, List(Nil), TypeTree(), value) + Block(defDef, Closure(Nil, Ident(nme.ANON_FUN), EmptyTree)) + + /** Generate a list of trees containing the ParamInfo instantiations. + * + * A ParamInfo has the following shape + * ``` + * new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))) + * ``` + */ + def parameterInfos(mt: MethodType): List[Tree] = + extension (tree: Tree) def withProperty(sym: Symbol, args: List[Tree]) = + Apply(Select(tree, sym.name), args) + + for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield + val param = paramName.toString + val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias + val paramType = paramType0.dealias + + val paramTypeStr = formal.dealias.typeSymbol.owner.showFullName + "." + paramType.show + val hasDefault = defaultValueSymbols.contains(idx) + val isRepeated = formal.isRepeatedParam + val paramDoc = documentation.argDocs.getOrElse(param, "") + val paramAnnots = + val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList + Apply(ref(defn.SeqModule.termRef), annotationTrees) + + val constructorArgs = List(param, paramTypeStr, hasDefault, isRepeated, paramDoc) + .map(value => Literal(Constant(value))) + + New(TypeTree(defn.MainAnnotationParameter.typeRef), List(constructorArgs :+ paramAnnots)) + + end parameterInfos + + /** + * Creates a list of references and definitions of arguments. + * The goal is to create the + * `val args0: () => S = annotation.argGetter[S](0, cmd(0), None)` + * part of the code. + */ + def argValDefs(mt: MethodType): List[ValDef] = + for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield + val argName = nme.args ++ idx.toString + val isRepeated = formal.isRepeatedParam + val formalType = if isRepeated then formal.argTypes.head else formal + val getterName = if isRepeated then nme.varargGetter else nme.argGetter + val defaultValueGetterOpt = defaultValueSymbols.get(idx) match + case None => ref(defn.NoneModule.termRef) + case Some(dvSym) => + val value = unitToValue(ref(dvSym.termRef)) + Apply(ref(defn.SomeClass.companionModule.termRef), value) + val argGetter0 = TypeApply(Select(Ident(nme.annotation), getterName), TypeTree(formalType) :: Nil) + val index = Literal(Constant(idx)) + val paramInfo = Apply(Select(Ident(nme.info), nme.parameters), index) + val argGetter = + if isRepeated then Apply(argGetter0, List(paramInfo, Apply(Select(Ident(nme.cmd), nme.drop), List(index)))) + else Apply(argGetter0, List(paramInfo, Apply(Ident(nme.cmd), List(index)), defaultValueGetterOpt)) + ValDef(argName, TypeTree(), argGetter) + end argValDefs + + + /** Create a list of argument references that will be passed as argument to the main method. + * `args0`, ...`argn*` + */ + def argRefs(mt: MethodType): List[Tree] = + for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield + val argRef = Apply(Ident(nme.args ++ idx.toString), Nil) + if formal.isRepeatedParam then repeated(argRef) else argRef + end argRefs + + + /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */ + def instantiateAnnotation(annot: Annotation): Tree = + val argss = { + def recurse(t: tpd.Tree, acc: List[List[Tree]]): List[List[Tree]] = t match { + case Apply(t, args: List[tpd.Tree]) => recurse(t, extractArgs(args) :: acc) + case _ => acc + } + + def extractArgs(args: List[tpd.Tree]): List[Tree] = + args.flatMap { + case Typed(SeqLiteral(varargs, _), _) => varargs.map(arg => TypedSplice(arg)) + case arg: Select if arg.name.is(DefaultGetterName) => Nil // Ignore default values, they will be added later by the compiler + case arg => List(TypedSplice(arg)) + } + + recurse(annot.tree, Nil) + } + + New(TypeTree(annot.symbol.typeRef), argss) + end instantiateAnnotation + + def generateMainClass(mainCall: Tree, args: List[Tree], parameterInfos: List[Tree]): TypeDef = + val cmdInfo = + val nameTree = Literal(Constant(mainFun.showName)) + val docTree = Literal(Constant(documentation.mainDoc)) + val paramInfos = Apply(ref(defn.SeqModule.termRef), parameterInfos) + New(TypeTree(defn.MainAnnotationInfo.typeRef), List(List(nameTree, docTree, paramInfos))) + + val annotVal = ValDef( + nme.annotation, + TypeTree(), + instantiateAnnotation(mainAnnot) + ) + val infoVal = ValDef( + nme.info, + TypeTree(), + cmdInfo + ) + val command = ValDef( + nme.command, + TypeTree(), + Apply( + Select(Ident(nme.annotation), nme.command), + List(Ident(nme.info), Ident(nme.args)) + ) + ) + val argsVal = ValDef( + nme.cmd, + TypeTree(), + Select(Ident(nme.command), nme.get) + ) + val run = Apply(Select(Ident(nme.annotation), nme.run), mainCall) + val body0 = If( + Select(Ident(nme.command), nme.isDefined), + Block(argsVal :: args, run), + EmptyTree + ) + val body = Block(List(annotVal, infoVal, command), body0) // TODO add `if (cmd.nonEmpty)` + + val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree) + .withFlags(Param) + /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol. + * The annotations will be retype-checked in another scope that may not have the same imports. + */ + def insertTypeSplices = new TreeMap { + override def transform(tree: Tree)(using Context): Tree = tree match + case tree: tpd.Ident @unchecked => TypedSplice(tree) + case tree => super.transform(tree) + } + val annots = mainFun.annotations + .filterNot(_.matches(defn.MainAnnotationClass)) + .map(annot => insertTypeSplices.transform(annot.tree)) + val mainMeth = DefDef(nme.main, (mainArg :: Nil) :: Nil, TypeTree(defn.UnitType), body) + .withFlags(JavaStatic) + .withAnnotations(annots) + val mainTempl = Template(emptyConstructor, Nil, Nil, EmptyValDef, mainMeth :: Nil) + val mainCls = TypeDef(mainFun.name.toTypeName, mainTempl) + .withFlags(Final | Invisible) + mainCls.withSpan(mainAnnot.tree.span.toSynthetic) + end generateMainClass + + if (!mainFun.owner.isStaticOwner) + report.error(s"main method is not statically accessible", pos) + None + else mainFun.info match { + case _: ExprType => + Some(generateMainClass(unitToValue(ref(mainFun.termRef)), Nil, Nil)) + case mt: MethodType => + if (mt.isImplicitMethod) + report.error(s"main method cannot have implicit parameters", pos) + None + else mt.resType match + case restpe: MethodType => + report.error(s"main method cannot be curried", pos) + None + case _ => + Some(generateMainClass(unitToValue(Apply(ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt))) + case _: PolyType => + report.error(s"main method cannot have type parameters", pos) + None + case _ => + report.error(s"main can only annotate a method", pos) + None + } + } + + /** A class responsible for extracting the docstrings of a method. */ + private class Documentation(docComment: Option[Comment]): + import util.CommentParsing._ + + /** The main part of the documentation. */ + lazy val mainDoc: String = _mainDoc + /** The parameters identified by @param. Maps from parameter name to its documentation. */ + lazy val argDocs: Map[String, String] = _argDocs + + private var _mainDoc: String = "" + private var _argDocs: Map[String, String] = Map() + + docComment match { + case Some(comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw + case None => + } + + private def cleanComment(raw: String): String = + var lines: Seq[String] = raw.trim.nn.split('\n').nn.toSeq + lines = lines.map(l => l.substring(skipLineLead(l, -1), l.length).nn.trim.nn) + var s = lines.foldLeft("") { + case ("", s2) => s2 + case (s1, "") if s1.last == '\n' => s1 // Multiple newlines are kept as single newlines + case (s1, "") => s1 + '\n' + case (s1, s2) if s1.last == '\n' => s1 + s2 + case (s1, s2) => s1 + ' ' + s2 + } + s.replaceAll(raw"\[\[", "").nn.replaceAll(raw"\]\]", "").nn.trim.nn + + private def parseDocComment(raw: String): Unit = + // Positions of the sections (@) in the docstring + val tidx: List[(Int, Int)] = tagIndex(raw) + + // Parse main comment + var mainComment: String = raw.substring(skipLineLead(raw, 0), startTag(raw, tidx)).nn + _mainDoc = cleanComment(mainComment) + + // Parse arguments comments + val argsCommentsSpans: Map[String, (Int, Int)] = paramDocs(raw, "@param", tidx) + val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _)) + val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn }) + _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap + end Documentation } diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 5eb1ccb0f957..8aaaff52708d 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -528,6 +528,8 @@ class Definitions { @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) + @tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") @tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format) @@ -853,6 +855,12 @@ class Definitions { @tu lazy val XMLTopScopeModule: Symbol = requiredModule("scala.xml.TopScope") + @tu lazy val MainAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MainAnnotation") + @tu lazy val MainAnnotationInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Info") + @tu lazy val MainAnnotationParameter: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Parameter") + @tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation") + @tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command") + @tu lazy val CommandLineParserModule: Symbol = requiredModule("scala.util.CommandLineParser") @tu lazy val CLP_ParseError: ClassSymbol = CommandLineParserModule.requiredClass("ParseError").typeRef.symbol.asClass @tu lazy val CLP_parseArgument: Symbol = CommandLineParserModule.requiredMethod("parseArgument") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 1bf91bf69abe..dc9e48b65f47 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -397,6 +397,7 @@ object StdNames { val applyOrElse: N = "applyOrElse" val args : N = "args" val argv : N = "argv" + val argGetter : N = "argGetter" val arrayClass: N = "arrayClass" val arrayElementClass: N = "arrayElementClass" val arrayType: N = "arrayType" @@ -427,6 +428,8 @@ object StdNames { val classOf: N = "classOf" val classType: N = "classType" val clone_ : N = "clone" + val cmd: N = "cmd" + val command: N = "command" val common: N = "common" val compiletime : N = "compiletime" val conforms_ : N = "$conforms" @@ -540,6 +543,7 @@ object StdNames { val ordinalDollar: N = "$ordinal" val ordinalDollar_ : N = "_$ordinal" val origin: N = "origin" + val parameters: N = "parameters" val parts: N = "parts" val postfixOps: N = "postfixOps" val prefix : N = "prefix" @@ -613,6 +617,7 @@ object StdNames { val fromOrdinal: N = "fromOrdinal" val values: N = "values" val view_ : N = "view" + val varargGetter : N = "varargGetter" val wait_ : N = "wait" val wildcardType: N = "wildcardType" val withFilter: N = "withFilter" diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index b7c65a30e7b4..1cce3fdea280 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -1351,12 +1351,13 @@ trait Checking { def checkAnnotApplicable(annot: Tree, sym: Symbol)(using Context): Boolean = !ctx.reporter.reportsErrorsFor { val annotCls = Annotations.annotClass(annot) + val concreteAnnot = Annotations.ConcreteAnnotation(annot) val pos = annot.srcPos - if (annotCls == defn.MainAnnot) { + if (annotCls == defn.MainAnnot || concreteAnnot.matches(defn.MainAnnotationClass)) { if (!sym.isRealMethod) - report.error(em"@main annotation cannot be applied to $sym", pos) + report.error(em"main annotation cannot be applied to $sym", pos) if (!sym.owner.is(Module) || !sym.owner.isStatic) - report.error(em"$sym cannot be a @main method since it cannot be accessed statically", pos) + report.error(em"$sym cannot be a main method since it cannot be accessed statically", pos) } // TODO: Add more checks here } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index ac8d6152812e..d915a35b88b4 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2602,7 +2602,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer pkg.moduleClass.info.decls.lookup(topLevelClassName).ensureCompleted() var stats1 = typedStats(tree.stats, pkg.moduleClass)._1 if (!ctx.isAfterTyper) - stats1 = stats1 ++ typedBlockStats(MainProxies.mainProxies(stats1))._1 + stats1 = stats1 ++ typedBlockStats(MainProxies.proxies(stats1))._1 cpy.PackageDef(tree)(pid1, stats1).withType(pkg.termRef) } case _ => diff --git a/docs/_docs/reference/experimental/main-annotation.md b/docs/_docs/reference/experimental/main-annotation.md new file mode 100644 index 000000000000..d2172d97a284 --- /dev/null +++ b/docs/_docs/reference/experimental/main-annotation.md @@ -0,0 +1,97 @@ +--- +layout: doc-page +title: "MainAnnotation" +--- + +`MainAnnotation` provides a generic way to define main annotations such as `@main`. + +When a users annotates a method with an annotation that extends `MainAnnotation` a class with a `main` method will be generated. The main method will contain the code needed to parse the command line arguments and run the application. + +```scala +/** Sum all the numbers + * + * @param first Fist number to sum + * @param rest The rest of the numbers to sum + */ +@myMain def sum(first: Int, second: Int = 0, rest: Int*): Int = first + second + rest.sum +``` + +```scala +object foo { + def main(args: Array[String]): Unit = { + val mainAnnot = new myMain() + val info = new Info( + name = "foo.main", + documentation = "Sum all the numbers", + parameters = Seq( + new Parameter("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()), + new Parameter("second", "scala.Int", hasDefault=true, isVarargs=false, "", Seq()), + new Parameter("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq()) + ) + ) + val mainArgsOpt = mainAnnot.command(info, args) + if mainArgsOpt.isDefined then + val mainArgs = mainArgsOpt.get + val args0 = mainAnnot.argGetter[Int](info.parameters(0), mainArgs(0), None) // using a parser of Int + val args1 = mainAnnot.argGetter[Int](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) // using a parser of Int + val args2 = mainAnnot.varargGetter[Int](info.parameters(2), mainArgs.drop(2)) // using a parser of Int + mainAnnot.run(() => sum(args0(), args1(), args2()*)) + } +} +``` + +The implementation of the `main` method first instantiates the annotation and then call `command`. +When calling the `command`, the arguments can be checked and preprocessed. +Then it defines a series of argument getters calling `argGetter` for each parameter and `varargGetter` for the last one if it is a varargs. `argGetter` gets an optional lambda that computes the default argument. +Finally, the `run` method is called to run the application. It receives a by-name argument that contains the call the annotated method with the instantiations arguments (using the lambdas from `argGetter`/`varargGetter`). + + +Example of implementation of `myMain` that takes all arguments positionally. It used `util.CommandLineParser.FromString` and expects no default arguments. For simplicity, any errors in preprocessing or parsing results in crash. + +```scala +// Parser used to parse command line arguments +import scala.util.CommandLineParser.FromString[T] + +// Result type of the annotated method is Int and arguments are parsed using FromString +@experimental class myMain extends MainAnnotation[FromString, Int]: + import MainAnnotation.{ Info, Parameter } + + def command(info: Info, args: Seq[String]): Option[Seq[String]] = + if args.contains("--help") then + println(info.documentation) + None // do not parse or run the program + else if info.parameters.exists(_.hasDefault) then + println("Default arguments are not supported") + None + else if info.hasVarargs then + val numPlainArgs = info.parameters.length - 1 + if numPlainArgs <= args.length then + println("Not enough arguments") + None + else + Some(args) + else + if info.parameters.length <= args.length then + println("Not enough arguments") + None + else if info.parameters.length >= args.length then + println("Too many arguments") + None + else + Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T = + () => parser.fromString(arg) + + def varargGetter[T](param: Parameter, args: Seq[String])(using parser: FromString[T]): () => Seq[T] = + () => args.map(arg => parser.fromString(arg)) + + def run(program: () => Int): Unit = + println("executing program") + + try { + val result = program() + println("result: " + result) + println("executed program") +end myMain +``` diff --git a/docs/sidebar.yml b/docs/sidebar.yml index 0f6ed6bf935d..7c68120bfbc2 100644 --- a/docs/sidebar.yml +++ b/docs/sidebar.yml @@ -147,6 +147,7 @@ subsection: - page: reference/experimental/named-typeargs-spec.md - page: reference/experimental/numeric-literals.md - page: reference/experimental/explicit-nulls.md + - page: reference/experimental/main-annotation.md - page: reference/experimental/cc.md - page: reference/experimental/tupled-function.md - page: reference/syntax.md diff --git a/library/src/scala/annotation/MainAnnotation.scala b/library/src/scala/annotation/MainAnnotation.scala new file mode 100644 index 000000000000..9d2f5362ba15 --- /dev/null +++ b/library/src/scala/annotation/MainAnnotation.scala @@ -0,0 +1,126 @@ +package scala.annotation + +/** MainAnnotation provides the functionality for a compiler-generated main class. + * It links a compiler-generated main method (call it compiler-main) to a user + * written main method (user-main). + * The protocol of calls from compiler-main is as follows: + * + * - create a `command` with the command line arguments, + * - for each parameter of user-main, a call to `command.argGetter`, + * or `command.varargGetter` if is a final varargs parameter, + * - a call to `command.run` with the closure of user-main applied to all arguments. + * + * Example: + * ```scala + * /** Sum all the numbers + * * + * * @param first Fist number to sum + * * @param rest The rest of the numbers to sum + * */ + * @myMain def sum(first: Int, second: Int = 0, rest: Int*): Int = first + second + rest.sum + * ``` + * generates + * ```scala + * object foo { + * def main(args: Array[String]): Unit = { + * val mainAnnot = new myMain() + * val info = new Info( + * name = "foo.main", + * documentation = "Sum all the numbers", + * parameters = Seq( + * new Parameter("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum"), + * new Parameter("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum") + * ) + * ) + * val mainArgsOpt = mainAnnot.command(info, args) + * if mainArgsOpt.isDefined then + * val mainArgs = mainArgsOpt.get + * val args0 = mainAnnot.argGetter[Int](info.parameters(0), mainArgs(0), None) // using parser Int + * val args1 = mainAnnot.argGetter[Int](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) // using parser Int + * val args2 = mainAnnot.varargGetter[Int](info.parameters(2), mainArgs.drop(2)) // using parser Int + * mainAnnot.run(() => sum(args0(), args1(), args2()*)) + * } + * } + * ``` + * + * @param Parser The class used for argument string parsing and arguments into a `T` + * @param Result The required result type of the main method. + * If this type is Any or Unit, any type will be accepted. + */ +@experimental +trait MainAnnotation[Parser[_], Result] extends StaticAnnotation: + import MainAnnotation.{Info, Parameter} + + /** Process the command arguments before parsing them. + * + * Return `Some` of the sequence of arguments that will be parsed to be passed to the main method. + * This sequence needs to have the same length as the number of parameters of the main method (i.e. `info.parameters.size`). + * If there is a varags parameter, then the sequence must be at least of length `info.parameters.size - 1`. + * + * Returns `None` if the arguments are invalid and parsing and run should be stopped. + * + * @param info The information about the command (name, documentation and info about parameters) + * @param args The command line arguments + */ + def command(info: Info, args: Seq[String]): Option[Seq[String]] + + /** The getter for the `idx`th argument of type `T` + * + * @param idx The index of the argument + * @param defaultArgument Optional lambda to instantiate the default argument + */ + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using Parser[T]): () => T + + /** The getter for a final varargs argument of type `T*` */ + def varargGetter[T](param: Parameter, args: Seq[String])(using Parser[T]): () => Seq[T] + + /** Run `program` if all arguments are valid if all arguments are valid + * + * @param program A function containing the call to the main method and instantiation of its arguments + */ + def run(program: () => Result): Unit + +end MainAnnotation + +@experimental +object MainAnnotation: + + /** Information about the main method + * + * @param name The name of the main method + * @param documentation The documentation of the main method without the `@param` documentation (see Parameter.documentaion) + * @param parameters Information about the parameters of the main method + */ + final class Info( + val name: String, + val documentation: String, + val parameters: Seq[Parameter], + ): + + /** If the method ends with a varargs parameter */ + def hasVarargs: Boolean = parameters.nonEmpty && parameters.last.isVarargs + + end Info + + /** Information about a parameter of a main method + * + * @param name The name of the parameter + * @param typeName The name of the parameter's type + * @param hasDefault If the parameter has a default argument + * @param isVarargs If the parameter is a varargs parameter (can only be true for the last parameter) + * @param documentation The documentation of the parameter (from `@param` documentation in the main method) + * @param annotations The annotations of the parameter that extend `ParameterAnnotation` + */ + final class Parameter( + val name: String, + val typeName: String, + val hasDefault: Boolean, + val isVarargs: Boolean, + val documentation: String, + val annotations: Seq[ParameterAnnotation], + ) + + /** Marker trait for annotations that will be included in the Parameter annotations. */ + trait ParameterAnnotation extends StaticAnnotation + +end MainAnnotation diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index 2c4fd4992432..8bd16f134f57 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -3,13 +3,21 @@ import com.typesafe.tools.mima.core._ object MiMaFilters { val Library: Seq[ProblemFilter] = Seq( - - // Those are OK because user code is not allowed to inherit from Quotes: + // Experimental APIs that can be added in 3.2.0 or later + ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.Tuples.append"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.asQuotes"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#ClassDefModule.apply"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolModule.newClass"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.typeRef"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.termRef"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#TypeTreeModule.ref"), + + // Experimental `MainAnnotation` APIs. Can be added in 3.3.0 or later. + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$Command"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$CommandInfo"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$ParameterInfo"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$ParameterAnnotation"), ) } diff --git a/project/resources/referenceReplacements/sidebar.yml b/project/resources/referenceReplacements/sidebar.yml index a8453449e73e..680b44d353d4 100644 --- a/project/resources/referenceReplacements/sidebar.yml +++ b/project/resources/referenceReplacements/sidebar.yml @@ -127,6 +127,7 @@ subsection: - page: reference/experimental/named-typeargs-spec.md - page: reference/experimental/numeric-literals.md - page: reference/experimental/explicit-nulls.md + - page: reference/experimental/main-annotation.md - page: reference/experimental/cc.md - page: reference/syntax.md - title: Language Versions diff --git a/project/scripts/expected-links/reference-expected-links.txt b/project/scripts/expected-links/reference-expected-links.txt index 737267576c6e..f51727b7b432 100644 --- a/project/scripts/expected-links/reference-expected-links.txt +++ b/project/scripts/expected-links/reference-expected-links.txt @@ -68,6 +68,7 @@ ./experimental/erased-defs.html ./experimental/explicit-nulls.html ./experimental/index.html +./experimental/main-annotation.html ./experimental/named-typeargs-spec.html ./experimental/named-typeargs.html ./experimental/numeric-literals.html diff --git a/tests/neg/main-annotation-mainannotation.scala b/tests/neg/main-annotation-mainannotation.scala new file mode 100644 index 000000000000..21e37d1779af --- /dev/null +++ b/tests/neg/main-annotation-mainannotation.scala @@ -0,0 +1,3 @@ +import scala.annotation.MainAnnotation + +@MainAnnotation def f(i: Int, n: Int) = () // error diff --git a/tests/run/main-annotation-example.check b/tests/run/main-annotation-example.check new file mode 100644 index 000000000000..97fcf11da08b --- /dev/null +++ b/tests/run/main-annotation-example.check @@ -0,0 +1,3 @@ +executing program +result: 28 +executed program diff --git a/tests/run/main-annotation-example.scala b/tests/run/main-annotation-example.scala new file mode 100644 index 000000000000..954278d6b26f --- /dev/null +++ b/tests/run/main-annotation-example.scala @@ -0,0 +1,62 @@ +import scala.annotation.* +import collection.mutable +import scala.util.CommandLineParser.FromString + +/** Sum all the numbers + * + * @param first Fist number to sum + * @param rest The rest of the numbers to sum + */ +@myMain def sum(first: Int, rest: Int*): Int = first + rest.sum + + +object Test: + def callMain(args: Array[String]): Unit = + val clazz = Class.forName("sum") + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, args) + + def main(args: Array[String]): Unit = + callMain(Array("23", "2", "3")) +end Test + +@experimental +class myMain extends MainAnnotation[FromString, Int]: + import MainAnnotation.{ Info, Parameter } + + def command(info: Info, args: Seq[String]): Option[Seq[String]] = + if args.contains("--help") then + println(info.documentation) + None // do not parse or run the program + else if info.parameters.exists(_.hasDefault) then + println("Default arguments are not supported") + None + else if info.hasVarargs then + val numPlainArgs = info.parameters.length - 1 + if numPlainArgs > args.length then + println("Not enough arguments") + None + else + Some(args) + else + if info.parameters.length > args.length then + println("Not enough arguments") + None + else if info.parameters.length < args.length then + println("Too many arguments") + None + else + Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T = + () => parser.fromString(arg) + + def varargGetter[T](param: Parameter, args: Seq[String])(using parser: FromString[T]): () => Seq[T] = + () => args.map(arg => parser.fromString(arg)) + + def run(program: () => Int): Unit = + println("executing program") + val result = program() + println("result: " + result) + println("executed program") +end myMain diff --git a/tests/run/main-annotation-homemade-annot-1.check b/tests/run/main-annotation-homemade-annot-1.check new file mode 100644 index 000000000000..4b7ff457bb11 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-1.check @@ -0,0 +1,4 @@ +42 +42 +1 +2 diff --git a/tests/run/main-annotation-homemade-annot-1.scala b/tests/run/main-annotation-homemade-annot-1.scala new file mode 100644 index 000000000000..daf27b944d99 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-1.scala @@ -0,0 +1,46 @@ +import scala.concurrent._ +import scala.annotation.* +import scala.collection.mutable +import ExecutionContext.Implicits.global +import duration._ +import util.CommandLineParser.FromString + +@mainAwait def get(wait: Int): Future[Int] = Future{ + Thread.sleep(1000 * wait) + 42 +} + +@mainAwait def getMany(wait: Int*): Future[Int] = Future{ + Thread.sleep(1000 * wait.sum) + wait.length +} + +object Test: + def callMain(cls: String, args: Array[String]): Unit = + val clazz = Class.forName(cls) + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, args) + + def main(args: Array[String]): Unit = + println(Await.result(get(1), Duration(2, SECONDS))) + callMain("get", Array("1")) + callMain("getMany", Array("1")) + callMain("getMany", Array("0", "1")) +end Test + +@experimental +class mainAwait(timeout: Int = 2) extends MainAnnotation[FromString, Future[Any]]: + import MainAnnotation.* + + // This is a toy example, it only works with positional args + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = + () => p.fromString(arg) + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = + () => for arg <- args yield p.fromString(arg) + + def run(f: () => Future[Any]): Unit = println(Await.result(f(), Duration(timeout, SECONDS))) + +end mainAwait diff --git a/tests/run/main-annotation-homemade-annot-2.check b/tests/run/main-annotation-homemade-annot-2.check new file mode 100644 index 000000000000..f57ec79b8dbd --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-2.check @@ -0,0 +1,11 @@ +I was run! +A +I was run! +A +I was run! +A +Here are some colors: +Purple smart, Blue fast, White fashion, Yellow quiet, Orange honest, Pink loud +This will be printed, but nothing more. +This will be printed, but nothing more. +This will be printed, but nothing more. diff --git a/tests/run/main-annotation-homemade-annot-2.scala b/tests/run/main-annotation-homemade-annot-2.scala new file mode 100644 index 000000000000..3cee9151282d --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-2.scala @@ -0,0 +1,49 @@ +import scala.collection.mutable +import scala.annotation.* +import util.CommandLineParser.FromString + +@myMain()("A") +def foo1(): Unit = println("I was run!") + +@myMain(0)("This should not be printed") +def foo2() = throw new Exception("This should not be run") + +@myMain(1)("Purple smart", "Blue fast", "White fashion", "Yellow quiet", "Orange honest", "Pink loud") +def foo3() = println("Here are some colors:") + +@myMain()() +def foo4() = println("This will be printed, but nothing more.") + +object Test: + val allClazzes: Seq[Class[?]] = + LazyList.from(1).map(i => scala.util.Try(Class.forName("foo" + i.toString))).takeWhile(_.isSuccess).map(_.get) + + def callMains(): Unit = + for (clazz <- allClazzes) + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) + + def main(args: Array[String]) = + callMains() +end Test + +// This is a toy example, it only works with positional args +@experimental +class myMain(runs: Int = 3)(after: String*) extends MainAnnotation[FromString, Any]: + import MainAnnotation.* + + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = + () => p.fromString(arg) + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = + () => for arg <- args yield p.fromString(arg) + + def run(f: () => Any): Unit = + for (_ <- 1 to runs) + f() + if after.length > 0 then println(after.mkString(", ")) + end run + +end myMain diff --git a/tests/run/main-annotation-homemade-annot-3.check b/tests/run/main-annotation-homemade-annot-3.check new file mode 100644 index 000000000000..cd0875583aab --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-3.check @@ -0,0 +1 @@ +Hello world! diff --git a/tests/run/main-annotation-homemade-annot-3.scala b/tests/run/main-annotation-homemade-annot-3.scala new file mode 100644 index 000000000000..3fc42abcce79 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-3.scala @@ -0,0 +1,23 @@ +import scala.annotation.* +import scala.util.CommandLineParser.FromString + +@mainNoArgs def foo() = println("Hello world!") + +object Test: + def main(args: Array[String]) = + val clazz = Class.forName("foo") + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) +end Test + +@experimental +class mainNoArgs extends MainAnnotation[FromString, Any]: + import MainAnnotation.* + + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = ??? + + def run(program: () => Any): Unit = program() diff --git a/tests/run/main-annotation-homemade-annot-4.check b/tests/run/main-annotation-homemade-annot-4.check new file mode 100644 index 000000000000..cd0875583aab --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-4.check @@ -0,0 +1 @@ +Hello world! diff --git a/tests/run/main-annotation-homemade-annot-4.scala b/tests/run/main-annotation-homemade-annot-4.scala new file mode 100644 index 000000000000..0dbd006ee5b1 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-4.scala @@ -0,0 +1,24 @@ +import scala.annotation.* +import scala.util.CommandLineParser.FromString + +@mainManyArgs(1, "B", 3) def foo() = println("Hello world!") + +object Test: + def main(args: Array[String]) = + val clazz = Class.forName("foo") + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) +end Test + +@experimental +class mainManyArgs(i1: Int, s2: String, i3: Int) extends MainAnnotation[FromString, Any]: + import MainAnnotation.* + + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = ??? + + + def run(program: () => Any): Unit = program() diff --git a/tests/run/main-annotation-homemade-annot-5.check b/tests/run/main-annotation-homemade-annot-5.check new file mode 100644 index 000000000000..7d60d6656c81 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-5.check @@ -0,0 +1,2 @@ +Hello world! +Hello world! diff --git a/tests/run/main-annotation-homemade-annot-5.scala b/tests/run/main-annotation-homemade-annot-5.scala new file mode 100644 index 000000000000..d61cd55eb852 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-5.scala @@ -0,0 +1,25 @@ +import scala.annotation.* +import scala.util.CommandLineParser.FromString + +@mainManyArgs(Some(1)) def foo() = println("Hello world!") +@mainManyArgs(None) def bar() = println("Hello world!") + +object Test: + def main(args: Array[String]) = + for (methodName <- List("foo", "bar")) + val clazz = Class.forName(methodName) + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) +end Test + +@experimental +class mainManyArgs(o: Option[Int]) extends MainAnnotation[FromString, Any]: + import MainAnnotation.* + + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = ??? + + def run(program: () => Any): Unit = program() diff --git a/tests/run/main-annotation-homemade-annot-6.check b/tests/run/main-annotation-homemade-annot-6.check new file mode 100644 index 000000000000..5cc6c07e1f56 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-6.check @@ -0,0 +1,25 @@ +command( + Array(1, 2), + foo, + "Foo docs", + Seq( + Parameter(name="i", typeName="scala.Int", hasDefault=false, isVarargs=false, documentation="", annotations=List()), + Parameter(name="j", typeName="java.lang.String", hasDefault=true, isVarargs=false, documentation="", annotations=List()) + )* +) +run() +foo(42, abc) + +command( + Array(1, 2), + bar, + "Bar docs", + Seq( + Parameter(name="i", typeName="scala.collection.immutable.List[Int]", hasDefault=false, isVarargs=false, documentation="the first parameter", annotations=List(MyParamAnnot(3))), + Parameter(name="rest", typeName="scala.Int", hasDefault=false, isVarargs=true, documentation="", annotations=List()) + )* +) +varargGetter() +run() +bar(List(42), 42, 42) + diff --git a/tests/run/main-annotation-homemade-annot-6.scala b/tests/run/main-annotation-homemade-annot-6.scala new file mode 100644 index 000000000000..9ba0b31fc689 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-6.scala @@ -0,0 +1,62 @@ +import scala.annotation.* + +/** Foo docs */ +@myMain def foo(i: Int, j: String = "2") = println(s"foo($i, $j)") +/** Bar docs + * + * @param i the first parameter + */ +@myMain def bar(@MyParamAnnot(3) i: List[Int], rest: Int*) = println(s"bar($i, ${rest.mkString(", ")})") + +object Test: + def main(args: Array[String]) = + for (methodName <- List("foo", "bar")) + val clazz = Class.forName(methodName) + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]("1", "2")) +end Test + +@experimental +class myMain extends MainAnnotation[Make, Any]: + import MainAnnotation.* + + def command(info: Info, args: Seq[String]): Option[Seq[String]] = + def paramInfoString(paramInfo: Parameter) = + import paramInfo.* + s" Parameter(name=\"$name\", typeName=\"$typeName\", hasDefault=$hasDefault, isVarargs=$isVarargs, documentation=\"$documentation\", annotations=$annotations)" + println( + s"""command( + | ${args.mkString("Array(", ", ", ")")}, + | ${info.name}, + | "${info.documentation}", + | ${info.parameters.map(paramInfoString).mkString("Seq(\n", ",\n", "\n )*")} + |)""".stripMargin) + Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: Make[T]): () => T = + () => p.make + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: Make[T]): () => Seq[T] = + println("varargGetter()") + () => Seq(p.make, p.make) + + def run(f: () => Any): Unit = + println("run()") + f() + println() + +@experimental +case class MyParamAnnot(n: Int) extends MainAnnotation.ParameterAnnotation + +trait Make[T]: + def make: T + +given Make[Int] with + def make: Int = 42 + + +given Make[String] with + def make: String = "abc" + +given [T: Make]: Make[List[T]] with + def make: List[T] = List(summon[Make[T]].make) diff --git a/tests/run/main-annotation-newMain.scala b/tests/run/main-annotation-newMain.scala new file mode 100644 index 000000000000..9e85d5f948cc --- /dev/null +++ b/tests/run/main-annotation-newMain.scala @@ -0,0 +1,320 @@ +import scala.annotation.* +import collection.mutable +import scala.util.CommandLineParser.FromString + +@newMain def happyBirthday(age: Int, name: String, others: String*) = + val suffix = + age % 100 match + case 11 | 12 | 13 => "th" + case _ => + age % 10 match + case 1 => "st" + case 2 => "nd" + case 3 => "rd" + case _ => "th" + val bldr = new StringBuilder(s"Happy $age$suffix birthday, $name") + for other <- others do bldr.append(" and ").append(other) + println(bldr) + + +object Test: + def callMain(args: Array[String]): Unit = + val clazz = Class.forName("happyBirthday") + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, args) + + def main(args: Array[String]): Unit = + callMain(Array("23", "Lisa", "Peter")) +end Test + + + +@experimental +final class newMain extends MainAnnotation[FromString, Any]: + import newMain._ + import MainAnnotation._ + + private inline val argMarker = "--" + private inline val shortArgMarker = "-" + + /** The name of the special argument to display the method's help. + * If one of the method's parameters is called the same, will be ignored. + */ + private inline val helpArg = "help" + + /** The short name of the special argument to display the method's help. + * If one of the method's parameters uses the same short name, will be ignored. + */ + private inline val shortHelpArg = 'h' + + private inline val maxUsageLineLength = 120 + + private var info: Info = _ // TODO remove this var + + + /** A buffer for all errors */ + private val errors = new mutable.ArrayBuffer[String] + + /** Issue an error, and return an uncallable getter */ + private def error(msg: String): () => Nothing = + errors += msg + () => throw new AssertionError("trying to get invalid argument") + + private def getAliases(param: Parameter): Seq[String] = + param.annotations.collect{ case a: Alias => a }.flatMap(_.aliases) + + private def getAlternativeNames(param: Parameter): Seq[String] = + getAliases(param).filter(nameIsValid(_)) + + private def getShortNames(param: Parameter): Seq[Char] = + getAliases(param).filter(shortNameIsValid(_)).map(_(0)) + + private inline def nameIsValid(name: String): Boolean = + name.length > 1 // TODO add more checks for illegal characters + + private inline def shortNameIsValid(name: String): Boolean = + name.length == 1 && shortNameIsValidChar(name(0)) + + private inline def shortNameIsValidChar(shortName: Char): Boolean = + ('A' <= shortName && shortName <= 'Z') || ('a' <= shortName && shortName <= 'z') + + private def getNameWithMarker(name: String | Char): String = name match { + case c: Char => shortArgMarker + c + case s: String if shortNameIsValid(s) => shortArgMarker + s + case s => argMarker + s + } + + private def getInvalidNames(param: Parameter): Seq[String | Char] = + getAliases(param).filter(name => !nameIsValid(name) && !shortNameIsValid(name)) + + def command(info: Info, args: Seq[String]): Option[Seq[String]] = + this.info = info + + val namesToCanonicalName: Map[String, String] = info.parameters.flatMap( + infos => + val names = getAlternativeNames(infos) + val canonicalName = infos.name + if nameIsValid(canonicalName) then (canonicalName +: names).map(_ -> canonicalName) + else names.map(_ -> canonicalName) + ).toMap + val shortNamesToCanonicalName: Map[Char, String] = info.parameters.flatMap( + infos => + val names = getShortNames(infos) + val canonicalName = infos.name + if shortNameIsValid(canonicalName) then (canonicalName(0) +: names).map(_ -> canonicalName) + else names.map(_ -> canonicalName) + ).toMap + + val helpIsOverridden = namesToCanonicalName.exists((name, _) => name == helpArg) + val shortHelpIsOverridden = shortNamesToCanonicalName.exists((name, _) => name == shortHelpArg) + + val (positionalArgs, byNameArgs, invalidByNameArgs) = { + def getCanonicalArgName(arg: String): Option[String] = + if arg.startsWith(argMarker) && arg.length > argMarker.length then + namesToCanonicalName.get(arg.drop(argMarker.length)) + else if arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 then + shortNamesToCanonicalName.get(arg(shortArgMarker.length)) + else + None + + def isArgName(arg: String): Boolean = + val isFullName = arg.startsWith(argMarker) + val isShortName = arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 && shortNameIsValidChar(arg(shortArgMarker.length)) + isFullName || isShortName + + def recurse(remainingArgs: Seq[String], pa: mutable.Queue[String], bna: Seq[(String, String)], ia: Seq[String]): (mutable.Queue[String], Seq[(String, String)], Seq[String]) = + remainingArgs match { + case Seq() => + (pa, bna, ia) + case argName +: argValue +: rest if isArgName(argName) => + getCanonicalArgName(argName) match { + case Some(canonicalName) => recurse(rest, pa, bna :+ (canonicalName -> argValue), ia) + case None => recurse(rest, pa, bna, ia :+ argName) + } + case arg +: rest => + recurse(rest, pa :+ arg, bna, ia) + } + + val (pa, bna, ia) = recurse(args.toSeq, mutable.Queue.empty, Vector(), Vector()) + val nameToArgValues: Map[String, Seq[String]] = if bna.isEmpty then Map.empty else bna.groupMapReduce(_._1)(p => List(p._2))(_ ++ _) + (pa, nameToArgValues, ia) + } + + val argStrings: Seq[Seq[String]] = + for paramInfo <- info.parameters yield { + if (paramInfo.isVarargs) { + val byNameGetters = byNameArgs.getOrElse(paramInfo.name, Seq()) + val positionalGetters = positionalArgs.removeAll() + // First take arguments passed by name, then those passed by position + byNameGetters ++ positionalGetters + } else { + byNameArgs.get(paramInfo.name) match + case Some(Nil) => + throw AssertionError(s"${paramInfo.name} present in byNameArgs, but it has no argument value") + case Some(argValues) => + if argValues.length > 1 then + // Do not accept multiple values + // Remove this test to take last given argument + error(s"more than one value for ${paramInfo.name}: ${argValues.mkString(", ")}") + Nil + else + List(argValues.last) + case None => + if positionalArgs.length > 0 then + List(positionalArgs.dequeue()) + else if paramInfo.hasDefault then + Nil + else + error(s"missing argument for ${paramInfo.name}") + Nil + } + } + + // Check aliases unicity + val nameAndCanonicalName = info.parameters.flatMap { + case paramInfo => (paramInfo.name +: getAlternativeNames(paramInfo) ++: getShortNames(paramInfo)).map(_ -> paramInfo.name) + } + val nameToCanonicalNames = nameAndCanonicalName.groupMap(_._1)(_._2) + + for (name, canonicalNames) <- nameToCanonicalNames if canonicalNames.length > 1 do + throw IllegalArgumentException(s"$name is used for multiple parameters: ${canonicalNames.mkString(", ")}") + + // Check aliases validity + val problematicNames = info.parameters.flatMap(getInvalidNames) + if problematicNames.length > 0 then + throw IllegalArgumentException(s"The following aliases are invalid: ${problematicNames.mkString(", ")}") + + // Handle unused and invalid args + for (remainingArg <- positionalArgs) error(s"unused argument: $remainingArg") + for (invalidArg <- invalidByNameArgs) error(s"unknown argument name: $invalidArg") + + val displayHelp = + (!helpIsOverridden && args.contains(getNameWithMarker(helpArg))) || + (!shortHelpIsOverridden && args.contains(getNameWithMarker(shortHelpArg))) + + if displayHelp then + usage() + println() + explain() + None + else if errors.nonEmpty then + for msg <- errors do println(s"Error: $msg") + usage() + None + else + Some(argStrings.flatten) + end command + + private def usage(): Unit = + def argsUsage: Seq[String] = + for (infos <- info.parameters) + yield { + val canonicalName = getNameWithMarker(infos.name) + val shortNames = getShortNames(infos).map(getNameWithMarker) + val alternativeNames = getAlternativeNames(infos).map(getNameWithMarker) + val namesPrint = (canonicalName +: alternativeNames ++: shortNames).mkString("[", " | ", "]") + val shortTypeName = infos.typeName.split('.').last + if infos.isVarargs then s"[<$shortTypeName> [<$shortTypeName> [...]]]" + else if infos.hasDefault then s"[$namesPrint <$shortTypeName>]" + else s"$namesPrint <$shortTypeName>" + } + + def wrapArgumentUsages(argsUsage: Seq[String], maxLength: Int): Seq[String] = { + def recurse(args: Seq[String], currentLine: String, acc: Vector[String]): Seq[String] = + (args, currentLine) match { + case (Nil, "") => acc + case (Nil, l) => (acc :+ l) + case (arg +: t, "") => recurse(t, arg, acc) + case (arg +: t, l) if l.length + 1 + arg.length <= maxLength => recurse(t, s"$l $arg", acc) + case (arg +: t, l) => recurse(t, arg, acc :+ l) + } + + recurse(argsUsage, "", Vector()).toList + } + + val usageBeginning = s"Usage: ${info.name} " + val argsOffset = usageBeginning.length + val usages = wrapArgumentUsages(argsUsage, maxUsageLineLength - argsOffset) + + println(usageBeginning + usages.mkString("\n" + " " * argsOffset)) + end usage + + private def explain(): Unit = + inline def shiftLines(s: Seq[String], shift: Int): String = s.map(" " * shift + _).mkString("\n") + + def wrapLongLine(line: String, maxLength: Int): List[String] = { + def recurse(s: String, acc: Vector[String]): Seq[String] = + val lastSpace = s.trim.nn.lastIndexOf(' ', maxLength) + if ((s.length <= maxLength) || (lastSpace < 0)) + acc :+ s + else { + val (shortLine, rest) = s.splitAt(lastSpace) + recurse(rest.trim.nn, acc :+ shortLine) + } + + recurse(line, Vector()).toList + } + + if (info.documentation.nonEmpty) + println(wrapLongLine(info.documentation, maxUsageLineLength).mkString("\n")) + if (info.parameters.nonEmpty) { + val argNameShift = 2 + val argDocShift = argNameShift + 2 + + println("Arguments:") + for infos <- info.parameters do + val canonicalName = getNameWithMarker(infos.name) + val shortNames = getShortNames(infos).map(getNameWithMarker) + val alternativeNames = getAlternativeNames(infos).map(getNameWithMarker) + val otherNames = (alternativeNames ++: shortNames) match { + case Seq() => "" + case names => names.mkString("(", ", ", ") ") + } + val argDoc = StringBuilder(" " * argNameShift) + argDoc.append(s"$canonicalName $otherNames- ${infos.typeName.split('.').last}") + if infos.isVarargs then argDoc.append(" (vararg)") + else if infos.hasDefault then argDoc.append(" (optional)") + + if (infos.documentation.nonEmpty) { + val shiftedDoc = + infos.documentation.split("\n").nn + .map(line => shiftLines(wrapLongLine(line.nn, maxUsageLineLength - argDocShift), argDocShift)) + .mkString("\n") + argDoc.append("\n").append(shiftedDoc) + } + + println(argDoc) + } + end explain + + private def convert[T](argName: String, arg: String, p: FromString[T]): () => T = + p.fromStringOption(arg) match + case Some(t) => () => t + case None => error(s"invalid argument for $argName: $arg") + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = { + if arg.nonEmpty then convert(param.name, arg, p) + else defaultArgument match + case Some(defaultGetter) => defaultGetter + case None => error(s"missing argument for ${param.name}") + } + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = { + val getters = args.map(arg => convert(param.name, arg, p)) + () => getters.map(_()) + } + + def run(execProgram: () => Any): Unit = { + if errors.nonEmpty then + for msg <- errors do println(s"Error: $msg") + usage() + else + execProgram() + } + +end newMain + +object newMain: + @experimental + final class Alias(val aliases: String*) extends MainAnnotation.ParameterAnnotation +end newMain