Compare commits
2 Commits
49cabbed95
...
cd2cf3346f
Author | SHA1 | Date | |
---|---|---|---|
cd2cf3346f | |||
62262cb0fe |
142
Cargo.lock
generated
142
Cargo.lock
generated
@ -17,6 +17,64 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"anstyle-parse",
|
||||
"anstyle-query",
|
||||
"anstyle-wincon",
|
||||
"colorchoice",
|
||||
"is_terminal_polyfill",
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b"
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-parse"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4"
|
||||
dependencies = [
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-query"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-wincon"
|
||||
version = "3.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.73"
|
||||
@ -65,10 +123,18 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
|
||||
|
||||
[[package]]
|
||||
name = "eagle"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"env_logger",
|
||||
"log",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rand",
|
||||
@ -78,6 +144,29 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"env_filter",
|
||||
"humantime",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.15"
|
||||
@ -101,12 +190,30 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.155"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.4"
|
||||
@ -212,6 +319,35 @@ dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
||||
|
||||
[[package]]
|
||||
name = "ron"
|
||||
version = "0.8.1"
|
||||
@ -305,6 +441,12 @@ version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||
|
||||
[[package]]
|
||||
name = "utf8parse"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
|
@ -7,9 +7,10 @@ publish = ["gitea"]
|
||||
resolver = "2"
|
||||
|
||||
[features]
|
||||
default = ["tcp"]
|
||||
default = ["tcp", "log"]
|
||||
tcp = ["tokio/net"]
|
||||
unix = ["tokio/net"]
|
||||
log = ["dep:log", "dep:env_logger"]
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = "1.0.85"
|
||||
@ -19,9 +20,13 @@ ron = "0.8.1"
|
||||
serde = { version = "1.0.203", features = ["serde_derive"] }
|
||||
syn = "2.0.66"
|
||||
tokio = { version = "1.38.0", features = ["sync", "io-util"] }
|
||||
env_logger = { version = "0.11.3", optional = true }
|
||||
log = { version = "0.4.21", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.38.0", features = ["sync", "rt-multi-thread", "macros", "time", "io-util", "net"] }
|
||||
env_logger = "0.11.3"
|
||||
log = "0.4.21"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
@ -49,6 +49,7 @@
|
||||
};
|
||||
devShell = with pkgs; mkShell {
|
||||
nativeBuildInputs = with pkgs; [rustc cargo rustfmt pre-commit clippy cargo-expand];
|
||||
RUST_LOG = "debug";
|
||||
};
|
||||
});
|
||||
}
|
174
src/lib.rs
174
src/lib.rs
@ -34,6 +34,33 @@ pub fn derive_protocol_derive(input: TokenStream) -> TokenStream {
|
||||
|
||||
fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
|
||||
let input = parse2::<DeriveInput>(input).unwrap();
|
||||
|
||||
// TODO: These logs should be filterable in some way
|
||||
#[cfg(feature = "log")]
|
||||
#[allow(unused_variables)]
|
||||
let debug = quote! { log::debug! };
|
||||
#[cfg(feature = "log")]
|
||||
#[allow(unused_variables)]
|
||||
let info = quote! { log::info! };
|
||||
#[cfg(feature = "log")]
|
||||
#[allow(unused_variables)]
|
||||
let warn = quote! { log::warn! };
|
||||
#[cfg(feature = "log")]
|
||||
#[allow(unused_variables)]
|
||||
let error = quote! { log::error! };
|
||||
#[cfg(not(feature = "log"))]
|
||||
#[allow(unused_variables)]
|
||||
let debug = quote! { eprintln! };
|
||||
#[cfg(not(feature = "log"))]
|
||||
#[allow(unused_variables)]
|
||||
let info = quote! { eprintln! };
|
||||
#[cfg(not(feature = "log"))]
|
||||
#[allow(unused_variables)]
|
||||
let warn = quote! { eprintln! };
|
||||
#[cfg(not(feature = "log"))]
|
||||
#[allow(unused_variables)]
|
||||
let error = quote! { eprintln! };
|
||||
|
||||
// Must be on an enum
|
||||
let enum_ = match &input.data {
|
||||
syn::Data::Enum(e) => e,
|
||||
@ -102,7 +129,10 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
// There is a function that must be implemented to set the answer in the query enum
|
||||
query_set_answer.push(quote! {
|
||||
#query_enum_name::#var_name(question, answer_opt) => match answer {
|
||||
#answer_enum_name::#var_name(answer) => {answer_opt.replace(answer);},
|
||||
#answer_enum_name::#var_name(answer) => {
|
||||
#debug("Setting answer for query {}", stringify!(#var_name));
|
||||
answer_opt.replace(answer);
|
||||
},
|
||||
_ => panic!("The answer for this query is not the correct type."),
|
||||
},
|
||||
});
|
||||
@ -116,17 +146,19 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
// There is a function that the server uses to call the appropriate function when receiving a query
|
||||
server_handler.push(quote! {
|
||||
#question_enum_name::#var_name(#question_tuple_args) => {
|
||||
#info("Received query {}", stringify!(#var_name));
|
||||
let answer = self.handler.lock().await.#var_name(#question_handler_args).await;
|
||||
return #answer_enum_name::#var_name(answer);
|
||||
},
|
||||
});
|
||||
// The function that the server needs to implement
|
||||
server_trait.push(quote! {
|
||||
async fn #var_name(&mut self, #question_args) -> #answer_type;
|
||||
fn #var_name(&mut self, #question_args) -> impl std::future::Future<Output = #answer_type> + Send;
|
||||
});
|
||||
// The function that the client uses to communicate
|
||||
client_impl.push(quote! {
|
||||
pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> {
|
||||
#info("Sending query {}", stringify!(#var_name));
|
||||
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
|
||||
let answer = self.recv_until(nonce).await?;
|
||||
match answer {
|
||||
@ -154,7 +186,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
let answer_enum = quote! {
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||
#vis enum #answer_enum_name {
|
||||
#(#server_enum), *
|
||||
#(#server_enum), *,
|
||||
Ready
|
||||
}
|
||||
};
|
||||
let question_enum = quote! {
|
||||
@ -220,18 +253,42 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
let listener = #listener_type::bind(addr.as_ref())?;
|
||||
};
|
||||
let sc_struct = quote! {
|
||||
#vis struct #server_connection_struct_name<H: #server_trait_name> {
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + Clone + 'static> {
|
||||
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>,
|
||||
stream: #stream_type,
|
||||
tasks: ::std::sync::Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>,
|
||||
}
|
||||
impl<H: #server_trait_name> #server_connection_struct_name<H> {
|
||||
pub async fn bind<S: #stream_addr_trait>(handler: H, addr: S) -> Result<Self, std::io::Error> {
|
||||
impl<H: #server_trait_name + ::std::marker::Send + Clone + 'static> #server_connection_struct_name<H> {
|
||||
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
|
||||
#info("Binding server to address {}", addr);
|
||||
let handler = ::std::sync::Arc::new(tokio::sync::Mutex::new(handler));
|
||||
let tasks = ::std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
|
||||
let sc = Self {
|
||||
handler,
|
||||
tasks,
|
||||
};
|
||||
let sc_clone = sc.clone();
|
||||
let acc_task = tokio::spawn(async move {
|
||||
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
|
||||
});
|
||||
sc.tasks.lock().await.push(acc_task);
|
||||
sc
|
||||
}
|
||||
|
||||
pub async fn accept_connections<A: #stream_addr_trait>(
|
||||
&self,
|
||||
addr: A,
|
||||
) -> Result<(), std::io::Error> {
|
||||
#listener_statement
|
||||
let (stream, _) = listener.accept().await?;
|
||||
Ok(Self {
|
||||
handler: ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)),
|
||||
stream,
|
||||
})
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
#info("Accepted connection from {}", stream.peer_addr()?);
|
||||
let self_clone = self.clone();
|
||||
let run_task = tokio::spawn(async move {
|
||||
self_clone.run(stream).await;
|
||||
});
|
||||
self.tasks.lock().await.push(run_task);
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle(&self, question: #question_enum_name) -> #answer_enum_name {
|
||||
@ -240,32 +297,41 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(&mut self, nonce: u64, answer: #answer_enum_name) {
|
||||
async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
|
||||
let len = serialized.len() as u32;
|
||||
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
|
||||
#debug("Sending `{}`", serialized);
|
||||
stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||
stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
|
||||
}
|
||||
|
||||
async fn run(mut self) {
|
||||
async fn run(&self, mut stream: #stream_type) {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::AsyncReadExt;
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
self.send(&mut stream, 0, #answer_enum_name::Ready).await;
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(_) = self.stream.readable() => {
|
||||
match self.stream.try_read(&mut buf) {
|
||||
Ok(_) = stream.readable() => {
|
||||
let mut read_buf = [0; 1024];
|
||||
match stream.try_read(&mut read_buf) {
|
||||
Ok(0) => break, // Stream closed
|
||||
Ok(n) => {
|
||||
// TODO: This doesn't cope with partial reads, we will handle that later
|
||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||
let question: (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
||||
// TODO: This should ideally be done in a separate task but that's not
|
||||
// necessary for now
|
||||
let answer = self.handle(question.1).await;
|
||||
self.send(question.0, answer).await;
|
||||
#debug("Received {} bytes (server)", n);
|
||||
buf.extend_from_slice(&read_buf[..n]);
|
||||
loop {
|
||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||
if buf.len() >= (4 + len as usize) {
|
||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||
let (nonce, question): (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize query!");
|
||||
let answer = self.handle(question).await;
|
||||
self.send(&mut stream, nonce, answer).await;
|
||||
buf.drain(0..(4 + len as usize));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
||||
@ -317,17 +383,20 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
struct #client_connection_struct_name {
|
||||
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
||||
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
||||
ready: std::sync::Arc<tokio::sync::Notify>,
|
||||
stream: #stream_type,
|
||||
}
|
||||
impl #client_connection_struct_name {
|
||||
pub fn new(
|
||||
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
||||
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
||||
ready: std::sync::Arc<tokio::sync::Notify>,
|
||||
stream: #stream_type,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_send,
|
||||
received,
|
||||
ready,
|
||||
stream,
|
||||
}
|
||||
}
|
||||
@ -341,19 +410,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
Some(msg) = self.to_send.recv() => {
|
||||
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
|
||||
let len = serialized.len() as u32;
|
||||
#debug("Sending `{}`", serialized);
|
||||
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
|
||||
},
|
||||
Ok(_) = self.stream.readable() => {
|
||||
match self.stream.try_read(&mut buf) {
|
||||
Ok(0) => break, // Stream closed
|
||||
let mut read_buf = [0; 1024];
|
||||
match self.stream.try_read(&mut read_buf) {
|
||||
Ok(0) => { break; },
|
||||
Ok(n) => {
|
||||
// TODO: This doesn't cope with partial reads, we will handle that later
|
||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||
let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
||||
self.received.send(response).await.expect("Failed to send response!");
|
||||
buf.clear();
|
||||
#debug("Received {} bytes (client)", n);
|
||||
buf.extend_from_slice(&read_buf[..n]);
|
||||
while buf.len() >= 4 {
|
||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||
if buf.len() >= (4 + len as usize) {
|
||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||
let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
||||
if let #answer_enum_name::Ready = response.1 {
|
||||
#debug("Received ready signal");
|
||||
self.ready.notify_one();
|
||||
} else {
|
||||
self.received.send(response).await.expect("Failed to send response!");
|
||||
}
|
||||
buf.drain(0..(4 + len as usize));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
||||
@ -386,26 +469,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
queries: #queries_struct_name,
|
||||
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||
recv_queue: #client_recv_queue_wrapper,
|
||||
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
|
||||
ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
|
||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
impl #client_struct_name {
|
||||
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>) -> Self {
|
||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
|
||||
Self {
|
||||
queries: #queries_struct_name::new(),
|
||||
recv_queue: #client_recv_queue_wrapper::new(recv_queue),
|
||||
ready: ::std::sync::Arc::new(false.into()),
|
||||
ready_notify,
|
||||
send_queue,
|
||||
connection_task,
|
||||
}
|
||||
}
|
||||
pub async fn connect<A: #stream_addr_trait>(addr: A) -> Result<Self, std::io::Error> {
|
||||
pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> {
|
||||
#info("Connecting to server at address {}", addr);
|
||||
let stream = #stream_type::connect(addr).await?;
|
||||
let (send_queue, to_send) = tokio::sync::mpsc::channel(16);
|
||||
let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16);
|
||||
let connection = #client_connection_struct_name::new(to_send, to_recv, stream);
|
||||
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
|
||||
let connection_task = tokio::spawn(connection.run());
|
||||
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task))))
|
||||
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)), ready_notify))
|
||||
}
|
||||
pub fn close(self) {
|
||||
if let Some(task) = self.connection_task {
|
||||
@ -413,6 +503,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
}
|
||||
}
|
||||
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
||||
// Wait until the connection is ready
|
||||
if !*self.ready.lock().await {
|
||||
self.ready_notify.notified().await;
|
||||
let mut ready = self.ready.lock().await;
|
||||
*ready = true;
|
||||
}
|
||||
let nonce = self.queries.len() as u64;
|
||||
let res = self.send_queue.send((nonce, query.clone())).await;
|
||||
match res {
|
||||
@ -428,11 +524,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
// Check if we've received the answer for the query we're looking for
|
||||
if let Some(query) = self.queries.get(&id) {
|
||||
if let Some(answer) = query.get_answer() {
|
||||
#info("Found answer for query {}", id);
|
||||
return Ok(answer);
|
||||
}
|
||||
}
|
||||
match self.recv_queue.recv().await {
|
||||
Some((nonce, answer)) => {
|
||||
#info("Received answer for query {}", nonce);
|
||||
self.queries.set_answer(nonce, answer.clone());
|
||||
}
|
||||
None => return Err(#error_enum_name::Closed),
|
||||
|
135
tests/client.rs
Normal file
135
tests/client.rs
Normal file
@ -0,0 +1,135 @@
|
||||
/*
|
||||
Eagle - A library for easy communication in full-stack Rust applications
|
||||
Copyright (c) 2024 KodiCraft
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
use eagle::Protocol;
|
||||
use env_logger::{Builder, Env};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Once;
|
||||
use tokio::sync::{
|
||||
mpsc::{self, Receiver, Sender},
|
||||
Notify,
|
||||
};
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
pub fn init_logger() {
|
||||
INIT.call_once(|| {
|
||||
let env = Env::default()
|
||||
.filter_or("RUST_LOG", "info")
|
||||
.write_style_or("LOG_STYLE", "always");
|
||||
|
||||
Builder::from_env(env).format_timestamp_nanos().init();
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Protocol)]
|
||||
enum TestProtocol {
|
||||
Addition((i32, i32), i32),
|
||||
SomeKindOfQuestion(String, i32),
|
||||
ThisRespondsWithAString(i32, String),
|
||||
Void((), ()),
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client() {
|
||||
init_logger();
|
||||
let (qtx, qrx) = mpsc::channel(16);
|
||||
let (atx, arx) = mpsc::channel(16);
|
||||
let ready_notify = Arc::new(Notify::new());
|
||||
let client = TestProtocolClient::new(qtx, arx, None, ready_notify.clone());
|
||||
ready_notify.notify_one();
|
||||
let server = tokio::spawn(server_loop(qrx, atx));
|
||||
let result = client.addition(2, 5).await.unwrap();
|
||||
assert_eq!(result, 7);
|
||||
let result = client
|
||||
.some_kind_of_question("Hello, world!".to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "Hello, world!".len() as i32);
|
||||
let result = client.this_responds_with_a_string(42).await.unwrap();
|
||||
assert_eq!(result, "The number is 42");
|
||||
client.void().await.unwrap();
|
||||
server.abort();
|
||||
}
|
||||
|
||||
async fn server_loop(
|
||||
mut qrx: Receiver<(u64, __TestProtocolQuestion)>,
|
||||
atx: Sender<(u64, __TestProtocolAnswer)>,
|
||||
) {
|
||||
loop {
|
||||
if let Some((nonce, q)) = qrx.recv().await {
|
||||
match q {
|
||||
__TestProtocolQuestion::addition((a, b)) => {
|
||||
atx.send((nonce, __TestProtocolAnswer::addition(a + b)))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
__TestProtocolQuestion::some_kind_of_question(s) => {
|
||||
atx.send((
|
||||
nonce,
|
||||
__TestProtocolAnswer::some_kind_of_question(s.len() as i32),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
__TestProtocolQuestion::this_responds_with_a_string(i) => {
|
||||
atx.send((
|
||||
nonce,
|
||||
__TestProtocolAnswer::this_responds_with_a_string(format!(
|
||||
"The number is {}",
|
||||
i
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
__TestProtocolQuestion::void(()) => {
|
||||
println!("Received void question");
|
||||
atx.send((nonce, __TestProtocolAnswer::void(())))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn heavy_async() {
|
||||
init_logger();
|
||||
let (qtx, qrx) = mpsc::channel(16);
|
||||
let (atx, arx) = mpsc::channel(16);
|
||||
let ready_notify = Arc::new(Notify::new());
|
||||
let client = TestProtocolClient::new(qtx, arx, None, ready_notify.clone());
|
||||
ready_notify.notify_one();
|
||||
let server = tokio::spawn(server_loop(qrx, atx));
|
||||
let mut tasks = Vec::new();
|
||||
for i in 0..100 {
|
||||
let client = client.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(
|
||||
rand::random::<u64>() % 100,
|
||||
))
|
||||
.await;
|
||||
let result = client.addition(i, i).await.unwrap();
|
||||
assert_eq!(result, i + i);
|
||||
}));
|
||||
}
|
||||
for task in tasks {
|
||||
task.await.unwrap();
|
||||
}
|
||||
server.abort();
|
||||
}
|
81
tests/full.rs
Normal file
81
tests/full.rs
Normal file
@ -0,0 +1,81 @@
|
||||
/*
|
||||
Eagle - A library for easy communication in full-stack Rust applications
|
||||
Copyright (c) 2024 KodiCraft
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
use eagle::Protocol;
|
||||
use env_logger::{Builder, Env};
|
||||
use std::sync::Once;
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
pub fn init_logger() {
|
||||
INIT.call_once(|| {
|
||||
let env = Env::default()
|
||||
.filter_or("RUST_LOG", "info")
|
||||
.write_style_or("LOG_STYLE", "always");
|
||||
|
||||
Builder::from_env(env).format_timestamp_nanos().init();
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Protocol)]
|
||||
enum TestProtocol {
|
||||
Addition((i32, i32), i32),
|
||||
SomeKindOfQuestion(String, i32),
|
||||
ThisRespondsWithAString(i32, String),
|
||||
Void((), ()),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TrivialServer;
|
||||
impl TestProtocolServerTrait for TrivialServer {
|
||||
async fn addition(&mut self, a: i32, b: i32) -> i32 {
|
||||
a + b
|
||||
}
|
||||
async fn some_kind_of_question(&mut self, s: String) -> i32 {
|
||||
s.len() as i32
|
||||
}
|
||||
async fn this_responds_with_a_string(&mut self, i: i32) -> String {
|
||||
format!("The number is {}", i)
|
||||
}
|
||||
async fn void(&mut self) {}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn e2e() {
|
||||
init_logger();
|
||||
#[cfg(feature = "unix")]
|
||||
let address = "/tmp/eagle-test.sock";
|
||||
#[cfg(feature = "tcp")]
|
||||
let address = format!("127.0.0.1:{}", 10000 + rand::random::<u64>() % 1000);
|
||||
let server_task = tokio::spawn(TestProtocolServer::bind(TrivialServer, address.clone()));
|
||||
// Wait for the server to start
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
let client = TestProtocolClient::connect(address).await.unwrap();
|
||||
let res = client.addition(2, 5).await.unwrap();
|
||||
// assert_eq!(client.addition(2, 5).await.unwrap(), 7);
|
||||
// assert_eq!(
|
||||
// client.some_kind_of_question("Hello, world!".to_string())
|
||||
// .await
|
||||
// .unwrap(),
|
||||
// "Hello, world!".len() as i32
|
||||
// );
|
||||
// assert_eq!(
|
||||
// client.this_responds_with_a_string(42).await.unwrap(),
|
||||
// "The number is 42"
|
||||
// );
|
||||
// client.void().await.unwrap();
|
||||
// server_task.abort();
|
||||
}
|
113
tests/mod.rs
113
tests/mod.rs
@ -1,112 +1 @@
|
||||
/*
|
||||
Eagle - A library for easy communication in full-stack Rust applications
|
||||
Copyright (c) 2024 KodiCraft
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
use eagle::Protocol;
|
||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||
|
||||
#[derive(Protocol)]
|
||||
enum TestProtocol {
|
||||
Addition((i32, i32), i32),
|
||||
SomeKindOfQuestion(String, i32),
|
||||
ThisRespondsWithAString(i32, String),
|
||||
Void((), ()),
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn main() {
|
||||
let (qtx, qrx) = mpsc::channel(16);
|
||||
let (atx, arx) = mpsc::channel(16);
|
||||
let client = TestProtocolClient::new(qtx, arx, None);
|
||||
let server = tokio::spawn(server_loop(qrx, atx));
|
||||
let result = client.addition(2, 5).await.unwrap();
|
||||
assert_eq!(result, 7);
|
||||
let result = client
|
||||
.some_kind_of_question("Hello, world!".to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "Hello, world!".len() as i32);
|
||||
let result = client.this_responds_with_a_string(42).await.unwrap();
|
||||
assert_eq!(result, "The number is 42");
|
||||
client.void().await.unwrap();
|
||||
server.abort();
|
||||
}
|
||||
|
||||
async fn server_loop(
|
||||
mut qrx: Receiver<(u64, __TestProtocolQuestion)>,
|
||||
atx: Sender<(u64, __TestProtocolAnswer)>,
|
||||
) {
|
||||
loop {
|
||||
if let Some((nonce, q)) = qrx.recv().await {
|
||||
match q {
|
||||
__TestProtocolQuestion::addition((a, b)) => {
|
||||
atx.send((nonce, __TestProtocolAnswer::addition(a + b)))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
__TestProtocolQuestion::some_kind_of_question(s) => {
|
||||
atx.send((
|
||||
nonce,
|
||||
__TestProtocolAnswer::some_kind_of_question(s.len() as i32),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
__TestProtocolQuestion::this_responds_with_a_string(i) => {
|
||||
atx.send((
|
||||
nonce,
|
||||
__TestProtocolAnswer::this_responds_with_a_string(format!(
|
||||
"The number is {}",
|
||||
i
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
__TestProtocolQuestion::void(()) => {
|
||||
println!("Received void question");
|
||||
atx.send((nonce, __TestProtocolAnswer::void(())))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn heavy_async() {
|
||||
let (qtx, qrx) = mpsc::channel(16);
|
||||
let (atx, arx) = mpsc::channel(16);
|
||||
let client = TestProtocolClient::new(qtx, arx, None);
|
||||
let server = tokio::spawn(server_loop(qrx, atx));
|
||||
let mut tasks = Vec::new();
|
||||
for i in 0..100 {
|
||||
let client = client.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(
|
||||
rand::random::<u64>() % 100,
|
||||
))
|
||||
.await;
|
||||
let result = client.addition(i, i).await.unwrap();
|
||||
assert_eq!(result, i + i);
|
||||
}));
|
||||
}
|
||||
for task in tasks {
|
||||
task.await.unwrap();
|
||||
}
|
||||
server.abort();
|
||||
}
|
||||
mod client;
|
||||
|
Loading…
Reference in New Issue
Block a user