diff --git a/src/lib.rs b/src/lib.rs index e8655fe..a37a529 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,10 +194,14 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let stream_type = quote! { tokio::net::TcpStream }; #[cfg(feature = "tcp")] let stream_addr_trait = quote! { tokio::net::ToSocketAddrs }; + #[cfg(feature = "tcp")] + let listener_type = quote! { tokio::net::TcpListener }; #[cfg(feature = "unix")] let stream_type = quote! { tokio::net::UnixStream }; #[cfg(feature = "unix")] let stream_addr_trait = quote! { std::convert::AsRef }; + #[cfg(feature = "unix")] + let listener_type = quote! { tokio::net::UnixListener }; // Create a trait which the server will have to implement let server_trait = quote! { @@ -210,15 +214,58 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let sc_struct = quote! { #vis struct #server_connection_struct_name { handler: ::std::sync::Arc>, - to_send: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, - received: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, + stream: #stream_type, } impl #server_connection_struct_name { + pub async fn bind(handler: H, addr: S) -> Result { + let listener = #listener_type::bind(addr).await?; + let (stream, _) = listener.accept().await?; + Ok(Self { + handler: ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)), + stream, + }) + } + async fn handle(&self, question: #question_enum_name) -> #answer_enum_name { match question { #(#server_handler)* } } + + async fn send(&mut self, nonce: u64, answer: #answer_enum_name) { + use tokio::io::AsyncWriteExt; + let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!"); + let len = serialized.len() as u32; + self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!"); + self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!"); + } + + async fn run(mut self) { + use tokio::io::AsyncWriteExt; + use tokio::io::AsyncReadExt; + let mut buf = Vec::with_capacity(1024); + loop { + tokio::select! { + Ok(_) = self.stream.readable() => { + match self.stream.try_read(&mut buf) { + Ok(0) => break, // Stream closed + Ok(n) => { + // TODO: This doesn't cope with partial reads, we will handle that later + let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32")); + let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string"); + let question: (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!"); + // TODO: This should ideally be done in a separate task but that's not + // necessary for now + let answer = self.handle(question.1).await; + self.send(question.0, answer).await; + }, + Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, + Err(e) => eprintln!("Error reading from stream: {:?}", e), + } + } + } + } + } } };