diff --git a/src/lib.rs b/src/lib.rs index 0a43fc6..3061801 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let answer_enum_name = format_ident!("__{}Answer", name); let question_enum_name = format_ident!("__{}Question", name); let query_enum_name = format_ident!("__{}Query", name); + let queries_struct_name = format_ident!("__{}Queries", name); let client_connection_struct_name = format_ident!("__{}Connection", name); let server_trait_name = format_ident!("{}ServerTrait", name); let client_struct_name = format_ident!("{}Client", name); @@ -176,6 +177,40 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream #(#server_trait)* } }; + // Create a struct to hold queries behind an Arc> to enable async access + // TODO: It might be a good idea to just make this a generic struct and write it in actual code + // rather than in this macro + let queries_struct = quote! { + #[derive(Clone)] + struct #queries_struct_name { + queries: ::std::sync::Arc<::std::sync::Mutex<::std::collections::HashMap>>, + } + impl #queries_struct_name { + fn new() -> Self { + Self { + queries: ::std::sync::Arc::new(::std::sync::Mutex::new(::std::collections::HashMap::new())), + } + } + + pub fn insert(&self, nonce: u64, query: #query_enum_name) { + self.queries.lock().unwrap().insert(nonce, query); + } + + pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> { + self.queries.lock().unwrap().get(nonce).cloned() + } + + pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) { + if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) { + query.set_answer(answer); + } + } + + pub fn len(&self) -> usize { + self.queries.lock().unwrap().len() + } + } + }; // Create a struct to handle the connection from the client to the server let stream_type = quote! { tokio::net::TcpStream }; // TODO: In the future we could support other stream types let cc_struct = quote! { @@ -248,24 +283,24 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } #[derive(Clone)] #vis struct #client_struct_name { - queries: ::std::sync::Arc<::tokio::sync::Mutex<::std::collections::HashMap>>, + queries: #queries_struct_name, send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: #client_recv_queue_wrapper, } impl #client_struct_name { pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { Self { - queries: ::std::sync::Arc::new(::tokio::sync::Mutex::new(::std::collections::HashMap::new())), + queries: #queries_struct_name::new(), recv_queue: #client_recv_queue_wrapper::new(recv_queue), send_queue, } } async fn send(&self, query: #question_enum_name) -> Result { - let nonce = self.queries.lock().await.len() as u64; + let nonce = self.queries.len() as u64; let res = self.send_queue.send((nonce, query.clone())).await; match res { Ok(_) => { - self.queries.lock().await.insert(nonce, query.into()); + self.queries.insert(nonce, query.into()); Ok(nonce) } Err(e) => Err(#error_enum_name::SendError(e)), @@ -273,18 +308,15 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { loop { - let mut queries = self.queries.lock().await; // Check if we've received the answer for the query we're looking for - if let Some(query) = queries.get(&id) { + if let Some(query) = self.queries.get(&id) { if let Some(answer) = query.get_answer() { return Ok(answer); } } match self.recv_queue.recv().await { Some((nonce, answer)) => { - if let Some(query) = queries.get_mut(&nonce) { - query.set_answer(answer.clone()); - } + self.queries.set_answer(nonce, answer.clone()); } None => return Err(#error_enum_name::Closed), }; @@ -299,6 +331,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream #answer_enum #question_enum #query_enum + #queries_struct #server_trait #cc_struct #client_struct