Client can now perform requests asynchronously
This commit is contained in:
parent
b4f1e1b092
commit
b9128a465c
73
src/lib.rs
73
src/lib.rs
@ -41,6 +41,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
let answer_enum_name = format_ident!("__{}Answer", name);
|
let answer_enum_name = format_ident!("__{}Answer", name);
|
||||||
let question_enum_name = format_ident!("__{}Question", name);
|
let question_enum_name = format_ident!("__{}Question", name);
|
||||||
let query_enum_name = format_ident!("__{}Query", name);
|
let query_enum_name = format_ident!("__{}Query", name);
|
||||||
|
let queries_struct_name = format_ident!("__{}Queries", name);
|
||||||
let server_trait_name = format_ident!("{}Server", name);
|
let server_trait_name = format_ident!("{}Server", name);
|
||||||
let client_struct_name = format_ident!("{}Client", name);
|
let client_struct_name = format_ident!("{}Client", name);
|
||||||
|
|
||||||
@ -106,7 +107,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
});
|
});
|
||||||
// The function that the client uses to communicate
|
// The function that the client uses to communicate
|
||||||
client_impl.push(quote! {
|
client_impl.push(quote! {
|
||||||
pub async fn #var_name(&mut self, #question_args) -> Result<#answer_type, #error_enum_name> {
|
pub async fn #var_name(&self, #question_args) -> Result<#answer_type, #error_enum_name> {
|
||||||
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
|
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
|
||||||
let answer = self.recv_until(nonce).await?;
|
let answer = self.recv_until(nonce).await?;
|
||||||
match answer {
|
match answer {
|
||||||
@ -175,22 +176,72 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
#(#server_trait)*
|
#(#server_trait)*
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
// Create a struct to hold queries behind an Arc<Mutex<>> to enable async access
|
||||||
|
// TODO: It might be a good idea to just make this a generic struct and write it in actual code
|
||||||
|
// rather than in this macro
|
||||||
|
let queries_struct = quote! {
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct #queries_struct_name {
|
||||||
|
queries: ::std::sync::Arc<::std::sync::Mutex<::std::collections::HashMap<u64, #query_enum_name>>>,
|
||||||
|
}
|
||||||
|
impl #queries_struct_name {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
queries: ::std::sync::Arc::new(::std::sync::Mutex::new(::std::collections::HashMap::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&self, nonce: u64, query: #query_enum_name) {
|
||||||
|
self.queries.lock().unwrap().insert(nonce, query);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> {
|
||||||
|
self.queries.lock().unwrap().get(nonce).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) {
|
||||||
|
if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) {
|
||||||
|
query.set_answer(answer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.queries.lock().unwrap().len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
// Create a struct which the client will use to communicate
|
// Create a struct which the client will use to communicate
|
||||||
|
let client_recv_queue_wrapper = format_ident!("__{}RecvQueueWrapper", name);
|
||||||
let client_struct = quote! {
|
let client_struct = quote! {
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct #client_recv_queue_wrapper {
|
||||||
|
recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>,
|
||||||
|
}
|
||||||
|
impl #client_recv_queue_wrapper {
|
||||||
|
fn new(recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
|
||||||
|
Self {
|
||||||
|
recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn recv(&self) -> Option<(u64, #answer_enum_name)> {
|
||||||
|
self.recv_queue.lock().await.recv().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[derive(Clone)]
|
||||||
#vis struct #client_struct_name {
|
#vis struct #client_struct_name {
|
||||||
queries: ::std::collections::HashMap<u64, #query_enum_name>,
|
queries: #queries_struct_name,
|
||||||
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||||
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
recv_queue: #client_recv_queue_wrapper,
|
||||||
} // TODO: This struct will have some fields to handle the actual connection
|
} // TODO: This struct will have some fields to handle the actual connection
|
||||||
impl #client_struct_name {
|
impl #client_struct_name {
|
||||||
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
|
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
queries: ::std::collections::HashMap::new(),
|
queries: #queries_struct_name::new(),
|
||||||
|
recv_queue: #client_recv_queue_wrapper::new(recv_queue),
|
||||||
send_queue,
|
send_queue,
|
||||||
recv_queue,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn send(&mut self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
||||||
let nonce = self.queries.len() as u64;
|
let nonce = self.queries.len() as u64;
|
||||||
let res = self.send_queue.send((nonce, query.clone())).await;
|
let res = self.send_queue.send((nonce, query.clone())).await;
|
||||||
match res {
|
match res {
|
||||||
@ -201,7 +252,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
Err(e) => Err(#error_enum_name::SendError(e)),
|
Err(e) => Err(#error_enum_name::SendError(e)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn recv_until(&mut self, id: u64) -> Result<#answer_enum_name, #error_enum_name> {
|
async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> {
|
||||||
loop {
|
loop {
|
||||||
// Check if we've received the answer for the query we're looking for
|
// Check if we've received the answer for the query we're looking for
|
||||||
if let Some(query) = self.queries.get(&id) {
|
if let Some(query) = self.queries.get(&id) {
|
||||||
@ -211,12 +262,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
}
|
}
|
||||||
match self.recv_queue.recv().await {
|
match self.recv_queue.recv().await {
|
||||||
Some((nonce, answer)) => {
|
Some((nonce, answer)) => {
|
||||||
// Replace the Option<> in the query with the answer
|
self.queries.set_answer(nonce, answer.clone());
|
||||||
if let Some(query) = self.queries.get_mut(&nonce) {
|
|
||||||
query.set_answer(answer);
|
|
||||||
} else {
|
|
||||||
panic!("Received an answer for a query we did not send");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
None => return Err(#error_enum_name::Closed),
|
None => return Err(#error_enum_name::Closed),
|
||||||
};
|
};
|
||||||
@ -231,6 +277,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
#answer_enum
|
#answer_enum
|
||||||
#question_enum
|
#question_enum
|
||||||
#query_enum
|
#query_enum
|
||||||
|
#queries_struct
|
||||||
#server_trait
|
#server_trait
|
||||||
#client_struct
|
#client_struct
|
||||||
};
|
};
|
||||||
|
22
tests/mod.rs
22
tests/mod.rs
@ -30,7 +30,7 @@ enum TestProtocol {
|
|||||||
async fn main() {
|
async fn main() {
|
||||||
let (qtx, qrx) = mpsc::channel(16);
|
let (qtx, qrx) = mpsc::channel(16);
|
||||||
let (atx, arx) = mpsc::channel(16);
|
let (atx, arx) = mpsc::channel(16);
|
||||||
let mut client = TestProtocolClient::new(qtx, arx);
|
let client = TestProtocolClient::new(qtx, arx);
|
||||||
let server = tokio::spawn(server_loop(qrx, atx));
|
let server = tokio::spawn(server_loop(qrx, atx));
|
||||||
let result = client.addition(2, 5).await.unwrap();
|
let result = client.addition(2, 5).await.unwrap();
|
||||||
assert_eq!(result, 7);
|
assert_eq!(result, 7);
|
||||||
@ -86,3 +86,23 @@ async fn server_loop(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn heavy_async() {
|
||||||
|
let (qtx, qrx) = mpsc::channel(16);
|
||||||
|
let (atx, arx) = mpsc::channel(16);
|
||||||
|
let client = TestProtocolClient::new(qtx, arx);
|
||||||
|
let server = tokio::spawn(server_loop(qrx, atx));
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
for i in 0..100 {
|
||||||
|
let client = client.clone();
|
||||||
|
tasks.push(tokio::spawn(async move {
|
||||||
|
let result = client.addition(i, i).await.unwrap();
|
||||||
|
assert_eq!(result, i + i);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
for task in tasks {
|
||||||
|
task.await.unwrap();
|
||||||
|
}
|
||||||
|
server.abort();
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user