/* Eagle - A library for easy communication in full-stack Rust applications Copyright (c) 2024 KodiCraft This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . */ use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::{parse2, spanned::Spanned, DeriveInput, Field, Ident}; #[proc_macro_derive(Protocol)] pub fn derive_protocol_derive(input: TokenStream) -> TokenStream { let expanded = derive_protocol(input.into()); TokenStream::from(expanded) } fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { let input = parse2::(input).unwrap(); // Must be on an enum let enum_ = match &input.data { syn::Data::Enum(e) => e, _ => { return syn::Error::new(input.span(), "Protocol can only be derived on enums") .to_compile_error() } }; let name = &input.ident; let error_enum_name = format_ident!("__{}Error", name); let answer_enum_name = format_ident!("__{}Answer", name); let question_enum_name = format_ident!("__{}Question", name); let query_enum_name = format_ident!("__{}Query", name); let queries_struct_name = format_ident!("__{}Queries", name); let server_trait_name = format_ident!("{}Server", name); let client_struct_name = format_ident!("{}Client", name); let vis = &input.vis; let mut server_trait = Vec::new(); let mut server_enum = Vec::new(); let mut client_impl = Vec::new(); let mut client_enum = Vec::new(); let mut query_enum = Vec::new(); let mut query_from_question_enum = Vec::new(); let mut query_set_answer = Vec::new(); let mut query_get_answer = Vec::new(); for variant in &enum_.variants { // Every variant must have 2 fields // The first field is the question (serverbound), the second field is the answer (clientbound) if variant.fields.len() != 2 { return syn::Error::new( variant.span(), "Every variant on a protocol must have exactly 2 fields", ) .to_compile_error(); } let var_name = ident_to_snake_case(&variant.ident); let mut variant_fields = variant.fields.iter(); let question_field = variant_fields.next().unwrap(); let question_args = field_to_args(question_field); let question_tuple_args = field_to_tuple_args(question_field); let answer_type = variant_fields.next().unwrap().ty.clone(); // The variants that either the server or the client will use // The "server" enum contains messages the server can send, the "client" enum contains messages the client can send server_enum.push(quote! { #var_name(#answer_type) }); client_enum.push(quote! { #var_name(#question_field) }); // There is a From implementation for the client enum to the query enum query_from_question_enum.push(quote! { #question_enum_name::#var_name(question) => #query_enum_name::#var_name(question, None), }); // There is a function that must be implemented to set the answer in the query enum query_set_answer.push(quote! { #query_enum_name::#var_name(question, answer_opt) => match answer { #answer_enum_name::#var_name(answer) => {answer_opt.replace(answer);}, _ => panic!("The answer for this query is not the correct type."), }, }); // There is a function that must be implemented to get the answer from the query enum query_get_answer.push(quote! { #query_enum_name::#var_name(_, answer) => match answer { Some(answer) => Some(#answer_enum_name::#var_name(answer.clone())), None => None }, }); // The function that the server needs to implement server_trait.push(quote! { fn #var_name(&mut self, #question_args) -> #answer_type; }); // The function that the client uses to communicate client_impl.push(quote! { pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> { let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?; let answer = self.recv_until(nonce).await?; match answer { #answer_enum_name::#var_name(answer) => Ok(answer), _ => panic!("The answer for this query is not the correct type."), } } }); // The query enum is the same as the source enum, but the second field is always wrapped in a Option<> query_enum.push(quote! { #var_name(#question_field, Option<#answer_type>) }); } // Create an error and result type for sending messages let error_enum = quote! { #[derive(Debug)] #vis enum #error_enum_name { SendError(tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>), Closed, } }; // Create enums for the types of messages the server and client will use let answer_enum = quote! { #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #vis enum #answer_enum_name { #(#server_enum), * } }; let question_enum = quote! { #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #vis enum #question_enum_name { #(#client_enum), * } }; // Create an enum to represent the queries the client has sent let query_enum = quote! { #[derive(Clone, Debug)] #vis enum #query_enum_name { #(#query_enum), * } impl #query_enum_name { pub fn set_answer(&mut self, answer: #answer_enum_name) { match self { #(#query_set_answer)* }; } pub fn get_answer(&self) -> Option<#answer_enum_name> { match self { #(#query_get_answer)* } } } impl From<#question_enum_name> for #query_enum_name { fn from(query: #question_enum_name) -> Self { match query { #(#query_from_question_enum)* } } } }; // Create a trait which the server will have to implement let server_trait = quote! { #vis trait #server_trait_name { #(#server_trait)* } }; // Create a struct to hold queries behind an Arc> to enable async access // TODO: It might be a good idea to just make this a generic struct and write it in actual code // rather than in this macro let queries_struct = quote! { #[derive(Clone)] struct #queries_struct_name { queries: ::std::sync::Arc<::std::sync::Mutex<::std::collections::HashMap>>, } impl #queries_struct_name { fn new() -> Self { Self { queries: ::std::sync::Arc::new(::std::sync::Mutex::new(::std::collections::HashMap::new())), } } pub fn insert(&self, nonce: u64, query: #query_enum_name) { self.queries.lock().unwrap().insert(nonce, query); } pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> { self.queries.lock().unwrap().get(nonce).cloned() } pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) { if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) { query.set_answer(answer); } } pub fn len(&self) -> usize { self.queries.lock().unwrap().len() } } }; // Create a struct which the client will use to communicate let client_recv_queue_wrapper = format_ident!("__{}RecvQueueWrapper", name); let client_struct = quote! { #[derive(Clone)] struct #client_recv_queue_wrapper { recv_queue: ::std::sync::Arc<::tokio::sync::Mutex>>, } impl #client_recv_queue_wrapper { fn new(recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { Self { recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)), } } async fn recv(&self) -> Option<(u64, #answer_enum_name)> { self.recv_queue.lock().await.recv().await } } #[derive(Clone)] #vis struct #client_struct_name { queries: #queries_struct_name, send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: #client_recv_queue_wrapper, } // TODO: This struct will have some fields to handle the actual connection impl #client_struct_name { pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { Self { queries: #queries_struct_name::new(), recv_queue: #client_recv_queue_wrapper::new(recv_queue), send_queue, } } async fn send(&self, query: #question_enum_name) -> Result { let nonce = self.queries.len() as u64; let res = self.send_queue.send((nonce, query.clone())).await; match res { Ok(_) => { self.queries.insert(nonce, query.into()); Ok(nonce) } Err(e) => Err(#error_enum_name::SendError(e)), } } async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { loop { // Check if we've received the answer for the query we're looking for if let Some(query) = self.queries.get(&id) { if let Some(answer) = query.get_answer() { return Ok(answer); } } match self.recv_queue.recv().await { Some((nonce, answer)) => { self.queries.set_answer(nonce, answer.clone()); } None => return Err(#error_enum_name::Closed), }; } } #(#client_impl)* } }; let expanded = quote! { #error_enum #answer_enum #question_enum #query_enum #queries_struct #server_trait #client_struct }; expanded } fn ident_to_snake_case(ident: &Ident) -> Ident { let ident = ident.to_string(); let mut out = String::new(); for (i, c) in ident.chars().enumerate() { if c.is_uppercase() { if i != 0 { out.push('_'); } out.push(c.to_lowercase().next().unwrap()); } else { out.push(c); } } Ident::new(&out, ident.span()) } fn field_to_args(field: &Field) -> proc_macro2::TokenStream { let type_ = &field.ty; if let syn::Type::Tuple(tuple) = type_ { let mut args = Vec::new(); for (i, elem) in tuple.elems.iter().enumerate() { let arg = Ident::new(&format!("arg{}", i), elem.span()); args.push(quote! { #arg: #elem }); } quote! { #( #args ), * } } else { quote! { arg: #type_ } } } fn field_to_tuple_args(field: &Field) -> proc_macro2::TokenStream { let type_ = &field.ty; if let syn::Type::Tuple(tuple) = type_ { let mut args = Vec::new(); for (i, elem) in tuple.elems.iter().enumerate() { let arg = Ident::new(&format!("arg{}", i), elem.span()); args.push(quote! { #arg }); } quote! { ( #( #args ), * ) } } else { quote! { (arg) } } }