@@ -11,8 +11,7 @@ use candle_transformers::generation::LogitsProcessor;
11
11
12
12
use candle_transformers:: models:: quantized_llama as model;
13
13
use model:: ModelWeights ;
14
- use std:: sync:: Arc ;
15
- use tokio:: sync:: Mutex ;
14
+ use std:: sync:: { Arc , Mutex , TryLockError } ;
16
15
pub struct ModelBuilder {
17
16
sample_len : usize ,
18
17
temperature : f64 ,
@@ -136,7 +135,21 @@ impl Model {
136
135
prompt_str : String ,
137
136
pre_prompt_tokens : & Vec < u32 > ,
138
137
) -> Result < ( String , Vec < u32 > ) , Box < dyn std:: error:: Error > > {
139
- let mut model_weights = self . model_weights . lock ( ) . await ;
138
+ let mut model_weights = loop {
139
+ match self . model_weights . try_lock ( ) {
140
+ Ok ( model_weights) => break model_weights,
141
+ Err ( TryLockError :: Poisoned ( e) ) => {
142
+ let guard = e. into_inner ( ) ;
143
+ // waiting for https://github.com/rust-lang/rust/issues/96469
144
+ // *guard = build_model_weights()?;
145
+ // self.model_weights.clear_poison();
146
+ // println!("Note: model_weights mutex was poisoned, will try to rebuild");
147
+ break guard;
148
+ }
149
+ Err ( TryLockError :: WouldBlock ) => { }
150
+ }
151
+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 100 ) ) . await ;
152
+ } ;
140
153
141
154
tokio:: task:: block_in_place ( move || {
142
155
let prompt_str = format ! ( "[INST] {prompt_str} [/INST]" ) ;
@@ -272,4 +285,27 @@ mod tests {
272
285
let ( output, _) = model. interact ( prompt, & pre_prompt_tokens) . await . unwrap ( ) ;
273
286
println ! ( "{output}" ) ;
274
287
}
288
+
289
+ // waiting for https://github.com/rust-lang/rust/issues/96469
290
+ // #[tokio::test]
291
+ // async fn poisoning_rebuild() {
292
+ // let model = ModelBuilder::default().build().unwrap();
293
+ // let c_model = model.clone();
294
+
295
+ // #[allow(unused_variables, unreachable_code)]
296
+ // std::thread::spawn(move || {
297
+ // let lock = c_model.model_weights.lock().unwrap();
298
+ // panic!();
299
+ // drop(lock);
300
+ // })
301
+ // .join()
302
+ // .unwrap_or_default();
303
+
304
+ // assert!(model.model_weights.is_poisoned());
305
+
306
+ // let prompt = "Create a basic Rust program".to_string();
307
+ // let pre_prompt_tokens = vec![];
308
+ // let (output, _) = model.interact(prompt, &pre_prompt_tokens).await.unwrap();
309
+ // println!("{output}");
310
+ // }
275
311
}
0 commit comments