Sanitization improvements
This commit is contained in:
		
							parent
							
								
									912b69ef93
								
							
						
					
					
						commit
						f4d65a2c51
					
				
							
								
								
									
										2
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							| @ -153,7 +153,7 @@ checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" | ||||
| 
 | ||||
| [[package]] | ||||
| name = "eagle" | ||||
| version = "0.2.3" | ||||
| version = "0.2.4" | ||||
| dependencies = [ | ||||
|  "env_logger", | ||||
|  "log", | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| [package] | ||||
| name = "eagle" | ||||
| version = "0.2.3" | ||||
| version = "0.2.4" | ||||
| description = "A simple library for creating RPC protocols." | ||||
| repository = "https://git.colon-three.com/kodi/eagle" | ||||
| authors = ["KodiCraft <kodi@kdcf.me>"] | ||||
|  | ||||
							
								
								
									
										172
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										172
									
								
								src/lib.rs
									
									
									
									
									
								
							| @ -198,28 +198,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|     // TODO: These logs should be filterable in some way
 | ||||
|     #[cfg(feature = "log")] | ||||
|     #[allow(unused_variables)] | ||||
|     let debug = quote! { log::debug! }; | ||||
|     let debug = quote! { ::log::debug! }; | ||||
|     #[cfg(feature = "log")] | ||||
|     #[allow(unused_variables)] | ||||
|     let info = quote! { log::info! }; | ||||
|     let info = quote! { ::log::info! }; | ||||
|     #[cfg(feature = "log")] | ||||
|     #[allow(unused_variables)] | ||||
|     let warn = quote! { log::warn! }; | ||||
|     let warn = quote! { ::log::warn! }; | ||||
|     #[cfg(feature = "log")] | ||||
|     #[allow(unused_variables)] | ||||
|     let error = quote! { log::error! }; | ||||
|     let error = quote! { ::log::error! }; | ||||
|     #[cfg(not(feature = "log"))] | ||||
|     #[allow(unused_variables)] | ||||
|     let debug = quote! { eprintln! }; | ||||
|     let debug = quote! { ::std::eprintln! }; | ||||
|     #[cfg(not(feature = "log"))] | ||||
|     #[allow(unused_variables)] | ||||
|     let info = quote! { eprintln! }; | ||||
|     let info = quote! { ::std::eprintln! }; | ||||
|     #[cfg(not(feature = "log"))] | ||||
|     #[allow(unused_variables)] | ||||
|     let warn = quote! { eprintln! }; | ||||
|     let warn = quote! { ::std::eprintln! }; | ||||
|     #[cfg(not(feature = "log"))] | ||||
|     #[allow(unused_variables)] | ||||
|     let error = quote! { eprintln! }; | ||||
|     let error = quote! { ::std::eprintln! }; | ||||
| 
 | ||||
|     // Must be on an enum
 | ||||
|     let enum_ = match &input.data { | ||||
| @ -299,8 +299,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|         // There is a function that must be implemented to get the answer from the query enum
 | ||||
|         query_get_answer.push(quote! { | ||||
|             #query_enum_name::#var_name(_, answer) => match answer { | ||||
|                 Some(answer) => Some(#answer_enum_name::#var_name(answer.clone())), | ||||
|                 None => None | ||||
|                 ::std::option::Option::Some(answer) => ::std::option::Option::Some(#answer_enum_name::#var_name(answer.clone())), | ||||
|                 ::std::option::Option::None => ::std::option::Option::None | ||||
|             }, | ||||
|         }); | ||||
|         // There is a function that the server uses to call the appropriate function when receiving a query
 | ||||
| @ -313,7 +313,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|         }); | ||||
|         // The function that the server needs to implement
 | ||||
|         server_trait.push(quote! { | ||||
|             fn #var_name(&mut self, #question_args) -> impl std::future::Future<Output = #answer_type> + Send; | ||||
|             fn #var_name(&mut self, #question_args) -> impl ::std::future::Future<Output = #answer_type> + Send; | ||||
|         }); | ||||
|         // The function that the client uses to communicate
 | ||||
|         client_impl.push(quote! { | ||||
| @ -323,13 +323,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|                 let answer = self.recv_until(nonce).await?; | ||||
|                 match answer { | ||||
|                     #answer_enum_name::#var_name(answer) => Ok(answer), | ||||
|                     _ => panic!("The answer for this query is not the correct type."), | ||||
|                     _ => ::std::panic!("The answer for this query is not the correct type."), | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
|         // The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
 | ||||
|         query_enum.push(quote! { | ||||
|             #var_name(#question_field, Option<#answer_type>) | ||||
|             #var_name(#question_field, ::std::option::Option<#answer_type>) | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
| @ -337,28 +337,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|     let error_enum = quote! { | ||||
|         #[derive(Debug)] | ||||
|         #vis enum #error_enum_name { | ||||
|             SendError(tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>), | ||||
|             SendError(::tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>), | ||||
|             Closed, | ||||
|         } | ||||
|     }; | ||||
|     // Create enums for the types of messages the server and client will use
 | ||||
| 
 | ||||
|     let answer_enum = quote! { | ||||
|         #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] | ||||
|         #[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)] | ||||
|         #vis enum #answer_enum_name { | ||||
|             #(#server_enum), *, | ||||
|             Ready | ||||
|         } | ||||
|     }; | ||||
|     let question_enum = quote! { | ||||
|         #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] | ||||
|         #[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)] | ||||
|         #vis enum #question_enum_name { | ||||
|             #(#client_enum), * | ||||
|         } | ||||
|     }; | ||||
|     // Create an enum to represent the queries the client has sent
 | ||||
|     let query_enum = quote! { | ||||
|         #[derive(Clone, Debug)] | ||||
|         #[derive(::std::clone::Clone, ::std::fmt::Debug)] | ||||
|         #vis enum #query_enum_name { | ||||
|             #(#query_enum), * | ||||
|         } | ||||
| @ -368,7 +368,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|                     #(#query_set_answer)* | ||||
|                 }; | ||||
|             } | ||||
|             pub fn get_answer(&self) -> Option<#answer_enum_name> { | ||||
|             pub fn get_answer(&self) -> ::std::option::Option<#answer_enum_name> { | ||||
|                 match self { | ||||
|                     #(#query_get_answer)* | ||||
|                 } | ||||
| @ -384,17 +384,17 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|     }; | ||||
| 
 | ||||
|     #[cfg(feature = "tcp")] | ||||
|     let stream_type = quote! { tokio::net::TcpStream }; | ||||
|     let stream_type = quote! { ::tokio::net::TcpStream }; | ||||
|     #[cfg(feature = "tcp")] | ||||
|     let stream_addr_trait = quote! { tokio::net::ToSocketAddrs }; | ||||
|     let stream_addr_trait = quote! { ::tokio::net::ToSocketAddrs }; | ||||
|     #[cfg(feature = "tcp")] | ||||
|     let listener_type = quote! { tokio::net::TcpListener }; | ||||
|     let listener_type = quote! { ::tokio::net::TcpListener }; | ||||
|     #[cfg(feature = "unix")] | ||||
|     let stream_type = quote! { tokio::net::UnixStream }; | ||||
|     let stream_type = quote! { ::tokio::net::UnixStream }; | ||||
|     #[cfg(feature = "unix")] | ||||
|     let stream_addr_trait = quote! { std::convert::AsRef<std::path::Path> }; | ||||
|     let stream_addr_trait = quote! { ::std::convert::AsRef<std::path::Path> }; | ||||
|     #[cfg(feature = "unix")] | ||||
|     let listener_type = quote! { tokio::net::UnixListener }; | ||||
|     let listener_type = quote! { ::tokio::net::UnixListener }; | ||||
| 
 | ||||
|     // Create a trait which the server will have to implement
 | ||||
|     let server_trait = quote! { | ||||
| @ -414,21 +414,21 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|     }; | ||||
|     let sc_struct = quote! { | ||||
|         #[derive(Clone)] | ||||
|         #vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + Clone + 'static> { | ||||
|             handler: ::std::sync::Arc<tokio::sync::Mutex<H>>, | ||||
|             tasks: ::std::sync::Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>, | ||||
|         #vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + ::std::clone::Clone + 'static> { | ||||
|             handler: ::std::sync::Arc<::tokio::sync::Mutex<H>>, | ||||
|             tasks: ::std::sync::Arc<::tokio::sync::Mutex<::std::vec::Vec<tokio::task::JoinHandle<()>>>>, | ||||
|         } | ||||
|         impl<H: #server_trait_name + ::std::marker::Send + Clone + 'static> #server_connection_struct_name<H> { | ||||
|             pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + std::fmt::Display + 'static>(handler: H, addr: A) -> Self { | ||||
|         impl<H: #server_trait_name + ::std::marker::Send + std::clone::Clone + 'static> #server_connection_struct_name<H> { | ||||
|             pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + ::std::fmt::Display + 'static>(handler: H, addr: A) -> Self { | ||||
|                 #info("Binding server to address {}", addr); | ||||
|                 let handler = ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)); | ||||
|                 let tasks = ::std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())); | ||||
|                 let handler = ::std::sync::Arc::new(::tokio::sync::Mutex::new(handler)); | ||||
|                 let tasks = ::std::sync::Arc::new(::tokio::sync::Mutex::new(::std::vec::Vec::new())); | ||||
|                 let sc = Self { | ||||
|                     handler, | ||||
|                     tasks, | ||||
|                 }; | ||||
|                 let sc_clone = sc.clone(); | ||||
|                 let acc_task = tokio::spawn(async move { | ||||
|                 let acc_task = ::tokio::spawn(async move { | ||||
|                     sc_clone.accept_connections(addr).await.expect("Failed to accept connections!"); | ||||
|                 }); | ||||
|                 sc.tasks.lock().await.push(acc_task); | ||||
| @ -438,13 +438,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|             pub async fn accept_connections<A: #stream_addr_trait>( | ||||
|                 &self, | ||||
|                 addr: A, | ||||
|             ) -> Result<(), std::io::Error> { | ||||
|             ) -> ::std::result::Result<(), ::std::io::Error> { | ||||
|                 #listener_statement | ||||
|                 loop { | ||||
|                     let (stream, _) = listener.accept().await?; | ||||
|                     #info("Accepted connection from {:?}", stream.peer_addr()?); | ||||
|                     let self_clone = self.clone(); | ||||
|                     let run_task = tokio::spawn(async move { | ||||
|                     let run_task = ::tokio::spawn(async move { | ||||
|                         self_clone.run(stream).await; | ||||
|                     }); | ||||
|                     self.tasks.lock().await.push(run_task); | ||||
| @ -458,7 +458,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|             } | ||||
| 
 | ||||
|             async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) { | ||||
|                 use tokio::io::AsyncWriteExt; | ||||
|                 use ::tokio::io::AsyncWriteExt; | ||||
|                 let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!"); | ||||
|                 let len = serialized.len() as u32; | ||||
|                 #debug("Sending `{}`", serialized); | ||||
| @ -467,24 +467,24 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|             } | ||||
| 
 | ||||
|             async fn run(&self, mut stream: #stream_type) { | ||||
|                 use tokio::io::AsyncWriteExt; | ||||
|                 use tokio::io::AsyncReadExt; | ||||
|                 let mut buf = Vec::with_capacity(1024); | ||||
|                 use ::tokio::io::AsyncWriteExt; | ||||
|                 use ::tokio::io::AsyncReadExt; | ||||
|                 let mut buf = ::std::vec::Vec::with_capacity(1024); | ||||
|                 self.send(&mut stream, 0, #answer_enum_name::Ready).await; | ||||
|                 loop { | ||||
|                     tokio::select! { | ||||
|                         Ok(_) = stream.readable() => { | ||||
|                     ::tokio::select! { | ||||
|                         ::std::result::Result::Ok(_) = stream.readable() => { | ||||
|                             let mut read_buf = [0; 1024]; | ||||
|                             match stream.try_read(&mut read_buf) { | ||||
|                                 Ok(0) => break, // Stream closed
 | ||||
|                                 Ok(n) => { | ||||
|                                 ::std::result::Result::Ok(0) => break, // Stream closed
 | ||||
|                                 ::std::result::Result::Ok(n) => { | ||||
|                                     #debug("Received {} bytes (server)", n); | ||||
|                                     buf.extend_from_slice(&read_buf[..n]); | ||||
|                                     while buf.len() >= 4 { | ||||
|                                         let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32")); | ||||
|                                         if buf.len() >= (4 + len as usize) { | ||||
|                                             let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string"); | ||||
|                                             let (nonce, question): (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize query!"); | ||||
|                                             let serialized = ::std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string"); | ||||
|                                             let (nonce, question): (u64, #question_enum_name) = ::ron::de::from_str(serialized).expect("Failed to deserialize query!"); | ||||
|                                             let answer = self.handle(question).await; | ||||
|                                             self.send(&mut stream, nonce, answer).await; | ||||
|                                             buf.drain(0..(4 + len as usize)); | ||||
| @ -493,8 +493,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|                                         } | ||||
|                                     } | ||||
|                                 }, | ||||
|                                 Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, | ||||
|                                 Err(e) => eprintln!("Error reading from stream: {:?}", e), | ||||
|                                 ::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, | ||||
|                                 ::std::result::Result::Err(e) => ::std::eprintln!("Error reading from stream: {:?}", e), | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
| @ -522,12 +522,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|                 self.queries.lock().unwrap().insert(nonce, query); | ||||
|             } | ||||
| 
 | ||||
|             pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> { | ||||
|             pub fn get(&self, nonce: &u64) -> ::std::option::Option<#query_enum_name> { | ||||
|                 self.queries.lock().unwrap().get(nonce).cloned() | ||||
|             } | ||||
| 
 | ||||
|             pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) { | ||||
|                 if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) { | ||||
|                 if let ::std::option::Option::Some(query) = self.queries.lock().unwrap().get_mut(&nonce) { | ||||
|                     query.set_answer(answer); | ||||
|                 } | ||||
|             } | ||||
| @ -541,16 +541,16 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|     // Create a struct to handle the connection from the client to the server
 | ||||
|     let cc_struct = quote! { | ||||
|         struct #client_connection_struct_name { | ||||
|             to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, | ||||
|             received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, | ||||
|             ready: std::sync::Arc<tokio::sync::Notify>, | ||||
|             to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, | ||||
|             received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, | ||||
|             ready: ::std::sync::Arc<tokio::sync::Notify>, | ||||
|             stream: #stream_type, | ||||
|         } | ||||
|         impl #client_connection_struct_name { | ||||
|             pub fn new( | ||||
|                 to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, | ||||
|                 received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, | ||||
|                 ready: std::sync::Arc<tokio::sync::Notify>, | ||||
|                 to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, | ||||
|                 received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, | ||||
|                 ready: ::std::sync::Arc<::tokio::sync::Notify>, | ||||
|                 stream: #stream_type, | ||||
|             ) -> Self { | ||||
|                 Self { | ||||
| @ -562,23 +562,23 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|             } | ||||
| 
 | ||||
|             pub async fn run(mut self) { | ||||
|                 use tokio::io::AsyncWriteExt; | ||||
|                 use tokio::io::AsyncReadExt; | ||||
|                 let mut buf = Vec::with_capacity(1024); | ||||
|                 use ::tokio::io::AsyncWriteExt; | ||||
|                 use ::tokio::io::AsyncReadExt; | ||||
|                 let mut buf = ::std::vec::Vec::with_capacity(1024); | ||||
|                 loop { | ||||
|                     tokio::select! { | ||||
|                         Some(msg) = self.to_send.recv() => { | ||||
|                     ::tokio::select! { | ||||
|                         ::std::option::Option::Some(msg) = self.to_send.recv() => { | ||||
|                             let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!"); | ||||
|                             let len = serialized.len() as u32; | ||||
|                             #debug("Sending `{}`", serialized); | ||||
|                             self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!"); | ||||
|                             self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!"); | ||||
|                         }, | ||||
|                         Ok(_) = self.stream.readable() => { | ||||
|                         ::std::result::Result::Ok(_) = self.stream.readable() => { | ||||
|                             let mut read_buf = [0; 1024]; | ||||
|                             match self.stream.try_read(&mut read_buf) { | ||||
|                                 Ok(0) => { break; }, | ||||
|                                 Ok(n) => { | ||||
|                                 ::std::result::Result::Ok(0) => { break; }, | ||||
|                                 ::std::result::Result::Ok(n) => { | ||||
|                                     #debug("Received {} bytes (client)", n); | ||||
|                                     buf.extend_from_slice(&read_buf[..n]); | ||||
|                                     while buf.len() >= 4 { | ||||
| @ -598,8 +598,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|                                         } | ||||
|                                     } | ||||
|                                 }, | ||||
|                                 Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, | ||||
|                                 Err(e) => eprintln!("Error reading from stream: {:?}", e), | ||||
|                                 ::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, | ||||
|                                 ::std::result::Result::Err(e) => eprintln!("Error reading from stream: {:?}", e), | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
| @ -612,31 +612,31 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|     let client_struct = quote! { | ||||
|         #[derive(Clone)] | ||||
|         struct #client_recv_queue_wrapper { | ||||
|             recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>, | ||||
|             recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>, | ||||
|         } | ||||
|         impl #client_recv_queue_wrapper { | ||||
|             fn new(recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { | ||||
|             fn new(recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { | ||||
|                 Self { | ||||
|                     recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)), | ||||
|                 } | ||||
|             } | ||||
|             async fn recv(&self) -> Option<(u64, #answer_enum_name)> { | ||||
|             async fn recv(&self) -> ::std::option::Option<(u64, #answer_enum_name)> { | ||||
|                 self.recv_queue.lock().await.recv().await | ||||
|             } | ||||
|         } | ||||
|         #[derive(Clone)] | ||||
|         #vis struct #client_struct_name { | ||||
|             queries: #queries_struct_name, | ||||
|             send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, | ||||
|             send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, | ||||
|             recv_queue: #client_recv_queue_wrapper, | ||||
|             ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>, | ||||
|             ready_notify: ::std::sync::Arc<tokio::sync::Notify>, | ||||
|             connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>, | ||||
|             connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>, | ||||
|         } | ||||
|         impl #client_struct_name { | ||||
|             pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, | ||||
|                         recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>, | ||||
|                         connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>, | ||||
|             pub fn new(send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, | ||||
|                         recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>, | ||||
|                         connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>, | ||||
|                         ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self { | ||||
|                 Self { | ||||
|                     queries: #queries_struct_name::new(), | ||||
| @ -650,19 +650,19 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|             pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> { | ||||
|                 #info("Connecting to server at address {}", addr); | ||||
|                 let stream = #stream_type::connect(addr).await?; | ||||
|                 let (send_queue, to_send) = tokio::sync::mpsc::channel(16); | ||||
|                 let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16); | ||||
|                 let (send_queue, to_send) = ::tokio::sync::mpsc::channel(16); | ||||
|                 let (to_recv, recv_queue) = ::tokio::sync::mpsc::channel(16); | ||||
|                 let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new()); | ||||
|                 let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream); | ||||
|                 let connection_task = tokio::spawn(connection.run()); | ||||
|                 Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)), ready_notify)) | ||||
|                 let connection_task = ::tokio::spawn(connection.run()); | ||||
|                 Ok(Self::new(send_queue, recv_queue, ::std::option::Option::Some(::std::sync::Arc::new(connection_task)), ready_notify)) | ||||
|             } | ||||
|             pub fn close(self) { | ||||
|                 if let Some(task) = self.connection_task { | ||||
|                 if let ::std::option::Option::Some(task) = self.connection_task { | ||||
|                     task.abort(); | ||||
|                 } | ||||
|             } | ||||
|             async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> { | ||||
|             async fn send(&self, query: #question_enum_name) -> ::std::result::Result<u64, #error_enum_name> { | ||||
|                 // Wait until the connection is ready
 | ||||
|                 if !*self.ready.lock().await { | ||||
|                     self.ready_notify.notified().await; | ||||
| @ -672,28 +672,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|                 let nonce = self.queries.len() as u64; | ||||
|                 let res = self.send_queue.send((nonce, query.clone())).await; | ||||
|                 match res { | ||||
|                     Ok(_) => { | ||||
|                     ::std::result::Result::Ok(_) => { | ||||
|                         self.queries.insert(nonce, query.into()); | ||||
|                         Ok(nonce) | ||||
|                         ::std::result::Result::Ok(nonce) | ||||
|                     } | ||||
|                     Err(e) => Err(#error_enum_name::SendError(e)), | ||||
|                     ::std::result::Result::Err(e) => ::std::result::Result::Err(#error_enum_name::SendError(e)), | ||||
|                 } | ||||
|             } | ||||
|             async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { | ||||
|             async fn recv_until(&self, id: u64) -> ::std::result::Result<#answer_enum_name, #error_enum_name> { | ||||
|                 loop { | ||||
|                     // Check if we've received the answer for the query we're looking for
 | ||||
|                     if let Some(query) = self.queries.get(&id) { | ||||
|                         if let Some(answer) = query.get_answer() { | ||||
|                     if let ::std::option::Option::Some(query) = self.queries.get(&id) { | ||||
|                         if let ::std::option::Option::Some(answer) = query.get_answer() { | ||||
|                             #info("Found answer for query {}", id); | ||||
|                             return Ok(answer); | ||||
|                         } | ||||
|                     } | ||||
|                     match self.recv_queue.recv().await { | ||||
|                         Some((nonce, answer)) => { | ||||
|                         ::std::option::Option::Some((nonce, answer)) => { | ||||
|                             #info("Received answer for query {}", nonce); | ||||
|                             self.queries.set_answer(nonce, answer.clone()); | ||||
|                         } | ||||
|                         None => return Err(#error_enum_name::Closed), | ||||
|                         ::std::option::Option::None => return ::std::result::Result::Err(#error_enum_name::Closed), | ||||
|                     }; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user