diff --git a/core/src/main/scala/scala/collection/parallel/Tasks.scala b/core/src/main/scala/scala/collection/parallel/Tasks.scala index 87e918aa..4cbafb33 100644 --- a/core/src/main/scala/scala/collection/parallel/Tasks.scala +++ b/core/src/main/scala/scala/collection/parallel/Tasks.scala @@ -254,12 +254,10 @@ trait ForkJoinTasks extends Tasks with HavingForkJoinPool { def execute[R, Tp](task: Task[R, Tp]): () => R = { val fjtask = newWrappedTask(task) - if (Thread.currentThread.isInstanceOf[ForkJoinWorkerThread]) { - fjtask.fork - } else { - forkJoinPool.execute(fjtask) + Thread.currentThread match { + case fjw: ForkJoinWorkerThread if fjw.getPool eq forkJoinPool => fjtask.fork() + case _ => forkJoinPool.execute(fjtask) } - () => { fjtask.sync() fjtask.body.forwardThrowable() @@ -277,12 +275,10 @@ trait ForkJoinTasks extends Tasks with HavingForkJoinPool { def executeAndWaitResult[R, Tp](task: Task[R, Tp]): R = { val fjtask = newWrappedTask(task) - if (Thread.currentThread.isInstanceOf[ForkJoinWorkerThread]) { - fjtask.fork - } else { - forkJoinPool.execute(fjtask) + Thread.currentThread match { + case fjw: ForkJoinWorkerThread if fjw.getPool eq forkJoinPool => fjtask.fork() + case _ => forkJoinPool.execute(fjtask) } - fjtask.sync() // if (fjtask.body.throwable != null) println("throwing: " + fjtask.body.throwable + " at " + fjtask.body) fjtask.body.forwardThrowable() diff --git a/junit/src/test/scala/scala/collection/parallel/TaskTest.scala b/junit/src/test/scala/scala/collection/parallel/TaskTest.scala new file mode 100644 index 00000000..c44cb89f --- /dev/null +++ b/junit/src/test/scala/scala/collection/parallel/TaskTest.scala @@ -0,0 +1,30 @@ +package scala.collection.parallel + +import org.junit.Test +import org.junit.Assert._ + +import java.util.concurrent.{ForkJoinPool, ForkJoinWorkerThread}, ForkJoinPool._ + +import CollectionConverters._ + +class TaskTest { + @Test + def `t10577 task executes on foreign pool`(): Unit = { + def mkFactory(name: String) = new ForkJoinWorkerThreadFactory { + override def newThread(pool: ForkJoinPool) = { + val t = new ForkJoinWorkerThread(pool) {} + t.setName(name) + t + } + } + def mkPool(name: String) = new ForkJoinPool(1, mkFactory(name), null, false) + + val one = List(1).par + val two = List(2).par + + one.tasksupport = new ForkJoinTaskSupport(mkPool("one")) + two.tasksupport = new ForkJoinTaskSupport(mkPool("two")) + + for (x <- one ; y <- two) assertEquals("two", Thread.currentThread.getName) + } +}