Implement code for recieving the answer to a query
This commit is contained in:
parent
aa7d51f088
commit
db036064e7
80
src/lib.rs
80
src/lib.rs
@ -33,10 +33,10 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|||||||
};
|
};
|
||||||
let name = &input.ident;
|
let name = &input.ident;
|
||||||
|
|
||||||
let error_enum_name = format_ident!("{}Error", name);
|
let error_enum_name = format_ident!("__{}Error", name);
|
||||||
let answer_enum_name = format_ident!("{}Answer", name);
|
let answer_enum_name = format_ident!("__{}Answer", name);
|
||||||
let question_enum_name = format_ident!("{}Question", name);
|
let question_enum_name = format_ident!("__{}Question", name);
|
||||||
let query_enum_name = format_ident!("{}Query", name);
|
let query_enum_name = format_ident!("__{}Query", name);
|
||||||
let server_trait_name = format_ident!("{}Server", name);
|
let server_trait_name = format_ident!("{}Server", name);
|
||||||
let client_struct_name = format_ident!("{}Client", name);
|
let client_struct_name = format_ident!("{}Client", name);
|
||||||
|
|
||||||
@ -49,6 +49,8 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|||||||
|
|
||||||
let mut query_enum = Vec::new();
|
let mut query_enum = Vec::new();
|
||||||
let mut query_from_question_enum = Vec::new();
|
let mut query_from_question_enum = Vec::new();
|
||||||
|
let mut query_set_answer = Vec::new();
|
||||||
|
let mut query_get_answer = Vec::new();
|
||||||
|
|
||||||
for variant in &enum_.variants {
|
for variant in &enum_.variants {
|
||||||
// Every variant must have 2 fields
|
// Every variant must have 2 fields
|
||||||
@ -81,6 +83,20 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|||||||
query_from_question_enum.push(quote! {
|
query_from_question_enum.push(quote! {
|
||||||
#question_enum_name::#var_name(question) => #query_enum_name::#var_name(question, None),
|
#question_enum_name::#var_name(question) => #query_enum_name::#var_name(question, None),
|
||||||
});
|
});
|
||||||
|
// There is a function that must be implemented to set the answer in the query enum
|
||||||
|
query_set_answer.push(quote! {
|
||||||
|
#query_enum_name::#var_name(question, answer_opt) => match answer {
|
||||||
|
#answer_enum_name::#var_name(answer) => {answer_opt.replace(answer);},
|
||||||
|
_ => panic!("The answer for this query is not the correct type."),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
// There is a function that must be implemented to get the answer from the query enum
|
||||||
|
query_get_answer.push(quote! {
|
||||||
|
#query_enum_name::#var_name(_, answer) => match answer {
|
||||||
|
Some(answer) => Some(#answer_enum_name::#var_name(answer.clone())),
|
||||||
|
None => None
|
||||||
|
},
|
||||||
|
});
|
||||||
// The function that the server needs to implement
|
// The function that the server needs to implement
|
||||||
server_trait.push(quote! {
|
server_trait.push(quote! {
|
||||||
fn #var_name(&mut self, #question_args) -> #answer_type;
|
fn #var_name(&mut self, #question_args) -> #answer_type;
|
||||||
@ -89,7 +105,11 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|||||||
client_impl.push(quote! {
|
client_impl.push(quote! {
|
||||||
pub async fn #var_name(&mut self, #question_args) -> Result<#answer_type, #error_enum_name> {
|
pub async fn #var_name(&mut self, #question_args) -> Result<#answer_type, #error_enum_name> {
|
||||||
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
|
let nonce = self.send(#question_enum_name::#var_name(#question_tuple_args)).await?;
|
||||||
todo!("Wait for the answer")
|
let answer = self.recv_until(nonce).await?;
|
||||||
|
match answer {
|
||||||
|
#answer_enum_name::#var_name(answer) => Ok(answer),
|
||||||
|
_ => panic!("The answer for this query is not the correct type."),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
|
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
|
||||||
@ -101,7 +121,8 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|||||||
// Create an error and result type for sending messages
|
// Create an error and result type for sending messages
|
||||||
let error_enum = quote! {
|
let error_enum = quote! {
|
||||||
#vis enum #error_enum_name {
|
#vis enum #error_enum_name {
|
||||||
SendError(tokio::sync::mpsc::error::SendError<#question_enum_name>),
|
SendError(tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>),
|
||||||
|
Closed,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// Create enums for the types of messages the server and client will use
|
// Create enums for the types of messages the server and client will use
|
||||||
@ -124,6 +145,18 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|||||||
#vis enum #query_enum_name {
|
#vis enum #query_enum_name {
|
||||||
#(#query_enum), *
|
#(#query_enum), *
|
||||||
}
|
}
|
||||||
|
impl #query_enum_name {
|
||||||
|
pub fn set_answer(&mut self, answer: #answer_enum_name) {
|
||||||
|
match self {
|
||||||
|
#(#query_set_answer)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
pub fn get_answer(&self) -> Option<#answer_enum_name> {
|
||||||
|
match self {
|
||||||
|
#(#query_get_answer)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
impl From<#question_enum_name> for #query_enum_name {
|
impl From<#question_enum_name> for #query_enum_name {
|
||||||
fn from(query: #question_enum_name) -> Self {
|
fn from(query: #question_enum_name) -> Self {
|
||||||
match query {
|
match query {
|
||||||
@ -142,11 +175,11 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|||||||
let client_struct = quote! {
|
let client_struct = quote! {
|
||||||
#vis struct #client_struct_name {
|
#vis struct #client_struct_name {
|
||||||
queries: ::std::collections::HashMap<u64, #query_enum_name>,
|
queries: ::std::collections::HashMap<u64, #query_enum_name>,
|
||||||
send_queue: tokio::sync::mpsc::Sender<#question_enum_name>,
|
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||||
recv_queue: tokio::sync::mpsc::Receiver<#answer_enum_name>,
|
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
||||||
} // TODO: This struct will have some fields to handle the actual connection
|
} // TODO: This struct will have some fields to handle the actual connection
|
||||||
impl #client_struct_name {
|
impl #client_struct_name {
|
||||||
pub fn new(send_queue: tokio::sync::mpsc::Sender<#question_enum_name>, recv_queue: tokio::sync::mpsc::Receiver<#answer_enum_name>) -> Self {
|
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
queries: ::std::collections::HashMap::new(),
|
queries: ::std::collections::HashMap::new(),
|
||||||
send_queue,
|
send_queue,
|
||||||
@ -154,16 +187,37 @@ pub fn derive_protocol(input: TokenStream) -> TokenStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn send(&mut self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
async fn send(&mut self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
||||||
let res = self.send_queue.send(query.clone()).await;
|
let nonce = self.queries.len() as u64;
|
||||||
|
let res = self.send_queue.send((nonce, query.clone())).await;
|
||||||
match res {
|
match res {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
let id = self.queries.len() as u64;
|
self.queries.insert(nonce, query.into());
|
||||||
self.queries.insert(id, query.into());
|
Ok(nonce)
|
||||||
Ok(id)
|
|
||||||
}
|
}
|
||||||
Err(e) => Err(#error_enum_name::SendError(e)),
|
Err(e) => Err(#error_enum_name::SendError(e)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
async fn recv_until(&mut self, id: u64) -> Result<#answer_enum_name, #error_enum_name> {
|
||||||
|
loop {
|
||||||
|
// Check if we've received the answer for the query we're looking for
|
||||||
|
if let Some(query) = self.queries.get(&id) {
|
||||||
|
if let Some(answer) = query.get_answer() {
|
||||||
|
return Ok(answer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match self.recv_queue.recv().await {
|
||||||
|
Some((nonce, answer)) => {
|
||||||
|
// Replace the Option<> in the query with the answer
|
||||||
|
if let Some(query) = self.queries.get_mut(&nonce) {
|
||||||
|
query.set_answer(answer);
|
||||||
|
} else {
|
||||||
|
panic!("Received an answer for a query we did not send");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => return Err(#error_enum_name::Closed),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
#(#client_impl)*
|
#(#client_impl)*
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
10
tests/mod.rs
10
tests/mod.rs
@ -21,6 +21,8 @@ use eagle::Protocol;
|
|||||||
enum TestProtocol {
|
enum TestProtocol {
|
||||||
Addition((i32, i32), i32),
|
Addition((i32, i32), i32),
|
||||||
SomeKindOfQuestion(String, i32),
|
SomeKindOfQuestion(String, i32),
|
||||||
|
ThisRespondsWithAString(i32, String),
|
||||||
|
Void((), ()),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct DummyServer;
|
struct DummyServer;
|
||||||
@ -32,6 +34,14 @@ impl TestProtocolServer for DummyServer {
|
|||||||
fn addition(&mut self, a: i32, b: i32) -> i32 {
|
fn addition(&mut self, a: i32, b: i32) -> i32 {
|
||||||
a + b
|
a + b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn this_responds_with_a_string(&mut self, arg: i32) -> String {
|
||||||
|
format!("The number is {}", arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn void(&mut self) {
|
||||||
|
println!("Void function called!")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {}
|
fn main() {}
|
||||||
|
Loading…
Reference in New Issue
Block a user