From 8a62c618991a26fea224a1cda1dc5ea95141895d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 13 Dec 2024 15:47:27 +0000 Subject: [PATCH] support tls_cert/tls_key in nexus --- nexus/Cargo.lock | 3 +++ nexus/server/Cargo.toml | 3 +++ nexus/server/src/main.rs | 37 ++++++++++++++++++++++++++++++++++--- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/nexus/Cargo.lock b/nexus/Cargo.lock index 5ecd6f63a..30aab8f2f 100644 --- a/nexus/Cargo.lock +++ b/nexus/Cargo.lock @@ -2864,11 +2864,14 @@ dependencies = [ "postgres", "pt", "rand", + "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde_json", "similar", "sqlparser", "time", "tokio", + "tokio-rustls 0.26.1", "tracing", "tracing-appender", "tracing-subscriber", diff --git a/nexus/server/Cargo.toml b/nexus/server/Cargo.toml index 8eccd78c2..3449aa780 100644 --- a/nexus/server/Cargo.toml +++ b/nexus/server/Cargo.toml @@ -52,8 +52,11 @@ pt = { path = "../pt" } sqlparser = { workspace = true, features = ["visitor"] } serde_json = "1.0" rand = "0.8" +rustls-pemfile = "2.0" +rustls-pki-types = "1.0" time = "0.3" tokio.workspace = true +tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12"]} tracing.workspace = true tracing-appender = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/nexus/server/src/main.rs b/nexus/server/src/main.rs index d4cc72480..20242a59e 100644 --- a/nexus/server/src/main.rs +++ b/nexus/server/src/main.rs @@ -1,6 +1,8 @@ use std::{ collections::{HashMap, HashSet}, fmt::Write, + fs::File, + io, sync::Arc, time::Duration, }; @@ -26,14 +28,13 @@ use pgwire::{ AuthSource, LoginInfo, Password, ServerParameterProvider, }, copy::NoopCopyHandler, - NoopErrorHandler, portal::Portal, query::{ExtendedQueryHandler, SimpleQueryHandler}, results::{ DescribePortalResponse, DescribeResponse, DescribeStatementResponse, Response, Tag, }, stmt::StoredStatement, - ClientInfo, PgWireServerHandlers, Type, + ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type, }, error::{ErrorInfo, PgWireError, PgWireResult}, tokio::process_socket, @@ -43,9 +44,13 @@ use pt::{ peerdb_peers::{peer::Config, Peer}, }; use rand::Rng; +use rustls_pemfile::{certs, pkcs8_private_keys}; +use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::Mutex; use tokio::{io::AsyncWriteExt, net::TcpListener}; +use tokio_rustls::rustls::ServerConfig; +use tokio_rustls::TlsAcceptor; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; @@ -1041,6 +1046,29 @@ async fn run_migrations<'a>( Err(anyhow::anyhow!("Failed to connect to catalog")) } +fn setup_tls(args: &Args) -> Result, io::Error> { + if let (Some(tls_cert), Some(tls_key)) = (args.tls_cert.as_deref(), args.tls_key.as_deref()) { + let cert = certs(&mut io::BufReader::new(File::open(tls_cert)?)) + .collect::, io::Error>>()?; + + let key = pkcs8_private_keys(&mut io::BufReader::new(File::open(tls_key)?)) + .map(|key| key.map(PrivateKeyDer::from)) + .collect::, io::Error>>()? + .remove(0); + + let mut config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert, key) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; + + config.alpn_protocols = vec![b"postgresql".to_vec()]; + + Ok(Some(TlsAcceptor::from(Arc::new(config)))) + } else { + Ok(None) + } +} + pub struct Handlers { authenticator: ( Arc, @@ -1107,6 +1135,8 @@ pub async fn main() -> anyhow::Result<()> { Arc::new(NexusServerParameterProvider), ); + let tls_acceptor = setup_tls(&args)?.map(Arc::new); + let peer_conns = { let conn_str = catalog_config.to_pg_connection_string(); let pconns = PeerConnections::new(&conn_str)?; @@ -1137,6 +1167,7 @@ pub async fn main() -> anyhow::Result<()> { let authenticator = authenticator.clone(); let pg_config = catalog_config.to_postgres_config(); let kms_key_id = args.kms_key_id.clone(); + let tls_acceptor = tls_acceptor.clone(); tokio::task::spawn(async move { match Catalog::new(pg_config, &kms_key_id).await { @@ -1152,7 +1183,7 @@ pub async fn main() -> anyhow::Result<()> { )); process_socket( socket, - None, + tls_acceptor, Arc::new(Handlers { nexus, authenticator,