From 4f1ad71ef2dccc1c175360c42aeb724e0fd5dd61 Mon Sep 17 00:00:00 2001 From: Tom Grigg Date: Fri, 4 Mar 2022 19:00:38 -0800 Subject: [PATCH 1/2] Harden REPL in presence of values that fail to initialize The right hand side of value definitions in the REPL are computed in the static initializer for the wrapper object created for that input line (e.g. rs$line$1). If any of these definitions throws an exception, the wrapper class will fail to initialize, and further attempts to use the class will throw NoClassDefFoundError. In this commit, we avoid all reflective access on a wrapper class once we notice that it failed to initialize, and mark that wrapper object as invalid in the REPL state. We discard all input from the failed wrapper (which may have been multi-line containing many statements and definitions); any types, terms, aliases, or imports defined there will not override any existing with the same name, and will not be accessible in subsequent runs. Fixes #4416 Fixes #14473 --- compiler/src/dotty/tools/repl/Rendering.scala | 16 ++-- .../src/dotty/tools/repl/ReplCompiler.scala | 2 +- .../src/dotty/tools/repl/ReplDriver.scala | 51 ++++++++--- .../dotty/tools/repl/ReplCompilerTests.scala | 91 +++++++++++++++++++ 4 files changed, 139 insertions(+), 21 deletions(-) diff --git a/compiler/src/dotty/tools/repl/Rendering.scala b/compiler/src/dotty/tools/repl/Rendering.scala index 5a0355570663..1391b30251a8 100644 --- a/compiler/src/dotty/tools/repl/Rendering.scala +++ b/compiler/src/dotty/tools/repl/Rendering.scala @@ -129,13 +129,15 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) { infoDiagnostic(d.symbol.showUser, d) /** Render value definition result */ - def renderVal(d: Denotation)(using Context): Option[Diagnostic] = + def renderVal(d: Denotation)(using Context): Either[InvocationTargetException, Option[Diagnostic]] = val dcl = d.symbol.showUser def msg(s: String) = infoDiagnostic(s, d) try - if (d.symbol.is(Flags.Lazy)) Some(msg(dcl)) - else valueOf(d.symbol).map(value => msg(s"$dcl = $value")) - catch case e: InvocationTargetException => Some(msg(renderError(e, d))) + Right( + if d.symbol.is(Flags.Lazy) then Some(msg(dcl)) + else valueOf(d.symbol).map(value => msg(s"$dcl = $value")) + ) + catch case e: InvocationTargetException => Left(e) end renderVal /** Force module initialization in the absence of members. */ @@ -144,10 +146,10 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) { val objectName = sym.fullName.encode.toString Class.forName(objectName, true, classLoader()) Nil - try load() catch case e: ExceptionInInitializerError => List(infoDiagnostic(renderError(e, sym.denot), sym.denot)) + try load() catch case e: ExceptionInInitializerError => List(renderError(e, sym.denot)) /** Render the stack trace of the underlying exception. */ - private def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): String = + def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): Diagnostic = import dotty.tools.dotc.util.StackTraceOps._ val cause = ite.getCause match case e: ExceptionInInitializerError => e.getCause @@ -159,7 +161,7 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) { ste.getClassName.startsWith(REPL_WRAPPER_NAME_PREFIX) // d.symbol.owner.name.show is simple name && (ste.getMethodName == nme.STATIC_CONSTRUCTOR.show || ste.getMethodName == nme.CONSTRUCTOR.show) - cause.formatStackTracePrefix(!isWrapperInitialization(_)) + infoDiagnostic(cause.formatStackTracePrefix(!isWrapperInitialization(_)), d) end renderError private def infoDiagnostic(msg: String, d: Denotation)(using Context): Diagnostic = diff --git a/compiler/src/dotty/tools/repl/ReplCompiler.scala b/compiler/src/dotty/tools/repl/ReplCompiler.scala index 205d1acf3805..fb71d4bbb805 100644 --- a/compiler/src/dotty/tools/repl/ReplCompiler.scala +++ b/compiler/src/dotty/tools/repl/ReplCompiler.scala @@ -61,7 +61,7 @@ class ReplCompiler extends Compiler { val rootCtx = super.rootContext.fresh .setOwner(defn.EmptyPackageClass) .withRootImports - (1 to state.objectIndex).foldLeft(rootCtx)((ctx, id) => + (state.validObjectIndexes).foldLeft(rootCtx)((ctx, id) => importPreviousRun(id)(using ctx)) } } diff --git a/compiler/src/dotty/tools/repl/ReplDriver.scala b/compiler/src/dotty/tools/repl/ReplDriver.scala index 337f87725762..184e1c0817fb 100644 --- a/compiler/src/dotty/tools/repl/ReplDriver.scala +++ b/compiler/src/dotty/tools/repl/ReplDriver.scala @@ -35,6 +35,7 @@ import dotty.tools.runner.ScalaClassLoader.* import org.jline.reader._ import scala.annotation.tailrec +import scala.collection.mutable import scala.collection.JavaConverters._ import scala.util.Using @@ -55,12 +56,15 @@ import scala.util.Using * @param objectIndex the index of the next wrapper * @param valIndex the index of next value binding for free expressions * @param imports a map from object index to the list of user defined imports + * @param invalidObjectIndexes the set of object indexes that failed to initialize * @param context the latest compiler context */ case class State(objectIndex: Int, valIndex: Int, imports: Map[Int, List[tpd.Import]], - context: Context) + invalidObjectIndexes: Set[Int], + context: Context): + def validObjectIndexes = (1 to objectIndex).filterNot(invalidObjectIndexes.contains(_)) /** Main REPL instance, orchestrating input, compilation and presentation */ class ReplDriver(settings: Array[String], @@ -94,7 +98,7 @@ class ReplDriver(settings: Array[String], } /** the initial, empty state of the REPL session */ - final def initialState: State = State(0, 0, Map.empty, rootCtx) + final def initialState: State = State(0, 0, Map.empty, Set.empty, rootCtx) /** Reset state of repl to the initial state * @@ -237,7 +241,7 @@ class ReplDriver(settings: Array[String], completions.map(_.label).distinct.map(makeCandidate) } .getOrElse(Nil) - end completions + end completions private def interpret(res: ParseResult)(implicit state: State): State = { res match { @@ -353,14 +357,33 @@ class ReplDriver(settings: Array[String], val typeAliases = info.bounds.hi.typeMembers.filter(_.symbol.info.isTypeAlias) - val formattedMembers = - typeAliases.map(rendering.renderTypeAlias) ++ - defs.map(rendering.renderMethod) ++ - vals.flatMap(rendering.renderVal) - - val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers - - (state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics) + // The wrapper object may fail to initialize if the rhs of a ValDef throws. + // In that case, don't attempt to render any subsequent vals, and mark this + // wrapper object index as invalid. + var failedInit = false + val renderedVals = + val buf = mutable.ListBuffer[Diagnostic]() + for d <- vals do if !failedInit then rendering.renderVal(d) match + case Right(Some(v)) => + buf += v + case Left(e) => + buf += rendering.renderError(e, d) + failedInit = true + case _ => + buf.toList + + if failedInit then + // We limit the returned diagnostics here to `renderedVals`, which will contain the rendered error + // for the val which failed to initialize. Since any other defs, aliases, imports, etc. from this + // input line will be inaccessible, we avoid rendering those so as not to confuse the user. + (state.copy(invalidObjectIndexes = state.invalidObjectIndexes + state.objectIndex), renderedVals) + else + val formattedMembers = + typeAliases.map(rendering.renderTypeAlias) + ++ defs.map(rendering.renderMethod) + ++ renderedVals + val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers + (state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics) } else (state, Seq.empty) @@ -378,8 +401,10 @@ class ReplDriver(settings: Array[String], tree.symbol.info.memberClasses .find(_.symbol.name == newestWrapper.moduleClassName) .map { wrapperModule => - val formattedTypeDefs = typeDefs(wrapperModule.symbol) val (newState, formattedMembers) = extractAndFormatMembers(wrapperModule.symbol) + val formattedTypeDefs = // don't render type defs if wrapper initialization failed + if newState.invalidObjectIndexes.contains(state.objectIndex) then Seq.empty + else typeDefs(wrapperModule.symbol) val highlighted = (formattedTypeDefs ++ formattedMembers) .map(d => new Diagnostic(d.msg.mapMsg(SyntaxHighlighting.highlight), d.pos, d.level)) (newState, highlighted) @@ -420,7 +445,7 @@ class ReplDriver(settings: Array[String], case Imports => for { - objectIndex <- 1 to state.objectIndex + objectIndex <- state.validObjectIndexes imp <- state.imports.getOrElse(objectIndex, Nil) } out.println(imp.show(using state.context)) state diff --git a/compiler/test/dotty/tools/repl/ReplCompilerTests.scala b/compiler/test/dotty/tools/repl/ReplCompilerTests.scala index e1c25937ecf8..5b1ec9a8bd3b 100644 --- a/compiler/test/dotty/tools/repl/ReplCompilerTests.scala +++ b/compiler/test/dotty/tools/repl/ReplCompilerTests.scala @@ -243,6 +243,97 @@ class ReplCompilerTests extends ReplTest: assertEquals(List("// defined class C"), lines()) } + def assertNotFoundError(id: String): Unit = + val lines = storedOutput().linesIterator + assert(lines.next().startsWith("-- [E006] Not Found Error:")) + assert(lines.drop(2).next().trim().endsWith(s"Not found: $id")) + + @Test def i4416 = initially { + val state = run("val x = 1 / 0") + val all = lines() + assertEquals(2, all.length) + assert(all.head.startsWith("java.lang.ArithmeticException:")) + state + } andThen { + val state = run("def foo = x") + assertNotFoundError("x") + state + } andThen { + run("x") + assertNotFoundError("x") + } + + @Test def i4416b = initially { + val state = run("val a = 1234") + val _ = storedOutput() // discard output + state + } andThen { + val state = run("val a = 1; val x = ???; val y = x") + val all = lines() + assertEquals(3, all.length) + assertEquals("scala.NotImplementedError: an implementation is missing", all.head) + state + } andThen { + val state = run("x") + assertNotFoundError("x") + state + } andThen { + val state = run("y") + assertNotFoundError("y") + state + } andThen { + run("a") // `a` should retain its original binding + assertEquals("val res0: Int = 1234", storedOutput().trim) + } + + @Test def i4416_imports = initially { + run("import scala.collection.mutable") + } andThen { + val state = run("import scala.util.Try; val x = ???") + val _ = storedOutput() // discard output + state + } andThen { + run(":imports") // scala.util.Try should not be imported + assertEquals("import scala.collection.mutable", storedOutput().trim) + } + + @Test def i4416_types_defs_aliases = initially { + val state = + run("""|type Foo = String + |trait Bar + |def bar: Bar = ??? + |val x = ??? + |""".stripMargin) + val all = lines() + assertEquals(3, all.length) + assertEquals("scala.NotImplementedError: an implementation is missing", all.head) + assert("type alias in failed wrapper should not be rendered", + !all.exists(_.startsWith("// defined alias type Foo = String"))) + assert("type definitions in failed wrapper should not be rendered", + !all.exists(_.startsWith("// defined trait Bar"))) + assert("defs in failed wrapper should not be rendered", + !all.exists(_.startsWith("def bar: Bar"))) + state + } andThen { + val state = run("def foo: Foo = ???") + assertNotFoundError("type Foo") + state + } andThen { + val state = run("type B = Bar") + assertNotFoundError("type Bar") + state + } andThen { + run("bar") + assertNotFoundError("bar") + } + + @Test def i14473 = initially { + run("""val (x,y) = if true then "hi" else (42,17)""") + val all = lines() + assertEquals(2, all.length) + assertEquals("scala.MatchError: hi (of class java.lang.String)", all.head) + } + @Test def i14491 = initially { run("import language.experimental.fewerBraces") From 0b0f62666d2227de00da1542231556f65cc31c3b Mon Sep 17 00:00:00 2001 From: Tom Grigg Date: Sun, 13 Mar 2022 16:57:00 -0700 Subject: [PATCH 2/2] Fix #14701: avoid REPL crash when input is of the form `val _ = ???` The REPL should not crash when the right hand side of `val _ =` throws a non-fatal error that is a subclass of java.lang.Error. --- compiler/src/dotty/tools/repl/Rendering.scala | 6 +++++- .../test/dotty/tools/repl/ReplCompilerTests.scala | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/repl/Rendering.scala b/compiler/src/dotty/tools/repl/Rendering.scala index 1391b30251a8..98944c9ab48c 100644 --- a/compiler/src/dotty/tools/repl/Rendering.scala +++ b/compiler/src/dotty/tools/repl/Rendering.scala @@ -142,11 +142,15 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) { /** Force module initialization in the absence of members. */ def forceModule(sym: Symbol)(using Context): Seq[Diagnostic] = + import scala.util.control.NonFatal def load() = val objectName = sym.fullName.encode.toString Class.forName(objectName, true, classLoader()) Nil - try load() catch case e: ExceptionInInitializerError => List(renderError(e, sym.denot)) + try load() + catch + case e: ExceptionInInitializerError => List(renderError(e, sym.denot)) + case NonFatal(e) => List(renderError(InvocationTargetException(e), sym.denot)) /** Render the stack trace of the underlying exception. */ def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): Diagnostic = diff --git a/compiler/test/dotty/tools/repl/ReplCompilerTests.scala b/compiler/test/dotty/tools/repl/ReplCompilerTests.scala index 5b1ec9a8bd3b..963f269af8a6 100644 --- a/compiler/test/dotty/tools/repl/ReplCompilerTests.scala +++ b/compiler/test/dotty/tools/repl/ReplCompilerTests.scala @@ -334,6 +334,19 @@ class ReplCompilerTests extends ReplTest: assertEquals("scala.MatchError: hi (of class java.lang.String)", all.head) } + @Test def i14701 = initially { + val state = run("val _ = ???") + val all = lines() + assertEquals(3, all.length) + assertEquals("scala.NotImplementedError: an implementation is missing", all.head) + state + } andThen { + run("val _ = assert(false)") + val all = lines() + assertEquals(3, all.length) + assertEquals("java.lang.AssertionError: assertion failed", all.head) + } + @Test def i14491 = initially { run("import language.experimental.fewerBraces")