Compare commits
2 Commits
49cabbed95
...
cd2cf3346f
Author | SHA1 | Date | |
---|---|---|---|
cd2cf3346f | |||
62262cb0fe |
142
Cargo.lock
generated
142
Cargo.lock
generated
@ -17,6 +17,64 @@ version = "1.0.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aho-corasick"
|
||||||
|
version = "1.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstream"
|
||||||
|
version = "0.6.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b"
|
||||||
|
dependencies = [
|
||||||
|
"anstyle",
|
||||||
|
"anstyle-parse",
|
||||||
|
"anstyle-query",
|
||||||
|
"anstyle-wincon",
|
||||||
|
"colorchoice",
|
||||||
|
"is_terminal_polyfill",
|
||||||
|
"utf8parse",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstyle"
|
||||||
|
version = "1.0.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstyle-parse"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4"
|
||||||
|
dependencies = [
|
||||||
|
"utf8parse",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstyle-query"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391"
|
||||||
|
dependencies = [
|
||||||
|
"windows-sys 0.52.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anstyle-wincon"
|
||||||
|
version = "3.0.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19"
|
||||||
|
dependencies = [
|
||||||
|
"anstyle",
|
||||||
|
"windows-sys 0.52.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backtrace"
|
name = "backtrace"
|
||||||
version = "0.3.73"
|
version = "0.3.73"
|
||||||
@ -65,10 +123,18 @@ version = "1.0.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "colorchoice"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "eagle"
|
name = "eagle"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"env_logger",
|
||||||
|
"log",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rand",
|
"rand",
|
||||||
@ -78,6 +144,29 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "env_filter"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea"
|
||||||
|
dependencies = [
|
||||||
|
"log",
|
||||||
|
"regex",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "env_logger"
|
||||||
|
version = "0.11.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9"
|
||||||
|
dependencies = [
|
||||||
|
"anstream",
|
||||||
|
"anstyle",
|
||||||
|
"env_filter",
|
||||||
|
"humantime",
|
||||||
|
"log",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
version = "0.2.15"
|
version = "0.2.15"
|
||||||
@ -101,12 +190,30 @@ version = "0.3.9"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
|
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "humantime"
|
||||||
|
version = "2.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "is_terminal_polyfill"
|
||||||
|
version = "1.70.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.155"
|
version = "0.2.155"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "log"
|
||||||
|
version = "0.4.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
version = "2.7.4"
|
version = "2.7.4"
|
||||||
@ -212,6 +319,35 @@ dependencies = [
|
|||||||
"getrandom",
|
"getrandom",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex"
|
||||||
|
version = "1.10.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-automata",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-automata"
|
||||||
|
version = "0.4.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-syntax"
|
||||||
|
version = "0.8.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ron"
|
name = "ron"
|
||||||
version = "0.8.1"
|
version = "0.8.1"
|
||||||
@ -305,6 +441,12 @@ version = "1.0.12"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utf8parse"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
version = "0.11.0+wasi-snapshot-preview1"
|
version = "0.11.0+wasi-snapshot-preview1"
|
||||||
|
@ -7,9 +7,10 @@ publish = ["gitea"]
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["tcp"]
|
default = ["tcp", "log"]
|
||||||
tcp = ["tokio/net"]
|
tcp = ["tokio/net"]
|
||||||
unix = ["tokio/net"]
|
unix = ["tokio/net"]
|
||||||
|
log = ["dep:log", "dep:env_logger"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
proc-macro2 = "1.0.85"
|
proc-macro2 = "1.0.85"
|
||||||
@ -19,9 +20,13 @@ ron = "0.8.1"
|
|||||||
serde = { version = "1.0.203", features = ["serde_derive"] }
|
serde = { version = "1.0.203", features = ["serde_derive"] }
|
||||||
syn = "2.0.66"
|
syn = "2.0.66"
|
||||||
tokio = { version = "1.38.0", features = ["sync", "io-util"] }
|
tokio = { version = "1.38.0", features = ["sync", "io-util"] }
|
||||||
|
env_logger = { version = "0.11.3", optional = true }
|
||||||
|
log = { version = "0.4.21", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { version = "1.38.0", features = ["sync", "rt-multi-thread", "macros", "time", "io-util", "net"] }
|
tokio = { version = "1.38.0", features = ["sync", "rt-multi-thread", "macros", "time", "io-util", "net"] }
|
||||||
|
env_logger = "0.11.3"
|
||||||
|
log = "0.4.21"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
proc-macro = true
|
proc-macro = true
|
||||||
|
@ -49,6 +49,7 @@
|
|||||||
};
|
};
|
||||||
devShell = with pkgs; mkShell {
|
devShell = with pkgs; mkShell {
|
||||||
nativeBuildInputs = with pkgs; [rustc cargo rustfmt pre-commit clippy cargo-expand];
|
nativeBuildInputs = with pkgs; [rustc cargo rustfmt pre-commit clippy cargo-expand];
|
||||||
|
RUST_LOG = "debug";
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
}
|
}
|
160
src/lib.rs
160
src/lib.rs
@ -34,6 +34,33 @@ pub fn derive_protocol_derive(input: TokenStream) -> TokenStream {
|
|||||||
|
|
||||||
fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
|
fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
|
||||||
let input = parse2::<DeriveInput>(input).unwrap();
|
let input = parse2::<DeriveInput>(input).unwrap();
|
||||||
|
|
||||||
|
// TODO: These logs should be filterable in some way
|
||||||
|
#[cfg(feature = "log")]
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
let debug = quote! { log::debug! };
|
||||||
|
#[cfg(feature = "log")]
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
let info = quote! { log::info! };
|
||||||
|
#[cfg(feature = "log")]
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
let warn = quote! { log::warn! };
|
||||||
|
#[cfg(feature = "log")]
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
let error = quote! { log::error! };
|
||||||
|
#[cfg(not(feature = "log"))]
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
let debug = quote! { eprintln! };
|
||||||
|
#[cfg(not(feature = "log"))]
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
let info = quote! { eprintln! };
|
||||||
|
#[cfg(not(feature = "log"))]
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
let warn = quote! { eprintln! };
|
||||||
|
#[cfg(not(feature = "log"))]
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
let error = quote! { eprintln! };
|
||||||
|
|
||||||
// Must be on an enum
|
// Must be on an enum
|
||||||
let enum_ = match &input.data {
|
let enum_ = match &input.data {
|
||||||
syn::Data::Enum(e) => e,
|
syn::Data::Enum(e) => e,
|
||||||
@ -102,7 +129,10 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
// There is a function that must be implemented to set the answer in the query enum
|
// There is a function that must be implemented to set the answer in the query enum
|
||||||
query_set_answer.push(quote! {
|
query_set_answer.push(quote! {
|
||||||
#query_enum_name::#var_name(question, answer_opt) => match answer {
|
#query_enum_name::#var_name(question, answer_opt) => match answer {
|
||||||
#answer_enum_name::#var_name(answer) => {answer_opt.replace(answer);},
|
#answer_enum_name::#var_name(answer) => {
|
||||||
|
#debug("Setting answer for query {}", stringify!(#var_name));
|
||||||
|
answer_opt.replace(answer);
|
||||||
|
},
|
||||||
_ => panic!("The answer for this query is not the correct type."),
|
_ => panic!("The answer for this query is not the correct type."),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -116,17 +146,19 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
// There is a function that the server uses to call the appropriate function when receiving a query
|
// There is a function that the server uses to call the appropriate function when receiving a query
|
||||||
server_handler.push(quote! {
|
server_handler.push(quote! {
|
||||||
#question_enum_name::#var_name(#question_tuple_args) => {
|
#question_enum_name::#var_name(#question_tuple_args) => {
|
||||||
|
#info("Received query {}", stringify!(#var_name));
|
||||||
let answer = self.handler.lock().await.#var_name(#question_handler_args).await;
|
let answer = self.handler.lock().await.#var_name(#question_handler_args).await;
|
||||||
return #answer_enum_name::#var_name(answer);
|
return #answer_enum_name::#var_name(answer);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
// The function that the server needs to implement
|
// The function that the server needs to implement
|
||||||
server_trait.push(quote! {
|
server_trait.push(quote! {
|
||||||
async fn #var_name(&mut self, #question_args) -> #answer_type;
|
fn #var_name(&mut self, #question_args) -> impl std::future::Future<Output = #answer_type> + Send;
|
||||||
});
|
});
|
||||||
// The function that the client uses to communicate
|
// The function that the client uses to communicate
|
||||||
client_impl.push(quote! {
|
client_impl.push(quote! {
|
||||||
pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> {
|
pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> {
|
||||||
|
#info("Sending query {}", stringify!(#var_name));
|
||||||
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
|
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
|
||||||
let answer = self.recv_until(nonce).await?;
|
let answer = self.recv_until(nonce).await?;
|
||||||
match answer {
|
match answer {
|
||||||
@ -154,7 +186,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
let answer_enum = quote! {
|
let answer_enum = quote! {
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||||
#vis enum #answer_enum_name {
|
#vis enum #answer_enum_name {
|
||||||
#(#server_enum), *
|
#(#server_enum), *,
|
||||||
|
Ready
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let question_enum = quote! {
|
let question_enum = quote! {
|
||||||
@ -220,18 +253,42 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
let listener = #listener_type::bind(addr.as_ref())?;
|
let listener = #listener_type::bind(addr.as_ref())?;
|
||||||
};
|
};
|
||||||
let sc_struct = quote! {
|
let sc_struct = quote! {
|
||||||
#vis struct #server_connection_struct_name<H: #server_trait_name> {
|
#[derive(Clone)]
|
||||||
|
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + Clone + 'static> {
|
||||||
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>,
|
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>,
|
||||||
stream: #stream_type,
|
tasks: ::std::sync::Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>,
|
||||||
}
|
}
|
||||||
impl<H: #server_trait_name> #server_connection_struct_name<H> {
|
impl<H: #server_trait_name + ::std::marker::Send + Clone + 'static> #server_connection_struct_name<H> {
|
||||||
pub async fn bind<S: #stream_addr_trait>(handler: H, addr: S) -> Result<Self, std::io::Error> {
|
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
|
||||||
|
#info("Binding server to address {}", addr);
|
||||||
|
let handler = ::std::sync::Arc::new(tokio::sync::Mutex::new(handler));
|
||||||
|
let tasks = ::std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
|
||||||
|
let sc = Self {
|
||||||
|
handler,
|
||||||
|
tasks,
|
||||||
|
};
|
||||||
|
let sc_clone = sc.clone();
|
||||||
|
let acc_task = tokio::spawn(async move {
|
||||||
|
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
|
||||||
|
});
|
||||||
|
sc.tasks.lock().await.push(acc_task);
|
||||||
|
sc
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn accept_connections<A: #stream_addr_trait>(
|
||||||
|
&self,
|
||||||
|
addr: A,
|
||||||
|
) -> Result<(), std::io::Error> {
|
||||||
#listener_statement
|
#listener_statement
|
||||||
|
loop {
|
||||||
let (stream, _) = listener.accept().await?;
|
let (stream, _) = listener.accept().await?;
|
||||||
Ok(Self {
|
#info("Accepted connection from {}", stream.peer_addr()?);
|
||||||
handler: ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)),
|
let self_clone = self.clone();
|
||||||
stream,
|
let run_task = tokio::spawn(async move {
|
||||||
})
|
self_clone.run(stream).await;
|
||||||
|
});
|
||||||
|
self.tasks.lock().await.push(run_task);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(&self, question: #question_enum_name) -> #answer_enum_name {
|
async fn handle(&self, question: #question_enum_name) -> #answer_enum_name {
|
||||||
@ -240,32 +297,41 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&mut self, nonce: u64, answer: #answer_enum_name) {
|
async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) {
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
|
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
|
||||||
let len = serialized.len() as u32;
|
let len = serialized.len() as u32;
|
||||||
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
#debug("Sending `{}`", serialized);
|
||||||
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
|
stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||||
|
stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run(mut self) {
|
async fn run(&self, mut stream: #stream_type) {
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
use tokio::io::AsyncReadExt;
|
use tokio::io::AsyncReadExt;
|
||||||
let mut buf = Vec::with_capacity(1024);
|
let mut buf = Vec::with_capacity(1024);
|
||||||
|
self.send(&mut stream, 0, #answer_enum_name::Ready).await;
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
Ok(_) = self.stream.readable() => {
|
Ok(_) = stream.readable() => {
|
||||||
match self.stream.try_read(&mut buf) {
|
let mut read_buf = [0; 1024];
|
||||||
|
match stream.try_read(&mut read_buf) {
|
||||||
Ok(0) => break, // Stream closed
|
Ok(0) => break, // Stream closed
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
// TODO: This doesn't cope with partial reads, we will handle that later
|
#debug("Received {} bytes (server)", n);
|
||||||
|
buf.extend_from_slice(&read_buf[..n]);
|
||||||
|
loop {
|
||||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||||
|
if buf.len() >= (4 + len as usize) {
|
||||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||||
let question: (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
let (nonce, question): (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize query!");
|
||||||
// TODO: This should ideally be done in a separate task but that's not
|
let answer = self.handle(question).await;
|
||||||
// necessary for now
|
self.send(&mut stream, nonce, answer).await;
|
||||||
let answer = self.handle(question.1).await;
|
buf.drain(0..(4 + len as usize));
|
||||||
self.send(question.0, answer).await;
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||||
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
||||||
@ -317,17 +383,20 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
struct #client_connection_struct_name {
|
struct #client_connection_struct_name {
|
||||||
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
||||||
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
||||||
|
ready: std::sync::Arc<tokio::sync::Notify>,
|
||||||
stream: #stream_type,
|
stream: #stream_type,
|
||||||
}
|
}
|
||||||
impl #client_connection_struct_name {
|
impl #client_connection_struct_name {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
||||||
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
||||||
|
ready: std::sync::Arc<tokio::sync::Notify>,
|
||||||
stream: #stream_type,
|
stream: #stream_type,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
to_send,
|
to_send,
|
||||||
received,
|
received,
|
||||||
|
ready,
|
||||||
stream,
|
stream,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -341,19 +410,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
Some(msg) = self.to_send.recv() => {
|
Some(msg) = self.to_send.recv() => {
|
||||||
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
|
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
|
||||||
let len = serialized.len() as u32;
|
let len = serialized.len() as u32;
|
||||||
|
#debug("Sending `{}`", serialized);
|
||||||
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||||
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
|
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
|
||||||
},
|
},
|
||||||
Ok(_) = self.stream.readable() => {
|
Ok(_) = self.stream.readable() => {
|
||||||
match self.stream.try_read(&mut buf) {
|
let mut read_buf = [0; 1024];
|
||||||
Ok(0) => break, // Stream closed
|
match self.stream.try_read(&mut read_buf) {
|
||||||
|
Ok(0) => { break; },
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
// TODO: This doesn't cope with partial reads, we will handle that later
|
#debug("Received {} bytes (client)", n);
|
||||||
|
buf.extend_from_slice(&read_buf[..n]);
|
||||||
|
while buf.len() >= 4 {
|
||||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||||
|
if buf.len() >= (4 + len as usize) {
|
||||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||||
let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
||||||
|
if let #answer_enum_name::Ready = response.1 {
|
||||||
|
#debug("Received ready signal");
|
||||||
|
self.ready.notify_one();
|
||||||
|
} else {
|
||||||
self.received.send(response).await.expect("Failed to send response!");
|
self.received.send(response).await.expect("Failed to send response!");
|
||||||
buf.clear();
|
}
|
||||||
|
buf.drain(0..(4 + len as usize));
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||||
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
||||||
@ -386,26 +469,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
queries: #queries_struct_name,
|
queries: #queries_struct_name,
|
||||||
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||||
recv_queue: #client_recv_queue_wrapper,
|
recv_queue: #client_recv_queue_wrapper,
|
||||||
|
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
|
||||||
|
ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
|
||||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||||
}
|
}
|
||||||
impl #client_struct_name {
|
impl #client_struct_name {
|
||||||
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||||
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
||||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>) -> Self {
|
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||||
|
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
queries: #queries_struct_name::new(),
|
queries: #queries_struct_name::new(),
|
||||||
recv_queue: #client_recv_queue_wrapper::new(recv_queue),
|
recv_queue: #client_recv_queue_wrapper::new(recv_queue),
|
||||||
|
ready: ::std::sync::Arc::new(false.into()),
|
||||||
|
ready_notify,
|
||||||
send_queue,
|
send_queue,
|
||||||
connection_task,
|
connection_task,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub async fn connect<A: #stream_addr_trait>(addr: A) -> Result<Self, std::io::Error> {
|
pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> {
|
||||||
|
#info("Connecting to server at address {}", addr);
|
||||||
let stream = #stream_type::connect(addr).await?;
|
let stream = #stream_type::connect(addr).await?;
|
||||||
let (send_queue, to_send) = tokio::sync::mpsc::channel(16);
|
let (send_queue, to_send) = tokio::sync::mpsc::channel(16);
|
||||||
let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16);
|
let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16);
|
||||||
let connection = #client_connection_struct_name::new(to_send, to_recv, stream);
|
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
|
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
|
||||||
let connection_task = tokio::spawn(connection.run());
|
let connection_task = tokio::spawn(connection.run());
|
||||||
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task))))
|
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)), ready_notify))
|
||||||
}
|
}
|
||||||
pub fn close(self) {
|
pub fn close(self) {
|
||||||
if let Some(task) = self.connection_task {
|
if let Some(task) = self.connection_task {
|
||||||
@ -413,6 +503,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
||||||
|
// Wait until the connection is ready
|
||||||
|
if !*self.ready.lock().await {
|
||||||
|
self.ready_notify.notified().await;
|
||||||
|
let mut ready = self.ready.lock().await;
|
||||||
|
*ready = true;
|
||||||
|
}
|
||||||
let nonce = self.queries.len() as u64;
|
let nonce = self.queries.len() as u64;
|
||||||
let res = self.send_queue.send((nonce, query.clone())).await;
|
let res = self.send_queue.send((nonce, query.clone())).await;
|
||||||
match res {
|
match res {
|
||||||
@ -428,11 +524,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
// Check if we've received the answer for the query we're looking for
|
// Check if we've received the answer for the query we're looking for
|
||||||
if let Some(query) = self.queries.get(&id) {
|
if let Some(query) = self.queries.get(&id) {
|
||||||
if let Some(answer) = query.get_answer() {
|
if let Some(answer) = query.get_answer() {
|
||||||
|
#info("Found answer for query {}", id);
|
||||||
return Ok(answer);
|
return Ok(answer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match self.recv_queue.recv().await {
|
match self.recv_queue.recv().await {
|
||||||
Some((nonce, answer)) => {
|
Some((nonce, answer)) => {
|
||||||
|
#info("Received answer for query {}", nonce);
|
||||||
self.queries.set_answer(nonce, answer.clone());
|
self.queries.set_answer(nonce, answer.clone());
|
||||||
}
|
}
|
||||||
None => return Err(#error_enum_name::Closed),
|
None => return Err(#error_enum_name::Closed),
|
||||||
|
135
tests/client.rs
Normal file
135
tests/client.rs
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
/*
|
||||||
|
Eagle - A library for easy communication in full-stack Rust applications
|
||||||
|
Copyright (c) 2024 KodiCraft
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as
|
||||||
|
published by the Free Software Foundation, either version 3 of the
|
||||||
|
License, or (at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
use eagle::Protocol;
|
||||||
|
use env_logger::{Builder, Env};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::Once;
|
||||||
|
use tokio::sync::{
|
||||||
|
mpsc::{self, Receiver, Sender},
|
||||||
|
Notify,
|
||||||
|
};
|
||||||
|
|
||||||
|
static INIT: Once = Once::new();
|
||||||
|
pub fn init_logger() {
|
||||||
|
INIT.call_once(|| {
|
||||||
|
let env = Env::default()
|
||||||
|
.filter_or("RUST_LOG", "info")
|
||||||
|
.write_style_or("LOG_STYLE", "always");
|
||||||
|
|
||||||
|
Builder::from_env(env).format_timestamp_nanos().init();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Protocol)]
|
||||||
|
enum TestProtocol {
|
||||||
|
Addition((i32, i32), i32),
|
||||||
|
SomeKindOfQuestion(String, i32),
|
||||||
|
ThisRespondsWithAString(i32, String),
|
||||||
|
Void((), ()),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client() {
|
||||||
|
init_logger();
|
||||||
|
let (qtx, qrx) = mpsc::channel(16);
|
||||||
|
let (atx, arx) = mpsc::channel(16);
|
||||||
|
let ready_notify = Arc::new(Notify::new());
|
||||||
|
let client = TestProtocolClient::new(qtx, arx, None, ready_notify.clone());
|
||||||
|
ready_notify.notify_one();
|
||||||
|
let server = tokio::spawn(server_loop(qrx, atx));
|
||||||
|
let result = client.addition(2, 5).await.unwrap();
|
||||||
|
assert_eq!(result, 7);
|
||||||
|
let result = client
|
||||||
|
.some_kind_of_question("Hello, world!".to_string())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "Hello, world!".len() as i32);
|
||||||
|
let result = client.this_responds_with_a_string(42).await.unwrap();
|
||||||
|
assert_eq!(result, "The number is 42");
|
||||||
|
client.void().await.unwrap();
|
||||||
|
server.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn server_loop(
|
||||||
|
mut qrx: Receiver<(u64, __TestProtocolQuestion)>,
|
||||||
|
atx: Sender<(u64, __TestProtocolAnswer)>,
|
||||||
|
) {
|
||||||
|
loop {
|
||||||
|
if let Some((nonce, q)) = qrx.recv().await {
|
||||||
|
match q {
|
||||||
|
__TestProtocolQuestion::addition((a, b)) => {
|
||||||
|
atx.send((nonce, __TestProtocolAnswer::addition(a + b)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
__TestProtocolQuestion::some_kind_of_question(s) => {
|
||||||
|
atx.send((
|
||||||
|
nonce,
|
||||||
|
__TestProtocolAnswer::some_kind_of_question(s.len() as i32),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
__TestProtocolQuestion::this_responds_with_a_string(i) => {
|
||||||
|
atx.send((
|
||||||
|
nonce,
|
||||||
|
__TestProtocolAnswer::this_responds_with_a_string(format!(
|
||||||
|
"The number is {}",
|
||||||
|
i
|
||||||
|
)),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
__TestProtocolQuestion::void(()) => {
|
||||||
|
println!("Received void question");
|
||||||
|
atx.send((nonce, __TestProtocolAnswer::void(())))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn heavy_async() {
|
||||||
|
init_logger();
|
||||||
|
let (qtx, qrx) = mpsc::channel(16);
|
||||||
|
let (atx, arx) = mpsc::channel(16);
|
||||||
|
let ready_notify = Arc::new(Notify::new());
|
||||||
|
let client = TestProtocolClient::new(qtx, arx, None, ready_notify.clone());
|
||||||
|
ready_notify.notify_one();
|
||||||
|
let server = tokio::spawn(server_loop(qrx, atx));
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
for i in 0..100 {
|
||||||
|
let client = client.clone();
|
||||||
|
tasks.push(tokio::spawn(async move {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(
|
||||||
|
rand::random::<u64>() % 100,
|
||||||
|
))
|
||||||
|
.await;
|
||||||
|
let result = client.addition(i, i).await.unwrap();
|
||||||
|
assert_eq!(result, i + i);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
for task in tasks {
|
||||||
|
task.await.unwrap();
|
||||||
|
}
|
||||||
|
server.abort();
|
||||||
|
}
|
81
tests/full.rs
Normal file
81
tests/full.rs
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/*
|
||||||
|
Eagle - A library for easy communication in full-stack Rust applications
|
||||||
|
Copyright (c) 2024 KodiCraft
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as
|
||||||
|
published by the Free Software Foundation, either version 3 of the
|
||||||
|
License, or (at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
use eagle::Protocol;
|
||||||
|
use env_logger::{Builder, Env};
|
||||||
|
use std::sync::Once;
|
||||||
|
|
||||||
|
static INIT: Once = Once::new();
|
||||||
|
pub fn init_logger() {
|
||||||
|
INIT.call_once(|| {
|
||||||
|
let env = Env::default()
|
||||||
|
.filter_or("RUST_LOG", "info")
|
||||||
|
.write_style_or("LOG_STYLE", "always");
|
||||||
|
|
||||||
|
Builder::from_env(env).format_timestamp_nanos().init();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Protocol)]
|
||||||
|
enum TestProtocol {
|
||||||
|
Addition((i32, i32), i32),
|
||||||
|
SomeKindOfQuestion(String, i32),
|
||||||
|
ThisRespondsWithAString(i32, String),
|
||||||
|
Void((), ()),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct TrivialServer;
|
||||||
|
impl TestProtocolServerTrait for TrivialServer {
|
||||||
|
async fn addition(&mut self, a: i32, b: i32) -> i32 {
|
||||||
|
a + b
|
||||||
|
}
|
||||||
|
async fn some_kind_of_question(&mut self, s: String) -> i32 {
|
||||||
|
s.len() as i32
|
||||||
|
}
|
||||||
|
async fn this_responds_with_a_string(&mut self, i: i32) -> String {
|
||||||
|
format!("The number is {}", i)
|
||||||
|
}
|
||||||
|
async fn void(&mut self) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e() {
|
||||||
|
init_logger();
|
||||||
|
#[cfg(feature = "unix")]
|
||||||
|
let address = "/tmp/eagle-test.sock";
|
||||||
|
#[cfg(feature = "tcp")]
|
||||||
|
let address = format!("127.0.0.1:{}", 10000 + rand::random::<u64>() % 1000);
|
||||||
|
let server_task = tokio::spawn(TestProtocolServer::bind(TrivialServer, address.clone()));
|
||||||
|
// Wait for the server to start
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||||
|
let client = TestProtocolClient::connect(address).await.unwrap();
|
||||||
|
let res = client.addition(2, 5).await.unwrap();
|
||||||
|
// assert_eq!(client.addition(2, 5).await.unwrap(), 7);
|
||||||
|
// assert_eq!(
|
||||||
|
// client.some_kind_of_question("Hello, world!".to_string())
|
||||||
|
// .await
|
||||||
|
// .unwrap(),
|
||||||
|
// "Hello, world!".len() as i32
|
||||||
|
// );
|
||||||
|
// assert_eq!(
|
||||||
|
// client.this_responds_with_a_string(42).await.unwrap(),
|
||||||
|
// "The number is 42"
|
||||||
|
// );
|
||||||
|
// client.void().await.unwrap();
|
||||||
|
// server_task.abort();
|
||||||
|
}
|
113
tests/mod.rs
113
tests/mod.rs
@ -1,112 +1 @@
|
|||||||
/*
|
mod client;
|
||||||
Eagle - A library for easy communication in full-stack Rust applications
|
|
||||||
Copyright (c) 2024 KodiCraft
|
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU Affero General Public License as
|
|
||||||
published by the Free Software Foundation, either version 3 of the
|
|
||||||
License, or (at your option) any later version.
|
|
||||||
|
|
||||||
This program is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU Affero General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU Affero General Public License
|
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
*/
|
|
||||||
use eagle::Protocol;
|
|
||||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
|
||||||
|
|
||||||
#[derive(Protocol)]
|
|
||||||
enum TestProtocol {
|
|
||||||
Addition((i32, i32), i32),
|
|
||||||
SomeKindOfQuestion(String, i32),
|
|
||||||
ThisRespondsWithAString(i32, String),
|
|
||||||
Void((), ()),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn main() {
|
|
||||||
let (qtx, qrx) = mpsc::channel(16);
|
|
||||||
let (atx, arx) = mpsc::channel(16);
|
|
||||||
let client = TestProtocolClient::new(qtx, arx, None);
|
|
||||||
let server = tokio::spawn(server_loop(qrx, atx));
|
|
||||||
let result = client.addition(2, 5).await.unwrap();
|
|
||||||
assert_eq!(result, 7);
|
|
||||||
let result = client
|
|
||||||
.some_kind_of_question("Hello, world!".to_string())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(result, "Hello, world!".len() as i32);
|
|
||||||
let result = client.this_responds_with_a_string(42).await.unwrap();
|
|
||||||
assert_eq!(result, "The number is 42");
|
|
||||||
client.void().await.unwrap();
|
|
||||||
server.abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn server_loop(
|
|
||||||
mut qrx: Receiver<(u64, __TestProtocolQuestion)>,
|
|
||||||
atx: Sender<(u64, __TestProtocolAnswer)>,
|
|
||||||
) {
|
|
||||||
loop {
|
|
||||||
if let Some((nonce, q)) = qrx.recv().await {
|
|
||||||
match q {
|
|
||||||
__TestProtocolQuestion::addition((a, b)) => {
|
|
||||||
atx.send((nonce, __TestProtocolAnswer::addition(a + b)))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
__TestProtocolQuestion::some_kind_of_question(s) => {
|
|
||||||
atx.send((
|
|
||||||
nonce,
|
|
||||||
__TestProtocolAnswer::some_kind_of_question(s.len() as i32),
|
|
||||||
))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
__TestProtocolQuestion::this_responds_with_a_string(i) => {
|
|
||||||
atx.send((
|
|
||||||
nonce,
|
|
||||||
__TestProtocolAnswer::this_responds_with_a_string(format!(
|
|
||||||
"The number is {}",
|
|
||||||
i
|
|
||||||
)),
|
|
||||||
))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
__TestProtocolQuestion::void(()) => {
|
|
||||||
println!("Received void question");
|
|
||||||
atx.send((nonce, __TestProtocolAnswer::void(())))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn heavy_async() {
|
|
||||||
let (qtx, qrx) = mpsc::channel(16);
|
|
||||||
let (atx, arx) = mpsc::channel(16);
|
|
||||||
let client = TestProtocolClient::new(qtx, arx, None);
|
|
||||||
let server = tokio::spawn(server_loop(qrx, atx));
|
|
||||||
let mut tasks = Vec::new();
|
|
||||||
for i in 0..100 {
|
|
||||||
let client = client.clone();
|
|
||||||
tasks.push(tokio::spawn(async move {
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(
|
|
||||||
rand::random::<u64>() % 100,
|
|
||||||
))
|
|
||||||
.await;
|
|
||||||
let result = client.addition(i, i).await.unwrap();
|
|
||||||
assert_eq!(result, i + i);
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
for task in tasks {
|
|
||||||
task.await.unwrap();
|
|
||||||
}
|
|
||||||
server.abort();
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user