From eb1289114c31a37c404f1a94bb8b878a6a211094 Mon Sep 17 00:00:00 2001 From: Julien Viet Date: Fri, 15 Sep 2023 21:56:50 +0200 Subject: [PATCH] SSLHelper cache with eviction --- .../io/vertx/core/net/impl/SSLHelper.java | 218 +++++++++++------- .../io/vertx/core/net/impl/SSLHelperTest.java | 21 +- 2 files changed, 147 insertions(+), 92 deletions(-) diff --git a/src/main/java/io/vertx/core/net/impl/SSLHelper.java b/src/main/java/io/vertx/core/net/impl/SSLHelper.java index 5b92258678c..75a64a5baa6 100755 --- a/src/main/java/io/vertx/core/net/impl/SSLHelper.java +++ b/src/main/java/io/vertx/core/net/impl/SSLHelper.java @@ -13,6 +13,7 @@ import io.netty.handler.ssl.OpenSsl; import io.vertx.core.Future; +import io.vertx.core.Promise; import io.vertx.core.VertxException; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.ClientAuth; @@ -25,7 +26,6 @@ import java.security.cert.CRL; import java.security.cert.CertificateFactory; import java.util.*; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -36,55 +36,6 @@ */ public class SSLHelper { - private final static class ConfigKey { - private final KeyCertOptions keyCertOptions; - private final TrustOptions trustOptions; - private final List crlValues; - public ConfigKey(SSLOptions options) { - this(options.getKeyCertOptions(), options.getTrustOptions(), options.getCrlValues()); - } - public ConfigKey(KeyCertOptions keyCertOptions, TrustOptions trustOptions, List crlValues) { - this.keyCertOptions = keyCertOptions; - this.trustOptions = trustOptions; - this.crlValues = crlValues != null ? new ArrayList<>(crlValues) : null; - } - - @Override - public boolean equals(Object obj) { - if (obj == this) { - return true; - } - if (obj instanceof ConfigKey) { - ConfigKey that = (ConfigKey) obj; - return Objects.equals(keyCertOptions, that.keyCertOptions) && Objects.equals(trustOptions, that.trustOptions) && Objects.equals(crlValues, that.crlValues); - } - return false; - } - - @Override - public int hashCode() { - int hashCode = Objects.hashCode(keyCertOptions); - hashCode = 31 * hashCode + Objects.hashCode(trustOptions); - hashCode = 31 * hashCode + Objects.hashCode(crlValues); - return hashCode; - } - } - - private final static class Config { - private final KeyManagerFactory keyManagerFactory; - private final TrustManagerFactory trustManagerFactory; - private final Function keyManagerFactoryMapper; - private final Function trustManagerMapper; - private final List crls; - public Config(KeyManagerFactory keyManagerFactory, TrustManagerFactory trustManagerFactory, Function keyManagerFactoryMapper, Function trustManagerMapper, List crls) { - this.keyManagerFactory = keyManagerFactory; - this.trustManagerFactory = trustManagerFactory; - this.keyManagerFactoryMapper = keyManagerFactoryMapper; - this.trustManagerMapper = trustManagerMapper; - this.crls = crls; - } - } - private static final Config NULL_CONFIG = new Config(null, null, null, null, null); static final EnumMap CLIENT_AUTH_MAPPING = new EnumMap<>(ClientAuth.class); @@ -138,18 +89,33 @@ public static SSLEngineOptions resolveEngineOptions(SSLEngineOptions engineOptio private final Supplier supplier; private final boolean useWorkerPool; - private final Map> configMap = new ConcurrentHashMap<>(); - private final Map> sslChannelProviderMap = new ConcurrentHashMap<>(); + private final Map> configMap; + private final Map> sslChannelProviderMap; - public SSLHelper(SSLEngineOptions sslEngineOptions) { + public SSLHelper(SSLEngineOptions sslEngineOptions, int cacheMaxSize) { + this.configMap = new LruCache<>(cacheMaxSize); + this.sslChannelProviderMap = new LruCache<>(cacheMaxSize); this.supplier = sslEngineOptions::sslContextFactory; this.useWorkerPool = sslEngineOptions.getUseWorkerThread(); } + public SSLHelper(SSLEngineOptions sslEngineOptions) { + this(sslEngineOptions, 256); + } + public Future resolveSslChannelProvider(SSLOptions options, String endpointIdentificationAlgorithm, boolean useSNI, ClientAuth clientAuth, List applicationProtocols, ContextInternal ctx) { - // return buildChannelProvider(options, ctx); - // Two level caching ... for now - return sslChannelProviderMap.computeIfAbsent(new ConfigKey(options), o -> buildChannelProvider(options, endpointIdentificationAlgorithm, useSNI, clientAuth, applicationProtocols, ctx)); + Promise promise; + ConfigKey k = new ConfigKey(options); + synchronized (this) { + Future v = sslChannelProviderMap.get(k); + if (v != null) { + return v; + } + promise = Promise.promise(); + sslChannelProviderMap.put(k, promise.future()); + } + buildChannelProvider(options, endpointIdentificationAlgorithm, useSNI, clientAuth, applicationProtocols, ctx).onComplete(promise); + return promise.future(); } /** @@ -199,38 +165,112 @@ private Future buildConfig(SSLOptions sslOptions, ContextInternal ctx) { if (sslOptions.getTrustOptions() == null && sslOptions.getKeyCertOptions() == null) { return Future.succeededFuture(NULL_CONFIG); } - return configMap.computeIfAbsent(new ConfigKey(sslOptions), o -> { - return ctx.executeBlockingInternal(() -> { - KeyManagerFactory keyManagerFactory = null; - Function keyManagerFactoryMapper = null; - TrustManagerFactory trustManagerFactory = null; - Function trustManagerMapper = null; - List crls = new ArrayList<>(); - if (sslOptions.getKeyCertOptions() != null) { - keyManagerFactory = sslOptions.getKeyCertOptions().getKeyManagerFactory(ctx.owner()); - keyManagerFactoryMapper = sslOptions.getKeyCertOptions().keyManagerFactoryMapper(ctx.owner()); - } - if (sslOptions.getTrustOptions() != null) { - trustManagerFactory = sslOptions.getTrustOptions().getTrustManagerFactory(ctx.owner()); - trustManagerMapper = sslOptions.getTrustOptions().trustManagerMapper(ctx.owner()); - } - List tmp = new ArrayList<>(); - if (sslOptions.getCrlPaths() != null) { - tmp.addAll(sslOptions.getCrlPaths() - .stream() - .map(path -> ctx.owner().resolveFile(path).getAbsolutePath()) - .map(ctx.owner().fileSystem()::readFileBlocking) - .collect(Collectors.toList())); - } - if (sslOptions.getCrlValues() != null) { - tmp.addAll(sslOptions.getCrlValues()); - } - CertificateFactory certificatefactory = CertificateFactory.getInstance("X.509"); - for (Buffer crlValue : tmp) { - crls.addAll(certificatefactory.generateCRLs(new ByteArrayInputStream(crlValue.getBytes()))); - } - return new Config(keyManagerFactory, trustManagerFactory, keyManagerFactoryMapper, trustManagerMapper, crls); - }); - }); + Promise promise = Promise.promise(); + ConfigKey k = new ConfigKey(sslOptions); + synchronized (this) { + Future fut = configMap.get(k); + if (fut != null) { + return fut; + } + configMap.put(k, promise.future()); + } + ctx.executeBlockingInternal(() -> { + KeyManagerFactory keyManagerFactory = null; + Function keyManagerFactoryMapper = null; + TrustManagerFactory trustManagerFactory = null; + Function trustManagerMapper = null; + List crls = new ArrayList<>(); + if (sslOptions.getKeyCertOptions() != null) { + keyManagerFactory = sslOptions.getKeyCertOptions().getKeyManagerFactory(ctx.owner()); + keyManagerFactoryMapper = sslOptions.getKeyCertOptions().keyManagerFactoryMapper(ctx.owner()); + } + if (sslOptions.getTrustOptions() != null) { + trustManagerFactory = sslOptions.getTrustOptions().getTrustManagerFactory(ctx.owner()); + trustManagerMapper = sslOptions.getTrustOptions().trustManagerMapper(ctx.owner()); + } + List tmp = new ArrayList<>(); + if (sslOptions.getCrlPaths() != null) { + tmp.addAll(sslOptions.getCrlPaths() + .stream() + .map(path -> ctx.owner().resolveFile(path).getAbsolutePath()) + .map(ctx.owner().fileSystem()::readFileBlocking) + .collect(Collectors.toList())); + } + if (sslOptions.getCrlValues() != null) { + tmp.addAll(sslOptions.getCrlValues()); + } + CertificateFactory certificatefactory = CertificateFactory.getInstance("X.509"); + for (Buffer crlValue : tmp) { + crls.addAll(certificatefactory.generateCRLs(new ByteArrayInputStream(crlValue.getBytes()))); + } + return new Config(keyManagerFactory, trustManagerFactory, keyManagerFactoryMapper, trustManagerMapper, crls); + }).onComplete(promise); + return promise.future(); + } + + private static class LruCache extends LinkedHashMap { + + private final int maxSize; + + public LruCache(int maxSize) { + if (maxSize < 1) { + throw new UnsupportedOperationException(); + } + this.maxSize = maxSize; + } + + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > maxSize; + } + } + + private final static class ConfigKey { + private final KeyCertOptions keyCertOptions; + private final TrustOptions trustOptions; + private final List crlValues; + public ConfigKey(SSLOptions options) { + this(options.getKeyCertOptions(), options.getTrustOptions(), options.getCrlValues()); + } + public ConfigKey(KeyCertOptions keyCertOptions, TrustOptions trustOptions, List crlValues) { + this.keyCertOptions = keyCertOptions; + this.trustOptions = trustOptions; + this.crlValues = crlValues != null ? new ArrayList<>(crlValues) : null; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj instanceof ConfigKey) { + ConfigKey that = (ConfigKey) obj; + return Objects.equals(keyCertOptions, that.keyCertOptions) && Objects.equals(trustOptions, that.trustOptions) && Objects.equals(crlValues, that.crlValues); + } + return false; + } + + @Override + public int hashCode() { + int hashCode = Objects.hashCode(keyCertOptions); + hashCode = 31 * hashCode + Objects.hashCode(trustOptions); + hashCode = 31 * hashCode + Objects.hashCode(crlValues); + return hashCode; + } + } + + private final static class Config { + private final KeyManagerFactory keyManagerFactory; + private final TrustManagerFactory trustManagerFactory; + private final Function keyManagerFactoryMapper; + private final Function trustManagerMapper; + private final List crls; + public Config(KeyManagerFactory keyManagerFactory, TrustManagerFactory trustManagerFactory, Function keyManagerFactoryMapper, Function trustManagerMapper, List crls) { + this.keyManagerFactory = keyManagerFactory; + this.trustManagerFactory = trustManagerFactory; + this.keyManagerFactoryMapper = keyManagerFactoryMapper; + this.trustManagerMapper = trustManagerMapper; + this.crls = crls; + } } } diff --git a/src/test/java/io/vertx/core/net/impl/SSLHelperTest.java b/src/test/java/io/vertx/core/net/impl/SSLHelperTest.java index 921d8481269..baaf95ea4bf 100755 --- a/src/test/java/io/vertx/core/net/impl/SSLHelperTest.java +++ b/src/test/java/io/vertx/core/net/impl/SSLHelperTest.java @@ -13,13 +13,12 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.handler.ssl.*; +import io.vertx.core.Future; import io.vertx.core.http.ClientAuth; import io.vertx.core.http.HttpServerOptions; import io.vertx.core.impl.ContextInternal; import io.vertx.core.json.JsonObject; -import io.vertx.core.net.OpenSSLEngineOptions; -import io.vertx.core.net.SSLEngineOptions; -import io.vertx.core.net.SSLOptions; +import io.vertx.core.net.*; import io.vertx.test.core.VertxTestBase; import io.vertx.test.tls.Cert; import io.vertx.test.tls.Trust; @@ -144,6 +143,22 @@ public void testPreserveEnabledSecureTransportProtocolOrder() throws Exception { assertEquals(new ArrayList<>(new HttpServerOptions(json).getEnabledSecureTransportProtocols()), expectedProtocols); } + @Test + public void testCache() throws Exception { + ContextInternal ctx = (ContextInternal) vertx.getOrCreateContext(); + SSLHelper helper = new SSLHelper(new JdkSSLEngineOptions(), 4); + SSLOptions options = new SSLOptions().setKeyCertOptions(Cert.SERVER_JKS.get()); + SslChannelProvider f1 = awaitFuture(helper.resolveSslChannelProvider(options, "", false, ClientAuth.NONE, null, ctx)); + SslChannelProvider f2 = awaitFuture(helper.resolveSslChannelProvider(options, "", false, ClientAuth.NONE, null, ctx)); + assertSame(f1, f2); + awaitFuture(helper.resolveSslChannelProvider(new SSLOptions().setKeyCertOptions(Cert.SERVER_PKCS12.get()), "", false, ClientAuth.NONE, null, ctx)); + awaitFuture(helper.resolveSslChannelProvider(new SSLOptions().setKeyCertOptions(Cert.SERVER_PEM.get()), "", false, ClientAuth.NONE, null, ctx)); + awaitFuture(helper.resolveSslChannelProvider(new SSLOptions().setKeyCertOptions(Cert.CLIENT_PEM.get()), "", false, ClientAuth.NONE, null, ctx)); + awaitFuture(helper.resolveSslChannelProvider(new SSLOptions().setKeyCertOptions(Cert.SNI_PEM.get()), "", false, ClientAuth.NONE, null, ctx)); + f2 = awaitFuture(helper.resolveSslChannelProvider(options, "", false, ClientAuth.NONE, null, ctx)); + assertNotSame(f1, f2); + } + @Test public void testDefaultVersions() { testTLSVersions(new SSLOptions(), engine -> {