From 2761ab4bb44d03a480fe48f5cfab15627cd423d9 Mon Sep 17 00:00:00 2001 From: Ben Spencer Date: Sat, 5 May 2018 13:09:33 +0100 Subject: [PATCH] Fix SSL session verification. --- .../scala/com/twitter/finagle/Postgres.scala | 5 +-- .../finagle/postgres/codec/PgCodec.scala | 36 ++++++++++--------- .../integration/IntegrationSpec.scala | 8 +++-- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/main/scala/com/twitter/finagle/Postgres.scala b/src/main/scala/com/twitter/finagle/Postgres.scala index 99f60627..6a45e262 100644 --- a/src/main/scala/com/twitter/finagle/Postgres.scala +++ b/src/main/scala/com/twitter/finagle/Postgres.scala @@ -11,7 +11,7 @@ import com.twitter.finagle.postgres.messages._ import com.twitter.finagle.postgres.values.ValueDecoder import com.twitter.finagle.service.FailFastFactory.FailFast import com.twitter.finagle.service._ -import com.twitter.finagle.ssl.client.SslClientEngineFactory +import com.twitter.finagle.ssl.client.{ SslClientEngineFactory, SslClientSessionVerifier } import com.twitter.finagle.stats.StatsReceiver import com.twitter.finagle.transport.{Transport, TransportContext} import com.twitter.util.{Monitor => _, _} @@ -77,6 +77,7 @@ object Postgres { private def pipelineFactory(params: Stack.Params) = { val SslClientEngineFactory.Param(sslFactory) = params[SslClientEngineFactory.Param] + val SslClientSessionVerifier.Param(sessionVerifier) = params[SslClientSessionVerifier.Param] val Transport.ClientSsl(ssl) = params[Transport.ClientSsl] new ChannelPipelineFactory { @@ -85,7 +86,7 @@ object Postgres { pipeline.addLast("binary_to_packet", new PacketDecoder(ssl.nonEmpty)) pipeline.addLast("packet_to_backend_messages", new BackendMessageDecoder(new BackendMessageParser)) - pipeline.addLast("backend_messages_to_postgres_response", new PgClientChannelHandler(sslFactory, ssl, ssl.nonEmpty)) + pipeline.addLast("backend_messages_to_postgres_response", new PgClientChannelHandler(sslFactory, sessionVerifier, ssl, ssl.nonEmpty)) pipeline } } diff --git a/src/main/scala/com/twitter/finagle/postgres/codec/PgCodec.scala b/src/main/scala/com/twitter/finagle/postgres/codec/PgCodec.scala index 1190d94a..d86280dc 100644 --- a/src/main/scala/com/twitter/finagle/postgres/codec/PgCodec.scala +++ b/src/main/scala/com/twitter/finagle/postgres/codec/PgCodec.scala @@ -10,7 +10,7 @@ import com.twitter.finagle.postgres.messages._ import com.twitter.finagle.postgres.values.Md5Encryptor import com.twitter.finagle.ssl.client.{ HostnameVerifier, SslClientConfiguration, SslClientEngineFactory, SslClientSessionVerifier } import com.twitter.logging.Logger -import com.twitter.util.Future +import com.twitter.util.{ Future, Try } import javax.net.ssl.{SSLContext, SSLEngine, SSLSession, TrustManagerFactory} import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers} @@ -184,6 +184,7 @@ class PacketDecoder(@volatile var inSslNegotation: Boolean) extends FrameDecoder */ class PgClientChannelHandler( sslEngineFactory: SslClientEngineFactory, + sessionVerifier: SslClientSessionVerifier, sslConfig: Option[SslClientConfiguration], val useSsl: Boolean ) extends SimpleChannelHandler { @@ -211,29 +212,30 @@ class PgClientChannelHandler( val pipeline = ctx.getPipeline - val addr = ctx.getChannel.getRemoteAddress - val inetAddr = addr match { - case i: InetSocketAddress => Some(i) - case _ => None + val (engine, verifier) = ctx.getChannel.getRemoteAddress match { + case i: InetSocketAddress => + val address = Address(i) + val config = sslConfig.getOrElse(SslClientConfiguration(hostname = Some(i.getHostString))) + (sslEngineFactory(address, config).self, (s: SSLSession) => sessionVerifier(address, config, s)) + case _ => + (Ssl.client().self, (_: SSLSession) => true) } - val engine = inetAddr.map(inet => - sslConfig.map(sslEngineFactory(Address(inet), _)).getOrElse(Ssl.client(inet.getHostString, inet.getPort)) - ) - .getOrElse(Ssl.client()) - .self - engine.setUseClientMode(true) val sslHandler = new SslHandler(engine) pipeline.addFirst("ssl", sslHandler) - val verifier: SSLSession => Boolean = inetAddr match { - case Some(inet) => - session => HostnameVerifier(Address(inet), SslClientConfiguration(hostname = Some(inet.getHostName)),session) - case None => - _ => true - } + sslHandler.handshake().addListener(new ChannelFutureListener { + override def operationComplete(f: ChannelFuture) = { + if (!Try(verifier(engine.getSession)).onFailure { err => + logger.error(err, "Exception thrown during SSL session verification") + }.getOrElse(false)) { + logger.error("SSL session verification failed") + Channels.close(ctx.getChannel) + } + } + }) connection.receive(SwitchToSsl).foreach { Channels.fireMessageReceived(ctx, _) diff --git a/src/test/scala/com/twitter/finagle/postgres/integration/IntegrationSpec.scala b/src/test/scala/com/twitter/finagle/postgres/integration/IntegrationSpec.scala index 54e76670..60d8de26 100644 --- a/src/test/scala/com/twitter/finagle/postgres/integration/IntegrationSpec.scala +++ b/src/test/scala/com/twitter/finagle/postgres/integration/IntegrationSpec.scala @@ -25,7 +25,8 @@ object IntegrationSpec { * * If these are conditions are met, the integration tests will be run. * - * The tests can be run with SSL by also setting the USE_PG_SSL variable to "1". + * The tests can be run with SSL by also setting the USE_PG_SSL variable to "1", and hostname verification can be added + * by setting PG_SSL_HOST. * */ class IntegrationSpec extends Spec { @@ -36,6 +37,7 @@ class IntegrationSpec extends Spec { password = sys.env.get("PG_PASSWORD") dbname <- sys.env.get("PG_DBNAME") useSsl = sys.env.getOrElse("USE_PG_SSL", "0") == "1" + sslHost = sys.env.get("PG_SSL_HOST") } yield { @@ -46,7 +48,7 @@ class IntegrationSpec extends Spec { .withCredentials(user, password) .database(dbname) .withSessionPool.maxSize(1) - .conditionally(useSsl, _.withTransport.tlsWithoutValidation) + .conditionally(useSsl, c => sslHost.fold(c.withTransport.tls)(c.withTransport.tls(_))) .newRichClient(hostPort) Await.result(Future[PostgresClientImpl] { @@ -60,7 +62,7 @@ class IntegrationSpec extends Spec { .withCredentials(user, password) .database(dbname) .withSessionPool.maxSize(1) - .conditionally(useSsl, _.withTransport.tlsWithoutValidation) + .conditionally(useSsl, c => sslHost.fold(c.withTransport.tls)(c.withTransport.tls(_))) .newRichClient("badhost:5432") }