diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 0d5a76042..58e186d0b 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -61,7 +61,7 @@ jobs: if: github.event_name != 'pull_request' strategy: matrix: - project: [delta, bonfire, autumn, january] + project: [delta, bonfire, autumn, january, pushd] name: Build ${{ matrix.project }} image steps: # Configure build environment @@ -106,6 +106,10 @@ jobs: "january": { "path": "crates/services/january", "tag": "${{ github.repository_owner }}/january" + }, + "pushd": { + "path": "crates/daemons/pushd", + "tag": "${{ github.repository_owner }}/pushd" } } export_to: output diff --git a/Cargo.lock b/Cargo.lock index 6d80054b5..a3a8571b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,15 +23,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" -[[package]] -name = "aead" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b613b8e1e3cf911a086f53f03bf286f52fd7a7258e4fa606f0ef220d39d8877" -dependencies = [ - "generic-array 0.14.5", -] - [[package]] name = "aead" version = "0.5.2" @@ -42,18 +33,6 @@ dependencies = [ "generic-array 0.14.5", ] -[[package]] -name = "aes" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" -dependencies = [ - "cfg-if", - "cipher 0.3.0", - "cpufeatures", - "opaque-debug 0.3.0", -] - [[package]] name = "aes" version = "0.8.4" @@ -61,35 +40,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", - "cipher 0.4.4", + "cipher", "cpufeatures", ] -[[package]] -name = "aes-gcm" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc3be92e19a7ef47457b8e6f90707e12b6ac5d20c6f3866584fa3be0787d839f" -dependencies = [ - "aead 0.4.3", - "aes 0.7.5", - "cipher 0.3.0", - "ctr 0.7.0", - "ghash 0.4.4", - "subtle", -] - [[package]] name = "aes-gcm" version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" dependencies = [ - "aead 0.5.2", - "aes 0.8.4", - "cipher 0.4.4", - "ctr 0.9.2", - "ghash 0.5.1", + "aead", + "aes", + "cipher", + "ctr", + "ghash", "subtle", ] @@ -138,6 +103,31 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +[[package]] +name = "amqp_serde" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "787581044ca08ad61cdb3a4e21afba087fb8cdba3bb3e23ce69a7d091808014d" +dependencies = [ + "bytes 1.5.0", + "serde", + "serde_bytes_ng", +] + +[[package]] +name = "amqprs" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f1b4afcbd862e16c272b7625b6b057930b052d63c720bc90f6afab0d9abe8a8" +dependencies = [ + "amqp_serde", + "async-trait", + "bytes 1.5.0", + "serde", + "serde_bytes_ng", + "tokio 1.40.0", +] + [[package]] name = "android-tzdata" version = "0.1.1" @@ -264,7 +254,7 @@ dependencies = [ "futures-lite", "once_cell", "tokio 0.2.25", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -450,16 +440,16 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" dependencies = [ - "hermit-abi", + "hermit-abi 0.1.19", "libc", "winapi", ] [[package]] name = "authifier" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30269caf0aaf1e1b542b150030e9688bf41d50026e09a51efd9408f332636c9d" +checksum = "9ba4df3b5df5cf1a08d4af71c407fb56a675b6aaf4d1fec704da32595497d73d" dependencies = [ "async-std", "async-trait", @@ -557,7 +547,7 @@ dependencies = [ "http 0.2.12", "ring 0.17.8", "time", - "tokio 1.35.1", + "tokio 1.40.0", "tracing", "url", "zeroize", @@ -740,7 +730,7 @@ checksum = "62220bc6e97f946ddd51b5f1361f78996e704677afc518a4ff66b7a72ea1378c" dependencies = [ "futures-util", "pin-project-lite 0.2.13", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -838,7 +828,7 @@ dependencies = [ "pin-project-lite 0.2.13", "pin-utils", "rustls 0.21.12", - "tokio 1.35.1", + "tokio 1.40.0", "tracing", ] @@ -854,7 +844,7 @@ dependencies = [ "http 0.2.12", "http 1.1.0", "pin-project-lite 0.2.13", - "tokio 1.35.1", + "tokio 1.40.0", "tracing", "zeroize", ] @@ -881,8 +871,8 @@ dependencies = [ "ryu", "serde", "time", - "tokio 1.35.1", - "tokio-util 0.7.2", + "tokio 1.40.0", + "tokio-util", ] [[package]] @@ -927,7 +917,7 @@ dependencies = [ "matchit", "memchr", "mime", - "multer 3.1.0", + "multer", "percent-encoding", "pin-project-lite 0.2.13", "rustversion", @@ -936,7 +926,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sync_wrapper 1.0.1", - "tokio 1.35.1", + "tokio 1.40.0", "tower", "tower-layer", "tower-service", @@ -1014,7 +1004,7 @@ dependencies = [ "futures-util", "tempfile", "thiserror", - "tokio 1.35.1", + "tokio 1.40.0", "uuid 1.4.1", ] @@ -1296,7 +1286,7 @@ dependencies = [ "instant", "once_cell", "thiserror", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -1376,15 +1366,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "cipher" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" -dependencies = [ - "generic-array 0.14.5", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1432,8 +1413,8 @@ dependencies = [ "futures-core", "memchr", "pin-project-lite 0.2.13", - "tokio 1.35.1", - "tokio-util 0.7.2", + "tokio 1.40.0", + "tokio-util", ] [[package]] @@ -1493,18 +1474,11 @@ checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" [[package]] name = "cookie" -version = "0.16.0" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94d4706de1b0fa5b132270cddffa8585166037822e260a944fe161acd137ca05" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" dependencies = [ - "aes-gcm 0.9.2", - "base64 0.13.0", - "hkdf", - "hmac", "percent-encoding", - "rand 0.8.5", - "sha2", - "subtle", "time", "version_check", ] @@ -1685,22 +1659,13 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3b7eb4404b8195a9abb6356f4ac07d8ba267045c8d6d220ac4dc992e6cc75df" -[[package]] -name = "ctr" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a232f92a03f37dd7d7dd2adc67166c77e9cd88de5b019b9a9eecfaeaf7bfd481" -dependencies = [ - "cipher 0.3.0", -] - [[package]] name = "ctr" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" dependencies = [ - "cipher 0.4.4", + "cipher", ] [[package]] @@ -1870,7 +1835,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16a2561fd313df162315935989dceb8c99db4ee1933358270a57a3cfb8c957f3" dependencies = [ "crossbeam-queue", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -1946,9 +1911,9 @@ dependencies = [ [[package]] name = "devise" -version = "0.3.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50c7580b072f1c8476148f16e0a0d5dedddab787da98d86c5082c5e9ed8ab595" +checksum = "f1d90b0c4c777a2cad215e3c7be59ac7c15adf45cf76317009b7d096d46f651d" dependencies = [ "devise_codegen", "devise_core", @@ -1956,9 +1921,9 @@ dependencies = [ [[package]] name = "devise_codegen" -version = "0.3.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "123c73e7a6e51b05c75fe1a1b2f4e241399ea5740ed810b0e3e6cacd9db5e7b2" +checksum = "71b28680d8be17a570a2334922518be6adc3f58ecc880cbb404eaeb8624fd867" dependencies = [ "devise_core", "quote 1.0.37", @@ -1966,15 +1931,15 @@ dependencies = [ [[package]] name = "devise_core" -version = "0.3.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841ef46f4787d9097405cac4e70fb8644fc037b526e8c14054247c0263c400d0" +checksum = "b035a542cf7abf01f2e3c4d5a7acbaebfefe120ae4efc7bde3df98186e4b8af7" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", "proc-macro2", "proc-macro2-diagnostics", "quote 1.0.37", - "syn 1.0.107", + "syn 2.0.76", ] [[package]] @@ -2420,9 +2385,9 @@ dependencies = [ "redis-protocol", "semver 1.0.23", "socket2 0.5.5", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-stream", - "tokio-util 0.7.2", + "tokio-util", "url", "urlencoding", ] @@ -2514,7 +2479,7 @@ checksum = "45ec6fe3675af967e67c5536c0b9d44e34e6c52f86bedc4ea49c5317b8e94d06" dependencies = [ "futures-channel", "futures-task", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -2639,16 +2604,6 @@ dependencies = [ "syn 1.0.107", ] -[[package]] -name = "ghash" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99" -dependencies = [ - "opaque-debug 0.3.0", - "polyval 0.5.3", -] - [[package]] name = "ghash" version = "0.5.1" @@ -2656,7 +2611,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" dependencies = [ "opaque-debug 0.3.0", - "polyval 0.6.2", + "polyval", ] [[package]] @@ -2731,8 +2686,8 @@ dependencies = [ "http 0.2.12", "indexmap 2.0.1", "slab", - "tokio 1.35.1", - "tokio-util 0.7.2", + "tokio 1.40.0", + "tokio-util", "tracing", ] @@ -2750,8 +2705,8 @@ dependencies = [ "http 1.1.0", "indexmap 2.0.1", "slab", - "tokio 1.35.1", - "tokio-util 0.7.2", + "tokio 1.40.0", + "tokio-util", "tracing", ] @@ -2858,6 +2813,18 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + [[package]] name = "hex" version = "0.4.3" @@ -3026,7 +2993,7 @@ dependencies = [ "itoa", "pin-project-lite 0.2.13", "socket2 0.5.5", - "tokio 1.35.1", + "tokio 1.40.0", "tower-service", "tracing", "want", @@ -3049,7 +3016,7 @@ dependencies = [ "itoa", "pin-project-lite 0.2.13", "smallvec", - "tokio 1.35.1", + "tokio 1.40.0", "want", ] @@ -3065,7 +3032,7 @@ dependencies = [ "log", "rustls 0.21.12", "rustls-native-certs", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-rustls 0.24.1", ] @@ -3081,7 +3048,7 @@ dependencies = [ "hyper-util", "rustls 0.22.4", "rustls-pki-types", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-rustls 0.25.0", "tower-service", "webpki-roots 0.26.3", @@ -3096,7 +3063,7 @@ dependencies = [ "bytes 1.5.0", "hyper 0.14.30", "native-tls", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-native-tls", ] @@ -3111,7 +3078,7 @@ dependencies = [ "hyper 1.3.1", "hyper-util", "native-tls", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-native-tls", "tower-service", ] @@ -3130,7 +3097,7 @@ dependencies = [ "hyper 1.3.1", "pin-project-lite 0.2.13", "socket2 0.5.5", - "tokio 1.35.1", + "tokio 1.40.0", "tower", "tower-service", "tracing", @@ -3327,6 +3294,17 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi 0.4.0", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "isahc" version = "1.7.2" @@ -4014,13 +3992,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ + "hermit-abi 0.3.9", "libc", "wasi", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -4038,7 +4017,7 @@ dependencies = [ "log", "metrics", "thiserror", - "tokio 1.35.1", + "tokio 1.40.0", "tracing", "tracing-subscriber", ] @@ -4114,9 +4093,9 @@ dependencies = [ "strsim 0.10.0", "take_mut", "thiserror", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-rustls 0.23.4", - "tokio-util 0.7.2", + "tokio-util", "trust-dns-proto", "trust-dns-resolver", "typed-builder", @@ -4124,26 +4103,6 @@ dependencies = [ "webpki-roots 0.22.3", ] -[[package]] -name = "multer" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f8f35e687561d5c1667590911e6698a8cb714a134a7505718a182e7bc9d3836" -dependencies = [ - "bytes 1.5.0", - "encoding_rs", - "futures-util", - "http 0.2.12", - "httparse", - "log", - "memchr", - "mime", - "spin 0.9.8", - "tokio 1.35.1", - "tokio-util 0.6.10", - "version_check", -] - [[package]] name = "multer" version = "3.1.0" @@ -4158,6 +4117,8 @@ dependencies = [ "memchr", "mime", "spin 0.9.8", + "tokio 1.40.0", + "tokio-util", "version_check", ] @@ -4369,7 +4330,7 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" dependencies = [ - "hermit-abi", + "hermit-abi 0.1.19", "libc", ] @@ -4597,9 +4558,9 @@ dependencies = [ [[package]] name = "pear" -version = "0.2.3" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e44241c5e4c868e3eaa78b7c1848cadd6344ed4f54d029832d32b415a58702" +checksum = "bdeeaa00ce488657faba8ebf44ab9361f9365a97bd39ffb8a60663f57ff4b467" dependencies = [ "inlinable_string", "pear_codegen", @@ -4608,14 +4569,14 @@ dependencies = [ [[package]] name = "pear_codegen" -version = "0.2.3" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82a5ca643c2303ecb740d506539deba189e16f2754040a42901cd8105d0282d0" +checksum = "4bab5b985dc082b345f812b7df84e1bef27e7207b39e448439ba8bd69c93f147" dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote 1.0.37", - "syn 1.0.107", + "syn 2.0.76", ] [[package]] @@ -4894,18 +4855,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "polyval" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8419d2b623c7c0896ff2d5d96e2cb4ede590fed28fcc34934f4c33c036e620a1" -dependencies = [ - "cfg-if", - "cpufeatures", - "opaque-debug 0.3.0", - "universal-hash 0.4.0", -] - [[package]] name = "polyval" version = "0.6.2" @@ -4915,7 +4864,7 @@ dependencies = [ "cfg-if", "cpufeatures", "opaque-debug 0.3.0", - "universal-hash 0.5.1", + "universal-hash", ] [[package]] @@ -4985,13 +4934,13 @@ dependencies = [ [[package]] name = "proc-macro2-diagnostics" -version = "0.9.1" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bf29726d67464d49fa6224a1d07936a8c08bb3fba727c7493f6cf1616fdaada" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote 1.0.37", - "syn 1.0.107", + "syn 2.0.76", "version_check", "yansi", ] @@ -5333,8 +5282,8 @@ dependencies = [ "pin-project-lite 0.2.13", "ryu", "sha1_smol", - "tokio 1.35.1", - "tokio-util 0.7.2", + "tokio 1.40.0", + "tokio-util", "url", ] @@ -5353,8 +5302,8 @@ dependencies = [ "percent-encoding", "pin-project-lite 0.2.13", "ryu", - "tokio 1.35.1", - "tokio-util 0.7.2", + "tokio 1.40.0", + "tokio-util", "url", ] @@ -5494,7 +5443,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-native-tls", "url", "wasm-bindgen", @@ -5535,7 +5484,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 0.1.2", "system-configuration", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-native-tls", "tower-service", "url", @@ -5599,7 +5548,7 @@ dependencies = [ "simdutf8", "strum_macros", "tempfile", - "tokio 1.35.1", + "tokio 1.40.0", "tower-http", "tracing", "tracing-subscriber", @@ -5659,6 +5608,7 @@ dependencies = [ name = "revolt-database" version = "0.7.19" dependencies = [ + "amqprs", "async-lock 2.8.0", "async-recursion", "async-std", @@ -5707,6 +5657,7 @@ dependencies = [ name = "revolt-delta" version = "0.7.19" dependencies = [ + "amqprs", "async-channel 1.6.1", "async-std", "authifier", @@ -5753,7 +5704,7 @@ dependencies = [ name = "revolt-files" version = "0.7.19" dependencies = [ - "aes-gcm 0.10.3", + "aes-gcm", "aws-config", "aws-sdk-s3", "base64 0.22.1", @@ -5793,7 +5744,7 @@ dependencies = [ "serde", "serde_json", "tempfile", - "tokio 1.35.1", + "tokio 1.40.0", "tracing", "tracing-subscriber", "utoipa", @@ -5857,6 +5808,30 @@ dependencies = [ "redis-kiss", ] +[[package]] +name = "revolt-pushd" +version = "0.1.0" +dependencies = [ + "amqprs", + "async-trait", + "authifier", + "base64 0.22.1", + "fcm_v1", + "isahc", + "iso8601-timestamp 0.2.11", + "log", + "revolt-config", + "revolt-database", + "revolt-models", + "revolt_a2", + "revolt_optional_struct", + "serde", + "serde_json", + "tokio 1.40.0", + "ulid 1.1.3", + "web-push", +] + [[package]] name = "revolt-result" version = "0.7.19" @@ -5892,7 +5867,7 @@ dependencies = [ "serde", "serde_json", "thiserror", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -5925,9 +5900,9 @@ dependencies = [ [[package]] name = "revolt_rocket_okapi" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "275e1e9bd3343f75225cafa64f4bfb939c8b21c5f861141180fc0e24769ff6cf" +checksum = "cb113b281380c12c185c8d98c4887627ae6f7add16a510073382518ce34e42db" dependencies = [ "either", "log", @@ -6026,23 +6001,22 @@ dependencies = [ [[package]] name = "rocket" -version = "0.5.0-rc.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98ead083fce4a405feb349cf09abdf64471c6077f14e0ce59364aa90d4b99317" +checksum = "a516907296a31df7dc04310e7043b61d71954d703b603cc6867a026d7e72d73f" dependencies = [ "async-stream", "async-trait", "atomic", - "atty", "binascii", "bytes 1.5.0", "either", "figment", "futures", - "indexmap 1.9.3", + "indexmap 2.0.1", "log", "memchr", - "multer 2.0.2", + "multer", "num_cpus", "parking_lot", "pin-project-lite 0.2.13", @@ -6055,9 +6029,9 @@ dependencies = [ "state", "tempfile", "time", - "tokio 1.35.1", + "tokio 1.40.0", "tokio-stream", - "tokio-util 0.7.2", + "tokio-util", "ubyte", "version_check", "yansi", @@ -6065,9 +6039,9 @@ dependencies = [ [[package]] name = "rocket_authifier" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89a12311f60e9288833fc3ce6029bce5d5c61870ceef74d4a50668a8b520ad" +checksum = "810753b79106c44a4e76247fc7576b660663133a9e8f4b0afeb303589ec51d59" dependencies = [ "authifier", "iso8601-timestamp 0.1.10", @@ -6081,24 +6055,25 @@ dependencies = [ [[package]] name = "rocket_codegen" -version = "0.5.0-rc.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6aeb6bb9c61e9cd2c00d70ea267bf36f76a4cc615e5908b349c2f9d93999b47" +checksum = "575d32d7ec1a9770108c879fc7c47815a80073f96ca07ff9525a94fcede1dd46" dependencies = [ "devise", "glob", - "indexmap 1.9.3", + "indexmap 2.0.1", "proc-macro2", "quote 1.0.37", "rocket_http", - "syn 1.0.107", + "syn 2.0.76", "unicode-xid 0.2.3", + "version_check", ] [[package]] name = "rocket_cors" -version = "0.6.0-alpha1" -source = "git+https://github.com/lawliet89/rocket_cors?rev=c17e8145baa4790319fdb6a473e465b960f55e7c#c17e8145baa4790319fdb6a473e465b960f55e7c" +version = "0.6.0" +source = "git+https://github.com/lawliet89/rocket_cors?rev=072d90359b23e9b291df6b672c07c93de9c46011#072d90359b23e9b291df6b672c07c93de9c46011" dependencies = [ "http 0.2.12", "log", @@ -6113,9 +6088,9 @@ dependencies = [ [[package]] name = "rocket_empty" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c0922e47f981204fee38578a8efcf47a5c1a4ac0eb7f59e6bdfa3e61c8e3d69" +checksum = "97a55000e1ef5f4a9b20ae3d9de2a0bd22620c78ebd1aa568776ae12276125a6" dependencies = [ "revolt_okapi", "revolt_rocket_okapi", @@ -6124,16 +6099,16 @@ dependencies = [ [[package]] name = "rocket_http" -version = "0.5.0-rc.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ded65d127954de3c12471630bf4b81a2792f065984461e65b91d0fdaafc17a2" +checksum = "e274915a20ee3065f611c044bd63c40757396b6dbc057d6046aec27f14f882b9" dependencies = [ "cookie", "either", "futures", "http 0.2.12", "hyper 0.14.30", - "indexmap 1.9.3", + "indexmap 2.0.1", "log", "memchr", "pear", @@ -6145,7 +6120,7 @@ dependencies = [ "stable-pattern", "state", "time", - "tokio 1.35.1", + "tokio 1.40.0", "uncased", ] @@ -6589,7 +6564,7 @@ dependencies = [ "sentry-debug-images", "sentry-panic", "sentry-tracing", - "tokio 1.35.1", + "tokio 1.40.0", "ureq", ] @@ -6700,6 +6675,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_bytes_ng" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb0ebce8684e2253f964e8b6ce51f0ccc6666bbb448fb4a6788088bda6544b6" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" version = "1.0.209" @@ -7027,9 +7011,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "state" -version = "0.5.3" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbe866e1e51e8260c9eed836a042a5e7f6726bb2b411dffeaa712e19c388f23b" +checksum = "2b8c4a4445d81357df8b1a650d0d0d6fbbbfe99d064aa5e02f3e4022061476d8" dependencies = [ "loom", ] @@ -7419,28 +7403,27 @@ dependencies = [ [[package]] name = "tokio" -version = "1.35.1" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes 1.5.0", "libc", "mio", - "num_cpus", "parking_lot", "pin-project-lite 0.2.13", "signal-hook-registry", "socket2 0.5.5", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote 1.0.37", @@ -7454,7 +7437,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b" dependencies = [ "native-tls", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -7464,7 +7447,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" dependencies = [ "rustls 0.20.6", - "tokio 1.35.1", + "tokio 1.40.0", "webpki", ] @@ -7475,7 +7458,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ "rustls 0.21.12", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -7486,7 +7469,7 @@ checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" dependencies = [ "rustls 0.22.4", "rustls-pki-types", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -7497,21 +7480,7 @@ checksum = "50145484efff8818b5ccd256697f36863f587da82cf8b409c53adf1e840798e3" dependencies = [ "futures-core", "pin-project-lite 0.2.13", - "tokio 1.35.1", -] - -[[package]] -name = "tokio-util" -version = "0.6.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36943ee01a6d67977dd3f84a5a1d2efeb4ada3a1ae771cadfaa535d9d9fc6507" -dependencies = [ - "bytes 1.5.0", - "futures-core", - "futures-sink", - "log", - "pin-project-lite 0.2.13", - "tokio 1.35.1", + "tokio 1.40.0", ] [[package]] @@ -7525,7 +7494,7 @@ dependencies = [ "futures-io", "futures-sink", "pin-project-lite 0.2.13", - "tokio 1.35.1", + "tokio 1.40.0", "tracing", ] @@ -7594,7 +7563,7 @@ dependencies = [ "futures-util", "pin-project", "pin-project-lite 0.2.13", - "tokio 1.35.1", + "tokio 1.40.0", "tower-layer", "tower-service", "tracing", @@ -7727,7 +7696,7 @@ dependencies = [ "smallvec", "thiserror", "tinyvec", - "tokio 1.35.1", + "tokio 1.40.0", "url", ] @@ -7747,7 +7716,7 @@ dependencies = [ "resolv-conf", "smallvec", "thiserror", - "tokio 1.35.1", + "tokio 1.40.0", "trust-dns-proto", ] @@ -7962,16 +7931,6 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "957e51f3646910546462e67d5f7599b9e4fb8acdd304b087a6494730f9eebf04" -[[package]] -name = "universal-hash" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8326b2c654932e3e4f9196e69d08fdf7cfd718e1dc6f66b347e6024a0c961402" -dependencies = [ - "generic-array 0.14.5", - "subtle", -] - [[package]] name = "universal-hash" version = "0.5.1" @@ -8728,9 +8687,12 @@ dependencies = [ [[package]] name = "yansi" -version = "0.5.1" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +dependencies = [ + "is-terminal", +] [[package]] name = "yup-oauth2" @@ -8754,7 +8716,7 @@ dependencies = [ "serde", "serde_json", "time", - "tokio 1.35.1", + "tokio 1.40.0", "tower-service", "url", ] diff --git a/Cargo.toml b/Cargo.toml index 61f6fa514..2779c26bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,15 @@ [workspace] resolver = "2" + members = [ "crates/delta", "crates/bonfire", "crates/core/*", "crates/services/*", "crates/bindings/*", + "crates/daemons/pushd", ] [patch.crates-io] -# mobc-redis = { git = "https://github.com/insertish/mobc", rev = "8b880bb59f2ba80b4c7bc40c649c113d8857a186" } redis22 = { package = "redis", version = "0.22.3", git = "https://github.com/revoltchat/redis-rs", rev = "1a41faf356fd21aebba71cea7eb7eb2653e5f0ef" } redis23 = { package = "redis", version = "0.23.1", git = "https://github.com/revoltchat/redis-rs", rev = "f8ca28ab85da59d2ccde526b4d2fb390eff5a5f9" } -# authifier = { package = "authifier", version = "1.0.8", path = "../authifier/crates/authifier" } -# rocket_authifier = { package = "rocket_authifier", version = "1.0.8", path = "../authifier/crates/rocket_authifier" } diff --git a/Dockerfile b/Dockerfile index 57ef759ed..10112b667 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,6 +29,7 @@ COPY crates/core/presence/Cargo.toml ./crates/core/presence/ COPY crates/core/result/Cargo.toml ./crates/core/result/ COPY crates/services/autumn/Cargo.toml ./crates/services/autumn/ COPY crates/services/january/Cargo.toml ./crates/services/january/ +COPY crates/daemons/pushd/Cargo.toml ./crates/daemons/pushd/ RUN sh /tmp/build-image-layer.sh deps # Build all apps diff --git a/README.md b/README.md index 490aa9059..ea8cab84b 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ The services and libraries that power the Revolt service.
| `services/january` | [crates/services/january](crates/services/january) | Proxy server | ![License](https://img.shields.io/badge/license-AGPL--3.0--or--later-blue) | | `services/autumn` | [crates/services/autumn](crates/services/autumn) | File server | ![License](https://img.shields.io/badge/license-AGPL--3.0--or--later-blue) | | `bindings/node` | [crates/bindings/node](crates/bindings/node) | Node.js bindings for the Revolt software | ![License](https://img.shields.io/badge/license-AGPL--3.0--or--later-blue) | +| `daemons/pushd` | [crates/daemons/pushd](crates/daemons/pushd) | Push notification daemon server | ![License](https://img.shields.io/badge/license-AGPL--3.0--or--later-blue) |
@@ -55,11 +56,12 @@ As a heads-up, the development environment uses the following ports: | Service | Port | | ------------------------- | :------------: | -| MongoDB | 14017 | -| Redis | 14079 | +| MongoDB | 27017 | +| Redis | 6379 | | MinIO | 14009 | | Maildev | 14025
14080 | | Revolt Web App | 14701 | +| RabbitMQ | 5672
15672 | | `crates/delta` | 14702 | | `crates/bonfire` | 14703 | | `crates/services/autumn` | 14704 | @@ -106,6 +108,8 @@ cargo run --bin revolt-bonfire cargo run --bin revolt-autumn # run the proxy server cargo run --bin revolt-january +# run the push daemon (not usually needed in regular development) +cargo run --bin revolt-pushd # hint: # mold -run diff --git a/compose.yml b/compose.yml index 53d531ce4..31af45bad 100644 --- a/compose.yml +++ b/compose.yml @@ -3,13 +3,13 @@ services: redis: image: eqalpha/keydb ports: - - "14079:6379" + - "6379:6379" # MongoDB database: image: mongo ports: - - "14017:27017" + - "27017:27017" volumes: - ./.data/db:/data/db @@ -33,14 +33,25 @@ services: depends_on: - minio entrypoint: > - /bin/sh -c " - while ! /usr/bin/mc ready minio; do + /bin/sh -c "while ! /usr/bin/mc ready minio; do /usr/bin/mc config host add minio http://minio:9000 minioautumn minioautumn; echo 'Waiting minio...' && sleep 1; - done; - /usr/bin/mc mb minio/revolt-uploads; - exit 0; - " + done; /usr/bin/mc mb minio/revolt-uploads; exit 0;" + + # Rabbit + rabbit: + image: rabbitmq:3-management + environment: + RABBITMQ_DEFAULT_USER: rabbituser + RABBITMQ_DEFAULT_PASS: rabbitpass + volumes: + - ./.data/rabbit:/var/lib/rabbitmq + #- ./rabbit_plugins:/opt/rabbitmq/plugins/ + #- ./rabbit_enabled_plugins:/etc/rabbitmq/enabled_plugins + # uncomment this if you need to enable other plugins + ports: + - "5672:5672" + - "15672:15672" # management UI, for development # Mock SMTP server maildev: diff --git a/crates/bindings/node/src/lib.rs b/crates/bindings/node/src/lib.rs index a2030d73c..85b391834 100644 --- a/crates/bindings/node/src/lib.rs +++ b/crates/bindings/node/src/lib.rs @@ -7,25 +7,25 @@ use neon::prelude::*; use revolt_database::{Database, DatabaseInfo}; fn js_init(mut cx: FunctionContext) -> JsResult { - static INIT: OnceLock<()> = OnceLock::new(); - if INIT.get().is_none() { - INIT.get_or_init(|| { - async_std::task::block_on(async { - revolt_config::configure!(api); - - match DatabaseInfo::Auto.connect().await { - Ok(db) => { - let authifier_db = db.clone().to_authifier().await.database; - revolt_database::tasks::start_workers(db, authifier_db); - Ok(()) - } - Err(err) => Err(err), - } - }) - .or_else(|err| cx.throw_error(err)) - .unwrap(); - }); - } + // static INIT: OnceLock<()> = OnceLock::new(); + // if INIT.get().is_none() { + // INIT.get_or_init(|| { + // async_std::task::block_on(async { + // revolt_config::configure!(api); + + // match DatabaseInfo::Auto.connect().await { + // Ok(db) => { + // let authifier_db = db.clone().to_authifier().await.database; + // revolt_database::tasks::start_workers(db, authifier_db); + // Ok(()) + // } + // Err(err) => Err(err), + // } + // }) + // .or_else(|err| cx.throw_error(err)) + // .unwrap(); + // }); + // } Ok(cx.undefined()) } diff --git a/crates/bonfire/Cargo.toml b/crates/bonfire/Cargo.toml index 151d1ded7..c91723b32 100644 --- a/crates/bonfire/Cargo.toml +++ b/crates/bonfire/Cargo.toml @@ -36,7 +36,7 @@ async-std = { version = "1.8.0", features = [ ] } # core -authifier = { version = "1.0.8" } +authifier = { version = "1.0.9" } revolt-result = { path = "../core/result" } revolt-models = { path = "../core/models" } revolt-config = { path = "../core/config" } diff --git a/crates/core/config/Revolt.test.toml b/crates/core/config/Revolt.test.toml index 084a46ee0..7ac98ee26 100644 --- a/crates/core/config/Revolt.test.toml +++ b/crates/core/config/Revolt.test.toml @@ -1,3 +1,9 @@ [database] mongodb = "mongodb://localhost" redis = "redis://localhost/" + +[rabbit] +host = "127.0.0.1" +port = 5672 +username = "rabbituser" +password = "rabbitpass" diff --git a/crates/core/config/Revolt.toml b/crates/core/config/Revolt.toml index 359c0e998..e18165972 100644 --- a/crates/core/config/Revolt.toml +++ b/crates/core/config/Revolt.toml @@ -20,6 +20,12 @@ january = "http://local.revolt.chat/january" voso_legacy = "" voso_legacy_ws = "" +[rabbit] +host = "127.0.0.1" +port = 5672 +username = "guest" +password = "guest" + [api] [api.registration] @@ -38,17 +44,45 @@ from_address = "noreply@example.com" # port = 587 # use_tls = true -[api.vapid] -# Generate your own keys: -# 1. Run `openssl ecparam -name prime256v1 -genkey -noout -out vapid_private.pem` -# 2. Find `private_key` using `base64 vapid_private.pem` -# 3. Find `public_key` using `openssl ec -in vapid_private.pem -outform DER|tail -c 65|base64|tr '/+' '_-'|tr -d '\n'` + +[api.security] +# Authifier Shield API key +authifier_shield_key = "" +# Legacy voice server management token +voso_legacy_token = "" +# Whether services are behind the Cloudflare network +trust_cloudflare = false + +[api.security.captcha] +# hCaptcha configuration +hcaptcha_key = "" +hcaptcha_sitekey = "" + +[api.workers] +# Maximum concurrent connections (to proxy server) +max_concurrent_connections = 50 + +[pushd] +# this changes the names of the queues to not overlap +# prod/beta if they happen to be on the same exchange/instance. +# Usually they have to be, so that messages sent from one or the other get sent to everyone +production = true + +# none of these should need changing +exchange = "revolt.notifications" +message_queue = "notifications.origin.message" +fr_accepted_queue = "notifications.ingest.fr_accepted" # friend request accepted +fr_received_queue = "notifications.ingest.fr_received" # friend request received +generic_queue = "notifications.ingest.generic" # generic messages (title + body) +ack_queue = "notifications.process.ack" # updates badges for apple devices + +[pushd.vapid] +queue = "notifications.outbound.vapid" private_key = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJSUWpyTWxLRnBiVWhsUHpUbERvcEliYk1yeVNrNXpKYzVYVzIxSjJDS3hvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFWnkrQkg2TGJQZ2hEa3pEempXOG0rUXVPM3pCajRXT1phdkR6ZU00c0pqbmFwd1psTFE0WAp1ZDh2TzVodU94QWhMQlU3WWRldVovWHlBdFpWZmNyQi9BPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo" public_key = "BGcvgR-i2z4IQ5Mw841vJvkLjt8wY-FjmWrw83jOLCY52qcGZS0OF7nfLzuYbjsQISwVO2HXrmf18gLWVX3Kwfw=" -[api.fcm] -# Google Firebase Cloud Messaging Service Account Key -# Obtained from the cloud messaging console +[pushd.fcm] +queue = "notifications.outbound.fcm" key_type = "" project_id = "" private_key_id = "" @@ -60,29 +94,13 @@ token_uri = "" auth_provider_x509_cert_url = "" client_x509_cert_url = "" -[api.apn] -# Apple Push Notifications keys for sending notifications +[pushd.apn] sandbox = false +queue = "notifications.outbound.apn" pkcs8 = "" key_id = "" team_id = "" -[api.security] -# Authifier Shield API key -authifier_shield_key = "" -# Legacy voice server management token -voso_legacy_token = "" -# Whether services are behind the Cloudflare network -trust_cloudflare = false - -[api.security.captcha] -# hCaptcha configuration -hcaptcha_key = "" -hcaptcha_sitekey = "" - -[api.workers] -# Maximum concurrent connections (to proxy server) -max_concurrent_connections = 50 [files] # Encryption key for stored files @@ -149,10 +167,11 @@ region = "minio" access_key_id = "minioautumn" # S3 protocol access key secret_access_key = "minioautumn" -# Bucket to upload to by default default_bucket = "revolt-uploads" + [features] +# Bucket to upload to by default # Feature gate options webhooks_enabled = false @@ -228,6 +247,11 @@ icons = 2_500_000 banners = 6_000_000 emojis = 500_000 +[features.advanced] +# The max amount of messages the rabbitmq provider/db mention adder job will delay for before forcing handling of a channel. +# default: 5 +process_message_delay_limit = 5 + [sentry] # Configuration for Sentry error reporting api = "" diff --git a/crates/core/config/src/lib.rs b/crates/core/config/src/lib.rs index 8913dfb78..b103a0718 100644 --- a/crates/core/config/src/lib.rs +++ b/crates/core/config/src/lib.rs @@ -79,6 +79,14 @@ pub struct Database { pub redis: String, } +#[derive(Deserialize, Debug, Clone)] +pub struct Rabbit { + pub host: String, + pub port: u16, + pub username: String, + pub password: String, +} + #[derive(Deserialize, Debug, Clone)] pub struct Hosts { pub app: String, @@ -107,13 +115,15 @@ pub struct ApiSmtp { } #[derive(Deserialize, Debug, Clone)] -pub struct ApiVapid { +pub struct PushVapid { + pub queue: String, pub private_key: String, pub public_key: String, } #[derive(Deserialize, Debug, Clone)] -pub struct ApiFcm { +pub struct PushFcm { + pub queue: String, pub key_type: String, pub project_id: String, pub private_key_id: String, @@ -127,7 +137,8 @@ pub struct ApiFcm { } #[derive(Deserialize, Debug, Clone)] -pub struct ApiApn { +pub struct PushApn { + pub queue: String, pub sandbox: bool, pub pkcs8: String, pub key_id: String, @@ -157,13 +168,54 @@ pub struct ApiWorkers { pub struct Api { pub registration: ApiRegistration, pub smtp: ApiSmtp, - pub vapid: ApiVapid, - pub fcm: ApiFcm, - pub apn: ApiApn, pub security: ApiSecurity, pub workers: ApiWorkers, } +#[derive(Deserialize, Debug, Clone)] +pub struct Pushd { + pub production: bool, + pub exchange: String, + pub message_queue: String, + pub fr_accepted_queue: String, + pub fr_received_queue: String, + pub generic_queue: String, + pub ack_queue: String, + + pub vapid: PushVapid, + pub fcm: PushFcm, + pub apn: PushApn, +} + +impl Pushd { + fn get_routing_key(&self, key: String) -> String { + match self.production { + true => key + "-prd", + false => key + "-tst", + } + } + + pub fn get_ack_routing_key(&self) -> String { + self.get_routing_key(self.ack_queue.clone()) + } + + pub fn get_message_routing_key(&self) -> String { + self.get_routing_key(self.message_queue.clone()) + } + + pub fn get_fr_accepted_routing_key(&self) -> String { + self.get_routing_key(self.fr_accepted_queue.clone()) + } + + pub fn get_fr_received_routing_key(&self) -> String { + self.get_routing_key(self.fr_received_queue.clone()) + } + + pub fn get_generic_routing_key(&self) -> String { + self.get_routing_key(self.generic_queue.clone()) + } +} + #[derive(Deserialize, Debug, Clone)] pub struct FilesLimit { pub min_file_size: usize, @@ -233,10 +285,26 @@ pub struct FeaturesLimitsCollection { pub roles: HashMap, } +#[derive(Deserialize, Debug, Clone)] +pub struct FeaturesAdvanced { + #[serde(default)] + pub process_message_delay_limit: u16, +} + +impl Default for FeaturesAdvanced { + fn default() -> Self { + Self { + process_message_delay_limit: 5, + } + } +} + #[derive(Deserialize, Debug, Clone)] pub struct Features { pub limits: FeaturesLimitsCollection, pub webhooks_enabled: bool, + #[serde(default)] + pub advanced: FeaturesAdvanced, } #[derive(Deserialize, Debug, Clone)] @@ -250,8 +318,10 @@ pub struct Sentry { #[derive(Deserialize, Debug, Clone)] pub struct Settings { pub database: Database, + pub rabbit: Rabbit, pub hosts: Hosts, pub api: Api, + pub pushd: Pushd, pub files: Files, pub features: Features, pub sentry: Sentry, diff --git a/crates/core/database/Cargo.toml b/crates/core/database/Cargo.toml index d1b280685..2be799feb 100644 --- a/crates/core/database/Cargo.toml +++ b/crates/core/database/Cargo.toml @@ -84,11 +84,11 @@ axum = { version = "0.7.5", optional = true } # Rocket Impl schemars = { version = "0.8.8", optional = true } -rocket = { version = "0.5.0-rc.2", default-features = false, features = [ +rocket = { version = "0.5.1", default-features = false, features = [ "json", ], optional = true } revolt_okapi = { version = "0.9.1", optional = true } -revolt_rocket_okapi = { version = "0.9.1", optional = true } +revolt_rocket_okapi = { version = "0.10.0", optional = true } # Notifications fcm_v1 = "0.3.0" @@ -96,4 +96,7 @@ web-push = "0.10.0" revolt_a2 = { version = "0.10", default-features = false, features = ["ring"] } # Authifier -authifier = { version = "1.0.8" } +authifier = { version = "1.0.9", features = ["rocket_impl"] } + +# RabbitMQ +amqprs = { version = "1.7.0" } diff --git a/crates/core/database/src/amqp/amqp.rs b/crates/core/database/src/amqp/amqp.rs new file mode 100644 index 000000000..39b4820eb --- /dev/null +++ b/crates/core/database/src/amqp/amqp.rs @@ -0,0 +1,211 @@ +use std::collections::HashSet; + +use crate::events::rabbit::*; +use crate::User; +use amqprs::channel::BasicPublishArguments; +use amqprs::{channel::Channel, connection::Connection, error::Error as AMQPError}; +use amqprs::{BasicProperties, FieldTable}; +use revolt_models::v0::PushNotification; +use revolt_presence::filter_online; + +use serde_json::to_string; + +#[derive(Clone)] +pub struct AMQP { + #[allow(unused)] + connection: Connection, + channel: Channel, +} + +impl AMQP { + pub fn new(connection: Connection, channel: Channel) -> AMQP { + AMQP { + connection, + channel, + } + } + + pub async fn friend_request_accepted( + &self, + accepted_request_user: &User, + sent_request_user: &User, + ) -> Result<(), AMQPError> { + let config = revolt_config::config().await; + let payload = FRAcceptedPayload { + accepted_user: accepted_request_user.to_owned(), + user: sent_request_user.id.clone(), + }; + let payload = to_string(&payload).unwrap(); + + debug!( + "Sending friend request accept payload on channel {}: {}", + config.pushd.get_fr_accepted_routing_key(), + payload + ); + self.channel + .basic_publish( + BasicProperties::default() + .with_content_type("application/json") + .with_persistence(true) + .finish(), + payload.into(), + BasicPublishArguments::new( + &config.pushd.exchange, + &config.pushd.get_fr_accepted_routing_key(), + ), + ) + .await + } + + pub async fn friend_request_received( + &self, + received_request_user: &User, + sent_request_user: &User, + ) -> Result<(), AMQPError> { + let config = revolt_config::config().await; + let payload = FRReceivedPayload { + from_user: sent_request_user.to_owned(), + user: received_request_user.id.clone(), + }; + let payload = to_string(&payload).unwrap(); + + debug!( + "Sending friend request received payload on channel {}: {}", + config.pushd.get_fr_received_routing_key(), + payload + ); + + self.channel + .basic_publish( + BasicProperties::default() + .with_content_type("application/json") + .with_persistence(true) + .finish(), + payload.into(), + BasicPublishArguments::new( + &config.pushd.exchange, + &config.pushd.get_fr_received_routing_key(), + ), + ) + .await + } + + pub async fn generic_message( + &self, + user: &User, + title: String, + body: String, + icon: Option, + ) -> Result<(), AMQPError> { + let config = revolt_config::config().await; + let payload = GenericPayload { + title, + body, + icon, + user: user.to_owned(), + }; + let payload = to_string(&payload).unwrap(); + + debug!( + "Sending generic payload on channel {}: {}", + config.pushd.get_generic_routing_key(), + payload + ); + + self.channel + .basic_publish( + BasicProperties::default() + .with_content_type("application/json") + .with_persistence(true) + .finish(), + payload.into(), + BasicPublishArguments::new( + &config.pushd.exchange, + &config.pushd.get_generic_routing_key(), + ), + ) + .await + } + + pub async fn message_sent( + &self, + recipients: Vec, + payload: PushNotification, + ) -> Result<(), AMQPError> { + if recipients.is_empty() { + return Ok(()); + } + + let config = revolt_config::config().await; + + let online_ids = filter_online(&recipients).await; + let recipients = (&recipients.into_iter().collect::>() - &online_ids) + .into_iter() + .collect::>(); + + let payload = MessageSentPayload { + notification: payload, + users: recipients, + }; + let payload = to_string(&payload).unwrap(); + + debug!( + "Sending message payload on channel {}: {}", + config.pushd.get_message_routing_key(), + payload + ); + + self.channel + .basic_publish( + BasicProperties::default() + .with_content_type("application/json") + .with_persistence(true) + .finish(), + payload.into(), + BasicPublishArguments::new( + &config.pushd.exchange, + &config.pushd.get_message_routing_key(), + ), + ) + .await + } + + pub async fn ack_message( + &self, + user_id: String, + channel_id: String, + message_id: String, + ) -> Result<(), AMQPError> { + let config = revolt_config::config().await; + + let payload = AckPayload { + user_id: user_id.clone(), + channel_id: channel_id.clone(), + message_id, + }; + let payload = to_string(&payload).unwrap(); + + info!( + "Sending ack payload on channel {}: {}", + config.pushd.ack_queue, payload + ); + + let mut headers = FieldTable::new(); + headers.insert( + "x-deduplication-header".try_into().unwrap(), + format!("{}-{}", &user_id, &channel_id).into(), + ); + + self.channel + .basic_publish( + BasicProperties::default() + .with_content_type("application/json") + .with_persistence(true) + //.with_headers(headers) + .finish(), + payload.into(), + BasicPublishArguments::new(&config.pushd.exchange, &config.pushd.ack_queue), + ) + .await + } +} diff --git a/crates/core/database/src/amqp/mod.rs b/crates/core/database/src/amqp/mod.rs new file mode 100644 index 000000000..9ed09d345 --- /dev/null +++ b/crates/core/database/src/amqp/mod.rs @@ -0,0 +1,2 @@ +#[allow(clippy::module_inception)] +pub mod amqp; diff --git a/crates/core/database/src/events/mod.rs b/crates/core/database/src/events/mod.rs index c07f47e0f..608984357 100644 --- a/crates/core/database/src/events/mod.rs +++ b/crates/core/database/src/events/mod.rs @@ -1,2 +1,3 @@ pub mod client; +pub mod rabbit; pub mod server; diff --git a/crates/core/database/src/events/rabbit.rs b/crates/core/database/src/events/rabbit.rs new file mode 100644 index 000000000..613047521 --- /dev/null +++ b/crates/core/database/src/events/rabbit.rs @@ -0,0 +1,59 @@ +use std::collections::HashMap; + +use revolt_models::v0::PushNotification; +use serde::{Deserialize, Serialize}; + +use crate::User; + +#[derive(Serialize, Deserialize)] +pub struct MessageSentPayload { + pub notification: PushNotification, + pub users: Vec, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct FRAcceptedPayload { + pub accepted_user: User, + pub user: String, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct FRReceivedPayload { + pub from_user: User, + pub user: String, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct GenericPayload { + pub title: String, + pub body: String, + pub icon: Option, + pub user: User, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +#[allow(clippy::large_enum_variant)] +pub enum PayloadKind { + MessageNotification(PushNotification), + FRAccepted(FRAcceptedPayload), + FRReceived(FRReceivedPayload), + BadgeUpdate(usize), + Generic(GenericPayload), +} + +#[derive(Serialize, Deserialize)] +pub struct PayloadToService { + pub notification: PayloadKind, + pub user_id: String, + pub session_id: String, + pub token: String, + pub extras: HashMap, +} + +#[derive(Serialize, Deserialize)] +pub struct AckPayload { + pub user_id: String, + pub channel_id: String, + pub message_id: String, +} diff --git a/crates/core/database/src/lib.rs b/crates/core/database/src/lib.rs index e48821b40..48f93f81a 100644 --- a/crates/core/database/src/lib.rs +++ b/crates/core/database/src/lib.rs @@ -85,6 +85,9 @@ pub use models::*; pub mod events; pub mod tasks; +mod amqp; +pub use amqp::amqp::AMQP; + /// Utility function to check if a boolean value is false pub fn if_false(t: &bool) -> bool { !t diff --git a/crates/core/database/src/models/channel_unreads/ops.rs b/crates/core/database/src/models/channel_unreads/ops.rs index 4d115069e..6a6e98af3 100644 --- a/crates/core/database/src/models/channel_unreads/ops.rs +++ b/crates/core/database/src/models/channel_unreads/ops.rs @@ -26,6 +26,9 @@ pub trait AbstractChannelUnreads: Sync + Send { message_ids: &[String], ) -> Result<()>; + /// Fetch all unreads with mentions for a user. + async fn fetch_unread_mentions(&self, user_id: &str) -> Result>; + /// Fetch all channel unreads for a user. async fn fetch_unreads(&self, user_id: &str) -> Result>; diff --git a/crates/core/database/src/models/channel_unreads/ops/mongodb.rs b/crates/core/database/src/models/channel_unreads/ops/mongodb.rs index e2dfedb17..c237afc65 100644 --- a/crates/core/database/src/models/channel_unreads/ops/mongodb.rs +++ b/crates/core/database/src/models/channel_unreads/ops/mongodb.rs @@ -123,6 +123,18 @@ impl AbstractChannelUnreads for MongoDb { ) } + async fn fetch_unread_mentions(&self, user_id: &str) -> Result> { + query! { + self, + find, + COL, + doc! { + "_id.user": user_id, + "mentions": {"$ne": null} + } + } + } + /// Fetch unread for a specific user in a channel. async fn fetch_unread(&self, user_id: &str, channel_id: &str) -> Result> { query!( @@ -135,5 +147,4 @@ impl AbstractChannelUnreads for MongoDb { } ) } - } diff --git a/crates/core/database/src/models/channel_unreads/ops/reference.rs b/crates/core/database/src/models/channel_unreads/ops/reference.rs index b8a95d03e..914ac95db 100644 --- a/crates/core/database/src/models/channel_unreads/ops/reference.rs +++ b/crates/core/database/src/models/channel_unreads/ops/reference.rs @@ -78,6 +78,15 @@ impl AbstractChannelUnreads for ReferenceDb { Ok(()) } + async fn fetch_unread_mentions(&self, user_id: &str) -> Result> { + let unreads = self.channel_unreads.lock().await; + Ok(unreads + .values() + .filter(|unread| unread.id.user == user_id && unread.mentions.is_some()) + .cloned() + .collect()) + } + /// Fetch all channel unreads for a user. async fn fetch_unreads(&self, user_id: &str) -> Result> { let unreads = self.channel_unreads.lock().await; @@ -92,9 +101,11 @@ impl AbstractChannelUnreads for ReferenceDb { async fn fetch_unread(&self, user_id: &str, channel_id: &str) -> Result> { let unreads = self.channel_unreads.lock().await; - Ok(unreads.get(&ChannelCompositeKey { - channel: channel_id.to_string(), - user: user_id.to_string() - }).cloned()) + Ok(unreads + .get(&ChannelCompositeKey { + channel: channel_id.to_string(), + user: user_id.to_string(), + }) + .cloned()) } } diff --git a/crates/core/database/src/models/channels/model.rs b/crates/core/database/src/models/channels/model.rs index cd065818d..f981ff1c1 100644 --- a/crates/core/database/src/models/channels/model.rs +++ b/crates/core/database/src/models/channels/model.rs @@ -9,7 +9,7 @@ use ulid::Ulid; use crate::{ events::client::EventV1, tasks::ack::AckEvent, Database, File, IntoDocumentPath, PartialServer, - Server, SystemMessage, User, + Server, SystemMessage, User, AMQP, }; auto_derived!( @@ -337,6 +337,7 @@ impl Channel { pub async fn add_user_to_group( &mut self, db: &Database, + amqp: &AMQP, user: &User, by_id: &str, ) -> Result<()> { @@ -373,6 +374,7 @@ impl Channel { .into_message(id.to_string()) .send( db, + Some(amqp), MessageAuthor::System { username: &user.username, avatar: user.avatar.as_ref().map(|file| file.id.as_ref()), @@ -639,7 +641,7 @@ impl Channel { .private(user.to_string()) .await; - crate::tasks::ack::queue( + crate::tasks::ack::queue_ack( self.id().to_string(), user.to_string(), AckEvent::AckMessage { @@ -655,6 +657,7 @@ impl Channel { pub async fn remove_user_from_group( &self, db: &Database, + amqp: &AMQP, user: &User, by_id: Option<&str>, silent: bool, @@ -686,6 +689,7 @@ impl Channel { .into_message(id.to_string()) .send( db, + Some(amqp), MessageAuthor::System { username: name, avatar: None, @@ -725,6 +729,7 @@ impl Channel { .into_message(id.to_string()) .send( db, + Some(amqp), MessageAuthor::System { username: &user.username, avatar: user.avatar.as_ref().map(|file| file.id.as_ref()), diff --git a/crates/core/database/src/models/messages/model.rs b/crates/core/database/src/models/messages/model.rs index e813fa94f..725d46d1d 100644 --- a/crates/core/database/src/models/messages/model.rs +++ b/crates/core/database/src/models/messages/model.rs @@ -15,8 +15,8 @@ use validator::Validate; use crate::{ events::client::EventV1, tasks::{self, ack::AckEvent}, - util::idempotency::IdempotencyKey, - Channel, Database, Emoji, File, User, + util::{bulk_permissions::BulkDatabasePermissionQuery, idempotency::IdempotencyKey}, + Channel, Database, Emoji, File, User, AMQP, }; auto_derived_partial!( @@ -230,6 +230,7 @@ impl Message { #[allow(clippy::too_many_arguments)] pub async fn create_from_api( db: &Database, + amqp: Option<&AMQP>, channel: Channel, data: DataMessageSend, author: MessageAuthor<'_>, @@ -337,35 +338,52 @@ impl Message { } } + // Validate the mentions go to users in the channel/server if !mentions.is_empty() { - // FIXME: temp fix to stop spam attacks match channel { Channel::DirectMessage { ref recipients, .. } | Channel::Group { ref recipients, .. } => { let recipients_hash: HashSet<&String, RandomState> = - HashSet::from_iter(recipients.iter()); - + HashSet::from_iter(recipients); mentions.retain(|m| recipients_hash.contains(m)); } Channel::TextChannel { ref server, .. } | Channel::VoiceChannel { ref server, .. } => { let mentions_vec = Vec::from_iter(mentions.iter().cloned()); + let valid_members = db.fetch_members(server.as_str(), &mentions_vec[..]).await; if let Ok(valid_members) = valid_members { - let valid_ids: HashSet = HashSet::from_iter( - valid_members.iter().map(|member| member.id.user.clone()), - ); - mentions.retain(|m| valid_ids.contains(m)); + let valid_mentions: HashSet<&String, RandomState> = + HashSet::from_iter(valid_members.iter().map(|m| &m.id.user)); + + mentions.retain(|m| valid_mentions.contains(m)); // quick pass, validate mentions are in the server + + if !mentions.is_empty() { + // if there are still mentions, drill down to a channel-level + let member_channel_view_perms = + BulkDatabasePermissionQuery::from_server_id(db, server) + .await + .channel(&channel) + .members(&valid_members) + .members_can_see_channel() + .await; + + mentions + .retain(|m| *member_channel_view_perms.get(m).unwrap_or(&false)); + } } else { revolt_config::capture_error(&valid_members.unwrap_err()); + return Err(create_error!(InternalError)); } } - Channel::SavedMessages { .. } => mentions.clear(), + Channel::SavedMessages { .. } => { + mentions.clear(); + } } + } - if !mentions.is_empty() { - message.mentions.replace(mentions.into_iter().collect()); - } + if !mentions.is_empty() { + message.mentions.replace(mentions.into_iter().collect()); } if !replies.is_empty() { @@ -418,7 +436,7 @@ impl Message { // Send the message message - .send(db, author, user, member, &channel, generate_embeds) + .send(db, amqp, author, user, member, &channel, generate_embeds) .await?; Ok(message) @@ -432,6 +450,9 @@ impl Message { member: Option, is_dm: bool, generate_embeds: bool, + // This determines if this function should queue the mentions task or if somewhere else will. + // If this is true, you MUST call tasks::ack::queue yourself. + mentions_elsewhere: bool, ) -> Result<()> { db.insert_message(self).await?; @@ -444,13 +465,12 @@ impl Message { tasks::last_message_id::queue(self.channel.to_string(), self.id.to_string(), is_dm).await; // Add mentions for affected users - if let Some(mentions) = &self.mentions { - for user in mentions { - tasks::ack::queue( + if !mentions_elsewhere { + if let Some(mentions) = &self.mentions { + tasks::ack::queue_message( self.channel.to_string(), - user.to_string(), - AckEvent::AddMention { - ids: vec![self.id.to_string()], + AckEvent::ProcessMessage { + messages: vec![(None, self.clone(), mentions.clone(), true)], }, ) .await; @@ -473,9 +493,11 @@ impl Message { } /// Send a message + #[allow(clippy::too_many_arguments)] pub async fn send( &mut self, db: &Database, + amqp: Option<&AMQP>, // this is optional mostly for tests. author: MessageAuthor<'_>, user: Option, member: Option, @@ -488,26 +510,36 @@ impl Message { member.clone(), matches!(channel, Channel::DirectMessage { .. }), generate_embeds, + true, ) .await?; if !self.has_suppressed_notifications() { - // Push out Web Push notifications - crate::tasks::web_push::queue( - { - match channel { - Channel::DirectMessage { recipients, .. } - | Channel::Group { recipients, .. } => recipients.clone(), - Channel::TextChannel { .. } => self.mentions.clone().unwrap_or_default(), - _ => vec![], - } + // send Push notifications + tasks::ack::queue_message( + self.channel.to_string(), + AckEvent::ProcessMessage { + messages: vec![( + Some( + PushNotification::from( + self.clone().into_model(user, member), + Some(author), + channel.to_owned().into(), + ) + .await, + ), + self.clone(), + match channel { + Channel::DirectMessage { recipients, .. } + | Channel::Group { recipients, .. } => recipients.clone(), + Channel::TextChannel { .. } => { + self.mentions.clone().unwrap_or_default() + } + _ => vec![], + }, + self.has_suppressed_notifications(), + )], }, - PushNotification::from( - self.clone().into_model(user, member), - Some(author), - channel.id(), - ) - .await, ) .await; } diff --git a/crates/core/database/src/models/server_members/model.rs b/crates/core/database/src/models/server_members/model.rs index 88f953cf8..72e161629 100644 --- a/crates/core/database/src/models/server_members/model.rs +++ b/crates/core/database/src/models/server_members/model.rs @@ -150,7 +150,7 @@ impl Member { id: user.id.clone(), } .into_message(id.to_string()) - .send_without_notifications(db, None, None, false, false) + .send_without_notifications(db, None, None, false, false, false) .await .ok(); } @@ -251,7 +251,7 @@ impl Member { } .into_message(id.to_string()) // TODO: support notifications here in the future? - .send_without_notifications(db, None, None, false, false) + .send_without_notifications(db, None, None, false, false, false) .await .ok(); } diff --git a/crates/core/database/src/models/server_members/ops/reference.rs b/crates/core/database/src/models/server_members/ops/reference.rs index 88c9c7b25..f038477ec 100644 --- a/crates/core/database/src/models/server_members/ops/reference.rs +++ b/crates/core/database/src/models/server_members/ops/reference.rs @@ -53,17 +53,17 @@ impl AbstractServerMembers for ReferenceDb { /// Fetch multiple members by their ids async fn fetch_members<'a>(&self, server_id: &str, ids: &'a [String]) -> Result> { let server_members = self.server_members.lock().await; - ids.iter() - .map(|id| { + Ok(ids + .iter() + .filter_map(|id| { server_members .get(&MemberCompositeKey { server: server_id.to_string(), user: id.to_string(), }) .cloned() - .ok_or_else(|| create_error!(NotFound)) }) - .collect() + .collect()) } /// Fetch member count of a server diff --git a/crates/core/database/src/models/users/model.rs b/crates/core/database/src/models/users/model.rs index 9a250ad67..059250282 100644 --- a/crates/core/database/src/models/users/model.rs +++ b/crates/core/database/src/models/users/model.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, str::FromStr, time::Duration}; -use crate::{events::client::EventV1, Database, File, RatelimitEvent}; +use crate::{events::client::EventV1, Database, File, RatelimitEvent, AMQP}; use authifier::config::{EmailVerificationConfig, Template}; use iso8601_timestamp::Timestamp; @@ -497,7 +497,12 @@ impl User { } /// Add another user as a friend - pub async fn add_friend(&mut self, db: &Database, target: &mut User) -> Result<()> { + pub async fn add_friend( + &mut self, + db: &Database, + amqp: &AMQP, + target: &mut User, + ) -> Result<()> { match self.relationship_with(&target.id) { RelationshipStatus::User => Err(create_error!(NoEffect)), RelationshipStatus::Friend => Err(create_error!(AlreadyFriends)), @@ -506,6 +511,8 @@ impl User { RelationshipStatus::BlockedOther => Err(create_error!(BlockedByOther)), RelationshipStatus::Incoming => { // Accept incoming friend request + _ = amqp.friend_request_accepted(self, target).await; + self.apply_relationship( db, target, @@ -534,6 +541,8 @@ impl User { })); } + _ = amqp.friend_request_received(target, self).await; + // Send the friend request self.apply_relationship( db, diff --git a/crates/core/database/src/models/users/rocket.rs b/crates/core/database/src/models/users/rocket.rs index ab8866947..57b634b15 100644 --- a/crates/core/database/src/models/users/rocket.rs +++ b/crates/core/database/src/models/users/rocket.rs @@ -38,7 +38,7 @@ impl<'r> FromRequest<'r> for User { if let Some(user) = user { Outcome::Success(user.clone()) } else { - Outcome::Failure((Status::Unauthorized, authifier::Error::InvalidSession)) + Outcome::Error((Status::Unauthorized, authifier::Error::InvalidSession)) } } } diff --git a/crates/core/database/src/tasks/ack.rs b/crates/core/database/src/tasks/ack.rs index b5c90fdd7..89971f661 100644 --- a/crates/core/database/src/tasks/ack.rs +++ b/crates/core/database/src/tasks/ack.rs @@ -1,21 +1,27 @@ // Queue Type: Debounced -use crate::Database; +use crate::{Database, Message, AMQP}; use deadqueue::limited::Queue; use once_cell::sync::Lazy; -use std::{collections::HashMap, time::Duration}; +use revolt_models::v0::PushNotification; +use rocket::form::validate::Contains; +use std::{ + collections::{HashMap, HashSet}, + time::Duration, +}; +use validator::HasLen; use revolt_result::Result; -use super::{apple_notifications::{self, ApnJob}, DelayedTask}; +use super::DelayedTask; /// Enumeration of possible events #[derive(Debug, Eq, PartialEq)] pub enum AckEvent { - /// Add mentions for a user in a channel - AddMention { - /// Message IDs - ids: Vec, + /// Add mentions for a channel + ProcessMessage { + /// push notification, message, recipients, push silenced + messages: Vec<(Option, Message, Vec, bool)>, }, /// Acknowledge message in a channel for a user @@ -30,7 +36,7 @@ struct Data { /// Channel to ack channel: String, /// User to ack for - user: String, + user: Option, /// Event event: AckEvent, } @@ -43,21 +49,49 @@ struct Task { static Q: Lazy> = Lazy::new(|| Queue::new(10_000)); /// Queue a new task for a worker -pub async fn queue(channel: String, user: String, event: AckEvent) { +pub async fn queue_ack(channel: String, user: String, event: AckEvent) { Q.try_push(Data { channel, - user, + user: Some(user), event, }) .ok(); - info!("Queue is using {} slots from {}.", Q.len(), Q.capacity()); + info!( + "Queue is using {} slots from {}. Queued type: ACK", + Q.len(), + Q.capacity() + ); } -pub async fn handle_ack_event(event: &AckEvent, db: &Database, authifier_db: &authifier::Database, user: &str, channel: &str) -> Result<()> { +pub async fn queue_message(channel: String, event: AckEvent) { + Q.try_push(Data { + channel, + user: None, + event, + }) + .ok(); + + info!( + "Queue is using {} slots from {}. Queued type: MENTION", + Q.len(), + Q.capacity() + ); +} + +pub async fn handle_ack_event( + event: &AckEvent, + db: &Database, + amqp: &AMQP, + user: &Option, + channel: &str, +) -> Result<()> { match &event { #[allow(clippy::disallowed_methods)] // event is sent by higher level function AckEvent::AckMessage { id } => { + let user = user.as_ref().unwrap(); + let user: &str = user.as_str(); + let unread = db.fetch_unread(user, channel).await?; let updated = db.acknowledge_message(channel, user, id).await?; @@ -68,21 +102,67 @@ pub async fn handle_ack_event(event: &AckEvent, db: &Database, authifier_db: &au let mentions_acked = before_mentions - after_mentions; if mentions_acked > 0 { - if let Ok(sessions) = authifier_db.find_sessions(user).await { - for session in sessions { - if let Some(sub) = session.subscription { - if sub.endpoint == "apn" { - apple_notifications::queue(ApnJob::from_ack(session.id, user.to_string(), sub.auth)).await; - } - } - } + if let Err(err) = amqp + .ack_message(user.to_string(), channel.to_string(), id.to_owned()) + .await + { + revolt_config::capture_error(&err); } }; + } + } + AckEvent::ProcessMessage { messages } => { + let mut users: HashSet<&String> = HashSet::new(); + debug!( + "Processing {} messages from channel {}", + messages.len(), + messages[0].1.channel + ); + + // find all the users we'll be notifying + messages + .iter() + .for_each(|(_, _, recipents, _)| users.extend(recipents.iter())); + + debug!("Found {} users to notify.", users.len()); + + for user in users { + let message_ids: Vec = messages + .iter() + .filter(|(_, _, recipients, _)| recipients.contains(user)) + .map(|(_, message, _, _)| message.id.clone()) + .collect(); + + if !message_ids.is_empty() { + db.add_mention_to_unread(channel, user, &message_ids) + .await?; + } + debug!("Added {} mentions for user {}", message_ids.len(), &user); + } + for (push, _, recipients, silenced) in messages { + if *silenced || recipients.is_empty() || push.is_none() { + debug!( + "Rejecting push: silenced: {}, recipient count: {}, push exists: {:?}", + *silenced, + recipients.length(), + push + ); + continue; + } + + debug!( + "Sending push event to AMQP; message {} for {} users", + push.as_ref().unwrap().message.id, + recipients.len() + ); + if let Err(err) = amqp + .message_sent(recipients.clone(), push.clone().unwrap()) + .await + { + revolt_config::capture_error(&err); + } } - }, - AckEvent::AddMention { ids } => { - db.add_mention_to_unread(channel, user, ids).await?; } }; @@ -90,9 +170,9 @@ pub async fn handle_ack_event(event: &AckEvent, db: &Database, authifier_db: &au } /// Start a new worker -pub async fn worker(db: Database, authifier_db: authifier::Database) { - let mut tasks = HashMap::<(String, String), DelayedTask>::new(); - let mut keys = vec![]; +pub async fn worker(db: Database, amqp: AMQP) { + let mut tasks = HashMap::<(Option, String, u8), DelayedTask>::new(); + let mut keys: Vec<(Option, String, u8)> = vec![]; loop { // Find due tasks. @@ -106,12 +186,13 @@ pub async fn worker(db: Database, authifier_db: authifier::Database) { for key in &keys { if let Some(task) = tasks.remove(key) { let Task { event } = task.data; - let (user, channel) = key; + let (user, channel, _) = key; - if let Err(err) = handle_ack_event(&event, &db, &authifier_db, user, channel).await { - error!("{err:?} for {event:?}. ({user}, {channel})"); + if let Err(err) = handle_ack_event(&event, &db, &amqp, user, channel).await { + revolt_config::capture_error(&err); + error!("{err:?} for {event:?}. ({user:?}, {channel})"); } else { - info!("User {user} ack in {channel} with {event:?}"); + info!("User {user:?} ack in {channel} with {event:?}"); } } } @@ -126,20 +207,41 @@ pub async fn worker(db: Database, authifier_db: authifier::Database) { mut event, }) = Q.try_pop() { - let key = (user, channel); + let key: (Option, String, u8) = ( + user, + channel, + match &event { + AckEvent::AckMessage { .. } => 0, + AckEvent::ProcessMessage { .. } => 1, + }, + ); if let Some(task) = tasks.get_mut(&key) { - task.delay(); - match &mut event { - AckEvent::AddMention { ids } => { - if let AckEvent::AddMention { ids: existing } = &mut task.data.event { - existing.append(ids); + AckEvent::ProcessMessage { messages: new_data } => { + if let AckEvent::ProcessMessage { messages: existing } = + &mut task.data.event + { + // add the new message to the list of messages to be processed. + existing.append(new_data); + + // put a cap on the amount of messages that can be queued, for particularly active channels + if (existing.length() as u16) + < revolt_config::config() + .await + .features + .advanced + .process_message_delay_limit + { + task.delay(); + } } else { - task.data.event = event; + panic!("Somehow got an ack message in the add mention arm"); } } AckEvent::AckMessage { .. } => { + // replace the last acked message with the new acked message task.data.event = event; + task.delay(); } } } else { diff --git a/crates/core/database/src/tasks/apple_notifications.rs b/crates/core/database/src/tasks/apple_notifications.rs deleted file mode 100644 index 918967f79..000000000 --- a/crates/core/database/src/tasks/apple_notifications.rs +++ /dev/null @@ -1,314 +0,0 @@ -use std::io::Cursor; - -use base64::{ - engine::{self}, - Engine as _, -}; -use deadqueue::limited::Queue; -use once_cell::sync::Lazy; -use revolt_a2::{ - request::{ - notification::{DefaultAlert, NotificationOptions}, - payload::{APSAlert, APSSound, PayloadLike, APS}, - }, - Client, ClientConfig, Endpoint, Error, ErrorBody, ErrorReason, Priority, PushType, Response, -}; -use revolt_config::config; -use revolt_models::v0::{Message, PushNotification}; - -use crate::Database; - -/// Payload information, before assembly -#[derive(Debug)] -#[allow(non_snake_case)] -pub struct ApnPayload { - message: Message, - url: String, - authorAvatar: String, - authorDisplayName: String, - channelName: String, -} - -#[derive(Serialize, Debug)] -#[allow(non_snake_case)] -struct Payload<'a> { - aps: APS<'a>, - #[serde(skip_serializing)] - options: NotificationOptions<'a>, - #[serde(skip_serializing)] - device_token: &'a str, - - message: &'a Message, - url: &'a str, - authorAvatar: &'a str, - authorDisplayName: &'a str, - channelName: &'a str, -} - -impl<'a> PayloadLike for Payload<'a> { - fn get_device_token(&self) -> &'a str { - self.device_token - } - fn get_options(&self) -> &NotificationOptions { - &self.options - } -} - -/// Task information -#[derive(Debug)] -pub struct AlertJob { - /// Session Id - session_id: String, - - /// Device token - device_token: String, - - /// User Id - user_id: String, - - /// Title - title: String, - - /// Body - body: String, - - /// Thread Id - thread_id: String, - - /// Category (informs the client what kind of notification is being sent.) - category: String, - - /// Payload used by the iOS client to modify the notification - custom_payload: ApnPayload, -} - -impl AlertJob { - fn format_title(notification: &PushNotification) -> String { - // ideally this changes depending on context - // in a server, it would look like "Sendername, #channelname in servername" - // in a group, it would look like "Sendername in groupname" - // in a dm it should just be "Sendername". - // not sure how feasible all those are given the PushNotification object as it currently stands. - format!( - "{} in {}", - notification.author, notification.message.channel - ) // TODO: this absolutely needs a channel name - } -} - -#[derive(Debug)] -pub struct BadgeJob { - /// Session Id - session_id: String, - - /// Device token - device_token: String, - - /// User Id - user_id: String, -} - -#[derive(Debug)] -pub enum JobType { - Alert(AlertJob), - Badge(BadgeJob), -} - -#[derive(Debug)] -pub struct ApnJob { - job_type: JobType, -} - -impl ApnJob { - pub fn from_notification( - session_id: String, - user_id: String, - device_token: String, - notification: &PushNotification, - ) -> ApnJob { - ApnJob { - job_type: JobType::Alert(AlertJob { - session_id, - device_token, - user_id, - title: AlertJob::format_title(notification), - body: notification.body.to_string(), - thread_id: notification.tag.to_string(), - category: "ALERT_MESSAGE".to_string(), - custom_payload: ApnPayload { - message: notification.message.clone(), - url: notification.url.clone(), - authorAvatar: notification.icon.clone(), - authorDisplayName: notification.author.clone(), - channelName: "#fetchchannelnamehere".to_string(), // TODO: get actual channel name - }, - }), - } - } - - pub fn from_ack(session_id: String, user_id: String, device_token: String) -> ApnJob { - ApnJob { - job_type: JobType::Badge(BadgeJob { - session_id, - device_token, - user_id, - }), - } - } -} - -enum AssembledPayload<'a> { - Alert(Payload<'a>), - Default(revolt_a2::request::payload::Payload<'a>), -} - -static Q: Lazy> = Lazy::new(|| Queue::new(10_000)); - -/// Queue a new task for a worker -pub async fn queue(task: ApnJob) { - Q.try_push(task).ok(); - info!("Queue is using {} slots from {}.", Q.len(), Q.capacity()); -} - -async fn get_badge_count(db: &Database, user: &str) -> Option { - if let Ok(unreads) = db.fetch_unreads(user).await { - let mut mention_count = 0; - for channel in unreads { - if let Some(mentions) = channel.mentions { - mention_count += mentions.len() as u32 - } - } - - return Some(mention_count); - } - None -} - -/// Start a new worker -pub async fn worker(db: Database) { - let config = config().await; - if config.api.apn.pkcs8.is_empty() - || config.api.apn.key_id.is_empty() - || config.api.apn.team_id.is_empty() - { - eprintln!("Missing APN keys."); - return; - } - - let endpoint = if config.api.apn.sandbox { - Endpoint::Sandbox - } else { - Endpoint::Production - }; - - let pkcs8 = engine::general_purpose::STANDARD - .decode(config.api.apn.pkcs8) - .expect("valid `pcks8`"); - - let client_config = ClientConfig::new(endpoint); - - let client = Client::token( - &mut Cursor::new(pkcs8), - config.api.apn.key_id, - config.api.apn.team_id, - client_config, - ) - .expect("could not create APN client"); - - let payload_options = NotificationOptions { - apns_id: None, - apns_push_type: Some(PushType::Alert), - apns_expiration: None, - apns_priority: Some(Priority::High), - apns_topic: Some("chat.revolt.app"), - apns_collapse_id: None, - }; - - loop { - let task = Q.pop().await; - let payload: AssembledPayload; - - match task.job_type { - JobType::Alert(ref alert) => { - payload = AssembledPayload::Alert(Payload { - aps: APS { - alert: Some(APSAlert::Default(DefaultAlert { - title: Some(&alert.title), - subtitle: None, - body: Some(&alert.body), - title_loc_key: None, - title_loc_args: None, - action_loc_key: None, - loc_key: None, - loc_args: None, - launch_image: None, - })), - badge: get_badge_count(&db, &alert.user_id).await, - sound: Some(APSSound::Sound("default")), - thread_id: Some(&alert.thread_id), - content_available: None, - category: Some(&alert.category), - mutable_content: Some(1), - url_args: None, - }, - device_token: &alert.device_token, - options: payload_options.clone(), - message: &alert.custom_payload.message, - url: &alert.custom_payload.url, - authorAvatar: &alert.custom_payload.authorAvatar, - authorDisplayName: &alert.custom_payload.authorDisplayName, - channelName: &alert.custom_payload.channelName, - }); - } - JobType::Badge(ref alert) => { - payload = AssembledPayload::Default(revolt_a2::request::payload::Payload { - aps: APS { - alert: None, - badge: get_badge_count(&db, &alert.user_id).await, - sound: None, - thread_id: None, - content_available: None, - category: None, - mutable_content: None, - url_args: None, - }, - device_token: &alert.device_token, - options: payload_options.clone(), - data: std::collections::BTreeMap::new(), - }) - } - } - - let resp = match payload { - AssembledPayload::Alert(p) => client.send(p).await, - AssembledPayload::Default(p) => client.send(p).await, - }; - //println!("response from APNS: {:?}", resp); - - if let Err(err) = resp { - match err { - Error::ResponseError(Response { - error: - Some(ErrorBody { - reason: ErrorReason::BadDeviceToken | ErrorReason::Unregistered, - .. - }), - .. - }) => { - if let Err(err) = db - .remove_push_subscription_by_session_id(match task.job_type { - JobType::Alert(ref a) => &a.session_id.as_str(), - JobType::Badge(ref a) => &a.session_id.as_str(), - }) - .await - { - revolt_config::capture_error(&err); - } - } - err => { - revolt_config::capture_error(&err); - } - } - } - } -} diff --git a/crates/core/database/src/tasks/mod.rs b/crates/core/database/src/tasks/mod.rs index d89887563..81bfd4e1f 100644 --- a/crates/core/database/src/tasks/mod.rs +++ b/crates/core/database/src/tasks/mod.rs @@ -1,6 +1,6 @@ //! Semi-important background task management -use crate::Database; +use crate::{Database, AMQP}; use async_std::task; use std::time::Instant; @@ -8,22 +8,18 @@ use std::time::Instant; const WORKER_COUNT: usize = 5; pub mod ack; -pub mod apple_notifications; pub mod authifier_relay; pub mod last_message_id; pub mod process_embeds; -pub mod web_push; /// Spawn background workers -pub fn start_workers(db: Database, authifier_db: authifier::Database) { +pub fn start_workers(db: Database, amqp: AMQP) { task::spawn(authifier_relay::worker()); - task::spawn(apple_notifications::worker(db.clone())); for _ in 0..WORKER_COUNT { - task::spawn(ack::worker(db.clone(), authifier_db.clone())); + task::spawn(ack::worker(db.clone(), amqp.clone())); task::spawn(last_message_id::worker(db.clone())); task::spawn(process_embeds::worker(db.clone())); - task::spawn(web_push::worker(db.clone(), authifier_db.clone())); } } diff --git a/crates/core/database/src/tasks/web_push.rs b/crates/core/database/src/tasks/web_push.rs deleted file mode 100644 index 72e22ecf0..000000000 --- a/crates/core/database/src/tasks/web_push.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - time::Duration, -}; - -use authifier::Database as AuthifierDatabase; -use base64::{ - engine::{self}, - Engine as _, -}; -use deadqueue::limited::Queue; -use fcm_v1::auth::{Authenticator, ServiceAccountKey}; -use once_cell::sync::Lazy; -use revolt_config::{config, report_internal_error}; -use revolt_models::v0::PushNotification; -use revolt_presence::filter_online; -use serde_json::json; -use web_push::{ - ContentEncoding, IsahcWebPushClient, SubscriptionInfo, SubscriptionKeys, VapidSignatureBuilder, - WebPushClient, WebPushMessageBuilder, -}; - -use crate::Database; - -use super::apple_notifications; - -/// Task information -#[derive(Debug)] -struct PushTask { - /// User IDs of the targets that are to receive this notification - recipients: Vec, - /// Push Notification - payload: PushNotification, -} - -static Q: Lazy> = Lazy::new(|| Queue::new(10_000)); - -/// Queue a new task for a worker -pub async fn queue(recipients: Vec, payload: PushNotification) { - if recipients.is_empty() { - return; - } - - let online_ids = filter_online(&recipients).await; - let recipients = (&recipients.into_iter().collect::>() - &online_ids) - .into_iter() - .collect::>(); - - Q.try_push(PushTask { - recipients, - payload, - }) - .ok(); - - info!("Queue is using {} slots from {}.", Q.len(), Q.capacity()); -} - -/// Start a new worker -pub async fn worker(db: Database, authifier_db: AuthifierDatabase) { - let config = config().await; - - let web_push_client = IsahcWebPushClient::new().unwrap(); - let fcm_client = if config.api.fcm.key_type.is_empty() { - None - } else { - Some(fcm_v1::Client::new( - Authenticator::service_account::<&str>(ServiceAccountKey { - key_type: Some(config.api.fcm.key_type), - project_id: Some(config.api.fcm.project_id.clone()), - private_key_id: Some(config.api.fcm.private_key_id), - private_key: config.api.fcm.private_key, - client_email: config.api.fcm.client_email, - client_id: Some(config.api.fcm.client_id), - auth_uri: Some(config.api.fcm.auth_uri), - token_uri: config.api.fcm.token_uri, - auth_provider_x509_cert_url: Some(config.api.fcm.auth_provider_x509_cert_url), - client_x509_cert_url: Some(config.api.fcm.client_x509_cert_url), - }) - .await - .unwrap(), - config.api.fcm.project_id, - false, - Duration::from_secs(5), - )) - }; - - let web_push_private_key = engine::general_purpose::URL_SAFE_NO_PAD - .decode(config.api.vapid.private_key) - .expect("valid `VAPID_PRIVATE_KEY`"); - - loop { - let task = Q.pop().await; - - if let Ok(sessions) = authifier_db - .find_sessions_with_subscription(&task.recipients) - .await - { - for session in sessions { - if let Some(sub) = session.subscription { - if sub.endpoint == "fcm" { - // Use Firebase Cloud Messaging - if let Some(client) = &fcm_client { - let message = fcm_v1::message::Message { - token: Some(sub.auth), - data: Some(HashMap::from([( - "payload".to_owned(), - serde_json::Value::String(json!(&task.payload).to_string()), - )])), - ..Default::default() - }; - - if let Err(err) = client.send(&message).await { - error!("Failed to send FCM notification! {:?}", err); - - if let fcm_v1::Error::FCM(fcm_error) = err { - if fcm_error.contains("404 (Not Found)") { - println!("Unregistering {:?}", session.id); - - report_internal_error!( - db.remove_push_subscription_by_session_id(&session.id) - .await - ) - .ok(); - } - } - } else { - info!("Sent FCM notification to {:?}.", session.id); - } - } else { - info!("No FCM token was specified!"); - } - } else if sub.endpoint == "apn" { - apple_notifications::queue(apple_notifications::ApnJob::from_notification( - session.id, - session.user_id, - sub.auth, - &task.payload, - )) - .await; - } else { - // Use Web Push Standard - let subscription = SubscriptionInfo { - endpoint: sub.endpoint, - keys: SubscriptionKeys { - auth: sub.auth, - p256dh: sub.p256dh, - }, - }; - - match VapidSignatureBuilder::from_pem( - std::io::Cursor::new(&web_push_private_key), - &subscription, - ) { - Ok(sig_builder) => match sig_builder.build() { - Ok(signature) => { - let mut builder = WebPushMessageBuilder::new(&subscription); - builder.set_vapid_signature(signature); - - let payload = json!(task.payload).to_string(); - builder - .set_payload(ContentEncoding::AesGcm, payload.as_bytes()); - - match builder.build() { - Ok(msg) => match web_push_client.send(msg).await { - Ok(_) => { - info!( - "Sent Web Push notification to {:?}.", - session.id - ) - } - Err(err) => { - error!("Hit error sending Web Push! {:?}", err) - } - }, - Err(err) => { - error!( - "Failed to build message for {}! {:?}", - session.user_id, err - ) - } - } - } - Err(err) => error!( - "Failed to build signature for {}! {:?}", - session.user_id, err - ), - }, - Err(err) => error!( - "Failed to create signature builder for {}! {:?}", - session.user_id, err - ), - } - } - } - } - } - } -} diff --git a/crates/core/database/src/util/bulk_permissions.rs b/crates/core/database/src/util/bulk_permissions.rs new file mode 100644 index 000000000..400359772 --- /dev/null +++ b/crates/core/database/src/util/bulk_permissions.rs @@ -0,0 +1,337 @@ +use std::{collections::HashMap, hash::RandomState}; + +use revolt_permissions::{ + ChannelPermission, ChannelType, Override, OverrideField, PermissionValue, ALLOW_IN_TIMEOUT, + DEFAULT_PERMISSION_DIRECT_MESSAGE, +}; + +use crate::{Channel, Database, Member, Server, User}; + +#[derive(Clone)] +pub struct BulkDatabasePermissionQuery<'a> { + #[allow(dead_code)] + database: &'a Database, + + server: Server, + channel: Option, + users: Option>, + members: Option>, + + // In case the users or members are fetched as part of the permissions checking operation + pub(crate) cached_users: Option>, + pub(crate) cached_members: Option>, + + cached_member_perms: Option>, +} + +impl<'z, 'x> BulkDatabasePermissionQuery<'x> { + pub async fn members_can_see_channel(&'z mut self) -> HashMap + where + 'z: 'x, + { + let member_perms = if self.cached_member_perms.is_some() { + // This isn't done as an if let to prevent borrow checker errors with the mut self call when the perms aren't cached. + let perms = self.cached_member_perms.as_ref().unwrap(); + perms + .iter() + .map(|(m, p)| { + ( + m.clone(), + p.has_channel_permission(ChannelPermission::ViewChannel), + ) + }) + .collect() + } else { + calculate_members_permissions(self) + .await + .iter() + .map(|(m, p)| { + ( + m.clone(), + p.has_channel_permission(ChannelPermission::ViewChannel), + ) + }) + .collect() + }; + member_perms + } +} + +impl<'z> BulkDatabasePermissionQuery<'z> { + pub fn new(database: &Database, server: Server) -> BulkDatabasePermissionQuery<'_> { + BulkDatabasePermissionQuery { + database, + server, + channel: None, + users: None, + members: None, + cached_members: None, + cached_users: None, + cached_member_perms: None, + } + } + + pub async fn from_server_id<'a>( + db: &'a Database, + server: &str, + ) -> BulkDatabasePermissionQuery<'a> { + BulkDatabasePermissionQuery { + database: db, + server: db.fetch_server(server).await.unwrap(), + channel: None, + users: None, + members: None, + cached_members: None, + cached_users: None, + cached_member_perms: None, + } + } + + pub fn channel(self, channel: &'z Channel) -> BulkDatabasePermissionQuery { + BulkDatabasePermissionQuery { + channel: Some(channel.clone()), + ..self + } + } + + pub fn members(self, members: &'z [Member]) -> BulkDatabasePermissionQuery { + BulkDatabasePermissionQuery { + members: Some(members.to_owned()), + ..self + } + } + + pub fn users(self, users: &'z [User]) -> BulkDatabasePermissionQuery { + BulkDatabasePermissionQuery { + users: Some(users.to_owned()), + ..self + } + } + + /// Get the default channel permissions + /// Group channel defaults should be mapped to an allow-only override + #[allow(dead_code)] + async fn get_default_channel_permissions(&mut self) -> Override { + if let Some(channel) = &self.channel { + match channel { + Channel::Group { permissions, .. } => Override { + allow: permissions.unwrap_or(*DEFAULT_PERMISSION_DIRECT_MESSAGE as i64) as u64, + deny: 0, + }, + Channel::TextChannel { + default_permissions, + .. + } + | Channel::VoiceChannel { + default_permissions, + .. + } => default_permissions.unwrap_or_default().into(), + _ => Default::default(), + } + } else { + Default::default() + } + } + + #[allow(dead_code)] + fn get_channel_type(&mut self) -> ChannelType { + if let Some(channel) = &self.channel { + match channel { + Channel::DirectMessage { .. } => ChannelType::DirectMessage, + Channel::Group { .. } => ChannelType::Group, + Channel::SavedMessages { .. } => ChannelType::SavedMessages, + Channel::TextChannel { .. } | Channel::VoiceChannel { .. } => { + ChannelType::ServerChannel + } + } + } else { + ChannelType::Unknown + } + } + + /// Get the ordered role overrides (from lowest to highest) for this member in this channel + #[allow(dead_code)] + async fn get_channel_role_overrides(&mut self) -> &HashMap { + if let Some(channel) = &self.channel { + match channel { + Channel::TextChannel { + role_permissions, .. + } + | Channel::VoiceChannel { + role_permissions, .. + } => role_permissions, + _ => panic!("Not supported for non-server channels"), + } + } else { + panic!("No channel added to query") + } + } +} + +/// Calculate members permissions in a server channel. +async fn calculate_members_permissions<'a>( + query: &'a mut BulkDatabasePermissionQuery<'a>, +) -> HashMap { + let mut resp = HashMap::new(); + + let (_, channel_role_permissions, channel_default_permissions) = match query + .channel + .as_ref() + .expect("A channel must be assigned to calculate channel permissions") + .clone() + { + Channel::TextChannel { + id, + role_permissions, + default_permissions, + .. + } + | Channel::VoiceChannel { + id, + role_permissions, + default_permissions, + .. + } => (id, role_permissions, default_permissions), + _ => panic!("Calculation of member permissions must be done on a server channel"), + }; + + if query.users.is_none() { + let ids: Vec = query + .members + .as_ref() + .expect("No users or members added to the query") + .iter() + .map(|m| m.id.user.clone()) + .collect(); + + query.cached_users = Some( + query + .database + .fetch_users(&ids[..]) + .await + .expect("Failed to get data from the db"), + ); + + query.users = Some(query.cached_users.as_ref().unwrap().to_vec()) + } + + let users = query.users.as_ref().unwrap(); + + if query.members.is_none() { + let ids: Vec = query + .users + .as_ref() + .expect("No users or members added to the query") + .iter() + .map(|m| m.id.clone()) + .collect(); + + query.cached_members = Some( + query + .database + .fetch_members(&query.server.id, &ids[..]) + .await + .expect("Failed to get data from the db"), + ); + query.members = Some(query.cached_members.as_ref().unwrap().to_vec()) + } + + let members: HashMap<&String, &Member, RandomState> = HashMap::from_iter( + query + .members + .as_ref() + .unwrap() + .iter() + .map(|m| (&m.id.user, m)), + ); + + for user in users { + let member = members.get(&user.id); + + // User isn't a part of the server + if member.is_none() { + resp.insert(user.id.clone(), 0_u64.into()); + continue; + } + + let member = *member.unwrap(); + + if user.privileged { + resp.insert( + user.id.clone(), + PermissionValue::from(ChannelPermission::GrantAllSafe), + ); + continue; + } + + if user.id == query.server.owner { + resp.insert( + user.id.clone(), + PermissionValue::from(ChannelPermission::GrantAllSafe), + ); + continue; + } + + // Get the user's server permissions + let mut permission = calculate_server_permissions(&query.server, user, member); + + if let Some(defaults) = channel_default_permissions { + permission.apply(defaults.into()); + } + + // Get the applicable role overrides + let mut roles = channel_role_permissions + .iter() + .filter(|(id, _)| member.roles.contains(id)) + .filter_map(|(id, permission)| { + query.server.roles.get(id).map(|role| { + let v: Override = (*permission).into(); + (role.rank, v) + }) + }) + .collect::>(); + + roles.sort_by(|a, b| b.0.cmp(&a.0)); + let overrides = roles.into_iter().map(|(_, v)| v); + + for role_override in overrides { + permission.apply(role_override) + } + + resp.insert(user.id.clone(), permission); + } + + resp +} + +/// Calculates a member's server permissions +fn calculate_server_permissions(server: &Server, user: &User, member: &Member) -> PermissionValue { + if user.privileged || server.owner == user.id { + return ChannelPermission::GrantAllSafe.into(); + } + + let mut permissions: PermissionValue = server.default_permissions.into(); + + let mut roles = server + .roles + .iter() + .filter(|(id, _)| member.roles.contains(id)) + .map(|(_, role)| { + let v: Override = role.permissions.into(); + (role.rank, v) + }) + .collect::>(); + + roles.sort_by(|a, b| b.0.cmp(&a.0)); + let role_overrides: Vec = roles.into_iter().map(|(_, v)| v).collect(); + + for role in role_overrides { + permissions.apply(role); + } + + if member.in_timeout() { + permissions.restrict(*ALLOW_IN_TIMEOUT); + } + + permissions +} diff --git a/crates/core/database/src/util/idempotency.rs b/crates/core/database/src/util/idempotency.rs index d95236d77..df68000d9 100644 --- a/crates/core/database/src/util/idempotency.rs +++ b/crates/core/database/src/util/idempotency.rs @@ -102,7 +102,7 @@ impl<'r> FromRequest<'r> for IdempotencyKey { .map(|k| k.to_string()) { if key.len() > 64 { - return Outcome::Failure(( + return Outcome::Error(( Status::BadRequest, create_error!(FailedValidation { error: "idempotency key too long".to_string(), @@ -113,7 +113,7 @@ impl<'r> FromRequest<'r> for IdempotencyKey { let idempotency = IdempotencyKey { key }; let mut cache = TOKEN_CACHE.lock().await; if cache.get(&idempotency.key).is_some() { - return Outcome::Failure((Status::Conflict, create_error!(DuplicateNonce))); + return Outcome::Error((Status::Conflict, create_error!(DuplicateNonce))); } cache.put(idempotency.key.clone(), ()); diff --git a/crates/core/database/src/util/mod.rs b/crates/core/database/src/util/mod.rs index 26cf436e6..1baf7d8ba 100644 --- a/crates/core/database/src/util/mod.rs +++ b/crates/core/database/src/util/mod.rs @@ -1,4 +1,5 @@ pub mod bridge; +pub mod bulk_permissions; pub mod idempotency; pub mod permissions; pub mod reference; diff --git a/crates/core/models/src/v0/channels.rs b/crates/core/models/src/v0/channels.rs index dfabb2af7..154ca48ca 100644 --- a/crates/core/models/src/v0/channels.rs +++ b/crates/core/models/src/v0/channels.rs @@ -306,4 +306,18 @@ impl Channel { | Channel::VoiceChannel { id, .. } => id, } } + + /// This returns a Result because the recipient name can't be determined here without a db call, + /// which can't be done since this is models, which can't reference the database crate. + /// + /// If it returns Err, you need to fetch the name from the db. + pub fn name(&self) -> Result<&str, ()> { + match self { + Channel::DirectMessage { .. } => Err(()), + Channel::SavedMessages { .. } => Ok("Saved Messages"), + Channel::TextChannel { name, .. } + | Channel::Group { name, .. } + | Channel::VoiceChannel { name, .. } => Ok(name), + } + } } diff --git a/crates/core/models/src/v0/messages.rs b/crates/core/models/src/v0/messages.rs index ef67b09fd..b8c53aeed 100644 --- a/crates/core/models/src/v0/messages.rs +++ b/crates/core/models/src/v0/messages.rs @@ -13,7 +13,7 @@ use rocket::{FromForm, FromFormField}; use iso8601_timestamp::Timestamp; -use super::{Embed, File, Member, MessageWebhook, User, Webhook, RE_COLOUR}; +use super::{Channel, Embed, File, Member, MessageWebhook, User, Webhook, RE_COLOUR}; pub static RE_MENTION: Lazy = Lazy::new(|| Regex::new(r"<@([0-9A-HJKMNP-TV-Z]{26})>").unwrap()); @@ -209,6 +209,8 @@ auto_derived!( pub url: String, /// The message object itself, to send to clients for processing pub message: Message, + /// The channel object itself, for clients to process + pub channel: Channel, } /// Representation of a text embed before it is sent. @@ -369,7 +371,7 @@ auto_derived!( /// Optional fields on message pub enum FieldsMessage { - Pinned + Pinned, } ); @@ -442,7 +444,7 @@ impl From for String { impl PushNotification { /// Create a new notification from a given message, author and channel ID - pub async fn from(msg: Message, author: Option>, channel_id: &str) -> Self { + pub async fn from(msg: Message, author: Option>, channel: Channel) -> Self { let config = config().await; let icon = if let Some(author) = &author { @@ -496,10 +498,11 @@ impl PushNotification { icon, image, body, - tag: channel_id.to_string(), + tag: channel.id().to_string(), timestamp, - url: format!("{}/channel/{}/{}", config.hosts.app, channel_id, msg.id), + url: format!("{}/channel/{}/{}", config.hosts.app, channel.id(), msg.id), message: msg, + channel, } } } diff --git a/crates/core/result/Cargo.toml b/crates/core/result/Cargo.toml index 5129ed3b2..daf49fac7 100644 --- a/crates/core/result/Cargo.toml +++ b/crates/core/result/Cargo.toml @@ -29,7 +29,7 @@ utoipa = { version = "4.2.3", optional = true } # Rocket rocket = { optional = true, version = "0.5.0-rc.2", default-features = false } -revolt_rocket_okapi = { version = "0.9.1", optional = true } +revolt_rocket_okapi = { version = "0.10.0", optional = true } revolt_okapi = { version = "0.9.1", optional = true } # Axum diff --git a/crates/daemons/pushd/Cargo.toml b/crates/daemons/pushd/Cargo.toml new file mode 100644 index 000000000..f27fef600 --- /dev/null +++ b/crates/daemons/pushd/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "revolt-pushd" +version = "0.1.0" +edition = "2021" + +[dependencies] +revolt-config = { version = "0.7.15", path = "../../core/config" } +revolt-database = { version = "0.7.15", path = "../../core/database" } +revolt-models = { version = "0.7.15", path = "../../core/models", features = [ + "validator", +] } + +amqprs = { version = "1.7.0" } +fcm_v1 = "0.3.0" +web-push = "0.10.0" +isahc = { optional = true, version = "1.7", features = ["json"] } +revolt_a2 = { version = "0.10", default-features = false, features = ["ring"] } +tokio = "1.39.2" +async-trait = "0.1.81" +ulid = "1.0.0" + +authifier = "1.0.8" + +log = "0.4.11" + +#serialization +serde_json = "1" +revolt_optional_struct = "0.2.0" +serde = { version = "1", features = ["derive"] } +iso8601-timestamp = { version = "0.2.10", features = ["serde", "bson"] } +base64 = "0.22.1" diff --git a/crates/daemons/pushd/Dockerfile b/crates/daemons/pushd/Dockerfile new file mode 100644 index 000000000..a1f80a39c --- /dev/null +++ b/crates/daemons/pushd/Dockerfile @@ -0,0 +1,9 @@ +# Build Stage +FROM ghcr.io/revoltchat/base:latest AS builder + +# Bundle Stage +FROM gcr.io/distroless/cc-debian12:nonroot +COPY --from=builder /home/rust/src/target/release/revolt-pushd ./ + +USER nonroot +CMD ["./revolt-pushd"] \ No newline at end of file diff --git a/crates/daemons/pushd/Pushd Flowchart.graffle b/crates/daemons/pushd/Pushd Flowchart.graffle new file mode 100644 index 000000000..961d5ff61 Binary files /dev/null and b/crates/daemons/pushd/Pushd Flowchart.graffle differ diff --git a/crates/daemons/pushd/src/consumers/inbound/ack.rs b/crates/daemons/pushd/src/consumers/inbound/ack.rs new file mode 100644 index 000000000..b9fe08861 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/inbound/ack.rs @@ -0,0 +1,138 @@ +use crate::consumers::inbound::internal::*; +use amqprs::{ + channel::{BasicPublishArguments, Channel}, + connection::Connection, + consumer::AsyncConsumer, + BasicProperties, Deliver, +}; +use async_trait::async_trait; +use revolt_database::{events::rabbit::*, Database}; + +pub struct AckConsumer { + #[allow(dead_code)] + db: Database, + authifier_db: authifier::Database, + conn: Option, + channel: Option, +} + +impl Channeled for AckConsumer { + fn get_connection(&self) -> Option<&Connection> { + if self.conn.is_none() { + None + } else { + Some(self.conn.as_ref().unwrap()) + } + } + + fn get_channel(&self) -> Option<&Channel> { + if self.channel.is_none() { + None + } else { + Some(self.channel.as_ref().unwrap()) + } + } + + fn set_connection(&mut self, conn: Connection) { + self.conn = Some(conn); + } + + fn set_channel(&mut self, channel: Channel) { + self.channel = Some(channel) + } +} + +impl AckConsumer { + pub fn new(db: Database, authifier_db: authifier::Database) -> AckConsumer { + AckConsumer { + db, + authifier_db, + conn: None, + channel: None, + } + } +} + +#[allow(unused_variables)] +#[async_trait] +impl AsyncConsumer for AckConsumer { + /// This consumer processes all acks the platform receives, and sends relevant badge updates to apple platforms. + async fn consume( + &mut self, + channel: &Channel, + deliver: Deliver, + basic_properties: BasicProperties, + content: Vec, + ) { + let content = String::from_utf8(content).unwrap(); + let payload: AckPayload = serde_json::from_str(content.as_str()).unwrap(); + + // Step 1: fetch unreads and don't continue if there's no unreads + #[allow(clippy::disallowed_methods)] + let unreads = self.db.fetch_unread_mentions(&payload.user_id).await; + + if let Ok(u) = &unreads { + if u.is_empty() { + return; + } + } else { + return; + } + + if let Ok(sessions) = self.authifier_db.find_sessions(&payload.user_id).await { + let config = revolt_config::config().await; + // Step 2: find any apple sessions, since we don't need to calculate this for anything else. + // If there's no apple sessions, we can return early + let apple_sessions: Vec<&authifier::models::Session> = sessions + .iter() + .filter(|session| { + if let Some(sub) = &session.subscription { + sub.endpoint == "apn" + } else { + false + } + }) + .collect(); + + if apple_sessions.is_empty() { + return; + } + + // Step 3: calculate the actual mention count, since we have to send it out + let mut mention_count = 0; + for u in &unreads.unwrap() { + mention_count += u.mentions.as_ref().unwrap().len() + } + + // Step 4: loop through each apple session and send the badge update + for session in apple_sessions { + let service_payload = PayloadToService { + notification: PayloadKind::BadgeUpdate(mention_count), + user_id: payload.user_id.clone(), + session_id: session.id.clone(), + token: session.subscription.as_ref().unwrap().auth.clone(), + extras: Default::default(), + }; + let raw_service_payload = serde_json::to_string(&service_payload); + + if let Ok(p) = raw_service_payload { + let args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.apn.queue.as_str(), + ) + .finish(); + + log::debug!( + "Publishing ack to apn session {}", + session.subscription.as_ref().unwrap().auth + ); + + publish_message(self, p.into(), args).await; + } else { + log::warn!("Failed to serialize ack badge update payload!"); + revolt_config::capture_error(&raw_service_payload.unwrap_err()); + } + } + } + } +} diff --git a/crates/daemons/pushd/src/consumers/inbound/fr_accepted.rs b/crates/daemons/pushd/src/consumers/inbound/fr_accepted.rs new file mode 100644 index 000000000..ff084a83a --- /dev/null +++ b/crates/daemons/pushd/src/consumers/inbound/fr_accepted.rs @@ -0,0 +1,121 @@ +use std::collections::HashMap; + +use crate::consumers::inbound::internal::*; +use amqprs::{ + channel::{BasicPublishArguments, Channel}, + connection::Connection, + consumer::AsyncConsumer, + BasicProperties, Deliver, +}; +use async_trait::async_trait; +use log::debug; +use revolt_database::{events::rabbit::*, Database}; + +pub struct FRAcceptedConsumer { + #[allow(dead_code)] + db: Database, + authifier_db: authifier::Database, + conn: Option, + channel: Option, +} + +impl Channeled for FRAcceptedConsumer { + fn get_connection(&self) -> Option<&Connection> { + if self.conn.is_none() { + None + } else { + Some(self.conn.as_ref().unwrap()) + } + } + + fn get_channel(&self) -> Option<&Channel> { + if self.channel.is_none() { + None + } else { + Some(self.channel.as_ref().unwrap()) + } + } + + fn set_connection(&mut self, conn: Connection) { + self.conn = Some(conn); + } + + fn set_channel(&mut self, channel: Channel) { + self.channel = Some(channel) + } +} + +impl FRAcceptedConsumer { + pub fn new(db: Database, authifier_db: authifier::Database) -> FRAcceptedConsumer { + FRAcceptedConsumer { + db, + authifier_db, + conn: None, + channel: None, + } + } +} + +#[allow(unused_variables)] +#[async_trait] +impl AsyncConsumer for FRAcceptedConsumer { + /// This consumer handles delegating messages into their respective platform queues. + async fn consume( + &mut self, + channel: &Channel, + deliver: Deliver, + basic_properties: BasicProperties, + content: Vec, + ) { + let content = String::from_utf8(content).unwrap(); + let payload: FRAcceptedPayload = serde_json::from_str(content.as_str()).unwrap(); + + debug!("Received FR accept event"); + + if let Ok(sessions) = self.authifier_db.find_sessions(&payload.user).await { + let config = revolt_config::config().await; + for session in sessions { + if let Some(sub) = session.subscription { + let mut sendable = PayloadToService { + notification: PayloadKind::FRAccepted(payload.clone()), + token: sub.auth, + user_id: session.user_id, + session_id: session.id, + extras: HashMap::new(), + }; + + let args: BasicPublishArguments; + + if sub.endpoint == "apn" { + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.apn.queue.as_str(), + ) + .finish(); + } else if sub.endpoint == "fcm" { + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.fcm.queue.as_str(), + ) + .finish(); + } else { + // web push (vapid) + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.vapid.queue.as_str(), + ) + .finish(); + sendable.extras.insert("p265dh".to_string(), sub.p256dh); + sendable + .extras + .insert("endpoint".to_string(), sub.endpoint.clone()); + } + + let payload = serde_json::to_string(&sendable).unwrap(); + + publish_message(self, payload.into(), args).await; + } + } + } + } +} diff --git a/crates/daemons/pushd/src/consumers/inbound/fr_received.rs b/crates/daemons/pushd/src/consumers/inbound/fr_received.rs new file mode 100644 index 000000000..c52dfec1b --- /dev/null +++ b/crates/daemons/pushd/src/consumers/inbound/fr_received.rs @@ -0,0 +1,121 @@ +use std::collections::HashMap; + +use crate::consumers::inbound::internal::*; +use amqprs::{ + channel::{BasicPublishArguments, Channel}, + connection::Connection, + consumer::AsyncConsumer, + BasicProperties, Deliver, +}; +use async_trait::async_trait; +use log::debug; +use revolt_database::{events::rabbit::*, Database}; + +pub struct FRReceivedConsumer { + #[allow(dead_code)] + db: Database, + authifier_db: authifier::Database, + conn: Option, + channel: Option, +} + +impl Channeled for FRReceivedConsumer { + fn get_connection(&self) -> Option<&Connection> { + if self.conn.is_none() { + None + } else { + Some(self.conn.as_ref().unwrap()) + } + } + + fn get_channel(&self) -> Option<&Channel> { + if self.channel.is_none() { + None + } else { + Some(self.channel.as_ref().unwrap()) + } + } + + fn set_connection(&mut self, conn: Connection) { + self.conn = Some(conn); + } + + fn set_channel(&mut self, channel: Channel) { + self.channel = Some(channel) + } +} + +impl FRReceivedConsumer { + pub fn new(db: Database, authifier_db: authifier::Database) -> FRReceivedConsumer { + FRReceivedConsumer { + db, + authifier_db, + conn: None, + channel: None, + } + } +} + +#[allow(unused_variables)] +#[async_trait] +impl AsyncConsumer for FRReceivedConsumer { + /// This consumer handles delegating messages into their respective platform queues. + async fn consume( + &mut self, + channel: &Channel, + deliver: Deliver, + basic_properties: BasicProperties, + content: Vec, + ) { + let content = String::from_utf8(content).unwrap(); + let payload: FRReceivedPayload = serde_json::from_str(content.as_str()).unwrap(); + + debug!("Received FR received event"); + + if let Ok(sessions) = self.authifier_db.find_sessions(&payload.user).await { + let config = revolt_config::config().await; + for session in sessions { + if let Some(sub) = session.subscription { + let mut sendable = PayloadToService { + notification: PayloadKind::FRReceived(payload.clone()), + token: sub.auth, + user_id: session.user_id, + session_id: session.id, + extras: HashMap::new(), + }; + + let args: BasicPublishArguments; + + if sub.endpoint == "apn" { + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.apn.queue.as_str(), + ) + .finish(); + } else if sub.endpoint == "fcm" { + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.fcm.queue.as_str(), + ) + .finish(); + } else { + // web push (vapid) + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.vapid.queue.as_str(), + ) + .finish(); + sendable.extras.insert("p265dh".to_string(), sub.p256dh); + sendable + .extras + .insert("endpoint".to_string(), sub.endpoint.clone()); + } + + let payload = serde_json::to_string(&sendable).unwrap(); + + publish_message(self, payload.into(), args).await; + } + } + } + } +} diff --git a/crates/daemons/pushd/src/consumers/inbound/generic.rs b/crates/daemons/pushd/src/consumers/inbound/generic.rs new file mode 100644 index 000000000..58070f664 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/inbound/generic.rs @@ -0,0 +1,127 @@ +use std::collections::HashMap; + +use crate::consumers::inbound::internal::*; +use amqprs::{ + channel::{BasicPublishArguments, Channel}, + connection::Connection, + consumer::AsyncConsumer, + BasicProperties, Deliver, +}; +use async_trait::async_trait; +use log::debug; +use revolt_database::{events::rabbit::*, Database}; + +pub struct GenericConsumer { + #[allow(dead_code)] + db: Database, + authifier_db: authifier::Database, + conn: Option, + channel: Option, +} + +impl Channeled for GenericConsumer { + fn get_connection(&self) -> Option<&Connection> { + if self.conn.is_none() { + None + } else { + Some(self.conn.as_ref().unwrap()) + } + } + + fn get_channel(&self) -> Option<&Channel> { + if self.channel.is_none() { + None + } else { + Some(self.channel.as_ref().unwrap()) + } + } + + fn set_connection(&mut self, conn: Connection) { + self.conn = Some(conn); + } + + fn set_channel(&mut self, channel: Channel) { + self.channel = Some(channel) + } +} + +impl GenericConsumer { + pub fn new(db: Database, authifier_db: authifier::Database) -> GenericConsumer { + GenericConsumer { + db, + authifier_db, + conn: None, + channel: None, + } + } +} + +#[allow(unused_variables)] +#[async_trait] +impl AsyncConsumer for GenericConsumer { + /// This consumer handles delegating messages into their respective platform queues. + async fn consume( + &mut self, + channel: &Channel, + deliver: Deliver, + basic_properties: BasicProperties, + content: Vec, + ) { + let content = String::from_utf8(content).unwrap(); + let payload: MessageSentPayload = serde_json::from_str(content.as_str()).unwrap(); + + debug!("Received message event on origin"); + + if let Ok(sessions) = self + .authifier_db + .find_sessions_with_subscription(&payload.users) + .await + { + let config = revolt_config::config().await; + for session in sessions { + if let Some(sub) = session.subscription { + let mut sendable = PayloadToService { + notification: PayloadKind::MessageNotification( + payload.notification.clone(), + ), + token: sub.auth, + user_id: session.user_id, + session_id: session.id, + extras: HashMap::new(), + }; + + let args: BasicPublishArguments; + + if sub.endpoint == "apn" { + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.apn.queue.as_str(), + ) + .finish(); + } else if sub.endpoint == "fcm" { + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.fcm.queue.as_str(), + ) + .finish(); + } else { + // web push (vapid) + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.vapid.queue.as_str(), + ) + .finish(); + sendable.extras.insert("p265dh".to_string(), sub.p256dh); + sendable + .extras + .insert("endpoint".to_string(), sub.endpoint.clone()); + } + + let payload = serde_json::to_string(&sendable).unwrap(); + + publish_message(self, payload.into(), args).await; + } + } + } + } +} diff --git a/crates/daemons/pushd/src/consumers/inbound/internal.rs b/crates/daemons/pushd/src/consumers/inbound/internal.rs new file mode 100644 index 000000000..387c08b52 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/inbound/internal.rs @@ -0,0 +1,53 @@ +use amqprs::{ + channel::{BasicPublishArguments, Channel}, + connection::{Connection, OpenConnectionArguments}, + BasicProperties, +}; +use log::{debug, warn}; + +pub(crate) trait Channeled { + #[allow(unused)] + fn get_connection(&self) -> Option<&Connection>; + fn get_channel(&self) -> Option<&Channel>; + fn set_connection(&mut self, conn: Connection); + fn set_channel(&mut self, channel: Channel); +} + +pub(crate) async fn make_channel(consumer: &mut T) { + let config = revolt_config::config().await; + + let args = OpenConnectionArguments::new( + &config.rabbit.host, + config.rabbit.port, + &config.rabbit.username, + &config.rabbit.password, + ); + let conn = amqprs::connection::Connection::open(&args).await.unwrap(); + + let channel = conn.open_channel(None).await.unwrap(); + + consumer.set_connection(conn); + consumer.set_channel(channel); +} + +pub(crate) async fn publish_message( + consumer: &mut T, + payload: Vec, + args: BasicPublishArguments, +) { + let routing_key = &args.routing_key.clone(); + let mut channel = consumer.get_channel(); + if channel.is_none() { + make_channel(consumer).await; + channel = consumer.get_channel(); + } + + if let Some(chnl) = channel { + chnl.basic_publish(BasicProperties::default(), payload.clone(), args.clone()) + .await + .unwrap(); + debug!("Sent message to queue for target {}", routing_key); + } else { + warn!("Failed to unwrap channel (including attempt to make a channel)!") + } +} diff --git a/crates/daemons/pushd/src/consumers/inbound/message.rs b/crates/daemons/pushd/src/consumers/inbound/message.rs new file mode 100644 index 000000000..99e61cbd5 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/inbound/message.rs @@ -0,0 +1,127 @@ +use std::collections::HashMap; + +use crate::consumers::inbound::internal::*; +use amqprs::{ + channel::{BasicPublishArguments, Channel}, + connection::Connection, + consumer::AsyncConsumer, + BasicProperties, Deliver, +}; +use async_trait::async_trait; +use log::debug; +use revolt_database::{events::rabbit::*, Database}; + +pub struct MessageConsumer { + #[allow(dead_code)] + db: Database, + authifier_db: authifier::Database, + conn: Option, + channel: Option, +} + +impl Channeled for MessageConsumer { + fn get_connection(&self) -> Option<&Connection> { + if self.conn.is_none() { + None + } else { + Some(self.conn.as_ref().unwrap()) + } + } + + fn get_channel(&self) -> Option<&Channel> { + if self.channel.is_none() { + None + } else { + Some(self.channel.as_ref().unwrap()) + } + } + + fn set_connection(&mut self, conn: Connection) { + self.conn = Some(conn); + } + + fn set_channel(&mut self, channel: Channel) { + self.channel = Some(channel) + } +} + +impl MessageConsumer { + pub fn new(db: Database, authifier_db: authifier::Database) -> MessageConsumer { + MessageConsumer { + db, + authifier_db, + conn: None, + channel: None, + } + } +} + +#[allow(unused_variables)] +#[async_trait] +impl AsyncConsumer for MessageConsumer { + /// This consumer handles delegating messages into their respective platform queues. + async fn consume( + &mut self, + channel: &Channel, + deliver: Deliver, + basic_properties: BasicProperties, + content: Vec, + ) { + let content = String::from_utf8(content).unwrap(); + let payload: MessageSentPayload = serde_json::from_str(content.as_str()).unwrap(); + + debug!("Received message event on origin"); + + if let Ok(sessions) = self + .authifier_db + .find_sessions_with_subscription(&payload.users) + .await + { + let config = revolt_config::config().await; + for session in sessions { + if let Some(sub) = session.subscription { + let mut sendable = PayloadToService { + notification: PayloadKind::MessageNotification( + payload.notification.clone(), + ), + token: sub.auth, + user_id: session.user_id, + session_id: session.id, + extras: HashMap::new(), + }; + + let args: BasicPublishArguments; + + if sub.endpoint == "apn" { + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.apn.queue.as_str(), + ) + .finish(); + } else if sub.endpoint == "fcm" { + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.fcm.queue.as_str(), + ) + .finish(); + } else { + // web push (vapid) + args = BasicPublishArguments::new( + config.pushd.exchange.as_str(), + config.pushd.vapid.queue.as_str(), + ) + .finish(); + sendable.extras.insert("p265dh".to_string(), sub.p256dh); + sendable + .extras + .insert("endpoint".to_string(), sub.endpoint.clone()); + } + + let payload = serde_json::to_string(&sendable).unwrap(); + + publish_message(self, payload.into(), args).await; + } + } + } + } +} diff --git a/crates/daemons/pushd/src/consumers/inbound/mod.rs b/crates/daemons/pushd/src/consumers/inbound/mod.rs new file mode 100644 index 000000000..a340143d7 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/inbound/mod.rs @@ -0,0 +1,6 @@ +pub mod ack; +pub mod fr_accepted; +pub mod fr_received; +pub mod generic; +mod internal; +pub mod message; diff --git a/crates/daemons/pushd/src/consumers/mod.rs b/crates/daemons/pushd/src/consumers/mod.rs new file mode 100644 index 000000000..5756441b7 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/mod.rs @@ -0,0 +1,2 @@ +pub mod inbound; +pub mod outbound; diff --git a/crates/daemons/pushd/src/consumers/outbound/apn.rs b/crates/daemons/pushd/src/consumers/outbound/apn.rs new file mode 100644 index 000000000..3b6ac3a71 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/outbound/apn.rs @@ -0,0 +1,338 @@ +use std::{borrow::Cow, collections::BTreeMap, io::Cursor}; + +use amqprs::{channel::Channel as AmqpChannel, consumer::AsyncConsumer, BasicProperties, Deliver}; +use async_trait::async_trait; +use base64::{ + engine::{self}, + Engine as _, +}; +use revolt_a2::{ + request::{ + notification::{DefaultAlert, NotificationOptions}, + payload::{APSAlert, APSSound, Payload, PayloadLike, APS}, + }, + Client, ClientConfig, Endpoint, Error, ErrorBody, ErrorReason, Priority, PushType, Response, +}; +use revolt_database::{events::rabbit::*, Database}; +use revolt_models::v0::{Channel, Message, PushNotification}; +use serde::Serialize; + +// region: payload + +#[derive(Serialize, Debug)] +struct MessagePayload<'a> { + aps: APS<'a>, + #[serde(skip_serializing)] + options: NotificationOptions<'a>, + #[serde(skip_serializing)] + device_token: &'a str, + + message: &'a Message, + url: &'a str, + #[serde(rename = "camelCase")] + author_avatar: &'a str, + #[serde(rename = "camelCase")] + author_display_name: &'a str, + #[serde(rename = "camelCase")] + channel_name: &'a str, +} + +impl<'a> PayloadLike for MessagePayload<'a> { + fn get_device_token(&self) -> &'a str { + self.device_token + } + fn get_options(&self) -> &NotificationOptions { + &self.options + } +} + +// region: consumer + +pub struct ApnsOutboundConsumer { + #[allow(dead_code)] + db: Database, + client: Client, +} + +impl ApnsOutboundConsumer { + fn format_title(&self, notification: &PushNotification) -> String { + // ideally this changes depending on context + // in a server, it would look like "Sendername, #channelname in servername" + // in a group, it would look like "Sendername in groupname" + // in a dm it should just be "Sendername". + // not sure how feasible all those are given the PushNotification object as it currently stands. + + match ¬ification.channel { + Channel::DirectMessage { .. } => notification.author.clone(), + Channel::Group { name, .. } => format!("{}, #{}", notification.author, name), + Channel::TextChannel { name, .. } | Channel::VoiceChannel { name, .. } => { + format!("{} in #{}", notification.author, name) + } + _ => "Unknown".to_string(), + } + } + + async fn get_badge_count(&self, user: &str) -> Option { + if let Ok(unreads) = self.db.fetch_unread_mentions(user).await { + let mut mention_count = 0; + for channel in unreads { + if let Some(mentions) = channel.mentions { + mention_count += mentions.len() as u32 + } + } + + println!("Got badge count for APN: {}", mention_count); + + return Some(mention_count); + } + None + } +} + +impl ApnsOutboundConsumer { + pub async fn new(db: Database) -> Result { + let config = revolt_config::config().await; + + if config.pushd.apn.pkcs8.is_empty() + || config.pushd.apn.key_id.is_empty() + || config.pushd.apn.team_id.is_empty() + { + return Err("Missing APN keys."); + } + + let endpoint = if config.pushd.apn.sandbox { + Endpoint::Sandbox + } else { + Endpoint::Production + }; + + let pkcs8 = engine::general_purpose::STANDARD + .decode(config.pushd.apn.pkcs8.clone()) + .expect("valid `pcks8`"); + + let client_config = ClientConfig::new(endpoint); + + let client = Client::token( + &mut Cursor::new(pkcs8), + config.pushd.apn.key_id.clone(), + config.pushd.apn.team_id.clone(), + client_config, + ) + .expect("could not create APN client"); + + Ok(ApnsOutboundConsumer { db, client }) + } +} + +#[allow(unused_variables)] +#[async_trait] +impl AsyncConsumer for ApnsOutboundConsumer { + async fn consume( + &mut self, + channel: &AmqpChannel, + deliver: Deliver, + basic_properties: BasicProperties, + content: Vec, + ) { + let content = String::from_utf8(content).unwrap(); + let payload: PayloadToService = serde_json::from_str(content.as_str()).unwrap(); + + let payload_options = NotificationOptions { + apns_id: None, + apns_push_type: Some(PushType::Alert), + apns_expiration: None, + apns_priority: Some(Priority::High), + apns_topic: Some("chat.revolt.app"), + apns_collapse_id: None, + }; + + let resp: Result; + + match payload.notification { + PayloadKind::FRReceived(alert) => { + let loc_args = vec![Cow::from( + alert + .from_user + .display_name + .or(Some(format!( + "{}#{}", + alert.from_user.username, alert.from_user.discriminator + ))) + .clone() + .unwrap(), + )]; + + let apn_payload = Payload { + aps: APS { + alert: Some(APSAlert::Default(DefaultAlert { + title: None, + subtitle: None, + body: None, + title_loc_key: None, + title_loc_args: None, + action_loc_key: None, + loc_key: Some("push.fr.received"), + loc_args: Some(loc_args), + launch_image: None, + })), + badge: self.get_badge_count(&payload.user_id).await, + sound: Some(APSSound::Sound("default")), + thread_id: None, + content_available: None, + category: None, + mutable_content: Some(1), + url_args: None, + }, + device_token: &payload.token, + options: payload_options.clone(), + data: BTreeMap::new(), + }; + + resp = self.client.send(apn_payload).await; + } + + PayloadKind::FRAccepted(alert) => { + let loc_args = vec![Cow::from( + alert + .accepted_user + .display_name + .or(Some(format!( + "{}#{}", + alert.accepted_user.username, alert.accepted_user.discriminator + ))) + .clone() + .unwrap(), + )]; + + let apn_payload = Payload { + aps: APS { + alert: Some(APSAlert::Default(DefaultAlert { + title: None, + subtitle: None, + body: None, + title_loc_key: None, + title_loc_args: None, + action_loc_key: None, + loc_key: Some("push.fr.accepted"), + loc_args: Some(loc_args), + launch_image: None, + })), + badge: self.get_badge_count(&payload.user_id).await, + sound: Some(APSSound::Sound("default")), + thread_id: None, + content_available: None, + category: None, + mutable_content: Some(1), + url_args: None, + }, + device_token: &payload.token, + options: payload_options.clone(), + data: BTreeMap::new(), + }; + + resp = self.client.send(apn_payload).await; + } + PayloadKind::Generic(alert) => { + let apn_payload = Payload { + aps: APS { + alert: Some(APSAlert::Default(DefaultAlert { + title: Some(&alert.title), + subtitle: None, + body: Some(&alert.body), + title_loc_key: None, + title_loc_args: None, + action_loc_key: None, + loc_key: None, + loc_args: None, + launch_image: None, + })), + badge: self.get_badge_count(&payload.user_id).await, + sound: Some(APSSound::Sound("default")), + thread_id: None, + content_available: None, + category: None, + mutable_content: Some(1), + url_args: None, + }, + device_token: &payload.token, + options: payload_options.clone(), + data: BTreeMap::new(), + }; + + resp = self.client.send(apn_payload).await; + } + + PayloadKind::MessageNotification(alert) => { + let title = self.format_title(&alert); + let apn_payload = MessagePayload { + aps: APS { + alert: Some(APSAlert::Default(DefaultAlert { + title: Some(&title), + subtitle: None, + body: Some(&alert.body), + title_loc_key: None, + title_loc_args: None, + action_loc_key: None, + loc_key: None, + loc_args: None, + launch_image: None, + })), + badge: self.get_badge_count(&payload.user_id).await, + sound: Some(APSSound::Sound("default")), + thread_id: Some(alert.channel.id()), + content_available: None, + category: None, + mutable_content: Some(1), + url_args: None, + }, + device_token: &payload.token, + options: payload_options.clone(), + message: &alert.message, + url: &alert.url, + author_avatar: &alert.icon, + author_display_name: &alert.author, + channel_name: alert.channel.name().unwrap_or(&title), + }; + + resp = self.client.send(apn_payload).await; + } + PayloadKind::BadgeUpdate(badge) => { + let apn_payload = Payload { + aps: APS { + badge: Some(badge as u32), + ..Default::default() + }, + device_token: &payload.token, + options: payload_options.clone(), + data: BTreeMap::new(), + }; + + resp = self.client.send(apn_payload).await; + } + } + + if let Err(err) = resp { + match err { + Error::ResponseError(Response { + error: + Some(ErrorBody { + reason: ErrorReason::BadDeviceToken | ErrorReason::Unregistered, + .. + }), + .. + }) => { + if let Err(err) = self + .db + .remove_push_subscription_by_session_id(&payload.session_id) + .await + { + revolt_config::capture_error(&err); + } + } + err => { + revolt_config::capture_error(&err); + } + } + } + } +} diff --git a/crates/daemons/pushd/src/consumers/outbound/fcm.rs b/crates/daemons/pushd/src/consumers/outbound/fcm.rs new file mode 100644 index 000000000..61700bbb9 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/outbound/fcm.rs @@ -0,0 +1,199 @@ +use std::{collections::HashMap, time::Duration}; + +use amqprs::{channel::Channel as AmqpChannel, consumer::AsyncConsumer, BasicProperties, Deliver}; + +use async_trait::async_trait; +use fcm_v1::{ + android::AndroidConfig, + auth::{Authenticator, ServiceAccountKey}, + message::{Message, Notification}, + Client, Error as FcmError, +}; +use revolt_database::{events::rabbit::*, Database}; +use revolt_models::v0::{Channel, PushNotification}; +use serde_json::Value; + +pub struct FcmOutboundConsumer { + db: Database, + client: Client, +} + +impl FcmOutboundConsumer { + fn format_title(&self, notification: &PushNotification) -> String { + // ideally this changes depending on context + // in a server, it would look like "Sendername, #channelname in servername" + // in a group, it would look like "Sendername in groupname" + // in a dm it should just be "Sendername". + // not sure how feasible all those are given the PushNotification object as it currently stands. + + match ¬ification.channel { + Channel::DirectMessage { .. } => notification.author.clone(), + Channel::Group { name, .. } => format!("{}, #{}", notification.author, name), + Channel::TextChannel { name, .. } | Channel::VoiceChannel { name, .. } => { + format!("{} in #{}", notification.author, name) + } + _ => "Unknown".to_string(), + } + } +} + +impl FcmOutboundConsumer { + pub async fn new(db: Database) -> Result { + let config = revolt_config::config().await; + + Ok(FcmOutboundConsumer { + db, + client: Client::new( + Authenticator::service_account::<&str>(ServiceAccountKey { + key_type: Some(config.pushd.fcm.key_type), + project_id: Some(config.pushd.fcm.project_id.clone()), + private_key_id: Some(config.pushd.fcm.private_key_id), + private_key: config.pushd.fcm.private_key, + client_email: config.pushd.fcm.client_email, + client_id: Some(config.pushd.fcm.client_id), + auth_uri: Some(config.pushd.fcm.auth_uri), + token_uri: config.pushd.fcm.token_uri, + auth_provider_x509_cert_url: Some(config.pushd.fcm.auth_provider_x509_cert_url), + client_x509_cert_url: Some(config.pushd.fcm.client_x509_cert_url), + }) + .await + .unwrap(), + config.pushd.fcm.project_id, + false, + Duration::from_secs(5), + ), + }) + } +} + +#[allow(unused_variables)] +#[async_trait] +impl AsyncConsumer for FcmOutboundConsumer { + async fn consume( + &mut self, + channel: &AmqpChannel, + deliver: Deliver, + basic_properties: BasicProperties, + content: Vec, + ) { + let content = String::from_utf8(content).unwrap(); + let payload: PayloadToService = serde_json::from_str(content.as_str()).unwrap(); + + let config = revolt_config::config().await; + + #[allow(clippy::needless_late_init)] + let resp: Result; + + match payload.notification { + PayloadKind::FRReceived(alert) => { + let name = alert + .from_user + .display_name + .or(Some(format!( + "{}#{}", + alert.from_user.username, alert.from_user.discriminator + ))) + .clone() + .unwrap(); + + let mut data = HashMap::new(); + data.insert( + "type".to_string(), + Value::String("push.fr.receive".to_string()), + ); + data.insert("id".to_string(), Value::String(alert.from_user.id)); + data.insert("username".to_string(), Value::String(name)); + + let msg = Message { + token: Some(payload.token), + data: Some(data), + ..Default::default() + }; + + resp = self.client.send(&msg).await; + } + + PayloadKind::FRAccepted(alert) => { + let name = alert + .accepted_user + .display_name + .or(Some(format!( + "{}#{}", + alert.accepted_user.username, alert.accepted_user.discriminator + ))) + .clone() + .unwrap(); + + let mut data: HashMap = HashMap::new(); + data.insert( + "type".to_string(), + Value::String("push.fr.accept".to_string()), + ); + data.insert("id".to_string(), Value::String(alert.accepted_user.id)); + data.insert("username".to_string(), Value::String(name)); + + let msg = Message { + token: Some(payload.token), + data: Some(data), + ..Default::default() + }; + + resp = self.client.send(&msg).await; + } + PayloadKind::Generic(alert) => { + let msg = Message { + token: Some(payload.token), + notification: Some(Notification { + title: Some(alert.title), + body: Some(alert.body), + image: alert.icon, + }), + ..Default::default() + }; + + resp = self.client.send(&msg).await; + } + + PayloadKind::MessageNotification(alert) => { + let title = self.format_title(&alert); + + let msg = Message { + token: Some(payload.token), + notification: Some(Notification { + title: Some(title), + body: Some(alert.body), + image: Some(alert.icon), + }), + android: Some(AndroidConfig { + collapse_key: Some(alert.tag), + ..Default::default() + }), + ..Default::default() + }; + + resp = self.client.send(&msg).await; + } + + PayloadKind::BadgeUpdate(_) => { + panic!("FCM cannot handle badge updates, and they should not be sent here.") + } + } + + if let Err(err) = resp { + match err { + FcmError::Auth => { + if let Err(err) = self + .db + .remove_push_subscription_by_session_id(&payload.session_id) + .await + { + revolt_config::capture_error(&err); + } + } + err => { + revolt_config::capture_error(&err); + } + } + } + } +} diff --git a/crates/daemons/pushd/src/consumers/outbound/mod.rs b/crates/daemons/pushd/src/consumers/outbound/mod.rs new file mode 100644 index 000000000..cfff0a268 --- /dev/null +++ b/crates/daemons/pushd/src/consumers/outbound/mod.rs @@ -0,0 +1,3 @@ +pub mod apn; +pub mod fcm; +pub mod vapid; diff --git a/crates/daemons/pushd/src/consumers/outbound/vapid.rs b/crates/daemons/pushd/src/consumers/outbound/vapid.rs new file mode 100644 index 000000000..fb735d5eb --- /dev/null +++ b/crates/daemons/pushd/src/consumers/outbound/vapid.rs @@ -0,0 +1,149 @@ +use std::collections::HashMap; + +use amqprs::{channel::Channel as AmqpChannel, consumer::AsyncConsumer, BasicProperties, Deliver}; + +use async_trait::async_trait; +use base64::{ + engine::{self}, + Engine as _, +}; +use revolt_database::{events::rabbit::*, Database}; +// use revolt_models::v0::{Channel, PushNotification}; +use web_push::{ + ContentEncoding, IsahcWebPushClient, SubscriptionInfo, SubscriptionKeys, VapidSignatureBuilder, + WebPushClient, WebPushError, WebPushMessageBuilder, +}; + +pub struct VapidOutboundConsumer { + db: Database, + client: IsahcWebPushClient, + pkey: Vec, +} + +impl VapidOutboundConsumer { + pub async fn new(db: Database) -> Result { + let config = revolt_config::config().await; + + if config.pushd.vapid.private_key.is_empty() | config.pushd.vapid.public_key.is_empty() { + return Err("No Vapid keys present"); + } + + let web_push_private_key = engine::general_purpose::URL_SAFE_NO_PAD + .decode(config.pushd.vapid.private_key) + .expect("valid `VAPID_PRIVATE_KEY`"); + + Ok(VapidOutboundConsumer { + db, + client: IsahcWebPushClient::new().unwrap(), + pkey: web_push_private_key, + }) + } +} + +#[allow(unused_variables)] +#[async_trait] +impl AsyncConsumer for VapidOutboundConsumer { + async fn consume( + &mut self, + channel: &AmqpChannel, + deliver: Deliver, + basic_properties: BasicProperties, + content: Vec, + ) { + let content = String::from_utf8(content).unwrap(); + let payload: PayloadToService = serde_json::from_str(content.as_str()).unwrap(); + + let config = revolt_config::config().await; + + let subscription = SubscriptionInfo { + endpoint: payload.extras.get("endpoint").unwrap().clone(), + keys: SubscriptionKeys { + auth: payload.token, + p256dh: payload.extras.get("p256dh").unwrap().clone(), + }, + }; + + #[allow(clippy::needless_late_init)] + let payload_body: String; + + match payload.notification { + PayloadKind::FRReceived(alert) => { + let name = alert + .from_user + .display_name + .or(Some(format!( + "{}#{}", + alert.from_user.username, alert.from_user.discriminator + ))) + .clone() + .unwrap(); + + let mut body = HashMap::new(); + body.insert("body", format!("{} sent you a friend request", name)); + + payload_body = serde_json::to_string(&body).unwrap(); + } + PayloadKind::FRAccepted(alert) => { + let name = alert + .accepted_user + .display_name + .or(Some(format!( + "{}#{}", + alert.accepted_user.username, alert.accepted_user.discriminator + ))) + .clone() + .unwrap(); + + let mut body = HashMap::new(); + body.insert("body", format!("{} accepted your friend request", name)); + + payload_body = serde_json::to_string(&body).unwrap(); + } + PayloadKind::Generic(alert) => { + payload_body = serde_json::to_string(&alert).unwrap(); + } + PayloadKind::MessageNotification(alert) => { + payload_body = serde_json::to_string(&alert).unwrap(); + } + PayloadKind::BadgeUpdate(_) => { + panic!("Vapid cannot handle badge updates, and they should not be sent here.") + } + } + + match VapidSignatureBuilder::from_pem(std::io::Cursor::new(&self.pkey), &subscription) { + Ok(sig_builder) => match sig_builder.build() { + Ok(signature) => { + let mut builder = WebPushMessageBuilder::new(&subscription); + builder.set_vapid_signature(signature); + + builder.set_payload(ContentEncoding::AesGcm, payload_body.as_bytes()); + + match builder.build() { + Ok(msg) => { + if let Err(err) = self.client.send(msg).await { + if err == WebPushError::Unauthorized { + if let Err(err) = self + .db + .remove_push_subscription_by_session_id(&payload.session_id) + .await + { + revolt_config::capture_error(&err); + } + } + } + } + Err(err) => { + revolt_config::capture_error(&err); + } + } + } + Err(err) => { + revolt_config::capture_error(&err); + } + }, + Err(err) => { + revolt_config::capture_error(&err); + } + } + } +} diff --git a/crates/daemons/pushd/src/main.rs b/crates/daemons/pushd/src/main.rs new file mode 100644 index 000000000..c63b8d9a1 --- /dev/null +++ b/crates/daemons/pushd/src/main.rs @@ -0,0 +1,233 @@ +use amqprs::{ + channel::{ + BasicConsumeArguments, Channel, ExchangeDeclareArguments, QueueBindArguments, + QueueDeclareArguments, + }, + connection::{Connection, OpenConnectionArguments}, + consumer::AsyncConsumer, + FieldTable, +}; +use revolt_config::{config, Settings}; +use tokio::sync::Notify; + +mod consumers; +use consumers::{ + inbound::{ + ack::AckConsumer, fr_accepted::FRAcceptedConsumer, fr_received::FRReceivedConsumer, + generic::GenericConsumer, message::MessageConsumer, + }, + outbound::{apn::ApnsOutboundConsumer, fcm::FcmOutboundConsumer, vapid::VapidOutboundConsumer}, +}; + +#[tokio::main(flavor = "multi_thread", worker_threads = 2)] +async fn main() { + let config = config().await; + + // Setup database + let db = revolt_database::DatabaseInfo::Auto.connect().await.unwrap(); + let authifier: authifier::Database; + + if let Some(client) = match &db { + revolt_database::Database::Reference(_) => None, + revolt_database::Database::MongoDb(mongo) => Some(mongo), + } { + authifier = + authifier::Database::MongoDb(authifier::database::MongoDb(client.database("revolt"))); + } else { + panic!("Mongo is not in use, can't connect via authifier!") + } + + let mut connections: Vec<(Channel, Connection)> = Vec::new(); + + // An explainer of how this works: + // The inbound connections are on separate routing keys, such that they only receive the proper payload + // from their respective api (prod or test). + // However, the outbound queues that go to the services are routed to receive from both, so that messages + // sent from beta are still notified on prod, and vice versa. + + // This'll require some interesting shimming if we need to add more events once this is in prod (different payloads between prod and test), + // but that sounds like a problem for future us. + + // inbound: generic + connections.push( + make_queue_and_consume( + &config, + &config.pushd.generic_queue, + config.pushd.get_generic_routing_key().as_str(), + None, + GenericConsumer::new(db.clone(), authifier.clone()), + ) + .await, + ); + + // inbound: messages + connections.push( + make_queue_and_consume( + &config, + &config.pushd.message_queue, + config.pushd.get_message_routing_key().as_str(), + None, + MessageConsumer::new(db.clone(), authifier.clone()), + ) + .await, + ); + + // inbound: FR received + connections.push( + make_queue_and_consume( + &config, + &config.pushd.fr_received_queue, + config.pushd.get_fr_received_routing_key().as_str(), + None, + FRReceivedConsumer::new(db.clone(), authifier.clone()), + ) + .await, + ); + + // inbound: FR accepted + connections.push( + make_queue_and_consume( + &config, + &config.pushd.fr_accepted_queue, + config.pushd.get_fr_accepted_routing_key().as_str(), + None, + FRAcceptedConsumer::new(db.clone(), authifier.clone()), + ) + .await, + ); + + if !config.pushd.apn.pkcs8.is_empty() { + connections.push( + make_queue_and_consume( + &config, + &config.pushd.apn.queue, + &config.pushd.apn.queue, + None, + ApnsOutboundConsumer::new(db.clone()).await.unwrap(), + ) + .await, + ); + + let mut table = FieldTable::new(); + table.insert("x-message-deduplication".try_into().unwrap(), "true".into()); + + connections.push( + make_queue_and_consume( + &config, + &config.pushd.ack_queue, + &config.pushd.ack_queue, + Some(table), + AckConsumer::new(db.clone(), authifier.clone()), + ) + .await, + ); + } + + if !config.pushd.fcm.auth_uri.is_empty() { + connections.push( + make_queue_and_consume( + &config, + &config.pushd.fcm.queue, + &config.pushd.fcm.queue, + None, + FcmOutboundConsumer::new(db.clone()).await.unwrap(), + ) + .await, + ) + } + + if !config.pushd.vapid.public_key.is_empty() { + connections.push( + make_queue_and_consume( + &config, + &config.pushd.vapid.queue, + &config.pushd.vapid.queue, + None, + VapidOutboundConsumer::new(db.clone()).await.unwrap(), + ) + .await, + ) + } + + let guard = Notify::new(); + guard.notified().await; + + for (channel, conn) in connections { + channel.close().await.expect("Unable to close channel"); + conn.close().await.expect("Unable to close connection"); + } +} + +async fn make_queue_and_consume( + config: &Settings, + queue_name: &str, + routing_key: &str, + queue_args: Option, + consumer: F, +) -> (Channel, Connection) +where + F: AsyncConsumer + Send + 'static, +{ + let connection = Connection::open(&OpenConnectionArguments::new( + &config.rabbit.host, + config.rabbit.port, + &config.rabbit.username, + &config.rabbit.password, + )) + .await + .unwrap(); + + let channel = connection.open_channel(None).await.unwrap(); + + channel + .exchange_declare( + ExchangeDeclareArguments::new(&config.pushd.exchange, "direct") + .durable(true) + .finish(), + ) + .await + .expect("Failed to declare pushd exchange"); + + let mut queue_name = queue_name.to_string(); + + if config.pushd.production { + queue_name += "-prd"; + } else { + queue_name += "-tst"; + } + + let queue_name = queue_name.as_str(); + + let mut args = QueueDeclareArguments::new(queue_name); + args.durable(true); + + if let Some(arg) = queue_args { + args.arguments(arg); + } + + let args = args.finish(); + _ = channel.queue_declare(args).await.unwrap().unwrap(); + + channel + .queue_bind(QueueBindArguments::new( + queue_name, + &config.pushd.exchange, + routing_key, + )) + .await + .expect( + "This probably means the revolt.notifications exchange does not exist in rabbitmq!", + ); + + let args = BasicConsumeArguments::new(queue_name, "") + .manual_ack(false) + .finish(); + + channel.basic_consume(consumer, args).await.unwrap(); + log::info!( + "Consuming routing key {} as queue {}", + routing_key, + queue_name + ); + (channel, connection) +} diff --git a/crates/delta/Cargo.toml b/crates/delta/Cargo.toml index 05857ae07..d4ecc97ee 100644 --- a/crates/delta/Cargo.toml +++ b/crates/delta/Cargo.toml @@ -52,20 +52,21 @@ async-std = { version = "1.8.0", features = [ lettre = "0.10.0-alpha.4" # web -rocket = { version = "0.5.0-rc.2", default-features = false, features = [ - "json", -] } -rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", rev = "c17e8145baa4790319fdb6a473e465b960f55e7c" } +rocket = { version = "0.5.1", default-features = false, features = ["json"] } +rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", rev = "072d90359b23e9b291df6b672c07c93de9c46011" } rocket_empty = { version = "0.1.1", features = ["schema"] } -rocket_authifier = { version = "1.0.8" } +rocket_authifier = { version = "1.0.9" } rocket_prometheus = "0.10.0-rc.3" # spec generation schemars = "0.8.8" -revolt_rocket_okapi = { version = "0.9.1", features = ["swagger"] } +revolt_rocket_okapi = { version = "0.10.0", features = ["swagger"] } + +# rabbit +amqprs = { version = "1.7.0" } # core -authifier = "1.0.8" +authifier = "1.0.9" revolt-config = { path = "../core/config" } revolt-database = { path = "../core/database", features = [ "rocket-impl", diff --git a/crates/delta/src/main.rs b/crates/delta/src/main.rs index 958254cd6..c25cf7fa8 100644 --- a/crates/delta/src/main.rs +++ b/crates/delta/src/main.rs @@ -10,12 +10,17 @@ pub mod util; use revolt_config::config; use revolt_database::events::client::EventV1; +use revolt_database::AMQP; use rocket::{Build, Rocket}; use rocket_cors::{AllowedOrigins, CorsOptions}; use rocket_prometheus::PrometheusMetrics; use std::net::Ipv4Addr; use std::str::FromStr; +use amqprs::{ + channel::ExchangeDeclareArguments, + connection::{Connection, OpenConnectionArguments}, +}; use async_std::channel::unbounded; use authifier::AuthifierEvent; use rocket::data::ToByteUnit; @@ -32,7 +37,7 @@ pub async fn web() -> Rocket { db.migrate_database().await.unwrap(); // Setup Authifier event channel - let (sender, receiver) = unbounded(); + let (_, receiver) = unbounded(); // Setup Authifier let authifier = db.clone().to_authifier().await; @@ -53,9 +58,6 @@ pub async fn web() -> Rocket { } }); - // Launch background task workers - revolt_database::tasks::start_workers(db.clone(), authifier.database.clone()); - // Configure CORS let cors = CorsOptions { allowed_origins: AllowedOrigins::All, @@ -79,6 +81,31 @@ pub async fn web() -> Rocket { ) .into(); + // Configure Rabbit + let connection = Connection::open(&OpenConnectionArguments::new( + &config.rabbit.host, + config.rabbit.port, + &config.rabbit.username, + &config.rabbit.password, + )) + .await + .unwrap(); + let channel = connection.open_channel(None).await.unwrap(); + + channel + .exchange_declare( + ExchangeDeclareArguments::new(&config.pushd.exchange, "direct") + .durable(true) + .finish(), + ) + .await + .expect("Failed to declare exchange"); + + let amqp = AMQP::new(connection, channel); + + // Launch background task workers + revolt_database::tasks::start_workers(db.clone(), amqp.clone()); + // Configure Rocket let rocket = rocket::build(); let prometheus = PrometheusMetrics::new(); @@ -91,6 +118,7 @@ pub async fn web() -> Rocket { .mount("/swagger/", swagger) .manage(authifier) .manage(db) + .manage(amqp) .manage(cors.clone()) .attach(util::ratelimiter::RatelimitFairing) .attach(cors) diff --git a/crates/delta/src/routes/bots/invite.rs b/crates/delta/src/routes/bots/invite.rs index 120c36dab..10c09eed1 100644 --- a/crates/delta/src/routes/bots/invite.rs +++ b/crates/delta/src/routes/bots/invite.rs @@ -1,6 +1,6 @@ use revolt_database::util::permissions::DatabasePermissionQuery; -use revolt_database::Member; use revolt_database::{util::reference::Reference, Database, User}; +use revolt_database::{Member, AMQP}; use revolt_models::v0; use revolt_permissions::{ calculate_channel_permissions, calculate_server_permissions, ChannelPermission, @@ -18,6 +18,7 @@ use rocket_empty::EmptyResponse; #[post("//invite", data = "")] pub async fn invite_bot( db: &State, + amqp: &State, user: User, target: Reference, dest: Json, @@ -55,7 +56,7 @@ pub async fn invite_bot( .throw_if_lacking_channel_permission(ChannelPermission::InviteOthers)?; channel - .add_user_to_group(db, &bot_user, &user.id) + .add_user_to_group(db, amqp, &bot_user, &user.id) .await .map(|_| EmptyResponse) } @@ -93,9 +94,12 @@ mod test { .client .post(format!("/bots/{}/invite", bot.id)) .header(ContentType::JSON) - .body(json!(v0::InviteBotDestination::Group { - group: group.id().to_string() - }).to_string()) + .body( + json!(v0::InviteBotDestination::Group { + group: group.id().to_string() + }) + .to_string(), + ) .header(Header::new("x-session-token", session.token.to_string())) .dispatch() .await; diff --git a/crates/delta/src/routes/channels/channel_delete.rs b/crates/delta/src/routes/channels/channel_delete.rs index ae09ee70f..86dea166c 100644 --- a/crates/delta/src/routes/channels/channel_delete.rs +++ b/crates/delta/src/routes/channels/channel_delete.rs @@ -1,6 +1,6 @@ use revolt_database::{ util::{permissions::DatabasePermissionQuery, reference::Reference}, - Channel, Database, PartialChannel, User, + Channel, Database, PartialChannel, User, AMQP, }; use revolt_models::v0; use revolt_permissions::{calculate_channel_permissions, ChannelPermission}; @@ -15,6 +15,7 @@ use rocket_empty::EmptyResponse; #[delete("/?")] pub async fn delete( db: &State, + amqp: &State, user: User, target: Reference, options: v0::OptionsChannelDelete, @@ -39,7 +40,13 @@ pub async fn delete( .await .map(|_| EmptyResponse), Channel::Group { .. } => channel - .remove_user_from_group(db, &user, None, options.leave_silently.unwrap_or_default()) + .remove_user_from_group( + db, + amqp, + &user, + None, + options.leave_silently.unwrap_or_default(), + ) .await .map(|_| EmptyResponse), Channel::TextChannel { .. } | Channel::VoiceChannel { .. } => { diff --git a/crates/delta/src/routes/channels/channel_edit.rs b/crates/delta/src/routes/channels/channel_edit.rs index d53c14fd4..1977ec9ba 100644 --- a/crates/delta/src/routes/channels/channel_edit.rs +++ b/crates/delta/src/routes/channels/channel_edit.rs @@ -1,6 +1,6 @@ use revolt_database::{ util::{permissions::DatabasePermissionQuery, reference::Reference}, - Channel, Database, File, PartialChannel, SystemMessage, User, + Channel, Database, File, PartialChannel, SystemMessage, User, AMQP, }; use revolt_models::v0; use revolt_permissions::{calculate_channel_permissions, ChannelPermission}; @@ -15,6 +15,7 @@ use validator::Validate; #[patch("/", data = "")] pub async fn edit( db: &State, + amqp: &State, user: User, target: Reference, data: Json, @@ -73,7 +74,15 @@ pub async fn edit( return Err(create_error!(InvalidOperation)); } .into_message(channel.id().to_string()) - .send(db, user.as_author_for_system(), None, None, &channel, false) + .send( + db, + Some(amqp), + user.as_author_for_system(), + None, + None, + &channel, + false, + ) .await .ok(); } @@ -151,7 +160,15 @@ pub async fn edit( by: user.id.clone(), } .into_message(channel.id().to_string()) - .send(db, user.as_author_for_system(), None, None, &channel, false) + .send( + db, + Some(amqp), + user.as_author_for_system(), + None, + None, + &channel, + false, + ) .await .ok(); } @@ -161,7 +178,15 @@ pub async fn edit( by: user.id.clone(), } .into_message(channel.id().to_string()) - .send(db, user.as_author_for_system(), None, None, &channel, false) + .send( + db, + Some(amqp), + user.as_author_for_system(), + None, + None, + &channel, + false, + ) .await .ok(); } @@ -171,7 +196,15 @@ pub async fn edit( by: user.id.clone(), } .into_message(channel.id().to_string()) - .send(db, user.as_author_for_system(), None, None, &channel, false) + .send( + db, + Some(amqp), + user.as_author_for_system(), + None, + None, + &channel, + false, + ) .await .ok(); } diff --git a/crates/delta/src/routes/channels/group_add_member.rs b/crates/delta/src/routes/channels/group_add_member.rs index 16ee84c73..79d9e6335 100644 --- a/crates/delta/src/routes/channels/group_add_member.rs +++ b/crates/delta/src/routes/channels/group_add_member.rs @@ -1,6 +1,6 @@ use revolt_database::{ util::{permissions::DatabasePermissionQuery, reference::Reference}, - Channel, Database, User, + Channel, Database, User, AMQP, }; use revolt_permissions::{calculate_channel_permissions, ChannelPermission}; use revolt_result::{create_error, Result}; @@ -15,6 +15,7 @@ use rocket_empty::EmptyResponse; #[put("//recipients/")] pub async fn add_member( db: &State, + amqp: &State, user: User, group_id: Reference, member_id: Reference, @@ -38,7 +39,7 @@ pub async fn add_member( } channel - .add_user_to_group(db, &member, &user.id) + .add_user_to_group(db, amqp, &member, &user.id) .await .map(|_| EmptyResponse) } diff --git a/crates/delta/src/routes/channels/group_remove_member.rs b/crates/delta/src/routes/channels/group_remove_member.rs index 3a6f00ef9..90cccfefa 100644 --- a/crates/delta/src/routes/channels/group_remove_member.rs +++ b/crates/delta/src/routes/channels/group_remove_member.rs @@ -1,4 +1,4 @@ -use revolt_database::{util::reference::Reference, Channel, Database, User}; +use revolt_database::{util::reference::Reference, Channel, Database, User, AMQP}; use revolt_permissions::ChannelPermission; use revolt_result::{create_error, Result}; @@ -12,6 +12,7 @@ use rocket_empty::EmptyResponse; #[delete("//recipients/")] pub async fn remove_member( db: &State, + amqp: &State, user: User, target: Reference, member: Reference, @@ -42,7 +43,7 @@ pub async fn remove_member( } channel - .remove_user_from_group(db, &member, Some(&user.id), false) + .remove_user_from_group(db, amqp, &member, Some(&user.id), false) .await .map(|_| EmptyResponse) } diff --git a/crates/delta/src/routes/channels/message_pin.rs b/crates/delta/src/routes/channels/message_pin.rs index 5dc318f87..d8df7febd 100644 --- a/crates/delta/src/routes/channels/message_pin.rs +++ b/crates/delta/src/routes/channels/message_pin.rs @@ -1,5 +1,6 @@ use revolt_database::{ - util::{permissions::DatabasePermissionQuery, reference::Reference}, Database, PartialMessage, SystemMessage, User + util::{permissions::DatabasePermissionQuery, reference::Reference}, + Database, PartialMessage, SystemMessage, User, AMQP, }; use revolt_models::v0::MessageAuthor; use revolt_permissions::{calculate_channel_permissions, ChannelPermission}; @@ -14,6 +15,7 @@ use rocket_empty::EmptyResponse; #[post("//messages//pin")] pub async fn message_pin( db: &State, + amqp: &State, user: User, target: Reference, msg: Reference, @@ -28,30 +30,38 @@ pub async fn message_pin( let mut message = msg.as_message_in_channel(db, channel.id()).await?; if message.pinned.unwrap_or_default() { - return Err(create_error!(AlreadyPinned)) + return Err(create_error!(AlreadyPinned)); } - message.update(db, PartialMessage { - pinned: Some(true), - ..Default::default() - }, vec![]).await?; + message + .update( + db, + PartialMessage { + pinned: Some(true), + ..Default::default() + }, + vec![], + ) + .await?; SystemMessage::MessagePinned { id: message.id.clone(), - by: user.id.clone() + by: user.id.clone(), } .into_message(channel.id().to_string()) .send( db, + Some(amqp), MessageAuthor::System { username: &user.username, - avatar: user.avatar.as_ref().map(|file| file.id.as_ref()) + avatar: user.avatar.as_ref().map(|file| file.id.as_ref()), }, None, None, &channel, - false - ).await?; + false, + ) + .await?; Ok(EmptyResponse) } @@ -59,7 +69,11 @@ pub async fn message_pin( #[cfg(test)] mod test { use crate::{rocket, util::test::TestHarness}; - use revolt_database::{events::client::EventV1, util::{idempotency::IdempotencyKey, reference::Reference}, Member, Message, Server}; + use revolt_database::{ + events::client::EventV1, + util::{idempotency::IdempotencyKey, reference::Reference}, + Member, Message, Server, + }; use revolt_models::v0::{self, SystemMessage}; use rocket::http::{Header, Status}; @@ -75,24 +89,29 @@ mod test { ..Default::default() }, &user, - true - ).await.expect("Failed to create test server"); + true, + ) + .await + .expect("Failed to create test server"); - let (member, channels) = Member::create(&harness.db, &server, &user, Some(channels)).await.expect("Failed to create member"); + let (member, channels) = Member::create(&harness.db, &server, &user, Some(channels)) + .await + .expect("Failed to create member"); let channel = &channels[0]; let message = Message::create_from_api( &harness.db, + None, channel.clone(), v0::DataMessageSend { - content:Some("Test message".to_string()), + content: Some("Test message".to_string()), nonce: None, attachments: None, replies: None, embeds: None, masquerade: None, interactions: None, - flags: None + flags: None, }, v0::MessageAuthor::User(&user.clone().into(&harness.db, Some(&user)).await), Some(user.clone().into(&harness.db, Some(&user)).await), @@ -100,14 +119,18 @@ mod test { user.limits().await, IdempotencyKey::unchecked_from_string("0".to_string()), false, - false + false, ) .await .expect("Failed to create message"); let response = harness .client - .post(format!("/channels/{}/messages/{}/pin", channel.id(), &message.id)) + .post(format!( + "/channels/{}/messages/{}/pin", + channel.id(), + &message.id + )) .header(Header::new("x-session-token", session.token.to_string())) .dispatch() .await; @@ -115,34 +138,37 @@ mod test { assert_eq!(response.status(), Status::NoContent); drop(response); - harness.wait_for_event(channel.id(), |event| { - match event { - EventV1::Message(message) => { - match &message.system { - Some(SystemMessage::MessagePinned { by, .. }) => { - assert_eq!(by, &user.id); + harness + .wait_for_event(channel.id(), |event| match event { + EventV1::Message(message) => match &message.system { + Some(SystemMessage::MessagePinned { by, .. }) => { + assert_eq!(by, &user.id); - true - }, - _ => false + true } + _ => false, }, - _ => false - } - }).await; + _ => false, + }) + .await; - harness.wait_for_event(channel.id(), |event| { - match event { - EventV1::MessageUpdate { id, channel: channel_id, data, .. } => { + harness + .wait_for_event(channel.id(), |event| match event { + EventV1::MessageUpdate { + id, + channel: channel_id, + data, + .. + } => { assert_eq!(id, &message.id); assert_eq!(channel_id, channel.id()); assert_eq!(data.pinned, Some(true)); true - }, - _ => false - } - }).await; + } + _ => false, + }) + .await; let updated_message = Reference::from_unchecked(message.id) .as_message(&harness.db) diff --git a/crates/delta/src/routes/channels/message_send.rs b/crates/delta/src/routes/channels/message_send.rs index 93a290579..55c133088 100644 --- a/crates/delta/src/routes/channels/message_send.rs +++ b/crates/delta/src/routes/channels/message_send.rs @@ -3,7 +3,7 @@ use revolt_database::util::permissions::DatabasePermissionQuery; use revolt_database::{ util::idempotency::IdempotencyKey, util::reference::Reference, Database, User, }; -use revolt_database::{Interactions, Message}; +use revolt_database::{Interactions, Message, AMQP}; use revolt_models::v0; use revolt_permissions::PermissionQuery; use revolt_permissions::{calculate_channel_permissions, ChannelPermission}; @@ -19,6 +19,7 @@ use validator::Validate; #[post("//messages", data = "")] pub async fn message_send( db: &State, + amqp: &State, user: User, target: Reference, data: Json, @@ -93,6 +94,7 @@ pub async fn message_send( Ok(Json( Message::create_from_api( db, + Some(amqp), channel, data, v0::MessageAuthor::User(&author), @@ -107,3 +109,218 @@ pub async fn message_send( .into_model(Some(model_user), model_member), )) } + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use crate::{rocket, util::test::TestHarness}; + use revolt_database::{ + util::{idempotency::IdempotencyKey, reference::Reference}, + Channel, Member, Message, PartialChannel, PartialMember, Role, Server, + }; + use revolt_models::v0::{self, DataCreateServerChannel}; + use revolt_permissions::{ChannelPermission, OverrideField}; + + #[rocket::async_test] + async fn message_mention_constraints() { + let harness = TestHarness::new().await; + let (_, _, user) = harness.new_user().await; + let (_, _, second_user) = harness.new_user().await; + + let (server, channels) = Server::create( + &harness.db, + v0::DataCreateServer { + name: "Test Server".to_string(), + ..Default::default() + }, + &user, + true, + ) + .await + .expect("Failed to create test server"); + + let server_mut: &mut Server = &mut server.clone(); + let mut locked_channel = Channel::create_server_channel( + &harness.db, + server_mut, + DataCreateServerChannel { + channel_type: v0::LegacyServerChannelType::Text, + name: "Hidden Channel".to_string(), + description: None, + nsfw: Some(false), + }, + true, + ) + .await + .expect("Failed to make new channel"); + + let role = Role { + name: "Show Hidden Channel".to_string(), + permissions: OverrideField { a: 0, d: 0 }, + colour: None, + hoist: false, + rank: 5, + }; + + let role_id = role + .create(&harness.db, &server.id) + .await + .expect("Failed to create the role"); + + let mut overrides = HashMap::new(); + overrides.insert( + role_id.clone(), + OverrideField { + a: (ChannelPermission::ViewChannel) as i64, + d: 0, + }, + ); + + let partial = PartialChannel { + name: None, + owner: None, + description: None, + icon: None, + nsfw: None, + active: None, + permissions: None, + role_permissions: Some(overrides), + default_permissions: Some(OverrideField { + a: 0, + d: ChannelPermission::ViewChannel as i64, + }), + last_message_id: None, + }; + locked_channel + .update(&harness.db, partial, vec![]) + .await + .expect("Failed to update the channel permissions for special role"); + + Member::create(&harness.db, &server, &user, Some(channels.clone())) + .await + .expect("Failed to create member"); + let member = Reference::from_unchecked(user.id.clone()) + .as_member(&harness.db, &server.id) + .await + .expect("Failed to get member"); + + // Second user is not part of the server + let message = Message::create_from_api( + &harness.db, + Some(&harness.amqp), + locked_channel.clone(), + v0::DataMessageSend { + content: Some(format!("<@{}>", second_user.id)), + nonce: None, + attachments: None, + replies: None, + embeds: None, + masquerade: None, + interactions: None, + flags: None, + }, + v0::MessageAuthor::User(&user.clone().into(&harness.db, Some(&user)).await), + Some(user.clone().into(&harness.db, Some(&user)).await), + Some(member.clone().into()), + user.limits().await, + IdempotencyKey::unchecked_from_string("0".to_string()), + false, + true, + ) + .await + .expect("Failed to create message"); + + // The mention should not go through here + assert!( + message.mentions.is_none() || message.mentions.unwrap().is_empty(), + "Mention failed to be scrubbed when the user is not part of the server" + ); + + Member::create(&harness.db, &server, &second_user, Some(channels.clone())) + .await + .expect("Failed to create second member"); + let mut second_member = Reference::from_unchecked(second_user.id.clone()) + .as_member(&harness.db, &server.id) + .await + .expect("Failed to get second member"); + + // Second user cannot see the channel + let message = Message::create_from_api( + &harness.db, + Some(&harness.amqp), + locked_channel.clone(), + v0::DataMessageSend { + content: Some(format!("<@{}>", second_user.id)), + nonce: None, + attachments: None, + replies: None, + embeds: None, + masquerade: None, + interactions: None, + flags: None, + }, + v0::MessageAuthor::User(&user.clone().into(&harness.db, Some(&user)).await), + Some(user.clone().into(&harness.db, Some(&user)).await), + Some(member.clone().into()), + user.limits().await, + IdempotencyKey::unchecked_from_string("1".to_string()), + false, + true, + ) + .await + .expect("Failed to create message"); + + // The mention should not go through here + assert!( + message.mentions.is_none() || message.mentions.unwrap().is_empty(), + "Mention failed to be scrubbed when the user cannot see the channel" + ); + + let second_member_roles = vec![role_id.clone()]; + let partial = PartialMember { + id: None, + joined_at: None, + nickname: None, + avatar: None, + timeout: None, + roles: Some(second_member_roles), + }; + second_member + .update(&harness.db, partial, vec![]) + .await + .expect("Failed to update the second user's roles"); + + // This time the mention SHOULD go through + let message = Message::create_from_api( + &harness.db, + Some(&harness.amqp), + locked_channel.clone(), + v0::DataMessageSend { + content: Some(format!("<@{}>", second_user.id)), + nonce: None, + attachments: None, + replies: None, + embeds: None, + masquerade: None, + interactions: None, + flags: None, + }, + v0::MessageAuthor::User(&user.clone().into(&harness.db, Some(&user)).await), + Some(user.clone().into(&harness.db, Some(&user)).await), + Some(member.clone().into()), + user.limits().await, + IdempotencyKey::unchecked_from_string("2".to_string()), + false, + true, + ) + .await + .expect("Failed to create message"); + + // The mention SHOULD go through here + assert!( + message.mentions.is_some() && !message.mentions.unwrap().is_empty(), + "Mention was scrubbed when the user can see the channel" + ); + } +} diff --git a/crates/delta/src/routes/channels/message_unpin.rs b/crates/delta/src/routes/channels/message_unpin.rs index 4084fc119..67154842d 100644 --- a/crates/delta/src/routes/channels/message_unpin.rs +++ b/crates/delta/src/routes/channels/message_unpin.rs @@ -1,5 +1,6 @@ use revolt_database::{ - util::{permissions::DatabasePermissionQuery, reference::Reference}, Database, FieldsMessage, PartialMessage, SystemMessage, User + util::{permissions::DatabasePermissionQuery, reference::Reference}, + Database, FieldsMessage, PartialMessage, SystemMessage, User, AMQP, }; use revolt_models::v0::MessageAuthor; use revolt_permissions::{calculate_channel_permissions, ChannelPermission}; @@ -14,6 +15,7 @@ use rocket_empty::EmptyResponse; #[delete("//messages//pin")] pub async fn message_unpin( db: &State, + amqp: &State, user: User, target: Reference, msg: Reference, @@ -28,27 +30,31 @@ pub async fn message_unpin( let mut message = msg.as_message_in_channel(db, channel.id()).await?; if !message.pinned.unwrap_or_default() { - return Err(create_error!(NotPinned)) + return Err(create_error!(NotPinned)); } - message.update(db, PartialMessage::default(), vec![FieldsMessage::Pinned]).await?; + message + .update(db, PartialMessage::default(), vec![FieldsMessage::Pinned]) + .await?; SystemMessage::MessageUnpinned { id: message.id.clone(), - by: user.id.clone() + by: user.id.clone(), } .into_message(channel.id().to_string()) .send( db, + Some(amqp), MessageAuthor::System { username: &user.username, - avatar: user.avatar.as_ref().map(|file| file.id.as_ref()) + avatar: user.avatar.as_ref().map(|file| file.id.as_ref()), }, None, None, &channel, - false - ).await?; + false, + ) + .await?; Ok(EmptyResponse) } @@ -56,7 +62,11 @@ pub async fn message_unpin( #[cfg(test)] mod test { use crate::{rocket, util::test::TestHarness}; - use revolt_database::{events::client::EventV1, util::{idempotency::IdempotencyKey, reference::Reference}, Member, Message, PartialMessage, Server}; + use revolt_database::{ + events::client::EventV1, + util::{idempotency::IdempotencyKey, reference::Reference}, + Member, Message, PartialMessage, Server, + }; use revolt_models::v0::{self, FieldsMessage, SystemMessage}; use rocket::http::{Header, Status}; @@ -72,26 +82,34 @@ mod test { ..Default::default() }, &user, - true - ).await.expect("Failed to create test server"); + true, + ) + .await + .expect("Failed to create test server"); let channel = &channels[0]; - Member::create(&harness.db, &server, &user, Some(channels.clone())).await.expect("Failed to create member"); - let member = Reference::from_unchecked(user.id.clone()).as_member(&harness.db, &server.id).await.expect("Failed to get member"); + Member::create(&harness.db, &server, &user, Some(channels.clone())) + .await + .expect("Failed to create member"); + let member = Reference::from_unchecked(user.id.clone()) + .as_member(&harness.db, &server.id) + .await + .expect("Failed to get member"); let message = Message::create_from_api( &harness.db, + None, channel.clone(), v0::DataMessageSend { - content:Some("Test message".to_string()), + content: Some("Test message".to_string()), nonce: None, attachments: None, replies: None, embeds: None, masquerade: None, interactions: None, - flags: None + flags: None, }, v0::MessageAuthor::User(&user.clone().into(&harness.db, Some(&user)).await), Some(user.clone().into(&harness.db, Some(&user)).await), @@ -99,23 +117,31 @@ mod test { user.limits().await, IdempotencyKey::unchecked_from_string("0".to_string()), false, - false + false, ) .await .expect("Failed to create message"); - harness.db.update_message( - &message.id, - &PartialMessage { - pinned: Some(true), - ..Default::default() - }, - vec![] - ).await.expect("Failed to update message"); + harness + .db + .update_message( + &message.id, + &PartialMessage { + pinned: Some(true), + ..Default::default() + }, + vec![], + ) + .await + .expect("Failed to update message"); let response = harness .client - .delete(format!("/channels/{}/messages/{}/pin", channel.id(), &message.id)) + .delete(format!( + "/channels/{}/messages/{}/pin", + channel.id(), + &message.id + )) .header(Header::new("x-session-token", session.token.to_string())) .dispatch() .await; @@ -123,33 +149,31 @@ mod test { assert_eq!(response.status(), Status::NoContent); drop(response); - harness.wait_for_event(channel.id(), |event| { - match event { - EventV1::Message(message) => { - match &message.system { - Some(SystemMessage::MessageUnpinned { by, .. }) => { - assert_eq!(by, &user.id); + harness + .wait_for_event(channel.id(), |event| match event { + EventV1::Message(message) => match &message.system { + Some(SystemMessage::MessageUnpinned { by, .. }) => { + assert_eq!(by, &user.id); - true - }, - _ => false + true } + _ => false, }, - _ => false - } - }).await; + _ => false, + }) + .await; - harness.wait_for_event(channel.id(), |event| { - match event { + harness + .wait_for_event(channel.id(), |event| match event { EventV1::MessageUpdate { id, clear, .. } => { assert_eq!(&message.id, id); assert_eq!(clear, &[FieldsMessage::Pinned]); true - }, - _ => false - } - }).await; + } + _ => false, + }) + .await; let updated_message = Reference::from_unchecked(message.id) .as_message(&harness.db) diff --git a/crates/delta/src/routes/invites/invite_join.rs b/crates/delta/src/routes/invites/invite_join.rs index 2fa872d3f..83b591951 100644 --- a/crates/delta/src/routes/invites/invite_join.rs +++ b/crates/delta/src/routes/invites/invite_join.rs @@ -1,4 +1,4 @@ -use revolt_database::{util::reference::Reference, Channel, Database, Invite, Member, User}; +use revolt_database::{util::reference::Reference, Channel, Database, Invite, Member, User, AMQP}; use revolt_models::v0::{self, InviteJoinResponse}; use revolt_result::{create_error, Result}; use rocket::{serde::json::Json, State}; @@ -10,6 +10,7 @@ use rocket::{serde::json::Json, State}; #[post("/")] pub async fn join( db: &State, + amqp: &State, user: User, target: Reference, ) -> Result> { @@ -34,7 +35,7 @@ pub async fn join( channel, creator, .. } => { let mut channel = db.fetch_channel(channel).await?; - channel.add_user_to_group(db, &user, creator).await?; + channel.add_user_to_group(db, amqp, &user, creator).await?; if let Channel::Group { recipients, .. } = &channel { Ok(Json(InviteJoinResponse::Group { users: User::fetch_many_ids_as_mutuals(db, &user, recipients).await?, diff --git a/crates/delta/src/routes/root.rs b/crates/delta/src/routes/root.rs index 1e237bff6..a67c596a6 100644 --- a/crates/delta/src/routes/root.rs +++ b/crates/delta/src/routes/root.rs @@ -114,7 +114,7 @@ pub async fn root() -> Result> { }, ws: config.hosts.events, app: config.hosts.app, - vapid: config.api.vapid.public_key, + vapid: config.pushd.vapid.public_key, build: BuildInformation { commit_sha: option_env!("VERGEN_GIT_SHA") .unwrap_or_else(|| "") diff --git a/crates/delta/src/routes/users/add_friend.rs b/crates/delta/src/routes/users/add_friend.rs index f544c766a..b7728d0ea 100644 --- a/crates/delta/src/routes/users/add_friend.rs +++ b/crates/delta/src/routes/users/add_friend.rs @@ -1,5 +1,5 @@ use revolt_database::util::reference::Reference; -use revolt_database::{Database, User}; +use revolt_database::{Database, User, AMQP}; use revolt_models::v0; use revolt_result::{create_error, Result}; use rocket::serde::json::Json; @@ -12,6 +12,7 @@ use rocket::State; #[put("//friend")] pub async fn add( db: &State, + amqp: &State, mut user: User, target: Reference, ) -> Result> { @@ -21,6 +22,6 @@ pub async fn add( return Err(create_error!(IsBot)); } - user.add_friend(db, &mut target).await?; + user.add_friend(db, amqp, &mut target).await?; Ok(Json(target.into(db, &user).await)) } diff --git a/crates/delta/src/routes/users/send_friend_request.rs b/crates/delta/src/routes/users/send_friend_request.rs index 552219f25..a87334af2 100644 --- a/crates/delta/src/routes/users/send_friend_request.rs +++ b/crates/delta/src/routes/users/send_friend_request.rs @@ -1,5 +1,5 @@ -use revolt_database::util::reference::Reference; -use revolt_database::{Database, User}; +// use revolt_database::util::reference::Reference; +use revolt_database::{Database, User, AMQP}; use revolt_models::v0; use revolt_result::{create_error, Result}; use rocket::serde::json::Json; @@ -12,6 +12,7 @@ use rocket::State; #[post("/friend", data = "")] pub async fn send_friend_request( db: &State, + amqp: &State, mut user: User, data: Json, ) -> Result> { @@ -22,7 +23,7 @@ pub async fn send_friend_request( return Err(create_error!(IsBot)); } - user.add_friend(db, &mut target).await?; + user.add_friend(db, amqp, &mut target).await?; Ok(Json(target.into(db, &user).await)) } else { Err(create_error!(InvalidProperty)) diff --git a/crates/delta/src/routes/webhooks/webhook_execute.rs b/crates/delta/src/routes/webhooks/webhook_execute.rs index 29b33dde3..e43ee9348 100644 --- a/crates/delta/src/routes/webhooks/webhook_execute.rs +++ b/crates/delta/src/routes/webhooks/webhook_execute.rs @@ -1,7 +1,7 @@ use revolt_config::config; use revolt_database::{ util::{idempotency::IdempotencyKey, reference::Reference}, - Database, Message, + Database, Message, AMQP, }; use revolt_models::v0; use revolt_permissions::{ChannelPermission, PermissionValue}; @@ -17,6 +17,7 @@ use validator::Validate; #[post("//", data = "")] pub async fn webhook_execute( db: &State, + amqp: &State, webhook_id: Reference, token: String, data: Json, @@ -56,6 +57,7 @@ pub async fn webhook_execute( Ok(Json( Message::create_from_api( db, + Some(amqp), channel, data, v0::MessageAuthor::Webhook(&webhook.into()), diff --git a/crates/delta/src/routes/webhooks/webhook_execute_github.rs b/crates/delta/src/routes/webhooks/webhook_execute_github.rs index 85bf2e429..a5b0d12b9 100644 --- a/crates/delta/src/routes/webhooks/webhook_execute_github.rs +++ b/crates/delta/src/routes/webhooks/webhook_execute_github.rs @@ -1,4 +1,4 @@ -use revolt_database::{util::reference::Reference, Database, Message}; +use revolt_database::{util::reference::Reference, Database, Message, AMQP}; use revolt_models::v0::{MessageAuthor, SendableEmbed, Webhook}; use revolt_result::{create_error, Error, Result}; use revolt_rocket_okapi::{ @@ -635,7 +635,10 @@ impl<'r> FromRequest<'r> for EventHeader<'r> { async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { let headers = request.headers(); let Some(event) = headers.get_one("X-GitHub-Event") else { - return rocket::request::Outcome::Failure((Status::BadRequest, create_error!(InvalidOperation))) + return rocket::request::Outcome::Error(( + Status::BadRequest, + create_error!(InvalidOperation), + )); }; rocket::request::Outcome::Success(Self(event)) @@ -747,6 +750,7 @@ fn convert_event(data: &str, event_name: &str) -> Result { #[post("///github", data = "")] pub async fn webhook_execute_github( db: &State, + amqp: &State, webhook_id: Reference, token: String, event: EventHeader<'_>, @@ -784,7 +788,9 @@ pub async fn webhook_execute_github( r#ref, .. }) => { - let Some(branch) = r#ref.split('/').nth(2) else { return Ok(()) }; + let Some(branch) = r#ref.split('/').nth(2) else { + return Ok(()); + }; if forced { let description = format!( @@ -1074,6 +1080,7 @@ pub async fn webhook_execute_github( message .send( db, + Some(amqp), MessageAuthor::Webhook(&webhook.into()), None, None, diff --git a/crates/delta/src/util/ratelimiter.rs b/crates/delta/src/util/ratelimiter.rs index ce7317ffc..c95f818df 100644 --- a/crates/delta/src/util/ratelimiter.rs +++ b/crates/delta/src/util/ratelimiter.rs @@ -235,7 +235,7 @@ impl<'r> FromRequest<'r> for Ratelimiter { match ratelimiter { Ok(ratelimiter) => Outcome::Success(*ratelimiter), - Err(ratelimiter) => Outcome::Failure((Status::TooManyRequests, *ratelimiter)), + Err(ratelimiter) => Outcome::Error((Status::TooManyRequests, *ratelimiter)), } } } @@ -264,7 +264,7 @@ impl Fairing for RatelimitFairing { async fn on_request(&self, request: &mut Request<'_>, _: &mut Data<'_>) { use rocket::outcome::Outcome; - if let Outcome::Failure(_) = request.guard::().await { + if let Outcome::Error(_) = request.guard::().await { info!( "User rate-limited on route {}! (IP = {:?})", request.uri(), @@ -278,7 +278,7 @@ impl Fairing for RatelimitFairing { async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { let guard = request.guard::().await; - let (Outcome::Success(ratelimiter) | Outcome::Failure((_, ratelimiter))) = guard else { + let (Outcome::Success(ratelimiter) | Outcome::Error((_, ratelimiter))) = guard else { unreachable!() }; let Ratelimiter { @@ -293,7 +293,7 @@ impl Fairing for RatelimitFairing { response.set_raw_header("X-RateLimit-Remaining", remaining.to_string()); response.set_raw_header("X-RateLimit-Reset-After", reset.to_string()); - if guard.is_failure() { + if guard.is_error() { response.set_status(Status::TooManyRequests); } } @@ -313,7 +313,7 @@ impl<'r> FromRequest<'r> for RatelimitInformation { async fn from_request(request: &'r rocket::Request<'_>) -> Outcome { let info = match request.guard::().await { Outcome::Success(ratelimiter) => RatelimitInformation::Success(ratelimiter), - Outcome::Failure((_, ratelimiter)) => RatelimitInformation::Failure { + Outcome::Error((_, ratelimiter)) => RatelimitInformation::Failure { retry_after: ratelimiter.reset, }, _ => unreachable!(), diff --git a/crates/delta/src/util/test.rs b/crates/delta/src/util/test.rs index 918135043..325a7ccf5 100644 --- a/crates/delta/src/util/test.rs +++ b/crates/delta/src/util/test.rs @@ -5,7 +5,7 @@ use authifier::{ use futures::StreamExt; use rand::Rng; use redis_kiss::redis::aio::PubSub; -use revolt_database::{events::client::EventV1, Database, User}; +use revolt_database::{events::client::EventV1, Database, User, AMQP}; use revolt_models::v0; use rocket::local::asynchronous::Client; @@ -13,12 +13,15 @@ pub struct TestHarness { pub client: Client, authifier: Authifier, pub db: Database, + pub amqp: AMQP, sub: PubSub, event_buffer: Vec<(String, EventV1)>, } impl TestHarness { pub async fn new() -> TestHarness { + let config = revolt_config::config().await; + let client = Client::tracked(crate::web().await) .await .expect("valid rocket instance"); @@ -41,10 +44,25 @@ impl TestHarness { .expect("`Authifier`") .clone(); + let connection = amqprs::connection::Connection::open( + &amqprs::connection::OpenConnectionArguments::new( + &config.rabbit.host, + config.rabbit.port, + &config.rabbit.username, + &config.rabbit.password, + ), + ) + .await + .unwrap(); + let channel = connection.open_channel(None).await.unwrap(); + + let amqp = AMQP::new(connection, channel); + TestHarness { client, authifier, db, + amqp, sub, event_buffer: vec![], } diff --git a/scripts/build-image-layer.sh b/scripts/build-image-layer.sh index 5043b09ba..4b2e18c47 100644 --- a/scripts/build-image-layer.sh +++ b/scripts/build-image-layer.sh @@ -33,12 +33,14 @@ deps() { crates/core/presence/src \ crates/core/result/src \ crates/services/autumn/src \ - crates/services/january/src + crates/services/january/src \ + crates/daemons/pushd/src echo 'fn main() { panic!("stub"); }' | tee crates/bonfire/src/main.rs | tee crates/delta/src/main.rs | tee crates/services/autumn/src/main.rs | - tee crates/services/january/src/main.rs + tee crates/services/january/src/main.rs | + tee crates/daemons/pushd/src/main.rs echo '' | tee crates/bindings/node/src/lib.rs | tee crates/core/config/src/lib.rs | @@ -60,6 +62,7 @@ apps() { touch -am \ crates/bonfire/src/main.rs \ crates/delta/src/main.rs \ + crates/daemons/pushd/src/main.rs \ crates/core/config/src/lib.rs \ crates/core/database/src/lib.rs \ crates/core/models/src/lib.rs \