@@ -21,7 +21,8 @@ object Test:
21
21
end Test
22
22
23
23
class mainAwait (timeout : Int = 2 ) extends MainAnnotation :
24
- self =>
24
+ import MainAnnotation ._
25
+ import main .{Arg }
25
26
26
27
private val maxLineLength = 120
27
28
@@ -33,8 +34,12 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
33
34
}
34
35
35
36
override def command (args : Array [String ], commandName : String , docComment : String ) =
36
- new MainAnnotation .Command [ArgumentParser , MainResultType ]:
37
+ new Command [ArgumentParser , MainResultType ]:
38
+ private val argMarker = " --"
39
+ private val shortArgMarker = " -"
40
+
37
41
private var argNames = new mutable.ArrayBuffer [String ]
42
+ private var argShortNames = new mutable.ArrayBuffer [Option [Char ]]
38
43
private var argTypes = new mutable.ArrayBuffer [String ]
39
44
private var argDocs = new mutable.ArrayBuffer [String ]
40
45
private var argKinds = new mutable.ArrayBuffer [ArgumentKind ]
@@ -53,23 +58,34 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
53
58
private def argAt (idx : Int ): Option [String ] =
54
59
if idx < args.length then Some (args(idx)) else None
55
60
61
+ private def isArgNameAt (idx : Int ): Boolean =
62
+ val arg = args(argIdx)
63
+ val isFullName = arg.startsWith(argMarker)
64
+ val isShortName = arg.startsWith(shortArgMarker) && arg.length == 2 && shortNameIsValid(arg(1 ))
65
+
66
+ isFullName || isShortName
67
+
56
68
private def nextPositionalArg (): Option [String ] =
57
- while argIdx < args.length && args (argIdx).startsWith( " -- " ) do argIdx += 2
69
+ while argIdx < args.length && isArgNameAt (argIdx) do argIdx += 2
58
70
val result = argAt(argIdx)
59
71
argIdx += 1
60
72
result
61
73
74
+ private def shortNameIsValid (shortName : Char ): Boolean =
75
+ shortName == 0 || shortName.isLetter
76
+
62
77
private def convert [T ](argName : String , arg : String , p : ArgumentParser [T ]): () => T =
63
78
p.fromStringOption(arg) match
64
79
case Some (t) => () => t
65
80
case None => error(s " invalid argument for $argName: $arg" )
66
81
67
82
private def argUsage (pos : Int ): String =
68
83
val name = argNames(pos)
84
+ val namePrint = argShortNames(pos).map(short => s " [ $shortArgMarker$short | $argMarker$name] " ).getOrElse(s " [ $argMarker$name] " )
69
85
70
86
argKinds(pos) match {
71
- case ArgumentKind .SimpleArgument => s " [-- $name ] < ${argTypes(pos)}> "
72
- case ArgumentKind .OptionalArgument => s " [[-- $name ] < ${argTypes(pos)}>] "
87
+ case ArgumentKind .SimpleArgument => s " $namePrint < ${argTypes(pos)}> "
88
+ case ArgumentKind .OptionalArgument => s " [ $namePrint < ${argTypes(pos)}>] "
73
89
case ArgumentKind .VarArgument => s " [< ${argTypes(pos)}> [< ${argTypes(pos)}> [...]]] "
74
90
}
75
91
@@ -137,19 +153,18 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
137
153
println(argDoc)
138
154
}
139
155
140
- private def indicesOfArg (argName : String ): Seq [Int ] =
141
- def allIndicesOf (s : String ): Seq [Int ] =
142
- def recurse (s : String , from : Int ): Seq [Int ] =
143
- val i = args.indexOf(s, from)
144
- if i < 0 then Seq () else i +: recurse(s, i + 1 )
145
-
146
- recurse(s, 0 )
156
+ private def indicesOfArg (argName : String , shortArgName : Option [Char ]): Seq [Int ] =
157
+ def allIndicesOf (s : String , from : Int ): Seq [Int ] =
158
+ val i = args.indexOf(s, from)
159
+ if i < 0 then Seq () else i +: allIndicesOf(s, i + 1 )
147
160
148
- val indices = allIndicesOf(s " -- $argName" )
149
- indices.filter(_ >= 0 )
161
+ val indices = allIndicesOf(s " $argMarker$argName" , 0 )
162
+ val indicesShort = shortArgName.map(shortName => allIndicesOf(s " $shortArgMarker$shortName" , 0 )).getOrElse(Seq ())
163
+ (indices ++: indicesShort).filter(_ >= 0 )
150
164
151
- private def getArgGetter [T ](argName : String , getDefaultGetter : () => () => T )(using p : ArgumentParser [T ]): () => T =
152
- indicesOfArg(argName) match {
165
+ private def getArgGetter [T ](paramInfos : ParameterInfos [_], getDefaultGetter : () => () => T )(using p : ArgumentParser [T ]): () => T =
166
+ val argName = getEffectiveName(paramInfos)
167
+ indicesOfArg(argName, getShortName(paramInfos)) match {
153
168
case s @ (Seq () | Seq (_)) =>
154
169
val argOpt = s.headOption.map(idx => argAt(idx + 1 )).getOrElse(nextPositionalArg())
155
170
argOpt match {
@@ -161,48 +176,64 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
161
176
error(s " more than one value for $argName: ${multValues.mkString(" , " )}" )
162
177
}
163
178
164
- private def registerArg (paramInfos : MainAnnotation .ParameterInfos [_], argKind : ArgumentKind ): Unit =
165
- argNames += paramInfos.name
179
+ private inline def getEffectiveName (paramInfos : ParameterInfos [_]): String =
180
+ paramInfos.annotations.collectFirst{ case arg : Arg if arg.name.length > 0 => arg.name }.getOrElse(paramInfos.name)
181
+
182
+ private inline def getShortName (paramInfos : ParameterInfos [_]): Option [Char ] =
183
+ paramInfos.annotations.collectFirst{ case arg : Arg if arg.shortName != 0 => arg.shortName }
184
+
185
+ private def registerArg (paramInfos : ParameterInfos [_], argKind : ArgumentKind ): Unit =
186
+ argNames += getEffectiveName(paramInfos)
166
187
argTypes += paramInfos.typeName
167
188
argDocs += paramInfos.documentation.getOrElse(" " )
168
189
argKinds += argKind
169
190
170
- override def argGetter [T ](paramInfos : MainAnnotation .ParameterInfos [T ])(using p : ArgumentParser [T ]): () => T =
171
- val name = paramInfos.name
172
- val (defaultGetter, argumentKind) = paramInfos.defaultValue match {
173
- case Some (value) => (() => () => value, ArgumentKind .OptionalArgument )
191
+ val shortName = getShortName(paramInfos)
192
+ shortName.foreach(c => if ! shortNameIsValid(c) then throw IllegalArgumentException (s " Invalid short name: $shortArgMarker$c" ))
193
+ argShortNames += shortName
194
+
195
+ override def argGetter [T ](paramInfos : ParameterInfos [T ])(using p : ArgumentParser [T ]): () => T =
196
+ val name = getEffectiveName(paramInfos)
197
+ val (defaultGetter, argumentKind) = paramInfos.defaultValueOpt match {
198
+ case Some (value) => (() => value, ArgumentKind .OptionalArgument )
174
199
case None => (() => error(s " missing argument for $name" ), ArgumentKind .SimpleArgument )
175
200
}
176
201
registerArg(paramInfos, argumentKind)
177
- getArgGetter(name , defaultGetter)
202
+ getArgGetter(paramInfos , defaultGetter)
178
203
179
- override def varargGetter [T ](paramInfos : MainAnnotation . ParameterInfos [T ])(using p : ArgumentParser [T ]): () => Seq [T ] =
204
+ override def varargGetter [T ](paramInfos : ParameterInfos [T ])(using p : ArgumentParser [T ]): () => Seq [T ] =
180
205
registerArg(paramInfos, ArgumentKind .VarArgument )
181
206
def remainingArgGetters (): List [() => T ] = nextPositionalArg() match
182
- case Some (arg) => convert(paramInfos.name , arg, p) :: remainingArgGetters()
207
+ case Some (arg) => convert(getEffectiveName( paramInfos) , arg, p) :: remainingArgGetters()
183
208
case None => Nil
184
209
val getters = remainingArgGetters()
185
210
() => getters.map(_())
186
211
187
212
override def run (f : => MainResultType ): Unit =
213
+ def checkShortNamesUnique (): Unit =
214
+ val shortNameToIndices = argShortNames.collect{ case Some (short) => short }.zipWithIndex.groupBy(_._1).view.mapValues(_.map(_._2))
215
+ for ((shortName, indices) <- shortNameToIndices if indices.length > 1 )
216
+ error(s " $shortName is used as short name for multiple parameters: ${indices.map(idx => argNames(idx)).mkString(" , " )}" )
217
+
188
218
def flagUnused (): Unit = nextPositionalArg() match
189
219
case Some (arg) =>
190
220
error(s " unused argument: $arg" )
191
221
flagUnused()
192
222
case None =>
193
223
for
194
224
arg <- args
195
- if arg.startsWith(" -- " ) && ! argNames.contains(arg.drop(2 ))
225
+ if arg.startsWith(argMarker ) && ! argNames.contains(arg.drop(2 ))
196
226
do
197
227
error(s " unknown argument name: $arg" )
198
228
end flagUnused
199
229
200
- if args.contains(" -- help" ) then
230
+ if args.contains(s " ${argMarker} help " ) then
201
231
usage()
202
232
println()
203
233
explain()
204
234
else
205
235
flagUnused()
236
+ checkShortNamesUnique()
206
237
if errors.nonEmpty then
207
238
for msg <- errors do println(s " Error: $msg" )
208
239
usage()
0 commit comments