From aa7d51f088d268392c7cd0b27f67ceeb44dce8ff Mon Sep 17 00:00:00 2001 From: Kodi Craft Date: Thu, 20 Jun 2024 12:52:42 +0200 Subject: [PATCH] Generate code to "send" a question --- src/lib.rs | 87 +++++++++++++++++++++++++++++++++++++++++----------- tests/mod.rs | 4 +-- 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 662ec26..bacbd67 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,14 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { } }; let name = &input.ident; + + let error_enum_name = format_ident!("{}Error", name); + let answer_enum_name = format_ident!("{}Answer", name); + let question_enum_name = format_ident!("{}Question", name); + let query_enum_name = format_ident!("{}Query", name); + let server_trait_name = format_ident!("{}Server", name); + let client_struct_name = format_ident!("{}Client", name); + let vis = &input.vis; let mut server_trait = Vec::new(); @@ -40,6 +48,7 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { let mut client_enum = Vec::new(); let mut query_enum = Vec::new(); + let mut query_from_question_enum = Vec::new(); for variant in &enum_.variants { // Every variant must have 2 fields @@ -57,6 +66,7 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { let mut variant_fields = variant.fields.iter(); let question_field = variant_fields.next().unwrap(); let question_args = field_to_args(question_field); + let question_tuple_args = field_to_tuple_args(question_field); let answer_type = variant_fields.next().unwrap().ty.clone(); // The variants that either the server or the client will use @@ -67,14 +77,19 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { client_enum.push(quote! { #var_name(#question_field) }); + // There is a From implementation for the client enum to the query enum + query_from_question_enum.push(quote! { + #question_enum_name::#var_name(question) => #query_enum_name::#var_name(question, None), + }); // The function that the server needs to implement server_trait.push(quote! { fn #var_name(&mut self, #question_args) -> #answer_type; }); // The function that the client uses to communicate client_impl.push(quote! { - pub fn #var_name(&mut self, #question_args) -> #answer_type { - ::std::unimplemented!() + pub async fn #var_name(&mut self, #question_args) -> Result<#answer_type, #error_enum_name> { + let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?; + todo!("Wait for the answer") } }); // The query enum is the same as the source enum, but the second field is always wrapped in a Option<> @@ -83,58 +98,80 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream { }); } + // Create an error and result type for sending messages + let error_enum = quote! { + #vis enum #error_enum_name { + SendError(tokio::sync::mpsc::error::SendError<#question_enum_name>), + } + }; // Create enums for the types of messages the server and client will use - let server_enum_name = format_ident!("{}Answer", name); - let server_enum = quote! { - #vis enum #server_enum_name { + + let answer_enum = quote! { + #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] + #vis enum #answer_enum_name { #(#server_enum), * } }; - let client_enum_name = format_ident!("{}Question", name); - let client_enum = quote! { - #[derive(serde::Serialize, serde::Deserialize)] - #vis enum #client_enum_name { + let question_enum = quote! { + #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] + #vis enum #question_enum_name { #(#client_enum), * } }; // Create an enum to represent the queries the client has sent - let query_enum_name = format_ident!("{}Query", name); let query_enum = quote! { - #[derive(serde::Serialize, serde::Deserialize)] + #[derive(Clone, Debug)] #vis enum #query_enum_name { #(#query_enum), * } + impl From<#question_enum_name> for #query_enum_name { + fn from(query: #question_enum_name) -> Self { + match query { + #(#query_from_question_enum)* + } + } + } }; // Create a trait which the server will have to implement - let server_trait_name = format_ident!("{}Server", name); let server_trait = quote! { #vis trait #server_trait_name { #(#server_trait)* } }; // Create a struct which the client will use to communicate - let client_struct_name = format_ident!("{}Client", name); let client_struct = quote! { #vis struct #client_struct_name { queries: ::std::collections::HashMap, - send_queue: tokio::sync::mpsc::Sender<#client_enum_name>, - recv_queue: tokio::sync::mpsc::Receiver<#server_enum_name>, + send_queue: tokio::sync::mpsc::Sender<#question_enum_name>, + recv_queue: tokio::sync::mpsc::Receiver<#answer_enum_name>, } // TODO: This struct will have some fields to handle the actual connection impl #client_struct_name { - pub fn new(send_queue: tokio::sync::mpsc::Sender<#client_enum_name>, recv_queue: tokio::sync::mpsc::Receiver<#server_enum_name>) -> Self { + pub fn new(send_queue: tokio::sync::mpsc::Sender<#question_enum_name>, recv_queue: tokio::sync::mpsc::Receiver<#answer_enum_name>) -> Self { Self { queries: ::std::collections::HashMap::new(), send_queue, recv_queue, } } + async fn send(&mut self, query: #question_enum_name) -> Result { + let res = self.send_queue.send(query.clone()).await; + match res { + Ok(_) => { + let id = self.queries.len() as u64; + self.queries.insert(id, query.into()); + Ok(id) + } + Err(e) => Err(#error_enum_name::SendError(e)), + } + } #(#client_impl)* } }; let expanded = quote! { - #server_enum - #client_enum + #error_enum + #answer_enum + #question_enum #query_enum #server_trait #client_struct @@ -171,3 +208,17 @@ fn field_to_args(field: &Field) -> proc_macro2::TokenStream { quote! { arg: #type_ } } } + +fn field_to_tuple_args(field: &Field) -> proc_macro2::TokenStream { + let type_ = &field.ty; + if let syn::Type::Tuple(tuple) = type_ { + let mut args = Vec::new(); + for (i, elem) in tuple.elems.iter().enumerate() { + let arg = Ident::new(&format!("arg{}", i), elem.span()); + args.push(quote! { #arg }); + } + quote! { ( #( #args ), * ) } + } else { + quote! { (arg) } + } +} diff --git a/tests/mod.rs b/tests/mod.rs index e508ae9..617847e 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -29,8 +29,8 @@ impl TestProtocolServer for DummyServer { question.len() as i32 } - fn addition(&mut self, arg0: i32, arg1: i32) -> i32 { - todo!() + fn addition(&mut self, a: i32, b: i32) -> i32 { + a + b } }