eagle/src/lib.rs

814 lines
34 KiB
Rust
Raw Permalink Normal View History

2024-06-19 23:25:45 +02:00
/*
2024-06-24 18:26:19 +02:00
Eagle - A simple library for RPC in Rust
2024-06-19 23:25:45 +02:00
Copyright (c) 2024 KodiCraft
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
2024-06-24 18:26:19 +02:00
//! # Eagle - A simple library for RPC in Rust
//!
//! <div class="warning">Eagle is still in early development. This documentation is subject to change and may not be entirely accurate.</div>
//!
//! Eagle is a library for building RPC protocols in Rust. It uses a macro
//! to transform your protocol definition into the necessary code to allow
//! communication between a server and a client.
//!
//! Eagle uses [`tokio`](https://tokio.rs) as its async runtime and
//! [`ron`](https://crates.io/crates/ron) for serialization.
//!
//! ## Usage
//! `eagle` is designed to be used to create your own protocol crate. We
//! recommend creating a new cargo workspace for your project with a shared
//! crate which will contain your protocol definition and individual crates
//! for the server and client.
//!
//! In your shared crate, you can define your protocol using the [`Protocol`]
//! derive macro. This will generate the necessary code for the server and
//! client to communicate.
//!
//! ```rust
//! use eagle::Protocol;
//!
//! #[derive(Protocol)]
//! pub enum Example {
//! Add((i32, i32), i32),
//! Length(String, usize),
//! /* ... */
//! }
//! ```
//!
//! The [`Protocol`] derive macro will generate all the necessary code, including
//! your handler trait, your server struct, and your client struct.
//!
//! On your server, you will need to implement the handler trait. This trait
//! describes the functions that the client can request from the server.
//!
//! ```rust
//! # use eagle::Protocol;
//! # #[derive(Protocol)]
//! # pub enum Example {
//! # Add((i32, i32), i32),
//! # Length(String, usize),
//! # /* ... */
//! # }
//! #
//! #[derive(Clone)]
//! pub struct Handler;
//! impl ExampleServerHandler for Handler {
//! async fn add(&mut self, a: i32, b: i32) -> i32 {
//! a + b
//! }
//! async fn length(&mut self, s: String) -> usize {
//! s.len()
//! }
//! /* ... */
//! }
//! ```
//!
//! To start the server, you simply need to use the generated server struct and
//! pass it your handler.
//!
//! ```no_run
2024-06-24 18:26:19 +02:00
//! # use eagle::Protocol;
//! # #[derive(Protocol)]
//! # pub enum Example {
//! # Add((i32, i32), i32),
//! # Length(String, usize),
//! # /* ... */
//! # }
//! #
//! # #[derive(Clone)]
//! # pub struct Handler;
//! # impl ExampleServerHandler for Handler {
//! # async fn add(&mut self, a: i32, b: i32) -> i32 {
//! # a + b
//! # }
//! # async fn length(&mut self, s: String) -> usize {
//! # s.len()
//! # }
//! # }
//! #
//! # tokio_test::block_on(async {
//! let handler = Handler;
//! let address = "127.0.0.1:12345"; // Or, if using the 'unix' feature, "/tmp/eagle.sock"
2024-06-25 11:39:07 +02:00
//! let server = ExampleServer::bind(handler, address).await;
//! server.close().await;
2024-06-24 18:26:19 +02:00
//! # });
//! ```
2024-06-25 11:39:07 +02:00
//! Once bound, the server will begin listening for incoming connections and
//! queries. **You must remember to use the `close` method to shut down the server.**
2024-06-24 18:26:19 +02:00
//!
//! On the client side, you can simply use the generated client struct to connect
//! to the server and begin sending queries.
//!
//! ```no_run
2024-06-24 18:26:19 +02:00
//! # use eagle::Protocol;
//! # #[derive(Protocol)]
//! # pub enum Example {
//! # Add((i32, i32), i32),
//! # Length(String, usize),
//! # /* ... */
//! # }
//! #
//! # #[derive(Clone)]
//! # pub struct Handler;
//! # impl ExampleServerHandler for Handler {
//! # async fn add(&mut self, a: i32, b: i32) -> i32 {
//! # a + b
//! # }
//! # async fn length(&mut self, s: String) -> usize {
//! # s.len()
//! # }
//! # }
//! #
//! # tokio_test::block_on(async {
//! # let handler = Handler;
//! let address = "127.0.0.1:12345"; // Or, if using the 'unix' feature, "/tmp/eagle.sock"
2024-06-25 11:39:07 +02:00
//! # let server = ExampleServer::bind(handler, address).await;
//! # tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; // Wait for the server to start
2024-06-24 18:26:19 +02:00
//! let client = ExampleClient::connect(address).await.unwrap();
//! assert_eq!(client.add(2, 5).await.unwrap(), 7);
2024-06-25 11:39:07 +02:00
//! # server.close().await;
2024-06-24 18:26:19 +02:00
//! # });
//! ```
//!
//! The client can be closed by calling the `close` method on the client struct.
//! This will abort the connection.
#![warn(missing_docs)]
2024-06-19 23:25:45 +02:00
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse2, spanned::Spanned, DeriveInput, Field, Ident};
2024-06-19 23:25:45 +02:00
#[cfg(all(feature = "tcp", feature = "unix"))]
compile_error!("You can only enable one of the 'tcp' or 'unix' features");
#[cfg(all(not(feature = "tcp"), not(feature = "unix")))]
compile_error!("You must enable either the 'tcp' or 'unix' feature");
#[cfg(all(feature = "unix", not(unix)))]
compile_error!("The 'unix' feature requires compiling for a unix target");
2024-06-24 18:26:19 +02:00
/// Generate all the necessary RPC code for a protocol from an enum describing it.
///
/// This macro will generate various enums and structs to enable communication
/// between a server and a client. The following items will be generated, where {}
/// is the name of the protocol enum:
/// - `{}ServerHandler` - A trait that the server must implement to handle queries
/// - `{}Server` - A struct that the server uses to communicate with clients
/// - `{}Client` - A struct that the client uses to communicate with a server
///
/// Each variant of the passed enum represents a query that the client can send to the
/// server. The first field of each variant is the question (serverbound), the second field
/// is the answer (clientbound). You may use tuples to represent sending multiple arguments and
/// you may use the unit type `()` to represent no arguments. Only data types which implement
/// [`Clone`], [`serde::Serialize`], and [`serde::Deserialize`] can be used.
///
/// For more information on how to use the generated code, see the [crate-level documentation](index.html).
///
/// # Example
/// ```rust
/// use eagle::Protocol;
///
/// #[derive(Protocol)]
/// pub enum Example {
/// Add((i32, i32), i32),
/// Length(String, usize),
/// }
/// ```
2024-06-19 23:25:45 +02:00
#[proc_macro_derive(Protocol)]
pub fn derive_protocol_derive(input: TokenStream) -> TokenStream {
let expanded = derive_protocol(input.into());
TokenStream::from(expanded)
}
fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let input = parse2::<DeriveInput>(input).unwrap();
2024-06-24 15:34:14 +02:00
// TODO: These logs should be filterable in some way
#[cfg(feature = "log")]
2024-06-24 15:36:29 +02:00
#[allow(unused_variables)]
2024-06-24 22:32:03 +02:00
let debug = quote! { ::log::debug! };
2024-06-24 15:34:14 +02:00
#[cfg(feature = "log")]
2024-06-24 15:36:29 +02:00
#[allow(unused_variables)]
2024-06-24 22:32:03 +02:00
let info = quote! { ::log::info! };
2024-06-24 15:34:14 +02:00
#[cfg(feature = "log")]
2024-06-24 15:36:29 +02:00
#[allow(unused_variables)]
2024-06-24 22:32:03 +02:00
let warn = quote! { ::log::warn! };
2024-06-24 15:34:14 +02:00
#[cfg(feature = "log")]
2024-06-24 15:36:29 +02:00
#[allow(unused_variables)]
2024-06-24 22:32:03 +02:00
let error = quote! { ::log::error! };
2024-06-24 15:34:14 +02:00
#[cfg(not(feature = "log"))]
2024-06-24 15:36:29 +02:00
#[allow(unused_variables)]
2024-06-24 22:32:03 +02:00
let debug = quote! { ::std::eprintln! };
2024-06-24 15:34:14 +02:00
#[cfg(not(feature = "log"))]
2024-06-24 15:36:29 +02:00
#[allow(unused_variables)]
2024-06-24 22:32:03 +02:00
let info = quote! { ::std::eprintln! };
2024-06-24 15:34:14 +02:00
#[cfg(not(feature = "log"))]
2024-06-24 15:36:29 +02:00
#[allow(unused_variables)]
2024-06-24 22:32:03 +02:00
let warn = quote! { ::std::eprintln! };
2024-06-24 15:34:14 +02:00
#[cfg(not(feature = "log"))]
2024-06-24 15:36:29 +02:00
#[allow(unused_variables)]
2024-06-24 22:32:03 +02:00
let error = quote! { ::std::eprintln! };
2024-06-24 15:34:14 +02:00
2024-06-19 23:25:45 +02:00
// Must be on an enum
let enum_ = match &input.data {
syn::Data::Enum(e) => e,
_ => {
return syn::Error::new(input.span(), "Protocol can only be derived on enums")
.to_compile_error()
}
};
let name = &input.ident;
2024-06-20 12:52:42 +02:00
let error_enum_name = format_ident!("__{}Error", name);
let answer_enum_name = format_ident!("__{}Answer", name);
let question_enum_name = format_ident!("__{}Question", name);
let query_enum_name = format_ident!("__{}Query", name);
let queries_struct_name = format_ident!("__{}Queries", name);
2024-06-21 15:54:48 +02:00
let client_connection_struct_name = format_ident!("__{}Connection", name);
2024-06-24 16:58:14 +02:00
let server_trait_name = format_ident!("{}ServerHandler", name);
let server_connection_struct_name = format_ident!("{}Server", name);
2024-06-20 12:52:42 +02:00
let client_struct_name = format_ident!("{}Client", name);
2024-06-19 23:25:45 +02:00
let vis = &input.vis;
let mut server_trait = Vec::new();
let mut server_enum = Vec::new();
2024-06-19 23:25:45 +02:00
let mut client_impl = Vec::new();
let mut client_enum = Vec::new();
let mut server_handler = Vec::new();
let mut query_enum = Vec::new();
2024-06-20 12:52:42 +02:00
let mut query_from_question_enum = Vec::new();
let mut query_set_answer = Vec::new();
let mut query_get_answer = Vec::new();
2024-06-19 23:25:45 +02:00
for variant in &enum_.variants {
// Every variant must have 2 fields
// The first field is the question (serverbound), the second field is the answer (clientbound)
if variant.fields.len() != 2 {
return syn::Error::new(
variant.span(),
"Every variant on a protocol must have exactly 2 fields",
)
.to_compile_error();
2024-06-19 23:25:45 +02:00
}
let var_name = ident_to_snake_case(&variant.ident);
let mut variant_fields = variant.fields.iter();
let question_field = variant_fields.next().unwrap();
let question_args = field_to_args(question_field);
let question_handler_args = field_to_handler_args(question_field);
2024-06-20 12:52:42 +02:00
let question_tuple_args = field_to_tuple_args(question_field);
2024-06-19 23:25:45 +02:00
let answer_type = variant_fields.next().unwrap().ty.clone();
// The variants that either the server or the client will use
// The "server" enum contains messages the server can send, the "client" enum contains messages the client can send
server_enum.push(quote! {
#var_name(#answer_type)
});
client_enum.push(quote! {
#var_name(#question_field)
});
2024-06-20 12:52:42 +02:00
// There is a From implementation for the client enum to the query enum
query_from_question_enum.push(quote! {
#question_enum_name::#var_name(question) => #query_enum_name::#var_name(question, None),
});
// There is a function that must be implemented to set the answer in the query enum
query_set_answer.push(quote! {
#query_enum_name::#var_name(question, answer_opt) => match answer {
2024-06-24 15:34:14 +02:00
#answer_enum_name::#var_name(answer) => {
#debug("Setting answer for query {}", stringify!(#var_name));
answer_opt.replace(answer);
},
_ => panic!("The answer for this query is not the correct type."),
},
});
// There is a function that must be implemented to get the answer from the query enum
query_get_answer.push(quote! {
#query_enum_name::#var_name(_, answer) => match answer {
2024-06-24 22:32:03 +02:00
::std::option::Option::Some(answer) => ::std::option::Option::Some(#answer_enum_name::#var_name(answer.clone())),
::std::option::Option::None => ::std::option::Option::None
},
});
// There is a function that the server uses to call the appropriate function when receiving a query
server_handler.push(quote! {
#question_enum_name::#var_name(#question_tuple_args) => {
2024-06-24 15:34:14 +02:00
#info("Received query {}", stringify!(#var_name));
let answer = self.handler.lock().await.#var_name(#question_handler_args).await;
return #answer_enum_name::#var_name(answer);
},
});
// The function that the server needs to implement
2024-06-19 23:25:45 +02:00
server_trait.push(quote! {
2024-06-24 22:32:03 +02:00
fn #var_name(&mut self, #question_args) -> impl ::std::future::Future<Output = #answer_type> + Send;
2024-06-19 23:25:45 +02:00
});
// The function that the client uses to communicate
2024-06-19 23:25:45 +02:00
client_impl.push(quote! {
pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> {
2024-06-24 15:34:14 +02:00
#info("Sending query {}", stringify!(#var_name));
2024-06-20 12:52:42 +02:00
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
let answer = self.recv_until(nonce).await?;
match answer {
#answer_enum_name::#var_name(answer) => Ok(answer),
2024-06-24 22:32:03 +02:00
_ => ::std::panic!("The answer for this query is not the correct type."),
}
2024-06-19 23:25:45 +02:00
}
});
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
query_enum.push(quote! {
2024-06-24 22:32:03 +02:00
#var_name(#question_field, ::std::option::Option<#answer_type>)
});
2024-06-19 23:25:45 +02:00
}
2024-06-20 12:52:42 +02:00
// Create an error and result type for sending messages
let error_enum = quote! {
2024-06-24 23:01:58 +02:00
#[derive(::std::fmt::Debug)]
2024-06-20 12:52:42 +02:00
#vis enum #error_enum_name {
2024-06-24 22:32:03 +02:00
SendError(::tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>),
Closed,
2024-06-20 12:52:42 +02:00
}
2024-06-24 23:01:58 +02:00
impl ::std::fmt::Display for #error_enum_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
match self {
#error_enum_name::SendError(e) => write!(f, "Failed to send query: {}", e),
#error_enum_name::Closed => write!(f, "Connection closed"),
}
}
}
impl ::std::error::Error for #error_enum_name {
fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> {
match self {
#error_enum_name::SendError(e) => ::std::option::Option::Some(e),
#error_enum_name::Closed => ::std::option::Option::None,
}
}
fn description(&self) -> &str {
match self {
#error_enum_name::SendError(_) => "Failed to send query",
#error_enum_name::Closed => "Connection closed",
}
}
fn cause(&self) -> ::std::option::Option<&dyn ::std::error::Error> {
match self {
#error_enum_name::SendError(e) => ::std::option::Option::Some(e),
#error_enum_name::Closed => ::std::option::Option::None,
}
}
}
2024-06-20 12:52:42 +02:00
};
// Create enums for the types of messages the server and client will use
2024-06-20 12:52:42 +02:00
let answer_enum = quote! {
2024-06-24 22:32:03 +02:00
#[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)]
2024-06-20 12:52:42 +02:00
#vis enum #answer_enum_name {
2024-06-24 15:34:14 +02:00
#(#server_enum), *,
Ready
}
};
2024-06-20 12:52:42 +02:00
let question_enum = quote! {
2024-06-24 22:32:03 +02:00
#[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)]
2024-06-20 12:52:42 +02:00
#vis enum #question_enum_name {
#(#client_enum), *
}
};
// Create an enum to represent the queries the client has sent
let query_enum = quote! {
2024-06-24 22:32:03 +02:00
#[derive(::std::clone::Clone, ::std::fmt::Debug)]
#vis enum #query_enum_name {
#(#query_enum), *
}
impl #query_enum_name {
pub fn set_answer(&mut self, answer: #answer_enum_name) {
match self {
#(#query_set_answer)*
};
}
2024-06-24 22:32:03 +02:00
pub fn get_answer(&self) -> ::std::option::Option<#answer_enum_name> {
match self {
#(#query_get_answer)*
}
}
}
2024-06-20 12:52:42 +02:00
impl From<#question_enum_name> for #query_enum_name {
fn from(query: #question_enum_name) -> Self {
match query {
#(#query_from_question_enum)*
}
}
}
};
#[cfg(feature = "tcp")]
2024-06-24 22:32:03 +02:00
let stream_type = quote! { ::tokio::net::TcpStream };
#[cfg(feature = "tcp")]
2024-06-24 22:32:03 +02:00
let stream_addr_trait = quote! { ::tokio::net::ToSocketAddrs };
2024-06-24 12:22:54 +02:00
#[cfg(feature = "tcp")]
2024-06-24 22:32:03 +02:00
let listener_type = quote! { ::tokio::net::TcpListener };
#[cfg(feature = "unix")]
2024-06-24 22:32:03 +02:00
let stream_type = quote! { ::tokio::net::UnixStream };
#[cfg(feature = "unix")]
2024-06-24 22:32:03 +02:00
let stream_addr_trait = quote! { ::std::convert::AsRef<std::path::Path> };
2024-06-24 12:22:54 +02:00
#[cfg(feature = "unix")]
2024-06-24 22:32:03 +02:00
let listener_type = quote! { ::tokio::net::UnixListener };
2024-06-19 23:25:45 +02:00
// Create a trait which the server will have to implement
let server_trait = quote! {
#vis trait #server_trait_name {
#(#server_trait)*
}
};
// Create a struct to implement the communication between the server and the client
2024-06-24 12:26:33 +02:00
#[cfg(feature = "tcp")]
let listener_statement = quote! {
let listener = #listener_type::bind(addr).await?;
};
#[cfg(feature = "unix")]
let listener_statement = quote! {
let listener = #listener_type::bind(addr.as_ref())?;
};
let sc_struct = quote! {
2024-06-24 15:34:14 +02:00
#[derive(Clone)]
2024-06-24 22:32:03 +02:00
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + ::std::clone::Clone + 'static> {
handler: ::std::sync::Arc<::tokio::sync::Mutex<H>>,
tasks: ::std::sync::Arc<::tokio::sync::Mutex<::std::vec::Vec<tokio::task::JoinHandle<()>>>>,
}
2024-06-24 22:32:03 +02:00
impl<H: #server_trait_name + ::std::marker::Send + std::clone::Clone + 'static> #server_connection_struct_name<H> {
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + ::std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
2024-06-24 15:34:14 +02:00
#info("Binding server to address {}", addr);
2024-06-24 22:32:03 +02:00
let handler = ::std::sync::Arc::new(::tokio::sync::Mutex::new(handler));
let tasks = ::std::sync::Arc::new(::tokio::sync::Mutex::new(::std::vec::Vec::new()));
2024-06-24 15:34:14 +02:00
let sc = Self {
handler,
tasks,
};
let sc_clone = sc.clone();
2024-06-24 22:32:03 +02:00
let acc_task = ::tokio::spawn(async move {
2024-06-24 15:34:14 +02:00
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
});
sc.tasks.lock().await.push(acc_task);
sc
}
2024-06-25 11:39:07 +02:00
pub async fn close(self) {
#info("Closing server");
for task in self.tasks.lock().await.drain(..) {
task.abort();
}
}
2024-06-24 15:34:14 +02:00
pub async fn accept_connections<A: #stream_addr_trait>(
&self,
addr: A,
2024-06-24 22:32:03 +02:00
) -> ::std::result::Result<(), ::std::io::Error> {
2024-06-24 12:26:33 +02:00
#listener_statement
2024-06-24 15:34:14 +02:00
loop {
let (stream, _) = listener.accept().await?;
2024-06-24 15:38:55 +02:00
#info("Accepted connection from {:?}", stream.peer_addr()?);
2024-06-24 15:34:14 +02:00
let self_clone = self.clone();
2024-06-24 22:32:03 +02:00
let run_task = ::tokio::spawn(async move {
2024-06-24 15:34:14 +02:00
self_clone.run(stream).await;
});
self.tasks.lock().await.push(run_task);
}
2024-06-24 12:22:54 +02:00
}
async fn handle(&self, question: #question_enum_name) -> #answer_enum_name {
match question {
#(#server_handler)*
}
}
2024-06-24 12:22:54 +02:00
2024-06-24 15:34:14 +02:00
async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) {
2024-06-24 22:32:03 +02:00
use ::tokio::io::AsyncWriteExt;
2024-06-24 12:22:54 +02:00
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
let len = serialized.len() as u32;
2024-06-24 15:34:14 +02:00
#debug("Sending `{}`", serialized);
stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
2024-06-24 12:22:54 +02:00
}
2024-06-24 15:34:14 +02:00
async fn run(&self, mut stream: #stream_type) {
2024-06-24 22:32:03 +02:00
use ::tokio::io::AsyncWriteExt;
use ::tokio::io::AsyncReadExt;
let mut buf = ::std::vec::Vec::with_capacity(1024);
2024-06-24 15:34:14 +02:00
self.send(&mut stream, 0, #answer_enum_name::Ready).await;
2024-06-24 12:22:54 +02:00
loop {
2024-06-24 22:32:03 +02:00
::tokio::select! {
::std::result::Result::Ok(_) = stream.readable() => {
2024-06-24 15:34:14 +02:00
let mut read_buf = [0; 1024];
match stream.try_read(&mut read_buf) {
2024-06-24 22:32:03 +02:00
::std::result::Result::Ok(0) => break, // Stream closed
::std::result::Result::Ok(n) => {
2024-06-24 15:34:14 +02:00
#debug("Received {} bytes (server)", n);
buf.extend_from_slice(&read_buf[..n]);
while buf.len() >= 4 {
2024-06-24 15:34:14 +02:00
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
if buf.len() >= (4 + len as usize) {
2024-06-24 22:32:03 +02:00
let serialized = ::std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
let (nonce, question): (u64, #question_enum_name) = ::ron::de::from_str(serialized).expect("Failed to deserialize query!");
2024-06-24 15:34:14 +02:00
let answer = self.handle(question).await;
self.send(&mut stream, nonce, answer).await;
buf.drain(0..(4 + len as usize));
} else {
break;
}
}
2024-06-24 12:22:54 +02:00
},
2024-06-24 22:32:03 +02:00
::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
::std::result::Result::Err(e) => ::std::eprintln!("Error reading from stream: {:?}", e),
2024-06-24 12:22:54 +02:00
}
}
}
}
}
}
};
// Create a struct to hold queries behind an Arc<Mutex<>> to enable async access
// TODO: It might be a good idea to just make this a generic struct and write it in actual code
// rather than in this macro
let queries_struct = quote! {
#[derive(Clone)]
struct #queries_struct_name {
queries: ::std::sync::Arc<::std::sync::Mutex<::std::collections::HashMap<u64, #query_enum_name>>>,
}
impl #queries_struct_name {
fn new() -> Self {
Self {
queries: ::std::sync::Arc::new(::std::sync::Mutex::new(::std::collections::HashMap::new())),
}
}
pub fn insert(&self, nonce: u64, query: #query_enum_name) {
self.queries.lock().unwrap().insert(nonce, query);
}
2024-06-24 22:32:03 +02:00
pub fn get(&self, nonce: &u64) -> ::std::option::Option<#query_enum_name> {
self.queries.lock().unwrap().get(nonce).cloned()
}
pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) {
2024-06-24 22:32:03 +02:00
if let ::std::option::Option::Some(query) = self.queries.lock().unwrap().get_mut(&nonce) {
query.set_answer(answer);
}
}
pub fn len(&self) -> usize {
self.queries.lock().unwrap().len()
}
}
};
// Create a struct to handle the connection from the client to the server
2024-06-21 15:54:48 +02:00
let cc_struct = quote! {
struct #client_connection_struct_name {
2024-06-24 22:32:03 +02:00
to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
ready: ::std::sync::Arc<tokio::sync::Notify>,
2024-06-21 15:54:48 +02:00
stream: #stream_type,
}
impl #client_connection_struct_name {
pub fn new(
2024-06-24 22:32:03 +02:00
to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
ready: ::std::sync::Arc<::tokio::sync::Notify>,
2024-06-21 15:54:48 +02:00
stream: #stream_type,
) -> Self {
Self {
to_send,
received,
2024-06-24 15:34:14 +02:00
ready,
2024-06-21 15:54:48 +02:00
stream,
}
}
pub async fn run(mut self) {
2024-06-24 22:32:03 +02:00
use ::tokio::io::AsyncWriteExt;
use ::tokio::io::AsyncReadExt;
let mut buf = ::std::vec::Vec::with_capacity(1024);
2024-06-21 15:54:48 +02:00
loop {
2024-06-24 22:32:03 +02:00
::tokio::select! {
::std::option::Option::Some(msg) = self.to_send.recv() => {
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
2024-06-21 15:54:48 +02:00
let len = serialized.len() as u32;
2024-06-24 15:34:14 +02:00
#debug("Sending `{}`", serialized);
2024-06-21 15:54:48 +02:00
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
},
2024-06-24 22:32:03 +02:00
::std::result::Result::Ok(_) = self.stream.readable() => {
2024-06-24 15:34:14 +02:00
let mut read_buf = [0; 1024];
match self.stream.try_read(&mut read_buf) {
2024-06-24 22:32:03 +02:00
::std::result::Result::Ok(0) => { break; },
::std::result::Result::Ok(n) => {
2024-06-24 15:34:14 +02:00
#debug("Received {} bytes (client)", n);
buf.extend_from_slice(&read_buf[..n]);
while buf.len() >= 4 {
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
if buf.len() >= (4 + len as usize) {
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
if let #answer_enum_name::Ready = response.1 {
#debug("Received ready signal");
self.ready.notify_one();
} else {
self.received.send(response).await.expect("Failed to send response!");
}
buf.drain(0..(4 + len as usize));
} else {
break;
}
}
2024-06-21 15:54:48 +02:00
},
2024-06-24 22:32:03 +02:00
::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
::std::result::Result::Err(e) => eprintln!("Error reading from stream: {:?}", e),
2024-06-21 15:54:48 +02:00
}
}
}
}
}
}
};
// Create a struct which the client will use to communicate
let client_recv_queue_wrapper = format_ident!("__{}RecvQueueWrapper", name);
2024-06-19 23:25:45 +02:00
let client_struct = quote! {
2024-06-24 23:40:47 +02:00
#[derive(::std::clone::Clone)]
struct #client_recv_queue_wrapper {
2024-06-24 22:32:03 +02:00
recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>,
}
impl #client_recv_queue_wrapper {
2024-06-24 22:32:03 +02:00
fn new(recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
Self {
recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)),
}
}
2024-06-24 22:32:03 +02:00
async fn recv(&self) -> ::std::option::Option<(u64, #answer_enum_name)> {
self.recv_queue.lock().await.recv().await
}
}
#[derive(Clone)]
#vis struct #client_struct_name {
queries: #queries_struct_name,
2024-06-24 22:32:03 +02:00
send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
recv_queue: #client_recv_queue_wrapper,
2024-06-24 15:34:14 +02:00
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
2024-06-24 22:32:03 +02:00
connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
2024-06-21 15:54:48 +02:00
}
2024-06-19 23:25:45 +02:00
impl #client_struct_name {
2024-06-24 22:32:03 +02:00
pub fn new(send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
2024-06-24 15:34:14 +02:00
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
Self {
queries: #queries_struct_name::new(),
recv_queue: #client_recv_queue_wrapper::new(recv_queue),
2024-06-24 15:34:14 +02:00
ready: ::std::sync::Arc::new(false.into()),
ready_notify,
send_queue,
connection_task,
}
}
2024-06-24 15:34:14 +02:00
pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> {
#info("Connecting to server at address {}", addr);
let stream = #stream_type::connect(addr).await?;
2024-06-24 22:32:03 +02:00
let (send_queue, to_send) = ::tokio::sync::mpsc::channel(16);
let (to_recv, recv_queue) = ::tokio::sync::mpsc::channel(16);
2024-06-24 15:34:14 +02:00
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
2024-06-24 22:32:03 +02:00
let connection_task = ::tokio::spawn(connection.run());
Ok(Self::new(send_queue, recv_queue, ::std::option::Option::Some(::std::sync::Arc::new(connection_task)), ready_notify))
}
2024-06-24 23:40:47 +02:00
pub fn close(&mut self) {
if let ::std::option::Option::Some(task) = self.connection_task.take() {
task.abort();
}
}
2024-06-24 22:32:03 +02:00
async fn send(&self, query: #question_enum_name) -> ::std::result::Result<u64, #error_enum_name> {
2024-06-24 15:34:14 +02:00
// Wait until the connection is ready
if !*self.ready.lock().await {
self.ready_notify.notified().await;
let mut ready = self.ready.lock().await;
*ready = true;
}
let nonce = self.queries.len() as u64;
let res = self.send_queue.send((nonce, query.clone())).await;
2024-06-20 12:52:42 +02:00
match res {
2024-06-24 22:32:03 +02:00
::std::result::Result::Ok(_) => {
self.queries.insert(nonce, query.into());
2024-06-24 22:32:03 +02:00
::std::result::Result::Ok(nonce)
2024-06-20 12:52:42 +02:00
}
2024-06-24 22:32:03 +02:00
::std::result::Result::Err(e) => ::std::result::Result::Err(#error_enum_name::SendError(e)),
2024-06-20 12:52:42 +02:00
}
}
2024-06-24 22:32:03 +02:00
async fn recv_until(&self, id: u64) -> ::std::result::Result<#answer_enum_name, #error_enum_name> {
loop {
// Check if we've received the answer for the query we're looking for
2024-06-24 22:32:03 +02:00
if let ::std::option::Option::Some(query) = self.queries.get(&id) {
if let ::std::option::Option::Some(answer) = query.get_answer() {
2024-06-24 15:34:14 +02:00
#info("Found answer for query {}", id);
return Ok(answer);
}
}
match self.recv_queue.recv().await {
2024-06-24 22:32:03 +02:00
::std::option::Option::Some((nonce, answer)) => {
2024-06-24 15:34:14 +02:00
#info("Received answer for query {}", nonce);
self.queries.set_answer(nonce, answer.clone());
}
2024-06-24 22:32:03 +02:00
::std::option::Option::None => return ::std::result::Result::Err(#error_enum_name::Closed),
};
}
}
2024-06-19 23:25:45 +02:00
#(#client_impl)*
}
2024-06-24 23:40:47 +02:00
impl ::std::ops::Drop for #client_struct_name {
fn drop(&mut self) {
self.close();
}
}
2024-06-19 23:25:45 +02:00
};
let expanded = quote! {
2024-06-20 12:52:42 +02:00
#error_enum
#answer_enum
#question_enum
#query_enum
#queries_struct
2024-06-19 23:25:45 +02:00
#server_trait
#sc_struct
2024-06-21 15:54:48 +02:00
#cc_struct
2024-06-19 23:25:45 +02:00
#client_struct
};
expanded
2024-06-19 23:25:45 +02:00
}
fn ident_to_snake_case(ident: &Ident) -> Ident {
let ident = ident.to_string();
let mut out = String::new();
for (i, c) in ident.chars().enumerate() {
if c.is_uppercase() {
if i != 0 {
out.push('_');
}
out.push(c.to_lowercase().next().unwrap());
} else {
out.push(c);
}
}
Ident::new(&out, ident.span())
}
fn field_to_args(field: &Field) -> proc_macro2::TokenStream {
let type_ = &field.ty;
if let syn::Type::Tuple(tuple) = type_ {
let mut args = Vec::new();
for (i, elem) in tuple.elems.iter().enumerate() {
let arg = Ident::new(&format!("arg{}", i), elem.span());
args.push(quote! { #arg: #elem });
}
quote! { #( #args ), * }
} else {
quote! { arg: #type_ }
}
}
2024-06-20 12:52:42 +02:00
fn field_to_tuple_args(field: &Field) -> proc_macro2::TokenStream {
let type_ = &field.ty;
if let syn::Type::Tuple(tuple) = type_ {
let mut args = Vec::new();
for (i, elem) in tuple.elems.iter().enumerate() {
let arg = Ident::new(&format!("arg{}", i), elem.span());
args.push(quote! { #arg });
}
quote! { ( #( #args ), * ) }
} else {
quote! { (arg) }
}
}
fn field_to_handler_args(field: &Field) -> proc_macro2::TokenStream {
let type_ = &field.ty;
if let syn::Type::Tuple(tuple) = type_ {
let mut args = Vec::new();
for (i, elem) in tuple.elems.iter().enumerate() {
let arg = Ident::new(&format!("arg{}", i), elem.span());
args.push(quote! { #arg });
}
quote! { #( #args ), * }
} else {
quote! { arg }
}
}