eagle/src/lib.rs

282 lines
11 KiB
Rust

/*
Eagle - A library for easy communication in full-stack Rust applications
Copyright (c) 2024 KodiCraft
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse2, spanned::Spanned, DeriveInput, Field, Ident};
#[proc_macro_derive(Protocol)]
pub fn derive_protocol_derive(input: TokenStream) -> TokenStream {
let expanded = derive_protocol(input.into());
TokenStream::from(expanded)
}
fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let input = parse2::<DeriveInput>(input).unwrap();
// Must be on an enum
let enum_ = match &input.data {
syn::Data::Enum(e) => e,
_ => {
return syn::Error::new(input.span(), "Protocol can only be derived on enums")
.to_compile_error()
}
};
let name = &input.ident;
let error_enum_name = format_ident!("__{}Error", name);
let answer_enum_name = format_ident!("__{}Answer", name);
let question_enum_name = format_ident!("__{}Question", name);
let query_enum_name = format_ident!("__{}Query", name);
let server_trait_name = format_ident!("{}Server", name);
let client_struct_name = format_ident!("{}Client", name);
let vis = &input.vis;
let mut server_trait = Vec::new();
let mut server_enum = Vec::new();
let mut client_impl = Vec::new();
let mut client_enum = Vec::new();
let mut query_enum = Vec::new();
let mut query_from_question_enum = Vec::new();
let mut query_set_answer = Vec::new();
let mut query_get_answer = Vec::new();
for variant in &enum_.variants {
// Every variant must have 2 fields
// The first field is the question (serverbound), the second field is the answer (clientbound)
if variant.fields.len() != 2 {
return syn::Error::new(
variant.span(),
"Every variant on a protocol must have exactly 2 fields",
)
.to_compile_error();
}
let var_name = ident_to_snake_case(&variant.ident);
let mut variant_fields = variant.fields.iter();
let question_field = variant_fields.next().unwrap();
let question_args = field_to_args(question_field);
let question_tuple_args = field_to_tuple_args(question_field);
let answer_type = variant_fields.next().unwrap().ty.clone();
// The variants that either the server or the client will use
// The "server" enum contains messages the server can send, the "client" enum contains messages the client can send
server_enum.push(quote! {
#var_name(#answer_type)
});
client_enum.push(quote! {
#var_name(#question_field)
});
// There is a From implementation for the client enum to the query enum
query_from_question_enum.push(quote! {
#question_enum_name::#var_name(question) => #query_enum_name::#var_name(question, None),
});
// There is a function that must be implemented to set the answer in the query enum
query_set_answer.push(quote! {
#query_enum_name::#var_name(question, answer_opt) => match answer {
#answer_enum_name::#var_name(answer) => {answer_opt.replace(answer);},
_ => panic!("The answer for this query is not the correct type."),
},
});
// There is a function that must be implemented to get the answer from the query enum
query_get_answer.push(quote! {
#query_enum_name::#var_name(_, answer) => match answer {
Some(answer) => Some(#answer_enum_name::#var_name(answer.clone())),
None => None
},
});
// The function that the server needs to implement
server_trait.push(quote! {
fn #var_name(&mut self, #question_args) -> #answer_type;
});
// The function that the client uses to communicate
client_impl.push(quote! {
pub async fn #var_name(&mut self, #question_args) -> Result<#answer_type, #error_enum_name> {
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
let answer = self.recv_until(nonce).await?;
match answer {
#answer_enum_name::#var_name(answer) => Ok(answer),
_ => panic!("The answer for this query is not the correct type."),
}
}
});
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
query_enum.push(quote! {
#var_name(#question_field, Option<#answer_type>)
});
}
// Create an error and result type for sending messages
let error_enum = quote! {
#vis enum #error_enum_name {
SendError(tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>),
Closed,
}
};
// Create enums for the types of messages the server and client will use
let answer_enum = quote! {
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#vis enum #answer_enum_name {
#(#server_enum), *
}
};
let question_enum = quote! {
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#vis enum #question_enum_name {
#(#client_enum), *
}
};
// Create an enum to represent the queries the client has sent
let query_enum = quote! {
#[derive(Clone, Debug)]
#vis enum #query_enum_name {
#(#query_enum), *
}
impl #query_enum_name {
pub fn set_answer(&mut self, answer: #answer_enum_name) {
match self {
#(#query_set_answer)*
};
}
pub fn get_answer(&self) -> Option<#answer_enum_name> {
match self {
#(#query_get_answer)*
}
}
}
impl From<#question_enum_name> for #query_enum_name {
fn from(query: #question_enum_name) -> Self {
match query {
#(#query_from_question_enum)*
}
}
}
};
// Create a trait which the server will have to implement
let server_trait = quote! {
#vis trait #server_trait_name {
#(#server_trait)*
}
};
// Create a struct which the client will use to communicate
let client_struct = quote! {
#vis struct #client_struct_name {
queries: ::std::collections::HashMap<u64, #query_enum_name>,
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
} // TODO: This struct will have some fields to handle the actual connection
impl #client_struct_name {
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
Self {
queries: ::std::collections::HashMap::new(),
send_queue,
recv_queue,
}
}
async fn send(&mut self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
let nonce = self.queries.len() as u64;
let res = self.send_queue.send((nonce, query.clone())).await;
match res {
Ok(_) => {
self.queries.insert(nonce, query.into());
Ok(nonce)
}
Err(e) => Err(#error_enum_name::SendError(e)),
}
}
async fn recv_until(&mut self, id: u64) -> Result<#answer_enum_name, #error_enum_name> {
loop {
// Check if we've received the answer for the query we're looking for
if let Some(query) = self.queries.get(&id) {
if let Some(answer) = query.get_answer() {
return Ok(answer);
}
}
match self.recv_queue.recv().await {
Some((nonce, answer)) => {
// Replace the Option<> in the query with the answer
if let Some(query) = self.queries.get_mut(&nonce) {
query.set_answer(answer);
} else {
panic!("Received an answer for a query we did not send");
}
}
None => return Err(#error_enum_name::Closed),
};
}
}
#(#client_impl)*
}
};
let expanded = quote! {
#error_enum
#answer_enum
#question_enum
#query_enum
#server_trait
#client_struct
};
expanded
}
fn ident_to_snake_case(ident: &Ident) -> Ident {
let ident = ident.to_string();
let mut out = String::new();
for (i, c) in ident.chars().enumerate() {
if c.is_uppercase() {
if i != 0 {
out.push('_');
}
out.push(c.to_lowercase().next().unwrap());
} else {
out.push(c);
}
}
Ident::new(&out, ident.span())
}
fn field_to_args(field: &Field) -> proc_macro2::TokenStream {
let type_ = &field.ty;
if let syn::Type::Tuple(tuple) = type_ {
let mut args = Vec::new();
for (i, elem) in tuple.elems.iter().enumerate() {
let arg = Ident::new(&format!("arg{}", i), elem.span());
args.push(quote! { #arg: #elem });
}
quote! { #( #args ), * }
} else {
quote! { arg: #type_ }
}
}
fn field_to_tuple_args(field: &Field) -> proc_macro2::TokenStream {
let type_ = &field.ty;
if let syn::Type::Tuple(tuple) = type_ {
let mut args = Vec::new();
for (i, elem) in tuple.elems.iter().enumerate() {
let arg = Ident::new(&format!("arg{}", i), elem.span());
args.push(quote! { #arg });
}
quote! { ( #( #args ), * ) }
} else {
quote! { (arg) }
}
}