Skip to content

Commit

Permalink
SSLHelper cache with eviction
Browse files Browse the repository at this point in the history
  • Loading branch information
vietj committed Sep 15, 2023
1 parent 6582b64 commit eb12891
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 92 deletions.
218 changes: 129 additions & 89 deletions src/main/java/io/vertx/core/net/impl/SSLHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import io.netty.handler.ssl.OpenSsl;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.VertxException;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.ClientAuth;
Expand All @@ -25,7 +26,6 @@
import java.security.cert.CRL;
import java.security.cert.CertificateFactory;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -36,55 +36,6 @@
*/
public class SSLHelper {

private final static class ConfigKey {
private final KeyCertOptions keyCertOptions;
private final TrustOptions trustOptions;
private final List<Buffer> crlValues;
public ConfigKey(SSLOptions options) {
this(options.getKeyCertOptions(), options.getTrustOptions(), options.getCrlValues());
}
public ConfigKey(KeyCertOptions keyCertOptions, TrustOptions trustOptions, List<Buffer> crlValues) {
this.keyCertOptions = keyCertOptions;
this.trustOptions = trustOptions;
this.crlValues = crlValues != null ? new ArrayList<>(crlValues) : null;
}

@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (obj instanceof ConfigKey) {
ConfigKey that = (ConfigKey) obj;
return Objects.equals(keyCertOptions, that.keyCertOptions) && Objects.equals(trustOptions, that.trustOptions) && Objects.equals(crlValues, that.crlValues);
}
return false;
}

@Override
public int hashCode() {
int hashCode = Objects.hashCode(keyCertOptions);
hashCode = 31 * hashCode + Objects.hashCode(trustOptions);
hashCode = 31 * hashCode + Objects.hashCode(crlValues);
return hashCode;
}
}

private final static class Config {
private final KeyManagerFactory keyManagerFactory;
private final TrustManagerFactory trustManagerFactory;
private final Function<String, KeyManagerFactory> keyManagerFactoryMapper;
private final Function<String, TrustManager[]> trustManagerMapper;
private final List<CRL> crls;
public Config(KeyManagerFactory keyManagerFactory, TrustManagerFactory trustManagerFactory, Function<String, KeyManagerFactory> keyManagerFactoryMapper, Function<String, TrustManager[]> trustManagerMapper, List<CRL> crls) {
this.keyManagerFactory = keyManagerFactory;
this.trustManagerFactory = trustManagerFactory;
this.keyManagerFactoryMapper = keyManagerFactoryMapper;
this.trustManagerMapper = trustManagerMapper;
this.crls = crls;
}
}

private static final Config NULL_CONFIG = new Config(null, null, null, null, null);

static final EnumMap<ClientAuth, io.netty.handler.ssl.ClientAuth> CLIENT_AUTH_MAPPING = new EnumMap<>(ClientAuth.class);
Expand Down Expand Up @@ -138,18 +89,33 @@ public static SSLEngineOptions resolveEngineOptions(SSLEngineOptions engineOptio

private final Supplier<SslContextFactory> supplier;
private final boolean useWorkerPool;
private final Map<ConfigKey, Future<Config>> configMap = new ConcurrentHashMap<>();
private final Map<ConfigKey, Future<SslChannelProvider>> sslChannelProviderMap = new ConcurrentHashMap<>();
private final Map<ConfigKey, Future<Config>> configMap;
private final Map<ConfigKey, Future<SslChannelProvider>> sslChannelProviderMap;

public SSLHelper(SSLEngineOptions sslEngineOptions) {
public SSLHelper(SSLEngineOptions sslEngineOptions, int cacheMaxSize) {
this.configMap = new LruCache<>(cacheMaxSize);
this.sslChannelProviderMap = new LruCache<>(cacheMaxSize);
this.supplier = sslEngineOptions::sslContextFactory;
this.useWorkerPool = sslEngineOptions.getUseWorkerThread();
}

public SSLHelper(SSLEngineOptions sslEngineOptions) {
this(sslEngineOptions, 256);
}

public Future<SslChannelProvider> resolveSslChannelProvider(SSLOptions options, String endpointIdentificationAlgorithm, boolean useSNI, ClientAuth clientAuth, List<String> applicationProtocols, ContextInternal ctx) {
// return buildChannelProvider(options, ctx);
// Two level caching ... for now
return sslChannelProviderMap.computeIfAbsent(new ConfigKey(options), o -> buildChannelProvider(options, endpointIdentificationAlgorithm, useSNI, clientAuth, applicationProtocols, ctx));
Promise<SslChannelProvider> promise;
ConfigKey k = new ConfigKey(options);
synchronized (this) {
Future<SslChannelProvider> v = sslChannelProviderMap.get(k);
if (v != null) {
return v;
}
promise = Promise.promise();
sslChannelProviderMap.put(k, promise.future());
}
buildChannelProvider(options, endpointIdentificationAlgorithm, useSNI, clientAuth, applicationProtocols, ctx).onComplete(promise);
return promise.future();
}

/**
Expand Down Expand Up @@ -199,38 +165,112 @@ private Future<Config> buildConfig(SSLOptions sslOptions, ContextInternal ctx) {
if (sslOptions.getTrustOptions() == null && sslOptions.getKeyCertOptions() == null) {
return Future.succeededFuture(NULL_CONFIG);
}
return configMap.computeIfAbsent(new ConfigKey(sslOptions), o -> {
return ctx.executeBlockingInternal(() -> {
KeyManagerFactory keyManagerFactory = null;
Function<String, KeyManagerFactory> keyManagerFactoryMapper = null;
TrustManagerFactory trustManagerFactory = null;
Function<String, TrustManager[]> trustManagerMapper = null;
List<CRL> crls = new ArrayList<>();
if (sslOptions.getKeyCertOptions() != null) {
keyManagerFactory = sslOptions.getKeyCertOptions().getKeyManagerFactory(ctx.owner());
keyManagerFactoryMapper = sslOptions.getKeyCertOptions().keyManagerFactoryMapper(ctx.owner());
}
if (sslOptions.getTrustOptions() != null) {
trustManagerFactory = sslOptions.getTrustOptions().getTrustManagerFactory(ctx.owner());
trustManagerMapper = sslOptions.getTrustOptions().trustManagerMapper(ctx.owner());
}
List<Buffer> tmp = new ArrayList<>();
if (sslOptions.getCrlPaths() != null) {
tmp.addAll(sslOptions.getCrlPaths()
.stream()
.map(path -> ctx.owner().resolveFile(path).getAbsolutePath())
.map(ctx.owner().fileSystem()::readFileBlocking)
.collect(Collectors.toList()));
}
if (sslOptions.getCrlValues() != null) {
tmp.addAll(sslOptions.getCrlValues());
}
CertificateFactory certificatefactory = CertificateFactory.getInstance("X.509");
for (Buffer crlValue : tmp) {
crls.addAll(certificatefactory.generateCRLs(new ByteArrayInputStream(crlValue.getBytes())));
}
return new Config(keyManagerFactory, trustManagerFactory, keyManagerFactoryMapper, trustManagerMapper, crls);
});
});
Promise<Config> promise = Promise.promise();
ConfigKey k = new ConfigKey(sslOptions);
synchronized (this) {
Future<Config> fut = configMap.get(k);
if (fut != null) {
return fut;
}
configMap.put(k, promise.future());
}
ctx.executeBlockingInternal(() -> {
KeyManagerFactory keyManagerFactory = null;
Function<String, KeyManagerFactory> keyManagerFactoryMapper = null;
TrustManagerFactory trustManagerFactory = null;
Function<String, TrustManager[]> trustManagerMapper = null;
List<CRL> crls = new ArrayList<>();
if (sslOptions.getKeyCertOptions() != null) {
keyManagerFactory = sslOptions.getKeyCertOptions().getKeyManagerFactory(ctx.owner());
keyManagerFactoryMapper = sslOptions.getKeyCertOptions().keyManagerFactoryMapper(ctx.owner());
}
if (sslOptions.getTrustOptions() != null) {
trustManagerFactory = sslOptions.getTrustOptions().getTrustManagerFactory(ctx.owner());
trustManagerMapper = sslOptions.getTrustOptions().trustManagerMapper(ctx.owner());
}
List<Buffer> tmp = new ArrayList<>();
if (sslOptions.getCrlPaths() != null) {
tmp.addAll(sslOptions.getCrlPaths()
.stream()
.map(path -> ctx.owner().resolveFile(path).getAbsolutePath())
.map(ctx.owner().fileSystem()::readFileBlocking)
.collect(Collectors.toList()));
}
if (sslOptions.getCrlValues() != null) {
tmp.addAll(sslOptions.getCrlValues());
}
CertificateFactory certificatefactory = CertificateFactory.getInstance("X.509");
for (Buffer crlValue : tmp) {
crls.addAll(certificatefactory.generateCRLs(new ByteArrayInputStream(crlValue.getBytes())));
}
return new Config(keyManagerFactory, trustManagerFactory, keyManagerFactoryMapper, trustManagerMapper, crls);
}).onComplete(promise);
return promise.future();
}

private static class LruCache<K, V> extends LinkedHashMap<K, V> {

private final int maxSize;

public LruCache(int maxSize) {
if (maxSize < 1) {
throw new UnsupportedOperationException();
}
this.maxSize = maxSize;
}

@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
return size() > maxSize;
}
}

private final static class ConfigKey {
private final KeyCertOptions keyCertOptions;
private final TrustOptions trustOptions;
private final List<Buffer> crlValues;
public ConfigKey(SSLOptions options) {
this(options.getKeyCertOptions(), options.getTrustOptions(), options.getCrlValues());
}
public ConfigKey(KeyCertOptions keyCertOptions, TrustOptions trustOptions, List<Buffer> crlValues) {
this.keyCertOptions = keyCertOptions;
this.trustOptions = trustOptions;
this.crlValues = crlValues != null ? new ArrayList<>(crlValues) : null;
}

@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (obj instanceof ConfigKey) {
ConfigKey that = (ConfigKey) obj;
return Objects.equals(keyCertOptions, that.keyCertOptions) && Objects.equals(trustOptions, that.trustOptions) && Objects.equals(crlValues, that.crlValues);
}
return false;
}

@Override
public int hashCode() {
int hashCode = Objects.hashCode(keyCertOptions);
hashCode = 31 * hashCode + Objects.hashCode(trustOptions);
hashCode = 31 * hashCode + Objects.hashCode(crlValues);
return hashCode;
}
}

private final static class Config {
private final KeyManagerFactory keyManagerFactory;
private final TrustManagerFactory trustManagerFactory;
private final Function<String, KeyManagerFactory> keyManagerFactoryMapper;
private final Function<String, TrustManager[]> trustManagerMapper;
private final List<CRL> crls;
public Config(KeyManagerFactory keyManagerFactory, TrustManagerFactory trustManagerFactory, Function<String, KeyManagerFactory> keyManagerFactoryMapper, Function<String, TrustManager[]> trustManagerMapper, List<CRL> crls) {
this.keyManagerFactory = keyManagerFactory;
this.trustManagerFactory = trustManagerFactory;
this.keyManagerFactoryMapper = keyManagerFactoryMapper;
this.trustManagerMapper = trustManagerMapper;
this.crls = crls;
}
}
}
21 changes: 18 additions & 3 deletions src/test/java/io/vertx/core/net/impl/SSLHelperTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@

import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.ssl.*;
import io.vertx.core.Future;
import io.vertx.core.http.ClientAuth;
import io.vertx.core.http.HttpServerOptions;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.json.JsonObject;
import io.vertx.core.net.OpenSSLEngineOptions;
import io.vertx.core.net.SSLEngineOptions;
import io.vertx.core.net.SSLOptions;
import io.vertx.core.net.*;
import io.vertx.test.core.VertxTestBase;
import io.vertx.test.tls.Cert;
import io.vertx.test.tls.Trust;
Expand Down Expand Up @@ -144,6 +143,22 @@ public void testPreserveEnabledSecureTransportProtocolOrder() throws Exception {
assertEquals(new ArrayList<>(new HttpServerOptions(json).getEnabledSecureTransportProtocols()), expectedProtocols);
}

@Test
public void testCache() throws Exception {
ContextInternal ctx = (ContextInternal) vertx.getOrCreateContext();
SSLHelper helper = new SSLHelper(new JdkSSLEngineOptions(), 4);
SSLOptions options = new SSLOptions().setKeyCertOptions(Cert.SERVER_JKS.get());
SslChannelProvider f1 = awaitFuture(helper.resolveSslChannelProvider(options, "", false, ClientAuth.NONE, null, ctx));
SslChannelProvider f2 = awaitFuture(helper.resolveSslChannelProvider(options, "", false, ClientAuth.NONE, null, ctx));
assertSame(f1, f2);
awaitFuture(helper.resolveSslChannelProvider(new SSLOptions().setKeyCertOptions(Cert.SERVER_PKCS12.get()), "", false, ClientAuth.NONE, null, ctx));
awaitFuture(helper.resolveSslChannelProvider(new SSLOptions().setKeyCertOptions(Cert.SERVER_PEM.get()), "", false, ClientAuth.NONE, null, ctx));
awaitFuture(helper.resolveSslChannelProvider(new SSLOptions().setKeyCertOptions(Cert.CLIENT_PEM.get()), "", false, ClientAuth.NONE, null, ctx));
awaitFuture(helper.resolveSslChannelProvider(new SSLOptions().setKeyCertOptions(Cert.SNI_PEM.get()), "", false, ClientAuth.NONE, null, ctx));
f2 = awaitFuture(helper.resolveSslChannelProvider(options, "", false, ClientAuth.NONE, null, ctx));
assertNotSame(f1, f2);
}

@Test
public void testDefaultVersions() {
testTLSVersions(new SSLOptions(), engine -> {
Expand Down

0 comments on commit eb12891

Please sign in to comment.