|
| 1 | +package dotty.tools.dotc |
| 2 | +package transform.localopt |
| 3 | + |
| 4 | +import scala.annotation.tailrec |
| 5 | +import scala.collection.mutable.ListBuffer |
| 6 | +import scala.util.chaining.* |
| 7 | +import scala.util.matching.Regex.Match |
| 8 | + |
| 9 | +import java.util.{Calendar, Date, Formattable} |
| 10 | + |
| 11 | +import PartialFunction.cond |
| 12 | + |
| 13 | +import dotty.tools.dotc.ast.tpd.{Match => _, *} |
| 14 | +import dotty.tools.dotc.core.Contexts._ |
| 15 | +import dotty.tools.dotc.core.Symbols._ |
| 16 | +import dotty.tools.dotc.core.Types._ |
| 17 | +import dotty.tools.dotc.core.Phases.typerPhase |
| 18 | +import dotty.tools.dotc.util.Spans.Span |
| 19 | + |
| 20 | +/** Formatter string checker. */ |
| 21 | +class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List[Tree])(using Context): |
| 22 | + |
| 23 | + val argTypes = args.map(_.tpe) |
| 24 | + val actuals = ListBuffer.empty[Tree] |
| 25 | + |
| 26 | + // count of args, for checking indexes |
| 27 | + val argc = argTypes.length |
| 28 | + |
| 29 | + // Pick the first runtime type which the i'th arg can satisfy. |
| 30 | + // If conversion is required, implementation must emit it. |
| 31 | + def argType(argi: Int, types: Type*): Type = |
| 32 | + require(argi < argc, s"$argi out of range picking from $types") |
| 33 | + val tpe = argTypes(argi) |
| 34 | + types.find(t => argConformsTo(argi, tpe, t)) |
| 35 | + .orElse(types.find(t => argConvertsTo(argi, tpe, t))) |
| 36 | + .getOrElse { |
| 37 | + report.argError(s"Found: ${tpe.show}, Required: ${types.map(_.show).mkString(", ")}", argi) |
| 38 | + actuals += args(argi) |
| 39 | + types.head |
| 40 | + } |
| 41 | + |
| 42 | + object formattableTypes: |
| 43 | + val FormattableType = requiredClassRef("java.util.Formattable") |
| 44 | + val BigIntType = requiredClassRef("scala.math.BigInt") |
| 45 | + val BigDecimalType = requiredClassRef("scala.math.BigDecimal") |
| 46 | + val CalendarType = requiredClassRef("java.util.Calendar") |
| 47 | + val DateType = requiredClassRef("java.util.Date") |
| 48 | + import formattableTypes.* |
| 49 | + def argConformsTo(argi: Int, arg: Type, target: Type): Boolean = (arg <:< target).tap(if _ then actuals += args(argi)) |
| 50 | + def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean = |
| 51 | + import typer.Implicits.SearchSuccess |
| 52 | + atPhase(typerPhase) { |
| 53 | + ctx.typer.inferView(args(argi), target) match |
| 54 | + case SearchSuccess(view, ref, _, _) => actuals += view ; true |
| 55 | + case _ => false |
| 56 | + } |
| 57 | + |
| 58 | + // match a conversion specifier |
| 59 | + val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r |
| 60 | + |
| 61 | + // ordinal is the regex group index in the format pattern |
| 62 | + enum SpecGroup: |
| 63 | + case Spec, Index, Flags, Width, Precision, CC |
| 64 | + import SpecGroup.* |
| 65 | + |
| 66 | + /** For N part strings and N-1 args to interpolate, normalize parts and check arg types. |
| 67 | + * |
| 68 | + * Returns normalized part strings and args, where args correcpond to conversions in tail of parts. |
| 69 | + */ |
| 70 | + def checked: (List[String], List[Tree]) = |
| 71 | + val amended = ListBuffer.empty[String] |
| 72 | + val convert = ListBuffer.empty[Conversion] |
| 73 | + |
| 74 | + @tailrec |
| 75 | + def loop(remaining: List[String], n: Int): Unit = |
| 76 | + remaining match |
| 77 | + case part0 :: more => |
| 78 | + def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage, index = n, offset = 0)) |
| 79 | + val part = try StringContext.processEscapes(part0) catch badPart |
| 80 | + val matches = formatPattern.findAllMatchIn(part) |
| 81 | + |
| 82 | + def insertStringConversion(): Unit = |
| 83 | + amended += "%s" + part |
| 84 | + convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve |
| 85 | + argType(n-1, defn.AnyType) |
| 86 | + def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}") |
| 87 | + def accept(op: Conversion): Unit = |
| 88 | + if !op.isLeading then errorLeading(op) |
| 89 | + op.accepts(argType(n-1, op.acceptableVariants*)) |
| 90 | + amended += part |
| 91 | + convert += op |
| 92 | + |
| 93 | + // after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed |
| 94 | + if n == 0 then amended += part |
| 95 | + else if !matches.hasNext then insertStringConversion() |
| 96 | + else |
| 97 | + val cv = Conversion(matches.next(), n) |
| 98 | + if cv.isLiteral then insertStringConversion() |
| 99 | + else if cv.isIndexed then |
| 100 | + if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion() |
| 101 | + else if !cv.isError then accept(cv) |
| 102 | + |
| 103 | + // any remaining conversions in this part must be either literals or indexed |
| 104 | + while matches.hasNext do |
| 105 | + val cv = Conversion(matches.next(), n) |
| 106 | + if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg") |
| 107 | + else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv) |
| 108 | + |
| 109 | + loop(more, n + 1) |
| 110 | + case Nil => () |
| 111 | + end loop |
| 112 | + |
| 113 | + loop(parts, n = 0) |
| 114 | + if reported then (Nil, Nil) |
| 115 | + else |
| 116 | + assert(argc == actuals.size, s"Expected ${argc} args but got ${actuals.size} for [${parts.mkString(", ")}]") |
| 117 | + (amended.toList, actuals.toList) |
| 118 | + end checked |
| 119 | + |
| 120 | + extension (descriptor: Match) |
| 121 | + def at(g: SpecGroup): Int = descriptor.start(g.ordinal) |
| 122 | + def end(g: SpecGroup): Int = descriptor.end(g.ordinal) |
| 123 | + def offset(g: SpecGroup, i: Int = 0): Int = at(g) + i |
| 124 | + def group(g: SpecGroup): Option[String] = Option(descriptor.group(g.ordinal)) |
| 125 | + def stringOf(g: SpecGroup): String = group(g).getOrElse("") |
| 126 | + def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt) |
| 127 | + |
| 128 | + extension (inline value: Boolean) |
| 129 | + inline def or(inline body: => Unit): Boolean = value || { body ; false } |
| 130 | + inline def orElse(inline body: => Unit): Boolean = value || { body ; true } |
| 131 | + inline def and(inline body: => Unit): Boolean = value && { body ; true } |
| 132 | + inline def but(inline body: => Unit): Boolean = value && { body ; false } |
| 133 | + |
| 134 | + enum Kind: |
| 135 | + case StringXn, HashXn, BooleanXn, CharacterXn, IntegralXn, FloatingPointXn, DateTimeXn, LiteralXn, ErrorXn |
| 136 | + import Kind.* |
| 137 | + |
| 138 | + /** A conversion specifier matched in the argi'th string part, with `argc` arguments to interpolate. |
| 139 | + */ |
| 140 | + final class Conversion(val descriptor: Match, val argi: Int, val kind: Kind): |
| 141 | + // the descriptor fields |
| 142 | + val index: Option[Int] = descriptor.intOf(Index) |
| 143 | + val flags: String = descriptor.stringOf(Flags) |
| 144 | + val width: Option[Int] = descriptor.intOf(Width) |
| 145 | + val precision: Option[Int] = descriptor.group(Precision).map(_.drop(1).toInt) |
| 146 | + val op: String = descriptor.stringOf(CC) |
| 147 | + |
| 148 | + // the conversion char is the head of the op string (but see DateTimeXn) |
| 149 | + val cc: Char = |
| 150 | + kind match |
| 151 | + case ErrorXn => if op.isEmpty then '?' else op(0) |
| 152 | + case DateTimeXn => if op.length > 1 then op(1) else '?' |
| 153 | + case _ => op(0) |
| 154 | + |
| 155 | + def isIndexed: Boolean = index.nonEmpty || hasFlag('<') |
| 156 | + def isError: Boolean = kind == ErrorXn |
| 157 | + def isLiteral: Boolean = kind == LiteralXn |
| 158 | + |
| 159 | + // descriptor is at index 0 of the part string |
| 160 | + def isLeading: Boolean = descriptor.at(Spec) == 0 |
| 161 | + |
| 162 | + // true if passes. |
| 163 | + def verify: Boolean = |
| 164 | + // various assertions |
| 165 | + def goodies = goodFlags && goodIndex |
| 166 | + def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed") |
| 167 | + def noWidth = width.isEmpty or errorAt(Width)("width not allowed") |
| 168 | + def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed") |
| 169 | + def only_-(msg: String) = |
| 170 | + val badFlags = flags.filterNot { case '-' | '<' => true case _ => false } |
| 171 | + badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg") |
| 172 | + def goodFlags = |
| 173 | + val badFlags = flags.filterNot(okFlags.contains) |
| 174 | + for f <- badFlags do badFlag(f, s"Illegal flag '$f'") |
| 175 | + badFlags.isEmpty |
| 176 | + def goodIndex = |
| 177 | + if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present") |
| 178 | + val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true) |
| 179 | + okRange || hasFlag('<') or errorAt(Index)("Argument index out of range") |
| 180 | + // begin verify |
| 181 | + kind match |
| 182 | + case StringXn => goodies |
| 183 | + case BooleanXn => goodies |
| 184 | + case HashXn => goodies |
| 185 | + case CharacterXn => goodies && noPrecision && only_-("c conversion") |
| 186 | + case IntegralXn => |
| 187 | + def d_# = cc == 'd' && hasFlag('#') and badFlag('#', "# not allowed for d conversion") |
| 188 | + def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types") |
| 189 | + goodies && noPrecision && !d_# && !x_comma |
| 190 | + case FloatingPointXn => |
| 191 | + goodies && (cc match |
| 192 | + case 'a' | 'A' => |
| 193 | + val badFlags = ",(".filter(hasFlag) |
| 194 | + noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A")) |
| 195 | + case _ => true |
| 196 | + ) |
| 197 | + case DateTimeXn => |
| 198 | + def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters") |
| 199 | + def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion") |
| 200 | + goodies && hasCC && goodCC && noPrecision && only_-("date/time conversions") |
| 201 | + case LiteralXn => |
| 202 | + op match |
| 203 | + case "%" => goodies && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal")) |
| 204 | + case "n" => noFlags && noWidth && noPrecision |
| 205 | + case ErrorXn => |
| 206 | + errorAt(CC)(s"illegal conversion character '$cc'") |
| 207 | + false |
| 208 | + end verify |
| 209 | + |
| 210 | + // is the specifier OK with the given arg |
| 211 | + def accepts(arg: Type): Boolean = |
| 212 | + kind match |
| 213 | + case BooleanXn => arg == defn.BooleanType orElse warningAt(CC)("Boolean format is null test for non-Boolean") |
| 214 | + case IntegralXn => |
| 215 | + arg == BigIntType || !cond(cc) { |
| 216 | + case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true |
| 217 | + } |
| 218 | + case _ => true |
| 219 | + |
| 220 | + // what arg type if any does the conversion accept |
| 221 | + def acceptableVariants: List[Type] = |
| 222 | + kind match |
| 223 | + case StringXn => if hasFlag('#') then FormattableType :: Nil else defn.AnyType :: Nil |
| 224 | + case BooleanXn => defn.BooleanType :: defn.NullType :: Nil |
| 225 | + case HashXn => defn.AnyType :: Nil |
| 226 | + case CharacterXn => defn.CharType :: defn.ByteType :: defn.ShortType :: defn.IntType :: Nil |
| 227 | + case IntegralXn => defn.IntType :: defn.LongType :: defn.ByteType :: defn.ShortType :: BigIntType :: Nil |
| 228 | + case FloatingPointXn => defn.DoubleType :: defn.FloatType :: BigDecimalType :: Nil |
| 229 | + case DateTimeXn => defn.LongType :: CalendarType :: DateType :: Nil |
| 230 | + case LiteralXn => Nil |
| 231 | + case ErrorXn => Nil |
| 232 | + |
| 233 | + // what flags does the conversion accept? |
| 234 | + private def okFlags: String = |
| 235 | + kind match |
| 236 | + case StringXn => "-#<" |
| 237 | + case BooleanXn | HashXn => "-<" |
| 238 | + case LiteralXn => "-" |
| 239 | + case _ => "-#+ 0,(<" |
| 240 | + |
| 241 | + def hasFlag(f: Char) = flags.contains(f) |
| 242 | + def hasAnyFlag(fs: String) = fs.exists(hasFlag) |
| 243 | + |
| 244 | + def badFlag(f: Char, msg: String) = |
| 245 | + val i = flags.indexOf(f) match { case -1 => 0 case j => j } |
| 246 | + errorAt(Flags, i)(msg) |
| 247 | + |
| 248 | + def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partError(msg, argi, descriptor.offset(g, i), descriptor.end(g)) |
| 249 | + def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partWarning(msg, argi, descriptor.offset(g, i), descriptor.end(g)) |
| 250 | + |
| 251 | + object Conversion: |
| 252 | + def apply(m: Match, i: Int): Conversion = |
| 253 | + def kindOf(cc: Char) = cc match |
| 254 | + case 's' | 'S' => StringXn |
| 255 | + case 'h' | 'H' => HashXn |
| 256 | + case 'b' | 'B' => BooleanXn |
| 257 | + case 'c' | 'C' => CharacterXn |
| 258 | + case 'd' | 'o' | |
| 259 | + 'x' | 'X' => IntegralXn |
| 260 | + case 'e' | 'E' | |
| 261 | + 'f' | |
| 262 | + 'g' | 'G' | |
| 263 | + 'a' | 'A' => FloatingPointXn |
| 264 | + case 't' | 'T' => DateTimeXn |
| 265 | + case '%' | 'n' => LiteralXn |
| 266 | + case _ => ErrorXn |
| 267 | + end kindOf |
| 268 | + m.group(CC) match |
| 269 | + case Some(cc) => new Conversion(m, i, kindOf(cc(0))).tap(_.verify) |
| 270 | + case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp")) |
| 271 | + end apply |
| 272 | + val literalHelp = "use %% for literal %, %n for newline" |
| 273 | + end Conversion |
| 274 | + |
| 275 | + var reported = false |
| 276 | + |
| 277 | + private def partPosAt(index: Int, offset: Int, end: Int) = |
| 278 | + val pos = partsElems(index).sourcePos |
| 279 | + val bgn = pos.span.start + offset |
| 280 | + val fin = if end < 0 then pos.span.end else pos.span.start + end |
| 281 | + pos.withSpan(Span(bgn, fin, bgn)) |
| 282 | + |
| 283 | + extension (r: report.type) |
| 284 | + def argError(message: String, index: Int): Unit = r.error(message, args(index).srcPos).tap(_ => reported = true) |
| 285 | + def partError(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.error(message, partPosAt(index, offset, end)).tap(_ => reported = true) |
| 286 | + def partWarning(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.warning(message, partPosAt(index, offset, end)).tap(_ => reported = true) |
| 287 | +end TypedFormatChecker |
0 commit comments