Generate code to "send" a question

This commit is contained in:
Kodi Craft 2024-06-20 12:52:42 +02:00
parent cc4d14fe69
commit aa7d51f088
Signed by: kodi
GPG Key ID: 69D9EED60B242822
2 changed files with 71 additions and 20 deletions

View File

@ -32,6 +32,14 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
} }
}; };
let name = &input.ident; let name = &input.ident;
let error_enum_name = format_ident!("{}Error", name);
let answer_enum_name = format_ident!("{}Answer", name);
let question_enum_name = format_ident!("{}Question", name);
let query_enum_name = format_ident!("{}Query", name);
let server_trait_name = format_ident!("{}Server", name);
let client_struct_name = format_ident!("{}Client", name);
let vis = &input.vis; let vis = &input.vis;
let mut server_trait = Vec::new(); let mut server_trait = Vec::new();
@ -40,6 +48,7 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
let mut client_enum = Vec::new(); let mut client_enum = Vec::new();
let mut query_enum = Vec::new(); let mut query_enum = Vec::new();
let mut query_from_question_enum = Vec::new();
for variant in &enum_.variants { for variant in &enum_.variants {
// Every variant must have 2 fields // Every variant must have 2 fields
@ -57,6 +66,7 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
let mut variant_fields = variant.fields.iter(); let mut variant_fields = variant.fields.iter();
let question_field = variant_fields.next().unwrap(); let question_field = variant_fields.next().unwrap();
let question_args = field_to_args(question_field); let question_args = field_to_args(question_field);
let question_tuple_args = field_to_tuple_args(question_field);
let answer_type = variant_fields.next().unwrap().ty.clone(); let answer_type = variant_fields.next().unwrap().ty.clone();
// The variants that either the server or the client will use // The variants that either the server or the client will use
@ -67,14 +77,19 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
client_enum.push(quote! { client_enum.push(quote! {
#var_name(#question_field) #var_name(#question_field)
}); });
// There is a From implementation for the client enum to the query enum
query_from_question_enum.push(quote! {
#question_enum_name::#var_name(question) => #query_enum_name::#var_name(question, None),
});
// The function that the server needs to implement // The function that the server needs to implement
server_trait.push(quote! { server_trait.push(quote! {
fn #var_name(&mut self, #question_args) -> #answer_type; fn #var_name(&mut self, #question_args) -> #answer_type;
}); });
// The function that the client uses to communicate // The function that the client uses to communicate
client_impl.push(quote! { client_impl.push(quote! {
pub fn #var_name(&mut self, #question_args) -> #answer_type { pub async fn #var_name(&mut self, #question_args) -> Result<#answer_type, #error_enum_name> {
::std::unimplemented!() let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
todo!("Wait for the answer")
} }
}); });
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<> // The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
@ -83,58 +98,80 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
}); });
} }
// Create an error and result type for sending messages
let error_enum = quote! {
#vis enum #error_enum_name {
SendError(tokio::sync::mpsc::error::SendError<#question_enum_name>),
}
};
// Create enums for the types of messages the server and client will use // Create enums for the types of messages the server and client will use
let server_enum_name = format_ident!("{}Answer", name);
let server_enum = quote! { let answer_enum = quote! {
#vis enum #server_enum_name { #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#vis enum #answer_enum_name {
#(#server_enum), * #(#server_enum), *
} }
}; };
let client_enum_name = format_ident!("{}Question", name); let question_enum = quote! {
let client_enum = quote! { #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[derive(serde::Serialize, serde::Deserialize)] #vis enum #question_enum_name {
#vis enum #client_enum_name {
#(#client_enum), * #(#client_enum), *
} }
}; };
// Create an enum to represent the queries the client has sent // Create an enum to represent the queries the client has sent
let query_enum_name = format_ident!("{}Query", name);
let query_enum = quote! { let query_enum = quote! {
#[derive(serde::Serialize, serde::Deserialize)] #[derive(Clone, Debug)]
#vis enum #query_enum_name { #vis enum #query_enum_name {
#(#query_enum), * #(#query_enum), *
} }
impl From<#question_enum_name> for #query_enum_name {
fn from(query: #question_enum_name) -> Self {
match query {
#(#query_from_question_enum)*
}
}
}
}; };
// Create a trait which the server will have to implement // Create a trait which the server will have to implement
let server_trait_name = format_ident!("{}Server", name);
let server_trait = quote! { let server_trait = quote! {
#vis trait #server_trait_name { #vis trait #server_trait_name {
#(#server_trait)* #(#server_trait)*
} }
}; };
// Create a struct which the client will use to communicate // Create a struct which the client will use to communicate
let client_struct_name = format_ident!("{}Client", name);
let client_struct = quote! { let client_struct = quote! {
#vis struct #client_struct_name { #vis struct #client_struct_name {
queries: ::std::collections::HashMap<u64, #query_enum_name>, queries: ::std::collections::HashMap<u64, #query_enum_name>,
send_queue: tokio::sync::mpsc::Sender<#client_enum_name>, send_queue: tokio::sync::mpsc::Sender<#question_enum_name>,
recv_queue: tokio::sync::mpsc::Receiver<#server_enum_name>, recv_queue: tokio::sync::mpsc::Receiver<#answer_enum_name>,
} // TODO: This struct will have some fields to handle the actual connection } // TODO: This struct will have some fields to handle the actual connection
impl #client_struct_name { impl #client_struct_name {
pub fn new(send_queue: tokio::sync::mpsc::Sender<#client_enum_name>, recv_queue: tokio::sync::mpsc::Receiver<#server_enum_name>) -> Self { pub fn new(send_queue: tokio::sync::mpsc::Sender<#question_enum_name>, recv_queue: tokio::sync::mpsc::Receiver<#answer_enum_name>) -> Self {
Self { Self {
queries: ::std::collections::HashMap::new(), queries: ::std::collections::HashMap::new(),
send_queue, send_queue,
recv_queue, recv_queue,
} }
} }
async fn send(&mut self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
let res = self.send_queue.send(query.clone()).await;
match res {
Ok(_) => {
let id = self.queries.len() as u64;
self.queries.insert(id, query.into());
Ok(id)
}
Err(e) => Err(#error_enum_name::SendError(e)),
}
}
#(#client_impl)* #(#client_impl)*
} }
}; };
let expanded = quote! { let expanded = quote! {
#server_enum #error_enum
#client_enum #answer_enum
#question_enum
#query_enum #query_enum
#server_trait #server_trait
#client_struct #client_struct
@ -171,3 +208,17 @@ fn field_to_args(field: &Field) -> proc_macro2::TokenStream {
quote! { arg: #type_ } quote! { arg: #type_ }
} }
} }
fn field_to_tuple_args(field: &Field) -> proc_macro2::TokenStream {
let type_ = &field.ty;
if let syn::Type::Tuple(tuple) = type_ {
let mut args = Vec::new();
for (i, elem) in tuple.elems.iter().enumerate() {
let arg = Ident::new(&format!("arg{}", i), elem.span());
args.push(quote! { #arg });
}
quote! { ( #( #args ), * ) }
} else {
quote! { (arg) }
}
}

View File

@ -29,8 +29,8 @@ impl TestProtocolServer for DummyServer {
question.len() as i32 question.len() as i32
} }
fn addition(&mut self, arg0: i32, arg1: i32) -> i32 { fn addition(&mut self, a: i32, b: i32) -> i32 {
todo!() a + b
} }
} }