Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
a1e10f93ce
|
|||
|
2cf0b9abe4
|
|||
|
8b0f01e606
|
|||
|
beda8c151d
|
|||
|
84f7009ad2
|
|||
|
267b741ac4
|
|||
|
bffb41e8a1
|
|||
|
b5870e62fe
|
|||
|
f4d65a2c51
|
|||
|
912b69ef93
|
|||
|
2353c1648e
|
|||
|
e1f453fa8b
|
@@ -14,7 +14,7 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Publish to Gitea Cargo registry
|
- name: Publish to Gitea Cargo registry
|
||||||
run: nix develop -c cargo publish --token ${{ secrets.GITEA_TOKEN }} --index sparse+https://git.colon-three.com/api/packages/kodi/cargo/
|
run: nix develop -c cargo publish --token "Bearer ${{ secrets.GITHUB_TOKEN }}" --index sparse+https://git.colon-three.com/api/packages/kodi/cargo/
|
||||||
- name: Publish to crates.io
|
- name: Publish to crates.io
|
||||||
run: nix develop -c cargo publish --token ${{ secrets.CRATESIO_TOKEN }}
|
run: nix develop -c cargo publish --token ${{ secrets.CRATESIO_TOKEN }}
|
||||||
- name: Publish to Gitea Releases
|
- name: Publish to Gitea Releases
|
||||||
|
|||||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -153,7 +153,7 @@ checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "eagle"
|
name = "eagle"
|
||||||
version = "0.2.2"
|
version = "0.3.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"log",
|
"log",
|
||||||
|
|||||||
20
Cargo.toml
20
Cargo.toml
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "eagle"
|
name = "eagle"
|
||||||
version = "0.2.2"
|
version = "0.3.0"
|
||||||
description = "A simple library for creating RPC protocols."
|
description = "A simple library for creating RPC protocols."
|
||||||
repository = "https://git.colon-three.com/kodi/eagle"
|
repository = "https://git.colon-three.com/kodi/eagle"
|
||||||
authors = ["KodiCraft <kodi@kdcf.me>"]
|
authors = ["KodiCraft <kodi@kdcf.me>"]
|
||||||
@@ -10,26 +10,24 @@ resolver = "2"
|
|||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["tcp", "log"]
|
default = ["tcp", "log"]
|
||||||
tcp = ["tokio/net"]
|
tcp = []
|
||||||
unix = ["tokio/net"]
|
unix = []
|
||||||
log = ["dep:log", "dep:env_logger"]
|
log = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
proc-macro2 = "1.0.85"
|
proc-macro2 = "1.0.85"
|
||||||
quote = "1.0.36"
|
quote = "1.0.36"
|
||||||
ron = "0.8.1"
|
|
||||||
serde = { version = "1.0.203", features = ["serde_derive"] }
|
|
||||||
syn = "2.0.66"
|
syn = "2.0.66"
|
||||||
tokio = { version = "1.38.0", features = ["sync", "io-util"] }
|
|
||||||
env_logger = { version = "0.11.3", optional = true }
|
|
||||||
log = { version = "0.4.21", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
# tokio = { version = "1.38.0", features = ["sync", "io-util"] }
|
||||||
|
ron = "0.8.1"
|
||||||
|
serde = { version = "1.0.203", features = ["serde_derive"] }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
tokio = { version = "1.38.0", features = ["sync", "rt-multi-thread", "macros", "time", "io-util", "net"] }
|
tokio = { version = "1.38.0", features = ["sync", "rt-multi-thread", "macros", "time", "io-util", "net"] }
|
||||||
env_logger = "0.11.3"
|
|
||||||
log = "0.4.21"
|
|
||||||
tokio-test = "0.4.4"
|
tokio-test = "0.4.4"
|
||||||
|
env_logger = { version = "0.11.3" }
|
||||||
|
log = { version = "0.4.21" }
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
proc-macro = true
|
proc-macro = true
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ Eagle is still in early development. Performance is not ideal, the interface is
|
|||||||
Eagle is a library which allows you to easily build an [RPC](https://en.wikipedia.org/wiki/Remote_procedure_call) protocol.
|
Eagle is a library which allows you to easily build an [RPC](https://en.wikipedia.org/wiki/Remote_procedure_call) protocol.
|
||||||
It uses a macro to generate the required communication code and makes adding new functions easy and quick. Eagle is designed to work specifically with [`tokio`](https://crates.io/crates/tokio) and uses [`serde`](https://crates.io/crates/serde) for formatting data.
|
It uses a macro to generate the required communication code and makes adding new functions easy and quick. Eagle is designed to work specifically with [`tokio`](https://crates.io/crates/tokio) and uses [`serde`](https://crates.io/crates/serde) for formatting data.
|
||||||
|
|
||||||
|
Please note that since `eagle` is a pure proc-macro library, you must manually add compatible versions of `tokio`, `serde`, `ron` and optionally `log` to your dependencies.
|
||||||
|
|
||||||
## Using Eagle
|
## Using Eagle
|
||||||
|
|
||||||
The way that `eagle` is designed to be used is inside a shared dependency between your "server" and your "client". Both of these should be in a workspace. Create a `shared` crate which both components should depend on, this crate should have `eagle` as a dependency. By default `eagle` uses TCP for communication, but you may disable default features and enable the `unix` feature on `eagle` to use unix sockets instead.
|
The way that `eagle` is designed to be used is inside a shared dependency between your "server" and your "client". Both of these should be in a workspace. Create a `shared` crate which both components should depend on, this crate should have `eagle` as a dependency. By default `eagle` uses TCP for communication, but you may disable default features and enable the `unix` feature on `eagle` to use unix sockets instead.
|
||||||
@@ -67,9 +69,9 @@ Your handler can now be used by the server. You can easily bind your server to a
|
|||||||
use shared::ExampleServer;
|
use shared::ExampleServer;
|
||||||
|
|
||||||
let handler = ExampleHandler { state: 0 };
|
let handler = ExampleHandler { state: 0 };
|
||||||
let server_task = tokio::spawn(ExampleServer::bind(handler, "127.0.0.1:1234"));
|
let server_task = ExampleServer::bind(handler, "127.0.0.1:1234").await;
|
||||||
// Or, if you're using the 'unix' feature...
|
// Or, if you're using the 'unix' feature...
|
||||||
let server_task = tokio::spawn(ExampleServer::bind(handler, "/tmp/sock"));
|
let server_task = ExampleServer::bind(handler, "/tmp/sock").await;
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
233
src/lib.rs
233
src/lib.rs
@@ -102,13 +102,12 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||||||
//! # tokio_test::block_on(async {
|
//! # tokio_test::block_on(async {
|
||||||
//! let handler = Handler;
|
//! let handler = Handler;
|
||||||
//! let address = "127.0.0.1:12345"; // Or, if using the 'unix' feature, "/tmp/eagle.sock"
|
//! let address = "127.0.0.1:12345"; // Or, if using the 'unix' feature, "/tmp/eagle.sock"
|
||||||
//! let server_task = tokio::spawn(ExampleServer::bind(handler, address));
|
//! let server = ExampleServer::bind(handler, address).await;
|
||||||
|
//! server.close().await;
|
||||||
//! # });
|
//! # });
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//! Once bound, the server will begin listening for incoming connections and
|
||||||
//! Please note the usage of `tokio::spawn`. This is because the `bind` function
|
//! queries. **You must remember to use the `close` method to shut down the server.**
|
||||||
//! will not return until the server is closed. You can use the `abort` method
|
|
||||||
//! on the task to close the server.
|
|
||||||
//!
|
//!
|
||||||
//! On the client side, you can simply use the generated client struct to connect
|
//! On the client side, you can simply use the generated client struct to connect
|
||||||
//! to the server and begin sending queries.
|
//! to the server and begin sending queries.
|
||||||
@@ -136,11 +135,11 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||||||
//! # tokio_test::block_on(async {
|
//! # tokio_test::block_on(async {
|
||||||
//! # let handler = Handler;
|
//! # let handler = Handler;
|
||||||
//! let address = "127.0.0.1:12345"; // Or, if using the 'unix' feature, "/tmp/eagle.sock"
|
//! let address = "127.0.0.1:12345"; // Or, if using the 'unix' feature, "/tmp/eagle.sock"
|
||||||
//! # let server_task = tokio::spawn(ExampleServer::bind(handler, address));
|
//! # let server = ExampleServer::bind(handler, address).await;
|
||||||
|
//! # tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; // Wait for the server to start
|
||||||
//! let client = ExampleClient::connect(address).await.unwrap();
|
//! let client = ExampleClient::connect(address).await.unwrap();
|
||||||
//! # // Wait for the server to start, the developer is responsible for this in production
|
|
||||||
//! # tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
||||||
//! assert_eq!(client.add(2, 5).await.unwrap(), 7);
|
//! assert_eq!(client.add(2, 5).await.unwrap(), 7);
|
||||||
|
//! # server.close().await;
|
||||||
//! # });
|
//! # });
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
@@ -198,28 +197,28 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
// TODO: These logs should be filterable in some way
|
// TODO: These logs should be filterable in some way
|
||||||
#[cfg(feature = "log")]
|
#[cfg(feature = "log")]
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let debug = quote! { log::debug! };
|
let debug = quote! { ::log::debug! };
|
||||||
#[cfg(feature = "log")]
|
#[cfg(feature = "log")]
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let info = quote! { log::info! };
|
let info = quote! { ::log::info! };
|
||||||
#[cfg(feature = "log")]
|
#[cfg(feature = "log")]
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let warn = quote! { log::warn! };
|
let warn = quote! { ::log::warn! };
|
||||||
#[cfg(feature = "log")]
|
#[cfg(feature = "log")]
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let error = quote! { log::error! };
|
let error = quote! { ::log::error! };
|
||||||
#[cfg(not(feature = "log"))]
|
#[cfg(not(feature = "log"))]
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let debug = quote! { eprintln! };
|
let debug = quote! { ::std::eprintln! };
|
||||||
#[cfg(not(feature = "log"))]
|
#[cfg(not(feature = "log"))]
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let info = quote! { eprintln! };
|
let info = quote! { ::std::eprintln! };
|
||||||
#[cfg(not(feature = "log"))]
|
#[cfg(not(feature = "log"))]
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let warn = quote! { eprintln! };
|
let warn = quote! { ::std::eprintln! };
|
||||||
#[cfg(not(feature = "log"))]
|
#[cfg(not(feature = "log"))]
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let error = quote! { eprintln! };
|
let error = quote! { ::std::eprintln! };
|
||||||
|
|
||||||
// Must be on an enum
|
// Must be on an enum
|
||||||
let enum_ = match &input.data {
|
let enum_ = match &input.data {
|
||||||
@@ -299,8 +298,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
// There is a function that must be implemented to get the answer from the query enum
|
// There is a function that must be implemented to get the answer from the query enum
|
||||||
query_get_answer.push(quote! {
|
query_get_answer.push(quote! {
|
||||||
#query_enum_name::#var_name(_, answer) => match answer {
|
#query_enum_name::#var_name(_, answer) => match answer {
|
||||||
Some(answer) => Some(#answer_enum_name::#var_name(answer.clone())),
|
::std::option::Option::Some(answer) => ::std::option::Option::Some(#answer_enum_name::#var_name(answer.clone())),
|
||||||
None => None
|
::std::option::Option::None => ::std::option::Option::None
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
// There is a function that the server uses to call the appropriate function when receiving a query
|
// There is a function that the server uses to call the appropriate function when receiving a query
|
||||||
@@ -313,7 +312,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
});
|
});
|
||||||
// The function that the server needs to implement
|
// The function that the server needs to implement
|
||||||
server_trait.push(quote! {
|
server_trait.push(quote! {
|
||||||
fn #var_name(&mut self, #question_args) -> impl std::future::Future<Output = #answer_type> + Send;
|
fn #var_name(&mut self, #question_args) -> impl ::std::future::Future<Output = #answer_type> + Send;
|
||||||
});
|
});
|
||||||
// The function that the client uses to communicate
|
// The function that the client uses to communicate
|
||||||
client_impl.push(quote! {
|
client_impl.push(quote! {
|
||||||
@@ -323,42 +322,70 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
let answer = self.recv_until(nonce).await?;
|
let answer = self.recv_until(nonce).await?;
|
||||||
match answer {
|
match answer {
|
||||||
#answer_enum_name::#var_name(answer) => Ok(answer),
|
#answer_enum_name::#var_name(answer) => Ok(answer),
|
||||||
_ => panic!("The answer for this query is not the correct type."),
|
_ => ::std::panic!("The answer for this query is not the correct type."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
|
// The query enum is the same as the source enum, but the second field is always wrapped in a Option<>
|
||||||
query_enum.push(quote! {
|
query_enum.push(quote! {
|
||||||
#var_name(#question_field, Option<#answer_type>)
|
#var_name(#question_field, ::std::option::Option<#answer_type>)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an error and result type for sending messages
|
// Create an error and result type for sending messages
|
||||||
let error_enum = quote! {
|
let error_enum = quote! {
|
||||||
#[derive(Debug)]
|
#[derive(::std::fmt::Debug)]
|
||||||
#vis enum #error_enum_name {
|
#vis enum #error_enum_name {
|
||||||
SendError(tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>),
|
SendError(::tokio::sync::mpsc::error::SendError<(u64, #question_enum_name)>),
|
||||||
Closed,
|
Closed,
|
||||||
}
|
}
|
||||||
|
impl ::std::fmt::Display for #error_enum_name {
|
||||||
|
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
#error_enum_name::SendError(e) => write!(f, "Failed to send query: {}", e),
|
||||||
|
#error_enum_name::Closed => write!(f, "Connection closed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl ::std::error::Error for #error_enum_name {
|
||||||
|
fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
#error_enum_name::SendError(e) => ::std::option::Option::Some(e),
|
||||||
|
#error_enum_name::Closed => ::std::option::Option::None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
#error_enum_name::SendError(_) => "Failed to send query",
|
||||||
|
#error_enum_name::Closed => "Connection closed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn cause(&self) -> ::std::option::Option<&dyn ::std::error::Error> {
|
||||||
|
match self {
|
||||||
|
#error_enum_name::SendError(e) => ::std::option::Option::Some(e),
|
||||||
|
#error_enum_name::Closed => ::std::option::Option::None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// Create enums for the types of messages the server and client will use
|
// Create enums for the types of messages the server and client will use
|
||||||
|
|
||||||
let answer_enum = quote! {
|
let answer_enum = quote! {
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
#[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)]
|
||||||
#vis enum #answer_enum_name {
|
#vis enum #answer_enum_name {
|
||||||
#(#server_enum), *,
|
#(#server_enum), *,
|
||||||
Ready
|
Ready
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let question_enum = quote! {
|
let question_enum = quote! {
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
#[derive(::serde::Serialize, ::serde::Deserialize, ::std::clone::Clone, ::std::fmt::Debug)]
|
||||||
#vis enum #question_enum_name {
|
#vis enum #question_enum_name {
|
||||||
#(#client_enum), *
|
#(#client_enum), *
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// Create an enum to represent the queries the client has sent
|
// Create an enum to represent the queries the client has sent
|
||||||
let query_enum = quote! {
|
let query_enum = quote! {
|
||||||
#[derive(Clone, Debug)]
|
#[derive(::std::clone::Clone, ::std::fmt::Debug)]
|
||||||
#vis enum #query_enum_name {
|
#vis enum #query_enum_name {
|
||||||
#(#query_enum), *
|
#(#query_enum), *
|
||||||
}
|
}
|
||||||
@@ -368,7 +395,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
#(#query_set_answer)*
|
#(#query_set_answer)*
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
pub fn get_answer(&self) -> Option<#answer_enum_name> {
|
pub fn get_answer(&self) -> ::std::option::Option<#answer_enum_name> {
|
||||||
match self {
|
match self {
|
||||||
#(#query_get_answer)*
|
#(#query_get_answer)*
|
||||||
}
|
}
|
||||||
@@ -384,17 +411,17 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(feature = "tcp")]
|
#[cfg(feature = "tcp")]
|
||||||
let stream_type = quote! { tokio::net::TcpStream };
|
let stream_type = quote! { ::tokio::net::TcpStream };
|
||||||
#[cfg(feature = "tcp")]
|
#[cfg(feature = "tcp")]
|
||||||
let stream_addr_trait = quote! { tokio::net::ToSocketAddrs };
|
let stream_addr_trait = quote! { ::tokio::net::ToSocketAddrs };
|
||||||
#[cfg(feature = "tcp")]
|
#[cfg(feature = "tcp")]
|
||||||
let listener_type = quote! { tokio::net::TcpListener };
|
let listener_type = quote! { ::tokio::net::TcpListener };
|
||||||
#[cfg(feature = "unix")]
|
#[cfg(feature = "unix")]
|
||||||
let stream_type = quote! { tokio::net::UnixStream };
|
let stream_type = quote! { ::tokio::net::UnixStream };
|
||||||
#[cfg(feature = "unix")]
|
#[cfg(feature = "unix")]
|
||||||
let stream_addr_trait = quote! { std::convert::AsRef<std::path::Path> };
|
let stream_addr_trait = quote! { ::std::convert::AsRef<std::path::Path> };
|
||||||
#[cfg(feature = "unix")]
|
#[cfg(feature = "unix")]
|
||||||
let listener_type = quote! { tokio::net::UnixListener };
|
let listener_type = quote! { ::tokio::net::UnixListener };
|
||||||
|
|
||||||
// Create a trait which the server will have to implement
|
// Create a trait which the server will have to implement
|
||||||
let server_trait = quote! {
|
let server_trait = quote! {
|
||||||
@@ -414,37 +441,44 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
};
|
};
|
||||||
let sc_struct = quote! {
|
let sc_struct = quote! {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + Clone + 'static> {
|
#vis struct #server_connection_struct_name<H: #server_trait_name + ::std::marker::Send + ::std::clone::Clone + 'static> {
|
||||||
handler: ::std::sync::Arc<tokio::sync::Mutex<H>>,
|
handler: ::std::sync::Arc<::tokio::sync::Mutex<H>>,
|
||||||
tasks: ::std::sync::Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>,
|
tasks: ::std::sync::Arc<::tokio::sync::Mutex<::std::vec::Vec<tokio::task::JoinHandle<()>>>>,
|
||||||
}
|
}
|
||||||
impl<H: #server_trait_name + ::std::marker::Send + Clone + 'static> #server_connection_struct_name<H> {
|
impl<H: #server_trait_name + ::std::marker::Send + std::clone::Clone + 'static> #server_connection_struct_name<H> {
|
||||||
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
|
pub async fn bind<A: #stream_addr_trait + ::std::marker::Send + ::std::fmt::Display + 'static>(handler: H, addr: A) -> Self {
|
||||||
#info("Binding server to address {}", addr);
|
#info("Binding server to address {}", addr);
|
||||||
let handler = ::std::sync::Arc::new(tokio::sync::Mutex::new(handler));
|
let handler = ::std::sync::Arc::new(::tokio::sync::Mutex::new(handler));
|
||||||
let tasks = ::std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
|
let tasks = ::std::sync::Arc::new(::tokio::sync::Mutex::new(::std::vec::Vec::new()));
|
||||||
let sc = Self {
|
let sc = Self {
|
||||||
handler,
|
handler,
|
||||||
tasks,
|
tasks,
|
||||||
};
|
};
|
||||||
let sc_clone = sc.clone();
|
let sc_clone = sc.clone();
|
||||||
let acc_task = tokio::spawn(async move {
|
let acc_task = ::tokio::spawn(async move {
|
||||||
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
|
sc_clone.accept_connections(addr).await.expect("Failed to accept connections!");
|
||||||
});
|
});
|
||||||
sc.tasks.lock().await.push(acc_task);
|
sc.tasks.lock().await.push(acc_task);
|
||||||
sc
|
sc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn close(self) {
|
||||||
|
#info("Closing server");
|
||||||
|
for task in self.tasks.lock().await.drain(..) {
|
||||||
|
task.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn accept_connections<A: #stream_addr_trait>(
|
pub async fn accept_connections<A: #stream_addr_trait>(
|
||||||
&self,
|
&self,
|
||||||
addr: A,
|
addr: A,
|
||||||
) -> Result<(), std::io::Error> {
|
) -> ::std::result::Result<(), ::std::io::Error> {
|
||||||
#listener_statement
|
#listener_statement
|
||||||
loop {
|
loop {
|
||||||
let (stream, _) = listener.accept().await?;
|
let (stream, _) = listener.accept().await?;
|
||||||
#info("Accepted connection from {:?}", stream.peer_addr()?);
|
#info("Accepted connection from {:?}", stream.peer_addr()?);
|
||||||
let self_clone = self.clone();
|
let self_clone = self.clone();
|
||||||
let run_task = tokio::spawn(async move {
|
let run_task = ::tokio::spawn(async move {
|
||||||
self_clone.run(stream).await;
|
self_clone.run(stream).await;
|
||||||
});
|
});
|
||||||
self.tasks.lock().await.push(run_task);
|
self.tasks.lock().await.push(run_task);
|
||||||
@@ -458,7 +492,7 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) {
|
async fn send(&self, stream: &mut #stream_type, nonce: u64, answer: #answer_enum_name) {
|
||||||
use tokio::io::AsyncWriteExt;
|
use ::tokio::io::AsyncWriteExt;
|
||||||
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
|
let serialized = ron::ser::to_string(&(nonce, answer)).expect("Failed to serialize response!");
|
||||||
let len = serialized.len() as u32;
|
let len = serialized.len() as u32;
|
||||||
#debug("Sending `{}`", serialized);
|
#debug("Sending `{}`", serialized);
|
||||||
@@ -467,24 +501,24 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn run(&self, mut stream: #stream_type) {
|
async fn run(&self, mut stream: #stream_type) {
|
||||||
use tokio::io::AsyncWriteExt;
|
use ::tokio::io::AsyncWriteExt;
|
||||||
use tokio::io::AsyncReadExt;
|
use ::tokio::io::AsyncReadExt;
|
||||||
let mut buf = Vec::with_capacity(1024);
|
let mut buf = ::std::vec::Vec::with_capacity(1024);
|
||||||
self.send(&mut stream, 0, #answer_enum_name::Ready).await;
|
self.send(&mut stream, 0, #answer_enum_name::Ready).await;
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
::tokio::select! {
|
||||||
Ok(_) = stream.readable() => {
|
::std::result::Result::Ok(_) = stream.readable() => {
|
||||||
let mut read_buf = [0; 1024];
|
let mut read_buf = [0; 1024];
|
||||||
match stream.try_read(&mut read_buf) {
|
match stream.try_read(&mut read_buf) {
|
||||||
Ok(0) => break, // Stream closed
|
::std::result::Result::Ok(0) => break, // Stream closed
|
||||||
Ok(n) => {
|
::std::result::Result::Ok(n) => {
|
||||||
#debug("Received {} bytes (server)", n);
|
#debug("Received {} bytes (server)", n);
|
||||||
buf.extend_from_slice(&read_buf[..n]);
|
buf.extend_from_slice(&read_buf[..n]);
|
||||||
while buf.len() >= 4 {
|
while buf.len() >= 4 {
|
||||||
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
let len = u32::from_le_bytes(buf[..4].try_into().expect("Failed to convert bytes to u32"));
|
||||||
if buf.len() >= (4 + len as usize) {
|
if buf.len() >= (4 + len as usize) {
|
||||||
let serialized = std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
let serialized = ::std::str::from_utf8(&buf[4..(4 + len as usize)]).expect("Failed to convert bytes to string");
|
||||||
let (nonce, question): (u64, #question_enum_name) = ron::de::from_str(serialized).expect("Failed to deserialize query!");
|
let (nonce, question): (u64, #question_enum_name) = ::ron::de::from_str(serialized).expect("Failed to deserialize query!");
|
||||||
let answer = self.handle(question).await;
|
let answer = self.handle(question).await;
|
||||||
self.send(&mut stream, nonce, answer).await;
|
self.send(&mut stream, nonce, answer).await;
|
||||||
buf.drain(0..(4 + len as usize));
|
buf.drain(0..(4 + len as usize));
|
||||||
@@ -493,8 +527,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||||
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
::std::result::Result::Err(e) => ::std::eprintln!("Error reading from stream: {:?}", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -522,12 +556,12 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
self.queries.lock().unwrap().insert(nonce, query);
|
self.queries.lock().unwrap().insert(nonce, query);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, nonce: &u64) -> Option<#query_enum_name> {
|
pub fn get(&self, nonce: &u64) -> ::std::option::Option<#query_enum_name> {
|
||||||
self.queries.lock().unwrap().get(nonce).cloned()
|
self.queries.lock().unwrap().get(nonce).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) {
|
pub fn set_answer(&self, nonce: u64, answer: #answer_enum_name) {
|
||||||
if let Some(query) = self.queries.lock().unwrap().get_mut(&nonce) {
|
if let ::std::option::Option::Some(query) = self.queries.lock().unwrap().get_mut(&nonce) {
|
||||||
query.set_answer(answer);
|
query.set_answer(answer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -541,16 +575,16 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
// Create a struct to handle the connection from the client to the server
|
// Create a struct to handle the connection from the client to the server
|
||||||
let cc_struct = quote! {
|
let cc_struct = quote! {
|
||||||
struct #client_connection_struct_name {
|
struct #client_connection_struct_name {
|
||||||
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
||||||
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
||||||
ready: std::sync::Arc<tokio::sync::Notify>,
|
ready: ::std::sync::Arc<tokio::sync::Notify>,
|
||||||
stream: #stream_type,
|
stream: #stream_type,
|
||||||
}
|
}
|
||||||
impl #client_connection_struct_name {
|
impl #client_connection_struct_name {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
to_send: ::tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>,
|
||||||
received: tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
received: ::tokio::sync::mpsc::Sender<(u64, #answer_enum_name)>,
|
||||||
ready: std::sync::Arc<tokio::sync::Notify>,
|
ready: ::std::sync::Arc<::tokio::sync::Notify>,
|
||||||
stream: #stream_type,
|
stream: #stream_type,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -562,23 +596,23 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(mut self) {
|
pub async fn run(mut self) {
|
||||||
use tokio::io::AsyncWriteExt;
|
use ::tokio::io::AsyncWriteExt;
|
||||||
use tokio::io::AsyncReadExt;
|
use ::tokio::io::AsyncReadExt;
|
||||||
let mut buf = Vec::with_capacity(1024);
|
let mut buf = ::std::vec::Vec::with_capacity(1024);
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
::tokio::select! {
|
||||||
Some(msg) = self.to_send.recv() => {
|
::std::option::Option::Some(msg) = self.to_send.recv() => {
|
||||||
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
|
let serialized = ron::ser::to_string(&msg).expect("Failed to serialize query!");
|
||||||
let len = serialized.len() as u32;
|
let len = serialized.len() as u32;
|
||||||
#debug("Sending `{}`", serialized);
|
#debug("Sending `{}`", serialized);
|
||||||
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
self.stream.write_all(&len.to_le_bytes()).await.expect("Failed to write length!");
|
||||||
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
|
self.stream.write_all(serialized.as_bytes()).await.expect("Failed to write query!");
|
||||||
},
|
},
|
||||||
Ok(_) = self.stream.readable() => {
|
::std::result::Result::Ok(_) = self.stream.readable() => {
|
||||||
let mut read_buf = [0; 1024];
|
let mut read_buf = [0; 1024];
|
||||||
match self.stream.try_read(&mut read_buf) {
|
match self.stream.try_read(&mut read_buf) {
|
||||||
Ok(0) => { break; },
|
::std::result::Result::Ok(0) => { break; },
|
||||||
Ok(n) => {
|
::std::result::Result::Ok(n) => {
|
||||||
#debug("Received {} bytes (client)", n);
|
#debug("Received {} bytes (client)", n);
|
||||||
buf.extend_from_slice(&read_buf[..n]);
|
buf.extend_from_slice(&read_buf[..n]);
|
||||||
while buf.len() >= 4 {
|
while buf.len() >= 4 {
|
||||||
@@ -598,8 +632,8 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
::std::result::Result::Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => { continue; },
|
||||||
Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
::std::result::Result::Err(e) => eprintln!("Error reading from stream: {:?}", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -610,33 +644,33 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
// Create a struct which the client will use to communicate
|
// Create a struct which the client will use to communicate
|
||||||
let client_recv_queue_wrapper = format_ident!("__{}RecvQueueWrapper", name);
|
let client_recv_queue_wrapper = format_ident!("__{}RecvQueueWrapper", name);
|
||||||
let client_struct = quote! {
|
let client_struct = quote! {
|
||||||
#[derive(Clone)]
|
#[derive(::std::clone::Clone)]
|
||||||
struct #client_recv_queue_wrapper {
|
struct #client_recv_queue_wrapper {
|
||||||
recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>,
|
recv_queue: ::std::sync::Arc<::tokio::sync::Mutex<::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>>>,
|
||||||
}
|
}
|
||||||
impl #client_recv_queue_wrapper {
|
impl #client_recv_queue_wrapper {
|
||||||
fn new(recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
|
fn new(recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)),
|
recv_queue: ::std::sync::Arc::new(::tokio::sync::Mutex::new(recv_queue)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn recv(&self) -> Option<(u64, #answer_enum_name)> {
|
async fn recv(&self) -> ::std::option::Option<(u64, #answer_enum_name)> {
|
||||||
self.recv_queue.lock().await.recv().await
|
self.recv_queue.lock().await.recv().await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
#vis struct #client_struct_name {
|
#vis struct #client_struct_name {
|
||||||
queries: #queries_struct_name,
|
queries: #queries_struct_name,
|
||||||
send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||||
recv_queue: #client_recv_queue_wrapper,
|
recv_queue: #client_recv_queue_wrapper,
|
||||||
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
|
ready: ::std::sync::Arc<tokio::sync::Mutex<bool>>,
|
||||||
ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
|
ready_notify: ::std::sync::Arc<tokio::sync::Notify>,
|
||||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||||
}
|
}
|
||||||
impl #client_struct_name {
|
impl #client_struct_name {
|
||||||
pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
pub fn new(send_queue: ::tokio::sync::mpsc::Sender<(u64, #question_enum_name)>,
|
||||||
recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
recv_queue: ::tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>,
|
||||||
connection_task: Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
connection_task: ::std::option::Option<::std::sync::Arc<tokio::task::JoinHandle<()>>>,
|
||||||
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
|
ready_notify: ::std::sync::Arc<tokio::sync::Notify>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
queries: #queries_struct_name::new(),
|
queries: #queries_struct_name::new(),
|
||||||
@@ -650,19 +684,19 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> {
|
pub async fn connect<A: #stream_addr_trait + ::std::fmt::Display>(addr: A) -> Result<Self, std::io::Error> {
|
||||||
#info("Connecting to server at address {}", addr);
|
#info("Connecting to server at address {}", addr);
|
||||||
let stream = #stream_type::connect(addr).await?;
|
let stream = #stream_type::connect(addr).await?;
|
||||||
let (send_queue, to_send) = tokio::sync::mpsc::channel(16);
|
let (send_queue, to_send) = ::tokio::sync::mpsc::channel(16);
|
||||||
let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16);
|
let (to_recv, recv_queue) = ::tokio::sync::mpsc::channel(16);
|
||||||
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
|
let ready_notify = ::std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
|
let connection = #client_connection_struct_name::new(to_send, to_recv, ready_notify.clone(), stream);
|
||||||
let connection_task = tokio::spawn(connection.run());
|
let connection_task = ::tokio::spawn(connection.run());
|
||||||
Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)), ready_notify))
|
Ok(Self::new(send_queue, recv_queue, ::std::option::Option::Some(::std::sync::Arc::new(connection_task)), ready_notify))
|
||||||
}
|
}
|
||||||
pub fn close(self) {
|
pub fn close(&mut self) {
|
||||||
if let Some(task) = self.connection_task {
|
if let ::std::option::Option::Some(task) = self.connection_task.take() {
|
||||||
task.abort();
|
task.abort();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn send(&self, query: #question_enum_name) -> Result<u64, #error_enum_name> {
|
async fn send(&self, query: #question_enum_name) -> ::std::result::Result<u64, #error_enum_name> {
|
||||||
// Wait until the connection is ready
|
// Wait until the connection is ready
|
||||||
if !*self.ready.lock().await {
|
if !*self.ready.lock().await {
|
||||||
self.ready_notify.notified().await;
|
self.ready_notify.notified().await;
|
||||||
@@ -672,33 +706,38 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream
|
|||||||
let nonce = self.queries.len() as u64;
|
let nonce = self.queries.len() as u64;
|
||||||
let res = self.send_queue.send((nonce, query.clone())).await;
|
let res = self.send_queue.send((nonce, query.clone())).await;
|
||||||
match res {
|
match res {
|
||||||
Ok(_) => {
|
::std::result::Result::Ok(_) => {
|
||||||
self.queries.insert(nonce, query.into());
|
self.queries.insert(nonce, query.into());
|
||||||
Ok(nonce)
|
::std::result::Result::Ok(nonce)
|
||||||
}
|
}
|
||||||
Err(e) => Err(#error_enum_name::SendError(e)),
|
::std::result::Result::Err(e) => ::std::result::Result::Err(#error_enum_name::SendError(e)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn recv_until(&self, id: u64) -> Result<#answer_enum_name, #error_enum_name> {
|
async fn recv_until(&self, id: u64) -> ::std::result::Result<#answer_enum_name, #error_enum_name> {
|
||||||
loop {
|
loop {
|
||||||
// Check if we've received the answer for the query we're looking for
|
// Check if we've received the answer for the query we're looking for
|
||||||
if let Some(query) = self.queries.get(&id) {
|
if let ::std::option::Option::Some(query) = self.queries.get(&id) {
|
||||||
if let Some(answer) = query.get_answer() {
|
if let ::std::option::Option::Some(answer) = query.get_answer() {
|
||||||
#info("Found answer for query {}", id);
|
#info("Found answer for query {}", id);
|
||||||
return Ok(answer);
|
return Ok(answer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match self.recv_queue.recv().await {
|
match self.recv_queue.recv().await {
|
||||||
Some((nonce, answer)) => {
|
::std::option::Option::Some((nonce, answer)) => {
|
||||||
#info("Received answer for query {}", nonce);
|
#info("Received answer for query {}", nonce);
|
||||||
self.queries.set_answer(nonce, answer.clone());
|
self.queries.set_answer(nonce, answer.clone());
|
||||||
}
|
}
|
||||||
None => return Err(#error_enum_name::Closed),
|
::std::option::Option::None => return ::std::result::Result::Err(#error_enum_name::Closed),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#(#client_impl)*
|
#(#client_impl)*
|
||||||
}
|
}
|
||||||
|
impl ::std::ops::Drop for #client_struct_name {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let expanded = quote! {
|
let expanded = quote! {
|
||||||
|
|||||||
@@ -73,9 +73,8 @@ async fn e2e() {
|
|||||||
};
|
};
|
||||||
#[cfg(feature = "tcp")]
|
#[cfg(feature = "tcp")]
|
||||||
let address = format!("127.0.0.1:{}", 10000 + rand::random::<u64>() % 1000);
|
let address = format!("127.0.0.1:{}", 10000 + rand::random::<u64>() % 1000);
|
||||||
let server_task = tokio::spawn(TestProtocolServer::bind(TrivialServer, address.clone()));
|
let server = TestProtocolServer::bind(TrivialServer, address.clone()).await;
|
||||||
// Wait for the server to start, the developer is responsible for this in production
|
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; // Wait for the server to start
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
||||||
let client = TestProtocolClient::connect(address).await.unwrap();
|
let client = TestProtocolClient::connect(address).await.unwrap();
|
||||||
assert_eq!(client.addition(2, 5).await.unwrap(), 7);
|
assert_eq!(client.addition(2, 5).await.unwrap(), 7);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -90,5 +89,5 @@ async fn e2e() {
|
|||||||
"The number is 42"
|
"The number is 42"
|
||||||
);
|
);
|
||||||
client.void().await.unwrap();
|
client.void().await.unwrap();
|
||||||
server_task.abort();
|
server.close().await;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user