Fully implement communication
This commit is contained in:
166
src/lib.rs
166
src/lib.rs
@@ -34,6 +34,25 @@ pub fn derive_protocol_derive(input: TokenStream) -> TokenStream {
|
||||
|
||||
fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
|
||||
let input = parse2::<DeriveInput>(input).unwrap();
|
||||
|
||||
// TODO: These logs should be filterable in some way
|
||||
#[cfg(feature = "log")]
|
||||
let debug = quote! { log::debug! };
|
||||
#[cfg(feature = "log")]
|
||||
let info = quote! { log::info! };
|
||||
#[cfg(feature = "log")]
|
||||
let _warn = quote! { log::warn! };
|
||||
#[cfg(feature = "log")]
|
||||
let _error = quote! { log::error! };
|
||||
#[cfg(not(feature = "log"))]
|
||||
let debug = quote! { eprintln! };
|
||||
#[cfg(not(feature = "log"))]
|
||||
let info = quote! { eprintln! };
|
||||
#[cfg(not(feature = "log"))]
|
||||
let _warn = quote! { eprintln! };
|
||||
#[cfg(not(feature = "log"))]
|
||||
let _error = quote! { eprintln! };
|
||||
|
||||
// Must be on an enum
|
||||
let enum_ = match &input.data {
|
||||
syn::Data::Enum(e) => e,
|
||||
@@ -102,7 +121,10 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
// There is a function that must be implemented to set the answer in the query enum
|
||||
query_set_answer.push(quote! {
|
||||
#query_enum_name::#var_name(question, answer_opt) => match answer {
|
||||
#answer_enum_name::#var_name(answer) => {answer_opt.replace(answer);},
|
||||
#answer_enum_name::#var_name(answer) => {
|
||||
#debug("Setting answer for query {}", stringify!(#var_name));
|
||||
answer_opt.replace(answer);
|
||||
},
|
||||
_ => panic!("The answer for this query is not the correct type."),
|
||||
},
|
||||
});
|
||||
@@ -116,17 +138,19 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
// There is a function that the server uses to call the appropriate function when receiving a query
|
||||
server_handler.push(quote! {
|
||||
#question_enum_name::#var_name(#question_tuple_args) => {
|
||||
#info("Received query {}", stringify!(#var_name));
|
||||
let answer = self.handler.lock().await.#var_name(#question_handler_args).await;
|
||||
return #answer_enum_name::#var_name(answer);
|
||||
},
|
||||
});
|
||||
// The function that the server needs to implement
|
||||
server_trait.push(quote! {
|
||||
async fn #var_name(&mut self, #question_args) -> #answer_type;
|
||||
fn #var_name(&mut self, #question_args) -> impl std::future::Future<Output = #answer_type> + Send;
|
||||
});
|
||||
// The function that the client uses to communicate
|
||||
client_impl.push(quote! {
|
||||
pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> {
|
||||
#info("Sending query {}", stringify!(#var_name));
|
||||
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
|
||||
let answer = self.recv_until(nonce).await?;
|
||||
match answer {
|
||||
@@ -154,7 +178,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
let answer_enum = quote! {
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||
#vis enum #answer_enum_name {
|
||||
#(#server_enum), *
|
||||
#(#server_enum), *,
|
||||
Ready
|
||||
}
|
||||
};
|
||||
let question_enum = quote! {
|
||||
@@ -220,18 +245,42 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
let listener = #listener_type::bind(addr.as_ref())?;
|
||||
};
|
||||
let sc_struct = quote! {
|
||||
#vis struct #server_connection_struct_name<H: #server_trait_name> {
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + Clone + 'static> {
|
||||
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>,
|
||||
stream: #stream_type,
|
||||
tasks: ::std::sync::Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>,
|
||||
}
|
||||
impl<H: #server_trait_name> #server_connection_struct_name<H> {
|
||||
pub async fn bind<S: #stream_addr_trait>(handler: H, addr: S) -> Result<Self, std::io::Error> {
|
||||
impl<H: #server_trait_name + ::std::marker::Send + Clone + 'static> #server_connection_struct_name<H> {
|
||||
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
|
||||
#info("Binding server to address {}", addr);
|
||||
let handler = ::std::sync::Arc::new(tokio::sync::Mutex::new(handler));
|
||||
let tasks = ::std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
|
||||
let sc = Self {
|
||||
handler,
|
||||
tasks,
|
||||
};
|
||||
let sc_clone = sc.clone();
|
||||
let acc_task = tokio::spawn(async move {
|
||||
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
|
||||
});
|
||||
sc.tasks.lock().await.push(acc_task);
|
||||
sc
|
||||
}
|
||||
|
||||
pub async fn accept_connections<A: #stream_addr_trait>(
|
||||
&self,
|
||||
addr: A,
|
||||
) -> Result<(), std::io::Error> {
|
||||
#listener_statement
|
||||
let (stream, _) = listener.accept().await?;
|
||||
Ok(Self {
|
||||
handler: ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)),
|
||||
stream,
|
||||
})
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
#info("Accepted connection from {}", stream.peer_addr()?);
|
||||
let self_clone = self.clone();
|
||||
let run_task = tokio::spawn(async move {
|
||||
self_clone.run(stream).await;
|
||||
});
|
||||
self.tasks.lock().await.push(run_task);
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle(&self, question: #question_enum_name) -> #answer_enum_name {
|
||||
@@ -240,32 +289,41 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(&mut self, nonce: u64, answer: #answer_enum_name) {
|
||||
async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
|
||||
let len = serialized.len() as u32;
|
||||
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
|
||||
#debug("Sending `{}`", serialized);
|
||||
stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||
stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
|
||||
}
|
||||
|
||||
async fn run(mut self) {
|
||||
async fn run(&self, mut stream: #stream_type) {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::AsyncReadExt;
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
self.send(&mut stream, 0, #answer_enum_name::Ready).await;
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(_) = self.stream.readable() => {
|
||||
match self.stream.try_read(&mut buf) {
|
||||
Ok(_) = stream.readable() => {
|
||||
let mut read_buf = [0; 1024];
|
||||
match stream.try_read(&mut read_buf) {
|
||||
Ok(0) => break, // Stream closed
|
||||
Ok(n) => {
|
||||
// TODO: This doesn't cope with partial reads, we will handle that later
|
||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||
let question: (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
||||
// TODO: This should ideally be done in a separate task but that's not
|
||||
// necessary for now
|
||||
let answer = self.handle(question.1).await;
|
||||
self.send(question.0, answer).await;
|
||||
#debug("Received {} bytes (server)", n);
|
||||
buf.extend_from_slice(&read_buf[..n]);
|
||||
loop {
|
||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||
if buf.len() >= (4 + len as usize) {
|
||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||
let (nonce, question): (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize query!");
|
||||
let answer = self.handle(question).await;
|
||||
self.send(&mut stream, nonce, answer).await;
|
||||
buf.drain(0..(4 + len as usize));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
||||
@@ -317,17 +375,20 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
struct #client_connection_struct_name {
|
||||
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
||||
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
||||
ready: std::sync::Arc<tokio::sync::Notify>,
|
||||
stream: #stream_type,
|
||||
}
|
||||
impl #client_connection_struct_name {
|
||||
pub fn new(
|
||||
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
||||
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
||||
ready: std::sync::Arc<tokio::sync::Notify>,
|
||||
stream: #stream_type,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_send,
|
||||
received,
|
||||
ready,
|
||||
stream,
|
||||
}
|
||||
}
|
||||
@@ -341,19 +402,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
Some(msg) = self.to_send.recv() => {
|
||||
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
|
||||
let len = serialized.len() as u32;
|
||||
#debug("Sending `{}`", serialized);
|
||||
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
|
||||
},
|
||||
Ok(_) = self.stream.readable() => {
|
||||
match self.stream.try_read(&mut buf) {
|
||||
Ok(0) => break, // Stream closed
|
||||
let mut read_buf = [0; 1024];
|
||||
match self.stream.try_read(&mut read_buf) {
|
||||
Ok(0) => { break; },
|
||||
Ok(n) => {
|
||||
// TODO: This doesn't cope with partial reads, we will handle that later
|
||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||
let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
||||
self.received.send(response).await.expect("Failed to send response!");
|
||||
buf.clear();
|
||||
#debug("Received {} bytes (client)", n);
|
||||
buf.extend_from_slice(&read_buf[..n]);
|
||||
while buf.len() >= 4 {
|
||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||
if buf.len() >= (4 + len as usize) {
|
||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||
let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
||||
if let #answer_enum_name::Ready = response.1 {
|
||||
#debug("Received ready signal");
|
||||
self.ready.notify_one();
|
||||
} else {
|
||||
self.received.send(response).await.expect("Failed to send response!");
|
||||
}
|
||||
buf.drain(0..(4 + len as usize));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
||||
@@ -386,26 +461,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
queries: #queries_struct_name,
|
||||
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||
recv_queue: #client_recv_queue_wrapper,
|
||||
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
|
||||
ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
|
||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
impl #client_struct_name {
|
||||
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>) -> Self {
|
||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
|
||||
Self {
|
||||
queries: #queries_struct_name::new(),
|
||||
recv_queue: #client_recv_queue_wrapper::new(recv_queue),
|
||||
ready: ::std::sync::Arc::new(false.into()),
|
||||
ready_notify,
|
||||
send_queue,
|
||||
connection_task,
|
||||
}
|
||||
}
|
||||
pub async fn connect<A: #stream_addr_trait>(addr: A) -> Result<Self, std::io::Error> {
|
||||
pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> {
|
||||
#info("Connecting to server at address {}", addr);
|
||||
let stream = #stream_type::connect(addr).await?;
|
||||
let (send_queue, to_send) = tokio::sync::mpsc::channel(16);
|
||||
let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16);
|
||||
let connection = #client_connection_struct_name::new(to_send, to_recv, stream);
|
||||
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
|
||||
let connection_task = tokio::spawn(connection.run());
|
||||
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task))))
|
||||
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)), ready_notify))
|
||||
}
|
||||
pub fn close(self) {
|
||||
if let Some(task) = self.connection_task {
|
||||
@@ -413,6 +495,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
}
|
||||
}
|
||||
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
||||
// Wait until the connection is ready
|
||||
if !*self.ready.lock().await {
|
||||
self.ready_notify.notified().await;
|
||||
let mut ready = self.ready.lock().await;
|
||||
*ready = true;
|
||||
}
|
||||
let nonce = self.queries.len() as u64;
|
||||
let res = self.send_queue.send((nonce, query.clone())).await;
|
||||
match res {
|
||||
@@ -428,11 +516,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
||||
// Check if we've received the answer for the query we're looking for
|
||||
if let Some(query) = self.queries.get(&id) {
|
||||
if let Some(answer) = query.get_answer() {
|
||||
#info("Found answer for query {}", id);
|
||||
return Ok(answer);
|
||||
}
|
||||
}
|
||||
match self.recv_queue.recv().await {
|
||||
Some((nonce, answer)) => {
|
||||
#info("Received answer for query {}", nonce);
|
||||
self.queries.set_answer(nonce, answer.clone());
|
||||
}
|
||||
None => return Err(#error_enum_name::Closed),
|
||||
|
||||
Reference in New Issue
Block a user