From c892053cbd83bd07243aa815f175a76ed623be0f Mon Sep 17 00:00:00 2001 From: Kodi Craft Date: Sun, 23 Jun 2024 01:46:42 +0200 Subject: [PATCH] Implement additional codegen for calling server handler methods --- src/lib.rs | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2473b58..e8655fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,6 +83,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let mut variant_fields = variant.fields.iter(); let question_field = variant_fields.next().unwrap(); let question_args = field_to_args(question_field); + let question_handler_args = field_to_handler_args(question_field); let question_tuple_args = field_to_tuple_args(question_field); let answer_type = variant_fields.next().unwrap().ty.clone(); @@ -114,14 +115,14 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream }); // There is a function that the server uses to call the appropriate function when receiving a query server_handler.push(quote! { - #question_enum_name::#var_name(#question_args) => { - let answer = self.handler.#var_name(#question_tuple_args); + #question_enum_name::#var_name(#question_tuple_args) => { + let answer = self.handler.lock().await.#var_name(#question_handler_args).await; return #answer_enum_name::#var_name(answer); - } + }, }); // The function that the server needs to implement server_trait.push(quote! { - fn #var_name(&mut self, #question_args) -> #answer_type; + async fn #var_name(&mut self, #question_args) -> #answer_type; }); // The function that the client uses to communicate client_impl.push(quote! { @@ -212,6 +213,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream to_send: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, received: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, } + impl #server_connection_struct_name { + async fn handle(&self, question: #question_enum_name) -> #answer_enum_name { + match question { + #(#server_handler)* + } + } + } }; // Create a struct to hold queries behind an Arc> to enable async access @@ -437,3 +445,17 @@ fn field_to_tuple_args(field: &Field) -> proc_macro2::TokenStream { quote! { (arg) } } } + +fn field_to_handler_args(field: &Field) -> proc_macro2::TokenStream { + let type_ = &field.ty; + if let syn::Type::Tuple(tuple) = type_ { + let mut args = Vec::new(); + for (i, elem) in tuple.elems.iter().enumerate() { + let arg = Ident::new(&format!("arg{}", i), elem.span()); + args.push(quote! { #arg }); + } + quote! { #( #args ), * } + } else { + quote! { arg } + } +}