Compare commits

...

2 Commits

Author SHA1 Message Date
cd2cf3346f
Minor tweak in this awful pile of code
Some checks failed
Build library & run tests / build (tcp) (push) Successful in 51s
Build library & run tests / build (unix) (push) Failing after 54s
2024-06-24 15:36:29 +02:00
62262cb0fe
Fully implement communication 2024-06-24 15:34:14 +02:00
7 changed files with 502 additions and 151 deletions

142
Cargo.lock generated
View File

@ -17,6 +17,64 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "anstream"
version = "0.6.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b"
[[package]]
name = "anstyle-parse"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19"
dependencies = [
"anstyle",
"windows-sys 0.52.0",
]
[[package]]
name = "backtrace"
version = "0.3.73"
@ -65,10 +123,18 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "colorchoice"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]]
name = "eagle"
version = "0.2.0"
dependencies = [
"env_logger",
"log",
"proc-macro2",
"quote",
"rand",
@ -78,6 +144,29 @@ dependencies = [
"tokio",
]
[[package]]
name = "env_filter"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea"
dependencies = [
"log",
"regex",
]
[[package]]
name = "env_logger"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9"
dependencies = [
"anstream",
"anstyle",
"env_filter",
"humantime",
"log",
]
[[package]]
name = "getrandom"
version = "0.2.15"
@ -101,12 +190,30 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800"
[[package]]
name = "libc"
version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]]
name = "log"
version = "0.4.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
[[package]]
name = "memchr"
version = "2.7.4"
@ -212,6 +319,35 @@ dependencies = [
"getrandom",
]
[[package]]
name = "regex"
version = "1.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]]
name = "ron"
version = "0.8.1"
@ -305,6 +441,12 @@ version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"

View File

@ -7,9 +7,10 @@ publish = ["gitea"]
resolver = "2"
[features]
default = ["tcp"]
default = ["tcp", "log"]
tcp = ["tokio/net"]
unix = ["tokio/net"]
log = ["dep:log", "dep:env_logger"]
[dependencies]
proc-macro2 = "1.0.85"
@ -19,9 +20,13 @@ ron = "0.8.1"
serde = { version = "1.0.203", features = ["serde_derive"] }
syn = "2.0.66"
tokio = { version = "1.38.0", features = ["sync", "io-util"] }
env_logger = { version = "0.11.3", optional = true }
log = { version = "0.4.21", optional = true }
[dev-dependencies]
tokio = { version = "1.38.0", features = ["sync", "rt-multi-thread", "macros", "time", "io-util", "net"] }
env_logger = "0.11.3"
log = "0.4.21"
[lib]
proc-macro = true

View File

@ -49,6 +49,7 @@
};
devShell = with pkgs; mkShell {
nativeBuildInputs = with pkgs; [rustc cargo rustfmt pre-commit clippy cargo-expand];
RUST_LOG = "debug";
};
});
}

View File

@ -34,6 +34,33 @@ pub fn derive_protocol_derive(input: TokenStream) -> TokenStream {
fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let input = parse2::<DeriveInput>(input).unwrap();
// TODO: These logs should be filterable in some way
#[cfg(feature = "log")]
#[allow(unused_variables)]
let debug = quote! { log::debug! };
#[cfg(feature = "log")]
#[allow(unused_variables)]
let info = quote! { log::info! };
#[cfg(feature = "log")]
#[allow(unused_variables)]
let warn = quote! { log::warn! };
#[cfg(feature = "log")]
#[allow(unused_variables)]
let error = quote! { log::error! };
#[cfg(not(feature = "log"))]
#[allow(unused_variables)]
let debug = quote! { eprintln! };
#[cfg(not(feature = "log"))]
#[allow(unused_variables)]
let info = quote! { eprintln! };
#[cfg(not(feature = "log"))]
#[allow(unused_variables)]
let warn = quote! { eprintln! };
#[cfg(not(feature = "log"))]
#[allow(unused_variables)]
let error = quote! { eprintln! };
// Must be on an enum
let enum_ = match &input.data {
syn::Data::Enum(e) => e,
@ -102,7 +129,10 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
// There is a function that must be implemented to set the answer in the query enum
query_set_answer.push(quote! {
#query_enum_name::#var_name(question, answer_opt) => match answer {
#answer_enum_name::#var_name(answer) => {answer_opt.replace(answer);},
#answer_enum_name::#var_name(answer) => {
#debug("Setting answer for query {}", stringify!(#var_name));
answer_opt.replace(answer);
},
_ => panic!("The answer for this query is not the correct type."),
},
});
@ -116,17 +146,19 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
// There is a function that the server uses to call the appropriate function when receiving a query
server_handler.push(quote! {
#question_enum_name::#var_name(#question_tuple_args) => {
#info("Received query {}", stringify!(#var_name));
let answer = self.handler.lock().await.#var_name(#question_handler_args).await;
return #answer_enum_name::#var_name(answer);
},
});
// The function that the server needs to implement
server_trait.push(quote! {
async fn #var_name(&mut self, #question_args) -> #answer_type;
fn #var_name(&mut self, #question_args) -> impl std::future::Future<Output = #answer_type> + Send;
});
// The function that the client uses to communicate
client_impl.push(quote! {
pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> {
#info("Sending query {}", stringify!(#var_name));
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
let answer = self.recv_until(nonce).await?;
match answer {
@ -154,7 +186,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
let answer_enum = quote! {
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#vis enum #answer_enum_name {
#(#server_enum), *
#(#server_enum), *,
Ready
}
};
let question_enum = quote! {
@ -220,18 +253,42 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
let listener = #listener_type::bind(addr.as_ref())?;
};
let sc_struct = quote! {
#vis struct #server_connection_struct_name<H: #server_trait_name> {
#[derive(Clone)]
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + Clone + 'static> {
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>,
stream: #stream_type,
tasks: ::std::sync::Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>,
}
impl<H: #server_trait_name> #server_connection_struct_name<H> {
pub async fn bind<S: #stream_addr_trait>(handler: H, addr: S) -> Result<Self, std::io::Error> {
impl<H: #server_trait_name + ::std::marker::Send + Clone + 'static> #server_connection_struct_name<H> {
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
#info("Binding server to address {}", addr);
let handler = ::std::sync::Arc::new(tokio::sync::Mutex::new(handler));
let tasks = ::std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
let sc = Self {
handler,
tasks,
};
let sc_clone = sc.clone();
let acc_task = tokio::spawn(async move {
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
});
sc.tasks.lock().await.push(acc_task);
sc
}
pub async fn accept_connections<A: #stream_addr_trait>(
&self,
addr: A,
) -> Result<(), std::io::Error> {
#listener_statement
loop {
let (stream, _) = listener.accept().await?;
Ok(Self {
handler: ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)),
stream,
})
#info("Accepted connection from {}", stream.peer_addr()?);
let self_clone = self.clone();
let run_task = tokio::spawn(async move {
self_clone.run(stream).await;
});
self.tasks.lock().await.push(run_task);
}
}
async fn handle(&self, question: #question_enum_name) -> #answer_enum_name {
@ -240,32 +297,41 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
}
}
async fn send(&mut self, nonce: u64, answer: #answer_enum_name) {
async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) {
use tokio::io::AsyncWriteExt;
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
let len = serialized.len() as u32;
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
#debug("Sending `{}`", serialized);
stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
}
async fn run(mut self) {
async fn run(&self, mut stream: #stream_type) {
use tokio::io::AsyncWriteExt;
use tokio::io::AsyncReadExt;
let mut buf = Vec::with_capacity(1024);
self.send(&mut stream, 0, #answer_enum_name::Ready).await;
loop {
tokio::select! {
Ok(_) = self.stream.readable() => {
match self.stream.try_read(&mut buf) {
Ok(_) = stream.readable() => {
let mut read_buf = [0; 1024];
match stream.try_read(&mut read_buf) {
Ok(0) => break, // Stream closed
Ok(n) => {
// TODO: This doesn't cope with partial reads, we will handle that later
#debug("Received {} bytes (server)", n);
buf.extend_from_slice(&read_buf[..n]);
loop {
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
if buf.len() >= (4 + len as usize) {
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
let question: (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
// TODO: This should ideally be done in a separate task but that's not
// necessary for now
let answer = self.handle(question.1).await;
self.send(question.0, answer).await;
let (nonce, question): (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize query!");
let answer = self.handle(question).await;
self.send(&mut stream, nonce, answer).await;
buf.drain(0..(4 + len as usize));
} else {
break;
}
}
},
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
Err(e) => eprintln!("Error reading from stream: {:?}", e),
@ -317,17 +383,20 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
struct #client_connection_struct_name {
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
ready: std::sync::Arc<tokio::sync::Notify>,
stream: #stream_type,
}
impl #client_connection_struct_name {
pub fn new(
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
ready: std::sync::Arc<tokio::sync::Notify>,
stream: #stream_type,
) -> Self {
Self {
to_send,
received,
ready,
stream,
}
}
@ -341,19 +410,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
Some(msg) = self.to_send.recv() => {
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
let len = serialized.len() as u32;
#debug("Sending `{}`", serialized);
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
},
Ok(_) = self.stream.readable() => {
match self.stream.try_read(&mut buf) {
Ok(0) => break, // Stream closed
let mut read_buf = [0; 1024];
match self.stream.try_read(&mut read_buf) {
Ok(0) => { break; },
Ok(n) => {
// TODO: This doesn't cope with partial reads, we will handle that later
#debug("Received {} bytes (client)", n);
buf.extend_from_slice(&read_buf[..n]);
while buf.len() >= 4 {
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
if buf.len() >= (4 + len as usize) {
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
if let #answer_enum_name::Ready = response.1 {
#debug("Received ready signal");
self.ready.notify_one();
} else {
self.received.send(response).await.expect("Failed to send response!");
buf.clear();
}
buf.drain(0..(4 + len as usize));
} else {
break;
}
}
},
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
Err(e) => eprintln!("Error reading from stream: {:?}", e),
@ -386,26 +469,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
queries: #queries_struct_name,
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
recv_queue: #client_recv_queue_wrapper,
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
}
impl #client_struct_name {
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>) -> Self {
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
Self {
queries: #queries_struct_name::new(),
recv_queue: #client_recv_queue_wrapper::new(recv_queue),
ready: ::std::sync::Arc::new(false.into()),
ready_notify,
send_queue,
connection_task,
}
}
pub async fn connect<A: #stream_addr_trait>(addr: A) -> Result<Self, std::io::Error> {
pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> {
#info("Connecting to server at address {}", addr);
let stream = #stream_type::connect(addr).await?;
let (send_queue, to_send) = tokio::sync::mpsc::channel(16);
let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16);
let connection = #client_connection_struct_name::new(to_send, to_recv, stream);
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
let connection_task = tokio::spawn(connection.run());
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task))))
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)), ready_notify))
}
pub fn close(self) {
if let Some(task) = self.connection_task {
@ -413,6 +503,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
}
}
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
// Wait until the connection is ready
if !*self.ready.lock().await {
self.ready_notify.notified().await;
let mut ready = self.ready.lock().await;
*ready = true;
}
let nonce = self.queries.len() as u64;
let res = self.send_queue.send((nonce, query.clone())).await;
match res {
@ -428,11 +524,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
// Check if we've received the answer for the query we're looking for
if let Some(query) = self.queries.get(&id) {
if let Some(answer) = query.get_answer() {
#info("Found answer for query {}", id);
return Ok(answer);
}
}
match self.recv_queue.recv().await {
Some((nonce, answer)) => {
#info("Received answer for query {}", nonce);
self.queries.set_answer(nonce, answer.clone());
}
None => return Err(#error_enum_name::Closed),

135
tests/client.rs Normal file
View File

@ -0,0 +1,135 @@
/*
Eagle - A library for easy communication in full-stack Rust applications
Copyright (c) 2024 KodiCraft
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
use eagle::Protocol;
use env_logger::{Builder, Env};
use std::sync::Arc;
use std::sync::Once;
use tokio::sync::{
mpsc::{self, Receiver, Sender},
Notify,
};
static INIT: Once = Once::new();
pub fn init_logger() {
INIT.call_once(|| {
let env = Env::default()
.filter_or("RUST_LOG", "info")
.write_style_or("LOG_STYLE", "always");
Builder::from_env(env).format_timestamp_nanos().init();
});
}
#[derive(Protocol)]
enum TestProtocol {
Addition((i32, i32), i32),
SomeKindOfQuestion(String, i32),
ThisRespondsWithAString(i32, String),
Void((), ()),
}
#[tokio::test]
async fn client() {
init_logger();
let (qtx, qrx) = mpsc::channel(16);
let (atx, arx) = mpsc::channel(16);
let ready_notify = Arc::new(Notify::new());
let client = TestProtocolClient::new(qtx, arx, None, ready_notify.clone());
ready_notify.notify_one();
let server = tokio::spawn(server_loop(qrx, atx));
let result = client.addition(2, 5).await.unwrap();
assert_eq!(result, 7);
let result = client
.some_kind_of_question("Hello, world!".to_string())
.await
.unwrap();
assert_eq!(result, "Hello, world!".len() as i32);
let result = client.this_responds_with_a_string(42).await.unwrap();
assert_eq!(result, "The number is 42");
client.void().await.unwrap();
server.abort();
}
async fn server_loop(
mut qrx: Receiver<(u64, __TestProtocolQuestion)>,
atx: Sender<(u64, __TestProtocolAnswer)>,
) {
loop {
if let Some((nonce, q)) = qrx.recv().await {
match q {
__TestProtocolQuestion::addition((a, b)) => {
atx.send((nonce, __TestProtocolAnswer::addition(a + b)))
.await
.unwrap();
}
__TestProtocolQuestion::some_kind_of_question(s) => {
atx.send((
nonce,
__TestProtocolAnswer::some_kind_of_question(s.len() as i32),
))
.await
.unwrap();
}
__TestProtocolQuestion::this_responds_with_a_string(i) => {
atx.send((
nonce,
__TestProtocolAnswer::this_responds_with_a_string(format!(
"The number is {}",
i
)),
))
.await
.unwrap();
}
__TestProtocolQuestion::void(()) => {
println!("Received void question");
atx.send((nonce, __TestProtocolAnswer::void(())))
.await
.unwrap();
}
}
}
}
}
#[tokio::test]
async fn heavy_async() {
init_logger();
let (qtx, qrx) = mpsc::channel(16);
let (atx, arx) = mpsc::channel(16);
let ready_notify = Arc::new(Notify::new());
let client = TestProtocolClient::new(qtx, arx, None, ready_notify.clone());
ready_notify.notify_one();
let server = tokio::spawn(server_loop(qrx, atx));
let mut tasks = Vec::new();
for i in 0..100 {
let client = client.clone();
tasks.push(tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(
rand::random::<u64>() % 100,
))
.await;
let result = client.addition(i, i).await.unwrap();
assert_eq!(result, i + i);
}));
}
for task in tasks {
task.await.unwrap();
}
server.abort();
}

81
tests/full.rs Normal file
View File

@ -0,0 +1,81 @@
/*
Eagle - A library for easy communication in full-stack Rust applications
Copyright (c) 2024 KodiCraft
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
use eagle::Protocol;
use env_logger::{Builder, Env};
use std::sync::Once;
static INIT: Once = Once::new();
pub fn init_logger() {
INIT.call_once(|| {
let env = Env::default()
.filter_or("RUST_LOG", "info")
.write_style_or("LOG_STYLE", "always");
Builder::from_env(env).format_timestamp_nanos().init();
});
}
#[derive(Protocol)]
enum TestProtocol {
Addition((i32, i32), i32),
SomeKindOfQuestion(String, i32),
ThisRespondsWithAString(i32, String),
Void((), ()),
}
#[derive(Clone)]
struct TrivialServer;
impl TestProtocolServerTrait for TrivialServer {
async fn addition(&mut self, a: i32, b: i32) -> i32 {
a + b
}
async fn some_kind_of_question(&mut self, s: String) -> i32 {
s.len() as i32
}
async fn this_responds_with_a_string(&mut self, i: i32) -> String {
format!("The number is {}", i)
}
async fn void(&mut self) {}
}
#[tokio::test]
async fn e2e() {
init_logger();
#[cfg(feature = "unix")]
let address = "/tmp/eagle-test.sock";
#[cfg(feature = "tcp")]
let address = format!("127.0.0.1:{}", 10000 + rand::random::<u64>() % 1000);
let server_task = tokio::spawn(TestProtocolServer::bind(TrivialServer, address.clone()));
// Wait for the server to start
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let client = TestProtocolClient::connect(address).await.unwrap();
let res = client.addition(2, 5).await.unwrap();
// assert_eq!(client.addition(2, 5).await.unwrap(), 7);
// assert_eq!(
// client.some_kind_of_question("Hello, world!".to_string())
// .await
// .unwrap(),
// "Hello, world!".len() as i32
// );
// assert_eq!(
// client.this_responds_with_a_string(42).await.unwrap(),
// "The number is 42"
// );
// client.void().await.unwrap();
// server_task.abort();
}

View File

@ -1,112 +1 @@
/*
Eagle - A library for easy communication in full-stack Rust applications
Copyright (c) 2024 KodiCraft
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
use eagle::Protocol;
use tokio::sync::mpsc::{self, Receiver, Sender};
#[derive(Protocol)]
enum TestProtocol {
Addition((i32, i32), i32),
SomeKindOfQuestion(String, i32),
ThisRespondsWithAString(i32, String),
Void((), ()),
}
#[tokio::test]
async fn main() {
let (qtx, qrx) = mpsc::channel(16);
let (atx, arx) = mpsc::channel(16);
let client = TestProtocolClient::new(qtx, arx, None);
let server = tokio::spawn(server_loop(qrx, atx));
let result = client.addition(2, 5).await.unwrap();
assert_eq!(result, 7);
let result = client
.some_kind_of_question("Hello, world!".to_string())
.await
.unwrap();
assert_eq!(result, "Hello, world!".len() as i32);
let result = client.this_responds_with_a_string(42).await.unwrap();
assert_eq!(result, "The number is 42");
client.void().await.unwrap();
server.abort();
}
async fn server_loop(
mut qrx: Receiver<(u64, __TestProtocolQuestion)>,
atx: Sender<(u64, __TestProtocolAnswer)>,
) {
loop {
if let Some((nonce, q)) = qrx.recv().await {
match q {
__TestProtocolQuestion::addition((a, b)) => {
atx.send((nonce, __TestProtocolAnswer::addition(a + b)))
.await
.unwrap();
}
__TestProtocolQuestion::some_kind_of_question(s) => {
atx.send((
nonce,
__TestProtocolAnswer::some_kind_of_question(s.len() as i32),
))
.await
.unwrap();
}
__TestProtocolQuestion::this_responds_with_a_string(i) => {
atx.send((
nonce,
__TestProtocolAnswer::this_responds_with_a_string(format!(
"The number is {}",
i
)),
))
.await
.unwrap();
}
__TestProtocolQuestion::void(()) => {
println!("Received void question");
atx.send((nonce, __TestProtocolAnswer::void(())))
.await
.unwrap();
}
}
}
}
}
#[tokio::test]
async fn heavy_async() {
let (qtx, qrx) = mpsc::channel(16);
let (atx, arx) = mpsc::channel(16);
let client = TestProtocolClient::new(qtx, arx, None);
let server = tokio::spawn(server_loop(qrx, atx));
let mut tasks = Vec::new();
for i in 0..100 {
let client = client.clone();
tasks.push(tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(
rand::random::<u64>() % 100,
))
.await;
let result = client.addition(i, i).await.unwrap();
assert_eq!(result, i + i);
}));
}
for task in tasks {
task.await.unwrap();
}
server.abort();
}
mod client;