Attempt to deal with task/memory leaks
This commit is contained in:
66
src/lib.rs
66
src/lib.rs
@@ -103,11 +103,10 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
//! let handler = Handler;
|
||||
//! let address = "127.0.0.1:12345"; // Or, if using the 'unix' feature, "/tmp/eagle.sock"
|
||||
//! let server = ExampleServer::bind(handler, address).await;
|
||||
//! server.close().await;
|
||||
//! # });
|
||||
//! ```
|
||||
//! Once bound, the server will begin listening for incoming connections and
|
||||
//! queries. **You must remember to use the `close` method to shut down the server.**
|
||||
//! queries.
|
||||
//!
|
||||
//! On the client side, you can simply use the generated client struct to connect
|
||||
//! to the server and begin sending queries.
|
||||
@@ -139,12 +138,8 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
//! # tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; // Wait for the server to start
|
||||
//! let client = ExampleClient::connect(address).await.unwrap();
|
||||
//! assert_eq!(client.add(2, 5).await.unwrap(), 7);
|
||||
//! # server.close().await;
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
//! The client can be closed by calling the `close` method on the client struct.
|
||||
//! This will abort the connection.
|
||||
|
||||
#![warn(missing_docs)]
|
||||
use proc_macro::TokenStream;
|
||||
@@ -254,6 +249,29 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
let mut query_set_answer = Vec::new();
|
||||
let mut query_get_answer = Vec::new();
|
||||
|
||||
// TODO: This should just be regular code!
|
||||
let join_handle_guard_name = format_ident!("__{}JoinHandleGuard", name);
|
||||
let join_handle_guard = quote! {
|
||||
struct #join_handle_guard_name<T: ::std::fmt::Debug>(::tokio::task::JoinHandle<T>);
|
||||
impl<T: ::std::fmt::Debug> From<::tokio::task::JoinHandle<T>> for #join_handle_guard_name<T> {
|
||||
fn from(handle: ::tokio::task::JoinHandle<T>) -> Self {
|
||||
#debug("Creating join handle guard for task {:?}", handle);
|
||||
Self(handle)
|
||||
}
|
||||
}
|
||||
impl #join_handle_guard_name<()> {
|
||||
pub fn abort(self) {
|
||||
self.0.abort();
|
||||
}
|
||||
}
|
||||
impl<T: ::std::fmt::Debug> Drop for #join_handle_guard_name<T> {
|
||||
fn drop(&mut self) {
|
||||
#debug("Dropping join handle guard for task {:?}", self.0);
|
||||
self.0.abort();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for variant in &enum_.variants {
|
||||
// Every variant must have 2 fields
|
||||
// The first field is the question (serverbound), the second field is the answer (clientbound)
|
||||
@@ -443,7 +461,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + ::std::clone::Clone + 'static> {
|
||||
handler: ::std::sync::Arc<::tokio::sync::Mutex<H>>,
|
||||
tasks: ::std::sync::Arc<::tokio::sync::Mutex<::std::vec::Vec<tokio::task::JoinHandle<()>>>>,
|
||||
tasks: ::std::sync::Arc<::tokio::sync::Mutex<::std::vec::Vec<#join_handle_guard_name<()>>>>,
|
||||
}
|
||||
impl<H: #server_trait_name + ::std::marker::Send + std::clone::Clone + 'static> #server_connection_struct_name<H> {
|
||||
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + ::std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
|
||||
@@ -458,22 +476,16 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
let acc_task = ::tokio::spawn(async move {
|
||||
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
|
||||
});
|
||||
sc.tasks.lock().await.push(acc_task);
|
||||
sc.tasks.lock().await.push(acc_task.into());
|
||||
sc
|
||||
}
|
||||
|
||||
pub async fn close(self) {
|
||||
#info("Closing server");
|
||||
for task in self.tasks.lock().await.drain(..) {
|
||||
task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn accept_connections<A: #stream_addr_trait>(
|
||||
&self,
|
||||
addr: A,
|
||||
) -> ::std::result::Result<(), ::std::io::Error> {
|
||||
#listener_statement
|
||||
#info("Listening for clients on {:?}", listener.local_addr()?);
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
#info("Accepted connection from {:?}", stream.peer_addr()?);
|
||||
@@ -481,7 +493,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
let run_task = ::tokio::spawn(async move {
|
||||
self_clone.run(stream).await;
|
||||
});
|
||||
self.tasks.lock().await.push(run_task);
|
||||
self.tasks.lock().await.push(run_task.into());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -624,7 +636,10 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
#debug("Received ready signal");
|
||||
self.ready.notify_one();
|
||||
} else {
|
||||
self.received.send(response).await.expect("Failed to send response!");
|
||||
match self.received.send(response).await {
|
||||
::std::result::Result::Ok(_) => {},
|
||||
::std::result::Result::Err(e) => #error("Failed to send received answer to : {:?}", e),
|
||||
};
|
||||
}
|
||||
buf.drain(0..(4 + len as usize));
|
||||
} else {
|
||||
@@ -665,12 +680,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
recv_queue: #client_recv_queue_wrapper,
|
||||
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
|
||||
ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
|
||||
connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||
connection_task: ::std::option::Option<::std::sync::Arc<#join_handle_guard_name<()>>>,
|
||||
}
|
||||
impl #client_struct_name {
|
||||
pub fn new(send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||
recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
||||
connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||
connection_task: ::std::option::Option<::std::sync::Arc<#join_handle_guard_name<()>>>,
|
||||
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
|
||||
Self {
|
||||
queries: #queries_struct_name::new(),
|
||||
@@ -689,12 +704,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
|
||||
let connection_task = ::tokio::spawn(connection.run());
|
||||
Ok(Self::new(send_queue, recv_queue, ::std::option::Option::Some(::std::sync::Arc::new(connection_task)), ready_notify))
|
||||
}
|
||||
pub fn close(&mut self) {
|
||||
if let ::std::option::Option::Some(task) = self.connection_task.take() {
|
||||
task.abort();
|
||||
}
|
||||
Ok(Self::new(send_queue, recv_queue, ::std::option::Option::Some(::std::sync::Arc::new(connection_task.into())), ready_notify))
|
||||
}
|
||||
async fn send(&self, query: #question_enum_name) -> ::std::result::Result<u64, #error_enum_name> {
|
||||
// Wait until the connection is ready
|
||||
@@ -733,14 +743,10 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
}
|
||||
#(#client_impl)*
|
||||
}
|
||||
impl ::std::ops::Drop for #client_struct_name {
|
||||
fn drop(&mut self) {
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let expanded = quote! {
|
||||
#join_handle_guard // TODO: This should just be regular code and not in the macro!
|
||||
#error_enum
|
||||
#answer_enum
|
||||
#question_enum
|
||||
|
||||
Reference in New Issue
Block a user