Skip to content

REPL class loader is context class loader #14293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import dotty.tools.dotc.util.{SourceFile, SourcePosition}
import dotty.tools.dotc.{CompilationUnit, Driver}
import dotty.tools.dotc.config.CompilerCommand
import dotty.tools.io._
import dotty.tools.runner.ScalaClassLoader.*
import org.jline.reader._

import scala.annotation.tailrec
Expand Down Expand Up @@ -62,7 +63,7 @@ case class State(objectIndex: Int,
/** Main REPL instance, orchestrating input, compilation and presentation */
class ReplDriver(settings: Array[String],
out: PrintStream = Console.out,
classLoader: Option[ClassLoader] = None) extends Driver {
classLoader: Option[ClassLoader] = None) extends Driver:

/** Overridden to `false` in order to not have to give sources on the
* commandline
Expand Down Expand Up @@ -161,15 +162,17 @@ class ReplDriver(settings: Array[String],
else loop(interpret(res)(state))
}

try withRedirectedOutput { loop(initialState) }
try runBody { loop(initialState) }
finally terminal.close()
}

final def run(input: String)(implicit state: State): State = withRedirectedOutput {
final def run(input: String)(implicit state: State): State = runBody {
val parsed = ParseResult(input)(state)
interpret(parsed)
}

private def runBody(body: => State): State = rendering.classLoader()(using rootCtx).asContext(withRedirectedOutput(body))

// TODO: i5069
final def bind(name: String, value: Any)(implicit state: State): State = state

Expand Down Expand Up @@ -455,4 +458,5 @@ class ReplDriver(settings: Array[String],
private def printDiagnostic(dia: Diagnostic)(implicit state: State) = dia.level match
case interfaces.Diagnostic.INFO => out.println(dia.msg) // print REPL's special info diagnostics directly to out
case _ => ReplConsoleReporter.doReport(dia)(using state.context)
}

end ReplDriver
17 changes: 11 additions & 6 deletions compiler/src/dotty/tools/runner/ScalaClassLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ import scala.annotation.tailrec
import scala.util.control.Exception.catching

final class RichClassLoader(private val self: ClassLoader) extends AnyVal {
/** Executing an action with this classloader as context classloader */
private def asContext[T](action: => T): T = {
val saved = Thread.currentThread.getContextClassLoader
try { ScalaClassLoader.setContext(self) ; action }
finally ScalaClassLoader.setContext(saved)
}
/** Execute an action with this classloader as context classloader. */
private def asContext[T](action: => T): T = ScalaClassLoader.asContext(self)(action)

/** Load and link a class with this classloader */
def tryToLoadClass[T <: AnyRef](path: String): Option[Class[T]] = tryClass(path, initialize = false)
Expand Down Expand Up @@ -74,4 +70,13 @@ object ScalaClassLoader {
MethodHandles.lookup().findStatic(classOf[ClassLoader], "getPlatformClassLoader", MethodType.methodType(classOf[ClassLoader])).invoke().asInstanceOf[ClassLoader]
catch case _: Throwable => null
else null

extension (classLoader: ClassLoader)
/** Execute an action with this classloader as context classloader. */
def asContext[T](action: => T): T =
val saved = Thread.currentThread.getContextClassLoader
try
setContext(classLoader)
action
finally setContext(saved)
}
102 changes: 32 additions & 70 deletions compiler/test/dotty/tools/repl/DocTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,110 +6,76 @@ import org.junit.Assert.assertEquals

class DocTests extends ReplTest {

@Test def docOfDef =
eval("/** doc */ def foo = 0").andThen { implicit s =>
assertEquals("doc", doc("foo"))
}
@Test def docOfDef = eval("/** doc */ def foo = 0") andThen assertEquals("doc", doc("foo"))

@Test def docOfVal =
eval("/** doc */ val foo = 0").andThen { implicit s =>
assertEquals("doc", doc("foo"))
}
@Test def docOfVal = eval("/** doc */ val foo = 0") andThen assertEquals("doc", doc("foo"))

@Test def docOfObject =
eval("/** doc */ object Foo").andThen { implicit s =>
assertEquals("doc", doc("Foo"))
}
@Test def docOfObject = eval("/** doc */ object Foo") andThen assertEquals("doc", doc("Foo"))

@Test def docOfClass =
eval("/** doc */ class Foo").andThen { implicit s =>
assertEquals("doc", doc("new Foo"))
}
@Test def docOfClass = eval("/** doc */ class Foo") andThen assertEquals("doc", doc("new Foo"))

@Test def docOfTrait =
eval("/** doc */ trait Foo").andThen { implicit s =>
assertEquals("doc", doc("new Foo"))
}
/*
@Test def docOfExtension1 =
eval("/** doc */ extension (x: Int) def foo = 0").andThen { implicit s =>
assertEquals("doc", doc("extension_foo"))
}
@Test def docOfTrait = eval("/** doc */ trait Foo") andThen assertEquals("doc", doc("new Foo"))

@Test def docOfExtension2 =
eval("extension (x: Int) /** doc */ def foo = 0").andThen { implicit s =>
assertEquals("doc", doc("extension_foo"))
}
/*@Test*/ def docOfExtension1 =
eval("/** doc */ extension (x: Int) def foo = 0") andThen assertEquals("doc", doc("extension_foo"))

@Test def docOfExtension3 =
eval("/** doc0 */ extension (x: Int) { /** doc1 */ def foo = 0; /** doc2 */ def bar = 0; def baz = 0 }").andThen { implicit s =>
/*@Test*/ def docOfExtension2 =
eval("extension (x: Int) /** doc */ def foo = 0") andThen assertEquals("doc", doc("extension_foo"))

/*@Test*/ def docOfExtension3 =
eval("/** doc0 */ extension (x: Int) { /** doc1 */ def foo = 0; /** doc2 */ def bar = 0; def baz = 0 }") andThen {
assertEquals("doc1", doc("extension_foo"))
assertEquals("doc2", doc("extension_bar"))
assertEquals("doc0", doc("extension_baz"))
}
*/
@Test def docOfDefInObject =
eval("object O { /** doc */ def foo = 0 }").andThen { implicit s =>
assertEquals("doc", doc("O.foo"))
}

@Test def docOfValInObject =
eval("object O { /** doc */ val foo = 0 }").andThen { implicit s =>
assertEquals("doc", doc("O.foo"))
}
@Test def docOfDefInObject = eval("object O { /** doc */ def foo = 0 }") andThen assertEquals("doc", doc("O.foo"))

@Test def docOfObjectInObject =
eval("object O { /** doc */ object Foo }").andThen { implicit s =>
assertEquals("doc", doc("O.Foo"))
}
@Test def docOfValInObject = eval("object O { /** doc */ val foo = 0 }") andThen assertEquals("doc", doc("O.foo"))

@Test def docOfClassInObject =
eval("object O { /** doc */ class Foo }").andThen { implicit s =>
assertEquals("doc", doc("new O.Foo"))
}
@Test def docOfObjectInObject = eval("object O { /** doc */ object Foo }") andThen assertEquals("doc", doc("O.Foo"))

@Test def docOfTraitInObject =
eval("object O { /** doc */ trait Foo }").andThen { implicit s =>
assertEquals("doc", doc("new O.Foo"))
}
@Test def docOfClassInObject = eval("object O { /** doc */ class Foo }") andThen assertEquals("doc", doc("new O.Foo"))

@Test def docOfTraitInObject = eval("object O { /** doc */ trait Foo }") andThen assertEquals("doc", doc("new O.Foo"))

@Test def docOfDefInClass =
eval(
"""class C { /** doc */ def foo = 0 }
|val c = new C
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("doc", doc("c.foo"))
}

@Test def docOfValInClass =
eval(
"""class C { /** doc */ val foo = 0 }
|val c = new C
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("doc", doc("c.foo"))
}

@Test def docOfObjectInClass =
eval(
"""class C { /** doc */ object Foo }
|val c = new C
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("doc", doc("c.Foo"))
}

@Test def docOfClassInClass =
eval(
"""class C { /** doc */ class Foo }
|val c = new C
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("doc", doc("new c.Foo"))
}

@Test def docOfTraitInClass =
eval(
"""class C { /** doc */ trait Foo }
|val c = new C
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("doc", doc("new c.Foo"))
}

Expand All @@ -119,7 +85,7 @@ class DocTests extends ReplTest {
| /** doc0 */ def foo(x: Int) = x
| /** doc1 */ def foo(x: String) = x
|}
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("doc0", doc("O.foo(_: Int)"))
assertEquals("doc1", doc("O.foo(_: String)"))
}
Expand All @@ -128,7 +94,7 @@ class DocTests extends ReplTest {
eval(
"""class C { /** doc */ def foo = 0 }
|object O extends C
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("doc", doc("O.foo"))
}

Expand All @@ -142,7 +108,7 @@ class DocTests extends ReplTest {
| override def foo(x: Int): Int = x
| /** overridden doc */ override def foo(x: String): String = x
|}
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("doc0", doc("O.foo(_: Int)"))
assertEquals("overridden doc", doc("O.foo(_: String)"))
}
Expand All @@ -158,38 +124,34 @@ class DocTests extends ReplTest {
| override def bar: Int = 0
| }
|}
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("companion", doc("O.foo"))
assertEquals("doc0", doc("O.foo.bar"))
}

@Test def docIsCooked =
eval(
"""/**
| * An object
"""/** An object
| * @define Variable some-value
| */
|object Foo {
| /** Expansion: $Variable */
| def hello = "world"
|}
""".stripMargin).andThen { implicit s =>
""".stripMargin) andThen {
assertEquals("Expansion: some-value", doc("Foo.hello"))
}

@Test def docOfEmpty =
fromInitialState { implicit s =>
@Test def docOfEmpty = initially {
run(":doc")
assertEquals(":doc <expression>", storedOutput().trim)
}

private def eval(code: String): State =
fromInitialState { implicit s => run(code) }
private def eval(code: String): State = initially(run(code))

private def doc(expr: String)(implicit s: State): String = {
private def doc(expr: String)(using State): String = {
storedOutput()
run(s":doc $expr")
storedOutput().trim
}

}
4 changes: 2 additions & 2 deletions compiler/test/dotty/tools/repl/JavaDefinedTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import org.junit.Assert._
import org.junit.Test

class JavaDefinedTests extends ReplTest {
@Test def typeOfJavaDefinedString = fromInitialState { implicit s =>
@Test def typeOfJavaDefinedString = initially {
run("String")
assertTrue(storedOutput().contains("Java defined class String is not a value"))
}

@Test def typeOfJavaDefinedClass = fromInitialState { implicit s =>
@Test def typeOfJavaDefinedClass = initially {
run("Class")
assertTrue(storedOutput().contains("Java defined class Class is not a value"))
}
Expand Down
6 changes: 2 additions & 4 deletions compiler/test/dotty/tools/repl/LoadTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ class LoadTests extends ReplTest {
)

def loadTest(file: String, defs: String, runCode: String, output: String) =
eval(s":load ${writeFile(file)}").andThen { implicit s =>
eval(s":load ${writeFile(file)}") andThen {
assertMultiLineEquals(defs, storedOutput())
run(runCode)
assertMultiLineEquals(output, storedOutput())
}

private def eval(code: String): State =
fromInitialState { implicit s => run(code) }

private def eval(code: String): State = initially(run(code))
}

object LoadTests {
Expand Down
Loading