Skip to content

Commit 12bc9a8

Browse files
using std::sync::Mutex, waiting for rust-lang/rust#96469 to use clear_poison()
1 parent 89b20be commit 12bc9a8

File tree

1 file changed

+39
-3
lines changed

1 file changed

+39
-3
lines changed

src/routes/llm.rs

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ use candle_transformers::generation::LogitsProcessor;
1111

1212
use candle_transformers::models::quantized_llama as model;
1313
use model::ModelWeights;
14-
use std::sync::Arc;
15-
use tokio::sync::Mutex;
14+
use std::sync::{Arc, Mutex, TryLockError};
1615
pub struct ModelBuilder {
1716
sample_len: usize,
1817
temperature: f64,
@@ -136,7 +135,21 @@ impl Model {
136135
prompt_str: String,
137136
pre_prompt_tokens: &Vec<u32>,
138137
) -> Result<(String, Vec<u32>), Box<dyn std::error::Error>> {
139-
let mut model_weights = self.model_weights.lock().await;
138+
let mut model_weights = loop {
139+
match self.model_weights.try_lock() {
140+
Ok(model_weights) => break model_weights,
141+
Err(TryLockError::Poisoned(e)) => {
142+
let guard = e.into_inner();
143+
// waiting for https://github.com/rust-lang/rust/issues/96469
144+
// *guard = build_model_weights()?;
145+
// self.model_weights.clear_poison();
146+
// println!("Note: model_weights mutex was poisoned, will try to rebuild");
147+
break guard;
148+
}
149+
Err(TryLockError::WouldBlock) => {}
150+
}
151+
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
152+
};
140153

141154
tokio::task::block_in_place(move || {
142155
let prompt_str = format!("[INST] {prompt_str} [/INST]");
@@ -272,4 +285,27 @@ mod tests {
272285
let (output, _) = model.interact(prompt, &pre_prompt_tokens).await.unwrap();
273286
println!("{output}");
274287
}
288+
289+
// waiting for https://github.com/rust-lang/rust/issues/96469
290+
// #[tokio::test]
291+
// async fn poisoning_rebuild() {
292+
// let model = ModelBuilder::default().build().unwrap();
293+
// let c_model = model.clone();
294+
295+
// #[allow(unused_variables, unreachable_code)]
296+
// std::thread::spawn(move || {
297+
// let lock = c_model.model_weights.lock().unwrap();
298+
// panic!();
299+
// drop(lock);
300+
// })
301+
// .join()
302+
// .unwrap_or_default();
303+
304+
// assert!(model.model_weights.is_poisoned());
305+
306+
// let prompt = "Create a basic Rust program".to_string();
307+
// let pre_prompt_tokens = vec![];
308+
// let (output, _) = model.interact(prompt, &pre_prompt_tokens).await.unwrap();
309+
// println!("{output}");
310+
// }
275311
}

0 commit comments

Comments
 (0)