Skip to content

Commit 48aa0a3

Browse files
authored
Merge pull request #354 from scala/backport-lts-3.3-22957
Backport "improvement: Support using directives in worksheets" to 3.3 LTS
2 parents 3985282 + 5ca62f8 commit 48aa0a3

File tree

8 files changed

+162
-92
lines changed

8 files changed

+162
-92
lines changed

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ private sealed trait WarningSettings:
173173
private val WvalueDiscard: Setting[Boolean] = BooleanSetting("-Wvalue-discard", "Warn when non-Unit expression results are unused.")
174174
private val WNonUnitStatement = BooleanSetting("-Wnonunit-statement", "Warn when block statements are non-Unit expressions.")
175175
private val WenumCommentDiscard = BooleanSetting("-Wenum-comment-discard", "Warn when a comment ambiguously assigned to multiple enum cases is discarded.")
176+
private val WtoStringInterpolated = BooleanSetting("-Wtostring-interpolated", "Warn a standard interpolator used toString on a reference type.")
176177
private val Wunused: Setting[List[ChoiceWithHelp[String]]] = MultiChoiceHelpSetting(
177178
name = "-Wunused",
178179
helpArg = "warning",
@@ -288,6 +289,7 @@ private sealed trait WarningSettings:
288289
def valueDiscard(using Context): Boolean = allOr(WvalueDiscard)
289290
def nonUnitStatement(using Context): Boolean = allOr(WNonUnitStatement)
290291
def enumCommentDiscard(using Context): Boolean = allOr(WenumCommentDiscard)
292+
def toStringInterpolated(using Context): Boolean = allOr(WtoStringInterpolated)
291293
def checkInit(using Context): Boolean = allOr(YcheckInit)
292294

293295
/** -X "Extended" or "Advanced" settings */

compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import scala.annotation.tailrec
55
import scala.collection.mutable.ListBuffer
66
import scala.util.matching.Regex.Match
77

8-
import PartialFunction.cond
9-
108
import dotty.tools.dotc.ast.tpd.{Match => _, *}
119
import dotty.tools.dotc.core.Contexts.*
1210
import dotty.tools.dotc.core.Symbols.*
@@ -30,8 +28,9 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
3028
def argType(argi: Int, types: Type*): Type =
3129
require(argi < argc, s"$argi out of range picking from $types")
3230
val tpe = argTypes(argi)
33-
types.find(t => argConformsTo(argi, tpe, t))
34-
.orElse(types.find(t => argConvertsTo(argi, tpe, t)))
31+
types.find(t => t != defn.AnyType && argConformsTo(argi, tpe, t))
32+
.orElse(types.find(t => t != defn.AnyType && argConvertsTo(argi, tpe, t)))
33+
.orElse(types.find(t => t == defn.AnyType && argConformsTo(argi, tpe, t)))
3534
.getOrElse {
3635
report.argError(s"Found: ${tpe.show}, Required: ${types.map(_.show).mkString(", ")}", argi)
3736
actuals += args(argi)
@@ -64,50 +63,57 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
6463

6564
/** For N part strings and N-1 args to interpolate, normalize parts and check arg types.
6665
*
67-
* Returns normalized part strings and args, where args correcpond to conversions in tail of parts.
66+
* Returns normalized part strings and args, where args correspond to conversions in tail of parts.
6867
*/
6968
def checked: (List[String], List[Tree]) =
7069
val amended = ListBuffer.empty[String]
7170
val convert = ListBuffer.empty[Conversion]
7271

72+
def checkPart(part: String, n: Int): Unit =
73+
val matches = formatPattern.findAllMatchIn(part)
74+
75+
def insertStringConversion(): Unit =
76+
amended += "%s" + part
77+
val cv = Conversion.stringXn(n)
78+
cv.accepts(argType(n-1, defn.AnyType))
79+
convert += cv
80+
cv.lintToString(argTypes(n-1))
81+
82+
def errorLeading(op: Conversion) = op.errorAt(Spec):
83+
s"conversions must follow a splice; ${Conversion.literalHelp}"
84+
85+
def accept(op: Conversion): Unit =
86+
if !op.isLeading then errorLeading(op)
87+
op.accepts(argType(n-1, op.acceptableVariants*))
88+
amended += part
89+
convert += op
90+
op.lintToString(argTypes(n-1))
91+
92+
// after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed
93+
if n == 0 then amended += part
94+
else if !matches.hasNext then insertStringConversion()
95+
else
96+
val cv = Conversion(matches.next(), n)
97+
if cv.isLiteral then insertStringConversion()
98+
else if cv.isIndexed then
99+
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion()
100+
else if !cv.isError then accept(cv)
101+
102+
// any remaining conversions in this part must be either literals or indexed
103+
while matches.hasNext do
104+
val cv = Conversion(matches.next(), n)
105+
if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg")
106+
else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv)
107+
end checkPart
108+
73109
@tailrec
74-
def loop(remaining: List[String], n: Int): Unit =
75-
remaining match
76-
case part0 :: more =>
77-
def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage.nn, index = n, offset = 0))
78-
val part = try StringContext.processEscapes(part0) catch badPart
79-
val matches = formatPattern.findAllMatchIn(part)
80-
81-
def insertStringConversion(): Unit =
82-
amended += "%s" + part
83-
convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve
84-
argType(n-1, defn.AnyType)
85-
def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}")
86-
def accept(op: Conversion): Unit =
87-
if !op.isLeading then errorLeading(op)
88-
op.accepts(argType(n-1, op.acceptableVariants*))
89-
amended += part
90-
convert += op
91-
92-
// after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed
93-
if n == 0 then amended += part
94-
else if !matches.hasNext then insertStringConversion()
95-
else
96-
val cv = Conversion(matches.next(), n)
97-
if cv.isLiteral then insertStringConversion()
98-
else if cv.isIndexed then
99-
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion()
100-
else if !cv.isError then accept(cv)
101-
102-
// any remaining conversions in this part must be either literals or indexed
103-
while matches.hasNext do
104-
val cv = Conversion(matches.next(), n)
105-
if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg")
106-
else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv)
107-
108-
loop(more, n + 1)
109-
case Nil => ()
110-
end loop
110+
def loop(remaining: List[String], n: Int): Unit = remaining match
111+
case part0 :: remaining =>
112+
def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage.nn, index = n, offset = 0))
113+
val part = try StringContext.processEscapes(part0) catch badPart
114+
checkPart(part, n)
115+
loop(remaining, n + 1)
116+
case Nil =>
111117

112118
loop(parts, n = 0)
113119
if reported then (Nil, Nil)
@@ -125,10 +131,8 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
125131
def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt)
126132

127133
extension (inline value: Boolean)
128-
inline def or(inline body: => Unit): Boolean = value || { body ; false }
129-
inline def orElse(inline body: => Unit): Boolean = value || { body ; true }
130-
inline def and(inline body: => Unit): Boolean = value && { body ; true }
131-
inline def but(inline body: => Unit): Boolean = value && { body ; false }
134+
inline infix def or(inline body: => Unit): Boolean = value || { body; false }
135+
inline infix def and(inline body: => Unit): Boolean = value && { body; true }
132136

133137
enum Kind:
134138
case StringXn, HashXn, BooleanXn, CharacterXn, IntegralXn, FloatingPointXn, DateTimeXn, LiteralXn, ErrorXn
@@ -147,9 +151,10 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
147151
// the conversion char is the head of the op string (but see DateTimeXn)
148152
val cc: Char =
149153
kind match
150-
case ErrorXn => if op.isEmpty then '?' else op(0)
151-
case DateTimeXn => if op.length > 1 then op(1) else '?'
152-
case _ => op(0)
154+
case ErrorXn => if op.isEmpty then '?' else op(0)
155+
case DateTimeXn => if op.length <= 1 then '?' else op(1)
156+
case StringXn => if op.isEmpty then 's' else op(0) // accommodate the default %s
157+
case _ => op(0)
153158

154159
def isIndexed: Boolean = index.nonEmpty || hasFlag('<')
155160
def isError: Boolean = kind == ErrorXn
@@ -209,18 +214,28 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
209214
// is the specifier OK with the given arg
210215
def accepts(arg: Type): Boolean =
211216
kind match
212-
case BooleanXn => arg == defn.BooleanType orElse warningAt(CC)("Boolean format is null test for non-Boolean")
213-
case IntegralXn =>
214-
arg == BigIntType || !cond(cc) {
215-
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true
216-
}
217+
case BooleanXn if arg != defn.BooleanType =>
218+
warningAt(CC):
219+
"""non-Boolean value formats as "true" for non-null references and boxed primitives, otherwise "false""""
220+
true
221+
case IntegralXn if arg != BigIntType =>
222+
cc match
223+
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") =>
224+
"+ (".filter(hasFlag).foreach: bad =>
225+
badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")
226+
false
217227
case _ => true
228+
case _ => true
229+
230+
def lintToString(arg: Type): Unit =
231+
if ctx.settings.Whas.toStringInterpolated && kind == StringXn && !(arg.widen =:= defn.StringType) && !arg.isPrimitiveValueType
232+
then warningAt(CC)("interpolation uses toString")
218233

219234
// what arg type if any does the conversion accept
220235
def acceptableVariants: List[Type] =
221236
kind match
222237
case StringXn => if hasFlag('#') then FormattableType :: Nil else defn.AnyType :: Nil
223-
case BooleanXn => defn.BooleanType :: defn.NullType :: Nil
238+
case BooleanXn => defn.BooleanType :: defn.NullType :: defn.AnyType :: Nil // warn if not boolean
224239
case HashXn => defn.AnyType :: Nil
225240
case CharacterXn => defn.CharType :: defn.ByteType :: defn.ShortType :: defn.IntType :: Nil
226241
case IntegralXn => defn.IntType :: defn.LongType :: defn.ByteType :: defn.ShortType :: BigIntType :: Nil
@@ -249,25 +264,30 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
249264

250265
object Conversion:
251266
def apply(m: Match, i: Int): Conversion =
252-
def kindOf(cc: Char) = cc match
253-
case 's' | 'S' => StringXn
254-
case 'h' | 'H' => HashXn
255-
case 'b' | 'B' => BooleanXn
256-
case 'c' | 'C' => CharacterXn
257-
case 'd' | 'o' |
258-
'x' | 'X' => IntegralXn
259-
case 'e' | 'E' |
260-
'f' |
261-
'g' | 'G' |
262-
'a' | 'A' => FloatingPointXn
263-
case 't' | 'T' => DateTimeXn
264-
case '%' | 'n' => LiteralXn
265-
case _ => ErrorXn
266-
end kindOf
267267
m.group(CC) match
268-
case Some(cc) => new Conversion(m, i, kindOf(cc(0))).tap(_.verify)
269-
case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp"))
268+
case Some(cc) =>
269+
val xn = cc(0) match
270+
case 's' | 'S' => StringXn
271+
case 'h' | 'H' => HashXn
272+
case 'b' | 'B' => BooleanXn
273+
case 'c' | 'C' => CharacterXn
274+
case 'd' | 'o' |
275+
'x' | 'X' => IntegralXn
276+
case 'e' | 'E' |
277+
'f' |
278+
'g' | 'G' |
279+
'a' | 'A' => FloatingPointXn
280+
case 't' | 'T' => DateTimeXn
281+
case '%' | 'n' => LiteralXn
282+
case _ => ErrorXn
283+
new Conversion(m, i, xn)
284+
.tap(_.verify)
285+
case None =>
286+
new Conversion(m, i, ErrorXn)
287+
.tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp"))
270288
end apply
289+
// construct a default %s conversion
290+
def stringXn(i: Int): Conversion = new Conversion(formatPattern.findAllMatchIn("%").next(), i, StringXn)
271291
val literalHelp = "use %% for literal %, %n for newline"
272292
end Conversion
273293

compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,22 @@ class StringInterpolatorOpt extends MiniPhase:
9696
def mkConcat(strs: List[Literal], elems: List[Tree]): Tree =
9797
val stri = strs.iterator
9898
val elemi = elems.iterator
99-
var result: Tree = stri.next
99+
var result: Tree = stri.next()
100100
def concat(tree: Tree): Unit =
101101
result = result.select(defn.String_+).appliedTo(tree).withSpan(tree.span)
102102
while elemi.hasNext
103103
do
104-
concat(elemi.next)
105-
val str = stri.next
104+
val elem = elemi.next()
105+
lintToString(elem)
106+
concat(elem)
107+
val str = stri.next()
106108
if !str.const.stringValue.isEmpty then concat(str)
107109
result
108110
end mkConcat
111+
def lintToString(t: Tree): Unit =
112+
val arg: Type = t.tpe
113+
if ctx.settings.Whas.toStringInterpolated && !(arg.widen =:= defn.StringType) && !arg.isPrimitiveValueType
114+
then report.warning("interpolation uses toString", t.srcPos)
109115
val sym = tree.symbol
110116
// Test names first to avoid loading scala.StringContext if not used, and common names first
111117
val isInterpolatedMethod =

presentation-compiler/src/main/dotty/tools/pc/PcInlineValueProvider.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ final class PcInlineValueProvider(
7272
text
7373
)(startOffset, endOffset)
7474
val startPos = new l.Position(
75-
range.getStart.getLine,
76-
range.getStart.getCharacter - (startOffset - startWithSpace)
75+
range.getStart.nn.getLine,
76+
range.getStart.nn.getCharacter - (startOffset - startWithSpace)
7777
)
7878
val endPos =
7979
if (endWithSpace - 1 >= 0 && text(endWithSpace - 1) == '\n')
80-
new l.Position(range.getEnd.getLine + 1, 0)
80+
new l.Position(range.getEnd.nn.getLine + 1, 0)
8181
else
8282
new l.Position(
83-
range.getEnd.getLine,
84-
range.getEnd.getCharacter + endWithSpace - endOffset
83+
range.getEnd.nn.getLine,
84+
range.getEnd.nn.getCharacter + endWithSpace - endOffset
8585
)
8686

8787
new l.Range(startPos, endPos)
@@ -129,15 +129,15 @@ final class PcInlineValueProvider(
129129
end defAndRefs
130130

131131
private def stripIndentPrefix(rhs: String, refIndent: String, defIndent: String): String =
132-
val rhsLines = rhs.split("\n").toList
132+
val rhsLines = rhs.split("\n").nn.toList
133133
rhsLines match
134134
case h :: Nil => rhs
135135
case h :: t =>
136-
val noPrefixH = h.stripPrefix(refIndent)
136+
val noPrefixH = h.nn.stripPrefix(refIndent)
137137
if noPrefixH.startsWith("{") then
138-
noPrefixH ++ t.map(refIndent ++ _.stripPrefix(defIndent)).mkString("\n","\n", "")
138+
noPrefixH ++ t.map(refIndent ++ _.nn.stripPrefix(defIndent)).mkString("\n","\n", "")
139139
else
140-
((" " ++ h) :: t).map(refIndent ++ _.stripPrefix(defIndent)).mkString("\n", "\n", "")
140+
((" " ++ h.nn) :: t).map(refIndent ++ _.nn.stripPrefix(defIndent)).mkString("\n", "\n", "")
141141
case Nil => rhs
142142

143143
private def definitionRequiresBrackets(tree: Tree)(using Context): Boolean =

presentation-compiler/src/main/dotty/tools/pc/completions/ScalaCliCompletions.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@ class ScalaCliCompletions(
1313
):
1414
def unapply(path: List[Tree]) =
1515
def scalaCliDep = CoursierComplete.isScalaCliDep(
16-
pos.lineContent.take(pos.column).stripPrefix("/*<script>*/")
16+
pos.lineContent.take(pos.column).stripPrefix("/*<script>*/").dropWhile(c => c == ' ' || c == '\t')
1717
)
18+
19+
lazy val supportsUsing =
20+
val filename = pos.source.file.path
21+
filename.endsWith(".sc.scala") ||
22+
filename.endsWith(".worksheet.sc")
23+
1824
path match
1925
case Nil | (_: PackageDef) :: _ => scalaCliDep
2026
// generated script file will end with .sc.scala
21-
case (_: TypeDef) :: (_: PackageDef) :: Nil if pos.source.file.path.endsWith(".sc.scala") =>
27+
case (_: TypeDef) :: (_: PackageDef) :: Nil if supportsUsing =>
2228
scalaCliDep
23-
case (_: Template) :: (_: TypeDef) :: Nil if pos.source.file.path.endsWith(".sc.scala") =>
29+
case (_: Template) :: (_: TypeDef) :: Nil if supportsUsing =>
2430
scalaCliDep
2531
case head :: next => None
2632

tests/neg/f-interpolator-neg.check

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
7 | new StringContext("", "").f() // error
1515
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
1616
| too few arguments for interpolated string
17-
-- [E209] Interpolation Error: tests/neg/f-interpolator-neg.scala:11:7 -------------------------------------------------
18-
11 | f"$s%b" // error
19-
| ^
20-
| Found: (s : String), Required: Boolean, Null
17+
-- [E209] Interpolation Warning: tests/neg/f-interpolator-neg.scala:11:9 -----------------------------------------------
18+
11 | f"$s%b" // warn only
19+
| ^
20+
| non-Boolean value formats as "true" for non-null references and boxed primitives, otherwise "false"
2121
-- [E209] Interpolation Error: tests/neg/f-interpolator-neg.scala:12:7 -------------------------------------------------
2222
12 | f"$s%c" // error
2323
| ^

tests/neg/f-interpolator-neg.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ object Test {
88
}
99

1010
def interpolationMismatches(s : String, f : Double, b : Boolean) = {
11-
f"$s%b" // error
11+
f"$s%b" // warn only
1212
f"$s%c" // error
1313
f"$f%c" // error
1414
f"$s%x" // error
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//> using options -Wtostring-interpolated
2+
//> abusing options -Wconf:cat=w-flag-tostring-interpolated:e -Wtostring-interpolated
3+
4+
case class C(x: Int)
5+
6+
trait T {
7+
def c = C(42)
8+
def f = f"$c" // warn
9+
def s = s"$c" // warn
10+
def r = raw"$c" // warn
11+
12+
def format = f"${c.x}%d in $c or $c%s" // warn using c.toString // warn
13+
14+
def bool = f"$c%b" // warn just a null check
15+
16+
def oops = s"${null} slipped thru my fingers" // warn
17+
18+
def ok = s"${c.toString}"
19+
20+
def sb = new StringBuilder().append("hello")
21+
def greeting = s"$sb, world" // warn
22+
}
23+
24+
class Mitigations {
25+
26+
val s = "hello, world"
27+
val i = 42
28+
def shown = println("shown")
29+
30+
def ok = s"$s is ok"
31+
def jersey = s"number $i"
32+
def unitized = s"unfortunately $shown" // maybe tell them about unintended ()?
33+
34+
def nopct = f"$s is ok"
35+
def nofmt = f"number $i"
36+
}

0 commit comments

Comments
 (0)