diff --git a/modules/arrow-flight-rpc/build.gradle b/modules/arrow-flight-rpc/build.gradle index 0d6b272608479..454a44a033814 100644 --- a/modules/arrow-flight-rpc/build.gradle +++ b/modules/arrow-flight-rpc/build.gradle @@ -10,6 +10,7 @@ */ apply plugin: 'opensearch.publish' +apply plugin: 'opensearch.internal-cluster-test' opensearchplugin { description 'Arrow flight based Stream implementation' diff --git a/modules/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java b/modules/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java new file mode 100644 index 0000000000000..7c4691d61cfcf --- /dev/null +++ b/modules/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight; + +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.OpenSearchFlightClient; +import org.apache.arrow.flight.Result; +import org.opensearch.arrow.flight.bootstrap.FlightService; +import org.opensearch.arrow.flight.bootstrap.client.FlightClientManager; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.OpenSearchIntegTestCase; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 10) +public class ArrowFlightServerIT extends OpenSearchIntegTestCase { + + private FlightClientManager flightClientManager; + + @Override + protected Collection> nodePlugins() { + return Collections.singleton(FlightStreamPlugin.class); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + ensureGreen(); + FlightService flightService = internalCluster().getInstance(FlightService.class); + flightClientManager = flightService.getFlightClientManager(); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + } + + public void testArrowFlightEndpoint() throws Exception { + Action pingAction = new Action("ping"); + OpenSearchFlightClient flightClient = flightClientManager.getFlightClient(flightClientManager.getLocalNodeId()); + assertNotNull(flightClient); + Iterator results = flightClient.doAction(pingAction); + flightClient.close(); + } + +} diff --git a/modules/arrow-flight-rpc/src/main/java/org/apache/arrow/flight/OpenSearchFlightClient.java b/modules/arrow-flight-rpc/src/main/java/org/apache/arrow/flight/OpenSearchFlightClient.java index 3d4ad93906a58..782bbbc84761d 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/apache/arrow/flight/OpenSearchFlightClient.java +++ b/modules/arrow-flight-rpc/src/main/java/org/apache/arrow/flight/OpenSearchFlightClient.java @@ -74,8 +74,10 @@ import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; -// TODO - add comment -/** Client for Flight services. */ +/** + * Clone of {@link FlightClient} to support setting SslContext directly. It can be discarded once + * FlightClient supports setting SslContext directly. + */ public class OpenSearchFlightClient implements AutoCloseable { private static final int PENDING_REQUESTS = 5; /** diff --git a/modules/arrow-flight-rpc/src/main/java/org/apache/arrow/flight/OpenSearchFlightServer.java b/modules/arrow-flight-rpc/src/main/java/org/apache/arrow/flight/OpenSearchFlightServer.java index 1a09e6684b946..38da9e4884f47 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/apache/arrow/flight/OpenSearchFlightServer.java +++ b/modules/arrow-flight-rpc/src/main/java/org/apache/arrow/flight/OpenSearchFlightServer.java @@ -52,7 +52,10 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; -// TODO - add comment +/** + * Clone of {@link FlightServer} to support setting SslContext directly. It can be discarded once + * FlightServer.Builder supports setting SslContext directly. + */ public class OpenSearchFlightServer implements AutoCloseable { private static final Logger logger = LogManager.getLogger(OpenSearchFlightServer.class); diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/BaseFlightStreamPlugin.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/BaseFlightStreamPlugin.java index 79d9e6bd2daee..de002e229d8d0 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/BaseFlightStreamPlugin.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/BaseFlightStreamPlugin.java @@ -8,6 +8,7 @@ package org.opensearch.arrow.flight; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.arrow.spi.StreamManager; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; @@ -21,6 +22,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.plugins.ClusterPlugin; import org.opensearch.plugins.NetworkPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SecureTransportSettingsProvider; @@ -31,6 +33,7 @@ import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportService; import org.opensearch.watcher.ResourceWatcherService; import java.util.Collection; @@ -42,7 +45,7 @@ * BaseFlightStreamPlugin is a plugin that implements the StreamManagerPlugin interface. * It provides the necessary components for handling flight streams in the OpenSearch cluster. */ -public abstract class BaseFlightStreamPlugin extends Plugin implements StreamManagerPlugin, NetworkPlugin { +public abstract class BaseFlightStreamPlugin extends Plugin implements StreamManagerPlugin, NetworkPlugin, ClusterPlugin { /** * Constructor for BaseFlightStreamPlugin. @@ -109,7 +112,7 @@ public abstract Map> getSecureTransports( * Returns the StreamManager instance for managing flight streams. */ @Override - public abstract StreamManager getStreamManager(); + public abstract Supplier getStreamManager(); /** * Returns a list of ExecutorBuilder instances for building thread pools used for FlightServer @@ -123,4 +126,7 @@ public abstract Map> getSecureTransports( */ @Override public abstract List> getSettings(); + + @Override + public abstract void onNodeStarted(DiscoveryNode localNode); } diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/FlightStreamPlugin.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/FlightStreamPlugin.java index ba8020469ff79..fa68dad7ef927 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/FlightStreamPlugin.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/FlightStreamPlugin.java @@ -12,6 +12,7 @@ import org.opensearch.arrow.spi.StreamManager; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.Setting; @@ -88,8 +89,8 @@ public Map> getSecureTransports( } @Override - public StreamManager getStreamManager() { - return null; + public Supplier getStreamManager() { + return () -> null; } @Override @@ -101,6 +102,11 @@ public List> getExecutorBuilders(Settings settings) { public List> getSettings() { return List.of(); } + + @Override + public void onNodeStarted(DiscoveryNode localNode) { + + } }; } } @@ -188,7 +194,7 @@ public Map> getSecureTransports( * Gets the StreamManager instance for managing flight streams. */ @Override - public StreamManager getStreamManager() { + public Supplier getStreamManager() { return delegate.getStreamManager(); } @@ -208,4 +214,9 @@ public List> getExecutorBuilders(Settings settings) { public List> getSettings() { return delegate.getSettings(); } + + @Override + public void onNodeStarted(DiscoveryNode localNode) { + delegate.onNodeStarted(localNode); + } } diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java index fcde449cdfda5..0e0fb8f2752e0 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java @@ -9,6 +9,7 @@ package org.opensearch.arrow.flight.bootstrap; import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Location; import org.apache.arrow.flight.OpenSearchFlightServer; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -24,6 +25,8 @@ import org.opensearch.arrow.flight.core.BaseFlightProducer; import org.opensearch.arrow.flight.core.FlightStreamManager; import org.opensearch.arrow.spi.StreamManager; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.SetOnce; import org.opensearch.common.lifecycle.AbstractLifecycleComponent; @@ -33,9 +36,10 @@ import java.io.IOException; import java.security.AccessController; -import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.Objects; +import java.util.Set; +import java.util.function.Supplier; /** * FlightService manages the Arrow Flight server and client for OpenSearch. @@ -48,9 +52,11 @@ public class FlightService extends AbstractLifecycleComponent { private static OpenSearchFlightServer server; private static BufferAllocator allocator; - private static FlightStreamManager streamManager; + private static Supplier streamManager; private static FlightClientManager clientManager; private final SetOnce threadPool = new SetOnce<>(); + private final SetOnce clusterService = new SetOnce<>(); + private final SetOnce secureTransportSettingsProvider = new SetOnce<>(); private SslContextProvider sslContextProvider; @@ -76,14 +82,8 @@ public FlightService(Settings settings) { * @param threadPool The ThreadPool instance. */ public void initialize(ClusterService clusterService, ThreadPool threadPool) { + this.clusterService.trySet(clusterService); this.threadPool.trySet(Objects.requireNonNull(threadPool)); - if (ServerConfig.isSslEnabled()) { - sslContextProvider = new DefaultSslContextProvider(secureTransportSettingsProvider::get); - } else { - sslContextProvider = new DisabledSslContextProvider(); - } - clientManager = new FlightClientManager(() -> allocator, Objects.requireNonNull(clusterService), sslContextProvider); - streamManager = new FlightStreamManager(() -> allocator, clientManager); } /** @@ -94,27 +94,10 @@ public void setSecureTransportSettingsProvider(SecureTransportSettingsProvider s this.secureTransportSettingsProvider.trySet(secureTransportSettingsProvider); } - /** - * Starts the FlightService by initializing and starting the Arrow Flight server. - */ + @Override protected void doStart() { - try { - allocator = AccessController.doPrivileged( - (PrivilegedExceptionAction) () -> new RootAllocator(Integer.MAX_VALUE) - ); - - FlightProducer producer = new BaseFlightProducer(clientManager, streamManager, allocator); - FlightServerBuilder builder = new FlightServerBuilder(threadPool.get(), () -> allocator, producer, sslContextProvider); - server = builder.build(); - server.start(); - logger.info("Arrow Flight server started successfully:{}", ServerConfig.getServerLocation().getUri().toString()); - } catch (IOException e) { - logger.error("Failed to start Arrow Flight server", e); - throw new RuntimeException("Failed to start Arrow Flight server", e); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } + // everything is lazily started in onNodeStart() after TransportService is started } /** @@ -123,10 +106,18 @@ protected void doStart() { @Override protected void doStop() { try { - server.shutdown(); - streamManager.close(); - clientManager.close(); - server.close(); + if (server != null) { + server.shutdown(); + } + if (streamManager != null && streamManager.get() != null) { + streamManager.get().close(); + } + if (clientManager != null) { + clientManager.close(); + } + if (server != null) { + server.close(); + } logger.info("Arrow Flight service closed successfully"); } catch (Exception e) { logger.error("Error while closing Arrow Flight service", e); @@ -151,11 +142,52 @@ public FlightClientManager getFlightClientManager() { return clientManager; } + public void onNodeStart(DiscoveryNode localNode) { + if (isDedicatedClusterManagerNode(localNode)) { + return; + } + try { + allocator = AccessController.doPrivileged( + (PrivilegedExceptionAction) () -> new RootAllocator(Integer.MAX_VALUE) + ); + } catch (Exception e) { + throw new RuntimeException("Failed to initialize Arrow Flight server", e); + } + + if (ServerConfig.isSslEnabled()) { + sslContextProvider = new DefaultSslContextProvider(secureTransportSettingsProvider::get); + } else { + sslContextProvider = new DisabledSslContextProvider(); + } + clientManager = new FlightClientManager(() -> allocator, Objects.requireNonNull(clusterService.get()), sslContextProvider); + FlightStreamManager flightStreamManager = new FlightStreamManager(() -> allocator, clientManager); + streamManager = () -> flightStreamManager; + + try { + FlightProducer producer = new BaseFlightProducer(clientManager, flightStreamManager, allocator); + Location serverLocation = ServerConfig.getLocation(localNode.getAddress().getAddress(), Integer.parseInt(localNode.getAttributes().get("transport.stream.port"))); + FlightServerBuilder builder = new FlightServerBuilder(threadPool.get(), () -> allocator, producer, sslContextProvider, serverLocation); + server = builder.build(); + server.start(); + logger.info("Arrow Flight server started successfully:{}", serverLocation); + } catch (IOException e) { + logger.error("Failed to start Arrow Flight server", e); + throw new RuntimeException("Failed to start Arrow Flight server", e); + } + } + + private boolean isDedicatedClusterManagerNode(DiscoveryNode node) { + Set nodeRoles = node.getRoles(); + return nodeRoles.size() == 1 && + (nodeRoles.contains(DiscoveryNodeRole.CLUSTER_MANAGER_ROLE) || + nodeRoles.contains(DiscoveryNodeRole.MASTER_ROLE)); + } + /** * Retrieves the StreamManager used by the FlightService. * @return The StreamManager instance. */ - public StreamManager getStreamManager() { + public Supplier getStreamManager() { return streamManager; } diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPluginImpl.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPluginImpl.java index f86e3cf0c5000..b7e030872fa73 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPluginImpl.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPluginImpl.java @@ -13,6 +13,7 @@ import org.opensearch.arrow.spi.StreamManager; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.Setting; @@ -30,6 +31,7 @@ import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportService; import org.opensearch.watcher.ResourceWatcherService; import java.util.Collection; @@ -109,15 +111,21 @@ public Map> getSecureTransports( SecureTransportSettingsProvider secureTransportSettingsProvider, Tracer tracer ) { + flightService.setSecureTransportSettingsProvider(secureTransportSettingsProvider); return Collections.emptyMap(); } + @Override + public void onNodeStarted(DiscoveryNode localNode) { + flightService.onNodeStart(localNode); + } + /** * Gets the StreamManager instance for managing flight streams. */ @Override - public StreamManager getStreamManager() { + public Supplier getStreamManager() { return flightService.getStreamManager(); } diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/client/FlightClientManager.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/client/FlightClientManager.java index 58b27189cba75..53f23d6a599f2 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/client/FlightClientManager.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/client/FlightClientManager.java @@ -20,6 +20,7 @@ import org.opensearch.cluster.ClusterStateListener; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.FeatureFlags; import java.util.Map; import java.util.Objects; @@ -27,6 +28,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Supplier; +import static org.opensearch.common.util.FeatureFlags.ARROW_STREAMS_SETTING; + /** * Manages Flight client connections to OpenSearch nodes in a cluster. * This class maintains a pool of Flight clients for internode communication, @@ -118,15 +121,13 @@ private FlightClientHolder buildFlightClient(String nodeId) { if (node.getVersion().before(minVersion)) { return null; } - - String arrowStreamsEnabled = node.getAttributes().get("arrow.streams.enabled"); - if (!"true".equals(arrowStreamsEnabled)) { + if (!FeatureFlags.isEnabled(ARROW_STREAMS_SETTING)) { return null; } String clientPort = node.getAttributes().get("transport.stream.port"); FlightClientBuilder builder = new FlightClientBuilder( - node.getHostAddress(), + node.getAddress().getAddress(), Integer.parseInt(clientPort), allocator.get(), sslContextProvider diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/server/FlightServerBuilder.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/server/FlightServerBuilder.java index f8fde465febdc..3a957e995e0ae 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/server/FlightServerBuilder.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/server/FlightServerBuilder.java @@ -31,6 +31,7 @@ public class FlightServerBuilder { private final Supplier allocator; private final FlightProducer producer; private final SslContextProvider sslContextProvider; + private final Location location; /** * Creates a new FlightServerBuilder instance with the specified configurations. @@ -44,12 +45,14 @@ public FlightServerBuilder( ThreadPool threadPool, Supplier allocator, FlightProducer producer, - SslContextProvider sslContextProvider + SslContextProvider sslContextProvider, + Location location ) { this.threadPool = threadPool; this.allocator = allocator; this.producer = producer; this.sslContextProvider = sslContextProvider; + this.location = location; } /** @@ -57,7 +60,6 @@ public FlightServerBuilder( * @return A configured OpenSearchFlightServer instance */ public OpenSearchFlightServer build() throws IOException { - final Location location = ServerConfig.getServerLocation(); ExecutorService executorService = threadPool.executor(FLIGHT_THREAD_POOL_NAME); OpenSearchFlightServer.Builder builder = OpenSearchFlightServer.builder(allocator.get(), location, producer); builder.executor(executorService); diff --git a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/server/ServerConfig.java b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/server/ServerConfig.java index 21a352ee83155..462236b4394c6 100644 --- a/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/server/ServerConfig.java +++ b/modules/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/server/ServerConfig.java @@ -30,7 +30,7 @@ public class ServerConfig { */ public ServerConfig() {} - static final Setting STREAM_PORT = Setting.intSetting( + public static final Setting STREAM_PORT = Setting.intSetting( "node.attr.transport.stream.port", 9880, 1024, @@ -91,7 +91,7 @@ public ServerConfig() {} static final String FLIGHT_THREAD_POOL_NAME = "flight-server"; private static final String host = "localhost"; - private static int port; + public static int port; private static boolean enableSsl; private static ScalingExecutorBuilder executorBuilder; @@ -166,7 +166,7 @@ public static List> getSettings() { }; } - private static Location getLocation(String address, int port) { + public static Location getLocation(String address, int port) { if (enableSsl) { return Location.forGrpcTls(address, port); } @@ -176,7 +176,7 @@ private static Location getLocation(String address, int port) { private static class Netty4Configs { public static final Setting NETTY_ALLOCATOR_NUM_DIRECT_ARENAS = Setting.intSetting( "io.netty.allocator.numDirectArenas", - 1, // TODO - 2 * the number of available processors + 1, // TODO - 2 * the number of available processors; to be confirmed and set after running benchmarks 1, Setting.Property.NodeScope ); diff --git a/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java b/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java index 24757e712a1bb..2d538740bd204 100644 --- a/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java +++ b/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.Collection; import java.util.List; +import java.util.function.Supplier; import static org.opensearch.common.util.FeatureFlags.ARROW_STREAMS_SETTING; import static org.mockito.Mockito.mock; @@ -69,7 +70,7 @@ public void testPluginEnableAndDisable() throws IOException { assertNotNull(executorBuilders); assertFalse(executorBuilders.isEmpty()); - StreamManager streamManager = plugin.getStreamManager(); + Supplier streamManager = plugin.getStreamManager(); assertNotNull(streamManager); List> settings = plugin.getSettings(); diff --git a/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/server/FlightServerBuilderTests.java b/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/server/FlightServerBuilderTests.java index d9e7a276d30d4..6215ba54923f5 100644 --- a/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/server/FlightServerBuilderTests.java +++ b/modules/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/server/FlightServerBuilderTests.java @@ -43,7 +43,7 @@ public void tearDown() throws Exception { } public void testBuilderConstructorWithValidInputs() throws IOException { - FlightServerBuilder newBuilder = new FlightServerBuilder(threadPool, () -> allocator, producer, mock(SslContextProvider.class)); + FlightServerBuilder newBuilder = new FlightServerBuilder(threadPool, () -> allocator, producer, mock(SslContextProvider.class), null); assertNotNull(newBuilder); assertNotNull(newBuilder.build()); } @@ -51,14 +51,14 @@ public void testBuilderConstructorWithValidInputs() throws IOException { public void testBuilderConstructorWithNullThreadPool() { expectThrows( NullPointerException.class, - () -> (new FlightServerBuilder(null, () -> allocator, producer, mock(SslContextProvider.class))).build() + () -> (new FlightServerBuilder(null, () -> allocator, producer, mock(SslContextProvider.class), null)).build() ); } public void testBuilderConstructorWithNullAllocator() { expectThrows( NullPointerException.class, - () -> (new FlightServerBuilder(threadPool, null, producer, mock(SslContextProvider.class))).build() + () -> (new FlightServerBuilder(threadPool, null, producer, mock(SslContextProvider.class), null)).build() ); } @@ -66,7 +66,7 @@ public void testBuilderConstructorWithSslNull() { SslContextProvider sslContextProvider = mock(SslContextProvider.class); when(sslContextProvider.isSslEnabled()).thenReturn(true); when(sslContextProvider.getServerSslContext()).thenReturn(null); - FlightServerBuilder newBuilder = new FlightServerBuilder(threadPool, () -> allocator, producer, sslContextProvider); + FlightServerBuilder newBuilder = new FlightServerBuilder(threadPool, () -> allocator, producer, sslContextProvider, null); assertNotNull(newBuilder); expectThrows(NullPointerException.class, newBuilder::build); } diff --git a/server/src/main/java/org/opensearch/arrow/spi/StreamManagerWrapper.java b/server/src/main/java/org/opensearch/arrow/spi/StreamManagerWrapper.java index 4f308d641181e..033dfb0af5468 100644 --- a/server/src/main/java/org/opensearch/arrow/spi/StreamManagerWrapper.java +++ b/server/src/main/java/org/opensearch/arrow/spi/StreamManagerWrapper.java @@ -20,16 +20,17 @@ import org.opensearch.tasks.TaskManager; import java.io.IOException; +import java.util.function.Supplier; /** * Wraps a StreamManager to make it work with the TaskManager. */ public class StreamManagerWrapper implements StreamManager { - private final StreamManager streamManager; + private final Supplier streamManager; private final TaskManager taskManager; - public StreamManagerWrapper(StreamManager streamManager, TaskManager taskManager) { + public StreamManagerWrapper(Supplier streamManager, TaskManager taskManager) { super(); this.streamManager = streamManager; this.taskManager = taskManager; @@ -38,24 +39,24 @@ public StreamManagerWrapper(StreamManager streamManager, TaskManager taskManager @Override public StreamTicket registerStream(StreamProducer producer, TaskId parentTaskId) { StreamProducerTaskWrapper wrappedProducer = new StreamProducerTaskWrapper(producer, taskManager, parentTaskId); - StreamTicket ticket = streamManager.registerStream(wrappedProducer, parentTaskId); + StreamTicket ticket = streamManager.get().registerStream(wrappedProducer, parentTaskId); wrappedProducer.setDescription(ticket.toString()); return ticket; } @Override public StreamReader getStreamReader(StreamTicket ticket) { - return streamManager.getStreamReader(ticket); + return streamManager.get().getStreamReader(ticket); } @Override public StreamTicketFactory getStreamTicketFactory() { - return streamManager.getStreamTicketFactory(); + return streamManager.get().getStreamTicketFactory(); } @Override public void close() throws Exception { - streamManager.close(); + streamManager.get().close(); } static class StreamProducerTaskWrapper implements StreamProducer { diff --git a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java index 4f0462f0b5cdd..4be45aed70023 100644 --- a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java +++ b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java @@ -129,7 +129,7 @@ public class FeatureFlags { ); public static final String ARROW_STREAMS = "opensearch.experimental.feature.arrow.streams.enabled"; - public static final Setting ARROW_STREAMS_SETTING = Setting.boolSetting(ARROW_STREAMS, true, Property.NodeScope); + public static final Setting ARROW_STREAMS_SETTING = Setting.boolSetting(ARROW_STREAMS, false, Property.NodeScope); private static final List> ALL_FEATURE_FLAG_SETTINGS = List.of( REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING, diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 3a04f51517334..16b2c75b419db 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -311,6 +311,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import java.util.function.UnaryOperator; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -1392,9 +1393,10 @@ protected Node( ); } if (!streamManagerPlugins.isEmpty()) { - if (streamManagerPlugins.get(0).getStreamManager() != null) { + Supplier baseStreamManager = streamManagerPlugins.get(0).getStreamManager(); + if (baseStreamManager != null) { streamManager = new StreamManagerWrapper( - streamManagerPlugins.get(0).getStreamManager(), + baseStreamManager, transportService.getTaskManager() ); logger.info("StreamManager initialized"); @@ -1802,6 +1804,7 @@ public void onTimeout(TimeValue timeout) { writePortsFile("http", http.boundAddress()); } + logger.info("started"); pluginsService.filterPlugins(ClusterPlugin.class).forEach(plugin -> plugin.onNodeStarted(clusterService.localNode())); diff --git a/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java b/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java index 62a3f2327acd7..aefde6b4842d0 100644 --- a/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java +++ b/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java @@ -10,6 +10,8 @@ import org.opensearch.arrow.spi.StreamManager; +import java.util.function.Supplier; + /** * An interface for OpenSearch plugins to implement to provide a StreamManager. * This interface is used by the Arrow Flight plugin to get the StreamManager instance. @@ -22,5 +24,5 @@ public interface StreamManagerPlugin { * * @return The StreamManager instance */ - StreamManager getStreamManager(); + Supplier getStreamManager(); } diff --git a/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java b/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java index fa5fb736f518f..68d9530e84972 100644 --- a/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java +++ b/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java @@ -165,6 +165,7 @@ import static org.opensearch.test.NodeRoles.onlyRoles; import static org.opensearch.test.NodeRoles.removeRoles; import static org.opensearch.test.OpenSearchTestCase.assertBusy; +import static org.opensearch.test.OpenSearchTestCase.getBaseStreamPort; import static org.opensearch.test.OpenSearchTestCase.randomBoolean; import static org.opensearch.test.OpenSearchTestCase.randomFrom; import static org.hamcrest.Matchers.equalTo; @@ -237,6 +238,8 @@ public final class InternalTestCluster extends TestCluster { static final int DEFAULT_MIN_NUM_CLIENT_NODES = 0; static final int DEFAULT_MAX_NUM_CLIENT_NODES = 1; + private static final AtomicInteger FLIGHT_PORT_COUNTER = new AtomicInteger(0); + /* Sorted map to make traverse order reproducible. * The map of nodes is never mutated so individual reads are safe without synchronization. * Updates are intended to follow a copy-on-write approach. */ @@ -755,7 +758,7 @@ private Settings getNodeSettings( final Settings.Builder updatedSettings = Settings.builder(); updatedSettings.put(Environment.PATH_HOME_SETTING.getKey(), baseDir); - + updatedSettings.put("node.attr.transport.stream.port", getBaseStreamPort() + FLIGHT_PORT_COUNTER.getAndIncrement()); if (numDataPaths > 1) { updatedSettings.putList( Environment.PATH_DATA_SETTING.getKey(), diff --git a/test/framework/src/main/java/org/opensearch/test/OpenSearchTestCase.java b/test/framework/src/main/java/org/opensearch/test/OpenSearchTestCase.java index b180187303a60..c6d215b443545 100644 --- a/test/framework/src/main/java/org/opensearch/test/OpenSearchTestCase.java +++ b/test/framework/src/main/java/org/opensearch/test/OpenSearchTestCase.java @@ -1768,7 +1768,7 @@ public static String getPortRange() { return getBasePort() + "-" + (getBasePort() + 99); // upper bound is inclusive } - protected static int getBasePort() { + private static int generateBasePort(int start) { // some tests use MockTransportService to do network based testing. Yet, we run tests in multiple JVMs that means // concurrent tests could claim port that another JVM just released and if that test tries to simulate a disconnect it might // be smart enough to re-connect depending on what is tested. To reduce the risk, since this is very hard to debug we use @@ -1792,7 +1792,15 @@ protected static int getBasePort() { startAt = (int) Math.floorMod(workerId - 1, 223L) + 1; } assert startAt >= 0 : "Unexpected test worker Id, resulting port range would be negative"; - return 10300 + (startAt * 100); + return start + (startAt * 100); + } + + protected static int getBaseStreamPort() { + return generateBasePort(9880); + } + + protected static int getBasePort() { + return generateBasePort(10300); } protected static InetAddress randomIp(boolean v4) {