From 15d44b6daa2244dfc7f25d7a11d203a1b49e4354 Mon Sep 17 00:00:00 2001 From: Kodi Craft Date: Sat, 22 Jun 2024 12:57:41 +0200 Subject: [PATCH] Add code to create a client connected to the network --- .gitea/workflows/build.yaml | 12 ++++++++---- Cargo.toml | 10 +++++++++- flake.nix | 20 ++++++++++++++++---- src/lib.rs | 36 ++++++++++++++++++++++++++++++++++-- tests/mod.rs | 4 ++-- 5 files changed, 69 insertions(+), 13 deletions(-) diff --git a/.gitea/workflows/build.yaml b/.gitea/workflows/build.yaml index cab32f9..68a3d8f 100644 --- a/.gitea/workflows/build.yaml +++ b/.gitea/workflows/build.yaml @@ -7,7 +7,11 @@ jobs: steps: - name: Checkout uses: actions/checkout@v3 - - name: Run clippy - run: nix develop . -c cargo clippy - - name: Build & test - run: nix build \ No newline at end of file + - name: Run clippy (tcp feature) + run: nix develop . -c cargo clippy --no-default-features --features tcp + - name: Run clippy (unix feature) + run: nix develop . -c cargo clippy --no-default-features --features unix + - name: Build & test (tcp feature) + run: nix build .#tcp + - name: Build & test (unix feature) + run: nix build .#unix \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 08c8d66..446e49e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,12 @@ version = "0.2.0" edition = "2021" license = "AGPL-3.0" publish = ["gitea"] +resolver = "2" + +[features] +default = ["tcp"] +tcp = ["tokio/net"] +unix = ["tokio/net"] [dependencies] proc-macro2 = "1.0.85" @@ -12,7 +18,9 @@ rand = "0.8.5" ron = "0.8.1" serde = { version = "1.0.203", features = ["serde_derive"] } syn = "2.0.66" -# TODO: rt and macros should be removed unless we do tests +tokio = { version = "1.38.0", features = ["sync", "io-util"] } + +[dev-dependencies] tokio = { version = "1.38.0", features = ["sync", "rt-multi-thread", "macros", "time", "io-util", "net"] } [lib] diff --git a/flake.nix b/flake.nix index 5830806..05b8334 100644 --- a/flake.nix +++ b/flake.nix @@ -14,10 +14,22 @@ naersk-lib = pkgs.callPackage naersk {}; in { - defaultPackage = naersk-lib.buildPackage { - src = ./.; - doCheck = true; - }; + packages = { + default = naersk-lib.buildPackage { + src = ./.; + doCheck = true; + }; + unix = naersk-lib.buildPackage { + src = ./.; + doCheck = true; + cargoOptions = x: x ++ ["--no-default-features" "--features" "unix"]; + }; + tcp = naersk-lib.buildPackage { + src = ./.; + doCheck = true; + cargoOptions = x: x ++ ["--no-default-features" "--features" "tcp"]; + }; + }; devShell = with pkgs; mkShell { nativeBuildInputs = with pkgs; [rustc cargo rustfmt pre-commit clippy cargo-expand]; }; diff --git a/src/lib.rs b/src/lib.rs index 3061801..2cfba2f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,13 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::{parse2, spanned::Spanned, DeriveInput, Field, Ident}; +#[cfg(all(feature = "tcp", feature = "unix"))] +compile_error!("You can only enable one of the 'tcp' or 'unix' features"); +#[cfg(all(not(feature = "tcp"), not(feature = "unix")))] +compile_error!("You must enable either the 'tcp' or 'unix' feature"); +#[cfg(all(feature = "unix", not(unix)))] +compile_error!("The 'unix' feature requires compiling for a unix target"); + #[proc_macro_derive(Protocol)] pub fn derive_protocol_derive(input: TokenStream) -> TokenStream { let expanded = derive_protocol(input.into()); @@ -212,7 +219,15 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream } }; // Create a struct to handle the connection from the client to the server - let stream_type = quote! { tokio::net::TcpStream }; // TODO: In the future we could support other stream types + #[cfg(feature = "tcp")] + let stream_type = quote! { tokio::net::TcpStream }; + #[cfg(feature = "tcp")] + let stream_addr_trait = quote! { tokio::net::ToSocketAddrs }; + #[cfg(feature = "unix")] + let stream_type = quote! { tokio::net::UnixStream }; + #[cfg(feature = "unix")] + let stream_addr_trait = quote! { std::convert::AsRef }; + let cc_struct = quote! { struct #client_connection_struct_name { to_send: tokio::sync::mpsc::Receiver<(u64, #question_enum_name)>, @@ -286,13 +301,30 @@ fn derive_protocol(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream queries: #queries_struct_name, send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: #client_recv_queue_wrapper, + connection_task: Option<::std::sync::Arc>>, } impl #client_struct_name { - pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>) -> Self { + pub fn new(send_queue: tokio::sync::mpsc::Sender<(u64, #question_enum_name)>, + recv_queue: tokio::sync::mpsc::Receiver<(u64, #answer_enum_name)>, + connection_task: Option<::std::sync::Arc>>) -> Self { Self { queries: #queries_struct_name::new(), recv_queue: #client_recv_queue_wrapper::new(recv_queue), send_queue, + connection_task, + } + } + pub async fn connect(addr: A) -> Result { + let stream = #stream_type::connect(addr).await?; + let (send_queue, to_send) = tokio::sync::mpsc::channel(16); + let (to_recv, recv_queue) = tokio::sync::mpsc::channel(16); + let connection = #client_connection_struct_name::new(to_send, to_recv, stream); + let connection_task = tokio::spawn(connection.run()); + Ok(Self::new(send_queue, recv_queue, Some(::std::sync::Arc::new(connection_task)))) + } + pub fn close(self) { + if let Some(task) = self.connection_task { + task.abort(); } } async fn send(&self, query: #question_enum_name) -> Result { diff --git a/tests/mod.rs b/tests/mod.rs index 75b1838..f56c77b 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -30,7 +30,7 @@ enum TestProtocol { async fn main() { let (qtx, qrx) = mpsc::channel(16); let (atx, arx) = mpsc::channel(16); - let client = TestProtocolClient::new(qtx, arx); + let client = TestProtocolClient::new(qtx, arx, None); let server = tokio::spawn(server_loop(qrx, atx)); let result = client.addition(2, 5).await.unwrap(); assert_eq!(result, 7); @@ -91,7 +91,7 @@ async fn server_loop( async fn heavy_async() { let (qtx, qrx) = mpsc::channel(16); let (atx, arx) = mpsc::channel(16); - let client = TestProtocolClient::new(qtx, arx); + let client = TestProtocolClient::new(qtx, arx, None); let server = tokio::spawn(server_loop(qrx, atx)); let mut tasks = Vec::new(); for i in 0..100 {