diff --git a/src/lib.rs b/src/lib.rs index 2cfba2f..2473b58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let queries_struct_name = format_ident!("__{}Queries", name); let client_connection_struct_name = format_ident!("__{}Connection", name); let server_trait_name = format_ident!("{}ServerTrait", name); + let server_connection_struct_name = format_ident!("{}Server", name); let client_struct_name = format_ident!("{}Client", name); let vis = &input.vis; @@ -60,6 +61,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let mut client_impl = Vec::new(); let mut client_enum = Vec::new(); + let mut server_handler = Vec::new(); + let mut query_enum = Vec::new(); let mut query_from_question_enum = Vec::new(); let mut query_set_answer = Vec::new(); @@ -109,6 +112,13 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream None => None }, }); + // There is a function that the server uses to call the appropriate function when receiving a query + server_handler.push(quote! { + #question_enum_name::#var_name(#question_args) => { + let answer = self.handler.#var_name(#question_tuple_args); + return #answer_enum_name::#var_name(answer); + } + }); // The function that the server needs to implement server_trait.push(quote! { fn #var_name(&mut self, #question_args) -> #answer_type; @@ -178,12 +188,32 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } } }; + + #[cfg(feature = "tcp")] + let stream_type = quote! { tokio::net::TcpStream }; + #[cfg(feature = "tcp")] + let stream_addr_trait = quote! { tokio::net::ToSocketAddrs }; + #[cfg(feature = "unix")] + let stream_type = quote! { tokio::net::UnixStream }; + #[cfg(feature = "unix")] + let stream_addr_trait = quote! { std::convert::AsRef }; + // Create a trait which the server will have to implement let server_trait = quote! { #vis trait #server_trait_name { #(#server_trait)* } }; + + // Create a struct to implement the communication between the server and the client + let sc_struct = quote! { + #vis struct #server_connection_struct_name { + handler: ::std::sync::Arc>, + to_send: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>, + received: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, + } + }; + // Create a struct to hold queries behind an Arc> to enable async access // TODO: It might be a good idea to just make this a generic struct and write it in actual code // rather than in this macro @@ -218,16 +248,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } } }; - // Create a struct to handle the connection from the client to the server - #[cfg(feature = "tcp")] - let stream_type = quote! { tokio::net::TcpStream }; - #[cfg(feature = "tcp")] - let stream_addr_trait = quote! { tokio::net::ToSocketAddrs }; - #[cfg(feature = "unix")] - let stream_type = quote! { tokio::net::UnixStream }; - #[cfg(feature = "unix")] - let stream_addr_trait = quote! { std::convert::AsRef }; + // Create a struct to handle the connection from the client to the server let cc_struct = quote! { struct #client_connection_struct_name { to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, @@ -253,8 +275,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream let mut buf = Vec::with_capacity(1024); loop { tokio::select! { - Some((nonce, query)) = self.to_send.recv() => { - let serialized = ron::ser::to_string(&query).expect("Failed to serialize query!"); + Some(msg) = self.to_send.recv() => { + let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!"); let len = serialized.len() as u32; self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!"); self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!"); @@ -266,8 +288,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream // TODO: This doesn't cope with partial reads, we will handle that later let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32")); let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string"); - let query: #answer_enum_name = ron::de::from_str(serialized).expect("Failed to deserialize query!"); - self.received.send((0, query)).await.expect("Failed to send query!"); + let response: (u64, #answer_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize response!"); + self.received.send(response).await.expect("Failed to send response!"); buf.clear(); }, Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; }, @@ -365,6 +387,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream #query_enum #queries_struct #server_trait + #sc_struct #cc_struct #client_struct };