diff --git a/src/lib.rs b/src/lib.rs index bacbd67..0d08eda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,10 +33,10 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { }; let name = &input.ident; - let error_enum_name = format_ident!("{}Error", name); - let answer_enum_name = format_ident!("{}Answer", name); - let question_enum_name = format_ident!("{}Question", name); - let query_enum_name = format_ident!("{}Query", name); + let error_enum_name = format_ident!("__{}Error", name); + let answer_enum_name = format_ident!("__{}Answer", name); + let question_enum_name = format_ident!("__{}Question", name); + let query_enum_name = format_ident!("__{}Query", name); let server_trait_name = format_ident!("{}Server", name); let client_struct_name = format_ident!("{}Client", name); @@ -49,6 +49,8 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { let mut query_enum = Vec::new(); let mut query_from_question_enum = Vec::new(); + let mut query_set_answer = Vec::new(); + let mut query_get_answer = Vec::new(); for variant in &enum_.variants { // Every variant must have 2 fields @@ -81,6 +83,20 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { query_from_question_enum.push(quote! { #question_enum_name::#var_name(question) => #query_enum_name::#var_name(question, None), }); + // There is a function that must be implemented to set the answer in the query enum + query_set_answer.push(quote! { + #query_enum_name::#var_name(question, answer_opt) => match answer { + #answer_enum_name::#var_name(answer) => {answer_opt.replace(answer);}, + _ => panic!("The answer for this query is not the correct type."), + }, + }); + // There is a function that must be implemented to get the answer from the query enum + query_get_answer.push(quote! { + #query_enum_name::#var_name(_, answer) => match answer { + Some(answer) => Some(#answer_enum_name::#var_name(answer.clone())), + None => None + }, + }); // The function that the server needs to implement server_trait.push(quote! { fn #var_name(&mut self, #question_args) -> #answer_type; @@ -89,7 +105,11 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { client_impl.push(quote! { pub async fn #var_name(&mut self, #question_args) -> Result<#answer_type, #error_enum_name> { let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?; - todo!("Wait for the answer") + let answer = self.recv_until(nonce).await?; + match answer { + #answer_enum_name::#var_name(answer) => Ok(answer), + _ => panic!("The answer for this query is not the correct type."), + } } }); // The query enum is the same as the source enum, but the second field is always wrapped in a Option<> @@ -101,7 +121,8 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { // Create an error and result type for sending messages let error_enum = quote! { #vis enum #error_enum_name { - SendError(tokio::sync::mpsc::error::SendError<#question_enum_name>), + SendError(tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>), + Closed, } }; // Create enums for the types of messages the server and client will use @@ -124,6 +145,18 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { #vis enum #query_enum_name { #(#query_enum), * } + impl #query_enum_name { + pub fn set_answer(&mut self, answer: #answer_enum_name) { + match self { + #(#query_set_answer)* + }; + } + pub fn get_answer(&self) -> Option<#answer_enum_name> { + match self { + #(#query_get_answer)* + } + } + } impl From<#question_enum_name> for #query_enum_name { fn from(query: #question_enum_name) -> Self { match query { @@ -142,11 +175,11 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { let client_struct = quote! { #vis struct #client_struct_name { queries: ::std::collections::HashMap, - send_queue: tokio::sync::mpsc::Sender<#question_enum_name>, - recv_queue: tokio::sync::mpsc::Receiver<#answer_enum_name>, + send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, + recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>, } // TODO: This struct will have some fields to handle the actual connection impl #client_struct_name { - pub fn new(send_queue: tokio::sync::mpsc::Sender<#question_enum_name>, recv_queue: tokio::sync::mpsc::Receiver<#answer_enum_name>) -> Self { + pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { Self { queries: ::std::collections::HashMap::new(), send_queue, @@ -154,16 +187,37 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { } } async fn send(&mut self, query: #question_enum_name) -> Result { - let res = self.send_queue.send(query.clone()).await; + let nonce = self.queries.len() as u64; + let res = self.send_queue.send((nonce, query.clone())).await; match res { Ok(_) => { - let id = self.queries.len() as u64; - self.queries.insert(id, query.into()); - Ok(id) + self.queries.insert(nonce, query.into()); + Ok(nonce) } Err(e) => Err(#error_enum_name::SendError(e)), } } + async fn recv_until(&mut self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { + loop { + // Check if we've received the answer for the query we're looking for + if let Some(query) = self.queries.get(&id) { + if let Some(answer) = query.get_answer() { + return Ok(answer); + } + } + match self.recv_queue.recv().await { + Some((nonce, answer)) => { + // Replace the Option<> in the query with the answer + if let Some(query) = self.queries.get_mut(&nonce) { + query.set_answer(answer); + } else { + panic!("Received an answer for a query we did not send"); + } + } + None => return Err(#error_enum_name::Closed), + }; + } + } #(#client_impl)* } }; diff --git a/tests/mod.rs b/tests/mod.rs index 617847e..d9d1d7a 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -21,6 +21,8 @@ use eagle::Protocol; enum TestProtocol { Addition((i32, i32), i32), SomeKindOfQuestion(String, i32), + ThisRespondsWithAString(i32, String), + Void((), ()), } struct DummyServer; @@ -32,6 +34,14 @@ impl TestProtocolServer for DummyServer { fn addition(&mut self, a: i32, b: i32) -> i32 { a + b } + + fn this_responds_with_a_string(&mut self, arg: i32) -> String { + format!("The number is {}", arg) + } + + fn void(&mut self) { + println!("Void function called!") + } } fn main() {}