From f4d65a2c5195513e11887b0d4d72ea2e15e65c03 Mon Sep 17 00:00:00 2001 From: Kodi Craft Date: Mon, 24 Jun 2024 22:32:03 +0200 Subject: [PATCH] Sanitization improvements --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/lib.rs | 172 ++++++++++++++++++++++++++--------------------------- 3 files changed, 88 insertions(+), 88 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f78d359..45f7d28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -153,7 +153,7 @@ checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" [[package]] name = "eagle" -version = "0.2.3" +version = "0.2.4" dependencies = [ "env_logger", "log", diff --git a/Cargo.toml b/Cargo.toml index 2e73bf8..b12b088 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "eagle" -version = "0.2.3" +version = "0.2.4" description = "A simple library for creating RPC protocols." repository = "https://git.colon-three.com/kodi/eagle" authors = ["KodiCraft "] diff --git a/src/lib.rs b/src/lib.rs index 858f204..4c879d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,28 +198,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream // TODO: These logs should be filterable in some way #[cfg(feature = "log")] #[allow(unused_variables)] - let debug = quote! { log::debug! }; + let debug = quote! { ::log::debug! }; #[cfg(feature = "log")] #[allow(unused_variables)] - let info = quote! { log::info! }; + let info = quote! { ::log::info! }; #[cfg(feature = "log")] #[allow(unused_variables)] - let warn = quote! { log::warn! }; + let warn = quote! { ::log::warn! }; #[cfg(feature = "log")] #[allow(unused_variables)] - let error = quote! { log::error! }; + let error = quote! { ::log::error! }; #[cfg(not(feature = "log"))] #[allow(unused_variables)] - let debug = quote! { eprintln! }; + let debug = quote! { ::std::eprintln! }; #[cfg(not(feature = "log"))] #[allow(unused_variables)] - let info = quote! { eprintln! }; + let info = quote! { ::std::eprintln! }; #[cfg(not(feature = "log"))] #[allow(unused_variables)] - let warn = quote! { eprintln! }; + let warn = quote! { ::std::eprintln! }; #[cfg(not(feature = "log"))] #[allow(unused_variables)] - let error = quote! { eprintln! }; + let error = quote! { ::std::eprintln! }; // Must be on an enum let enum_ = match &input.data { @@ -299,8 +299,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream // There is a function that must be implemented to get the answer from the query enum query_get_answer.push(quote! { #query_enum_name::#var_name(_, answer) => match answer { - Some(answer) => Some(#answer_enum_name::#var_name(answer.clone())), - None => None + ::std::option::Option::Some(answer) => ::std::option::Option::Some(#answer_enum_name::#var_name(answer.clone())), + ::std::option::Option::None => ::std::option::Option::None }, }); // There is a function that the server uses to call the appropriate function when receiving a query @@ -313,7 +313,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream }); // The function that the server needs to implement server_trait.push(quote! { - fn #var_name(&mut self, #question_args) -> impl std::future::Future + Send; + fn #var_name(&mut self, #question_args) -> impl ::std::future::Future + Send; }); // The function that the client uses to communicate client_impl.push(quote! { @@ -323,13 +323,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let answer = self.recv_until(nonce).await?; match answer { #answer_enum_name::#var_name(answer) => Ok(answer), - _ => panic!("The answer for this query is not the correct type."), + _ => ::std::panic!("The answer for this query is not the correct type."), } } }); // The query enum is the same as the source enum, but the second field is always wrapped in a Option<> query_enum.push(quote! { - #var_name(#question_field, Option<#answer_type>) + #var_name(#question_field, ::std::option::Option<#answer_type>) }); } @@ -337,28 +337,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let error_enum = quote! { #[derive(Debug)] #vis enum #error_enum_name { - SendError(tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>), + SendError(::tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>), Closed, } }; // Create enums for the types of messages the server and client will use let answer_enum = quote! { - #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] + #[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)] #vis enum #answer_enum_name { #(#server_enum), *, Ready } }; let question_enum = quote! { - #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] + #[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)] #vis enum #question_enum_name { #(#client_enum), * } }; // Create an enum to represent the queries the client has sent let query_enum = quote! { - #[derive(Clone, Debug)] + #[derive(::std::clone::Clone, ::std::fmt::Debug)] #vis enum #query_enum_name { #(#query_enum), * } @@ -368,7 +368,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream #(#query_set_answer)* }; } - pub fn get_answer(&self) -> Option<#answer_enum_name> { + pub fn get_answer(&self) -> ::std::option::Option<#answer_enum_name> { match self { #(#query_get_answer)* } @@ -384,17 +384,17 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream }; #[cfg(feature = "tcp")] - let stream_type = quote! { tokio::net::TcpStream }; + let stream_type = quote! { ::tokio::net::TcpStream }; #[cfg(feature = "tcp")] - let stream_addr_trait = quote! { tokio::net::ToSocketAddrs }; + let stream_addr_trait = quote! { ::tokio::net::ToSocketAddrs }; #[cfg(feature = "tcp")] - let listener_type = quote! { tokio::net::TcpListener }; + let listener_type = quote! { ::tokio::net::TcpListener }; #[cfg(feature = "unix")] - let stream_type = quote! { tokio::net::UnixStream }; + let stream_type = quote! { ::tokio::net::UnixStream }; #[cfg(feature = "unix")] - let stream_addr_trait = quote! { std::convert::AsRef }; + let stream_addr_trait = quote! { ::std::convert::AsRef }; #[cfg(feature = "unix")] - let listener_type = quote! { tokio::net::UnixListener }; + let listener_type = quote! { ::tokio::net::UnixListener }; // Create a trait which the server will have to implement let server_trait = quote! { @@ -414,21 +414,21 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream }; let sc_struct = quote! { #[derive(Clone)] - #vis struct #server_connection_struct_name { - handler: ::std::sync::Arc>, - tasks: ::std::sync::Arc>>>, + #vis struct #server_connection_struct_name { + handler: ::std::sync::Arc<::tokio::sync::Mutex>, + tasks: ::std::sync::Arc<::tokio::sync::Mutex<::std::vec::Vec>>>, } - impl #server_connection_struct_name { - pub async fn bind(handler: H, addr: A) -> Self { + impl #server_connection_struct_name { + pub async fn bind(handler: H, addr: A) -> Self { #info("Binding server to address {}", addr); - let handler = ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)); - let tasks = ::std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())); + let handler = ::std::sync::Arc::new(::tokio::sync::Mutex::new(handler)); + let tasks = ::std::sync::Arc::new(::tokio::sync::Mutex::new(::std::vec::Vec::new())); let sc = Self { handler, tasks, }; let sc_clone = sc.clone(); - let acc_task = tokio::spawn(async move { + let acc_task = ::tokio::spawn(async move { sc_clone.accept_connections(addr).await.expect("Failed to accept connections!"); }); sc.tasks.lock().await.push(acc_task); @@ -438,13 +438,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream pub async fn accept_connections( &self, addr: A, - ) -> Result<(), std::io::Error> { + ) -> ::std::result::Result<(), ::std::io::Error> { #listener_statement loop { let (stream, _) = listener.accept().await?; #info("Accepted connection from {:?}", stream.peer_addr()?); let self_clone = self.clone(); - let run_task = tokio::spawn(async move { + let run_task = ::tokio::spawn(async move { self_clone.run(stream).await; }); self.tasks.lock().await.push(run_task); @@ -458,7 +458,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) { - use tokio::io::AsyncWriteExt; + use ::tokio::io::AsyncWriteExt; let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!"); let len = serialized.len() as u32; #debug("Sending `{}`", serialized); @@ -467,24 +467,24 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } async fn run(&self, mut stream: #stream_type) { - use tokio::io::AsyncWriteExt; - use tokio::io::AsyncReadExt; - let mut buf = Vec::with_capacity(1024); + use ::tokio::io::AsyncWriteExt; + use ::tokio::io::AsyncReadExt; + let mut buf = ::std::vec::Vec::with_capacity(1024); self.send(&mut stream, 0, #answer_enum_name::Ready).await; loop { - tokio::select! { - Ok(_) = stream.readable() => { + ::tokio::select! { + ::std::result::Result::Ok(_) = stream.readable() => { let mut read_buf = [0; 1024]; match stream.try_read(&mut read_buf) { - Ok(0) => break, // Stream closed - Ok(n) => { + ::std::result::Result::Ok(0) => break, // Stream closed + ::std::result::Result::Ok(n) => { #debug("Received {} bytes (server)", n); buf.extend_from_slice(&read_buf[..n]); while buf.len() >= 4 { let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32")); if buf.len() >= (4 + len as usize) { - let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string"); - let (nonce, question): (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize query!"); + let serialized = ::std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string"); + let (nonce, question): (u64, #question_enum_name) = ::ron::de::from_str(serialized).expect("Failed to deserialize query!"); let answer = self.handle(question).await; self.send(&mut stream, nonce, answer).await; buf.drain(0..(4 + len as usize)); @@ -493,8 +493,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } } }, - Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, - Err(e) => eprintln!("Error reading from stream: {:?}", e), + ::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, + ::std::result::Result::Err(e) => ::std::eprintln!("Error reading from stream: {:?}", e), } } } @@ -522,12 +522,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream self.queries.lock().unwrap().insert(nonce, query); } - pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> { + pub fn get(&self, nonce: &u64) -> ::std::option::Option<#query_enum_name> { self.queries.lock().unwrap().get(nonce).cloned() } pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) { - if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) { + if let ::std::option::Option::Some(query) = self.queries.lock().unwrap().get_mut(&nonce) { query.set_answer(answer); } } @@ -541,16 +541,16 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream // Create a struct to handle the connection from the client to the server let cc_struct = quote! { struct #client_connection_struct_name { - to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, - received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, - ready: std::sync::Arc, + to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, + received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, + ready: ::std::sync::Arc, stream: #stream_type, } impl #client_connection_struct_name { pub fn new( - to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, - received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, - ready: std::sync::Arc, + to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, + received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, + ready: ::std::sync::Arc<::tokio::sync::Notify>, stream: #stream_type, ) -> Self { Self { @@ -562,23 +562,23 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } pub async fn run(mut self) { - use tokio::io::AsyncWriteExt; - use tokio::io::AsyncReadExt; - let mut buf = Vec::with_capacity(1024); + use ::tokio::io::AsyncWriteExt; + use ::tokio::io::AsyncReadExt; + let mut buf = ::std::vec::Vec::with_capacity(1024); loop { - tokio::select! { - Some(msg) = self.to_send.recv() => { + ::tokio::select! { + ::std::option::Option::Some(msg) = self.to_send.recv() => { let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!"); let len = serialized.len() as u32; #debug("Sending `{}`", serialized); self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!"); self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!"); }, - Ok(_) = self.stream.readable() => { + ::std::result::Result::Ok(_) = self.stream.readable() => { let mut read_buf = [0; 1024]; match self.stream.try_read(&mut read_buf) { - Ok(0) => { break; }, - Ok(n) => { + ::std::result::Result::Ok(0) => { break; }, + ::std::result::Result::Ok(n) => { #debug("Received {} bytes (client)", n); buf.extend_from_slice(&read_buf[..n]); while buf.len() >= 4 { @@ -598,8 +598,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } } }, - Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, - Err(e) => eprintln!("Error reading from stream: {:?}", e), + ::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, + ::std::result::Result::Err(e) => eprintln!("Error reading from stream: {:?}", e), } } } @@ -612,31 +612,31 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let client_struct = quote! { #[derive(Clone)] struct #client_recv_queue_wrapper { - recv_queue: ::std::sync::Arc<::tokio::sync::Mutex>>, + recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>, } impl #client_recv_queue_wrapper { - fn new(recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { + fn new(recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { Self { recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)), } } - async fn recv(&self) -> Option<(u64, #answer_enum_name)> { + async fn recv(&self) -> ::std::option::Option<(u64, #answer_enum_name)> { self.recv_queue.lock().await.recv().await } } #[derive(Clone)] #vis struct #client_struct_name { queries: #queries_struct_name, - send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, + send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: #client_recv_queue_wrapper, ready: ::std::sync::Arc>, ready_notify: ::std::sync::Arc, - connection_task: Option<::std::sync::Arc>>, + connection_task: ::std::option::Option<::std::sync::Arc>>, } impl #client_struct_name { - pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, - recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>, - connection_task: Option<::std::sync::Arc>>, + pub fn new(send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, + recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>, + connection_task: ::std::option::Option<::std::sync::Arc>>, ready_notify: ::std::sync::Arc) -> Self { Self { queries: #queries_struct_name::new(), @@ -650,19 +650,19 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream pub async fn connect(addr: A) -> Result { #info("Connecting to server at address {}", addr); let stream = #stream_type::connect(addr).await?; - let (send_queue, to_send) = tokio::sync::mpsc::channel(16); - let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16); + let (send_queue, to_send) = ::tokio::sync::mpsc::channel(16); + let (to_recv, recv_queue) = ::tokio::sync::mpsc::channel(16); let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new()); let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream); - let connection_task = tokio::spawn(connection.run()); - Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)), ready_notify)) + let connection_task = ::tokio::spawn(connection.run()); + Ok(Self::new(send_queue, recv_queue, ::std::option::Option::Some(::std::sync::Arc::new(connection_task)), ready_notify)) } pub fn close(self) { - if let Some(task) = self.connection_task { + if let ::std::option::Option::Some(task) = self.connection_task { task.abort(); } } - async fn send(&self, query: #question_enum_name) -> Result { + async fn send(&self, query: #question_enum_name) -> ::std::result::Result { // Wait until the connection is ready if !*self.ready.lock().await { self.ready_notify.notified().await; @@ -672,28 +672,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let nonce = self.queries.len() as u64; let res = self.send_queue.send((nonce, query.clone())).await; match res { - Ok(_) => { + ::std::result::Result::Ok(_) => { self.queries.insert(nonce, query.into()); - Ok(nonce) + ::std::result::Result::Ok(nonce) } - Err(e) => Err(#error_enum_name::SendError(e)), + ::std::result::Result::Err(e) => ::std::result::Result::Err(#error_enum_name::SendError(e)), } } - async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> { + async fn recv_until(&self, id: u64) -> ::std::result::Result<#answer_enum_name, #error_enum_name> { loop { // Check if we've received the answer for the query we're looking for - if let Some(query) = self.queries.get(&id) { - if let Some(answer) = query.get_answer() { + if let ::std::option::Option::Some(query) = self.queries.get(&id) { + if let ::std::option::Option::Some(answer) = query.get_answer() { #info("Found answer for query {}", id); return Ok(answer); } } match self.recv_queue.recv().await { - Some((nonce, answer)) => { + ::std::option::Option::Some((nonce, answer)) => { #info("Received answer for query {}", nonce); self.queries.set_answer(nonce, answer.clone()); } - None => return Err(#error_enum_name::Closed), + ::std::option::Option::None => return ::std::result::Result::Err(#error_enum_name::Closed), }; } }