diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala index 722867215adc..4a7be54cfc28 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala @@ -209,6 +209,7 @@ object TypeTestsCasts { * can be true in some cases. Issues a warning or an error otherwise. */ def checkSensical(foundClasses: List[Symbol])(using Context): Boolean = + def exprType = i"type ${expr.tpe.widen.stripAnnots}" def check(foundCls: Symbol): Boolean = if (!isCheckable(foundCls)) true else if (!foundCls.derivesFrom(testCls)) { @@ -216,9 +217,9 @@ object TypeTestsCasts { testCls.is(Final) || !testCls.is(Trait) && !foundCls.is(Trait) ) if (foundCls.is(Final)) - unreachable(i"type ${expr.tpe.widen} is not a subclass of $testCls") + unreachable(i"$exprType is not a subclass of $testCls") else if (unrelated) - unreachable(i"type ${expr.tpe.widen} and $testCls are unrelated") + unreachable(i"$exprType and $testCls are unrelated") else true } else true @@ -227,7 +228,7 @@ object TypeTestsCasts { val foundEffectiveClass = effectiveClass(expr.tpe.widen) if foundEffectiveClass.isPrimitiveValueClass && !testCls.isPrimitiveValueClass then - ctx.error("cannot test if value types are references", tree.sourcePos) + ctx.error(i"cannot test if value of $exprType is a reference of $testCls", tree.sourcePos) false else foundClasses.exists(check) end checkSensical diff --git a/tests/pos/main-method-scheme-class-based.scala b/tests/pos/main-method-scheme-class-based.scala new file mode 100644 index 000000000000..4dece92c9817 --- /dev/null +++ b/tests/pos/main-method-scheme-class-based.scala @@ -0,0 +1,205 @@ +import annotation.StaticAnnotation +import collection.mutable + +/** MainAnnotation provides the functionality for a compiler-generated main class. + * It links a compiler-generated main method (call it compiler-main) to a user + * written main method (user-main). + * The protocol of calls from compiler-main is as follows: + * + * - create a `command` with the command line arguments, + * - for each parameter of user-main, a call to `command.argGetter`, + * or `command.argsGetter` if is a final varargs parameter, + * - a call to `command.run` with the closure of user-main applied to all arguments. + */ +trait MainAnnotation extends StaticAnnotation: + + /** The class used for argument string parsing. E.g. `scala.util.FromString`, + * but could be something else + */ + type ArgumentParser[T] + + /** The required result type of the main function */ + type MainResultType + + /** A new command with arguments from `args` */ + def command(args: Array[String]): Command + + /** A class representing a command to run */ + abstract class Command: + + /** The getter for the next argument of type `T` */ + def argGetter[T](argName: String, fromString: ArgumentParser[T], defaultValue: Option[T] = None): () => T + + /** The getter for a final varargs argument of type `T*` */ + def argsGetter[T](argName: String, fromString: ArgumentParser[T]): () => Seq[T] + + /** Run `program` if all arguments are valid, + * or print usage information and/or error messages. + */ + def run(program: => MainResultType, progName: String, docComment: String): Unit + end Command +end MainAnnotation + +//Sample main class, can be freely implemented: + +class main extends MainAnnotation: + + type ArgumentParser[T] = util.FromString[T] + type MainResultType = Any + + def command(args: Array[String]): Command = new Command: + + /** A buffer of demanded argument names, plus + * "?" if it has a default + * "*" if it is a vararg + * "" otherwise + */ + private var argInfos = new mutable.ListBuffer[(String, String)] + + /** A buffer for all errors */ + private var errors = new mutable.ListBuffer[String] + + /** Issue an error, and return an uncallable getter */ + private def error(msg: String): () => Nothing = + errors += msg + () => assertFail("trying to get invalid argument") + + /** The next argument index */ + private var argIdx: Int = 0 + + private def argAt(idx: Int): Option[String] = + if idx < args.length then Some(args(idx)) else None + + private def nextPositionalArg(): Option[String] = + while argIdx < args.length && args(argIdx).startsWith("--") do argIdx += 2 + val result = argAt(argIdx) + argIdx += 1 + result + + private def convert[T](argName: String, arg: String, p: ArgumentParser[T]): () => T = + p.fromStringOption(arg) match + case Some(t) => () => t + case None => error(s"invalid argument for $argName: $arg") + + def argGetter[T](argName: String, p: ArgumentParser[T], defaultValue: Option[T] = None): () => T = + argInfos += ((argName, if defaultValue.isDefined then "?" else "")) + val idx = args.indexOf(s"--$argName") + val argOpt = if idx >= 0 then argAt(idx + 1) else nextPositionalArg() + argOpt match + case Some(arg) => convert(argName, arg, p) + case None => defaultValue match + case Some(t) => () => t + case None => error(s"missing argument for $argName") + + def argsGetter[T](argName: String, p: ArgumentParser[T]): () => Seq[T] = + argInfos += ((argName, "*")) + def remainingArgGetters(): List[() => T] = nextPositionalArg() match + case Some(arg) => convert(arg, argName, p) :: remainingArgGetters() + case None => Nil + val getters = remainingArgGetters() + () => getters.map(_()) + + def run(f: => MainResultType, progName: String, docComment: String): Unit = + def usage(): Unit = + println(s"Usage: $progName ${argInfos.map(_ + _).mkString(" ")}") + + def explain(): Unit = + if docComment.nonEmpty then println(docComment) // todo: process & format doc comment + + def flagUnused(): Unit = nextPositionalArg() match + case Some(arg) => + error(s"unused argument: $arg") + flagUnused() + case None => + for + arg <- args + if arg.startsWith("--") && !argInfos.map(_._1).contains(arg.drop(2)) + do + error(s"unknown argument name: $arg") + end flagUnused + + if args.isEmpty || args.contains("--help") then + usage() + explain() + else + flagUnused() + if errors.nonEmpty then + for msg <- errors do println(s"Error: $msg") + usage() + else f match + case n: Int if n < 0 => System.exit(-n) + case _ => + end run + end command +end main + +// Sample main method + +object myProgram: + + /** Adds two numbers */ + @main def add(num: Int, inc: Int = 1): Unit = + println(s"$num + $inc = ${num + inc}") + +end myProgram + +// Compiler generated code: + +object add extends main: + def main(args: Array[String]) = + val cmd = command(args) + val arg1 = cmd.argGetter[Int]("num", summon[ArgumentParser[Int]]) + val arg2 = cmd.argGetter[Int]("inc", summon[ArgumentParser[Int]], Some(1)) + cmd.run(myProgram.add(arg1(), arg2()), "add", "Adds two numbers") +end add + +/** --- Some scenarios ---------------------------------------- + +> java add 2 3 +2 + 3 = 5 +> java add 4 +4 + 1 = 5 +> java add --num 10 --inc -2 +10 + -2 = 8 +> java add --num 10 +10 + 1 = 11 +> java add --help +Usage: add num inc? +Adds two numbers +> java add +Usage: add num inc? +Adds two numbers +> java add 1 2 3 4 +Error: unused argument: 3 +Error: unused argument: 4 +Usage: add num inc? +> java add -n 1 -i 10 +Error: invalid argument for num: -n +Error: unused argument: -i +Error: unused argument: 10 +Usage: add num inc? +> java add --n 1 --i 10 +Error: missing argument for num +Error: unknown argument name: --n +Error: unknown argument name: --i +Usage: add num inc? +> java add true 10 +Error: invalid argument for num: true +Usage: add num inc? +> java add true false +Error: invalid argument for num: true +Error: invalid argument for inc: false +Usage: add num inc? +> java add true false 10 +Error: invalid argument for num: true +Error: invalid argument for inc: false +Error: unused argument: 10 +Usage: add num inc? +> java add --inc 10 --num 20 +20 + 10 = 30 +> java add binary 10 01 +Error: invalid argument for num: binary +Error: unused argument: 01 +Usage: add num inc? + +*/ \ No newline at end of file diff --git a/tests/pos/main-method-scheme.scala b/tests/pos/main-method-scheme.scala new file mode 100644 index 000000000000..ce8ea3ac3c19 --- /dev/null +++ b/tests/pos/main-method-scheme.scala @@ -0,0 +1,170 @@ +import annotation.StaticAnnotation +import collection.mutable + +trait MainAnnotation extends StaticAnnotation: + + type ArgumentParser[T] + + // get single argument + def getArg[T](argName: String, fromString: ArgumentParser[T], defaultValue: Option[T] = None): () => T + + // get varargs argument + def getArgs[T](argName: String, fromString: ArgumentParser[T]): () => List[T] + + // check that everything is parsed + def done(): Boolean + +end MainAnnotation + +//Sample main class, can be freely implemented: + +class main(progName: String, args: Array[String], docComment: String) extends MainAnnotation: + + def this() = this("", Array(), "") + + type ArgumentParser[T] = util.FromString[T] + + /** A buffer of demanded argument names, plus + * "?" if it has a default + * "*" if it is a vararg + * "" otherwise + */ + private var argInfos = new mutable.ListBuffer[(String, String)] + + /** A buffer for all errors */ + private var errors = new mutable.ListBuffer[String] + + /** The next argument index */ + private var n: Int = 0 + + private def error(msg: String): () => Nothing = + errors += msg + () => ??? + + private def argAt(idx: Int): Option[String] = + if idx < args.length then Some(args(idx)) else None + + private def nextPositionalArg(): Option[String] = + while n < args.length && args(n).startsWith("--") do n += 2 + val result = argAt(n) + n += 1 + result + + private def convert[T](argName: String, arg: String, p: ArgumentParser[T]): () => T = + p.fromStringOption(arg) match + case Some(t) => () => t + case None => error(s"invalid argument for $argName: $arg") + + def getArg[T](argName: String, p: ArgumentParser[T], defaultValue: Option[T] = None): () => T = + argInfos += ((argName, if defaultValue.isDefined then "?" else "")) + val idx = args.indexOf(s"--$argName") + val argOpt = if idx >= 0 then argAt(idx + 1) else nextPositionalArg() + argOpt match + case Some(arg) => convert(argName, arg, p) + case None => defaultValue match + case Some(t) => () => t + case None => error(s"missing argument for $argName") + + def getArgs[T](argName: String, fromString: ArgumentParser[T]): () => List[T] = + argInfos += ((argName, "*")) + def recur(): List[() => T] = nextPositionalArg() match + case Some(arg) => convert(arg, argName, fromString) :: recur() + case None => Nil + val fns = recur() + () => fns.map(_()) + + def usage(): Boolean = + println(s"Usage: $progName ${argInfos.map(_ + _).mkString(" ")}") + if docComment.nonEmpty then + println(docComment) // todo: process & format doc comment + false + + def showUnused(): Unit = nextPositionalArg() match + case Some(arg) => + error(s"unused argument: $arg") + showUnused() + case None => + for + arg <- args + if arg.startsWith("--") && !argInfos.map(_._1).contains(arg.drop(2)) + do + error(s"unknown argument name: $arg") + end showUnused + + def done(): Boolean = + if args.contains("--help") then + usage() + else + showUnused() + if errors.nonEmpty then + for msg <- errors do println(s"Error: $msg") + usage() + else + true + end done +end main + +// Sample main method + +object myProgram: + + /** Adds two numbers */ + @main def add(num: Int, inc: Int = 1) = + println(s"$num + $inc = ${num + inc}") + +end myProgram + +// Compiler generated code: + +object add: + def main(args: Array[String]) = + val cmd = new main("add", args, "Adds two numbers") + val arg1 = cmd.getArg[Int]("num", summon[cmd.ArgumentParser[Int]]) + val arg2 = cmd.getArg[Int]("inc", summon[cmd.ArgumentParser[Int]], Some(1)) + if cmd.done() then myProgram.add(arg1(), arg2()) +end add + +/** --- Some scenarios ---------------------------------------- + +> java add 2 3 +2 + 3 = 5 +> java add 2 3 +2 + 3 = 5 +> java add 4 +4 + 1 = 5 +> java add --num 10 --inc -2 +10 + -2 = 8 +> java add --num 10 +10 + 1 = 11 +> java add --help +Usage: add num inc? +Adds two numbers +> java add +error: missing argument for num +Usage: add num inc? +Adds two numbers +> java add 1 2 3 +error: unused argument: 3 +Usage: add num inc? +Adds two numbers +> java add --num 1 --incr 2 +error: unknown argument name: --incr +Usage: add num inc? +Adds two numbers +> java add 1 true +error: invalid argument for inc: true +Usage: add num inc? +Adds two numbers +> java add true false +error: invalid argument for num: true +error: invalid argument for inc: false +Usage: add num inc? +Adds two numbers +> java add true false --foo 33 +Error: invalid argument for num: true +Error: invalid argument for inc: false +Error: unknown argument name: --foo +Usage: add num inc? +Adds two numbers + +*/ \ No newline at end of file diff --git a/tests/run/decorators.check b/tests/run/decorators.check new file mode 100644 index 000000000000..ffa7186e3093 --- /dev/null +++ b/tests/run/decorators.check @@ -0,0 +1,85 @@ +> java add 2 3 +2 + 3 = 5 + +> java add 4 +4 + 1 = 5 + +> java add --num 10 --inc -2 +10 + -2 = 8 + +> java add --num 10 +10 + 1 = 11 + +> java add --help +Adds two numbers +Usage: java add num inc? +where + num the first number + inc the second number + +> java add +Error: invalid argument for num: +Usage: java add num inc? +--help gives more information + +> java add 1 2 3 4 +Error: unused argument: 3 +Error: unused argument: 4 +Usage: java add num inc? +--help gives more information + +> java add -n 1 -i 2 +Error: invalid argument for num: -n +Error: unused argument: -i +Error: unused argument: 2 +Usage: java add num inc? +--help gives more information + +> java add true 10 +Error: invalid argument for num: true +Usage: java add num inc? +--help gives more information + +> java add true false +Error: invalid argument for num: true +Error: invalid argument for inc: false +Usage: java add num inc? +--help gives more information + +> java add true false 10 +Error: invalid argument for num: true +Error: invalid argument for inc: false +Error: unused argument: 10 +Usage: java add num inc? +--help gives more information + +> java add --inc 10 --num 20 +20 + 10 = 30 + +> java add binary 10 01 +Error: invalid argument for num: binary +Error: unused argument: 01 +Usage: java add num inc? +--help gives more information + +> java addAll 1 2 3 4 5 +15 +[log] MyProgram.addAll(1 2 3 4 5) -> () + +> java addAll --nums +0 +[log] MyProgram.addAll(--nums) -> () + +> java addAll --nums 33 44 +44 +[log] MyProgram.addAll(--nums 33 44) -> () + +> java addAll true 1 2 3 +Error: invalid argument for nums: true +Usage: java addAll --nums +[log] MyProgram.addAll(true 1 2 3) -> () + +> java addAll --help +Usage: java addAll --nums +[log] MyProgram.addAll(--help) -> () + diff --git a/tests/run/decorators/DocComment.scala b/tests/run/decorators/DocComment.scala new file mode 100644 index 000000000000..85b30fbce393 --- /dev/null +++ b/tests/run/decorators/DocComment.scala @@ -0,0 +1,26 @@ +/** Represents a doc comment, splitting it into `body` and `tags` + * `tags` are all lines starting with an `@`, where the tag thats starts + * with `@` is paired with the text that follows, up to the next + * tagged line. + * `body` what comes before the first tagged line + */ +case class DocComment(body: String, tags: Map[String, List[String]]) +object DocComment: + def fromString(str: String): DocComment = + val lines = str.linesIterator.toList + def tagged(line: String): Option[(String, String)] = + val ws = WordSplitter(line) + val tag = ws.next() + if tag.startsWith("@") then Some(tag, line.drop(ws.nextOffset)) + else None + val (bodyLines, taggedLines) = lines.span(tagged(_).isEmpty) + def tagPairs(lines: List[String]): List[(String, String)] = lines match + case line :: lines1 => + val (tag, descPrefix) = tagged(line).get + val (untaggedLines, lines2) = lines1.span(tagged(_).isEmpty) + val following = untaggedLines.map(_.dropWhile(_ <= ' ')) + (tag, (descPrefix :: following).mkString("\n")) :: tagPairs(lines2) + case _ => + Nil + DocComment(bodyLines.mkString("\n"), tagPairs(taggedLines).groupMap(_._1)(_._2)) +end DocComment \ No newline at end of file diff --git a/tests/run/decorators/EntryPoint.scala b/tests/run/decorators/EntryPoint.scala new file mode 100644 index 000000000000..f87519fabee0 --- /dev/null +++ b/tests/run/decorators/EntryPoint.scala @@ -0,0 +1,183 @@ +import collection.mutable + +/** A framework for defining stackable entry point wrappers */ +object EntryPoint: + + /** A base trait for wrappers of entry points. + * Sub-traits: Annotation#Wrapper + * Adapter#Wrapper + */ + sealed trait Wrapper + + /** This class provides a framework for compiler-generated wrappers + * of "entry-point" methods. It routes and transforms parameters and results + * between a compiler-generated wrapper method that has calling conventions + * fixed by a framework and a user-written entry-point method that can have + * flexible argument lists. It allows the wrapper to provide help and usage + * information as well as customised error messages if the actual wrapper arguments + * do not match the expected entry-point parameters. + * + * The protocol of calls from the wrapper method is as follows: + * + * 1. Create a `call` instance with the wrapper argument. + * 2. For each parameter of the entry-point, invoke `call.nextArgGetter`, + * or `call.finalArgsGetter` if is a final varargs parameter. + * 3. Invoke `call.run` with the closure of entry-point applied to all arguments. + * + * The wrapper class has this outline: + * + * object : + * @WrapperAnnotation def (args: ) = + * ... + * + * Here `` and `` are obtained from an + * inline call to the `wrapperName` method. + */ + trait Annotation extends annotation.StaticAnnotation: + + /** The class used for argument parsing. E.g. `scala.util.FromString`, if + * arguments are strings, but it could be something else. + */ + type ArgumentParser[T] + + /** The required result type of the user-defined main function */ + type EntryPointResult + + /** The fully qualified name (relative to enclosing package) to + * use for the static wrapper method. + * @param entryPointName the fully qualified name of the user-defined entry point method + */ + inline def wrapperName(entryPointName: String): String + + /** Create an entry point wrapper. + * @param entryPointName the fully qualified name of the user-defined entry point method + * @param docComment the doc comment of the user-defined entry point method + */ + def wrapper(entryPointName: String, docComment: String): Wrapper + + /** Base class for descriptions of an entry point wrappers */ + abstract class Wrapper extends EntryPoint.Wrapper: + + /** The type of the wrapper argument. E.g., for Java main methods: `Array[String]` */ + type Argument + + /** The return type of the generated wrapper. E.g., for Java main methods: `Unit` */ + type Result + + /** An annotation type with which the wrapper method is decorated. + * No annotation is generated if the type is left abstract. + * Multiple annotations are generated if the type is an intersection of annotations. + */ + type WrapperAnnotation <: annotation.Annotation + + /** The fully qualified name of the user-defined entry point method that is wrapped */ + val entryPointName: String + + /** The doc comment of the user-defined entry point method that is wrapped */ + val docComment: String + + /** A new wrapper call with arguments from `args` */ + def call(arg: Argument): Call + + /** A class representing a wrapper call */ + abstract class Call: + + /** The getter for the next argument of type `T` */ + def nextArgGetter[T](argName: String, fromString: ArgumentParser[T], defaultValue: Option[T] = None): () => T + + /** The getter for a final varargs argument of type `T*` */ + def finalArgsGetter[T](argName: String, fromString: ArgumentParser[T]): () => Seq[T] + + /** Run `entryPointWithArgs` if all arguments are valid, + * or print usage information and/or error messages. + * @param entryPointWithArgs the applied entry-point to run + */ + def run(entryPointWithArgs: => EntryPointResult): Result + end Call + end Wrapper + end Annotation + + /** An annotation that generates an adapter of an entry point wrapper. + * An `EntryPoint.Adapter` annotation should always be written together + * with an `EntryPoint.Annotation` and the adapter should be given first. + * If several adapters are present, they are applied right to left. + * Example: + * + * @logged @transactional @main def f(...) + * + * This wraps the main method generated by @main first in a `transactional` + * wrapper and then in a `logged` wrapper. The result would look like this: + * + * $logged$wrapper.adapt { y => + * $transactional$wrapper.adapt { z => + * val cll = $main$wrapper.call(z) + * val arg1 = ... + * ... + * val argN = ... + * cll.run(...) + * } (y) + * } (x) + * + * where + * + * - $logged$wrapper, $transactional$wrapper, $main$wrapper are the wrappers + * created from @logged, @transactional, and @main, respectively. + * - `x` is the argument of the outer $logged$wrapper. + */ + trait Adapter extends annotation.StaticAnnotation: + + /** Creates a new wrapper around `wrapped` */ + def wrapper(wrapped: EntryPoint.Wrapper): Wrapper + + /** The wrapper class. A wrapper class must define a method `adapt` + * that maps unary functions to unary functions. A typical definition + * of `adapt` is: + * + * def adapt(f: A1 => B1)(a: A2): B2 = toB2(f(toA1(a))) + * + * This version of `adapt` converts its argument `a` to the wrapped + * function's argument type `A1`, applies the function, and converts + * the application's result back to type `B2`. `adapt` can also call + * the wrapped function only under a condition or call it multiple times. + * + * `adapt` can also be polymorphic. For instance: + * + * def adapt[R](f: A1 => R)(a: A2): R = f(toA1(a)) + * + * or + * + * def adapt[A, R](f: A => R)(a: A): R = { log(a); f(a) } + * + * Since `adapt` can be of many forms, the base class does not provide + * an abstract method that needs to be implemented in concrete wrapper + * classes. Instead it is checked that the types line up when adapt chains + * are assembled. + * + * I investigated an explicitly typed approach, but could not arrive at a + * solution that was clean and simple enough. If Scala had more dependent + * type support, it would be quite straightforward, i.e. generic `Wrapper` + * could be defined like this: + * + * class Wrapper(wrapped: EntryPoint.Wrapper): + * type Argument + * type Result + * def adapt(f: wrapped.Argument => wrapped.Result)(x: Argument): Result + * + * But to get this to work, we'd need support for types depending on their + * arguments, e.g. a type of the form `Wrapper(wrapped)`. That's an interesting + * avenue to pursue. Until that materializes I think it's preferable to + * keep the `adapt` type contract implicit (types are still checked when adapts + * are generated, of course). + */ + abstract class Wrapper extends EntryPoint.Wrapper: + /** The wrapper that this wrapped in turn by this wrapper */ + val wrapped: EntryPoint.Wrapper + + /** The wrapper of the entry point annotation that this wrapper + * wraps directly or indirectly + */ + def finalWrapped: EntryPoint.Annotation#Wrapper = wrapped match + case wrapped: EntryPoint.Adapter#Wrapper => wrapped.finalWrapped + case wrapped: EntryPoint.Annotation#Wrapper => wrapped + end Wrapper + end Adapter diff --git a/tests/run/decorators/Test.scala b/tests/run/decorators/Test.scala new file mode 100644 index 000000000000..1d821d2109d2 --- /dev/null +++ b/tests/run/decorators/Test.scala @@ -0,0 +1,30 @@ +object Test: + def main(args: Array[String]) = + def testAdd(args: String) = + println(s"> java add $args") + add.main(args.split(" ")) + println() + def testAddAll(args: String) = + println(s"> java addAll $args") + addAll.main(args.split(" ")) + println() + + testAdd("2 3") + testAdd("4") + testAdd("--num 10 --inc -2") + testAdd("--num 10") + testAdd("--help") + testAdd("") + testAdd("1 2 3 4") + testAdd("-n 1 -i 2") + testAdd("true 10") + testAdd("true false") + testAdd("true false 10") + testAdd("--inc 10 --num 20") + testAdd("binary 10 01") + testAddAll("1 2 3 4 5") + testAddAll("--nums") + testAddAll("--nums 33 44") + testAddAll("true 1 2 3") + testAddAll("--help") +end Test \ No newline at end of file diff --git a/tests/run/decorators/WordSplitter.scala b/tests/run/decorators/WordSplitter.scala new file mode 100644 index 000000000000..4195fa73a220 --- /dev/null +++ b/tests/run/decorators/WordSplitter.scala @@ -0,0 +1,30 @@ +/** An iterator to return words in a string while keeping tarck of their offsets */ +class WordSplitter(str: String, start: Int = 0, isSeparator: Char => Boolean = _ <= ' ') +extends Iterator[String]: + private var idx: Int = start + private var lastIdx: Int = start + private var word: String = _ + + private def skipSeparators() = + while idx < str.length && isSeparator(str(idx)) do + idx += 1 + + def lastWord = word + def lastOffset = lastIdx + + def nextOffset = + skipSeparators() + idx + + def next(): String = + skipSeparators() + lastIdx = idx + val b = new StringBuilder + while idx < str.length && !isSeparator(str(idx)) do + b += str(idx) + idx += 1 + word = b.toString + word + + def hasNext: Boolean = nextOffset < str.length +end WordSplitter \ No newline at end of file diff --git a/tests/run/decorators/main.scala b/tests/run/decorators/main.scala new file mode 100644 index 000000000000..a73760252049 --- /dev/null +++ b/tests/run/decorators/main.scala @@ -0,0 +1,124 @@ +import collection.mutable + +/** A sample @main entry point annotation. + * Generates a main function. + */ +class main extends EntryPoint.Annotation: + + type ArgumentParser[T] = util.FromString[T] + type EntryPointResult = Unit + + inline def wrapperName(entryPointName: String): String = + s"${entryPointName.drop(entryPointName.lastIndexOf('.') + 1)}.main" + + def wrapper(name: String, doc: String): MainWrapper = new MainWrapper(name, doc) + + class MainWrapper(val entryPointName: String, val docComment: String) extends Wrapper: + type Argument = Array[String] + type Result = Unit + + def call(args: Array[String]) = new Call: + + /** A buffer of demanded argument names, plus + * "?" if it has a default + * "*" if it is a vararg + * "" otherwise + */ + private var argInfos = new mutable.ListBuffer[(String, String)] + + /** A buffer for all errors */ + private var errors = new mutable.ListBuffer[String] + + /** Issue an error, and return an uncallable getter */ + private def error(msg: String): () => Nothing = + errors += msg + () => assertFail("trying to get invalid argument") + + /** The next argument index */ + private var argIdx: Int = 0 + + private def argAt(idx: Int): Option[String] = + if idx < args.length then Some(args(idx)) else None + + private def nextPositionalArg(): Option[String] = + while argIdx < args.length && args(argIdx).startsWith("--") do argIdx += 2 + val result = argAt(argIdx) + argIdx += 1 + result + + private def convert[T](argName: String, arg: String, p: ArgumentParser[T]): () => T = + p.fromStringOption(arg) match + case Some(t) => () => t + case None => error(s"invalid argument for $argName: $arg") + + def nextArgGetter[T](argName: String, p: ArgumentParser[T], defaultValue: Option[T] = None): () => T = + argInfos += ((argName, if defaultValue.isDefined then "?" else "")) + val idx = args.indexOf(s"--$argName") + val argOpt = if idx >= 0 then argAt(idx + 1) else nextPositionalArg() + argOpt match + case Some(arg) => convert(argName, arg, p) + case None => defaultValue match + case Some(t) => () => t + case None => error(s"missing argument for $argName") + + def finalArgsGetter[T](argName: String, p: ArgumentParser[T]): () => Seq[T] = + argInfos += ((argName, "*")) + def remainingArgGetters(): List[() => T] = nextPositionalArg() match + case Some(arg) => convert(argName, arg, p) :: remainingArgGetters() + case None => Nil + val getters = remainingArgGetters() + () => getters.map(_()) + + def run(entryPointWithArgs: => EntryPointResult): Unit = + lazy val DocComment(explanation, docTags) = DocComment.fromString(docComment) + + def usageString = + docTags.get("@usage") match + case Some(s :: _) => s + case _ => + val cmd = wrapperName(entryPointName).stripSuffix(".main") + val params = argInfos.map(_ + _).mkString(" ") + s"java $cmd $params" + + def printUsage() = println(s"Usage: $usageString") + + def explain(): Unit = + if explanation.nonEmpty then println(explanation) + printUsage() + docTags.get("@param") match + case Some(paramInfos) => + println("where") + for paramInfo <- paramInfos do + val ws = WordSplitter(paramInfo) + val name = ws.next() + val desc = paramInfo.drop(ws.nextOffset) + println(s" $name $desc") + case None => + end explain + + def flagUnused(): Unit = nextPositionalArg() match + case Some(arg) => + error(s"unused argument: $arg") + flagUnused() + case None => + for + arg <- args + if arg.startsWith("--") && !argInfos.map(_._1).contains(arg.drop(2)) + do + error(s"unknown argument name: $arg") + end flagUnused + + if args.contains("--help") then + explain() + else + flagUnused() + if errors.nonEmpty then + for msg <- errors do println(s"Error: $msg") + printUsage() + if explanation.nonEmpty || docTags.contains("@param") then + println("--help gives more information") + else entryPointWithArgs + end run + end call + end MainWrapper +end main diff --git a/tests/run/decorators/sample-adapters.scala b/tests/run/decorators/sample-adapters.scala new file mode 100644 index 000000000000..77622261d16e --- /dev/null +++ b/tests/run/decorators/sample-adapters.scala @@ -0,0 +1,34 @@ +// Sample adapters: + +class logged extends EntryPoint.Adapter: + + def wrapper(wrapped: EntryPoint.Wrapper): LoggedWrapper = LoggedWrapper(wrapped) + + class LoggedWrapper(val wrapped: EntryPoint.Wrapper) extends Wrapper: + def adapt[A, R](op: A => R)(args: A): R = + val argsString: String = args match + case args: Array[_] => args.mkString(", ") + case args: Seq[_] => args.mkString(", ") + case args: Unit => "()" + case args => args.toString + val result = op(args) + println(s"[log] ${finalWrapped.entryPointName}($argsString) -> $result") + result + end LoggedWrapper +end logged + +class split extends EntryPoint.Adapter: + + def wrapper(wrapped: EntryPoint.Wrapper): SplitWrapper = SplitWrapper(wrapped) + + class SplitWrapper(val wrapped: EntryPoint.Wrapper) extends Wrapper: + def adapt[R](op: Array[String] => R)(args: String): R = op(args.split(" ")) +end split + +class join extends EntryPoint.Adapter: + + def wrapper(wrapped: EntryPoint.Wrapper): JoinWrapper = JoinWrapper(wrapped) + + class JoinWrapper(val wrapped: EntryPoint.Wrapper) extends Wrapper: + def adapt[R](op: String => R)(args: Array[String]): R = op(args.mkString(" ")) +end join diff --git a/tests/run/decorators/sample-program.scala b/tests/run/decorators/sample-program.scala new file mode 100644 index 000000000000..1f58f4cd3260 --- /dev/null +++ b/tests/run/decorators/sample-program.scala @@ -0,0 +1,51 @@ +object myProgram: + + /** Adds two numbers + * @param num the first number + * @param inc the second number + */ + @main def add(num: Int, inc: Int = 1): Unit = + println(s"$num + $inc = ${num + inc}") + + /** @usage java addAll --nums */ + @join @logged @split @main def addAll(nums: Int*): Unit = + println(nums.sum) + +end myProgram + +// Compiler generated code: + +object add: + private val $main = new main() + private val $main$wrapper = $main.wrapper( + "MyProgram.add", + """Adds two numbers + |@param num the first number + |@param inc the second number""".stripMargin) + def main(args: Array[String]): Unit = + val cll = $main$wrapper.call(args) + val arg1 = cll.nextArgGetter[Int]("num", summon[$main.ArgumentParser[Int]]) + val arg2 = cll.nextArgGetter[Int]("inc", summon[$main.ArgumentParser[Int]], Some(1)) + cll.run(myProgram.add(arg1(), arg2())) +end add + +object addAll: + private val $main = new main() + private val $split = new split() + private val $logged = new logged() + private val $join = new join() + private val $main$wrapper = $main.wrapper("MyProgram.addAll", "@usage java addAll --nums ") + private val $split$wrapper = $split.wrapper($main$wrapper) + private val $logged$wrapper = $logged.wrapper($split$wrapper) + private val $join$wrapper = $join.wrapper($logged$wrapper) + def main(args: Array[String]): Unit = + $join$wrapper.adapt { (args: String) => + $logged$wrapper.adapt { (args: String) => + $split$wrapper.adapt { (args: Array[String]) => + val cll = $main$wrapper.call(args) + val arg1 = cll.finalArgsGetter[Int]("nums", summon[$main.ArgumentParser[Int]]) + cll.run(myProgram.addAll(arg1(): _*)) + } (args) + } (args) + } (args) +end addAll \ No newline at end of file