Tweak usage of mutexes to reduce chance of deadlocks
Some checks failed
Build library & run tests / build (push) Failing after 27s

This commit is contained in:
Kodi Craft 2024-06-21 16:25:58 +02:00
parent 06aa6a1f71
commit f35166ce78
Signed by: kodi
GPG Key ID: 69D9EED60B242822

View File

@ -41,7 +41,6 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
let answer_enum_name = format_ident!("__{}Answer", name); let answer_enum_name = format_ident!("__{}Answer", name);
let question_enum_name = format_ident!("__{}Question", name); let question_enum_name = format_ident!("__{}Question", name);
let query_enum_name = format_ident!("__{}Query", name); let query_enum_name = format_ident!("__{}Query", name);
let queries_struct_name = format_ident!("__{}Queries", name);
let client_connection_struct_name = format_ident!("__{}Connection", name); let client_connection_struct_name = format_ident!("__{}Connection", name);
let server_trait_name = format_ident!("{}ServerTrait", name); let server_trait_name = format_ident!("{}ServerTrait", name);
let client_struct_name = format_ident!("{}Client", name); let client_struct_name = format_ident!("{}Client", name);
@ -177,40 +176,6 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
#(#server_trait)* #(#server_trait)*
} }
}; };
// Create a struct to hold queries behind an Arc<Mutex<>> to enable async access
// TODO: It might be a good idea to just make this a generic struct and write it in actual code
// rather than in this macro
let queries_struct = quote! {
#[derive(Clone)]
struct #queries_struct_name {
queries: ::std::sync::Arc<::std::sync::Mutex<::std::collections::HashMap<u64, #query_enum_name>>>,
}
impl #queries_struct_name {
fn new() -> Self {
Self {
queries: ::std::sync::Arc::new(::std::sync::Mutex::new(::std::collections::HashMap::new())),
}
}
pub fn insert(&self, nonce: u64, query: #query_enum_name) {
self.queries.lock().unwrap().insert(nonce, query);
}
pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> {
self.queries.lock().unwrap().get(nonce).cloned()
}
pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) {
if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) {
query.set_answer(answer);
}
}
pub fn len(&self) -> usize {
self.queries.lock().unwrap().len()
}
}
};
// Create a struct to handle the connection from the client to the server // Create a struct to handle the connection from the client to the server
let stream_type = quote! { tokio::net::TcpStream }; // TODO: In the future we could support other stream types let stream_type = quote! { tokio::net::TcpStream }; // TODO: In the future we could support other stream types
let cc_struct = quote! { let cc_struct = quote! {
@ -283,24 +248,24 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
} }
#[derive(Clone)] #[derive(Clone)]
#vis struct #client_struct_name { #vis struct #client_struct_name {
queries: #queries_struct_name, queries: ::std::sync::Arc<::tokio::sync::Mutex<::std::collections::HashMap<u64, #query_enum_name>>>,
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
recv_queue: #client_recv_queue_wrapper, recv_queue: #client_recv_queue_wrapper,
} }
impl #client_struct_name { impl #client_struct_name {
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
Self { Self {
queries: #queries_struct_name::new(), queries: ::std::sync::Arc::new(::tokio::sync::Mutex::new(::std::collections::HashMap::new())),
recv_queue: #client_recv_queue_wrapper::new(recv_queue), recv_queue: #client_recv_queue_wrapper::new(recv_queue),
send_queue, send_queue,
} }
} }
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> { async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
let nonce = self.queries.len() as u64; let nonce = self.queries.lock().await.len() as u64;
let res = self.send_queue.send((nonce, query.clone())).await; let res = self.send_queue.send((nonce, query.clone())).await;
match res { match res {
Ok(_) => { Ok(_) => {
self.queries.insert(nonce, query.into()); self.queries.lock().await.insert(nonce, query.into());
Ok(nonce) Ok(nonce)
} }
Err(e) => Err(#error_enum_name::SendError(e)), Err(e) => Err(#error_enum_name::SendError(e)),
@ -308,15 +273,18 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
} }
async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> {
loop { loop {
let mut queries = self.queries.lock().await;
// Check if we've received the answer for the query we're looking for // Check if we've received the answer for the query we're looking for
if let Some(query) = self.queries.get(&id) { if let Some(query) = queries.get(&id) {
if let Some(answer) = query.get_answer() { if let Some(answer) = query.get_answer() {
return Ok(answer); return Ok(answer);
} }
} }
match self.recv_queue.recv().await { match self.recv_queue.recv().await {
Some((nonce, answer)) => { Some((nonce, answer)) => {
self.queries.set_answer(nonce, answer.clone()); if let Some(query) = queries.get_mut(&nonce) {
query.set_answer(answer.clone());
}
} }
None => return Err(#error_enum_name::Closed), None => return Err(#error_enum_name::Closed),
}; };
@ -331,7 +299,6 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
#answer_enum #answer_enum
#question_enum #question_enum
#query_enum #query_enum
#queries_struct
#server_trait #server_trait
#cc_struct #cc_struct
#client_struct #client_struct