Skip to content

Commit

Permalink
interim changes for integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhmaurya committed Dec 12, 2024
1 parent 7c7437d commit 4d69db1
Show file tree
Hide file tree
Showing 19 changed files with 208 additions and 70 deletions.
1 change: 1 addition & 0 deletions modules/arrow-flight-rpc/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
*/

apply plugin: 'opensearch.publish'
apply plugin: 'opensearch.internal-cluster-test'

opensearchplugin {
description 'Arrow flight based Stream implementation'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.arrow.flight;

import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.OpenSearchFlightClient;
import org.apache.arrow.flight.Result;
import org.opensearch.arrow.flight.bootstrap.FlightService;
import org.opensearch.arrow.flight.bootstrap.client.FlightClientManager;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.OpenSearchIntegTestCase;

import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;

@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 10)
public class ArrowFlightServerIT extends OpenSearchIntegTestCase {

private FlightClientManager flightClientManager;

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(FlightStreamPlugin.class);
}

@Override
public void setUp() throws Exception {
super.setUp();
ensureGreen();
FlightService flightService = internalCluster().getInstance(FlightService.class);
flightClientManager = flightService.getFlightClientManager();
}

@Override
public void tearDown() throws Exception {
super.tearDown();
}

public void testArrowFlightEndpoint() throws Exception {
Action pingAction = new Action("ping");
OpenSearchFlightClient flightClient = flightClientManager.getFlightClient(flightClientManager.getLocalNodeId());
assertNotNull(flightClient);
Iterator<Result> results = flightClient.doAction(pingAction);
flightClient.close();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;

// TODO - add comment
/** Client for Flight services. */
/**
* Clone of {@link FlightClient} to support setting SslContext directly. It can be discarded once
* FlightClient supports setting SslContext directly.
*/
public class OpenSearchFlightClient implements AutoCloseable {
private static final int PENDING_REQUESTS = 5;
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;

// TODO - add comment
/**
* Clone of {@link FlightServer} to support setting SslContext directly. It can be discarded once
* FlightServer.Builder supports setting SslContext directly.
*/
public class OpenSearchFlightServer implements AutoCloseable {
private static final Logger logger = LogManager.getLogger(OpenSearchFlightServer.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.opensearch.arrow.flight;

import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.plugins.ClusterPlugin;
import org.opensearch.plugins.NetworkPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SecureTransportSettingsProvider;
Expand All @@ -31,6 +33,7 @@
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Transport;
import org.opensearch.transport.TransportService;
import org.opensearch.watcher.ResourceWatcherService;

import java.util.Collection;
Expand All @@ -42,7 +45,7 @@
* BaseFlightStreamPlugin is a plugin that implements the StreamManagerPlugin interface.
* It provides the necessary components for handling flight streams in the OpenSearch cluster.
*/
public abstract class BaseFlightStreamPlugin extends Plugin implements StreamManagerPlugin, NetworkPlugin {
public abstract class BaseFlightStreamPlugin extends Plugin implements StreamManagerPlugin, NetworkPlugin, ClusterPlugin {

/**
* Constructor for BaseFlightStreamPlugin.
Expand Down Expand Up @@ -109,7 +112,7 @@ public abstract Map<String, Supplier<Transport>> getSecureTransports(
* Returns the StreamManager instance for managing flight streams.
*/
@Override
public abstract StreamManager getStreamManager();
public abstract Supplier<StreamManager> getStreamManager();

/**
* Returns a list of ExecutorBuilder instances for building thread pools used for FlightServer
Expand All @@ -123,4 +126,7 @@ public abstract Map<String, Supplier<Transport>> getSecureTransports(
*/
@Override
public abstract List<Setting<?>> getSettings();

@Override
public abstract void onNodeStarted(DiscoveryNode localNode);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.Setting;
Expand Down Expand Up @@ -88,8 +89,8 @@ public Map<String, Supplier<Transport>> getSecureTransports(
}

@Override
public StreamManager getStreamManager() {
return null;
public Supplier<StreamManager> getStreamManager() {
return () -> null;
}

@Override
Expand All @@ -101,6 +102,11 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
public List<Setting<?>> getSettings() {
return List.of();
}

@Override
public void onNodeStarted(DiscoveryNode localNode) {

}
};
}
}
Expand Down Expand Up @@ -188,7 +194,7 @@ public Map<String, Supplier<Transport>> getSecureTransports(
* Gets the StreamManager instance for managing flight streams.
*/
@Override
public StreamManager getStreamManager() {
public Supplier<StreamManager> getStreamManager() {
return delegate.getStreamManager();
}

Expand All @@ -208,4 +214,9 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
public List<Setting<?>> getSettings() {
return delegate.getSettings();
}

@Override
public void onNodeStarted(DiscoveryNode localNode) {
delegate.onNodeStarted(localNode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.arrow.flight.bootstrap;

import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.OpenSearchFlightServer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
Expand All @@ -24,6 +25,8 @@
import org.opensearch.arrow.flight.core.BaseFlightProducer;
import org.opensearch.arrow.flight.core.FlightStreamManager;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodeRole;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.SetOnce;
import org.opensearch.common.lifecycle.AbstractLifecycleComponent;
Expand All @@ -33,9 +36,10 @@

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;

/**
* FlightService manages the Arrow Flight server and client for OpenSearch.
Expand All @@ -48,9 +52,11 @@ public class FlightService extends AbstractLifecycleComponent {

private static OpenSearchFlightServer server;
private static BufferAllocator allocator;
private static FlightStreamManager streamManager;
private static Supplier<StreamManager> streamManager;
private static FlightClientManager clientManager;
private final SetOnce<ThreadPool> threadPool = new SetOnce<>();
private final SetOnce<ClusterService> clusterService = new SetOnce<>();

private final SetOnce<SecureTransportSettingsProvider> secureTransportSettingsProvider = new SetOnce<>();
private SslContextProvider sslContextProvider;

Expand All @@ -76,14 +82,8 @@ public FlightService(Settings settings) {
* @param threadPool The ThreadPool instance.
*/
public void initialize(ClusterService clusterService, ThreadPool threadPool) {
this.clusterService.trySet(clusterService);
this.threadPool.trySet(Objects.requireNonNull(threadPool));
if (ServerConfig.isSslEnabled()) {
sslContextProvider = new DefaultSslContextProvider(secureTransportSettingsProvider::get);
} else {
sslContextProvider = new DisabledSslContextProvider();
}
clientManager = new FlightClientManager(() -> allocator, Objects.requireNonNull(clusterService), sslContextProvider);
streamManager = new FlightStreamManager(() -> allocator, clientManager);
}

/**
Expand All @@ -94,27 +94,10 @@ public void setSecureTransportSettingsProvider(SecureTransportSettingsProvider s
this.secureTransportSettingsProvider.trySet(secureTransportSettingsProvider);
}

/**
* Starts the FlightService by initializing and starting the Arrow Flight server.
*/

@Override
protected void doStart() {
try {
allocator = AccessController.doPrivileged(
(PrivilegedExceptionAction<BufferAllocator>) () -> new RootAllocator(Integer.MAX_VALUE)
);

FlightProducer producer = new BaseFlightProducer(clientManager, streamManager, allocator);
FlightServerBuilder builder = new FlightServerBuilder(threadPool.get(), () -> allocator, producer, sslContextProvider);
server = builder.build();
server.start();
logger.info("Arrow Flight server started successfully:{}", ServerConfig.getServerLocation().getUri().toString());
} catch (IOException e) {
logger.error("Failed to start Arrow Flight server", e);
throw new RuntimeException("Failed to start Arrow Flight server", e);
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
// everything is lazily started in onNodeStart() after TransportService is started
}

/**
Expand All @@ -123,10 +106,18 @@ protected void doStart() {
@Override
protected void doStop() {
try {
server.shutdown();
streamManager.close();
clientManager.close();
server.close();
if (server != null) {
server.shutdown();
}
if (streamManager != null && streamManager.get() != null) {
streamManager.get().close();
}
if (clientManager != null) {
clientManager.close();
}
if (server != null) {
server.close();
}
logger.info("Arrow Flight service closed successfully");
} catch (Exception e) {
logger.error("Error while closing Arrow Flight service", e);
Expand All @@ -151,11 +142,52 @@ public FlightClientManager getFlightClientManager() {
return clientManager;
}

public void onNodeStart(DiscoveryNode localNode) {
if (isDedicatedClusterManagerNode(localNode)) {
return;
}
try {
allocator = AccessController.doPrivileged(
(PrivilegedExceptionAction<BufferAllocator>) () -> new RootAllocator(Integer.MAX_VALUE)
);
} catch (Exception e) {
throw new RuntimeException("Failed to initialize Arrow Flight server", e);
}

if (ServerConfig.isSslEnabled()) {
sslContextProvider = new DefaultSslContextProvider(secureTransportSettingsProvider::get);
} else {
sslContextProvider = new DisabledSslContextProvider();
}
clientManager = new FlightClientManager(() -> allocator, Objects.requireNonNull(clusterService.get()), sslContextProvider);
FlightStreamManager flightStreamManager = new FlightStreamManager(() -> allocator, clientManager);
streamManager = () -> flightStreamManager;

try {
FlightProducer producer = new BaseFlightProducer(clientManager, flightStreamManager, allocator);
Location serverLocation = ServerConfig.getLocation(localNode.getAddress().getAddress(), Integer.parseInt(localNode.getAttributes().get("transport.stream.port")));
FlightServerBuilder builder = new FlightServerBuilder(threadPool.get(), () -> allocator, producer, sslContextProvider, serverLocation);
server = builder.build();
server.start();
logger.info("Arrow Flight server started successfully:{}", serverLocation);
} catch (IOException e) {
logger.error("Failed to start Arrow Flight server", e);
throw new RuntimeException("Failed to start Arrow Flight server", e);
}
}

private boolean isDedicatedClusterManagerNode(DiscoveryNode node) {
Set<DiscoveryNodeRole> nodeRoles = node.getRoles();
return nodeRoles.size() == 1 &&
(nodeRoles.contains(DiscoveryNodeRole.CLUSTER_MANAGER_ROLE) ||
nodeRoles.contains(DiscoveryNodeRole.MASTER_ROLE));
}

/**
* Retrieves the StreamManager used by the FlightService.
* @return The StreamManager instance.
*/
public StreamManager getStreamManager() {
public Supplier<StreamManager> getStreamManager() {
return streamManager;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.Setting;
Expand All @@ -30,6 +31,7 @@
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Transport;
import org.opensearch.transport.TransportService;
import org.opensearch.watcher.ResourceWatcherService;

import java.util.Collection;
Expand Down Expand Up @@ -109,15 +111,21 @@ public Map<String, Supplier<Transport>> getSecureTransports(
SecureTransportSettingsProvider secureTransportSettingsProvider,
Tracer tracer
) {

flightService.setSecureTransportSettingsProvider(secureTransportSettingsProvider);
return Collections.emptyMap();
}

@Override
public void onNodeStarted(DiscoveryNode localNode) {
flightService.onNodeStart(localNode);
}

/**
* Gets the StreamManager instance for managing flight streams.
*/
@Override
public StreamManager getStreamManager() {
public Supplier<StreamManager> getStreamManager() {
return flightService.getStreamManager();
}

Expand Down
Loading

0 comments on commit 4d69db1

Please sign in to comment.