From 0d23ac1db73fbe70948937b5089b15a1a3333ca5 Mon Sep 17 00:00:00 2001 From: Troy Benson Date: Mon, 18 Nov 2024 19:19:46 +0000 Subject: [PATCH] settings and signal --- Cargo.lock | 645 +++++++++++++++++- Cargo.toml | 3 + crates/batching/Cargo.toml | 8 + crates/batching/src/batch.rs | 251 +++++++ crates/batching/src/dataloader.rs | 254 +++++++ crates/batching/src/lib.rs | 5 + crates/context/src/lib.rs | 98 ++- crates/foundations/src/batcher/dataloader.rs | 309 --------- crates/foundations/src/batcher/mod.rs | 448 ------------ crates/foundations/src/bootstrap.rs | 4 - crates/foundations/src/http/mod.rs | 1 - crates/foundations/src/http/server/builder.rs | 37 - crates/foundations/src/http/server/mod.rs | 282 -------- .../foundations/src/http/server/stream/mod.rs | 152 ----- .../src/http/server/stream/quic.rs | 544 --------------- .../foundations/src/http/server/stream/tcp.rs | 159 ----- .../foundations/src/http/server/stream/tls.rs | 251 ------- crates/foundations/src/lib.rs | 33 - crates/foundations/src/settings/cli.rs | 49 -- crates/foundations/src/settings/traits.rs | 154 ----- crates/foundations/src/signal.rs | 40 -- crates/h3-webtransport/src/session.rs | 6 +- crates/http/src/backend/tcp/config.rs | 8 +- crates/settings/Cargo.toml | 29 + crates/settings/src/lib.rs | 235 +++++++ crates/signal/Cargo.toml | 13 + crates/signal/src/lib.rs | 125 ++++ examples/settings/Cargo.toml | 19 + examples/settings/src/cli.rs | 16 + rustfmt.toml | 2 +- 30 files changed, 1681 insertions(+), 2499 deletions(-) create mode 100644 crates/batching/Cargo.toml create mode 100644 crates/batching/src/batch.rs create mode 100644 crates/batching/src/dataloader.rs create mode 100644 crates/batching/src/lib.rs delete mode 100644 crates/foundations/src/batcher/dataloader.rs delete mode 100644 crates/foundations/src/batcher/mod.rs delete mode 100644 crates/foundations/src/http/mod.rs delete mode 100644 crates/foundations/src/http/server/builder.rs delete mode 100644 crates/foundations/src/http/server/mod.rs delete mode 100644 crates/foundations/src/http/server/stream/mod.rs delete mode 100644 crates/foundations/src/http/server/stream/quic.rs delete mode 100644 crates/foundations/src/http/server/stream/tcp.rs delete mode 100644 crates/foundations/src/http/server/stream/tls.rs delete mode 100644 crates/foundations/src/settings/traits.rs delete mode 100644 crates/foundations/src/signal.rs create mode 100644 crates/settings/Cargo.toml create mode 100644 crates/settings/src/lib.rs create mode 100644 crates/signal/Cargo.toml create mode 100644 crates/signal/src/lib.rs create mode 100644 examples/settings/Cargo.toml create mode 100644 examples/settings/src/cli.rs diff --git a/Cargo.lock b/Cargo.lock index f0764c840..0188a2214 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -26,6 +38,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -50,6 +68,67 @@ dependencies = [ "winapi", ] +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys 0.59.0", +] + +[[package]] +name = "anyhow" +version = "1.0.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" + +[[package]] +name = "arraydeque" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" + [[package]] name = "async-trait" version = "0.1.83" @@ -170,6 +249,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "bindgen" version = "0.69.5" @@ -198,6 +283,43 @@ name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +dependencies = [ + "serde", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bon" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65b4408cb90c75f462c2428254f2a687c399d1feb22ebdc7511889d07be6cab0" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfc5494814aa273050386f95a7d1fa36dcc677fc796bffe54f34659678335bc0" +dependencies = [ + "darling", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn", +] [[package]] name = "bumpalo" @@ -280,6 +402,33 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "4.5.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_lex" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" + [[package]] name = "cmake" version = "0.1.51" @@ -289,6 +438,12 @@ dependencies = [ "cc", ] +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "combine" version = "4.6.7" @@ -299,6 +454,43 @@ dependencies = [ "memchr", ] +[[package]] +name = "config" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68578f196d2a33ff61b27fae256c3164f65e36382648e30666dde05b8cc9dfdf" +dependencies = [ + "json5", + "nom", + "pathdiff", + "ron", + "rust-ini", + "serde", + "serde_json", + "toml", + "yaml-rust2", +] + +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom", + "once_cell", + "tiny-keccak", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -315,6 +507,66 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" +dependencies = [ + "libc", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "derive_more" version = "1.0.0" @@ -336,6 +588,25 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "dlv-list" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "442039f5147480ba31067cb00ada1adae6892028e40e45fc5de7b7df6dcc1b5f" +dependencies = [ + "const-random", +] + [[package]] name = "dunce" version = "1.0.5" @@ -348,6 +619,15 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -493,6 +773,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -561,12 +851,31 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] + [[package]] name = "hashbrown" version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +[[package]] +name = "hashlink" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -686,6 +995,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "indexmap" version = "2.6.0" @@ -693,9 +1008,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.0", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.12.1" @@ -749,6 +1070,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "json5" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b0db21af676c1ce64250b5f40f3ce2cf27e4e47cb91ed91eb6fe9350b430c1" +dependencies = [ + "pest", + "pest_derive", + "serde", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -832,6 +1164,18 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minijinja" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c37e1b517d1dcd0e51dc36c4567b9d5a29262b3ec8da6cb5d35e27a8fb529b5" +dependencies = [ + "aho-corasick", + "percent-encoding", + "serde", + "serde_json", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -934,6 +1278,16 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "ordered-multimap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49203cdcae0030493bad186b28da2fa25645fa276a51b6fec8010d281e02ef79" +dependencies = [ + "dlv-list", + "hashbrown 0.14.5", +] + [[package]] name = "overload" version = "0.1.1" @@ -961,12 +1315,63 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pathdiff" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61c5ce1153ab5b689d0c074c4e7fc613e942dfb7dd9eea5ab202d2ad91fe361" + [[package]] name = "percent-encoding" version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pest" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879952a81a83930934cbf1786752d6dedc3b1f29e8f8fb2ad1d0a36f377cf442" +dependencies = [ + "memchr", + "thiserror 1.0.66", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d214365f632b123a47fd913301e14c946c61d1c183ee245fa76eb752e59a02dd" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb55586734301717aea2ac313f50b2eb8f60d2fc3dc01d190eefa2e625f60c4e" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pest_meta" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b75da2a70cf4d9cb76833c990ac9cd3923c9a8905a8929789ce347c84564d03d" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + [[package]] name = "pin-project-lite" version = "0.2.15" @@ -1156,6 +1561,28 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ron" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" +dependencies = [ + "base64", + "bitflags", + "serde", + "serde_derive", +] + +[[package]] +name = "rust-ini" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e0698206bcb8882bf2a9ecb4c1e7785db57ff052297085a6efd4fe42302068a" +dependencies = [ + "cfg-if", + "ordered-multimap", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1306,6 +1733,14 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scuffle-batching" +version = "0.1.0" +dependencies = [ + "tokio", + "tokio-util", +] + [[package]] name = "scuffle-context" version = "0.1.0" @@ -1316,6 +1751,18 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "scuffle-foundations" +version = "0.0.0" +dependencies = [ + "anyhow", + "http", + "hyper", + "libc", + "tokio", + "tower-async", +] + [[package]] name = "scuffle-h3-webtransport" version = "0.1.0" @@ -1366,6 +1813,30 @@ dependencies = [ "tracing-subscriber 0.3.18", ] +[[package]] +name = "scuffle-settings" +version = "0.1.0" +dependencies = [ + "bon", + "clap", + "config", + "minijinja", + "serde", + "serde_derive", + "thiserror 2.0.0", + "tracing", + "tracing-subscriber 0.3.18", +] + +[[package]] +name = "scuffle-signal" +version = "0.1.0" +dependencies = [ + "futures", + "libc", + "tokio", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -1432,6 +1903,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1444,6 +1924,27 @@ dependencies = [ "serde", ] +[[package]] +name = "settings-examples" +version = "0.1.0" +dependencies = [ + "scuffle-settings", + "serde", + "serde_derive", + "smart-default", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1459,6 +1960,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "slab" version = "0.4.9" @@ -1483,6 +1993,17 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smart-default" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eb01866308440fc64d6c44d9e86c5cc17adfe33c4d6eed55da9145044d0ffc1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "socket2" version = "0.5.7" @@ -1508,6 +2029,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -1587,6 +2114,15 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -1604,15 +2140,16 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.41.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", "libc", "mio", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -1664,6 +2201,40 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tower" version = "0.5.1" @@ -1680,6 +2251,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-async" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d4b9f1f79aa6c44c843b7c52f55e0a4aa3f39f16fcf9b57e5f6ee6e9b92253" +dependencies = [ + "futures-core", + "futures-util", + "tower-async-layer", + "tower-async-service", +] + +[[package]] +name = "tower-async-layer" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640d292da3a02313e994a5b2d779e7ebe37dc5bf1493df57d92f790d90b3239" + +[[package]] +name = "tower-async-service" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d36d33ce13e35cb9b26ef69ec041f1a1a4c05a101064435b147fe612074a70af" + [[package]] name = "tower-layer" version = "0.3.3" @@ -1787,6 +2382,18 @@ dependencies = [ "tracing-log 0.2.0", ] +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "unicode-ident" version = "1.0.13" @@ -1805,12 +2412,24 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "valuable" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "walkdir" version = "2.5.0" @@ -2025,6 +2644,26 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +dependencies = [ + "memchr", +] + +[[package]] +name = "yaml-rust2" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8902160c4e6f2fb145dbe9d6760a75e3c9522d8bf796ed7047c85919ac7115f8" +dependencies = [ + "arraydeque", + "encoding_rs", + "hashlink", +] + [[package]] name = "zerocopy" version = "0.7.35" diff --git a/Cargo.toml b/Cargo.toml index 83bdad835..d7ab93296 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,9 @@ members = [ "crates/context", "crates/http", "crates/h3-webtransport", + "crates/foundations", + "crates/signal", + "crates/batching", "crates/settings", "examples/settings", ] resolver = "2" diff --git a/crates/batching/Cargo.toml b/crates/batching/Cargo.toml new file mode 100644 index 000000000..03340a76b --- /dev/null +++ b/crates/batching/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "scuffle-batching" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1", default-features = false, features = ["time", "sync", "rt"] } +tokio-util = "0.7" diff --git a/crates/batching/src/batch.rs b/crates/batching/src/batch.rs new file mode 100644 index 000000000..770c00f9a --- /dev/null +++ b/crates/batching/src/batch.rs @@ -0,0 +1,251 @@ +use std::future::Future; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; + +use tokio::sync::oneshot; + +/// A response to a batch request +pub struct BatchResponse { + send: oneshot::Sender, +} + +impl BatchResponse { + /// Create a new batch response + #[must_use] + pub fn new(send: oneshot::Sender) -> Self { + Self { send } + } + + /// Send a response back to the requester + #[inline(always)] + pub fn send(self, response: Resp) { + let _ = self.send.send(response); + } + + /// Send a successful response back to the requester + #[inline(always)] + pub fn send_ok(self, response: O) + where + Resp: From>, + { + self.send(Ok(response).into()) + } + + /// Send an error response back to the requestor + #[inline(always)] + pub fn send_err(self, error: E) + where + Resp: From>, + { + self.send(Err(error).into()) + } +} + +/// A trait for executing batches +pub trait BatchExecutor { + /// The incoming request type + type Request: Send + 'static; + /// The outgoing response type + type Response: Send + Sync + 'static; + + /// Execute a batch of requests + /// You must call `send` on the `BatchResponse` to send the response back to + /// the client + fn execute(&self, requests: Vec<(Self::Request, BatchResponse)>) -> impl Future + Send; +} + +/// A builder for a [`Batcher`] +#[derive(Clone, Copy, Debug)] +#[must_use = "builders must be used to create a batcher"] +pub struct BatcherBuilder { + batch_size: usize, + delay: std::time::Duration, +} + +impl Default for BatcherBuilder { + #[must_use] + fn default() -> Self { + Self::new() + } +} + +impl BatcherBuilder { + /// Create a new builder + pub fn new() -> Self { + Self { + batch_size: 100, + delay: std::time::Duration::from_millis(50), + } + } + + /// Set the batch size + pub fn batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the delay + pub fn delay(mut self, delay: std::time::Duration) -> Self { + self.delay = delay; + self + } + + /// Set the batch size + pub fn with_batch_size(&mut self, batch_size: usize) -> &mut Self { + self.batch_size = batch_size; + self + } + + /// Set the delay + pub fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self { + self.delay = delay; + self + } + + /// Build the batcher + pub fn build(self, executor: E) -> Batcher + where + E: BatchExecutor + Send + Sync + 'static, + { + Batcher::new(executor, self.batch_size, self.delay) + } +} + +/// A batcher used to batch requests to a [`BatchExecutor`] +#[must_use = "batchers must be used to execute batches"] +pub struct Batcher +where + E: BatchExecutor + Send + Sync + 'static, +{ + _auto_spawn: tokio::task::JoinHandle<()>, + executor: Arc, + notify: Arc, + semaphore: Arc, + current_batch: Arc>>>, + batch_size: usize, + batch_id: AtomicU64, +} + +struct Batch +where + E: BatchExecutor + Send + Sync + 'static, +{ + id: u64, + items: Vec<(E::Request, BatchResponse)>, + _ticket: tokio::sync::OwnedSemaphorePermit, +} + +impl Batcher +where + E: BatchExecutor + Send + Sync + 'static, +{ + /// Create a new batcher + pub fn new(executor: E, batch_size: usize, delay: std::time::Duration) -> Self { + let semaphore = Arc::new(tokio::sync::Semaphore::new(batch_size.min(1))); + let notify = Arc::new(tokio::sync::Notify::new()); + let current_batch = Arc::new(tokio::sync::Mutex::new(None)); + let executor = Arc::new(executor); + + let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), notify.clone(), delay)); + + Self { + executor, + _auto_spawn: join_handle, + notify, + semaphore, + current_batch, + batch_size, + batch_id: AtomicU64::new(0), + } + } + + /// Create a builder for a [`Batcher`] + pub fn builder() -> BatcherBuilder { + BatcherBuilder::new() + } + + /// Execute a single request + pub async fn execute(&self, items: E::Request) -> Option { + self.execute_many(std::iter::once(items)).await.pop()? + } + + /// Execute many requests + pub async fn execute_many(&self, items: impl IntoIterator) -> Vec> { + let mut batch = self.current_batch.lock().await; + + let mut responses = Vec::new(); + + for item in items { + if batch.is_none() { + batch.replace( + Batch::new( + self.batch_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), + self.semaphore.clone(), + ) + .await, + ); + } + + let batch_mut = batch.as_mut().unwrap(); + let (tx, rx) = oneshot::channel(); + batch_mut.items.push((item, BatchResponse::new(tx))); + responses.push(rx); + + if batch_mut.items.len() >= self.batch_size { + batch.take().unwrap().spawn(self.executor.clone()).await; + self.notify.notify_one(); + } + } + + let mut results = Vec::new(); + for response in responses { + results.push(response.await.ok()); + } + + results + } +} + +async fn batch_loop( + executor: Arc, + current_batch: Arc>>>, + notify: Arc, + delay: std::time::Duration, +) where + E: BatchExecutor + Send + Sync + 'static, +{ + let mut pending_id = None; + loop { + tokio::time::timeout(delay, notify.notified()).await.ok(); + + let mut batch = current_batch.lock().await; + let Some(batch_id) = batch.as_ref().map(|b| b.id) else { + pending_id = None; + continue; + }; + + if pending_id != Some(batch_id) || batch.as_ref().unwrap().items.is_empty() { + pending_id = Some(batch_id); + continue; + } + + tokio::spawn(batch.take().unwrap().spawn(executor.clone())); + } +} + +impl Batch +where + E: BatchExecutor + Send + Sync + 'static, +{ + async fn new(id: u64, semaphore: Arc) -> Self { + Self { + id, + items: Vec::new(), + _ticket: semaphore.acquire_owned().await.unwrap(), + } + } + + async fn spawn(self, executor: Arc) { + executor.execute(self.items).await; + } +} diff --git a/crates/batching/src/dataloader.rs b/crates/batching/src/dataloader.rs new file mode 100644 index 000000000..efbbb2a7d --- /dev/null +++ b/crates/batching/src/dataloader.rs @@ -0,0 +1,254 @@ +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; + +/// A trait for fetching data in batches +pub trait DataLoaderFetcher { + /// The incoming key type + type Key: Clone + Eq + std::hash::Hash + Send + Sync; + /// The outgoing value type + type Value: Clone + Send + Sync; + + /// Load a batch of keys + fn load(&self, keys: HashSet) -> impl Future>> + Send; +} + +/// A builder for a [`DataLoader`] +#[derive(Clone, Copy, Debug)] +#[must_use = "builders must be used to create a dataloader"] +pub struct DataLoaderBuilder { + batch_size: usize, + delay: std::time::Duration, +} + +impl Default for DataLoaderBuilder { + fn default() -> Self { + Self::new() + } +} + +impl DataLoaderBuilder { + /// Create a new builder + pub fn new() -> Self { + Self { + batch_size: 100, + delay: std::time::Duration::from_millis(50), + } + } + + /// Set the batch size + pub fn batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the delay + pub fn delay(mut self, delay: std::time::Duration) -> Self { + self.delay = delay; + self + } + + /// Build the dataloader + pub fn build(self, executor: E) -> DataLoader + where + E: DataLoaderFetcher + Send + Sync + 'static, + { + DataLoader::new(executor, self.batch_size, self.delay) + } +} + +/// A dataloader used to batch requests to a [`DataLoaderFetcher`] +#[must_use = "dataloaders must be used to load data"] +pub struct DataLoader +where + E: DataLoaderFetcher + Send + Sync + 'static, +{ + _auto_spawn: tokio::task::JoinHandle<()>, + executor: Arc, + notify: Arc, + semaphore: Arc, + current_batch: Arc>>>, + batch_size: usize, + batch_id: AtomicU64, +} + +impl DataLoader +where + E: DataLoaderFetcher + Send + Sync + 'static, +{ + /// Create a new dataloader + pub fn new(executor: E, batch_size: usize, delay: std::time::Duration) -> Self { + let semaphore = Arc::new(tokio::sync::Semaphore::new(batch_size.min(1))); + let notify = Arc::new(tokio::sync::Notify::new()); + let current_batch = Arc::new(tokio::sync::Mutex::new(None)); + let executor = Arc::new(executor); + + let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), notify.clone(), delay)); + + Self { + executor, + _auto_spawn: join_handle, + notify, + semaphore, + current_batch, + batch_size, + batch_id: AtomicU64::new(0), + } + } + + /// Create a builder for a [`DataLoader`] + pub fn builder() -> DataLoaderBuilder { + DataLoaderBuilder::new() + } + + /// Load a single key + /// Can return an error if the underlying [`DataLoaderFetcher`] returns an + /// error + /// + /// Returns `None` if the key is not found + pub async fn load(&self, items: E::Key) -> Result, ()> { + Ok(self.load_many(std::iter::once(items)).await?.into_values().next()) + } + + /// Load many keys + /// Can return an error if the underlying [`DataLoaderFetcher`] returns an + /// error + /// + /// Returns a map of keys to values which may be incomplete if any of the + /// keys were not found + pub async fn load_many(&self, items: impl IntoIterator) -> Result, ()> { + let mut batch = self.current_batch.lock().await; + + struct BatchWaiting { + id: u64, + keys: HashSet, + result: Arc>, + } + + let mut waiters = Vec::>::new(); + + for item in items { + if batch.is_none() { + batch.replace( + Batch::new( + self.batch_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), + self.semaphore.clone(), + ) + .await, + ); + } + + let batch_mut = batch.as_mut().unwrap(); + batch_mut.items.insert(item.clone()); + + if waiters.is_empty() || waiters.last().unwrap().id != batch_mut.id { + waiters.push(BatchWaiting { + id: batch_mut.id, + keys: HashSet::new(), + result: batch_mut.result.clone(), + }); + } + + let waiting = waiters.last_mut().unwrap(); + waiting.keys.insert(item); + + if batch_mut.items.len() >= self.batch_size { + batch.take().unwrap().spawn(self.executor.clone()).await; + self.notify.notify_one(); + } + } + + let mut results = HashMap::new(); + for waiting in waiters { + let result = waiting.result.wait().await?; + results.extend(waiting.keys.into_iter().filter_map(|key| { + let value = result.get(&key)?.clone(); + Some((key, value)) + })); + } + + Ok(results) + } +} + +async fn batch_loop( + executor: Arc, + current_batch: Arc>>>, + notify: Arc, + delay: std::time::Duration, +) where + E: DataLoaderFetcher + Send + Sync + 'static, +{ + let mut pending_id = None; + loop { + tokio::time::timeout(delay, notify.notified()).await.ok(); + + let mut batch = current_batch.lock().await; + let Some(batch_id) = batch.as_ref().map(|b| b.id) else { + pending_id = None; + continue; + }; + + if pending_id != Some(batch_id) || batch.as_ref().unwrap().items.is_empty() { + pending_id = Some(batch_id); + continue; + } + + tokio::spawn(batch.take().unwrap().spawn(executor.clone())); + } +} + +struct BatchResult { + values: tokio::sync::OnceCell>>, + token: tokio_util::sync::CancellationToken, +} + +impl BatchResult { + fn new() -> Self { + Self { + values: tokio::sync::OnceCell::new(), + token: tokio_util::sync::CancellationToken::new(), + } + } + + async fn wait(&self) -> Result<&HashMap, ()> { + self.token.cancelled().await; + self.values.get().ok_or(())?.as_ref().ok_or(()) + } +} + +struct Batch +where + E: DataLoaderFetcher + Send + Sync + 'static, +{ + id: u64, + items: HashSet, + result: Arc>, + _ticket: tokio::sync::OwnedSemaphorePermit, +} + +impl Batch +where + E: DataLoaderFetcher + Send + Sync + 'static, +{ + async fn new(id: u64, semaphore: Arc) -> Self { + Self { + id, + items: HashSet::new(), + result: Arc::new(BatchResult::new()), + _ticket: semaphore.acquire_owned().await.unwrap(), + } + } + + async fn spawn(self, executor: Arc) { + let _drop_guard = self.result.token.clone().drop_guard(); + let result = executor.load(self.items).await; + match self.result.values.set(result) { + Ok(()) => {} + Err(_) => unreachable!( + "batch result already set, this is a bug please report it https://github.com/scuffletv/scuffle/issues" + ), + } + } +} diff --git a/crates/batching/src/lib.rs b/crates/batching/src/lib.rs new file mode 100644 index 000000000..2fd5916bd --- /dev/null +++ b/crates/batching/src/lib.rs @@ -0,0 +1,5 @@ +pub mod batch; +pub mod dataloader; + +pub use batch::{BatchExecutor, Batcher}; +pub use dataloader::{DataLoader, DataLoaderFetcher}; diff --git a/crates/context/src/lib.rs b/crates/context/src/lib.rs index 93887abee..dad3a3e08 100644 --- a/crates/context/src/lib.rs +++ b/crates/context/src/lib.rs @@ -56,6 +56,15 @@ impl ContextTrackerInner { } } +/// A context for cancelling futures and waiting for shutdown +/// +/// A context can be created from a handler or another context so to have a +/// hierarchy of contexts +/// +/// Contexts can then be attached to futures or streams in order to +/// automatically cancel them when the context is done, when invoking +/// `Handler::cancel`. The `Handler::shutdown` method will block until all +/// contexts have been dropped allowing for a graceful shutdown. #[derive(Debug)] pub struct Context { token: CancellationToken, @@ -73,11 +82,15 @@ impl Clone for Context { impl Context { #[must_use] + /// Create a new context using the global handler + /// Returns a tuple and a child handler pub fn new() -> (Self, Handler) { Handler::global().new_child() } #[must_use] + /// Create a new child context from this context + /// Returns a tuple and a child handler pub fn new_child(&self) -> (Self, Handler) { let token = self.token.child_token(); let tracker = ContextTrackerInner::new(); @@ -88,25 +101,29 @@ impl Context { token: token.clone(), }, Handler { - _token: Arc::new(TokenDropGuard(token)), + token: Arc::new(TokenDropGuard(token)), tracker, }, ) } #[must_use] + /// Returns the global context pub fn global() -> Self { Handler::global().context() } + /// Waits for the context to be done (the handler to be shutdown) pub async fn done(&self) { self.token.cancelled().await; } + /// The same as done but takes ownership of the context pub async fn into_done(self) { self.done().await; } + /// Returns true if the context is done #[must_use] pub fn is_done(&self) -> bool { self.token.is_cancelled() @@ -135,7 +152,7 @@ impl Drop for TokenDropGuard { #[derive(Debug, Clone)] pub struct Handler { - _token: Arc, + token: Arc, tracker: Arc, } @@ -147,56 +164,71 @@ impl Default for Handler { impl Handler { #[must_use] + /// Create a new handler pub fn new() -> Handler { let token = CancellationToken::new(); let tracker = ContextTrackerInner::new(); Handler { - _token: Arc::new(TokenDropGuard(token)), + token: Arc::new(TokenDropGuard(token)), tracker, } } #[must_use] + /// Returns the global handler pub fn global() -> &'static Self { static GLOBAL: std::sync::OnceLock = std::sync::OnceLock::new(); GLOBAL.get_or_init(Handler::new) } + /// Shutdown the handler and wait for all contexts to be done pub async fn shutdown(&self) { self.cancel(); self.done().await; } + /// Waits for the handler to be done (waiting for all contexts to be done) pub async fn done(&self) { - self._token.0.cancelled().await; + self.token.0.cancelled().await; self.tracker.wait().await; } #[must_use] + /// Create a new context from this handler pub fn context(&self) -> Context { Context { - token: self._token.child(), + token: self.token.child(), tracker: self.tracker.child(), } } #[must_use] + /// Create a new child context from this handler pub fn new_child(&self) -> (Context, Handler) { self.context().new_child() } + /// Cancel the handler pub fn cancel(&self) { self.tracker.stop(); - self._token.cancel(); + self.token.cancel(); } } pin_project_lite::pin_project! { - #[project = ContextRefProj] - pub enum ContextRef<'a> { - #[allow(private_interfaces)] + /// A reference to a context + /// Can either be owned or borrowed + pub struct ContextRef<'a> { + #[pin] + inner: ContextRefInner<'a>, + } +} + +pin_project_lite::pin_project! { + #[project = ContextRefInnerProj] + enum ContextRefInner<'a> { Owned { #[pin] fut: WaitForCancellationFutureOwned, tracker: ContextTracker, @@ -207,33 +239,41 @@ pin_project_lite::pin_project! { } } -impl ContextRef<'_> { - pub fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<()> { - match self.project() { - ContextRefProj::Owned { fut, .. } => fut.poll(cx), - ContextRefProj::Ref { fut } => fut.poll(cx), +impl<'a> std::future::Future for ContextRef<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + match self.project().inner.project() { + ContextRefInnerProj::Owned { fut, .. } => fut.poll(cx), + ContextRefInnerProj::Ref { fut } => fut.poll(cx), } } } impl From for ContextRef<'_> { fn from(ctx: Context) -> Self { - ContextRef::Owned { - fut: ctx.token.cancelled_owned(), - tracker: ctx.tracker, + ContextRef { + inner: ContextRefInner::Owned { + fut: ctx.token.cancelled_owned(), + tracker: ctx.tracker, + }, } } } impl<'a> From<&'a Context> for ContextRef<'a> { fn from(ctx: &'a Context) -> Self { - ContextRef::Ref { - fut: ctx.token.cancelled(), + ContextRef { + inner: ContextRefInner::Ref { + fut: ctx.token.cancelled(), + }, } } } pub trait ContextFutExt { + /// Wraps a future with a context, allowing the future to be cancelled when + /// the context is done fn with_context<'a>(self, ctx: impl Into>) -> FutureWithContext<'a, Fut> where Self: Sized; @@ -253,6 +293,8 @@ impl ContextFutExt for F { } pub trait ContextStreamExt { + /// Wraps a stream with a context, allowing the stream to be stopped when + /// the context is done fn with_context<'a>(self, ctx: impl Into>) -> StreamWithContext<'a, Stream> where Self: Sized; @@ -269,6 +311,9 @@ impl ContextStreamExt for F { } pin_project_lite::pin_project! { + /// A future with a context attached to it. + /// + /// This future will be cancelled when the context is done. pub struct FutureWithContext<'a, F> { #[pin] future: F, @@ -281,10 +326,10 @@ pin_project_lite::pin_project! { impl Future for FutureWithContext<'_, F> { type Output = Option; - fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - let mut this = self.as_mut().project(); + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let this = self.project(); - match (this.ctx.as_mut().poll(cx), this.future.poll(cx)) { + match (this.ctx.poll(cx), this.future.poll(cx)) { (_, Poll::Ready(v)) => std::task::Poll::Ready(Some(v)), (Poll::Ready(_), Poll::Pending) => std::task::Poll::Ready(None), _ => std::task::Poll::Pending, @@ -293,6 +338,9 @@ impl Future for FutureWithContext<'_, F> { } pin_project_lite::pin_project! { + /// A stream with a context attached to it. + /// + /// This stream will be cancelled when the context is done. pub struct StreamWithContext<'a, F> { #[pin] stream: F, @@ -305,10 +353,10 @@ pin_project_lite::pin_project! { impl Stream for StreamWithContext<'_, F> { type Item = F::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let mut this = self.as_mut().project(); + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.project(); - match (this.ctx.as_mut().poll(cx), this.stream.poll_next(cx)) { + match (this.ctx.poll(cx), this.stream.poll_next(cx)) { (_, Poll::Ready(v)) => std::task::Poll::Ready(v), (Poll::Ready(_), Poll::Pending) => std::task::Poll::Ready(None), _ => std::task::Poll::Pending, diff --git a/crates/foundations/src/batcher/dataloader.rs b/crates/foundations/src/batcher/dataloader.rs deleted file mode 100644 index f596f2516..000000000 --- a/crates/foundations/src/batcher/dataloader.rs +++ /dev/null @@ -1,309 +0,0 @@ -use std::collections::HashMap; -use std::hash::{BuildHasher, RandomState}; -use std::marker::PhantomData; - -use super::{BatchOperation, Batcher, BatcherConfig, BatcherDataloader, BatcherError}; - -#[allow(type_alias_bounds)] -pub type LoaderOutput, S: BuildHasher = RandomState> = Result, ()>; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd, Default)] -struct Unit; - -impl std::error::Error for Unit {} - -impl std::fmt::Display for Unit { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "unknown") - } -} - -impl From<()> for Unit { - fn from(_: ()) -> Self { - Self - } -} - -pub trait Loader { - type Key: Clone + Eq + std::hash::Hash + Send + Sync; - type Value: Clone + Send + Sync; - - fn config(&self) -> BatcherConfig { - BatcherConfig { - name: std::any::type_name::().to_string(), - concurrency: 10, - max_batch_size: 1000, - sleep_duration: std::time::Duration::from_millis(5), - } - } - - fn fetch(&self, keys: Vec) -> impl std::future::Future> + Send; -} - -pub struct DataLoader, S: BuildHasher + Default + Send + Sync = RandomState> { - batcher: Batcher>, -} - -impl + 'static + Send + Sync, S: BuildHasher + Default + Send + Sync + 'static> DataLoader { - pub fn new(loader: L) -> Self { - Self { - batcher: Batcher::new(Wrapper(loader, PhantomData)), - } - } - - #[tracing::instrument(skip_all, fields(name = %self.batcher.inner.name))] - pub async fn load(&self, key: L::Key) -> Result, ()> { - self.internal_load_many(std::iter::once(key.clone())) - .await - .map(|mut map| map.remove(&key)) - } - - #[tracing::instrument(skip_all, fields(name = %self.batcher.inner.name))] - pub async fn load_many(&self, keys: impl IntoIterator) -> Result, ()> { - self.internal_load_many(keys).await - } - - async fn internal_load_many(&self, keys: impl IntoIterator) -> LoaderOutput { - self.batcher.internal_execute_many(keys).await.map_err(|err| match err { - BatcherError::Batch(Unit) => {} - err => tracing::error!("failed to load data: {err}"), - }) - } -} - -struct Wrapper, S: BuildHasher + Default = RandomState>(L, PhantomData); - -impl, S: BuildHasher + Default + Send + Sync> BatchOperation for Wrapper { - type Error = Unit; - type Item = L::Key; - type Mode = BatcherDataloader; - type Response = L::Value; - - fn config(&self) -> BatcherConfig { - self.0.config() - } - - async fn process( - &self, - documents: >::Input, - ) -> Result<>::OperationOutput, Self::Error> - where - Self: Send + Sync, - { - self.0.fetch(documents.into_iter().collect()).await.map_err(|()| Unit) - } -} - -#[cfg(test)] -mod tests { - use std::collections::hash_map::RandomState; - use std::collections::HashMap; - - use super::{DataLoader, LoaderOutput}; - use crate::batcher::BatcherConfig; - - type DynBoxLoader = Box) -> HashMap + Sync + Send>; - - struct LoaderTest { - results: DynBoxLoader, - config: BatcherConfig, - } - - impl super::Loader for LoaderTest { - type Key = u64; - type Value = u64; - - fn config(&self) -> BatcherConfig { - self.config.clone() - } - - async fn fetch(&self, keys: Vec) -> LoaderOutput { - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; - Ok((self.results)(keys)) - } - } - - #[tokio::test] - async fn test_data_loader() { - let run_count = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)); - - let loader = LoaderTest { - results: Box::new(move |keys| { - let mut results = HashMap::new(); - - assert_eq!(run_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst), 0); - - assert_eq!(keys.len(), 250); - - for key in keys { - assert!(!results.contains_key(&key)); - - results.insert(key, key * 2); - } - - results - }), - config: BatcherConfig { - name: "test".to_string(), - concurrency: 10, - max_batch_size: 1000, - sleep_duration: std::time::Duration::from_millis(5), - }, - }; - - let dataloader = DataLoader::new(loader); - - let futures = (0..250) - .map(|i| dataloader.load(i as u64)) - .chain((0..250).map(|i| dataloader.load(i as u64))); - - let results = futures::future::join_all(futures).await; - - let expected = (0..250) - .map(|i| Ok(Some(i * 2))) - .chain((0..250).map(|i| Ok(Some(i * 2)))) - .collect::>(); - - assert_eq!(results, expected); - } - - #[tokio::test] - async fn test_data_loader_larger() { - let loader = LoaderTest { - results: Box::new(move |keys| { - let mut results = HashMap::new(); - - assert!(keys.len() <= 1000); - - for key in keys { - assert!(!results.contains_key(&key)); - - results.insert(key, key * 2); - } - - results - }), - config: BatcherConfig { - name: "test".to_string(), - concurrency: 10, - max_batch_size: 1000, - sleep_duration: std::time::Duration::from_millis(5), - }, - }; - - let dataloader = DataLoader::new(loader); - - const LIMIT: usize = 10_000; - - let results = futures::future::join_all((0..LIMIT).map(|i| dataloader.load(i as u64))).await; - - let expected = (0..LIMIT).map(|i| Ok(Some(i as u64 * 2))).collect::>(); - - assert_eq!(results, expected); - } - - #[tokio::test] - async fn test_data_loader_change_batch_size() { - let loader = LoaderTest { - results: Box::new(move |keys| { - let mut results = HashMap::new(); - - assert!(keys.len() <= 3000); - - for key in keys { - assert!(!results.contains_key(&key)); - - results.insert(key, key * 2); - } - - results - }), - config: BatcherConfig { - name: "test".to_string(), - concurrency: 10, - max_batch_size: 3000, - sleep_duration: std::time::Duration::from_millis(5), - }, - }; - - let dataloader = DataLoader::new(loader); - - let futures = (0..5000).map(|i| dataloader.load(i as u64)); - - let results = futures::future::join_all(futures).await; - - let expected = (0..5000).map(|i| Ok(Some(i * 2))).collect::>(); - - assert_eq!(results, expected); - } - - #[tokio::test] - async fn test_data_loader_change_duration() { - let loader = LoaderTest { - results: Box::new(move |keys| { - let mut results = HashMap::new(); - - assert!(keys.len() <= 1000); - - for key in keys { - assert!(!results.contains_key(&key)); - - results.insert(key, key * 2); - } - - results - }), - config: BatcherConfig { - name: "test".to_string(), - concurrency: 10, - max_batch_size: 1000, - sleep_duration: std::time::Duration::from_millis(100), - }, - }; - - let dataloader = DataLoader::new(loader); - - let futures = (0..250) - .map(|i| dataloader.load(i as u64)) - .chain((0..250).map(|i| dataloader.load(i as u64))); - - let results = futures::future::join_all(futures).await; - - let expected = (0..250) - .map(|i| Ok(Some(i * 2))) - .chain((0..250).map(|i| Ok(Some(i * 2)))) - .collect::>(); - - assert_eq!(results, expected); - } - - #[tokio::test] - async fn test_data_loader_value_deduplication() { - let run_count = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)); - - let loader = LoaderTest { - results: Box::new({ - let run_count = run_count.clone(); - move |keys| { - run_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - keys.iter().map(|&key| (key, key * 2)).collect() - } - }), - config: BatcherConfig { - name: "test".to_string(), - concurrency: 10, - max_batch_size: 1000, - sleep_duration: std::time::Duration::from_millis(5), - }, - }; - - let dataloader = DataLoader::new(loader); - - let futures = vec![dataloader.load(5), dataloader.load(5), dataloader.load(5)]; - - let results: Vec<_> = futures::future::join_all(futures).await; - - assert_eq!(results, vec![Ok(Some(10)), Ok(Some(10)), Ok(Some(10))]); - assert_eq!(run_count.load(std::sync::atomic::Ordering::SeqCst), 1); // Ensure the loader was only called once - } -} diff --git a/crates/foundations/src/batcher/mod.rs b/crates/foundations/src/batcher/mod.rs deleted file mode 100644 index 4a6ac0ceb..000000000 --- a/crates/foundations/src/batcher/mod.rs +++ /dev/null @@ -1,448 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::hash::{BuildHasher, RandomState}; -use std::marker::PhantomData; -use std::sync::atomic::{AtomicU64, AtomicUsize}; -use std::sync::Arc; - -use tokio::sync::OnceCell; -use tracing::Instrument; - -pub mod dataloader; - -pub trait BatchMode: Sized { - type Input: Send + Sync; - type Output: Send + Sync; - type OutputItem: Send + Sync; - type OperationOutput: Send + Sync; - type FinalOutput: Send + Sync; - type Tracker: Send + Sync; - - fn new_input() -> Self::Input; - fn new_tracker() -> Self::Tracker; - fn new_output() -> Self::Output; - - fn input_add(input: &mut Self::Input, tracker: &mut Self::Tracker, item: T::Item); - fn input_len(input: &Self::Input) -> usize; - - fn tracked_output( - result: Option<&Result>>, - tracker: Self::Tracker, - output: &mut Self::Output, - ) -> Result<(), BatcherError>; - - fn final_output_into_iter( - output: Self::FinalOutput, - ) -> Result, BatcherError>; - - fn filter_item_iter(item: impl IntoIterator) -> impl IntoIterator; - - fn output_item_to_result(item: Self::OutputItem) -> Result>; - - fn output_into_final_output(output: Result>) -> Self::FinalOutput; -} - -pub struct BatcherNormalMode; - -impl BatchMode for BatcherNormalMode { - type FinalOutput = Self::Output; - type Input = Vec; - type OperationOutput = Vec>; - type Output = Vec; - type OutputItem = Result>; - type Tracker = std::ops::Range; - - fn new_input() -> Self::Input { - Vec::new() - } - - fn new_tracker() -> Self::Tracker { - 0..0 - } - - fn new_output() -> Self::Output { - Vec::new() - } - - fn input_add(input: &mut Self::Input, tracker: &mut Self::Tracker, item: T::Item) { - input.push(item); - tracker.end = input.len(); - } - - fn input_len(input: &Self::Input) -> usize { - input.len() - } - - fn tracked_output( - result: Option<&Result>>, - tracker: Self::Tracker, - output: &mut Self::Output, - ) -> Result<(), BatcherError> { - for i in tracker.into_iter() { - match result { - Some(Ok(r)) => output.push( - r.get(i) - .cloned() - .transpose() - .map_err(BatcherError::Batch) - .transpose() - .unwrap_or(Err(BatcherError::MissingResult)), - ), - Some(Err(e)) => output.push(Err(e.clone())), - None => output.push(Err(BatcherError::Panic)), - } - } - - Ok(()) - } - - fn final_output_into_iter( - output: Self::FinalOutput, - ) -> Result, BatcherError> { - Ok(output) - } - - fn filter_item_iter(item: impl IntoIterator) -> impl IntoIterator { - item - } - - fn output_item_to_result(item: Self::OutputItem) -> Result> { - item - } - - fn output_into_final_output( - output: Result::Error>>, - ) -> Self::FinalOutput { - output.expect("erro shouldnt be possible here") - } -} - -pub struct BatcherDataloader(PhantomData); - -impl BatchMode for BatcherDataloader -where - T::Item: Clone + std::hash::Hash + Eq, -{ - type FinalOutput = Result, BatcherError>; - type Input = HashSet; - type OperationOutput = HashMap; - type Output = Self::OperationOutput; - type OutputItem = T::Response; - type Tracker = Vec; - - fn new_input() -> Self::Input { - HashSet::default() - } - - fn new_tracker() -> Self::Tracker { - Vec::new() - } - - fn new_output() -> Self::Output { - HashMap::default() - } - - fn input_add(input: &mut Self::Input, tracker: &mut Self::Tracker, item: T::Item) { - input.insert(item.clone()); - tracker.push(item); - } - - fn input_len(input: &Self::Input) -> usize { - input.len() - } - - fn tracked_output( - result: Option<&Result>>, - tracker: Self::Tracker, - output: &mut Self::Output, - ) -> Result<(), BatcherError> { - for key in tracker.clone().into_iter() { - match result { - Some(Ok(res)) => { - if let Some(value) = res.get(&key).cloned() { - output.insert(key, value); - } - } - Some(Err(e)) => { - return Err(e.clone()); - } - None => { - return Err(BatcherError::Panic); - } - } - } - - Ok(()) - } - - fn final_output_into_iter( - output: Self::FinalOutput, - ) -> Result, BatcherError> { - output.map(|output| output.into_values()) - } - - fn filter_item_iter(item: impl IntoIterator) -> impl IntoIterator { - item - } - - fn output_item_to_result(item: Self::OutputItem) -> Result> { - Ok(item) - } - - fn output_into_final_output( - output: Result::Error>>, - ) -> Self::FinalOutput { - output - } -} - -pub trait BatchOperation { - type Item: Send + Sync; - type Response: Clone + Send + Sync; - type Error: Clone + std::fmt::Debug + Send + Sync; - type Mode: BatchMode; - - fn config(&self) -> BatcherConfig; - - fn process( - &self, - documents: >::Input, - ) -> impl std::future::Future>::OperationOutput, Self::Error>> + Send + '_ - where - Self: Send + Sync; -} - -pub struct Batcher { - inner: Arc>, - _auto_loader_abort: CancelOnDrop, -} - -struct CancelOnDrop(tokio::task::AbortHandle); - -impl Drop for CancelOnDrop { - fn drop(&mut self) { - self.0.abort(); - } -} - -struct BatcherInner { - semaphore: tokio::sync::Semaphore, - notify: tokio::sync::Notify, - sleep_duration: AtomicU64, - batch_id: AtomicU64, - max_batch_size: AtomicUsize, - operation: T, - name: String, - active_batch: tokio::sync::RwLock>>, -} - -struct Batch { - id: u64, - expires_at: tokio::time::Instant, - done: DropGuardCancellationToken, - ops: >::Input, - #[allow(clippy::type_complexity)] - results: Arc>::OperationOutput, BatcherError>>>, -} - -struct DropGuardCancellationToken(tokio_util::sync::CancellationToken); - -impl Drop for DropGuardCancellationToken { - fn drop(&mut self) { - self.0.cancel(); - } -} - -impl DropGuardCancellationToken { - fn new() -> Self { - Self(tokio_util::sync::CancellationToken::new()) - } - - fn child_token(&self) -> tokio_util::sync::CancellationToken { - self.0.child_token() - } -} - -struct BatchInsertWaiter { - id: u64, - done: tokio_util::sync::CancellationToken, - tracker: >::Tracker, - #[allow(clippy::type_complexity)] - results: Arc>::OperationOutput, BatcherError>>>, -} - -#[derive(thiserror::Error, Debug, Clone, PartialEq, Copy, Eq, Hash, Ord, PartialOrd)] -pub enum BatcherError { - #[error("failed to acquire semaphore")] - AcquireSemaphore, - #[error("panic in batch inserter")] - Panic, - #[error("missing result")] - MissingResult, - #[error("batch failed with: {0}")] - Batch(E), -} - -impl From for BatcherError { - fn from(value: E) -> Self { - Self::Batch(value) - } -} - -impl Batch { - #[tracing::instrument(skip_all, fields(name = %inner.name), level = "debug")] - async fn run(self, inner: Arc>) { - self.results - .get_or_init(|| async move { - let _ticket = inner - .semaphore - .acquire() - .instrument(tracing::debug_span!("Semaphore")) - .await - .map_err(|_| BatcherError::AcquireSemaphore)?; - inner.operation.process(self.ops).await.map_err(BatcherError::Batch) - }) - .await; - } -} - -#[derive(Clone)] -pub struct BatcherConfig { - pub name: String, - pub concurrency: usize, - pub max_batch_size: usize, - pub sleep_duration: std::time::Duration, -} - -impl BatcherInner { - fn spawn_batch(self: &Arc, batch: Batch) { - tokio::spawn(batch.run(self.clone())); - } - - fn new_batch(&self) -> Batch { - let id = self.batch_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let expires_at = tokio::time::Instant::now() - + tokio::time::Duration::from_nanos(self.sleep_duration.load(std::sync::atomic::Ordering::Relaxed)); - - Batch { - id, - expires_at, - ops: T::Mode::new_input(), - done: DropGuardCancellationToken::new(), - results: Arc::new(OnceCell::new()), - } - } - - async fn batch_inserts(self: &Arc, documents: impl IntoIterator) -> Vec> { - let mut waiters = vec![]; - let mut batch = self.active_batch.write().await; - let max_documents = self.max_batch_size.load(std::sync::atomic::Ordering::Relaxed); - - for document in T::Mode::filter_item_iter(documents) { - if batch - .as_ref() - .map(|b| T::Mode::input_len(&b.ops) >= max_documents) - .unwrap_or(true) - { - if let Some(b) = batch.take() { - self.spawn_batch(b); - } - - *batch = Some(self.new_batch()); - self.notify.notify_one(); - } - - let Some(b) = batch.as_mut() else { - unreachable!("batch should be Some"); - }; - - if waiters.last().map(|w: &BatchInsertWaiter| w.id != b.id).unwrap_or(true) { - waiters.push(BatchInsertWaiter { - id: b.id, - done: b.done.child_token(), - results: b.results.clone(), - tracker: T::Mode::new_tracker(), - }); - } - - let tracker = &mut waiters.last_mut().unwrap().tracker; - T::Mode::input_add(&mut b.ops, tracker, document); - } - - waiters - } -} - -impl Batcher { - pub fn new(operation: T) -> Self { - let config = operation.config(); - - let inner = Arc::new(BatcherInner { - semaphore: tokio::sync::Semaphore::new(config.concurrency), - notify: tokio::sync::Notify::new(), - batch_id: AtomicU64::new(0), - active_batch: tokio::sync::RwLock::new(None), - sleep_duration: AtomicU64::new(config.sleep_duration.as_nanos() as u64), - max_batch_size: AtomicUsize::new(config.max_batch_size), - operation, - name: config.name, - }); - - Self { - inner: inner.clone(), - _auto_loader_abort: CancelOnDrop( - tokio::task::spawn(async move { - loop { - inner.notify.notified().await; - let Some((id, expires_at)) = inner.active_batch.read().await.as_ref().map(|b| (b.id, b.expires_at)) - else { - continue; - }; - - if expires_at > tokio::time::Instant::now() { - tokio::time::sleep_until(expires_at).await; - } - - let mut batch = inner.active_batch.write().await; - if batch.as_ref().is_some_and(|b| b.id == id) { - inner.spawn_batch(batch.take().unwrap()); - } - } - }) - .abort_handle(), - ), - } - } - - #[tracing::instrument(skip_all, fields(name = %self.inner.name))] - pub async fn execute(&self, document: T::Item) -> Result> { - let output = self.internal_execute_many(std::iter::once(document)).await; - let iter = T::Mode::final_output_into_iter(output)?; - T::Mode::output_item_to_result(iter.into_iter().next().ok_or(BatcherError::MissingResult)?) - } - - #[tracing::instrument(skip_all, fields(name = %self.inner.name))] - pub async fn execute_many( - &self, - documents: impl IntoIterator, - ) -> >::FinalOutput { - self.internal_execute_many(documents).await - } - - pub(crate) async fn internal_execute_many( - &self, - documents: impl IntoIterator, - ) -> >::FinalOutput { - let waiters = self.inner.batch_inserts(documents).await; - - let mut results = >::new_output(); - - for waiter in waiters { - waiter.done.cancelled().await; - if let Err(e) = >::tracked_output(waiter.results.get(), waiter.tracker, &mut results) { - return >::output_into_final_output(Err(e)); - } - } - - >::output_into_final_output(Ok(results)) - } -} diff --git a/crates/foundations/src/bootstrap.rs b/crates/foundations/src/bootstrap.rs index 3da037f27..3c0433de1 100644 --- a/crates/foundations/src/bootstrap.rs +++ b/crates/foundations/src/bootstrap.rs @@ -70,10 +70,6 @@ pub trait Bootstrap: Sized + From { fn telemetry(&self) -> Option { None } - - fn additional_args() -> Vec { - vec![] - } } impl Bootstrap for () { diff --git a/crates/foundations/src/http/mod.rs b/crates/foundations/src/http/mod.rs deleted file mode 100644 index 74f47ad34..000000000 --- a/crates/foundations/src/http/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod server; diff --git a/crates/foundations/src/http/server/builder.rs b/crates/foundations/src/http/server/builder.rs deleted file mode 100644 index 61c13e9cb..000000000 --- a/crates/foundations/src/http/server/builder.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::net::SocketAddr; -use std::sync::Arc; - -use hyper_util::rt::TokioExecutor; - -pub enum BackendBuilder { - Tcp(TcpBackendBuilder), - #[cfg(feature = "http-tls")] - Tls(TlsBackendBuilder), - #[cfg(feature = "http3")] - Quic(QuicBackendBuilder), -} - -pub struct TcpBackendBuilder { - worker_count: usize, - keep_alive_timeout: Option, - http_builder: hyper_util::server::conn::auto::Builder, - make_listener: Box std::io::Result>, -} - -#[cfg(feature = "http-tls")] -pub struct TlsBackendBuilder { - worker_count: usize, - keep_alive_timeout: Option, - tls: Arc, - http_builder: hyper_util::server::conn::auto::Builder, - make_listener: Box std::io::Result>, -} - -#[cfg(feature = "http3")] -pub struct QuicBackendBuilder { - worker_count: usize, - keep_alive_timeout: Option, - config: quinn::ServerConfig, - http_builder: Arc, - make_listener: Box std::io::Result>, -} diff --git a/crates/foundations/src/http/server/mod.rs b/crates/foundations/src/http/server/mod.rs deleted file mode 100644 index ea519d714..000000000 --- a/crates/foundations/src/http/server/mod.rs +++ /dev/null @@ -1,282 +0,0 @@ -use std::net::SocketAddr; -use std::sync::Arc; - -mod builder; -pub mod stream; - -pub use axum; -use hyper_util::rt::TokioExecutor; -use tokio::spawn; -#[cfg(feature = "tracing")] -use tracing::Instrument; - -pub use self::builder::ServerBuilder; -#[cfg(feature = "http3")] -use self::stream::quic::QuicBackend; -use self::stream::tcp::TcpBackend; -#[cfg(feature = "http-tls")] -use self::stream::tls::TlsBackend; -use self::stream::{Backend, MakeService}; - -#[cfg(feature = "http3")] -#[derive(Clone)] -struct Quic { - h3: Arc, - config: quinn::ServerConfig, -} - -#[cfg(feature = "http3")] -impl std::fmt::Debug for Quic { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Quic").finish() - } -} - -pub struct Server { - bind: SocketAddr, - #[cfg(feature = "http-tls")] - insecure_bind: Option, - #[cfg(feature = "http-tls")] - tls: Option>, - http1_2: Arc>, - #[cfg(feature = "http3")] - quic: Option, - make_service: M, - backends: Vec, - handler: Option, - worker_count: usize, - keep_alive_timeout: Option, - make_tcp_socket: fn(SocketAddr) -> std::io::Result, - #[cfg(feature = "http3")] - make_udp_socket: fn(SocketAddr) -> std::io::Result, -} - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("io: {0}")] - Io(#[from] std::io::Error), - #[error("no bind address specified")] - NoBindAddress, - #[cfg(feature = "http3")] - #[error("quinn connection: {0}")] - QuinnConnection(#[from] quinn::ConnectionError), - #[cfg(feature = "http3")] - #[error("http3: {0}")] - Http3(#[from] h3::Error), - #[error("connection closed")] - ConnectionClosed, - #[error("axum: {0}")] - Axum(#[from] axum::Error), - #[error("{0}")] - Other(#[from] Box), - #[error("task join: {0}")] - TaskJoin(#[from] tokio::task::JoinError), -} - -impl Server<()> { - pub fn builder() -> ServerBuilder { - ServerBuilder::default() - } -} - -struct AbortOnDrop(Option>>); - -impl AbortOnDrop { - fn new(inner: tokio::task::JoinHandle>) -> Self { - Self(Some(inner)) - } - - fn into_inner(mut self) -> tokio::task::JoinHandle> { - let inner = self.0.take(); - inner.expect("inner task handle already taken") - } -} - -impl Drop for AbortOnDrop { - fn drop(&mut self) { - if let Some(inner) = self.0.take() { - inner.abort(); - } - } -} - -fn ip_mode(addr: SocketAddr) -> std::io::Result { - if addr.ip().is_ipv4() { - Ok(socket2::Domain::IPV4) - } else if addr.ip().is_ipv6() { - Ok(socket2::Domain::IPV6) - } else { - Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid ip address")) - } -} - -// fn make_tcp_listener(addr: SocketAddr) -> -// std::io::Result { let socket = -// socket2::Socket::new(ip_mode(addr)?, socket2::Type::STREAM, -// Some(socket2::Protocol::TCP))?; - -// socket.set_nonblocking(true)?; -// socket.set_reuse_address(true)?; -// socket.set_reuse_port(true)?; -// socket.bind(&socket2::SockAddr::from(addr))?; -// socket.listen(1024)?; - -// Ok(socket.into()) -// } - -// #[cfg(feature = "http3")] -// fn make_udp_socket(addr: SocketAddr) -> std::io::Result -// { let socket = socket2::Socket::new(ip_mode(addr)?, socket2::Type::DGRAM, -// Some(socket2::Protocol::UDP))?; - -// socket.set_nonblocking(true)?; -// socket.set_reuse_address(true)?; -// socket.set_reuse_port(true)?; -// socket.bind(&socket2::SockAddr::from(addr))?; - -// Ok(socket.into()) -// } - -impl Server { - pub async fn start(&mut self) -> Result<(), Error> { - self.backends.clear(); - if let Some(handler) = self.handler.take() { - handler.shutdown().await; - } - - let ctx = { - let (ctx, handler) = crate::context::Context::new(); - self.handler = Some(handler); - ctx - }; - - #[cfg(feature = "http-tls")] - if let Some(tls) = self.tls.clone() { - let acceptor = Arc::new(tokio_rustls::TlsAcceptor::from(tls)); - for i in 0..self.worker_count { - let tcp_listener = tokio::net::TcpListener::from_std((self.make_tcp_socket)(self.bind)?)?; - let make_service = self.make_service.clone(); - let backend = TlsBackend::new(tcp_listener, acceptor.clone(), self.http1_2.clone(), &ctx) - .with_keep_alive_timeout(self.keep_alive_timeout); - let span = tracing::info_span!("tls", addr = %self.bind, worker = i); - self.backends - .push(AbortOnDrop::new(spawn(backend.serve(make_service).instrument(span)))); - } - } else if self.insecure_bind.is_none() { - self.insecure_bind = Some(self.bind); - } - - #[cfg(feature = "http-tls")] - let bind = self.insecure_bind; - #[cfg(not(feature = "http-tls"))] - let bind = Some(self.bind); - - if let Some(addr) = bind { - for i in 0..self.worker_count { - let tcp_listener = tokio::net::TcpListener::from_std((self.make_tcp_socket)(addr)?)?; - let make_service = self.make_service.clone(); - let backend = TcpBackend::new(tcp_listener, self.http1_2.clone(), &ctx) - .with_keep_alive_timeout(self.keep_alive_timeout); - let span = tracing::info_span!("tcp", addr = %addr, worker = i); - self.backends - .push(AbortOnDrop::new(spawn(backend.serve(make_service).instrument(span)))); - } - } - - #[cfg(feature = "http3")] - if let Some(quic) = &self.quic { - for i in 0..self.worker_count { - let socket = (self.make_udp_socket)(self.bind)?; - let endpoint = quinn::Endpoint::new( - quinn::EndpointConfig::default(), - Some(quic.config.clone()), - socket, - quinn::default_runtime().unwrap(), - )?; - let make_service = self.make_service.clone(); - let backend = - QuicBackend::new(endpoint, quic.h3.clone(), &ctx).with_keep_alive_timeout(self.keep_alive_timeout); - let span = tracing::info_span!("quic", addr = %self.bind, worker = i); - self.backends - .push(AbortOnDrop::new(spawn(backend.serve(make_service).instrument(span)))); - } - } - - let mut binds = vec![]; - - #[cfg(feature = "http-tls")] - if let Some(insecure_bind) = self.insecure_bind { - binds.push(format!("http://{insecure_bind}")); - } - #[cfg(not(feature = "http-tls"))] - binds.push(format!("http://{bind}", bind = self.bind)); - - #[cfg(feature = "http-tls")] - if self.tls.is_some() { - binds.push(format!("https://{}", self.bind)); - } - - #[cfg(feature = "http3")] - if self.quic.is_some() { - binds.push(format!("https+quic://{}", self.bind)); - } - - tracing::info!( - worker_count = self.worker_count, - "listening on {binds}", - binds = binds.join(", ") - ); - - Ok(()) - } - - pub async fn start_and_wait(&mut self) -> Result<(), Error> { - self.start().await?; - self.wait().await - } - - pub async fn wait(&mut self) -> Result<(), Error> { - let Some(handler) = &self.handler else { - return Ok(()); - }; - - let result = futures::future::try_join_all(self.backends.iter_mut().map(|backend| async move { - let child_handler = backend.0.as_mut().unwrap().await??; - handler.cancel(); - child_handler.shutdown().await; - Ok::<_, Error>(()) - })) - .await; - - self.backends.clear(); - - handler.cancel(); - - let handler = self.handler.take().unwrap(); - - result?; - - handler.shutdown().await; - - Ok(()) - } - - pub async fn shutdown(&mut self) -> Result<(), Error> { - let Some(handler) = self.handler.take() else { - return Ok(()); - }; - - handler.cancel(); - - futures::future::try_join_all(self.backends.drain(..).map(|backend| async move { - let child_handler = backend.into_inner().await??; - child_handler.shutdown().await; - Ok::<_, Error>(()) - })) - .await?; - - handler.shutdown().await; - - Ok(()) - } -} diff --git a/crates/foundations/src/http/server/stream/mod.rs b/crates/foundations/src/http/server/stream/mod.rs deleted file mode 100644 index f81024188..000000000 --- a/crates/foundations/src/http/server/stream/mod.rs +++ /dev/null @@ -1,152 +0,0 @@ -#[cfg(feature = "http3")] -pub mod quic; -pub mod tcp; -#[cfg(feature = "http-tls")] -pub mod tls; - -use std::convert::Infallible; - -use super::Error; - -pub trait ConnectionLayer { - type Service: tower_async::Service, Response = http::Response>; - - /// Called when the service is ready to accept requests. - fn on_ready(&self) -> impl std::future::Future { - std::future::ready(()) - } - - /// Called when an error occurs. Some errors may be recoverable, others may - /// not. - fn on_error(&self, err: Error) -> impl std::future::Future { - let _ = err; - std::future::ready(()) - } - - /// Called when the connection is closed. - fn on_close(&self) -> impl std::future::Future { - std::future::ready(()) - } - - /// Called when the connection is hijacked. - /// When a connection is hijacked the on_close method will never be called. - /// And there will be no more requests. - fn on_hijack(&self) -> impl std::future::Future { - std::future::ready(()) - } - - fn layer(&self) -> impl std::future::Future>; -} - -#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub enum SocketKind { - Tcp, - #[cfg(feature = "http3")] - Quic, - #[cfg(feature = "http-tls")] - TcpWithTls, -} - -pub struct IncomingConnection<'a> { - remote_addr: std::net::SocketAddr, - local_addr: std::net::SocketAddr, - socket_kind: SocketKind, - raw_socket: &'a dyn std::any::Any, -} - -impl<'a> IncomingConnection<'a> { - pub fn remote_addr(&self) -> std::net::SocketAddr { - self.remote_addr - } - - pub fn local_addr(&self) -> std::net::SocketAddr { - self.local_addr - } - - pub fn downcast(&'a self) -> Option<&'a T> { - self.raw_socket.downcast_ref() - } -} - -struct LayerServiceHelper + Clone> { - layer: L, - _marker: std::marker::PhantomData, -} - -impl<'a, S, L: ConnectionLayer + Clone> tower_async::Service> for LayerServiceHelper { - type Error = Infallible; - type Response = Option; - - fn call(&self, _: IncomingConnection<'a>) -> impl std::future::Future> { - std::future::ready(Ok(Some(self.layer.clone()))) - } -} - -#[derive(Clone, Debug, Copy)] -pub enum EitherConnectionLayer { - A(A), - B(B), -} - -impl ConnectionLayer for EitherConnectionLayer -where - A: ConnectionLayer, - B: ConnectionLayer, - A::Service: tower_async::Service, - B::Service: tower_async::Service, -{ - async fn on_ready(&self) { - match self { - Self::A(a) => a.on_ready().await, - Self::B(b) => b.on_ready().await, - } - } - - async fn on_error(&self, err: Error) { - match self { - Self::A(a) => a.on_error(err).await, - Self::B(b) => b.on_error(err).await, - } - } - - async fn on_close(&self) { - match self { - Self::A(a) => a.on_close().await, - Self::B(b) => b.on_close().await, - } - } - - async fn on_hijack(&self) { - match self { - Self::A(a) => a.on_hijack().await, - Self::B(b) => b.on_hijack().await, - } - } -} - -impl tower_async::Layer for EitherConnectionLayer -where - A: tower_async::Layer, - B: tower_async::Layer, -{ - type Service = tower_async::util::Either; - - fn layer(&self, inner: S) -> Self::Service { - match self { - Self::A(a) => tower_async::util::Either::A(a.layer(inner)), - Self::B(b) => tower_async::util::Either::B(b.layer(inner)), - } - } -} - -fn is_fatal_tcp_error(err: &std::io::Error) -> bool { - matches!( - err.raw_os_error(), - Some(libc::EFAULT) - | Some(libc::EINVAL) - | Some(libc::ENFILE) - | Some(libc::EMFILE) - | Some(libc::ENOBUFS) - | Some(libc::ENOMEM) - ) -} diff --git a/crates/foundations/src/http/server/stream/quic.rs b/crates/foundations/src/http/server/stream/quic.rs deleted file mode 100644 index 2a4f77e7a..000000000 --- a/crates/foundations/src/http/server/stream/quic.rs +++ /dev/null @@ -1,544 +0,0 @@ -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use axum::body::Body; -use axum::extract::Request; -use axum::response::IntoResponse; -use bytes::{Buf, Bytes}; -use futures::future::poll_fn; -use futures::Future; -use h3::error::{Code, ErrorLevel}; -use h3::ext::Protocol; -use h3::server::{Builder, RequestStream}; -use h3_quinn::{BidiStream, RecvStream, SendStream}; -use http::Response; -use http_body::Body as HttpBody; -use quinn::Connecting; -#[cfg(not(feature = "runtime"))] -use tokio::spawn; -use tracing::Instrument; - -use super::{Backend, IncomingConnection, MakeService, ServiceHandler, SocketKind}; -use crate::context::ContextFutExt; -use crate::http::server::stream::{jitter, ActiveRequestsGuard}; -use crate::http::server::Error; -#[cfg(feature = "runtime")] -use crate::runtime::spawn; -#[cfg(feature = "opentelemetry")] -use crate::telemetry::opentelemetry::OpenTelemetrySpanExt; - -pub struct QuicBackend { - endpoint: quinn::Endpoint, - builder: Arc, - handler: crate::context::Handler, - keep_alive_timeout: Option, -} - -impl QuicBackend { - pub fn new(endpoint: quinn::Endpoint, builder: Arc, ctx: &crate::context::Context) -> Self { - Self { - endpoint, - builder, - handler: ctx.new_child().1, - keep_alive_timeout: None, - } - } - - pub fn with_keep_alive_timeout(mut self, timeout: impl Into>) -> Self { - self.keep_alive_timeout = timeout.into(); - self - } -} - -struct IncomingQuicConnection<'a> { - remote_addr: std::net::SocketAddr, - connection: &'a Connecting, -} - -impl IncomingConnection for IncomingQuicConnection<'_> { - fn socket_kind(&self) -> SocketKind { - SocketKind::Quic - } - - fn remote_addr(&self) -> std::net::SocketAddr { - self.remote_addr - } - - fn local_addr(&self) -> Option { - None - } - - fn downcast(&self) -> Option<&T> { - if std::any::TypeId::of::() == std::any::TypeId::of::() { - // Safety: Connection is valid for the lifetime of self and the type is correct. - Some(unsafe { &*(self.connection as *const Connecting as *const T) }) - } else { - None - } - } -} - -impl Backend for QuicBackend { - async fn serve(self, make_service: impl MakeService) -> Result { - tracing::debug!("listening for incoming connections on {:?}", self.endpoint.local_addr()?); - loop { - let ctx = self.handler.context(); - - tracing::trace!("waiting for incoming connection"); - - let Some(Some(connection)) = self.endpoint.accept().with_context(&ctx).await else { - break; - }; - - if !connection.remote_address_validated() { - if let Err(err) = connection.retry() { - tracing::debug!(error = %err, "failed to retry quic connection"); - } - - continue; - } - - let connection = match connection.accept() { - Ok(connection) => connection, - Err(e) => { - tracing::debug!(error = %e, "failed to accept quic connection"); - continue; - } - }; - - let span = tracing::trace_span!("connection", remote_addr = %connection.remote_address()); - let _guard = span.enter(); - tracing::trace!("connection accepted"); - - let Some(service) = make_service - .make_service(&IncomingQuicConnection { - remote_addr: connection.remote_address(), - connection: &connection, - }) - .await - else { - tracing::trace!("no service returned for connection, closing"); - continue; - }; - - tracing::trace!("spawning connection handler"); - - spawn( - Connection { - connection, - builder: self.builder.clone(), - service, - keep_alive_timeout: self.keep_alive_timeout, - parent_ctx: ctx, - } - .serve() - .in_current_span(), - ); - } - - Ok(self.handler) - } - - fn handler(&self) -> &crate::context::Handler { - &self.handler - } -} - -struct Connection { - connection: Connecting, - builder: Arc, - service: S, - keep_alive_timeout: Option, - parent_ctx: crate::context::Context, -} - -impl Connection { - async fn serve(self) { - tracing::trace!("connection handler started"); - let connection = match self.connection.with_context(&self.parent_ctx).await { - Some(Ok(connection)) => connection, - Some(Err(err)) => { - self.service.on_error(err.into()).await; - self.service.on_close().await; - return; - } - None => { - self.service.on_close().await; - return; - } - }; - - let ip_addr = connection.remote_address().ip(); - - let mut connection = match self - .builder - .build(h3_quinn::Connection::new(connection)) - .with_context(&self.parent_ctx) - .await - { - Some(Ok(connection)) => connection, - Some(Err(err)) => { - self.service.on_error(err.into()).await; - self.service.on_close().await; - return; - } - None => { - self.service.on_close().await; - return; - } - }; - - let (hijack_conn_tx, mut hijack_conn_rx) = tokio::sync::mpsc::channel::(1); - - self.service.on_ready().await; - #[cfg(feature = "opentelemetry")] - tracing::Span::current().make_root(); - tracing::trace!("connection ready"); - - let (_, handler) = self.parent_ctx.new_child(); - - // This handle is similar to the above however, unlike the above if this handle - // is cancelled, all futures for this connection are immediately cancelled. - // When the above is cancelled, the connection is allowed to finish. - let connection_handle = crate::context::Handler::new(); - - let active_requests = Arc::new(std::sync::atomic::AtomicUsize::new(0)); - - loop { - let (request, stream) = tokio::select! { - request = connection.accept() => { - match request { - Ok(Some(request)) => request, - // The connection was closed. - Ok(None) => { - tracing::trace!("connection closed"); - connection_handle.cancel(); - break; - }, - // An error occurred. - Err(err) => { - match err.get_error_level() { - ErrorLevel::ConnectionError => { - tracing::debug!(err = %err, "error accepting request"); - self.service.on_error(err.into()).await; - connection_handle.cancel(); - break; - } - ErrorLevel::StreamError => { - if let Some(Code::H3_NO_ERROR) = err.try_get_code() { - tracing::trace!("stream closed"); - } else { - tracing::debug!(err = %err, "stream error"); - self.service.on_error(err.into()).await; - } - continue; - } - } - } - } - }, - Some(_) = async { - if let Some(keep_alive_timeout) = self.keep_alive_timeout { - loop { - tokio::time::sleep(jitter(keep_alive_timeout)).await; - if active_requests.load(std::sync::atomic::Ordering::Relaxed) != 0 { - continue; - } - - break Some(()); - } - } else { - None - } - } => { - tracing::debug!("keep alive timeout"); - break; - } - // This happens when the connection has been upgraded to a WebTransport connection. - Some(send_hijack_conn) = hijack_conn_rx.recv() => { - tracing::trace!("connection hijacked"); - send_hijack_conn.send(connection).ok(); - self.service.on_hijack().await; - return; - }, - _ = self.parent_ctx.done() => break, - }; - - tracing::trace!("new request"); - let active_requests = ActiveRequestsGuard::new(active_requests.clone()); - - let service = self.service.clone(); - let stream = QuinnStream::new(stream); - - let mut request = request.map(|()| Body::from_stream(QuinnHttpBodyAdapter::new(stream.clone()))); - - let ctx = handler.context(); - - request.extensions_mut().insert(QuicConnectionState { - hijack_conn_tx: hijack_conn_tx.clone(), - stream: stream.clone(), - }); - request.extensions_mut().insert(SocketKind::Quic); - request.extensions_mut().insert(ctx.clone()); - request.extensions_mut().insert(ip_addr); - - let connection_context = connection_handle.context(); - - tokio::spawn( - async move { - if let Err(err) = serve_request(&service, request, stream).await { - service.on_error(err).await; - } - - drop(active_requests); - drop(ctx); - } - .with_context(connection_context) - .in_current_span(), - ); - } - - tracing::trace!("connection closing"); - - handler.shutdown().await; - - connection_handle.shutdown().await; - - self.service.on_close().await; - - tracing::trace!("connection closed"); - } -} - -async fn serve_request(service: &impl ServiceHandler, request: Request, mut stream: QuinnStream) -> Result<(), Error> { - let response = service.on_request(request).await.into_response(); - - let Some(send) = stream.get_send() else { - // The service was hijacked. - tracing::trace!("service hijacked, not sending response"); - return Ok(()); - }; - - let (parts, body) = response.into_parts(); - tracing::trace!(?parts, "sending response"); - send.send_response(Response::from_parts(parts, ())).await?; - - let mut body = std::pin::pin!(body); - - tracing::trace!("sending response body"); - - loop { - match poll_fn(|cx| body.as_mut().poll_frame(cx)).await.transpose()? { - Some(frame) => { - if frame.is_data() { - let data = frame.into_data().unwrap(); - tracing::trace!(size = data.len(), "sending data"); - send.send_data(data).await?; - } else if frame.is_trailers() { - tracing::trace!("sending trailers"); - send.send_trailers(frame.into_trailers().unwrap()).await?; - break; - } - } - None => { - send.finish().await?; - break; - } - } - } - - tracing::trace!("response body finished"); - - Ok(()) -} - -type SendQuicConnection = tokio::sync::oneshot::Sender>; - -#[derive(Clone)] -struct QuicConnectionState { - hijack_conn_tx: tokio::sync::mpsc::Sender, - stream: QuinnStream, -} - -enum SharedStream { - Bidi(Option, Bytes>>), - Recv(Option>), - Send(Option, Bytes>>), -} - -impl SharedStream { - fn take_bidi(&mut self) -> Option, Bytes>> { - match self { - SharedStream::Bidi(stream) => stream.take(), - _ => None, - } - } - - fn take_recv(&mut self) -> Option> { - match self { - SharedStream::Recv(stream) => stream.take(), - SharedStream::Bidi(stream) => { - let (send, recv) = stream.take()?.split(); - *self = SharedStream::Send(Some(send)); - Some(recv) - } - _ => None, - } - } - - fn take_send(&mut self) -> Option, Bytes>> { - match self { - SharedStream::Send(stream) => stream.take(), - SharedStream::Bidi(stream) => { - let (send, recv) = stream.take()?.split(); - *self = SharedStream::Recv(Some(recv)); - Some(send) - } - _ => None, - } - } -} - -enum QuinnStream { - Shared(Arc>), - LocalRecv(RequestStream), - LocalSend(RequestStream, Bytes>), - None, -} - -impl Clone for QuinnStream { - fn clone(&self) -> Self { - match self { - QuinnStream::Shared(stream) => QuinnStream::Shared(stream.clone()), - _ => QuinnStream::None, - } - } -} - -impl QuinnStream { - fn new(stream: RequestStream, Bytes>) -> Self { - QuinnStream::Shared(Arc::new(spin::Mutex::new(SharedStream::Bidi(Some(stream))))) - } - - fn take_bidi(&mut self) -> Option, Bytes>> { - match self { - QuinnStream::Shared(stream) => { - let stream = stream.lock().take_bidi()?; - *self = Self::None; - Some(stream) - } - _ => None, - } - } - - fn get_recv(&mut self) -> Option<&mut RequestStream> { - match self { - QuinnStream::Shared(stream) => { - let stream = stream.lock().take_recv()?; - *self = Self::LocalRecv(stream); - self.get_recv() - } - QuinnStream::LocalRecv(stream) => Some(stream), - _ => None, - } - } - - fn get_send(&mut self) -> Option<&mut RequestStream, Bytes>> { - match self { - QuinnStream::Shared(stream) => { - let stream = stream.lock().take_send()?; - *self = Self::LocalSend(stream); - self.get_send() - } - QuinnStream::LocalSend(stream) => Some(stream), - _ => None, - } - } -} - -struct QuinnHttpBodyAdapter { - stream: QuinnStream, -} - -impl QuinnHttpBodyAdapter { - fn new(stream: QuinnStream) -> Self { - Self { stream } - } -} - -impl futures::Stream for QuinnHttpBodyAdapter { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let stream = match self.stream.get_recv() { - Some(stream) => stream, - None => return Poll::Ready(None), - }; - - match std::pin::pin!(stream.recv_data()).poll(cx) { - Poll::Ready(Ok(Some(mut buf))) => Poll::Ready(Some(Ok(buf.copy_to_bytes(buf.remaining())))), - Poll::Ready(Ok(None)) => Poll::Ready(None), - Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))), - Poll::Pending => Poll::Pending, - } - } -} - -#[cfg(feature = "http3-webtransport")] -use future::IntoFuture; -#[cfg(feature = "http3-webtransport")] -use h3_webtransport::server::WebTransportSession; - -pub struct HijackedQuicConnection { - pub connection: h3::server::Connection, - pub stream: RequestStream, Bytes>, - pub request: Request<()>, -} - -impl HijackedQuicConnection { - #[cfg(feature = "http3-webtransport")] - pub async fn upgrade_webtransport(self) -> Result, Error> { - Ok(WebTransportSession::accept(self.request, self.stream, self.connection).await?) - } -} - -pub async fn is_webtransport(request: &Request) -> bool { - request.method() == http::Method::CONNECT - && request.extensions().get::() == Some(&Protocol::WEB_TRANSPORT) - && request.extensions().get::().is_some() -} - -pub async fn hijack_quic_connection(request: Request) -> Result { - let Some(web_transport_state) = request.extensions().get::() else { - tracing::debug!("request is not a quic connection"); - return Err(request); - }; - - let Some(stream) = web_transport_state.stream.clone().take_bidi() else { - tracing::debug!("request body has already been read"); - return Err(request); - }; - - let (send, recv) = tokio::sync::oneshot::channel(); - if web_transport_state.hijack_conn_tx.send(send).await.is_err() { - tracing::debug!("connection has already been hijacked"); - return Err(request); - } - - let connection = match recv.await { - Ok(connection) => connection, - Err(_) => { - tracing::debug!("connection was dropped"); - return Err(request); - } - }; - - let request = request.map(|_| {}); - - Ok(HijackedQuicConnection { - connection, - stream, - request, - }) -} diff --git a/crates/foundations/src/http/server/stream/tcp.rs b/crates/foundations/src/http/server/stream/tcp.rs deleted file mode 100644 index 5ff8c79b4..000000000 --- a/crates/foundations/src/http/server/stream/tcp.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::convert::Infallible; -use std::sync::Arc; - -use crate::http::server::stream::is_fatal_tcp_error; - -pub struct TcpBackend { - listener: tokio::net::TcpListener, - handler: crate::context::Handler, - inner: L, -} - -impl Backend for TcpBackend { - async fn serve(self, make_service: impl MakeService) -> Result { - tracing::debug!("listening for incoming connections on {:?}", self.listener.local_addr()?); - loop { - let ctx = self.handler.context(); - - tracing::trace!("waiting for incoming connection"); - - let Some(stream) = self.listener.accept().with_context(&ctx).await else { - break; - }; - - let (connection, addr) = match stream { - Ok((connection, addr)) => (connection, addr), - Err(err) if is_fatal_tcp_error(&err) => { - return Err(err.into()); - } - Err(err) => { - tracing::error!(err = %err, "failed to accept connection"); - continue; - } - }; - - let span = tracing::trace_span!("connection", remote_addr = %addr); - - let _guard = span.enter(); - - tracing::trace!("accepted connection"); - - let Some(service) = make_service - .make_service(&IncomingTcpConnection { - remote_addr: addr, - local_addr: self.listener.local_addr()?, - connection: &connection, - }) - .await - else { - tracing::trace!("no service returned for connection, closing"); - continue; - }; - - spawn( - Connection { - connection, - builder: self.builder.clone(), - service, - parent_ctx: ctx, - peer_addr: addr, - keep_alive_timeout: self.keep_alive_timeout, - } - .serve() - .in_current_span(), - ); - } - - Ok(self.handler) - } - - fn handler(&self) -> &crate::context::Handler { - &self.handler - } -} - -struct Connection { - connection: TcpStream, - builder: Arc>, - service: S, - peer_addr: std::net::SocketAddr, - parent_ctx: crate::context::Context, - keep_alive_timeout: Option, -} - -impl Connection { - async fn serve(self) { - self.service.on_ready().await; - #[cfg(feature = "opentelemetry")] - tracing::Span::current().make_root(); - tracing::trace!("connection ready"); - - let (_, handle) = self.parent_ctx.new_child(); - - let make_ctx = { - let handle = handle.clone(); - Arc::new(move || handle.context()) - }; - - let active_requests = Arc::new(std::sync::atomic::AtomicUsize::new(0)); - - let ip_addr = self.peer_addr.ip(); - - let service_fn = { - let service = self.service.clone(); - let span = tracing::Span::current(); - let active_requests = active_requests.clone(); - - service_fn(move |mut req: Request| { - let service = service.clone(); - let make_ctx = make_ctx.clone(); - let guard = ActiveRequestsGuard::new(active_requests.clone()); - async move { - let ctx = make_ctx(); - req.extensions_mut().insert(ctx.clone()); - req.extensions_mut().insert(SocketKind::Tcp); - req.extensions_mut().insert(ip_addr); - let resp = service.on_request(req.map(Body::new)).await.into_response(); - drop(ctx); - drop(guard); - Ok::<_, Infallible>(resp) - } - .instrument(span.clone()) - }) - }; - - let r = tokio::select! { - r = self.builder.serve_connection_with_upgrades(TokioIo::new(self.connection), service_fn) => r, - Some(_) = async { - if let Some(keep_alive_timeout) = self.keep_alive_timeout { - loop { - tokio::time::sleep(jitter(keep_alive_timeout)).await; - if active_requests.load(std::sync::atomic::Ordering::Relaxed) != 0 { - continue; - } - - break Some(()); - } - } else { - None - } - } => { - tracing::debug!("keep alive timeout"); - Ok(()) - } - _ = async { - self.parent_ctx.done().await; - handle.shutdown().await; - } => { - Ok(()) - } - }; - - if let Err(err) = r { - self.service.on_error(err.into()).await; - } - - self.service.on_close().await; - tracing::trace!("connection closed"); - } -} diff --git a/crates/foundations/src/http/server/stream/tls.rs b/crates/foundations/src/http/server/stream/tls.rs deleted file mode 100644 index c65f09bde..000000000 --- a/crates/foundations/src/http/server/stream/tls.rs +++ /dev/null @@ -1,251 +0,0 @@ -use std::convert::Infallible; -use std::sync::Arc; - -use axum::body::Body; -use axum::extract::Request; -use axum::response::IntoResponse; -use hyper::body::Incoming; -use hyper::service::service_fn; -use hyper_util::rt::{TokioExecutor, TokioIo}; -use hyper_util::server::conn::auto::Builder; -use tokio::net::{TcpListener, TcpStream}; -use tokio::spawn; -use tokio_rustls::TlsAcceptor; -use tracing::Instrument; - -use super::{Backend, IncomingConnection, MakeService, ServiceHandler, SocketKind}; -use crate::context::ContextFutExt; -use crate::http::server::stream::{is_fatal_tcp_error, jitter, ActiveRequestsGuard}; -#[cfg(feature = "opentelemetry")] -use crate::telemetry::opentelemetry::OpenTelemetrySpanExt; - -pub struct TlsBackend { - listener: TcpListener, - acceptor: Arc, - builder: Arc>, - handler: crate::context::Handler, - keep_alive_timeout: Option, -} - -impl TlsBackend { - pub fn new( - listener: TcpListener, - acceptor: Arc, - builder: Arc>, - ctx: &crate::context::Context, - ) -> Self { - Self { - listener, - acceptor, - builder, - handler: ctx.new_child().1, - keep_alive_timeout: None, - } - } - - pub fn with_keep_alive_timeout(mut self, timeout: impl Into>) -> Self { - self.keep_alive_timeout = timeout.into(); - self - } -} - -struct IncomingTlsConnection<'a> { - remote_addr: std::net::SocketAddr, - local_addr: std::net::SocketAddr, - connection: &'a TcpStream, -} - -impl IncomingConnection for IncomingTlsConnection<'_> { - fn socket_kind(&self) -> SocketKind { - SocketKind::TlsTcp - } - - fn remote_addr(&self) -> std::net::SocketAddr { - self.remote_addr - } - - fn local_addr(&self) -> Option { - Some(self.local_addr) - } - - fn downcast(&self) -> Option<&T> { - if std::any::TypeId::of::() == std::any::TypeId::of::() { - // Safety: We know that the type is TcpStream because we checked the type id. - // We also know that the reference is valid because it is a reference to a field - // of self. - Some(unsafe { &*(self.connection as *const TcpStream as *const T) }) - } else { - None - } - } -} - -impl Backend for TlsBackend { - async fn serve(self, make_service: impl MakeService) -> Result { - tracing::debug!("listening for incoming connections on {:?}", self.listener.local_addr()?); - loop { - let ctx = self.handler.context(); - - tracing::trace!("waiting for incoming connection"); - - let Some(stream) = self.listener.accept().with_context(&ctx).await else { - break; - }; - - let (connection, addr) = match stream { - Ok((connection, addr)) => (connection, addr), - Err(err) if is_fatal_tcp_error(&err) => { - return Err(err.into()); - } - Err(err) => { - tracing::error!(err = %err, "failed to accept connection"); - continue; - } - }; - - let span = tracing::trace_span!("connection", remote_addr = %addr); - let _guard = span.enter(); - - tracing::trace!("accepted connection"); - - let Some(service) = make_service - .make_service(&IncomingTlsConnection { - remote_addr: addr, - local_addr: self.listener.local_addr()?, - connection: &connection, - }) - .await - else { - tracing::trace!("no service for connection, closing"); - continue; - }; - - tracing::trace!("spawning connection handler"); - - spawn( - Connection { - connection, - builder: self.builder.clone(), - acceptor: self.acceptor.clone(), - service, - parent_ctx: ctx, - peer_addr: addr, - keep_alive_timeout: self.keep_alive_timeout, - } - .serve() - .in_current_span(), - ); - } - - Ok(self.handler) - } - - fn handler(&self) -> &crate::context::Handler { - &self.handler - } -} - -struct Connection { - connection: TcpStream, - builder: Arc>, - acceptor: Arc, - service: S, - peer_addr: std::net::SocketAddr, - keep_alive_timeout: Option, - parent_ctx: crate::context::Context, -} - -impl Connection { - async fn serve(self) { - #[cfg(feature = "opentelemetry")] - tracing::Span::current().begin_trace(); - - tracing::trace!("connection handler started"); - let connection = match self.acceptor.accept(self.connection).with_context(&self.parent_ctx).await { - Some(Ok(connection)) => connection, - Some(Err(err)) => { - tracing::debug!(err = %err, "error accepting connection"); - self.service.on_error(err.into()).await; - self.service.on_close().await; - return; - } - None => { - self.service.on_close().await; - return; - } - }; - - self.service.on_ready().await; - - tracing::trace!("connection ready"); - - let (_, handle) = self.parent_ctx.new_child(); - - let make_ctx = { - let handle = handle.clone(); - Arc::new(move || handle.context()) - }; - - let ip_addr = self.peer_addr.ip(); - - let active_requests = Arc::new(std::sync::atomic::AtomicUsize::new(0)); - - let service_fn = { - let service = self.service.clone(); - let make_ctx = make_ctx.clone(); - let span = tracing::Span::current(); - let active_requests = active_requests.clone(); - - service_fn(move |mut req: Request| { - let service = service.clone(); - let make_ctx = make_ctx.clone(); - let guard = ActiveRequestsGuard::new(active_requests.clone()); - async move { - let ctx = make_ctx(); - req.extensions_mut().insert(ctx.clone()); - req.extensions_mut().insert(SocketKind::TlsTcp); - req.extensions_mut().insert(ip_addr); - let resp = service.on_request(req.map(Body::new)).await.into_response(); - drop(ctx); - drop(guard); - Ok::<_, Infallible>(resp) - } - .instrument(span.clone()) - }) - }; - - let r = tokio::select! { - r = self.builder.serve_connection_with_upgrades(TokioIo::new(connection), service_fn) => r, - Some(_) = async { - if let Some(keep_alive_timeout) = self.keep_alive_timeout { - loop { - tokio::time::sleep(jitter(keep_alive_timeout)).await; - if active_requests.load(std::sync::atomic::Ordering::Relaxed) != 0 { - continue; - } - - break Some(()); - } - } else { - None - } - } => { - tracing::debug!("keep alive timeout"); - Ok(()) - } - _ = async { - self.parent_ctx.done().await; - handle.shutdown().await; - } => { - Ok(()) - } - }; - - if let Err(err) = r { - self.service.on_error(err.into()).await; - } - - self.service.on_close().await; - tracing::trace!("connection closed"); - } -} diff --git a/crates/foundations/src/lib.rs b/crates/foundations/src/lib.rs index d2ce19559..c11edb6c3 100644 --- a/crates/foundations/src/lib.rs +++ b/crates/foundations/src/lib.rs @@ -1,44 +1,11 @@ -#[cfg(feature = "macros")] -pub use scuffle_foundations_macros::wrapped; - -#[cfg(feature = "macros")] -#[doc(hidden)] -pub mod macro_reexports { - #[cfg(feature = "cli")] - pub use const_str; - #[cfg(feature = "metrics")] - pub use once_cell; - #[cfg(feature = "metrics")] - pub use parking_lot; - #[cfg(feature = "metrics")] - pub use prometheus_client; - #[cfg(any(feature = "settings", feature = "metrics"))] - pub use serde; -} - pub type BootstrapResult = anyhow::Result; -#[cfg(feature = "settings")] -pub mod settings; - #[cfg(feature = "bootstrap")] pub mod bootstrap; #[cfg(feature = "_telemetry")] pub mod telemetry; -#[cfg(feature = "signal")] -pub mod signal; - -#[cfg(feature = "context")] -pub mod context; - -#[cfg(feature = "batcher")] -pub mod batcher; - -#[cfg(feature = "http")] -pub mod http; - #[derive(Debug, Clone, Copy, Default)] /// Information about the service. pub struct ServiceInfo { diff --git a/crates/foundations/src/settings/cli.rs b/crates/foundations/src/settings/cli.rs index e024b0ac0..78117ac59 100644 --- a/crates/foundations/src/settings/cli.rs +++ b/crates/foundations/src/settings/cli.rs @@ -8,55 +8,6 @@ const GENERATE_ARG_ID: &str = "generate"; const CONFIG_ARG_ID: &str = "config"; const ALLOW_TEMPLATE: &str = "jinja"; -pub use clap; - -#[derive(Debug)] -pub struct Cli { - settings: SettingsParser, - app: clap::Command, -} - -fn default_cmd() -> clap::Command { - clap::Command::new("") - .arg( - clap::Arg::new(CONFIG_ARG_ID) - .long(CONFIG_ARG_ID) - .short('c') - .help("The configuration file to use") - .value_name("FILE") - .action(ArgAction::Append), - ) - .arg( - clap::Arg::new(GENERATE_ARG_ID) - .long(GENERATE_ARG_ID) - .help("Generate a configuration file") - .value_name("FILE") - .action(ArgAction::Set) - .num_args(0..=1) - .default_missing_value("./config.toml"), - ) - .arg( - clap::Arg::new(ALLOW_TEMPLATE) - .long("jinja") - .help("Allows for the expansion of templates in the configuration file using Jinja syntax") - .action(ArgAction::Set) - .num_args(0..=1) - .default_missing_value("true"), - ) -} - -impl Default for Cli { - fn default() -> Self { - Self::new(&Default::default()) - } -} - -#[derive(Debug, Clone)] -pub struct Matches { - pub settings: S, - pub args: clap::ArgMatches, -} - impl Cli { pub fn new(default: &S) -> Self { Self { diff --git a/crates/foundations/src/settings/traits.rs b/crates/foundations/src/settings/traits.rs deleted file mode 100644 index 889dd9ad8..000000000 --- a/crates/foundations/src/settings/traits.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! This module contains an auto-deref specialization to help with adding doc -//! comments to sub-types. You can read more about how it works here -//! https://lukaskalbertodt.github.io/2019/12/05/generalized-autoref-based-specialization.html - -use std::borrow::Cow; -use std::collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, LinkedList, VecDeque}; -use std::hash::Hash; - -use super::to_docs_string; - -pub trait Settings { - #[doc(hidden)] - fn add_docs( - &self, - parent_key: &[Cow<'static, str>], - docs: &mut HashMap>, Cow<'static, [Cow<'static, str>]>>, - ) { - let (_, _) = (parent_key, docs); - } - - fn docs(&self) -> HashMap>, Cow<'static, [Cow<'static, str>]>> { - let mut docs = HashMap::new(); - self.add_docs(&[], &mut docs); - docs - } - - fn to_docs_string(&self) -> Result - where - Self: serde::Serialize + Sized, - { - to_docs_string(self) - } -} - -#[doc(hidden)] -pub struct Wrapped(pub T); - -/// Default implementation for adding docs to a wrapped type. -impl Settings for Wrapped<&T> {} - -/// Specialization for adding docs to a type that implements SerdeDocs. -impl Settings for &Wrapped<&T> { - fn add_docs( - &self, - parent_key: &[Cow<'static, str>], - docs: &mut HashMap>, Cow<'static, [Cow<'static, str>]>>, - ) { - ::add_docs(self.0, parent_key, docs) - } -} - -/// Specialization for adding docs an array type that implements SerdeDocs. -macro_rules! impl_arr { - ($($n:literal)+) => { - $( - impl Settings for &Wrapped<&[T; $n]> { - fn add_docs(&self, parent_key: &[Cow<'static, str>], docs: &mut HashMap>, Cow<'static, [Cow<'static, str>]>>) { - let mut key = parent_key.to_vec(); - for (i, item) in self.0.iter().enumerate() { - key.push(i.to_string().into()); - item.add_docs(&key, docs); - key.pop(); - } - } - } - )+ - }; -} - -impl_arr!(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32); - -/// Specialization for adding docs to a slice type that implements SerdeDocs. -macro_rules! impl_seq { - ($($impl_desc:tt)*) => { - impl $($impl_desc)* { - fn add_docs(&self, parent_key: &[Cow<'static, str>], docs: &mut HashMap>, Cow<'static, [Cow<'static, str>]>>) { - let mut key = parent_key.to_vec(); - for (i, item) in self.0.iter().enumerate() { - key.push(i.to_string().into()); - item.add_docs(&key, docs); - key.pop(); - } - } - } - }; -} - -impl_seq!( Settings for &Wrapped<&Vec>); -impl_seq!( Settings for &Wrapped<&VecDeque>); -impl_seq!( Settings for &Wrapped<&BinaryHeap>); -impl_seq!( Settings for &Wrapped<&LinkedList>); -impl_seq!( Settings for &Wrapped<&BTreeSet>); - -/// Specialization for adding docs to a map type that implements SerdeDocs. -macro_rules! impl_map { - ($($impl_desc:tt)*) => { - impl $($impl_desc)* { - fn add_docs(&self, parent_key: &[Cow<'static, str>], docs: &mut HashMap>, Cow<'static, [Cow<'static, str>]>>) { - let mut key = parent_key.to_vec(); - for (k, v) in self.0.iter() { - key.push(k.to_string().into()); - v.add_docs(&key, docs); - key.pop(); - } - } - } - }; -} - -/// Key types for those maps that implement SerdeDocs. -trait Keyable: Hash + PartialOrd + PartialEq + std::fmt::Display {} - -macro_rules! impl_keyable { - ($($t:ty)*) => { - $( - impl Keyable for $t {} - )* - }; -} - -impl_keyable!(String &'static str Cow<'static, str> usize u8 u16 u32 u64 u128 i8 i16 i32 i64 i128 bool char); - -impl_map!( Settings for &Wrapped<&HashMap>); -impl_map!( Settings for &Wrapped<&BTreeMap>); - -/// Specialization for adding docs to an option type that implements SerdeDocs. -impl Settings for &Wrapped<&Option> { - fn add_docs( - &self, - parent_key: &[Cow<'static, str>], - docs: &mut HashMap>, Cow<'static, [Cow<'static, str>]>>, - ) { - if let Some(inner) = self.0 { - inner.add_docs(parent_key, docs); - } - } -} - -/// Specialization for any type that derefs into a type that implements -/// SerdeDocs. -impl Settings for &&Wrapped<&R> -where - R: std::ops::Deref, -{ - fn add_docs( - &self, - parent_key: &[Cow<'static, str>], - docs: &mut HashMap>, Cow<'static, [Cow<'static, str>]>>, - ) { - (**self.0).add_docs(parent_key, docs); - } -} - -impl Settings for () {} diff --git a/crates/foundations/src/signal.rs b/crates/foundations/src/signal.rs deleted file mode 100644 index ac9e1efe6..000000000 --- a/crates/foundations/src/signal.rs +++ /dev/null @@ -1,40 +0,0 @@ -use futures::FutureExt; -use tokio::signal::unix::{Signal, SignalKind}; - -#[derive(Default)] -pub struct SignalHandler { - signals: Vec<(SignalKind, Signal)>, -} - -impl SignalHandler { - pub fn new() -> Self { - Self::default() - } - - pub fn with_signal(mut self, kind: SignalKind) -> Self { - if self.signals.iter().any(|(k, _)| k == &kind) { - return self; - } - - let signal = tokio::signal::unix::signal(kind).expect("failed to create signal"); - - self.signals.push((kind, signal)); - - self - } - - pub async fn recv(&mut self) -> Option { - if self.signals.is_empty() { - return None; - } - - let (item, _, _) = futures::future::select_all( - self.signals - .iter_mut() - .map(|(kind, signal)| Box::pin(signal.recv().map(|_| *kind))), - ) - .await; - - Some(item) - } -} diff --git a/crates/h3-webtransport/src/session.rs b/crates/h3-webtransport/src/session.rs index cb7d889a4..c9484181d 100644 --- a/crates/h3-webtransport/src/session.rs +++ b/crates/h3-webtransport/src/session.rs @@ -190,7 +190,7 @@ where } /// Polls to open a bidi stream - #[allow(clippy::type_complexity)] + #[allow(clippy::type_complexity)] pub fn poll_open_bi( &self, cx: &mut Context<'_>, @@ -211,7 +211,7 @@ where } /// Polls to open a uni stream - #[allow(clippy::type_complexity)] + #[allow(clippy::type_complexity)] pub fn poll_open_uni( &self, cx: &mut Context<'_>, @@ -266,7 +266,7 @@ where } /// Completes the WebTransport upgrade - #[allow(clippy::type_complexity)] + #[allow(clippy::type_complexity)] pub fn complete( response: &mut Response, stream: RequestStream, diff --git a/crates/http/src/backend/tcp/config.rs b/crates/http/src/backend/tcp/config.rs index f0c03386e..a758021d8 100644 --- a/crates/http/src/backend/tcp/config.rs +++ b/crates/http/src/backend/tcp/config.rs @@ -51,7 +51,7 @@ impl Http1Builder { allow_http10: false, } } - + pub fn with_recv_buffer_size(mut self, size: usize) -> Self { self.recv_buffer_size = size; self @@ -214,9 +214,9 @@ pub struct TcpServerConfigBuilder { } impl Default for TcpServerConfigBuilder { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl TcpServerConfigBuilder { diff --git a/crates/settings/Cargo.toml b/crates/settings/Cargo.toml new file mode 100644 index 000000000..3f0bb26ae --- /dev/null +++ b/crates/settings/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "scuffle-settings" +version = "0.1.0" +edition = "2021" + +[dependencies] +config = { version = "0.14", default-features = false } +clap = { version = "4", optional = true } +minijinja = { version = "2.5", optional = true, features = ["json", "custom_syntax", "urlencode"] } +serde = "1" +thiserror = "2" +bon = "3" + +[features] +cli = ["clap"] +ron = ["config/ron"] +toml = ["config/toml"] +yaml = ["config/yaml"] +json = ["config/json"] +json5 = ["config/json5"] +ini = ["config/ini"] +templates = ["minijinja"] + +default = ["toml", "json", "yaml", "json5", "ini", "ron", "cli", "templates"] + +[dev-dependencies] +serde_derive = "1" +tracing = "0.1" +tracing-subscriber = "0.3.18" diff --git a/crates/settings/src/lib.rs b/crates/settings/src/lib.rs new file mode 100644 index 000000000..5a75192ea --- /dev/null +++ b/crates/settings/src/lib.rs @@ -0,0 +1,235 @@ +use std::path::Path; + +#[derive(Debug, thiserror::Error)] +pub enum ConfigError { + #[error(transparent)] + Config(#[from] config::ConfigError), + #[cfg(feature = "cli")] + #[error(transparent)] + Clap(#[from] clap::Error), +} + +/// A struct used to define how the CLI should be generated +#[derive(Debug, Clone)] +pub struct Cli { + /// The name of the program + pub name: &'static str, + + /// The version of the program + pub version: &'static str, + + /// The about of the program + pub about: &'static str, + + /// The author of the program + pub author: &'static str, + + /// The arguments to add to the CLI + pub argv: Vec, +} + +#[macro_export] +/// A macro to create a CLI struct +/// This macro will automatically set the name, version, about, and author from +/// the environment variables at compile time +macro_rules! cli { + () => { + $crate::cli!(std::env::args().collect()) + }; + ($args:expr) => { + $crate::Cli { + name: env!("CARGO_PKG_NAME"), + version: env!("CARGO_PKG_VERSION"), + about: env!("CARGO_PKG_DESCRIPTION"), + author: env!("CARGO_PKG_AUTHORS"), + argv: $args, + } + }; +} + +#[derive(Debug, Clone, Copy)] +struct FormatWrapper; + +use std::borrow::Cow; + +#[cfg(not(feature = "templates"))] +fn template_text(text: &str, _: config::FileFormat) -> Result, Box> { + Ok(Cow::Borrowed(text)) +} + +#[cfg(feature = "templates")] +fn template_text(text: &str, _: config::FileFormat) -> Result, Box> { + use minijinja::syntax::SyntaxConfig; + + let mut env = minijinja::Environment::new(); + + env.add_global("env", std::env::vars().collect::>()); + env.set_syntax( + SyntaxConfig::builder() + .block_delimiters("{%", "%}") + .variable_delimiters("${{", "}}") + .comment_delimiters("{#", "#}") + .build() + .unwrap(), + ); + + Ok(Cow::Owned(env.template_from_str(text).unwrap().render(())?)) +} + +impl config::Format for FormatWrapper { + fn parse( + &self, + uri: Option<&String>, + text: &str, + ) -> Result, Box> { + match uri.and_then(|s| Path::new(s.as_str()).extension()).and_then(|s| s.to_str()) { + #[cfg(feature = "toml")] + Some("toml") => config::FileFormat::Toml.parse(uri, template_text(text, config::FileFormat::Toml)?.as_ref()), + #[cfg(feature = "json")] + Some("json") => config::FileFormat::Json.parse(uri, template_text(text, config::FileFormat::Json)?.as_ref()), + #[cfg(feature = "yaml")] + Some("yaml") | Some("yml") => { + config::FileFormat::Yaml.parse(uri, template_text(text, config::FileFormat::Yaml)?.as_ref()) + } + #[cfg(feature = "json5")] + Some("json5") => config::FileFormat::Json5.parse(uri, template_text(text, config::FileFormat::Json5)?.as_ref()), + #[cfg(feature = "ini")] + Some("ini") => config::FileFormat::Ini.parse(uri, template_text(text, config::FileFormat::Ini)?.as_ref()), + #[cfg(feature = "ron")] + Some("ron") => config::FileFormat::Ron.parse(uri, template_text(text, config::FileFormat::Ron)?.as_ref()), + _ => { + let formats = [ + #[cfg(feature = "toml")] + config::FileFormat::Toml, + #[cfg(feature = "json")] + config::FileFormat::Json, + #[cfg(feature = "yaml")] + config::FileFormat::Yaml, + #[cfg(feature = "json5")] + config::FileFormat::Json5, + #[cfg(feature = "ini")] + config::FileFormat::Ini, + #[cfg(feature = "ron")] + config::FileFormat::Ron, + ]; + + for format in formats { + if let Ok(map) = format.parse(uri, template_text(text, format)?.as_ref()) { + return Ok(map); + } + } + + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("No supported format found for file: {:?}", uri), + ))) + } + } + } +} + +impl config::FileStoredFormat for FormatWrapper { + fn file_extensions(&self) -> &'static [&'static str] { + &[ + #[cfg(feature = "toml")] + "toml", + #[cfg(feature = "json")] + "json", + #[cfg(feature = "yaml")] + "yaml", + #[cfg(feature = "yaml")] + "yml", + #[cfg(feature = "json5")] + "json5", + #[cfg(feature = "ini")] + "ini", + #[cfg(feature = "ron")] + "ron", + ] + } +} + +#[derive(Debug, Clone, bon::Builder)] +pub struct Options { + /// The CLI options + #[cfg(feature = "cli")] + pub cli: Option, + /// The default config file name (loaded if no other files are specified) + pub default_config_file: Option<&'static str>, + /// Environment variables prefix + pub env_prefix: Option<&'static str>, +} + +impl Default for Options { + fn default() -> Self { + Self { + cli: None, + default_config_file: Some("config"), + env_prefix: Some("APP"), + } + } +} + +pub fn parse_settings(options: Options) -> Result { + let mut config = config::Config::builder(); + + #[allow(unused_mut)] + let mut added_files = false; + + #[cfg(feature = "cli")] + if let Some(cli) = options.cli { + let command = clap::Command::new(cli.name) + .version(cli.version) + .about(cli.about) + .author(cli.author) + .bin_name(cli.name) + .arg( + clap::Arg::new("config") + .short('c') + .long("config") + .value_name("FILE") + .help("Path to configuration file(s)") + .action(clap::ArgAction::Append), + ) + .arg( + clap::Arg::new("overrides") + .long("override") + .help("Provide an override for a configuration value, in the format KEY=VALUE") + .action(clap::ArgAction::Append), + ); + + let matches = command.get_matches_from(cli.argv); + + if let Some(config_files) = matches.get_many::("config") { + for path in config_files { + config = config.add_source(config::File::new(path, FormatWrapper)); + added_files = true; + } + } + + if let Some(overrides) = matches.get_many::("overrides") { + for ov in overrides { + let (key, value) = ov.split_once('=').ok_or_else(|| { + clap::Error::raw( + clap::error::ErrorKind::InvalidValue, + "Override must be in the format KEY=VALUE", + ) + })?; + + config = config.set_override(key, value)?; + } + } + } + + if !added_files { + if let Some(default_config_file) = options.default_config_file { + config = config.add_source(config::File::new(default_config_file, FormatWrapper).required(false)); + } + } + + if let Some(env_prefix) = options.env_prefix { + config = config.add_source(config::Environment::with_prefix(env_prefix)); + } + + Ok(config.build()?.try_deserialize()?) +} diff --git a/crates/signal/Cargo.toml b/crates/signal/Cargo.toml new file mode 100644 index 000000000..cb8d572c2 --- /dev/null +++ b/crates/signal/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "scuffle-signal" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1.41.1", default-features = false, features = ["signal"] } + +[dev-dependencies] +tokio = { version = "1.41.1", features = ["macros", "rt", "time"] } +libc = "0.2" +futures = "0.3" + diff --git a/crates/signal/src/lib.rs b/crates/signal/src/lib.rs new file mode 100644 index 000000000..e80f21e17 --- /dev/null +++ b/crates/signal/src/lib.rs @@ -0,0 +1,125 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::signal::unix::{Signal, SignalKind}; + +/// A handler for listening to multiple Unix signals, and providing a future for +/// receiving them. +/// +/// This is useful for applications that need to listen for multiple signals, +/// and want to react to them in a non-blocking way. Typically you would need to +/// use a tokio::select{} to listen for multiple signals, but this provides a +/// more ergonomic interface for doing so. +/// +/// After a signal is received you can poll the handler again to wait for +/// another signal. Dropping the handle will cancel the signal subscription +#[derive(Debug)] +#[must_use = "signal handlers must be used to wait for signals"] +pub struct SignalHandler { + signals: Vec<(SignalKind, Signal)>, +} + +impl Default for SignalHandler { + fn default() -> Self { + Self::new() + } +} + +impl SignalHandler { + /// Create a new `SignalHandler` with no signals. + pub const fn new() -> Self { + Self { signals: Vec::new() } + } + + /// Add a signal to the handler. + /// + /// If the signal is already in the handler, it will not be added again. + pub fn with_signal(mut self, kind: SignalKind) -> Self { + if self.signals.iter().any(|(k, _)| k == &kind) { + return self; + } + + let signal = tokio::signal::unix::signal(kind).expect("failed to create signal"); + + self.signals.push((kind, signal)); + + self + } + + /// Add a signal to the handler. + /// + /// If the signal is already in the handler, it will not be added again. + pub fn add_signal(&mut self, kind: SignalKind) -> &mut Self { + if self.signals.iter().any(|(k, _)| k == &kind) { + return self; + } + + let signal = tokio::signal::unix::signal(kind).expect("failed to create signal"); + + self.signals.push((kind, signal)); + + self + } + + /// Wait for a signal to be received. + /// This is equivilant to calling (&mut handler).await, but is more + /// ergonomic if you want to not take ownership of the handler. + pub async fn recv(&mut self) -> SignalKind { + self.await + } +} + +impl std::future::Future for SignalHandler { + type Output = SignalKind; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + for (kind, signal) in self.signals.iter_mut() { + if signal.poll_recv(cx).is_ready() { + return Poll::Ready(*kind); + } + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn raise_signal(kind: SignalKind) { + // Safety: This is a test, and we control the process. + unsafe { + libc::raise(kind.as_raw_value()); + } + } + + #[tokio::test] + async fn test_signal_handler() { + let mut handler = SignalHandler::new() + .with_signal(SignalKind::user_defined1()) + .with_signal(SignalKind::user_defined2()); + + raise_signal(SignalKind::user_defined1()); + + let recv = tokio::time::timeout(tokio::time::Duration::from_millis(5), &mut handler) + .await + .unwrap(); + + assert_eq!(recv, SignalKind::user_defined1(), "expected SIGUSR1"); + + // We already received the signal, so polling again should return Poll::Pending + let recv = tokio::time::timeout(tokio::time::Duration::from_millis(5), &mut handler).await; + + assert!(recv.is_err(), "expected timeout"); + + raise_signal(SignalKind::user_defined2()); + + // We should be able to receive the signal again + let recv = tokio::time::timeout(tokio::time::Duration::from_millis(5), &mut handler) + .await + .unwrap(); + + assert_eq!(recv, SignalKind::user_defined2(), "expected SIGUSR2"); + } +} diff --git a/examples/settings/Cargo.toml b/examples/settings/Cargo.toml new file mode 100644 index 000000000..ffe24d9a9 --- /dev/null +++ b/examples/settings/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "settings-examples" +version = "0.1.0" +edition = "2021" + +[[example]] +name = "settings-cli" +required_features = ["cli"] +path = "src/cli.rs" + +[dependencies] +scuffle-settings = { path = "../../crates/settings" } +serde = "1" +serde_derive = "1" +smart-default = "0.7.1" + +[features] + +cli = ["scuffle-settings/cli"] \ No newline at end of file diff --git a/examples/settings/src/cli.rs b/examples/settings/src/cli.rs new file mode 100644 index 000000000..091963972 --- /dev/null +++ b/examples/settings/src/cli.rs @@ -0,0 +1,16 @@ +#[derive(Debug, serde_derive::Deserialize, smart_default::SmartDefault)] +#[serde(default)] +struct Config { + #[default = "baz"] + foo: String, + bar: i32, + baz: bool, +} + +fn main() { + let config = scuffle_settings::parse_settings::( + scuffle_settings::Options::builder().cli(scuffle_settings::cli!()).build(), + ); + + println!("{:#?}", config); +} diff --git a/rustfmt.toml b/rustfmt.toml index 21051fb38..29aed9c28 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -7,4 +7,4 @@ style_edition = "2021" format_macro_matchers = true hard_tabs = true reorder_impl_items = true -max_width = 125 \ No newline at end of file +max_width = 125