From 700b92ad00c3b6cfcde767b7b400d7334fe60c29 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:45:57 -0700 Subject: [PATCH] Add rule validation in AnomalyDetector constructor (#1341) (#1342) * Add rule validation in AnomalyDetector constructor This commit introduces rule validation within the AnomalyDetector constructor. Any validation errors are now propagated and displayed on the frontend to ensure immediate feedback. Testing: * Verified that validation errors are properly propagated and shown on the frontend. * Added UTs to cover the new validation logic. * address Amit's comments --------- (cherry picked from commit 9cdbceef5821c5c6fc7437018a082d614a77a689) Signed-off-by: Kaituo Li Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- build.gradle | 3 - .../org/opensearch/ad/ml/ADModelManager.java | 26 +- .../opensearch/ad/model/AnomalyDetector.java | 121 +++++++++ .../ml/MemoryAwareConcurrentHashmap.java | 47 ---- .../timeseries/model/ValidationIssueType.java | 3 +- .../transport/SuggestConfigParamRequest.java | 10 +- .../ad/model/AnomalyDetectorTests.java | 236 ++++++++++++++++++ .../SuggestConfigParamRequestTests.java | 140 +++++++++++ .../SuggestConfigParamResponseTests.java | 147 +++++++++++ 9 files changed, 654 insertions(+), 79 deletions(-) create mode 100644 src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamRequestTests.java create mode 100644 src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamResponseTests.java diff --git a/build.gradle b/build.gradle index 4addff425..ca10d416d 100644 --- a/build.gradle +++ b/build.gradle @@ -699,9 +699,6 @@ List jacocoExclusions = [ // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', - 'org.opensearch.timeseries.transport.SuggestConfigParamResponse', - 'org.opensearch.timeseries.transport.SuggestConfigParamRequest', - 'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap', 'org.opensearch.timeseries.transport.ResultBulkTransportAction', 'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler', 'org.opensearch.timeseries.transport.handler.ResultIndexingHandler', diff --git a/src/main/java/org/opensearch/ad/ml/ADModelManager.java b/src/main/java/org/opensearch/ad/ml/ADModelManager.java index 354b02557..a8f0febd9 100644 --- a/src/main/java/org/opensearch/ad/ml/ADModelManager.java +++ b/src/main/java/org/opensearch/ad/ml/ADModelManager.java @@ -15,7 +15,6 @@ import java.time.Duration; import java.time.Instant; import java.util.Arrays; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -42,7 +41,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; -import org.opensearch.timeseries.AnalysisModelSize; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; @@ -52,7 +50,6 @@ import org.opensearch.timeseries.ml.ModelColdStart; import org.opensearch.timeseries.ml.ModelManager; import org.opensearch.timeseries.ml.ModelState; -import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DateUtils; @@ -69,9 +66,7 @@ * A facade managing ML operations and models. */ public class ADModelManager extends - ModelManager - implements - AnalysisModelSize { + ModelManager { protected static final String ENTITY_SAMPLE = "sp"; protected static final String ENTITY_RCF = "rcf"; protected static final String ENTITY_THRESHOLD = "th"; @@ -594,25 +589,6 @@ public List getPreviewResults(Features features, AnomalyDete }).collect(Collectors.toList()); } - /** - * Get all RCF partition's size corresponding to a detector. Thresholding models' size is a constant since they are small in size (KB). - * @param detectorId detector id - * @return a map of model id to its memory size - */ - @Override - public Map getModelSize(String detectorId) { - Map res = new HashMap<>(); - res.putAll(forests.getModelSize(detectorId)); - thresholds - .entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(detectorId)) - .forEach(entry -> { - res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes()); - }); - return res; - } - /** * Get a RCF model's total updates. * @param modelId the RCF model's id diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index 9b057d000..2572299b1 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -22,6 +22,7 @@ import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -109,6 +110,7 @@ public Integer getShingleSize(Integer customShingleSize) { @Deprecated public static final String DETECTION_DATE_RANGE_FIELD = "detection_date_range"; public static final String RULES_FIELD = "rules"; + private static final String SUPPRESSION_RULE_ISSUE_PREFIX = "Suppression Rule Error: "; protected String detectorType; @@ -229,6 +231,8 @@ public AnomalyDetector( issueType = ValidationIssueType.CATEGORY; } + validateRules(features, rules); + checkAndThrowValidationErrors(ValidationAspect.DETECTOR); this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name(); @@ -720,4 +724,121 @@ private static Boolean onlyParseBooleanValue(XContentParser parser) throws IOExc } return null; } + + /** + * Validates each condition in the list of rules against the list of features. + * Checks that: + * - The feature name exists in the list of features. + * - The related feature is enabled. + * - The value is not NaN and is positive. + * + * @param features The list of available features. Must not be null. + * @param rules The list of rules containing conditions to validate. Can be null. + */ + private void validateRules(List features, List rules) { + // Null check for rules + if (rules == null || rules.isEmpty()) { + return; // No suppression rules to validate; consider as valid + } + + // Null check for features + if (features == null || features.isEmpty()) { + // Cannot proceed with validation if features are null but rules are not null + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "Features are not defined while suppression rules are provided."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Create a map of feature names to their enabled status for quick lookup + Map featureEnabledMap = new HashMap<>(); + for (Feature feature : features) { + if (feature != null && feature.getName() != null) { + featureEnabledMap.put(feature.getName(), feature.getEnabled()); + } + } + + // Iterate over each rule + for (Rule rule : rules) { + if (rule == null || rule.getConditions() == null) { + // Invalid rule or conditions list is null + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A suppression rule or its conditions are not properly defined."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Iterate over each condition in the rule + for (Condition condition : rule.getConditions()) { + if (condition == null) { + // Invalid condition + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A condition within a suppression rule is not properly defined."; + this.issueType = ValidationIssueType.RULE; + return; + } + + String featureName = condition.getFeatureName(); + + // Check if the feature name is null + if (featureName == null) { + // Feature name is required + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A condition is missing the feature name."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Check if the feature exists + if (!featureEnabledMap.containsKey(featureName)) { + // Feature does not exist + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + + "Feature \"" + + featureName + + "\" specified in a suppression rule does not exist."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Check if the feature is enabled + if (!featureEnabledMap.get(featureName)) { + // Feature is not enabled + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + + "Feature \"" + + featureName + + "\" specified in a suppression rule is not enabled."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // other threshold types may not have value operand + ThresholdType thresholdType = condition.getThresholdType(); + if (thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_MARGIN + || thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_MARGIN + || thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_RATIO + || thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_RATIO) { + // Check if the value is not NaN + double value = condition.getValue(); + if (Double.isNaN(value)) { + // Value is NaN + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + + "The threshold value for feature \"" + + featureName + + "\" is not a valid number."; + this.issueType = ValidationIssueType.RULE; + return; + } + + // Check if the value is positive + if (value <= 0) { + // Value is not positive + this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + + "The threshold value for feature \"" + + featureName + + "\" must be a positive number."; + this.issueType = ValidationIssueType.RULE; + return; + } + } + } + } + + // All checks passed + } } diff --git a/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java index b477f454a..cc723b5f4 100644 --- a/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java +++ b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java @@ -11,9 +11,6 @@ package org.opensearch.timeseries.ml; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import org.opensearch.timeseries.MemoryTracker; @@ -55,48 +52,4 @@ public ModelState put(String key, ModelState value) } return previousAssociatedState; } - - /** - * Gets all of a config's model sizes hosted on a node - * - * @param configId config Id - * @return a map of model id to its memory size - */ - public Map getModelSize(String configId) { - Map res = new HashMap<>(); - super.entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) - .forEach(entry -> { - Optional modelOptional = entry.getValue().getModel(); - if (modelOptional.isPresent()) { - res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(modelOptional.get())); - } - }); - return res; - } - - /** - * Checks if a model exists for the given config. - * @param configId Config Id - * @return `true` if the model exists, `false` otherwise. - */ - public boolean doesModelExist(String configId) { - return super.entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) - .anyMatch(n -> true); - } - - public boolean hostIfPossible(String modelId, ModelState toUpdate) { - return Optional - .ofNullable(toUpdate) - .filter(state -> state.getModel().isPresent()) - .filter(state -> memoryTracker.isHostingAllowed(modelId, state.getModel().get())) - .map(state -> { - super.put(modelId, toUpdate); - return true; - }) - .orElse(false); - } } diff --git a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java index bd4a86cee..55d039eb4 100644 --- a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java +++ b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java @@ -38,7 +38,8 @@ public enum ValidationIssueType implements Name { SUBAGGREGATION(SearchTopForecastResultRequest.SUBAGGREGATIONS_FIELD), RECENCY_EMPHASIS(Config.RECENCY_EMPHASIS_FIELD), DESCRIPTION(Config.DESCRIPTION_FIELD), - HISTORY(Config.HISTORY_INTERVAL_FIELD); + HISTORY(Config.HISTORY_INTERVAL_FIELD), + RULE(AnomalyDetector.RULES_FIELD); private String name; diff --git a/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java index 3c7b9f45a..ee17f163c 100644 --- a/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java @@ -33,9 +33,9 @@ public class SuggestConfigParamRequest extends ActionRequest { public SuggestConfigParamRequest(StreamInput in) throws IOException { super(in); context = in.readEnum(AnalysisType.class); - if (context.isAD()) { + if (getContext().isAD()) { config = new AnomalyDetector(in); - } else if (context.isForecast()) { + } else if (getContext().isForecast()) { config = new Forecaster(in); } else { throw new UnsupportedOperationException("This method is not supported"); @@ -55,7 +55,7 @@ public SuggestConfigParamRequest(AnalysisType context, Config config, String par @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeEnum(context); + out.writeEnum(getContext()); config.writeTo(out); out.writeString(param); out.writeTimeValue(requestTimeout); @@ -77,4 +77,8 @@ public String getParam() { public TimeValue getRequestTimeout() { return requestTimeout; } + + public AnalysisType getContext() { + return context; + } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java index b10c1afa4..902edb949 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java @@ -18,6 +18,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; @@ -1047,4 +1048,239 @@ public void testNullFixedValue() throws IOException { assertEquals("Got: " + e.getMessage(), "Enabled features are present, but no default fill values are provided.", e.getMessage()); assertEquals("Got :" + e.getType(), ValidationIssueType.IMPUTATION, e.getType()); } + + /** + * Test that validation passes when rules are null. + */ + public void testValidateRulesWithNullRules() throws IOException { + AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(null).build(); + + // Should pass validation; no exception should be thrown + assertNotNull(detector); + } + + /** + * Test that validation fails when features are null but rules are provided. + */ + public void testValidateRulesWithNullFeatures() throws IOException { + List rules = Arrays.asList(createValidRule()); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(0).setFeatureAttributes(null).setRules(rules).build(); + fail("Expected ValidationException due to features being null while rules are provided"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: Features are not defined while suppression rules are provided.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a rule is null. + */ + public void testValidateRulesWithNullRule() throws IOException { + List rules = Arrays.asList((Rule) null); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to null rule"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: A suppression rule or its conditions are not properly defined.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a rule's conditions are null. + */ + public void testValidateRulesWithNullConditions() throws IOException { + Rule rule = new Rule(Action.IGNORE_ANOMALY, null); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to rule with null conditions"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: A suppression rule or its conditions are not properly defined.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a condition is null. + */ + public void testValidateRulesWithNullCondition() throws IOException { + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList((Condition) null)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to null condition in rule"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: A condition within a suppression rule is not properly defined.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a condition's featureName is null. + */ + public void testValidateRulesWithNullFeatureName() throws IOException { + Condition condition = new Condition( + null, // featureName is null + ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, + Operator.LTE, + 0.5 + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to condition with null feature name"); + } catch (ValidationException e) { + assertEquals("Suppression Rule Error: A condition is missing the feature name.", e.getMessage()); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when a condition's featureName does not exist in features. + */ + public void testValidateRulesWithNonexistentFeatureName() throws IOException { + Condition condition = new Condition( + "nonexistentFeature", // featureName not in features + ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, + Operator.LTE, + 0.5 + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(rules).build(); + fail("Expected ValidationException due to condition with nonexistent feature name"); + } catch (ValidationException e) { + assertEquals( + "Suppression Rule Error: Feature \"nonexistentFeature\" specified in a suppression rule does not exist.", + e.getMessage() + ); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when the feature in condition is disabled. + */ + public void testValidateRulesWithDisabledFeature() throws IOException { + String featureName = "testFeature"; + Feature disabledFeature = TestHelpers.randomFeature(featureName, "agg", false); + + Condition condition = new Condition(featureName, ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, Operator.LTE, 0.5); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setFeatureAttributes(Arrays.asList(disabledFeature)).setRules(rules).build(); + fail("Expected ValidationException due to condition with disabled feature"); + } catch (ValidationException e) { + assertEquals( + "Suppression Rule Error: Feature \"" + featureName + "\" specified in a suppression rule is not enabled.", + e.getMessage() + ); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when the value in condition is NaN for specific threshold types. + */ + public void testValidateRulesWithNaNValue() throws IOException { + String featureName = "testFeature"; + Feature enabledFeature = TestHelpers.randomFeature(featureName, "agg", true); + + Condition condition = new Condition( + featureName, + ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, + Operator.LTE, + Double.NaN // Value is NaN + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setFeatureAttributes(Arrays.asList(enabledFeature)).setRules(rules).build(); + fail("Expected ValidationException due to NaN value in condition"); + } catch (ValidationException e) { + assertEquals( + "Suppression Rule Error: The threshold value for feature \"" + featureName + "\" is not a valid number.", + e.getMessage() + ); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation fails when the value in condition is not positive for specific threshold types. + */ + public void testValidateRulesWithNonPositiveValue() throws IOException { + String featureName = "testFeature"; + Feature enabledFeature = TestHelpers.randomFeature(featureName, "agg", true); + + Condition condition = new Condition( + featureName, + ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, + Operator.LTE, + -0.5 // Value is negative + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + try { + TestHelpers.AnomalyDetectorBuilder.newInstance(1).setFeatureAttributes(Arrays.asList(enabledFeature)).setRules(rules).build(); + fail("Expected ValidationException due to non-positive value in condition"); + } catch (ValidationException e) { + assertEquals( + "Suppression Rule Error: The threshold value for feature \"" + featureName + "\" must be a positive number.", + e.getMessage() + ); + assertEquals(ValidationIssueType.RULE, e.getType()); + } + } + + /** + * Test that validation passes when the threshold type is not one of the specified types and value is NaN. + */ + public void testValidateRulesWithOtherThresholdTypeAndNaNValue() throws IOException { + String featureName = "testFeature"; + Feature enabledFeature = TestHelpers.randomFeature(featureName, "agg", true); + + Condition condition = new Condition( + featureName, + null, // ThresholdType is null or another type not specified + Operator.LTE, + Double.NaN // Value is NaN, but should not be checked + ); + Rule rule = new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + List rules = Arrays.asList(rule); + + AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder + .newInstance(1) + .setFeatureAttributes(Arrays.asList(enabledFeature)) + .setRules(rules) + .build(); + + // Should pass validation; no exception should be thrown + assertNotNull(detector); + } + + /** + * Helper method to create a valid rule for testing. + * + * @return A valid Rule instance + */ + private Rule createValidRule() { + String featureName = "testFeature"; + Condition condition = new Condition(featureName, ThresholdType.ACTUAL_OVER_EXPECTED_RATIO, Operator.LTE, 0.5); + return new Rule(Action.IGNORE_ANOMALY, Arrays.asList(condition)); + } } diff --git a/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamRequestTests.java b/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamRequestTests.java new file mode 100644 index 000000000..e3c772c38 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamRequestTests.java @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TestHelpers; + +public class SuggestConfigParamRequestTests extends OpenSearchTestCase { + private NamedWriteableRegistry registry; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + List namedWriteables = new ArrayList<>(); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, BoolQueryBuilder.NAME, BoolQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RangeQueryBuilder.NAME, RangeQueryBuilder::new)); + namedWriteables + .add( + new NamedWriteableRegistry.Entry( + AggregationBuilder.class, + ValueCountAggregationBuilder.NAME, + ValueCountAggregationBuilder::new + ) + ); + registry = new NamedWriteableRegistry(namedWriteables); + } + + /** + * Test serialization and deserialization of SuggestConfigParamRequest with AD context. + */ + public void testSerializationDeserialization_ADContext() throws IOException { + // Create an AnomalyDetector instance + AnomalyDetector detector = createTestAnomalyDetector(); + + AnalysisType context = AnalysisType.AD; + String param = "test-param"; + TimeValue requestTimeout = TimeValue.timeValueSeconds(30); + + SuggestConfigParamRequest originalRequest = new SuggestConfigParamRequest(context, detector, param, requestTimeout); + + // Serialize the request + BytesStreamOutput out = new BytesStreamOutput(); + originalRequest.writeTo(out); + + // Deserialize the request + StreamInput in = out.bytes().streamInput(); + + StreamInput input = new NamedWriteableAwareStreamInput(in, registry); + + SuggestConfigParamRequest deserializedRequest = new SuggestConfigParamRequest(input); + + // Verify the deserialized object + assertEquals(context, deserializedRequest.getContext()); + assertTrue(deserializedRequest.getConfig() instanceof AnomalyDetector); + AnomalyDetector deserializedDetector = (AnomalyDetector) deserializedRequest.getConfig(); + assertEquals(detector, deserializedDetector); + assertEquals(param, deserializedRequest.getParam()); + assertEquals(requestTimeout, deserializedRequest.getRequestTimeout()); + } + + /** + * Test serialization and deserialization of SuggestConfigParamRequest with Forecast context. + */ + public void testSerializationDeserialization_ForecastContext() throws IOException { + // Create a Forecaster instance using TestHelpers.ForecasterBuilder + Forecaster forecaster = createTestForecaster(); + + AnalysisType context = AnalysisType.FORECAST; + String param = "test-param"; + TimeValue requestTimeout = TimeValue.timeValueSeconds(30); + + SuggestConfigParamRequest originalRequest = new SuggestConfigParamRequest(context, forecaster, param, requestTimeout); + + // Serialize the request + BytesStreamOutput out = new BytesStreamOutput(); + originalRequest.writeTo(out); + + // Deserialize the request + StreamInput in = out.bytes().streamInput(); + StreamInput input = new NamedWriteableAwareStreamInput(in, registry); + + SuggestConfigParamRequest deserializedRequest = new SuggestConfigParamRequest(input); + + // Verify the deserialized object + assertEquals(context, deserializedRequest.getContext()); + assertTrue(deserializedRequest.getConfig() instanceof Forecaster); + Forecaster deserializedForecaster = (Forecaster) deserializedRequest.getConfig(); + assertEquals(forecaster, deserializedForecaster); + assertEquals(param, deserializedRequest.getParam()); + assertEquals(requestTimeout, deserializedRequest.getRequestTimeout()); + } + + // Helper methods to create test instances of AnomalyDetector and Forecaster + + private AnomalyDetector createTestAnomalyDetector() { + // Use TestHelpers.AnomalyDetectorBuilder to create a test AnomalyDetector instance + try { + return TestHelpers.AnomalyDetectorBuilder.newInstance(1).build(); + } catch (IOException e) { + fail("Failed to create test AnomalyDetector: " + e.getMessage()); + return null; + } + } + + private Forecaster createTestForecaster() { + // Use TestHelpers.ForecasterBuilder to create a Forecaster instance + try { + return TestHelpers.ForecasterBuilder.newInstance().build(); + } catch (IOException e) { + fail("Failed to create test Forecaster: " + e.getMessage()); + return null; + } + } +} diff --git a/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamResponseTests.java b/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamResponseTests.java new file mode 100644 index 000000000..7e083731e --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/transport/SuggestConfigParamResponseTests.java @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.time.temporal.ChronoUnit; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Mergeable; + +public class SuggestConfigParamResponseTests extends OpenSearchTestCase { + + /** + * Test the serialization and deserialization of SuggestConfigParamResponse. + * This covers both the writeTo(StreamOutput out) method and the + * SuggestConfigParamResponse(StreamInput in) constructor. + */ + public void testSerializationDeserialization() throws IOException { + // Create an instance of SuggestConfigParamResponse + IntervalTimeConfiguration interval = new IntervalTimeConfiguration(10, ChronoUnit.MINUTES); + Integer horizon = 12; + Integer history = 24; + + SuggestConfigParamResponse originalResponse = new SuggestConfigParamResponse(interval, horizon, history); + + // Serialize it to a BytesStreamOutput + BytesStreamOutput out = new BytesStreamOutput(); + originalResponse.writeTo(out); + + // Deserialize it from the StreamInput + StreamInput in = out.bytes().streamInput(); + SuggestConfigParamResponse deserializedResponse = new SuggestConfigParamResponse(in); + + // Assert that the deserialized object matches the original + assertEquals(originalResponse.getInterval(), deserializedResponse.getInterval()); + assertEquals(originalResponse.getHorizon(), deserializedResponse.getHorizon()); + assertEquals(originalResponse.getHistory(), deserializedResponse.getHistory()); + } + + /** + * Test the toXContent(XContentBuilder builder) method. + * This ensures that the response is correctly converted to XContent. + */ + public void testToXContent() throws IOException { + IntervalTimeConfiguration interval = new IntervalTimeConfiguration(10, ChronoUnit.MINUTES); + Integer horizon = 12; + Integer history = 24; + + SuggestConfigParamResponse response = new SuggestConfigParamResponse(interval, horizon, history); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder); + String jsonString = builder.toString(); + + // Expected JSON string contains interval, horizon, history + assertTrue("actual json: " + jsonString, jsonString.contains("\"interval\"")); + assertTrue("actual json: " + jsonString, jsonString.contains("\"interval\":10")); + assertTrue("actual json: " + jsonString, jsonString.contains("\"unit\":\"Minutes\"")); + assertTrue("actual json: " + jsonString, jsonString.contains("\"horizon\":12")); + assertTrue("actual json: " + jsonString, jsonString.contains("\"history\":24")); + } + + /** + * Test the merge(Mergeable other) method when it returns early due to: + * - other being null + * - this being equal to other + * - getClass() != other.getClass() + */ + public void testMerge_ReturnEarly() { + IntervalTimeConfiguration interval = new IntervalTimeConfiguration(10, ChronoUnit.MINUTES); + Integer horizon = 12; + Integer history = 24; + + SuggestConfigParamResponse response = new SuggestConfigParamResponse(interval, horizon, history); + + // Case when other == null + response.merge(null); + + // Response should remain unchanged + assertEquals(interval, response.getInterval()); + assertEquals(horizon, response.getHorizon()); + assertEquals(history, response.getHistory()); + + // Case when this == other + response.merge(response); + + // Response should remain unchanged + assertEquals(interval, response.getInterval()); + assertEquals(horizon, response.getHorizon()); + assertEquals(history, response.getHistory()); + + // Case when getClass() != other.getClass() + Mergeable other = new Mergeable() { + @Override + public void merge(Mergeable other) { + // No operation + } + }; + + response.merge(other); + + // Response should remain unchanged + assertEquals(interval, response.getInterval()); + assertEquals(horizon, response.getHorizon()); + assertEquals(history, response.getHistory()); + } + + /** + * Test the merge(Mergeable other) method when otherProfile.getHistory() != null. + * This ensures that the history field is correctly updated from the other object. + */ + public void testMerge_OtherHasHistory() { + IntervalTimeConfiguration interval = new IntervalTimeConfiguration(10, ChronoUnit.MINUTES); + Integer horizon = 12; + Integer history = null; // Initial history is null + + SuggestConfigParamResponse response = new SuggestConfigParamResponse(interval, horizon, history); + + Integer otherHistory = 30; + + SuggestConfigParamResponse otherResponse = new SuggestConfigParamResponse(null, null, otherHistory); + + // Before merge, response.history is null + assertNull(response.getHistory()); + + // Merge + response.merge(otherResponse); + + // After merge, response.history should be updated + assertEquals(otherHistory, response.getHistory()); + + // Interval and horizon should remain unchanged + assertEquals(interval, response.getInterval()); + assertEquals(horizon, response.getHorizon()); + } +}