Skip to content

Commit

Permalink
Merge pull request #79 from dangerousben/feature/ssl-verification
Browse files Browse the repository at this point in the history
Fix SSL session verification.
  • Loading branch information
dangerousben authored Nov 14, 2018
2 parents dfac0ed + 2761ab4 commit 0f143dd
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 22 deletions.
5 changes: 3 additions & 2 deletions src/main/scala/com/twitter/finagle/Postgres.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import com.twitter.finagle.postgres.messages._
import com.twitter.finagle.postgres.values.ValueDecoder
import com.twitter.finagle.service.FailFastFactory.FailFast
import com.twitter.finagle.service._
import com.twitter.finagle.ssl.client.SslClientEngineFactory
import com.twitter.finagle.ssl.client.{ SslClientEngineFactory, SslClientSessionVerifier }
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.finagle.transport.{Transport, TransportContext}
import com.twitter.util.{Monitor => _, _}
Expand Down Expand Up @@ -77,6 +77,7 @@ object Postgres {

private def pipelineFactory(params: Stack.Params) = {
val SslClientEngineFactory.Param(sslFactory) = params[SslClientEngineFactory.Param]
val SslClientSessionVerifier.Param(sessionVerifier) = params[SslClientSessionVerifier.Param]
val Transport.ClientSsl(ssl) = params[Transport.ClientSsl]

new ChannelPipelineFactory {
Expand All @@ -85,7 +86,7 @@ object Postgres {

pipeline.addLast("binary_to_packet", new PacketDecoder(ssl.nonEmpty))
pipeline.addLast("packet_to_backend_messages", new BackendMessageDecoder(new BackendMessageParser))
pipeline.addLast("backend_messages_to_postgres_response", new PgClientChannelHandler(sslFactory, ssl, ssl.nonEmpty))
pipeline.addLast("backend_messages_to_postgres_response", new PgClientChannelHandler(sslFactory, sessionVerifier, ssl, ssl.nonEmpty))
pipeline
}
}
Expand Down
36 changes: 19 additions & 17 deletions src/main/scala/com/twitter/finagle/postgres/codec/PgCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.twitter.finagle.postgres.messages._
import com.twitter.finagle.postgres.values.Md5Encryptor
import com.twitter.finagle.ssl.client.{ HostnameVerifier, SslClientConfiguration, SslClientEngineFactory, SslClientSessionVerifier }
import com.twitter.logging.Logger
import com.twitter.util.Future
import com.twitter.util.{ Future, Try }
import javax.net.ssl.{SSLContext, SSLEngine, SSLSession, TrustManagerFactory}

import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers}
Expand Down Expand Up @@ -184,6 +184,7 @@ class PacketDecoder(@volatile var inSslNegotation: Boolean) extends FrameDecoder
*/
class PgClientChannelHandler(
sslEngineFactory: SslClientEngineFactory,
sessionVerifier: SslClientSessionVerifier,
sslConfig: Option[SslClientConfiguration],
val useSsl: Boolean
) extends SimpleChannelHandler {
Expand Down Expand Up @@ -211,29 +212,30 @@ class PgClientChannelHandler(

val pipeline = ctx.getPipeline

val addr = ctx.getChannel.getRemoteAddress
val inetAddr = addr match {
case i: InetSocketAddress => Some(i)
case _ => None
val (engine, verifier) = ctx.getChannel.getRemoteAddress match {
case i: InetSocketAddress =>
val address = Address(i)
val config = sslConfig.getOrElse(SslClientConfiguration(hostname = Some(i.getHostString)))
(sslEngineFactory(address, config).self, (s: SSLSession) => sessionVerifier(address, config, s))
case _ =>
(Ssl.client().self, (_: SSLSession) => true)
}

val engine = inetAddr.map(inet =>
sslConfig.map(sslEngineFactory(Address(inet), _)).getOrElse(Ssl.client(inet.getHostString, inet.getPort))
)
.getOrElse(Ssl.client())
.self

engine.setUseClientMode(true)

val sslHandler = new SslHandler(engine)
pipeline.addFirst("ssl", sslHandler)

val verifier: SSLSession => Boolean = inetAddr match {
case Some(inet) =>
session => HostnameVerifier(Address(inet), SslClientConfiguration(hostname = Some(inet.getHostName)),session)
case None =>
_ => true
}
sslHandler.handshake().addListener(new ChannelFutureListener {
override def operationComplete(f: ChannelFuture) = {
if (!Try(verifier(engine.getSession)).onFailure { err =>
logger.error(err, "Exception thrown during SSL session verification")
}.getOrElse(false)) {
logger.error("SSL session verification failed")
Channels.close(ctx.getChannel)
}
}
})

connection.receive(SwitchToSsl).foreach {
Channels.fireMessageReceived(ctx, _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ object IntegrationSpec {
*
* If these are conditions are met, the integration tests will be run.
*
* The tests can be run with SSL by also setting the USE_PG_SSL variable to "1".
* The tests can be run with SSL by also setting the USE_PG_SSL variable to "1", and hostname verification can be added
* by setting PG_SSL_HOST.
*
*/
class IntegrationSpec extends Spec {
Expand All @@ -36,6 +37,7 @@ class IntegrationSpec extends Spec {
password = sys.env.get("PG_PASSWORD")
dbname <- sys.env.get("PG_DBNAME")
useSsl = sys.env.getOrElse("USE_PG_SSL", "0") == "1"
sslHost = sys.env.get("PG_SSL_HOST")
} yield {


Expand All @@ -46,7 +48,7 @@ class IntegrationSpec extends Spec {
.withCredentials(user, password)
.database(dbname)
.withSessionPool.maxSize(1)
.conditionally(useSsl, _.withTransport.tlsWithoutValidation)
.conditionally(useSsl, c => sslHost.fold(c.withTransport.tls)(c.withTransport.tls(_)))
.newRichClient(hostPort)

Await.result(Future[PostgresClientImpl] {
Expand All @@ -60,7 +62,7 @@ class IntegrationSpec extends Spec {
.withCredentials(user, password)
.database(dbname)
.withSessionPool.maxSize(1)
.conditionally(useSsl, _.withTransport.tlsWithoutValidation)
.conditionally(useSsl, c => sslHost.fold(c.withTransport.tls)(c.withTransport.tls(_)))
.newRichClient("badhost:5432")
}

Expand Down

0 comments on commit 0f143dd

Please sign in to comment.