From 1507dd4b4045bd8e8cdf505d07041ce304788ca0 Mon Sep 17 00:00:00 2001 From: Tyler Ohlsen Date: Wed, 21 Feb 2024 15:59:34 -0800 Subject: [PATCH] Inject NamedWriteableRegistry in AD node client (#1164) Signed-off-by: Tyler Ohlsen --- .../ad/client/AnomalyDetectionNodeClient.java | 11 +++++-- .../transport/GetAnomalyDetectorResponse.java | 17 +++++++--- .../AnomalyDetectionNodeClientTests.java | 3 +- .../GetAnomalyDetectorResponseTests.java | 31 +++++++++++++++++++ 4 files changed, 54 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java index 60bb274ab..714ad353f 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java @@ -17,12 +17,15 @@ import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; public class AnomalyDetectionNodeClient implements AnomalyDetectionClient { private final Client client; + private final NamedWriteableRegistry namedWriteableRegistry; - public AnomalyDetectionNodeClient(Client client) { + public AnomalyDetectionNodeClient(Client client, NamedWriteableRegistry namedWriteableRegistry) { this.client = client; + this.namedWriteableRegistry = namedWriteableRegistry; } @Override @@ -46,6 +49,9 @@ public void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionL // We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic // ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins. + // Additionally, we need to inject the configured NamedWriteableRegistry so NamedWriteables (present in sub-fields of + // GetAnomalyDetectorResponse) are able to be re-serialized and prevent errors like the following: + // "can't read named writeable from StreamInput" private ActionListener getAnomalyDetectorResponseActionListener( ActionListener listener ) { @@ -53,7 +59,8 @@ private ActionListener getAnomalyDetectorResponseAct listener.onResponse(getAnomalyDetectorResponse); }, listener::onFailure); ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { - GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse.fromActionResponse(actionResponse); + GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse + .fromActionResponse(actionResponse, this.namedWriteableRegistry); return response; }); return actionListener; diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java index f3808dab2..5db241377 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java @@ -22,6 +22,8 @@ import org.opensearch.ad.model.EntityProfile; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -218,16 +220,21 @@ public AnomalyDetector getDetector() { return detector; } - public static GetAnomalyDetectorResponse fromActionResponse(ActionResponse actionResponse) { + public static GetAnomalyDetectorResponse fromActionResponse( + ActionResponse actionResponse, + NamedWriteableRegistry namedWriteableRegistry + ) { if (actionResponse instanceof GetAnomalyDetectorResponse) { return (GetAnomalyDetectorResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos); actionResponse.writeTo(osso); - try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new GetAnomalyDetectorResponse(input); - } + InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray())); + NamedWriteableAwareStreamInput namedWriteableAwareInput = new NamedWriteableAwareStreamInput(input, namedWriteableRegistry); + return new GetAnomalyDetectorResponse(namedWriteableAwareInput); } catch (IOException e) { throw new UncheckedIOException("failed to parse ActionResponse into GetAnomalyDetectorResponse", e); } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java index c142e5e3d..614bf445a 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java @@ -40,6 +40,7 @@ import org.opensearch.client.Client; import org.opensearch.common.lucene.uid.Versions; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -64,7 +65,7 @@ public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTest @Before public void setup() { clientSpy = spy(client()); - adClient = new AnomalyDetectionNodeClient(clientSpy); + adClient = new AnomalyDetectionNodeClient(clientSpy, mock(NamedWriteableRegistry.class)); } @Test diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java index ace2c3c8c..236cd2b58 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java @@ -17,8 +17,11 @@ import java.util.Collection; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.rest.RestStatus; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; @@ -76,6 +79,21 @@ public void testSerializationWithJobAndTask() throws IOException { assertEquals(response.getDetector(), parsedResponse.getDetector()); } + public void testFromActionResponse() throws IOException { + GetAnomalyDetectorResponse response = createGetAnomalyDetectorResponse(true, true); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + + GetAnomalyDetectorResponse reserializedResponse = GetAnomalyDetectorResponse + .fromActionResponse((ActionResponse) response, writableRegistry()); + assertEquals(response, reserializedResponse); + + ActionResponse invalidActionResponse = new TestActionResponse(input); + assertThrows(Exception.class, () -> GetAnomalyDetectorResponse.fromActionResponse(invalidActionResponse, writableRegistry())); + + } + private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean returnJob, boolean returnTask) throws IOException { GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( randomLong(), @@ -95,4 +113,17 @@ private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean retu ); return response; } + + // A test ActionResponse class with an inactive writeTo class. Used to ensure exceptions + // are thrown when parsing implementations of such class. + private class TestActionResponse extends ActionResponse { + public TestActionResponse(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + return; + } + } }