Skip to content

Commit 1d2436a

Browse files
committed
[debugger] Create a class prepare request for the outermost class when stopping in lambdas
#SCL-22554 fixed #SCL-22145 - Scala 3.4 and later tries to compile lambdas to the outermost static class. - The reasons for this are explained in scala/scala3#19251. - When trying to stop at breakpoints inside lambdas, it is important to load the outermost class and its nested classes. - Java also does this in the platform implementation. - Tests that cover code examples mentioned in the Scala 3 PR have been added.
1 parent 1cd249e commit 1d2436a

File tree

3 files changed

+314
-6
lines changed

3 files changed

+314
-6
lines changed

scala/debugger/src/org/jetbrains/plugins/scala/debugger/ScalaPositionManager.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ class ScalaPositionManager(val debugProcess: DebugProcess) extends PositionManag
219219
isLocalClass(definition) || isDelayedInit(definition)
220220
}
221221

222-
def findEnclosingTypeDefinition: Option[ScTypeDefinition] = {
222+
def findTopmostEnclosingTypeDefinition: Option[ScTypeDefinition] = {
223223
@tailrec
224224
def notLocalEnclosingTypeDefinition(element: PsiElement): Option[ScTypeDefinition] = {
225-
PsiTreeUtil.getParentOfType(element, classOf[ScTypeDefinition]) match {
225+
PsiTreeUtil.getTopmostParentOfType(element, classOf[ScTypeDefinition]) match {
226226
case null => None
227227
case td if isLocalClass(td) => notLocalEnclosingTypeDefinition(td.getParent)
228228
case td => Some(td)
@@ -274,7 +274,7 @@ class ScalaPositionManager(val debugProcess: DebugProcess) extends PositionManag
274274
ClassPattern.Single(topLevelMemberClassName(pckg.getContainingFile, Some(pckg)))
275275
case _ =>
276276
val pattern =
277-
findEnclosingTypeDefinition match {
277+
findTopmostEnclosingTypeDefinition match {
278278
case Some(td) => Some(ClassPattern.Nested(getSpecificNameForDebugger(td)))
279279
case None =>
280280
findEnclosingPackageOrFile.map {
@@ -316,9 +316,7 @@ class ScalaPositionManager(val debugProcess: DebugProcess) extends PositionManag
316316
case (position, ClassPattern.Single(pattern)) => (createRequestor(position), pattern)
317317
}
318318

319-
val res = requestorsAndPatterns.flatMap { case (requestor, pattern) => createClassPrepareRequests(requestor, pattern) }.asJava
320-
println()
321-
res
319+
requestorsAndPatterns.flatMap { case (requestor, pattern) => createClassPrepareRequests(requestor, pattern) }.asJava
322320
}
323321

324322
private def throwIfNotScalaFile(file: PsiFile): Unit = {
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
package org.jetbrains.plugins.scala
2+
package debugger
3+
package breakpoints
4+
5+
import org.jetbrains.plugins.scala.extensions.inReadAction
6+
import org.junit.Assert.{assertTrue, fail}
7+
8+
import java.util.concurrent.ConcurrentLinkedQueue
9+
import java.util.stream.Collectors
10+
import scala.jdk.CollectionConverters._
11+
12+
class LambdaBreakpointsTest_2_11 extends LambdaBreakpointsTestBase {
13+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_2_11
14+
}
15+
16+
class LambdaBreakpointsTest_2_12 extends LambdaBreakpointsTestBase {
17+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_2_12
18+
}
19+
20+
class LambdaBreakpointsTest_2_13 extends LambdaBreakpointsTestBase {
21+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_2_13
22+
}
23+
24+
class LambdaBreakpointsTest_3_0 extends LambdaBreakpointsTestBase {
25+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_3_0
26+
27+
override def testLambdaInClassConstructor(): Unit = breakpointsTest()(9, 4, 4, 4, 4, 4, 4)
28+
29+
override def testLambdaInObjectConstructor(): Unit = breakpointsTest()(9, 4, 4, 4, 4, 4, 4)
30+
31+
override def testLambdaInNestedObject(): Unit = breakpointsTest()(15, 8, 8, 8, 8, 8, 8)
32+
33+
override def testLambdaInNestedClass(): Unit = breakpointsTest()(15, 8, 8, 8, 8, 8, 8)
34+
35+
override def testLambdaInLocalMethod(): Unit = breakpointsTest()(
36+
21, 11,
37+
8, 9, 10, 11,
38+
8, 9, 10, 11,
39+
8, 9, 10, 11,
40+
8, 9, 10, 11,
41+
8, 9, 10, 11,
42+
)
43+
44+
addSourceFile("LambdaInExtension.scala",
45+
s"""
46+
|object LambdaInExtension:
47+
| extension (n: Int) def blah(): Unit =
48+
| (0 until n).foreach { x =>
49+
| println(s"blah $$x") $breakpoint
50+
| }
51+
|
52+
| def main(args: Array[String]): Unit =
53+
| 5.blah() $breakpoint
54+
|""".stripMargin)
55+
56+
def testLambdaInExtension(): Unit = breakpointsTest()(8, 4, 4, 4, 4, 4, 4)
57+
58+
addSourceFile("MainAnnotation.scala",
59+
s"""
60+
|@main def multipleBreakpoints(): Unit = {
61+
| def foo(o: Any): Any = {
62+
| o match {
63+
| case s: String if s.nonEmpty => "string" $breakpoint
64+
| case _ => "not string"
65+
| }
66+
| }
67+
|
68+
| foo("abc")
69+
|}
70+
|""".stripMargin)
71+
72+
def testMainAnnotation(): Unit = breakpointsTest("multipleBreakpoints")(4)
73+
}
74+
75+
class LambdaBreakpointsTest_3_1 extends LambdaBreakpointsTest_3_0 {
76+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_3_1
77+
}
78+
79+
class LambdaBreakpointsTest_3_2 extends LambdaBreakpointsTest_3_0 {
80+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_3_2
81+
}
82+
83+
class LambdaBreakpointsTest_3_3 extends LambdaBreakpointsTest_3_0 {
84+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_3_3
85+
}
86+
87+
class LambdaBreakpointsTest_3_4 extends LambdaBreakpointsTest_3_0 {
88+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_3_4
89+
}
90+
91+
class LambdaBreakpointsTest_3_RC extends LambdaBreakpointsTest_3_0 {
92+
override protected def supportedIn(version: ScalaVersion): Boolean = version == ScalaVersion.Latest.Scala_3_5_RC
93+
94+
override def testLambdaInClassConstructor(): Unit = breakpointsTest()(9, 4, 4, 4, 4, 4)
95+
96+
override def testLambdaInObjectConstructor(): Unit = breakpointsTest()(9, 4, 4, 4, 4, 4)
97+
98+
override def testLambdaInExtension(): Unit = breakpointsTest()(8, 4, 4, 4, 4, 4)
99+
100+
override def testLambdaInNestedObject(): Unit = breakpointsTest()(15, 8, 8, 8, 8, 8)
101+
102+
override def testLambdaInNestedClass(): Unit = breakpointsTest()(15, 8, 8, 8, 8, 8)
103+
104+
override def testLambdaInLocalMethod(): Unit = breakpointsTest()(
105+
21,
106+
8, 9, 10, 11,
107+
8, 9, 10, 11,
108+
8, 9, 10, 11,
109+
8, 9, 10, 11,
110+
8, 9, 10, 11
111+
)
112+
}
113+
114+
abstract class LambdaBreakpointsTestBase extends ScalaDebuggerTestCase {
115+
116+
private val expectedLineQueue: ConcurrentLinkedQueue[Int] = new ConcurrentLinkedQueue()
117+
118+
override protected def tearDown(): Unit = {
119+
try {
120+
if (!expectedLineQueue.isEmpty) {
121+
val remaining =
122+
expectedLineQueue.stream().collect(Collectors.toList[Int]).asScala.toList
123+
fail(s"The debugger did not stop on all expected lines. Remaining: $remaining")
124+
}
125+
} finally {
126+
super.tearDown()
127+
}
128+
}
129+
130+
protected def breakpointsTest(className: String = getTestName(false))(lineNumbers: Int*): Unit = {
131+
assertTrue("The test should stop on at least 1 breakpoint", lineNumbers.nonEmpty)
132+
expectedLineQueue.addAll(lineNumbers.asJava)
133+
134+
createLocalProcess(className)
135+
136+
val debugProcess = getDebugProcess
137+
val positionManager = ScalaPositionManager.instance(debugProcess).getOrElse(new ScalaPositionManager(debugProcess))
138+
139+
onEveryBreakpoint { implicit ctx =>
140+
val loc = ctx.getFrameProxy.location()
141+
val srcPos = inReadAction(positionManager.getSourcePosition(loc))
142+
val actual = srcPos.getLine
143+
Option(expectedLineQueue.poll()) match {
144+
case None =>
145+
fail(s"The debugger stopped on line $actual, but there were no more expected lines")
146+
case Some(expected) =>
147+
assertEquals(expected, actual)
148+
resume(ctx)
149+
}
150+
}
151+
}
152+
153+
addSourceFile("LambdaInClassConstructor.scala",
154+
s"""
155+
|object LambdaInClassConstructor {
156+
| class C {
157+
| (0 until 5).foreach { x =>
158+
| println(x) $breakpoint
159+
| }
160+
| }
161+
|
162+
| def main(args: Array[String]): Unit = {
163+
| println(new C()) $breakpoint
164+
| }
165+
|}
166+
|""".stripMargin)
167+
168+
def testLambdaInClassConstructor(): Unit = breakpointsTest()(9, 4, 4, 4, 4, 4)
169+
170+
addSourceFile("LambdaInObjectConstructor.scala",
171+
s"""
172+
|object LambdaInObjectConstructor {
173+
| object O {
174+
| (0 until 5).foreach { x =>
175+
| println(x) $breakpoint
176+
| }
177+
| }
178+
|
179+
| def main(args: Array[String]): Unit = {
180+
| println(O) $breakpoint
181+
| }
182+
|}
183+
|""".stripMargin)
184+
185+
def testLambdaInObjectConstructor(): Unit = breakpointsTest()(9, 4, 4, 4, 4, 4)
186+
187+
addSourceFile("LambdaInNestedObjectStatic.scala",
188+
s"""
189+
|object LambdaInNestedObjectStatic {
190+
| class Outer {
191+
| object Inner {
192+
| def method(n: Int): Unit = {
193+
| (0 until n).foreach { x => println(x) } $breakpoint ${lambdaOrdinal(0)}
194+
| }
195+
| }
196+
| }
197+
|
198+
| def main(args: Array[String]): Unit = {
199+
| new Outer().Inner.method(5) $breakpoint
200+
| }
201+
|}
202+
|""".stripMargin)
203+
204+
def testLambdaInNestedObjectStatic(): Unit = breakpointsTest()(11, 5, 5, 5, 5, 5)
205+
206+
addSourceFile("LambdaInNestedClassStatic.scala",
207+
s"""
208+
|object LambdaInNestedClassStatic {
209+
| object Outer {
210+
| class Inner {
211+
| def method(n: Int): Unit = {
212+
| (0 until n).foreach(println) $breakpoint ${lambdaOrdinal(0)}
213+
| }
214+
| }
215+
| }
216+
|
217+
| def main(args: Array[String]): Unit = {
218+
| new Outer.Inner().method(5) $breakpoint
219+
| }
220+
|}
221+
|""".stripMargin)
222+
223+
def testLambdaInNestedClassStatic(): Unit = breakpointsTest()(11, 5, 5, 5, 5, 5)
224+
225+
addSourceFile("LambdaInNestedObject.scala",
226+
s"""
227+
|object LambdaInNestedObject {
228+
| class Outer {
229+
| val field: Int = 5
230+
|
231+
| object Inner {
232+
| def method(): Unit = {
233+
| (0 until field).foreach { x =>
234+
| println(x) $breakpoint
235+
| }
236+
| }
237+
| }
238+
| }
239+
|
240+
| def main(args: Array[String]): Unit = {
241+
| new Outer().Inner.method() $breakpoint
242+
| }
243+
|}
244+
|""".stripMargin)
245+
246+
def testLambdaInNestedObject(): Unit = breakpointsTest()(15, 8, 8, 8, 8, 8)
247+
248+
addSourceFile("LambdaInNestedClass.scala",
249+
s"""
250+
|object LambdaInNestedClass {
251+
| object Outer {
252+
| val field: Int = 5
253+
|
254+
| class Inner {
255+
| def method(): Unit = {
256+
| (0 until field).foreach { x =>
257+
| println(x) $breakpoint
258+
| }
259+
| }
260+
| }
261+
| }
262+
|
263+
| def main(args: Array[String]): Unit = {
264+
| new Outer.Inner().method() $breakpoint
265+
| }
266+
|}
267+
|""".stripMargin)
268+
269+
def testLambdaInNestedClass(): Unit = breakpointsTest()(15, 8, 8, 8, 8, 8)
270+
271+
addSourceFile("LambdaInLocalMethod.scala",
272+
s"""
273+
|object LambdaInLocalMethod {
274+
| case class A(s: String = "s", i: Int = 1)
275+
|
276+
| object Inside {
277+
| def create(a: A) = {
278+
| def func(a: A, count: Int) = {
279+
| (0 until count).map { i =>
280+
| val number = i + 1 $breakpoint
281+
| val string = i.toString $breakpoint
282+
| val insideA = A(string, number) $breakpoint
283+
| insideA.s * number $breakpoint
284+
| }
285+
|
286+
| }
287+
| func(a, 5)
288+
| a
289+
| }
290+
| }
291+
|
292+
| def main(args: Array[String]): Unit = {
293+
| Inside.create(A()) $breakpoint
294+
| }
295+
|}
296+
|""".stripMargin)
297+
298+
def testLambdaInLocalMethod(): Unit = breakpointsTest()(
299+
21,
300+
8, 9, 10, 11,
301+
8, 9, 10, 11,
302+
8, 9, 10, 11,
303+
8, 9, 10, 11,
304+
8, 9, 10, 11
305+
)
306+
}
307+

scala/scala-impl/src/org/jetbrains/plugins/scala/ScalaVersion.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ object LatestScalaVersions {
9797
//
9898
val Scala_3_RC = new ScalaVersion(ScalaLanguageLevel.Scala_3_4, "2-RC1")
9999

100+
val Scala_3_5_RC = new ScalaVersion(ScalaLanguageLevel.Scala_3_5, "0-RC1")
101+
100102
val allScala2: Seq[ScalaVersion] = Seq(
101103
Scala_2_9,
102104
Scala_2_10,
@@ -114,6 +116,7 @@ object LatestScalaVersions {
114116

115117
val scalaNext: Seq[ScalaVersion] = Seq(
116118
Scala_3_4,
119+
Scala_3_5_RC,
117120
Scala_3_RC
118121
)
119122

0 commit comments

Comments
 (0)