Implement server
This commit is contained in:
parent
5e498f5882
commit
2934177373
51
src/lib.rs
51
src/lib.rs
@ -194,10 +194,14 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
let stream_type = quote! { tokio::net::TcpStream };
|
let stream_type = quote! { tokio::net::TcpStream };
|
||||||
#[cfg(feature = "tcp")]
|
#[cfg(feature = "tcp")]
|
||||||
let stream_addr_trait = quote! { tokio::net::ToSocketAddrs };
|
let stream_addr_trait = quote! { tokio::net::ToSocketAddrs };
|
||||||
|
#[cfg(feature = "tcp")]
|
||||||
|
let listener_type = quote! { tokio::net::TcpListener };
|
||||||
#[cfg(feature = "unix")]
|
#[cfg(feature = "unix")]
|
||||||
let stream_type = quote! { tokio::net::UnixStream };
|
let stream_type = quote! { tokio::net::UnixStream };
|
||||||
#[cfg(feature = "unix")]
|
#[cfg(feature = "unix")]
|
||||||
let stream_addr_trait = quote! { std::convert::AsRef<std::path::Path> };
|
let stream_addr_trait = quote! { std::convert::AsRef<std::path::Path> };
|
||||||
|
#[cfg(feature = "unix")]
|
||||||
|
let listener_type = quote! { tokio::net::UnixListener };
|
||||||
|
|
||||||
// Create a trait which the server will have to implement
|
// Create a trait which the server will have to implement
|
||||||
let server_trait = quote! {
|
let server_trait = quote! {
|
||||||
@ -210,15 +214,58 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
let sc_struct = quote! {
|
let sc_struct = quote! {
|
||||||
#vis struct #server_connection_struct_name<H: #server_trait_name> {
|
#vis struct #server_connection_struct_name<H: #server_trait_name> {
|
||||||
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>,
|
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>,
|
||||||
to_send: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
stream: #stream_type,
|
||||||
received: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
|
||||||
}
|
}
|
||||||
impl<H: #server_trait_name> #server_connection_struct_name<H> {
|
impl<H: #server_trait_name> #server_connection_struct_name<H> {
|
||||||
|
pub async fn bind<S: #stream_addr_trait>(handler: H, addr: S) -> Result<Self, std::io::Error> {
|
||||||
|
let listener = #listener_type::bind(addr).await?;
|
||||||
|
let (stream, _) = listener.accept().await?;
|
||||||
|
Ok(Self {
|
||||||
|
handler: ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)),
|
||||||
|
stream,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle(&self, question: #question_enum_name) -> #answer_enum_name {
|
async fn handle(&self, question: #question_enum_name) -> #answer_enum_name {
|
||||||
match question {
|
match question {
|
||||||
#(#server_handler)*
|
#(#server_handler)*
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn send(&mut self, nonce: u64, answer: #answer_enum_name) {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
|
||||||
|
let len = serialized.len() as u32;
|
||||||
|
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||||
|
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(mut self) {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
use tokio::io::AsyncReadExt;
|
||||||
|
let mut buf = Vec::with_capacity(1024);
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
Ok(_) = self.stream.readable() => {
|
||||||
|
match self.stream.try_read(&mut buf) {
|
||||||
|
Ok(0) => break, // Stream closed
|
||||||
|
Ok(n) => {
|
||||||
|
// TODO: This doesn't cope with partial reads, we will handle that later
|
||||||
|
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||||
|
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||||
|
let question: (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!");
|
||||||
|
// TODO: This should ideally be done in a separate task but that's not
|
||||||
|
// necessary for now
|
||||||
|
let answer = self.handle(question.1).await;
|
||||||
|
self.send(question.0, answer).await;
|
||||||
|
},
|
||||||
|
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||||
|
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user