Skip to content

Commit a937d18

Browse files
committed
Implement barrier.
1 parent de57a22 commit a937d18

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

src/libextra/sync.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ use std::unstable::finally::Finally;
2727
use std::util;
2828
use std::util::NonCopyable;
2929

30+
use arc::MutexArc;
31+
3032
/****************************************************************************
3133
* Internals
3234
****************************************************************************/
@@ -682,6 +684,67 @@ impl<'a> RWLockReadMode<'a> {
682684
pub fn read<U>(&self, blk: || -> U) -> U { blk() }
683685
}
684686

687+
/// A barrier enables multiple tasks to synchronize the beginning
688+
/// of some computation.
689+
/// ```rust
690+
/// use extra::sync::Barrier;
691+
///
692+
/// let barrier = Barrier::new(10);
693+
/// 10.times(|| {
694+
/// let c = barrier.clone();
695+
/// // The same messages will be printed together.
696+
/// // You will NOT see any interleaving.
697+
/// do spawn {
698+
/// println!("before wait");
699+
/// c.wait();
700+
/// println!("after wait");
701+
/// }
702+
/// });
703+
/// ```
704+
#[deriving(Clone)]
705+
pub struct Barrier {
706+
priv arc: MutexArc<BarrierState>,
707+
priv num_tasks: uint,
708+
}
709+
710+
// The inner state of a double barrier
711+
struct BarrierState {
712+
priv count: uint,
713+
priv generation_id: uint,
714+
}
715+
716+
impl Barrier {
717+
/// Create a new barrier that can block a given number of tasks.
718+
pub fn new(num_tasks: uint) -> Barrier {
719+
Barrier {
720+
arc: MutexArc::new(BarrierState {
721+
count: 0,
722+
generation_id: 0,
723+
}),
724+
num_tasks: num_tasks,
725+
}
726+
}
727+
728+
/// Block the current task until a certain number of tasks is waiting.
729+
pub fn wait(&self) {
730+
self.arc.access_cond(|state, cond| {
731+
let local_gen = state.generation_id;
732+
state.count += 1;
733+
if state.count < self.num_tasks {
734+
// We need a while loop to guard against spurious wakeups.
735+
// http://en.wikipedia.org/wiki/Spurious_wakeup
736+
while local_gen == state.generation_id && state.count < self.num_tasks {
737+
cond.wait();
738+
}
739+
} else {
740+
state.count = 0;
741+
state.generation_id += 1;
742+
cond.broadcast();
743+
}
744+
});
745+
}
746+
}
747+
685748
/****************************************************************************
686749
* Tests
687750
****************************************************************************/
@@ -693,6 +756,7 @@ mod tests {
693756
use std::cast;
694757
use std::result;
695758
use std::task;
759+
use std::comm::{SharedChan, Empty};
696760

697761
/************************************************************************
698762
* Semaphore tests
@@ -1315,4 +1379,35 @@ mod tests {
13151379
})
13161380
})
13171381
}
1382+
1383+
/************************************************************************
1384+
* Barrier tests
1385+
************************************************************************/
1386+
#[test]
1387+
fn test_barrier() {
1388+
let barrier = Barrier::new(10);
1389+
let (port, chan) = SharedChan::new();
1390+
1391+
9.times(|| {
1392+
let c = barrier.clone();
1393+
let chan = chan.clone();
1394+
do spawn {
1395+
c.wait();
1396+
chan.send(true);
1397+
}
1398+
});
1399+
1400+
// At this point, all spawned tasks should be blocked,
1401+
// so we shouldn't get anything from the port
1402+
assert!(match port.try_recv() {
1403+
Empty => true,
1404+
_ => false,
1405+
});
1406+
1407+
barrier.wait();
1408+
// Now, the barrier is cleared and we should get data.
1409+
9.times(|| {
1410+
port.recv();
1411+
});
1412+
}
13181413
}

0 commit comments

Comments
 (0)