Skip to content

Commit

Permalink
Merge pull request #93 from 56quarters/tls-update
Browse files Browse the repository at this point in the history
Update TLS related dependencies
  • Loading branch information
56quarters authored Dec 22, 2023
2 parents 69e6fd9 + 5b45819 commit df4f85e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 71 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## v0.7.1 - unreleased

- Add default 5 second timeout to network operations done by `mtop`. #90
- TLS related dependency updates. #93

## v0.7.0 - 2023-11-28

Expand Down
78 changes: 47 additions & 31 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions mtop-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ edition = "2021"

[dependencies]
pin-project-lite = "0.2.13"
rustls-pemfile = "1.0.2"
rustls-webpki = "0.101.2"
rustls-pemfile = "2.0.0"
rustls-webpki = "0.102.0"
tokio = { version = "1.14.0", features = ["full"] }
tokio-rustls = { version = "0.24.0" }
tokio-rustls = { version = "0.25.0" }
tracing = "0.1.11"
urlencoding = "2.1.2"
webpki-roots = "0.25.0"
webpki-roots = "0.26.0"

[lib]
name = "mtop_client"
Expand Down
62 changes: 26 additions & 36 deletions mtop-client/src/pool.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use crate::core::{Memcached, MtopError};
use std::collections::HashMap;
use std::fmt;
use std::fs::File;
use std::io::BufReader as StdBufReader;
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::sync::Arc;
use std::{fmt, io};
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::runtime::Handle;
use tokio::sync::Mutex;
use tokio_rustls::rustls::{Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerName};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::TlsConnector;
use webpki::TrustAnchor;

#[derive(Debug)]
pub struct PooledMemcached {
Expand Down Expand Up @@ -63,7 +63,7 @@ pub struct TLSConfig {
pub struct MemcachedPool {
clients: Mutex<HashMap<String, Memcached>>,
client_config: Option<Arc<ClientConfig>>,
server: Option<ServerName>,
server: Option<ServerName<'static>>,
config: PoolConfig,
}

Expand Down Expand Up @@ -115,9 +115,7 @@ impl MemcachedPool {
};

let trust_store = Self::trust_store(ca)?;
let builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(trust_store);
let builder = ClientConfig::builder().with_root_certificates(trust_store);

let config = match (client_cert, client_key) {
(Some(cert), Some(key)) => {
Expand All @@ -135,58 +133,49 @@ impl MemcachedPool {
Ok(config)
}

async fn load_cert(handle: &Handle, path: &PathBuf) -> Result<Vec<Certificate>, MtopError> {
let mut reader = File::open(path.clone())
async fn load_cert(handle: &Handle, path: &PathBuf) -> Result<Vec<CertificateDer<'static>>, MtopError> {
let mut reader = File::open(path)
.map(StdBufReader::new)
.map_err(|e| MtopError::configuration_cause(format!("unable to load cert {:?}", path), e))?;

Ok(handle
.spawn_blocking(move || rustls_pemfile::certs(&mut reader))
// Read all certs from the file in a separate thread and then convert the awkward
// Vec<Result<Cert>> type to a Result<Vec<Cert>> since we expect all certs to be valid
handle
.spawn_blocking(move || rustls_pemfile::certs(&mut reader).collect::<Vec<_>>())
.await
.unwrap()
.map_err(|e| MtopError::configuration_cause(format!("unable to parse cert {:?}", path), e))? // unwrap the spawn result, try the read result
.into_iter()
.map(Certificate)
.collect())
.collect::<Result<Vec<CertificateDer<'static>>, io::Error>>()
.map_err(|e| MtopError::configuration_cause(format!("unable to parse cert {:?}", path), e))
}

async fn load_key(handle: &Handle, path: &PathBuf) -> Result<PrivateKey, MtopError> {
let mut reader = File::open(path.clone())
async fn load_key(handle: &Handle, path: &PathBuf) -> Result<PrivateKeyDer<'static>, MtopError> {
let mut reader = File::open(path)
.map(StdBufReader::new)
.map_err(|e| MtopError::configuration_cause(format!("unable to load key {:?}", path), e))?;

// Read a single key in a separate thread returning an error if there is no key.
handle
.spawn_blocking(move || rustls_pemfile::pkcs8_private_keys(&mut reader))
.spawn_blocking(move || rustls_pemfile::private_key(&mut reader))
.await
.unwrap()
.map_err(|e| MtopError::configuration_cause(format!("unable to parse key {:?}", path), e))? // unwrap the spawn result, try the read result
.into_iter()
.next()
.map(PrivateKey)
.map_err(|e| MtopError::configuration_cause(format!("unable to parse key {:?}", path), e))?
.ok_or_else(|| MtopError::configuration(format!("no keys available in {:?}", path)))
}

fn trust_store(ca: Option<Vec<Certificate>>) -> Result<RootCertStore, MtopError> {
fn trust_store(ca: Option<Vec<CertificateDer<'static>>>) -> Result<RootCertStore, MtopError> {
let mut root_cert_store = RootCertStore::empty();

if let Some(ca_certs) = ca {
let mut anchors = Vec::with_capacity(ca_certs.len());
tracing::debug!(message = "adding custom CA certs for roots", num_certs = ca_certs.len());
for cert in ca_certs {
let anchor = TrustAnchor::try_from_cert_der(&cert.0)
.map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)
})
root_cert_store
.add(cert)
.map_err(|e| MtopError::internal_cause("unable to parse CA cert", e))?;
anchors.push(anchor);
}

tracing::debug!(message = "adding custom CA certs for roots", num_certs = anchors.len());
root_cert_store.add_trust_anchors(anchors.into_iter());
} else {
tracing::debug!(message = "using default CA certs for roots");
root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)
}));
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|c| c.to_owned()));
}

Ok(root_cert_store)
Expand All @@ -209,8 +198,9 @@ impl MemcachedPool {
}
}

fn host_to_server_name(host: &str) -> Result<ServerName, MtopError> {
fn host_to_server_name(host: &str) -> Result<ServerName<'static>, MtopError> {
ServerName::try_from(host)
.map(|s| s.to_owned())
.map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e))
}

Expand Down Expand Up @@ -248,7 +238,7 @@ where
Ok(Memcached::new(read, write))
}

async fn tls_connect<A>(host: A, server: ServerName, config: Arc<ClientConfig>) -> Result<Memcached, MtopError>
async fn tls_connect<A>(host: A, server: ServerName<'static>, config: Arc<ClientConfig>) -> Result<Memcached, MtopError>
where
A: ToSocketAddrs + fmt::Display,
{
Expand Down

0 comments on commit df4f85e

Please sign in to comment.