Implement server
This commit is contained in:
		
							parent
							
								
									5e498f5882
								
							
						
					
					
						commit
						2934177373
					
				
							
								
								
									
										51
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										51
									
								
								src/lib.rs
									
									
									
									
									
								
							| @ -194,10 +194,14 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|     let stream_type = quote! { tokio::net::TcpStream }; | ||||
|     #[cfg(feature = "tcp")] | ||||
|     let stream_addr_trait = quote! { tokio::net::ToSocketAddrs }; | ||||
|     #[cfg(feature = "tcp")] | ||||
|     let listener_type = quote! { tokio::net::TcpListener }; | ||||
|     #[cfg(feature = "unix")] | ||||
|     let stream_type = quote! { tokio::net::UnixStream }; | ||||
|     #[cfg(feature = "unix")] | ||||
|     let stream_addr_trait = quote! { std::convert::AsRef<std::path::Path> }; | ||||
|     #[cfg(feature = "unix")] | ||||
|     let listener_type = quote! { tokio::net::UnixListener }; | ||||
| 
 | ||||
|     // Create a trait which the server will have to implement
 | ||||
|     let server_trait = quote! { | ||||
| @ -210,15 +214,58 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream | ||||
|     let sc_struct = quote! { | ||||
|         #vis struct #server_connection_struct_name<H: #server_trait_name> { | ||||
|             handler: ::std::sync::Arc<tokio::sync::Mutex<H>>, | ||||
|             to_send: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, | ||||
|             received: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, | ||||
|             stream: #stream_type, | ||||
|         } | ||||
|         impl<H: #server_trait_name> #server_connection_struct_name<H> { | ||||
|             pub async fn bind<S: #stream_addr_trait>(handler: H, addr: S) -> Result<Self, std::io::Error> { | ||||
|                 let listener = #listener_type::bind(addr).await?; | ||||
|                 let (stream, _) = listener.accept().await?; | ||||
|                 Ok(Self { | ||||
|                     handler: ::std::sync::Arc::new(tokio::sync::Mutex::new(handler)), | ||||
|                     stream, | ||||
|                 }) | ||||
|             } | ||||
| 
 | ||||
|             async fn handle(&self, question: #question_enum_name) -> #answer_enum_name { | ||||
|                 match question { | ||||
|                     #(#server_handler)* | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             async fn send(&mut self, nonce: u64, answer: #answer_enum_name) { | ||||
|                 use tokio::io::AsyncWriteExt; | ||||
|                 let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!"); | ||||
|                 let len = serialized.len() as u32; | ||||
|                 self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!"); | ||||
|                 self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write response!"); | ||||
|             } | ||||
| 
 | ||||
|             async fn run(mut self) { | ||||
|                 use tokio::io::AsyncWriteExt; | ||||
|                 use tokio::io::AsyncReadExt; | ||||
|                 let mut buf = Vec::with_capacity(1024); | ||||
|                 loop { | ||||
|                     tokio::select! { | ||||
|                         Ok(_) = self.stream.readable() => { | ||||
|                             match self.stream.try_read(&mut buf) { | ||||
|                                 Ok(0) => break, // Stream closed
 | ||||
|                                 Ok(n) => { | ||||
|                                     // TODO: This doesn't cope with partial reads, we will handle that later
 | ||||
|                                     let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32")); | ||||
|                                     let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string"); | ||||
|                                     let question: (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!"); | ||||
|                                     // TODO: This should ideally be done in a separate task but that's not
 | ||||
|                                     // necessary for now
 | ||||
|                                     let answer = self.handle(question.1).await; | ||||
|                                     self.send(question.0, answer).await; | ||||
|                                 }, | ||||
|                                 Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, | ||||
|                                 Err(e) => eprintln!("Error reading from stream: {:?}", e), | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user