From b9128a465c62f8f048f4ee35893cdf2f5958eab4 Mon Sep 17 00:00:00 2001 From: Kodi Craft Date: Fri, 21 Jun 2024 11:51:33 +0200 Subject: [PATCH] Client can now perform requests asynchronously --- src/lib.rs | 73 ++++++++++++++++++++++++++++++++++++++++++---------- tests/mod.rs | 22 +++++++++++++++- 2 files changed, 81 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 331df8d..da12f4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let answer_enum_name = format_ident!("__{}Answer", name); let question_enum_name = format_ident!("__{}Question", name); let query_enum_name = format_ident!("__{}Query", name); + let queries_struct_name = format_ident!("__{}Queries", name); let server_trait_name = format_ident!("{}Server", name); let client_struct_name = format_ident!("{}Client", name); @@ -106,7 +107,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream }); // The function that the client uses to communicate client_impl.push(quote! { - pub async fn #var_name(&mut self, #question_args) -> Result<#answer_type, #error_enum_name> { + pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> { let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?; let answer = self.recv_until(nonce).await?; match answer { @@ -175,22 +176,72 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream #(#server_trait)* } }; + // Create a struct to hold queries behind an Arc> to enable async access + // TODO: It might be a good idea to just make this a generic struct and write it in actual code + // rather than in this macro + let queries_struct = quote! { + #[derive(Clone)] + struct #queries_struct_name { + queries: ::std::sync::Arc<::std::sync::Mutex<::std::collections::HashMap>>, + } + impl #queries_struct_name { + fn new() -> Self { + Self { + queries: ::std::sync::Arc::new(::std::sync::Mutex::new(::std::collections::HashMap::new())), + } + } + + pub fn insert(&self, nonce: u64, query: #query_enum_name) { + self.queries.lock().unwrap().insert(nonce, query); + } + + pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> { + self.queries.lock().unwrap().get(nonce).cloned() + } + + pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) { + if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) { + query.set_answer(answer); + } + } + + pub fn len(&self) -> usize { + self.queries.lock().unwrap().len() + } + } + }; // Create a struct which the client will use to communicate + let client_recv_queue_wrapper = format_ident!("__{}RecvQueueWrapper", name); let client_struct = quote! { + #[derive(Clone)] + struct #client_recv_queue_wrapper { + recv_queue: ::std::sync::Arc<::tokio::sync::Mutex>>, + } + impl #client_recv_queue_wrapper { + fn new(recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { + Self { + recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)), + } + } + async fn recv(&self) -> Option<(u64, #answer_enum_name)> { + self.recv_queue.lock().await.recv().await + } + } + #[derive(Clone)] #vis struct #client_struct_name { - queries: ::std::collections::HashMap, + queries: #queries_struct_name, send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, - recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>, + recv_queue: #client_recv_queue_wrapper, } // TODO: This struct will have some fields to handle the actual connection impl #client_struct_name { pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { Self { - queries: ::std::collections::HashMap::new(), + queries: #queries_struct_name::new(), + recv_queue: #client_recv_queue_wrapper::new(recv_queue), send_queue, - recv_queue, } } - async fn send(&mut self, query: #question_enum_name) -> Result { + async fn send(&self, query: #question_enum_name) -> Result { let nonce = self.queries.len() as u64; let res = self.send_queue.send((nonce, query.clone())).await; match res { @@ -201,7 +252,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream Err(e) => Err(#error_enum_name::SendError(e)), } } - async fn recv_until(&mut self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { + async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { loop { // Check if we've received the answer for the query we're looking for if let Some(query) = self.queries.get(&id) { @@ -211,12 +262,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } match self.recv_queue.recv().await { Some((nonce, answer)) => { - // Replace the Option<> in the query with the answer - if let Some(query) = self.queries.get_mut(&nonce) { - query.set_answer(answer); - } else { - panic!("Received an answer for a query we did not send"); - } + self.queries.set_answer(nonce, answer.clone()); } None => return Err(#error_enum_name::Closed), }; @@ -231,6 +277,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream #answer_enum #question_enum #query_enum + #queries_struct #server_trait #client_struct }; diff --git a/tests/mod.rs b/tests/mod.rs index 9b980fa..d99e968 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -30,7 +30,7 @@ enum TestProtocol { async fn main() { let (qtx, qrx) = mpsc::channel(16); let (atx, arx) = mpsc::channel(16); - let mut client = TestProtocolClient::new(qtx, arx); + let client = TestProtocolClient::new(qtx, arx); let server = tokio::spawn(server_loop(qrx, atx)); let result = client.addition(2, 5).await.unwrap(); assert_eq!(result, 7); @@ -86,3 +86,23 @@ async fn server_loop( } } } + +#[tokio::test] +async fn heavy_async() { + let (qtx, qrx) = mpsc::channel(16); + let (atx, arx) = mpsc::channel(16); + let client = TestProtocolClient::new(qtx, arx); + let server = tokio::spawn(server_loop(qrx, atx)); + let mut tasks = Vec::new(); + for i in 0..100 { + let client = client.clone(); + tasks.push(tokio::spawn(async move { + let result = client.addition(i, i).await.unwrap(); + assert_eq!(result, i + i); + })); + } + for task in tasks { + task.await.unwrap(); + } + server.abort(); +}