Skip to content

Commit a2f7e84

Browse files
committed
use enum for condvar locks
1 parent a1d94d4 commit a2f7e84

File tree

3 files changed

+76
-44
lines changed

3 files changed

+76
-44
lines changed

src/tools/miri/src/concurrency/sync.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,25 @@ struct RwLock {
116116

117117
declare_id!(CondvarId);
118118

119+
#[derive(Debug, Copy, Clone)]
120+
pub enum RwLockMode {
121+
Read,
122+
Write,
123+
}
124+
125+
#[derive(Debug)]
126+
pub enum CondvarLock {
127+
Mutex(MutexId),
128+
RwLock { id: RwLockId, mode: RwLockMode },
129+
}
130+
119131
/// A thread waiting on a conditional variable.
120132
#[derive(Debug)]
121133
struct CondvarWaiter {
122134
/// The thread that is waiting on this variable.
123135
thread: ThreadId,
124136
/// The mutex or rwlock on which the thread is waiting.
125-
lock: u32,
126-
/// If the lock is shared or exclusive
127-
shared: bool,
137+
lock: CondvarLock,
128138
}
129139

130140
/// The conditional variable state.
@@ -571,16 +581,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
571581
}
572582

573583
/// Mark that the thread is waiting on the conditional variable.
574-
fn condvar_wait(&mut self, id: CondvarId, thread: ThreadId, lock: u32, shared: bool) {
584+
fn condvar_wait(&mut self, id: CondvarId, thread: ThreadId, lock: CondvarLock) {
575585
let this = self.eval_context_mut();
576586
let waiters = &mut this.machine.threads.sync.condvars[id].waiters;
577587
assert!(waiters.iter().all(|waiter| waiter.thread != thread), "thread is already waiting");
578-
waiters.push_back(CondvarWaiter { thread, lock, shared });
588+
waiters.push_back(CondvarWaiter { thread, lock });
579589
}
580590

581591
/// Wake up some thread (if there is any) sleeping on the conditional
582592
/// variable.
583-
fn condvar_signal(&mut self, id: CondvarId) -> Option<(ThreadId, u32, bool)> {
593+
fn condvar_signal(&mut self, id: CondvarId) -> Option<(ThreadId, CondvarLock)> {
584594
let this = self.eval_context_mut();
585595
let current_thread = this.get_active_thread();
586596
let condvar = &mut this.machine.threads.sync.condvars[id];
@@ -594,7 +604,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
594604
if let Some(data_race) = data_race {
595605
data_race.validate_lock_acquire(&condvar.data_race, waiter.thread);
596606
}
597-
(waiter.thread, waiter.lock, waiter.shared)
607+
(waiter.thread, waiter.lock)
598608
})
599609
}
600610

src/tools/miri/src/shims/unix/sync.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::time::SystemTime;
33
use rustc_hir::LangItem;
44
use rustc_middle::ty::{layout::TyAndLayout, query::TyCtxtAt, Ty};
55

6+
use crate::concurrency::sync::CondvarLock;
67
use crate::concurrency::thread::{MachineCallback, Time};
78
use crate::*;
89

@@ -696,9 +697,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
696697
fn pthread_cond_signal(&mut self, cond_op: &OpTy<'tcx, Provenance>) -> InterpResult<'tcx, i32> {
697698
let this = self.eval_context_mut();
698699
let id = this.condvar_get_or_create_id(cond_op, CONDVAR_ID_OFFSET)?;
699-
if let Some((thread, mutex, shared)) = this.condvar_signal(id) {
700-
assert!(!shared);
701-
post_cond_signal(this, thread, MutexId::from_u32(mutex))?;
700+
if let Some((thread, lock)) = this.condvar_signal(id) {
701+
if let CondvarLock::Mutex(mutex) = lock {
702+
post_cond_signal(this, thread, mutex)?;
703+
} else {
704+
panic!("condvar should not have an rwlock on unix");
705+
}
702706
}
703707

704708
Ok(0)
@@ -711,9 +715,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
711715
let this = self.eval_context_mut();
712716
let id = this.condvar_get_or_create_id(cond_op, CONDVAR_ID_OFFSET)?;
713717

714-
while let Some((thread, mutex, shared)) = this.condvar_signal(id) {
715-
assert!(!shared);
716-
post_cond_signal(this, thread, MutexId::from_u32(mutex))?;
718+
while let Some((thread, lock)) = this.condvar_signal(id) {
719+
if let CondvarLock::Mutex(mutex) = lock {
720+
post_cond_signal(this, thread, mutex)?;
721+
} else {
722+
panic!("condvar should not have an rwlock on unix");
723+
}
717724
}
718725

719726
Ok(0)
@@ -731,7 +738,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
731738
let active_thread = this.get_active_thread();
732739

733740
release_cond_mutex_and_block(this, active_thread, mutex_id)?;
734-
this.condvar_wait(id, active_thread, mutex_id.to_u32(), false);
741+
this.condvar_wait(id, active_thread, CondvarLock::Mutex(mutex_id));
735742

736743
Ok(0)
737744
}
@@ -770,7 +777,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
770777
};
771778

772779
release_cond_mutex_and_block(this, active_thread, mutex_id)?;
773-
this.condvar_wait(id, active_thread, mutex_id.to_u32(), false);
780+
this.condvar_wait(id, active_thread, CondvarLock::Mutex(mutex_id));
774781

775782
// We return success for now and override it in the timeout callback.
776783
this.write_scalar(Scalar::from_i32(0), dest)?;

src/tools/miri/src/shims/windows/sync.rs

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::time::Duration;
33
use rustc_target::abi::Size;
44

55
use crate::concurrency::init_once::InitOnceStatus;
6+
use crate::concurrency::sync::{CondvarLock, RwLockMode};
67
use crate::concurrency::thread::MachineCallback;
78
use crate::*;
89

@@ -18,23 +19,24 @@ pub trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tc
1819
&mut self,
1920
thread: ThreadId,
2021
lock: RwLockId,
21-
shared: bool,
22+
mode: RwLockMode,
2223
) -> InterpResult<'tcx> {
2324
let this = self.eval_context_mut();
2425
this.unblock_thread(thread);
2526

26-
if shared {
27-
if this.rwlock_is_locked(lock) {
28-
this.rwlock_enqueue_and_block_reader(lock, thread);
29-
} else {
30-
this.rwlock_reader_lock(lock, thread);
31-
}
32-
} else {
33-
if this.rwlock_is_write_locked(lock) {
34-
this.rwlock_enqueue_and_block_writer(lock, thread);
35-
} else {
36-
this.rwlock_writer_lock(lock, thread);
37-
}
27+
match mode {
28+
RwLockMode::Read =>
29+
if this.rwlock_is_locked(lock) {
30+
this.rwlock_enqueue_and_block_reader(lock, thread);
31+
} else {
32+
this.rwlock_reader_lock(lock, thread);
33+
},
34+
RwLockMode::Write =>
35+
if this.rwlock_is_write_locked(lock) {
36+
this.rwlock_enqueue_and_block_writer(lock, thread);
37+
} else {
38+
this.rwlock_writer_lock(lock, thread);
39+
},
3840
}
3941

4042
Ok(())
@@ -383,14 +385,19 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
383385
};
384386

385387
let shared_mode = 0x1; // CONDITION_VARIABLE_LOCKMODE_SHARED is not in std
386-
let shared = flags == shared_mode;
388+
let mode = if flags == 0 {
389+
RwLockMode::Write
390+
} else if flags == shared_mode {
391+
RwLockMode::Read
392+
} else {
393+
throw_unsup_format!("unsupported `Flags` {flags} in `SleepConditionVariableSRW`");
394+
};
387395

388396
let active_thread = this.get_active_thread();
389397

390-
let was_locked = if shared {
391-
this.rwlock_reader_unlock(lock_id, active_thread)
392-
} else {
393-
this.rwlock_writer_unlock(lock_id, active_thread)
398+
let was_locked = match mode {
399+
RwLockMode::Read => this.rwlock_reader_unlock(lock_id, active_thread),
400+
RwLockMode::Write => this.rwlock_writer_unlock(lock_id, active_thread),
394401
};
395402

396403
if !was_locked {
@@ -400,27 +407,27 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
400407
}
401408

402409
this.block_thread(active_thread);
403-
this.condvar_wait(condvar_id, active_thread, lock_id.to_u32(), shared);
410+
this.condvar_wait(condvar_id, active_thread, CondvarLock::RwLock { id: lock_id, mode });
404411

405412
if let Some(timeout_time) = timeout_time {
406413
struct Callback<'tcx> {
407414
thread: ThreadId,
408415
condvar_id: CondvarId,
409416
lock_id: RwLockId,
410-
shared: bool,
417+
mode: RwLockMode,
411418
dest: PlaceTy<'tcx, Provenance>,
412419
}
413420

414421
impl<'tcx> VisitTags for Callback<'tcx> {
415422
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
416-
let Callback { thread: _, condvar_id: _, lock_id: _, shared: _, dest } = self;
423+
let Callback { thread: _, condvar_id: _, lock_id: _, mode: _, dest } = self;
417424
dest.visit_tags(visit);
418425
}
419426
}
420427

421428
impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for Callback<'tcx> {
422429
fn call(&self, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
423-
this.reacquire_cond_lock(self.thread, self.lock_id, self.shared)?;
430+
this.reacquire_cond_lock(self.thread, self.lock_id, self.mode)?;
424431

425432
this.condvar_remove_waiter(self.condvar_id, self.thread);
426433

@@ -438,7 +445,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
438445
thread: active_thread,
439446
condvar_id,
440447
lock_id,
441-
shared,
448+
mode,
442449
dest: dest.clone(),
443450
}),
444451
);
@@ -451,9 +458,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
451458
let this = self.eval_context_mut();
452459
let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?;
453460

454-
if let Some((thread, lock, shared)) = this.condvar_signal(condvar_id) {
455-
this.reacquire_cond_lock(thread, RwLockId::from_u32(lock), shared)?;
456-
this.unregister_timeout_callback_if_exists(thread);
461+
if let Some((thread, lock)) = this.condvar_signal(condvar_id) {
462+
if let CondvarLock::RwLock { id, mode } = lock {
463+
this.reacquire_cond_lock(thread, id, mode)?;
464+
this.unregister_timeout_callback_if_exists(thread);
465+
} else {
466+
panic!("mutexes should not exist on windows");
467+
}
457468
}
458469

459470
Ok(())
@@ -466,9 +477,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
466477
let this = self.eval_context_mut();
467478
let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?;
468479

469-
while let Some((thread, lock, shared)) = this.condvar_signal(condvar_id) {
470-
this.reacquire_cond_lock(thread, RwLockId::from_u32(lock), shared)?;
471-
this.unregister_timeout_callback_if_exists(thread);
480+
while let Some((thread, lock)) = this.condvar_signal(condvar_id) {
481+
if let CondvarLock::RwLock { id, mode } = lock {
482+
this.reacquire_cond_lock(thread, id, mode)?;
483+
this.unregister_timeout_callback_if_exists(thread);
484+
} else {
485+
panic!("mutexes should not exist on windows");
486+
}
472487
}
473488

474489
Ok(())

0 commit comments

Comments
 (0)