-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
147 additions
and
5 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package in.rcard.sus4s | ||
|
||
import java.util.concurrent.{CompletableFuture, StructuredTaskScope} | ||
|
||
object sus4s { | ||
trait Suspend { | ||
val scope: StructuredTaskScope[Any] | ||
} | ||
|
||
type Suspended[A] = Suspend ?=> A | ||
|
||
class Job[A] private[sus4s] (private val cf: CompletableFuture[A]) { | ||
def value: A = cf.join() | ||
} | ||
|
||
inline def structured[A](inline block: Suspend ?=> A): A = { | ||
val _scope = new StructuredTaskScope.ShutdownOnFailure() | ||
|
||
given suspended: Suspend = new Suspend { | ||
override val scope: StructuredTaskScope[Any] = _scope | ||
} | ||
|
||
try { | ||
val mainTask = _scope.fork(() => { | ||
block | ||
}) | ||
_scope.join().throwIfFailed(identity) | ||
mainTask.get() | ||
} finally { | ||
_scope.close() | ||
} | ||
} | ||
|
||
def fork[A](block: Suspend ?=> A): Suspend ?=> Job[A] = { | ||
val result = new CompletableFuture[A]() | ||
summon[Suspend].scope.fork(() => { | ||
try result.complete(block) | ||
catch | ||
case throwable: Throwable => | ||
result.completeExceptionally(throwable) | ||
throw throwable; | ||
}) | ||
Job(result) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import in.rcard.sus4s.sus4s | ||
import in.rcard.sus4s.sus4s.{fork, structured} | ||
import org.scalatest.TryValues.* | ||
import org.scalatest.flatspec.AnyFlatSpec | ||
import org.scalatest.matchers.should.Matchers | ||
|
||
import java.util.concurrent.ConcurrentLinkedQueue | ||
import scala.util.Try | ||
|
||
class StructuredSpec extends AnyFlatSpec with Matchers { | ||
|
||
"A structured program" should "wait the completion of all the forked jobs" in { | ||
val results = structured { | ||
val queue = new ConcurrentLinkedQueue[String]() | ||
val job1 = fork { | ||
Thread.sleep(1000) | ||
queue.add("job1") | ||
} | ||
val job2 = fork { | ||
Thread.sleep(500) | ||
queue.add("job2") | ||
} | ||
val job3 = fork { | ||
Thread.sleep(100) | ||
queue.add("job3") | ||
} | ||
queue | ||
} | ||
|
||
results.toArray should contain theSameElementsInOrderAs List("job3", "job2", "job1") | ||
} | ||
|
||
it should "stop the execution if one the job throws an exception" in { | ||
val results = new ConcurrentLinkedQueue[String]() | ||
val tryResult = Try { | ||
structured { | ||
val job1 = fork { | ||
Thread.sleep(1000) | ||
results.add("job1") | ||
} | ||
val job2 = fork { | ||
Thread.sleep(500) | ||
results.add("job2") | ||
throw new RuntimeException("Error") | ||
} | ||
val job3 = fork { | ||
Thread.sleep(100) | ||
results.add("job3") | ||
} | ||
} | ||
} | ||
|
||
tryResult.failure.exception shouldBe a[RuntimeException] | ||
tryResult.failure.exception.getMessage shouldBe "Error" | ||
results.toArray should contain theSameElementsInOrderAs List("job3", "job2") | ||
} | ||
|
||
it should "stop the execution if the block throws an exception" in { | ||
val results = new ConcurrentLinkedQueue[String]() | ||
val tryResult = Try { | ||
structured { | ||
val job1 = fork { | ||
Thread.sleep(1000) | ||
results.add("job1") | ||
} | ||
val job2 = fork { | ||
Thread.sleep(500) | ||
results.add("job2") | ||
} | ||
val job3 = fork { | ||
Thread.sleep(100) | ||
results.add("job3") | ||
} | ||
throw new RuntimeException("Error") | ||
} | ||
} | ||
|
||
tryResult.failure.exception shouldBe a[RuntimeException] | ||
tryResult.failure.exception.getMessage shouldBe "Error" | ||
results.toArray shouldBe empty | ||
} | ||
|
||
it should "join the values of different jobs" in { | ||
val queue = new ConcurrentLinkedQueue[String]() | ||
val result = structured { | ||
val job1 = fork { | ||
Thread.sleep(1000) | ||
queue.add("job1") | ||
42 | ||
} | ||
val job2 = fork { | ||
Thread.sleep(500) | ||
queue.add("job2") | ||
43 | ||
} | ||
job1.value + job2.value | ||
} | ||
|
||
queue.toArray should contain theSameElementsInOrderAs List("job2", "job1") | ||
result shouldBe 85 | ||
} | ||
} |