Sanitization improvements
Some checks failed
Build library & run tests / build (unix) (push) Successful in 19s
Build library & run tests / build (tcp) (push) Successful in 20s
Build library & run tests / docs (push) Successful in 23s
Publish library / publish (push) Failing after 26s

This commit is contained in:
Kodi Craft 2024-06-24 22:32:03 +02:00
parent 912b69ef93
commit f4d65a2c51
Signed by: kodi
GPG Key ID: 69D9EED60B242822
3 changed files with 88 additions and 88 deletions

2
Cargo.lock generated
View File

@ -153,7 +153,7 @@ checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]] [[package]]
name = "eagle" name = "eagle"
version = "0.2.3" version = "0.2.4"
dependencies = [ dependencies = [
"env_logger", "env_logger",
"log", "log",

View File

@ -1,6 +1,6 @@
[package] [package]
name = "eagle" name = "eagle"
version = "0.2.3" version = "0.2.4"
description = "A simple library for creating RPC protocols." description = "A simple library for creating RPC protocols."
repository = "https://git.colon-three.com/kodi/eagle" repository = "https://git.colon-three.com/kodi/eagle"
authors = ["KodiCraft <kodi@kdcf.me>"] authors = ["KodiCraft <kodi@kdcf.me>"]

View File

@ -198,28 +198,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
// TODO: These logs should be filterable in some way // TODO: These logs should be filterable in some way
#[cfg(feature = "log")] #[cfg(feature = "log")]
#[allow(unused_variables)] #[allow(unused_variables)]
let debug = quote! { log::debug! }; let debug = quote! { ::log::debug! };
#[cfg(feature = "log")] #[cfg(feature = "log")]
#[allow(unused_variables)] #[allow(unused_variables)]
let info = quote! { log::info! }; let info = quote! { ::log::info! };
#[cfg(feature = "log")] #[cfg(feature = "log")]
#[allow(unused_variables)] #[allow(unused_variables)]
let warn = quote! { log::warn! }; let warn = quote! { ::log::warn! };
#[cfg(feature = "log")] #[cfg(feature = "log")]
#[allow(unused_variables)] #[allow(unused_variables)]
let error = quote! { log::error! }; let error = quote! { ::log::error! };
#[cfg(not(feature = "log"))] #[cfg(not(feature = "log"))]
#[allow(unused_variables)] #[allow(unused_variables)]
let debug = quote! { eprintln! }; let debug = quote! { ::std::eprintln! };
#[cfg(not(feature = "log"))] #[cfg(not(feature = "log"))]
#[allow(unused_variables)] #[allow(unused_variables)]
let info = quote! { eprintln! }; let info = quote! { ::std::eprintln! };
#[cfg(not(feature = "log"))] #[cfg(not(feature = "log"))]
#[allow(unused_variables)] #[allow(unused_variables)]
let warn = quote! { eprintln! }; let warn = quote! { ::std::eprintln! };
#[cfg(not(feature = "log"))] #[cfg(not(feature = "log"))]
#[allow(unused_variables)] #[allow(unused_variables)]
let error = quote! { eprintln! }; let error = quote! { ::std::eprintln! };
// Must be on an enum // Must be on an enum
let enum_ = match &input.data { let enum_ = match &input.data {
@ -299,8 +299,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
// There is a function that must be implemented to get the answer from the query enum // There is a function that must be implemented to get the answer from the query enum
query_get_answer.push(quote! { query_get_answer.push(quote! {
#query_enum_name::#var_name(_, answer) => match answer { #query_enum_name::#var_name(_, answer) => match answer {
Some(answer) => Some(#answer_enum_name::#var_name(answer.clone())), ::std::option::Option::Some(answer) => ::std::option::Option::Some(#answer_enum_name::#var_name(answer.clone())),
None => None ::std::option::Option::None => ::std::option::Option::None
}, },
}); });
// There is a function that the server uses to call the appropriate function when receiving a query // There is a function that the server uses to call the appropriate function when receiving a query
@ -313,7 +313,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
}); });
// The function that the server needs to implement // The function that the server needs to implement
server_trait.push(quote! { server_trait.push(quote! {
fn #var_name(&mut self, #question_args) -> impl std::future::Future<Output = #answer_type> + Send; fn #var_name(&mut self, #question_args) -> impl ::std::future::Future<Output = #answer_type> + Send;
}); });
// The function that the client uses to communicate // The function that the client uses to communicate
client_impl.push(quote! { client_impl.push(quote! {
@ -323,13 +323,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
let answer = self.recv_until(nonce).await?; let answer = self.recv_until(nonce).await?;
match answer { match answer {
#answer_enum_name::#var_name(answer) => Ok(answer), #answer_enum_name::#var_name(answer) => Ok(answer),
_ => panic!("The answer for this query is not the correct type."), _ => ::std::panic!("The answer for this query is not the correct type."),
} }
} }
}); });
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<> // The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
query_enum.push(quote! { query_enum.push(quote! {
#var_name(#question_field, Option<#answer_type>) #var_name(#question_field, ::std::option::Option<#answer_type>)
}); });
} }
@ -337,28 +337,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
let error_enum = quote! { let error_enum = quote! {
#[derive(Debug)] #[derive(Debug)]
#vis enum #error_enum_name { #vis enum #error_enum_name {
SendError(tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>), SendError(::tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>),
Closed, Closed,
} }
}; };
// Create enums for the types of messages the server and client will use // Create enums for the types of messages the server and client will use
let answer_enum = quote! { let answer_enum = quote! {
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)]
#vis enum #answer_enum_name { #vis enum #answer_enum_name {
#(#server_enum), *, #(#server_enum), *,
Ready Ready
} }
}; };
let question_enum = quote! { let question_enum = quote! {
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)]
#vis enum #question_enum_name { #vis enum #question_enum_name {
#(#client_enum), * #(#client_enum), *
} }
}; };
// Create an enum to represent the queries the client has sent // Create an enum to represent the queries the client has sent
let query_enum = quote! { let query_enum = quote! {
#[derive(Clone, Debug)] #[derive(::std::clone::Clone, ::std::fmt::Debug)]
#vis enum #query_enum_name { #vis enum #query_enum_name {
#(#query_enum), * #(#query_enum), *
} }
@ -368,7 +368,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
#(#query_set_answer)* #(#query_set_answer)*
}; };
} }
pub fn get_answer(&self) -> Option<#answer_enum_name> { pub fn get_answer(&self) -> ::std::option::Option<#answer_enum_name> {
match self { match self {
#(#query_get_answer)* #(#query_get_answer)*
} }
@ -384,17 +384,17 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
}; };
#[cfg(feature = "tcp")] #[cfg(feature = "tcp")]
let stream_type = quote! { tokio::net::TcpStream }; let stream_type = quote! { ::tokio::net::TcpStream };
#[cfg(feature = "tcp")] #[cfg(feature = "tcp")]
let stream_addr_trait = quote! { tokio::net::ToSocketAddrs }; let stream_addr_trait = quote! { ::tokio::net::ToSocketAddrs };
#[cfg(feature = "tcp")] #[cfg(feature = "tcp")]
let listener_type = quote! { tokio::net::TcpListener }; let listener_type = quote! { ::tokio::net::TcpListener };
#[cfg(feature = "unix")] #[cfg(feature = "unix")]
let stream_type = quote! { tokio::net::UnixStream }; let stream_type = quote! { ::tokio::net::UnixStream };
#[cfg(feature = "unix")] #[cfg(feature = "unix")]
let stream_addr_trait = quote! { std::convert::AsRef<std::path::Path> }; let stream_addr_trait = quote! { ::std::convert::AsRef<std::path::Path> };
#[cfg(feature = "unix")] #[cfg(feature = "unix")]
let listener_type = quote! { tokio::net::UnixListener }; let listener_type = quote! { ::tokio::net::UnixListener };
// Create a trait which the server will have to implement // Create a trait which the server will have to implement
let server_trait = quote! { let server_trait = quote! {
@ -414,21 +414,21 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
}; };
let sc_struct = quote! { let sc_struct = quote! {
#[derive(Clone)] #[derive(Clone)]
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + Clone + 'static> { #vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + ::std::clone::Clone + 'static> {
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>, handler: ::std::sync::Arc<::tokio::sync::Mutex<H>>,
tasks: ::std::sync::Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>, tasks: ::std::sync::Arc<::tokio::sync::Mutex<::std::vec::Vec<tokio::task::JoinHandle<()>>>>,
} }
impl<H: #server_trait_name + ::std::marker::Send + Clone + 'static> #server_connection_struct_name<H> { impl<H: #server_trait_name + ::std::marker::Send + std::clone::Clone + 'static> #server_connection_struct_name<H> {
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + std::fmt::Display + 'static>(handler: H, addr: A) -> Self { pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + ::std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
#info("Binding server to address {}", addr); #info("Binding server to address {}", addr);
let handler = ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)); let handler = ::std::sync::Arc::new(::tokio::sync::Mutex::new(handler));
let tasks = ::std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())); let tasks = ::std::sync::Arc::new(::tokio::sync::Mutex::new(::std::vec::Vec::new()));
let sc = Self { let sc = Self {
handler, handler,
tasks, tasks,
}; };
let sc_clone = sc.clone(); let sc_clone = sc.clone();
let acc_task = tokio::spawn(async move { let acc_task = ::tokio::spawn(async move {
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!"); sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
}); });
sc.tasks.lock().await.push(acc_task); sc.tasks.lock().await.push(acc_task);
@ -438,13 +438,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
pub async fn accept_connections<A: #stream_addr_trait>( pub async fn accept_connections<A: #stream_addr_trait>(
&self, &self,
addr: A, addr: A,
) -> Result<(), std::io::Error> { ) -> ::std::result::Result<(), ::std::io::Error> {
#listener_statement #listener_statement
loop { loop {
let (stream, _) = listener.accept().await?; let (stream, _) = listener.accept().await?;
#info("Accepted connection from {:?}", stream.peer_addr()?); #info("Accepted connection from {:?}", stream.peer_addr()?);
let self_clone = self.clone(); let self_clone = self.clone();
let run_task = tokio::spawn(async move { let run_task = ::tokio::spawn(async move {
self_clone.run(stream).await; self_clone.run(stream).await;
}); });
self.tasks.lock().await.push(run_task); self.tasks.lock().await.push(run_task);
@ -458,7 +458,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
} }
async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) { async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) {
use tokio::io::AsyncWriteExt; use ::tokio::io::AsyncWriteExt;
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!"); let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
let len = serialized.len() as u32; let len = serialized.len() as u32;
#debug("Sending `{}`", serialized); #debug("Sending `{}`", serialized);
@ -467,24 +467,24 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
} }
async fn run(&self, mut stream: #stream_type) { async fn run(&self, mut stream: #stream_type) {
use tokio::io::AsyncWriteExt; use ::tokio::io::AsyncWriteExt;
use tokio::io::AsyncReadExt; use ::tokio::io::AsyncReadExt;
let mut buf = Vec::with_capacity(1024); let mut buf = ::std::vec::Vec::with_capacity(1024);
self.send(&mut stream, 0, #answer_enum_name::Ready).await; self.send(&mut stream, 0, #answer_enum_name::Ready).await;
loop { loop {
tokio::select! { ::tokio::select! {
Ok(_) = stream.readable() => { ::std::result::Result::Ok(_) = stream.readable() => {
let mut read_buf = [0; 1024]; let mut read_buf = [0; 1024];
match stream.try_read(&mut read_buf) { match stream.try_read(&mut read_buf) {
Ok(0) => break, // Stream closed ::std::result::Result::Ok(0) => break, // Stream closed
Ok(n) => { ::std::result::Result::Ok(n) => {
#debug("Received {} bytes (server)", n); #debug("Received {} bytes (server)", n);
buf.extend_from_slice(&read_buf[..n]); buf.extend_from_slice(&read_buf[..n]);
while buf.len() >= 4 { while buf.len() >= 4 {
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32")); let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
if buf.len() >= (4 + len as usize) { if buf.len() >= (4 + len as usize) {
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string"); let serialized = ::std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
let (nonce, question): (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize query!"); let (nonce, question): (u64, #question_enum_name) = ::ron::de::from_str(serialized).expect("Failed to deserialize query!");
let answer = self.handle(question).await; let answer = self.handle(question).await;
self.send(&mut stream, nonce, answer).await; self.send(&mut stream, nonce, answer).await;
buf.drain(0..(4 + len as usize)); buf.drain(0..(4 + len as usize));
@ -493,8 +493,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
} }
} }
}, },
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, ::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
Err(e) => eprintln!("Error reading from stream: {:?}", e), ::std::result::Result::Err(e) => ::std::eprintln!("Error reading from stream: {:?}", e),
} }
} }
} }
@ -522,12 +522,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
self.queries.lock().unwrap().insert(nonce, query); self.queries.lock().unwrap().insert(nonce, query);
} }
pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> { pub fn get(&self, nonce: &u64) -> ::std::option::Option<#query_enum_name> {
self.queries.lock().unwrap().get(nonce).cloned() self.queries.lock().unwrap().get(nonce).cloned()
} }
pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) { pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) {
if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) { if let ::std::option::Option::Some(query) = self.queries.lock().unwrap().get_mut(&nonce) {
query.set_answer(answer); query.set_answer(answer);
} }
} }
@ -541,16 +541,16 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
// Create a struct to handle the connection from the client to the server // Create a struct to handle the connection from the client to the server
let cc_struct = quote! { let cc_struct = quote! {
struct #client_connection_struct_name { struct #client_connection_struct_name {
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
ready: std::sync::Arc<tokio::sync::Notify>, ready: ::std::sync::Arc<tokio::sync::Notify>,
stream: #stream_type, stream: #stream_type,
} }
impl #client_connection_struct_name { impl #client_connection_struct_name {
pub fn new( pub fn new(
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
ready: std::sync::Arc<tokio::sync::Notify>, ready: ::std::sync::Arc<::tokio::sync::Notify>,
stream: #stream_type, stream: #stream_type,
) -> Self { ) -> Self {
Self { Self {
@ -562,23 +562,23 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
} }
pub async fn run(mut self) { pub async fn run(mut self) {
use tokio::io::AsyncWriteExt; use ::tokio::io::AsyncWriteExt;
use tokio::io::AsyncReadExt; use ::tokio::io::AsyncReadExt;
let mut buf = Vec::with_capacity(1024); let mut buf = ::std::vec::Vec::with_capacity(1024);
loop { loop {
tokio::select! { ::tokio::select! {
Some(msg) = self.to_send.recv() => { ::std::option::Option::Some(msg) = self.to_send.recv() => {
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!"); let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
let len = serialized.len() as u32; let len = serialized.len() as u32;
#debug("Sending `{}`", serialized); #debug("Sending `{}`", serialized);
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!"); self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!"); self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
}, },
Ok(_) = self.stream.readable() => { ::std::result::Result::Ok(_) = self.stream.readable() => {
let mut read_buf = [0; 1024]; let mut read_buf = [0; 1024];
match self.stream.try_read(&mut read_buf) { match self.stream.try_read(&mut read_buf) {
Ok(0) => { break; }, ::std::result::Result::Ok(0) => { break; },
Ok(n) => { ::std::result::Result::Ok(n) => {
#debug("Received {} bytes (client)", n); #debug("Received {} bytes (client)", n);
buf.extend_from_slice(&read_buf[..n]); buf.extend_from_slice(&read_buf[..n]);
while buf.len() >= 4 { while buf.len() >= 4 {
@ -598,8 +598,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
} }
} }
}, },
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, ::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
Err(e) => eprintln!("Error reading from stream: {:?}", e), ::std::result::Result::Err(e) => eprintln!("Error reading from stream: {:?}", e),
} }
} }
} }
@ -612,31 +612,31 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
let client_struct = quote! { let client_struct = quote! {
#[derive(Clone)] #[derive(Clone)]
struct #client_recv_queue_wrapper { struct #client_recv_queue_wrapper {
recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>, recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>,
} }
impl #client_recv_queue_wrapper { impl #client_recv_queue_wrapper {
fn new(recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { fn new(recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
Self { Self {
recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)), recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)),
} }
} }
async fn recv(&self) -> Option<(u64, #answer_enum_name)> { async fn recv(&self) -> ::std::option::Option<(u64, #answer_enum_name)> {
self.recv_queue.lock().await.recv().await self.recv_queue.lock().await.recv().await
} }
} }
#[derive(Clone)] #[derive(Clone)]
#vis struct #client_struct_name { #vis struct #client_struct_name {
queries: #queries_struct_name, queries: #queries_struct_name,
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
recv_queue: #client_recv_queue_wrapper, recv_queue: #client_recv_queue_wrapper,
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>, ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
ready_notify: ::std::sync::Arc<tokio::sync::Notify>, ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>, connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
} }
impl #client_struct_name { impl #client_struct_name {
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, pub fn new(send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>, recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>, connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self { ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
Self { Self {
queries: #queries_struct_name::new(), queries: #queries_struct_name::new(),
@ -650,19 +650,19 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> { pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> {
#info("Connecting to server at address {}", addr); #info("Connecting to server at address {}", addr);
let stream = #stream_type::connect(addr).await?; let stream = #stream_type::connect(addr).await?;
let (send_queue, to_send) = tokio::sync::mpsc::channel(16); let (send_queue, to_send) = ::tokio::sync::mpsc::channel(16);
let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16); let (to_recv, recv_queue) = ::tokio::sync::mpsc::channel(16);
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new()); let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream); let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
let connection_task = tokio::spawn(connection.run()); let connection_task = ::tokio::spawn(connection.run());
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)), ready_notify)) Ok(Self::new(send_queue, recv_queue, ::std::option::Option::Some(::std::sync::Arc::new(connection_task)), ready_notify))
} }
pub fn close(self) { pub fn close(self) {
if let Some(task) = self.connection_task { if let ::std::option::Option::Some(task) = self.connection_task {
task.abort(); task.abort();
} }
} }
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> { async fn send(&self, query: #question_enum_name) -> ::std::result::Result<u64, #error_enum_name> {
// Wait until the connection is ready // Wait until the connection is ready
if !*self.ready.lock().await { if !*self.ready.lock().await {
self.ready_notify.notified().await; self.ready_notify.notified().await;
@ -672,28 +672,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
let nonce = self.queries.len() as u64; let nonce = self.queries.len() as u64;
let res = self.send_queue.send((nonce, query.clone())).await; let res = self.send_queue.send((nonce, query.clone())).await;
match res { match res {
Ok(_) => { ::std::result::Result::Ok(_) => {
self.queries.insert(nonce, query.into()); self.queries.insert(nonce, query.into());
Ok(nonce) ::std::result::Result::Ok(nonce)
} }
Err(e) => Err(#error_enum_name::SendError(e)), ::std::result::Result::Err(e) => ::std::result::Result::Err(#error_enum_name::SendError(e)),
} }
} }
async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { async fn recv_until(&self, id: u64) -> ::std::result::Result<#answer_enum_name, #error_enum_name> {
loop { loop {
// Check if we've received the answer for the query we're looking for // Check if we've received the answer for the query we're looking for
if let Some(query) = self.queries.get(&id) { if let ::std::option::Option::Some(query) = self.queries.get(&id) {
if let Some(answer) = query.get_answer() { if let ::std::option::Option::Some(answer) = query.get_answer() {
#info("Found answer for query {}", id); #info("Found answer for query {}", id);
return Ok(answer); return Ok(answer);
} }
} }
match self.recv_queue.recv().await { match self.recv_queue.recv().await {
Some((nonce, answer)) => { ::std::option::Option::Some((nonce, answer)) => {
#info("Received answer for query {}", nonce); #info("Received answer for query {}", nonce);
self.queries.set_answer(nonce, answer.clone()); self.queries.set_answer(nonce, answer.clone());
} }
None => return Err(#error_enum_name::Closed), ::std::option::Option::None => return ::std::result::Result::Err(#error_enum_name::Closed),
}; };
} }
} }