Skip to content

Commit c50ffaa

Browse files
Simplify homemade annotations tests
1 parent 5755160 commit c50ffaa

File tree

2 files changed

+16
-416
lines changed

2 files changed

+16
-416
lines changed

tests/run/main-annotation-homemade-annot-1.scala

Lines changed: 7 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -22,223 +22,22 @@ end Test
2222

2323
class mainAwait(timeout: Int = 2) extends MainAnnotation:
2424
import MainAnnotation._
25-
import main.{Arg}
26-
27-
private val maxLineLength = 120
2825

2926
override type ArgumentParser[T] = util.CommandLineParser.FromString[T]
3027
override type MainResultType = Future[Any]
3128

32-
private enum ArgumentKind {
33-
case SimpleArgument, OptionalArgument, VarArgument
34-
}
35-
29+
// This is a toy example, it only works with positional args
3630
override def command(args: Array[String], commandName: String, docComment: String) =
3731
new Command[ArgumentParser, MainResultType]:
38-
private val argMarker = "--"
39-
private val shortArgMarker = "-"
40-
41-
private var argNames = new mutable.ArrayBuffer[String]
42-
private var argShortNames = new mutable.ArrayBuffer[Option[Char]]
43-
private var argTypes = new mutable.ArrayBuffer[String]
44-
private var argDocs = new mutable.ArrayBuffer[String]
45-
private var argKinds = new mutable.ArrayBuffer[ArgumentKind]
46-
47-
/** A buffer for all errors */
48-
private var errors = new mutable.ArrayBuffer[String]
49-
50-
/** Issue an error, and return an uncallable getter */
51-
private def error(msg: String): () => Nothing =
52-
errors += msg
53-
() => throw new AssertionError("trying to get invalid argument")
54-
55-
/** The next argument index */
56-
private var argIdx: Int = 0
57-
58-
private def argAt(idx: Int): Option[String] =
59-
if idx < args.length then Some(args(idx)) else None
60-
61-
private def isArgNameAt(idx: Int): Boolean =
62-
val arg = args(argIdx)
63-
val isFullName = arg.startsWith(argMarker)
64-
val isShortName = arg.startsWith(shortArgMarker) && arg.length == 2 && shortNameIsValid(arg(1))
65-
66-
isFullName || isShortName
67-
68-
private def nextPositionalArg(): Option[String] =
69-
while argIdx < args.length && isArgNameAt(argIdx) do argIdx += 2
70-
val result = argAt(argIdx)
71-
argIdx += 1
72-
result
73-
74-
private def shortNameIsValid(shortName: Char): Boolean =
75-
shortName == 0 || shortName.isLetter
76-
77-
private def convert[T](argName: String, arg: String, p: ArgumentParser[T]): () => T =
78-
p.fromStringOption(arg) match
79-
case Some(t) => () => t
80-
case None => error(s"invalid argument for $argName: $arg")
81-
82-
private def argUsage(pos: Int): String =
83-
val name = argNames(pos)
84-
val namePrint = argShortNames(pos).map(short => s"[$shortArgMarker$short | $argMarker$name]").getOrElse(s"[$argMarker$name]")
85-
86-
argKinds(pos) match {
87-
case ArgumentKind.SimpleArgument => s"$namePrint <${argTypes(pos)}>"
88-
case ArgumentKind.OptionalArgument => s"[$namePrint <${argTypes(pos)}>]"
89-
case ArgumentKind.VarArgument => s"[<${argTypes(pos)}> [<${argTypes(pos)}> [...]]]"
90-
}
91-
92-
private def wrapLongLine(line: String, maxLength: Int): List[String] = {
93-
def recurse(s: String, acc: Vector[String]): Seq[String] =
94-
val lastSpace = s.trim.nn.lastIndexOf(' ', maxLength)
95-
if ((s.length <= maxLength) || (lastSpace < 0))
96-
acc :+ s
97-
else {
98-
val (shortLine, rest) = s.splitAt(lastSpace)
99-
recurse(rest.trim.nn, acc :+ shortLine)
100-
}
101-
102-
recurse(line, Vector()).toList
103-
}
104-
105-
private def wrapArgumentUsages(argsUsage: List[String], maxLength: Int): List[String] = {
106-
def recurse(args: List[String], currentLine: String, acc: Vector[String]): Seq[String] =
107-
(args, currentLine) match {
108-
case (Nil, "") => acc
109-
case (Nil, l) => (acc :+ l)
110-
case (arg :: t, "") => recurse(t, arg, acc)
111-
case (arg :: t, l) if l.length + 1 + arg.length <= maxLength => recurse(t, s"$l $arg", acc)
112-
case (arg :: t, l) => recurse(t, arg, acc :+ l)
113-
}
114-
115-
recurse(argsUsage, "", Vector()).toList
116-
}
117-
118-
private inline def shiftLines(s: Seq[String], shift: Int): String = s.map(" " * shift + _).mkString("\n")
119-
120-
private def usage(): Unit =
121-
val usageBeginning = s"Usage: $commandName "
122-
val argsOffset = usageBeginning.length
123-
val argUsages = wrapArgumentUsages((0 until argNames.length).map(argUsage).toList, maxLineLength - argsOffset)
124-
125-
println(usageBeginning + argUsages.mkString("\n" + " " * argsOffset))
126-
127-
private def explain(): Unit =
128-
if (docComment.nonEmpty)
129-
println(wrapLongLine(docComment, maxLineLength).mkString("\n"))
130-
if (argNames.nonEmpty) {
131-
val argNameShift = 2
132-
val argDocShift = argNameShift + 2
133-
134-
println("Arguments:")
135-
for (pos <- 0 until argNames.length)
136-
val argDoc = StringBuilder(" " * argNameShift)
137-
argDoc.append(s"${argNames(pos)} - ${argTypes(pos)}")
138-
139-
argKinds(pos) match {
140-
case ArgumentKind.OptionalArgument => argDoc.append(" (optional)")
141-
case ArgumentKind.VarArgument => argDoc.append(" (vararg)")
142-
case _ =>
143-
}
144-
145-
if (argDocs(pos).nonEmpty) {
146-
val shiftedDoc =
147-
argDocs(pos).split("\n").nn
148-
.map(line => shiftLines(wrapLongLine(line.nn, maxLineLength - argDocShift), argDocShift))
149-
.mkString("\n")
150-
argDoc.append("\n").append(shiftedDoc)
151-
}
152-
153-
println(argDoc)
154-
}
155-
156-
private def indicesOfArg(argName: String, shortArgName: Option[Char]): Seq[Int] =
157-
def allIndicesOf(s: String, from: Int): Seq[Int] =
158-
val i = args.indexOf(s, from)
159-
if i < 0 then Seq() else i +: allIndicesOf(s, i + 1)
160-
161-
val indices = allIndicesOf(s"$argMarker$argName", 0)
162-
val indicesShort = shortArgName.map(shortName => allIndicesOf(s"$shortArgMarker$shortName", 0)).getOrElse(Seq())
163-
(indices ++: indicesShort).filter(_ >= 0)
164-
165-
private def getArgGetter[T](paramInfos: ParameterInfos[_], getDefaultGetter: () => () => T)(using p: ArgumentParser[T]): () => T =
166-
val argName = getEffectiveName(paramInfos)
167-
indicesOfArg(argName, getShortName(paramInfos)) match {
168-
case s @ (Seq() | Seq(_)) =>
169-
val argOpt = s.headOption.map(idx => argAt(idx + 1)).getOrElse(nextPositionalArg())
170-
argOpt match {
171-
case Some(arg) => convert(argName, arg, p)
172-
case None => getDefaultGetter()
173-
}
174-
case s =>
175-
val multValues = s.flatMap(idx => argAt(idx + 1))
176-
error(s"more than one value for $argName: ${multValues.mkString(", ")}")
177-
}
178-
179-
private inline def getEffectiveName(paramInfos: ParameterInfos[_]): String =
180-
paramInfos.annotations.collectFirst{ case arg: Arg if arg.name.length > 0 => arg.name }.getOrElse(paramInfos.name)
181-
182-
private inline def getShortName(paramInfos: ParameterInfos[_]): Option[Char] =
183-
paramInfos.annotations.collectFirst{ case arg: Arg if arg.shortName != 0 => arg.shortName }
184-
185-
private def registerArg(paramInfos: ParameterInfos[_], argKind: ArgumentKind): Unit =
186-
argNames += getEffectiveName(paramInfos)
187-
argTypes += paramInfos.typeName
188-
argDocs += paramInfos.documentation.getOrElse("")
189-
argKinds += argKind
190-
191-
val shortName = getShortName(paramInfos)
192-
shortName.foreach(c => if !shortNameIsValid(c) then throw IllegalArgumentException(s"Invalid short name: $shortArgMarker$c"))
193-
argShortNames += shortName
32+
private var idx = 0
19433

19534
override def argGetter[T](paramInfos: ParameterInfos[T])(using p: ArgumentParser[T]): () => T =
196-
val name = getEffectiveName(paramInfos)
197-
val (defaultGetter, argumentKind) = paramInfos.defaultValueOpt match {
198-
case Some(value) => (() => value, ArgumentKind.OptionalArgument)
199-
case None => (() => error(s"missing argument for $name"), ArgumentKind.SimpleArgument)
200-
}
201-
registerArg(paramInfos, argumentKind)
202-
getArgGetter(paramInfos, defaultGetter)
35+
val i = idx
36+
idx += 1
37+
() => p.fromString(args(i))
20338

20439
override def varargGetter[T](paramInfos: ParameterInfos[T])(using p: ArgumentParser[T]): () => Seq[T] =
205-
registerArg(paramInfos, ArgumentKind.VarArgument)
206-
def remainingArgGetters(): List[() => T] = nextPositionalArg() match
207-
case Some(arg) => convert(getEffectiveName(paramInfos), arg, p) :: remainingArgGetters()
208-
case None => Nil
209-
val getters = remainingArgGetters()
210-
() => getters.map(_())
211-
212-
override def run(f: => MainResultType): Unit =
213-
def checkShortNamesUnique(): Unit =
214-
val shortNameToIndices = argShortNames.collect{ case Some(short) => short }.zipWithIndex.groupBy(_._1).view.mapValues(_.map(_._2))
215-
for ((shortName, indices) <- shortNameToIndices if indices.length > 1)
216-
error(s"$shortName is used as short name for multiple parameters: ${indices.map(idx => argNames(idx)).mkString(", ")}")
217-
218-
def flagUnused(): Unit = nextPositionalArg() match
219-
case Some(arg) =>
220-
error(s"unused argument: $arg")
221-
flagUnused()
222-
case None =>
223-
for
224-
arg <- args
225-
if arg.startsWith(argMarker) && !argNames.contains(arg.drop(2))
226-
do
227-
error(s"unknown argument name: $arg")
228-
end flagUnused
40+
() => for i <- (idx until args.length) yield p.fromString(args(i))
22941

230-
if args.contains(s"${argMarker}help") then
231-
usage()
232-
println()
233-
explain()
234-
else
235-
flagUnused()
236-
checkShortNamesUnique()
237-
if errors.nonEmpty then
238-
for msg <- errors do println(s"Error: $msg")
239-
usage()
240-
else
241-
println(Await.result(f, Duration(timeout, SECONDS)))
242-
end run
243-
end command
42+
override def run(f: => MainResultType): Unit = println(Await.result(f, Duration(timeout, SECONDS)))
24443
end mainAwait

0 commit comments

Comments
 (0)