2024-06-19 23:25:45 +02:00
|
|
|
/*
|
2024-06-19 23:27:10 +02:00
|
|
|
Eagle - A library for easy communication in full-stack Rust applications
|
2024-06-19 23:25:45 +02:00
|
|
|
Copyright (c) 2024 KodiCraft
|
|
|
|
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
|
|
it under the terms of the GNU Affero General Public License as
|
|
|
|
published by the Free Software Foundation, either version 3 of the
|
|
|
|
License, or (at your option) any later version.
|
|
|
|
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
GNU Affero General Public License for more details.
|
|
|
|
|
|
|
|
You should have received a copy of the GNU Affero General Public License
|
|
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
*/
|
|
|
|
use proc_macro::TokenStream;
|
2024-06-20 11:34:04 +02:00
|
|
|
use quote::{format_ident, quote};
|
2024-06-19 23:25:45 +02:00
|
|
|
use syn::{parse_macro_input, spanned::Spanned, DeriveInput, Field, Ident};
|
|
|
|
|
|
|
|
#[proc_macro_derive(Protocol)]
|
|
|
|
pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|
|
|
let input = parse_macro_input!(input as DeriveInput);
|
|
|
|
// Must be on an enum
|
|
|
|
let enum_ = match &input.data {
|
|
|
|
syn::Data::Enum(e) => e,
|
|
|
|
_ => {
|
|
|
|
return syn::Error::new(input.span(), "Protocol can only be derived on enums")
|
|
|
|
.to_compile_error()
|
|
|
|
.into()
|
|
|
|
}
|
|
|
|
};
|
|
|
|
let name = &input.ident;
|
|
|
|
let vis = &input.vis;
|
|
|
|
|
|
|
|
let mut server_trait = Vec::new();
|
2024-06-20 11:34:04 +02:00
|
|
|
let mut server_enum = Vec::new();
|
2024-06-19 23:25:45 +02:00
|
|
|
let mut client_impl = Vec::new();
|
2024-06-20 11:34:04 +02:00
|
|
|
let mut client_enum = Vec::new();
|
|
|
|
|
|
|
|
let mut query_enum = Vec::new();
|
2024-06-19 23:25:45 +02:00
|
|
|
|
|
|
|
for variant in &enum_.variants {
|
|
|
|
// Every variant must have 2 fields
|
|
|
|
// The first field is the question (serverbound), the second field is the answer (clientbound)
|
|
|
|
if variant.fields.len() != 2 {
|
|
|
|
return syn::Error::new(
|
|
|
|
variant.span(),
|
|
|
|
"Every variant on a protocol must have exactly 2 fields",
|
|
|
|
)
|
|
|
|
.to_compile_error()
|
|
|
|
.into();
|
|
|
|
}
|
|
|
|
|
|
|
|
let var_name = ident_to_snake_case(&variant.ident);
|
|
|
|
let mut variant_fields = variant.fields.iter();
|
|
|
|
let question_field = variant_fields.next().unwrap();
|
|
|
|
let question_args = field_to_args(question_field);
|
|
|
|
let answer_type = variant_fields.next().unwrap().ty.clone();
|
|
|
|
|
2024-06-20 11:34:04 +02:00
|
|
|
// The variants that either the server or the client will use
|
|
|
|
// The "server" enum contains messages the server can send, the "client" enum contains messages the client can send
|
|
|
|
server_enum.push(quote! {
|
|
|
|
#var_name(#answer_type)
|
|
|
|
});
|
|
|
|
client_enum.push(quote! {
|
|
|
|
#var_name(#question_field)
|
|
|
|
});
|
|
|
|
// The function that the server needs to implement
|
2024-06-19 23:25:45 +02:00
|
|
|
server_trait.push(quote! {
|
|
|
|
fn #var_name(&mut self, #question_args) -> #answer_type;
|
|
|
|
});
|
2024-06-20 11:34:04 +02:00
|
|
|
// The function that the client uses to communicate
|
2024-06-19 23:25:45 +02:00
|
|
|
client_impl.push(quote! {
|
|
|
|
pub fn #var_name(&mut self, #question_args) -> #answer_type {
|
|
|
|
::std::unimplemented!()
|
|
|
|
}
|
2024-06-20 11:34:04 +02:00
|
|
|
});
|
|
|
|
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
|
|
|
|
query_enum.push(quote! {
|
|
|
|
#var_name(#question_field, Option<#answer_type>)
|
|
|
|
});
|
2024-06-19 23:25:45 +02:00
|
|
|
}
|
|
|
|
|
2024-06-20 11:34:04 +02:00
|
|
|
// Create enums for the types of messages the server and client will use
|
|
|
|
let server_enum_name = format_ident!("{}Answer", name);
|
|
|
|
let server_enum = quote! {
|
|
|
|
#vis enum #server_enum_name {
|
|
|
|
#(#server_enum), *
|
|
|
|
}
|
|
|
|
};
|
|
|
|
let client_enum_name = format_ident!("{}Question", name);
|
|
|
|
let client_enum = quote! {
|
2024-06-20 11:52:36 +02:00
|
|
|
#[derive(serde::Serialize, serde::Deserialize)]
|
2024-06-20 11:34:04 +02:00
|
|
|
#vis enum #client_enum_name {
|
|
|
|
#(#client_enum), *
|
|
|
|
}
|
|
|
|
};
|
|
|
|
// Create an enum to represent the queries the client has sent
|
|
|
|
let query_enum_name = format_ident!("{}Query", name);
|
|
|
|
let query_enum = quote! {
|
2024-06-20 11:52:36 +02:00
|
|
|
#[derive(serde::Serialize, serde::Deserialize)]
|
2024-06-20 11:34:04 +02:00
|
|
|
#vis enum #query_enum_name {
|
|
|
|
#(#query_enum), *
|
|
|
|
}
|
|
|
|
};
|
2024-06-19 23:25:45 +02:00
|
|
|
// Create a trait which the server will have to implement
|
2024-06-20 11:34:04 +02:00
|
|
|
let server_trait_name = format_ident!("{}Server", name);
|
2024-06-19 23:25:45 +02:00
|
|
|
let server_trait = quote! {
|
|
|
|
#vis trait #server_trait_name {
|
|
|
|
#(#server_trait)*
|
|
|
|
}
|
|
|
|
};
|
2024-06-20 11:34:04 +02:00
|
|
|
// Create a struct which the client will use to communicate
|
|
|
|
let client_struct_name = format_ident!("{}Client", name);
|
2024-06-19 23:25:45 +02:00
|
|
|
let client_struct = quote! {
|
2024-06-20 11:34:04 +02:00
|
|
|
#vis struct #client_struct_name {
|
|
|
|
queries: ::std::collections::HashMap<u64, #query_enum_name>,
|
|
|
|
send_queue: tokio::sync::mpsc::Sender<#client_enum_name>,
|
|
|
|
recv_queue: tokio::sync::mpsc::Receiver<#server_enum_name>,
|
|
|
|
} // TODO: This struct will have some fields to handle the actual connection
|
2024-06-19 23:25:45 +02:00
|
|
|
impl #client_struct_name {
|
2024-06-20 11:34:04 +02:00
|
|
|
pub fn new(send_queue: tokio::sync::mpsc::Sender<#client_enum_name>, recv_queue: tokio::sync::mpsc::Receiver<#server_enum_name>) -> Self {
|
|
|
|
Self {
|
|
|
|
queries: ::std::collections::HashMap::new(),
|
|
|
|
send_queue,
|
|
|
|
recv_queue,
|
|
|
|
}
|
|
|
|
}
|
2024-06-19 23:25:45 +02:00
|
|
|
#(#client_impl)*
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
let expanded = quote! {
|
2024-06-20 11:34:04 +02:00
|
|
|
#server_enum
|
|
|
|
#client_enum
|
|
|
|
#query_enum
|
2024-06-19 23:25:45 +02:00
|
|
|
#server_trait
|
|
|
|
#client_struct
|
|
|
|
};
|
|
|
|
expanded.into()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn ident_to_snake_case(ident: &Ident) -> Ident {
|
|
|
|
let ident = ident.to_string();
|
|
|
|
let mut out = String::new();
|
|
|
|
for (i, c) in ident.chars().enumerate() {
|
|
|
|
if c.is_uppercase() {
|
|
|
|
if i != 0 {
|
|
|
|
out.push('_');
|
|
|
|
}
|
|
|
|
out.push(c.to_lowercase().next().unwrap());
|
|
|
|
} else {
|
|
|
|
out.push(c);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ident::new(&out, ident.span())
|
|
|
|
}
|
|
|
|
|
|
|
|
fn field_to_args(field: &Field) -> proc_macro2::TokenStream {
|
|
|
|
let type_ = &field.ty;
|
|
|
|
if let syn::Type::Tuple(tuple) = type_ {
|
|
|
|
let mut args = Vec::new();
|
|
|
|
for (i, elem) in tuple.elems.iter().enumerate() {
|
|
|
|
let arg = Ident::new(&format!("arg{}", i), elem.span());
|
|
|
|
args.push(quote! { #arg: #elem });
|
|
|
|
}
|
|
|
|
quote! { #( #args ), * }
|
|
|
|
} else {
|
|
|
|
quote! { arg: #type_ }
|
|
|
|
}
|
|
|
|
}
|