diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/internals/UpdatableResults.scala b/modules/scala/scala-interpreter/src/main/scala/almond/internals/UpdatableResults.scala index 924315313..f59f2da0e 100644 --- a/modules/scala/scala-interpreter/src/main/scala/almond/internals/UpdatableResults.scala +++ b/modules/scala/scala-interpreter/src/main/scala/almond/internals/UpdatableResults.scala @@ -1,5 +1,7 @@ package almond.internals +import java.util.concurrent.ConcurrentHashMap + import almond.interpreter.api.DisplayData import almond.logger.LoggerContext import ammonite.util.Ref @@ -15,33 +17,55 @@ final class UpdatableResults( private val log = logCtx(getClass) - val refs = new mutable.HashMap[String, Ref[(DisplayData, Map[String, String])]] + val refs = new ConcurrentHashMap[String, Ref[(DisplayData, Map[String, String])]] + + val addRefsLock = new Object + + val earlyUpdates = new mutable.HashMap[String, (String, Boolean)] def add(data: DisplayData, variables: Map[String, String]): DisplayData = { val ref = Ref((data, variables)) - refs ++= variables.keysIterator.map { k => - k -> ref + addRefsLock.synchronized { + val variables0 = variables.map { + case (k, v) => + val vOpt = earlyUpdates.remove(k) + if (!vOpt.exists(_._2)) + refs.put(k, ref) + k -> vOpt.fold(v)(_._1) + } + UpdatableResults.substituteVariables(data, variables0) } - UpdatableResults.substituteVariables(data, variables) } - def update(k: String, v: String, last: Boolean): Unit = - Future( - refs.get(k) match { - case None => - log.warn(s"Updatable variable $k not found") - throw new NoSuchElementException(s"Result variable $k") - case Some(ref) => - log.info(s"Updating variable $k with $v") - val (data0, m0) = ref() - val m = m0 + (k -> v) - val data = UpdatableResults.substituteVariables(data0, m) - ref() = (data, m) - updateData(data) - if (last) - refs -= k - } - )(ec) // FIXME Log failures + def update(k: String, v: String, last: Boolean): Unit = { + + def updateRef(ref: Ref[(DisplayData, Map[String, String])]): Unit = { + log.info(s"Updating variable $k with $v") + val (data0, m0) = ref() + val m = m0 + (k -> v) + val data = UpdatableResults.substituteVariables(data0, m) + ref() = (data, m) + Future(updateData(data))(ec) + if (last) + refs.remove(k) + } + + Option(refs.get(k)) match { + case None => + val r = addRefsLock.synchronized { + val r = Option(refs.get(k)) + if (r.isEmpty) { + log.warn(s"Updatable variable $k not found") + earlyUpdates += k -> (v, last) + } + r + } + for (ref <- r) + updateRef(ref) + case Some(ref) => + updateRef(ref) + } + } } diff --git a/modules/scala/scala-interpreter/src/test/scala/almond/UpdatableResultsTests.scala b/modules/scala/scala-interpreter/src/test/scala/almond/UpdatableResultsTests.scala new file mode 100644 index 000000000..fdc548e1c --- /dev/null +++ b/modules/scala/scala-interpreter/src/test/scala/almond/UpdatableResultsTests.scala @@ -0,0 +1,32 @@ +package almond + +import almond.internals.UpdatableResults +import almond.interpreter.api.DisplayData +import almond.logger.LoggerContext +import utest._ + +import scala.concurrent.ExecutionContext + +object UpdatableResultsTests extends TestSuite { + + private val ec: ExecutionContext = + new ExecutionContext { + def execute(runnable: Runnable) = runnable.run() + def reportFailure(cause: Throwable) = () + } + + val tests = Tests { + + "early update" - { + val updates = new java.util.concurrent.ConcurrentLinkedQueue[DisplayData] + val r = new UpdatableResults(ec, LoggerContext.nop, updates.add) + r.update("", "value", last = true) + val data = r.add(DisplayData.text("Foo "), Map("" -> "---")) + val expectedData = DisplayData.text("Foo value") + assert(data == expectedData) + assert(r.earlyUpdates.isEmpty) + } + + } + +}