Skip to content

Commit 9f17663

Browse files
Merge pull request #13367 from som-snytt/issue/9939
Upgrade f-interpolator
2 parents 48d5747 + 0c87244 commit 9f17663

21 files changed

+961
-924
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Compiler {
8585
new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only)
8686
new ExplicitOuter, // Add accessors to outer classes from nested ones.
8787
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
88-
new StringInterpolatorOpt) :: // Optimizes raw and s string interpolators by rewriting them to string concatenations
88+
new StringInterpolatorOpt) :: // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
8989
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
9090
new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_`
9191
new InlinePatterns, // Remove placeholders of inlined patterns

compiler/src/dotty/tools/dotc/parsing/Scanners.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,9 @@ object Scanners {
11681168
finishNamedToken(IDENTIFIER, target = next)
11691169
}
11701170
else
1171-
error("invalid string interpolation: `$$`, `$\"`, `$`ident or `$`BlockExpr expected")
1171+
error("invalid string interpolation: `$$`, `$\"`, `$`ident or `$`BlockExpr expected", off = charOffset - 2)
1172+
putChar('$')
1173+
getStringPart(multiLine)
11721174
}
11731175
else {
11741176
val isUnclosedLiteral = !isUnicodeEscape && (ch == SU || (!multiLine && (ch == CR || ch == LF)))
@@ -1251,7 +1253,7 @@ object Scanners {
12511253
nextChar()
12521254
}
12531255
}
1254-
val alt = if oct == LF then raw"\n" else f"\u$oct%04x"
1256+
val alt = if oct == LF then raw"\n" else f"\\u$oct%04x"
12551257
error(s"octal escape literals are unsupported: use $alt instead", start)
12561258
putChar(oct.toChar)
12571259
}

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
542542
case '"' => "\\\""
543543
case '\'' => "\\\'"
544544
case '\\' => "\\\\"
545-
case _ => if (ch.isControl) f"\u${ch.toInt}%04x" else String.valueOf(ch)
545+
case _ => if ch.isControl then f"\\u${ch.toInt}%04x" else String.valueOf(ch)
546546
}
547547

548548
def toText(const: Constant): Text = const.tag match {
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
package dotty.tools.dotc
2+
package transform.localopt
3+
4+
import scala.annotation.tailrec
5+
import scala.collection.mutable.ListBuffer
6+
import scala.util.chaining.*
7+
import scala.util.matching.Regex.Match
8+
9+
import java.util.{Calendar, Date, Formattable}
10+
11+
import PartialFunction.cond
12+
13+
import dotty.tools.dotc.ast.tpd.{Match => _, *}
14+
import dotty.tools.dotc.core.Contexts._
15+
import dotty.tools.dotc.core.Symbols._
16+
import dotty.tools.dotc.core.Types._
17+
import dotty.tools.dotc.core.Phases.typerPhase
18+
import dotty.tools.dotc.util.Spans.Span
19+
20+
/** Formatter string checker. */
21+
class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List[Tree])(using Context):
22+
23+
val argTypes = args.map(_.tpe)
24+
val actuals = ListBuffer.empty[Tree]
25+
26+
// count of args, for checking indexes
27+
val argc = argTypes.length
28+
29+
// Pick the first runtime type which the i'th arg can satisfy.
30+
// If conversion is required, implementation must emit it.
31+
def argType(argi: Int, types: Type*): Type =
32+
require(argi < argc, s"$argi out of range picking from $types")
33+
val tpe = argTypes(argi)
34+
types.find(t => argConformsTo(argi, tpe, t))
35+
.orElse(types.find(t => argConvertsTo(argi, tpe, t)))
36+
.getOrElse {
37+
report.argError(s"Found: ${tpe.show}, Required: ${types.map(_.show).mkString(", ")}", argi)
38+
actuals += args(argi)
39+
types.head
40+
}
41+
42+
object formattableTypes:
43+
val FormattableType = requiredClassRef("java.util.Formattable")
44+
val BigIntType = requiredClassRef("scala.math.BigInt")
45+
val BigDecimalType = requiredClassRef("scala.math.BigDecimal")
46+
val CalendarType = requiredClassRef("java.util.Calendar")
47+
val DateType = requiredClassRef("java.util.Date")
48+
import formattableTypes.*
49+
def argConformsTo(argi: Int, arg: Type, target: Type): Boolean = (arg <:< target).tap(if _ then actuals += args(argi))
50+
def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean =
51+
import typer.Implicits.SearchSuccess
52+
atPhase(typerPhase) {
53+
ctx.typer.inferView(args(argi), target) match
54+
case SearchSuccess(view, ref, _, _) => actuals += view ; true
55+
case _ => false
56+
}
57+
58+
// match a conversion specifier
59+
val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r
60+
61+
// ordinal is the regex group index in the format pattern
62+
enum SpecGroup:
63+
case Spec, Index, Flags, Width, Precision, CC
64+
import SpecGroup.*
65+
66+
/** For N part strings and N-1 args to interpolate, normalize parts and check arg types.
67+
*
68+
* Returns normalized part strings and args, where args correcpond to conversions in tail of parts.
69+
*/
70+
def checked: (List[String], List[Tree]) =
71+
val amended = ListBuffer.empty[String]
72+
val convert = ListBuffer.empty[Conversion]
73+
74+
@tailrec
75+
def loop(remaining: List[String], n: Int): Unit =
76+
remaining match
77+
case part0 :: more =>
78+
def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage, index = n, offset = 0))
79+
val part = try StringContext.processEscapes(part0) catch badPart
80+
val matches = formatPattern.findAllMatchIn(part)
81+
82+
def insertStringConversion(): Unit =
83+
amended += "%s" + part
84+
convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve
85+
argType(n-1, defn.AnyType)
86+
def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}")
87+
def accept(op: Conversion): Unit =
88+
if !op.isLeading then errorLeading(op)
89+
op.accepts(argType(n-1, op.acceptableVariants*))
90+
amended += part
91+
convert += op
92+
93+
// after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed
94+
if n == 0 then amended += part
95+
else if !matches.hasNext then insertStringConversion()
96+
else
97+
val cv = Conversion(matches.next(), n)
98+
if cv.isLiteral then insertStringConversion()
99+
else if cv.isIndexed then
100+
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion()
101+
else if !cv.isError then accept(cv)
102+
103+
// any remaining conversions in this part must be either literals or indexed
104+
while matches.hasNext do
105+
val cv = Conversion(matches.next(), n)
106+
if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg")
107+
else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv)
108+
109+
loop(more, n + 1)
110+
case Nil => ()
111+
end loop
112+
113+
loop(parts, n = 0)
114+
if reported then (Nil, Nil)
115+
else
116+
assert(argc == actuals.size, s"Expected ${argc} args but got ${actuals.size} for [${parts.mkString(", ")}]")
117+
(amended.toList, actuals.toList)
118+
end checked
119+
120+
extension (descriptor: Match)
121+
def at(g: SpecGroup): Int = descriptor.start(g.ordinal)
122+
def end(g: SpecGroup): Int = descriptor.end(g.ordinal)
123+
def offset(g: SpecGroup, i: Int = 0): Int = at(g) + i
124+
def group(g: SpecGroup): Option[String] = Option(descriptor.group(g.ordinal))
125+
def stringOf(g: SpecGroup): String = group(g).getOrElse("")
126+
def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt)
127+
128+
extension (inline value: Boolean)
129+
inline def or(inline body: => Unit): Boolean = value || { body ; false }
130+
inline def orElse(inline body: => Unit): Boolean = value || { body ; true }
131+
inline def and(inline body: => Unit): Boolean = value && { body ; true }
132+
inline def but(inline body: => Unit): Boolean = value && { body ; false }
133+
134+
enum Kind:
135+
case StringXn, HashXn, BooleanXn, CharacterXn, IntegralXn, FloatingPointXn, DateTimeXn, LiteralXn, ErrorXn
136+
import Kind.*
137+
138+
/** A conversion specifier matched in the argi'th string part, with `argc` arguments to interpolate.
139+
*/
140+
final class Conversion(val descriptor: Match, val argi: Int, val kind: Kind):
141+
// the descriptor fields
142+
val index: Option[Int] = descriptor.intOf(Index)
143+
val flags: String = descriptor.stringOf(Flags)
144+
val width: Option[Int] = descriptor.intOf(Width)
145+
val precision: Option[Int] = descriptor.group(Precision).map(_.drop(1).toInt)
146+
val op: String = descriptor.stringOf(CC)
147+
148+
// the conversion char is the head of the op string (but see DateTimeXn)
149+
val cc: Char =
150+
kind match
151+
case ErrorXn => if op.isEmpty then '?' else op(0)
152+
case DateTimeXn => if op.length > 1 then op(1) else '?'
153+
case _ => op(0)
154+
155+
def isIndexed: Boolean = index.nonEmpty || hasFlag('<')
156+
def isError: Boolean = kind == ErrorXn
157+
def isLiteral: Boolean = kind == LiteralXn
158+
159+
// descriptor is at index 0 of the part string
160+
def isLeading: Boolean = descriptor.at(Spec) == 0
161+
162+
// true if passes.
163+
def verify: Boolean =
164+
// various assertions
165+
def goodies = goodFlags && goodIndex
166+
def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed")
167+
def noWidth = width.isEmpty or errorAt(Width)("width not allowed")
168+
def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed")
169+
def only_-(msg: String) =
170+
val badFlags = flags.filterNot { case '-' | '<' => true case _ => false }
171+
badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg")
172+
def goodFlags =
173+
val badFlags = flags.filterNot(okFlags.contains)
174+
for f <- badFlags do badFlag(f, s"Illegal flag '$f'")
175+
badFlags.isEmpty
176+
def goodIndex =
177+
if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present")
178+
val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true)
179+
okRange || hasFlag('<') or errorAt(Index)("Argument index out of range")
180+
// begin verify
181+
kind match
182+
case StringXn => goodies
183+
case BooleanXn => goodies
184+
case HashXn => goodies
185+
case CharacterXn => goodies && noPrecision && only_-("c conversion")
186+
case IntegralXn =>
187+
def d_# = cc == 'd' && hasFlag('#') and badFlag('#', "# not allowed for d conversion")
188+
def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types")
189+
goodies && noPrecision && !d_# && !x_comma
190+
case FloatingPointXn =>
191+
goodies && (cc match
192+
case 'a' | 'A' =>
193+
val badFlags = ",(".filter(hasFlag)
194+
noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A"))
195+
case _ => true
196+
)
197+
case DateTimeXn =>
198+
def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters")
199+
def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion")
200+
goodies && hasCC && goodCC && noPrecision && only_-("date/time conversions")
201+
case LiteralXn =>
202+
op match
203+
case "%" => goodies && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal"))
204+
case "n" => noFlags && noWidth && noPrecision
205+
case ErrorXn =>
206+
errorAt(CC)(s"illegal conversion character '$cc'")
207+
false
208+
end verify
209+
210+
// is the specifier OK with the given arg
211+
def accepts(arg: Type): Boolean =
212+
kind match
213+
case BooleanXn => arg == defn.BooleanType orElse warningAt(CC)("Boolean format is null test for non-Boolean")
214+
case IntegralXn =>
215+
arg == BigIntType || !cond(cc) {
216+
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true
217+
}
218+
case _ => true
219+
220+
// what arg type if any does the conversion accept
221+
def acceptableVariants: List[Type] =
222+
kind match
223+
case StringXn => if hasFlag('#') then FormattableType :: Nil else defn.AnyType :: Nil
224+
case BooleanXn => defn.BooleanType :: defn.NullType :: Nil
225+
case HashXn => defn.AnyType :: Nil
226+
case CharacterXn => defn.CharType :: defn.ByteType :: defn.ShortType :: defn.IntType :: Nil
227+
case IntegralXn => defn.IntType :: defn.LongType :: defn.ByteType :: defn.ShortType :: BigIntType :: Nil
228+
case FloatingPointXn => defn.DoubleType :: defn.FloatType :: BigDecimalType :: Nil
229+
case DateTimeXn => defn.LongType :: CalendarType :: DateType :: Nil
230+
case LiteralXn => Nil
231+
case ErrorXn => Nil
232+
233+
// what flags does the conversion accept?
234+
private def okFlags: String =
235+
kind match
236+
case StringXn => "-#<"
237+
case BooleanXn | HashXn => "-<"
238+
case LiteralXn => "-"
239+
case _ => "-#+ 0,(<"
240+
241+
def hasFlag(f: Char) = flags.contains(f)
242+
def hasAnyFlag(fs: String) = fs.exists(hasFlag)
243+
244+
def badFlag(f: Char, msg: String) =
245+
val i = flags.indexOf(f) match { case -1 => 0 case j => j }
246+
errorAt(Flags, i)(msg)
247+
248+
def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partError(msg, argi, descriptor.offset(g, i), descriptor.end(g))
249+
def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partWarning(msg, argi, descriptor.offset(g, i), descriptor.end(g))
250+
251+
object Conversion:
252+
def apply(m: Match, i: Int): Conversion =
253+
def kindOf(cc: Char) = cc match
254+
case 's' | 'S' => StringXn
255+
case 'h' | 'H' => HashXn
256+
case 'b' | 'B' => BooleanXn
257+
case 'c' | 'C' => CharacterXn
258+
case 'd' | 'o' |
259+
'x' | 'X' => IntegralXn
260+
case 'e' | 'E' |
261+
'f' |
262+
'g' | 'G' |
263+
'a' | 'A' => FloatingPointXn
264+
case 't' | 'T' => DateTimeXn
265+
case '%' | 'n' => LiteralXn
266+
case _ => ErrorXn
267+
end kindOf
268+
m.group(CC) match
269+
case Some(cc) => new Conversion(m, i, kindOf(cc(0))).tap(_.verify)
270+
case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp"))
271+
end apply
272+
val literalHelp = "use %% for literal %, %n for newline"
273+
end Conversion
274+
275+
var reported = false
276+
277+
private def partPosAt(index: Int, offset: Int, end: Int) =
278+
val pos = partsElems(index).sourcePos
279+
val bgn = pos.span.start + offset
280+
val fin = if end < 0 then pos.span.end else pos.span.start + end
281+
pos.withSpan(Span(bgn, fin, bgn))
282+
283+
extension (r: report.type)
284+
def argError(message: String, index: Int): Unit = r.error(message, args(index).srcPos).tap(_ => reported = true)
285+
def partError(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.error(message, partPosAt(index, offset, end)).tap(_ => reported = true)
286+
def partWarning(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.warning(message, partPosAt(index, offset, end)).tap(_ => reported = true)
287+
end TypedFormatChecker
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package dotty.tools.dotc
2+
package transform.localopt
3+
4+
import dotty.tools.dotc.ast.tpd.*
5+
import dotty.tools.dotc.core.Constants.Constant
6+
import dotty.tools.dotc.core.Contexts.*
7+
8+
object FormatInterpolatorTransform:
9+
10+
/** For f"${arg}%xpart", check format conversions and return (format, args)
11+
* suitable for String.format(format, args).
12+
*/
13+
def checked(fun: Tree, args0: Tree)(using Context): (Tree, Tree) =
14+
val (partsExpr, parts) = fun match
15+
case TypeApply(Select(Apply(_, (parts: SeqLiteral) :: Nil), _), _) =>
16+
(parts.elems, parts.elems.map { case Literal(Constant(s: String)) => s })
17+
case _ =>
18+
report.error("Expected statically known StringContext", fun.srcPos)
19+
(Nil, Nil)
20+
val (args, elemtpt) = args0 match
21+
case seqlit: SeqLiteral => (seqlit.elems, seqlit.elemtpt)
22+
case _ =>
23+
report.error("Expected statically known argument list", args0.srcPos)
24+
(Nil, EmptyTree)
25+
26+
def literally(s: String) = Literal(Constant(s))
27+
if parts.lengthIs != args.length + 1 then
28+
val badParts =
29+
if parts.isEmpty then "there are no parts"
30+
else s"too ${if parts.lengthIs > args.length + 1 then "few" else "many"} arguments for interpolated string"
31+
report.error(badParts, fun.srcPos)
32+
(literally(""), args0)
33+
else
34+
val checker = TypedFormatChecker(partsExpr, parts, args)
35+
val (format, formatArgs) = checker.checked
36+
if format.isEmpty then (literally(parts.mkString), args0)
37+
else (literally(format.mkString), SeqLiteral(formatArgs.toList, elemtpt))
38+
end checked
39+
end FormatInterpolatorTransform

0 commit comments

Comments
 (0)