Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in PageListener #1351

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ protected void doExecute(Task task, ResultBulkRequestType request, ActionListene
// all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure).
long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes();
float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits;
@SuppressWarnings("rawtypes")
List<? extends ResultWriteRequest> results = request.getResults();

if (results == null || results.size() < 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ class PageListener implements ActionListener<CompositeRetriever.Page> {
private String taskId;
private AtomicInteger receivedPages;
private AtomicInteger sentOutPages;
// By introducing pagesInFlight and incrementing it in the main thread before asynchronous processing begins,
// we ensure that the count of in-flight pages is accurate at all times. This allows us to reliably determine
// when all pages have been processed.
private AtomicInteger pagesInFlight;

PageListener(PageIterator pageIterator, Config config, long dataStartTime, long dataEndTime, String taskId) {
this.pageIterator = pageIterator;
Expand All @@ -220,14 +224,21 @@ class PageListener implements ActionListener<CompositeRetriever.Page> {
this.taskId = taskId;
this.receivedPages = new AtomicInteger();
this.sentOutPages = new AtomicInteger();
this.pagesInFlight = new AtomicInteger();
}

@Override
public void onResponse(CompositeRetriever.Page entityFeatures) {
// start processing next page after sending out features for previous page
if (pageIterator.hasNext()) {
pageIterator.next(this);
} else if (config.getImputationOption() != null) {
scheduleImputeHCTask();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we first going inside here and then incrementing the pages inFlight, shouldn't we first increment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, changed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure on this actually, then this pagesInFlight.get() == 0 will never be reached? I was just thinking of first case also if its 0 it might pass right away

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we finished processing all of the inflight requests, pagesInFlight.get() == 0, right?

}

// Increment pagesInFlight to track the processing of this page
pagesInFlight.incrementAndGet();

if (entityFeatures != null && false == entityFeatures.isEmpty()) {
LOG
.info(
Expand Down Expand Up @@ -309,19 +320,15 @@ public void onResponse(CompositeRetriever.Page entityFeatures) {
} catch (Exception e) {
LOG.error("Unexpected exception", e);
handleException(e);
} finally {
// Decrement pagesInFlight after processing is complete
pagesInFlight.decrementAndGet();
}
});
}

if (!pageIterator.hasNext() && config.getImputationOption() != null) {
if (sentOutPages.get() > 0) {
// at least 1 page sent out. Wait until all responses are back.
scheduleImputeHCTask();
} else {
// no data in current interval. Send out impute request right away.
imputeHC(dataStartTime, dataEndTime, configId, taskId);
}

} else {
// No entity features to process
// Decrement pagesInFlight immediately
pagesInFlight.decrementAndGet();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, here every page is from the feature aggregation results, we want to make sure we received every page and then sending it out to processing (means sending the aggregated feature data to the correct model and doing .process())? then after we check each entity if data was received and send impute call to place the imputed value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

}
}

Expand Down Expand Up @@ -358,7 +365,10 @@ private void scheduleImputeHCTask() {

@Override
public void run() {
if (sentOutPages.get() == receivedPages.get()) {
// By using pagesInFlight in the condition within scheduleImputeHCTask, we ensure that imputeHC
// is executed only after all pages have been processed (pagesInFlight.get() == 0) and all
// responses have been received (sentOutPages.get() == receivedPages.get()).
if (pagesInFlight.get() == 0 && sentOutPages.get() == receivedPages.get()) {
if (!sent.get()) {
// since we don't know when cancel will succeed, need sent to ensure imputeHC is only called once
sent.set(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ protected String genDetector(
long windowDelayMinutes,
boolean hc,
ImputationMethod imputation,
long trainTimeMillis
long trainTimeMillis,
String name
) {
StringBuilder sb = new StringBuilder();
// common part
Expand Down
35 changes: 30 additions & 5 deletions src/test/java/org/opensearch/ad/e2e/MissingIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,27 @@ protected TrainResult createAndStartRealTimeDetector(
List<JsonObject> data,
ImputationMethod imputation,
boolean hc,
long trainTimeMillis
long trainTimeMillis,
String name
) throws Exception {
TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis);
TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, name);
List<JsonObject> result = startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, true);
recordLastSeenFromResult(result);

return trainResult;
}

protected TrainResult createAndStartRealTimeDetector(
int numberOfEntities,
int trainTestSplit,
List<JsonObject> data,
ImputationMethod imputation,
boolean hc,
long trainTimeMillis
) throws Exception {
return createAndStartRealTimeDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, "test");
}

protected TrainResult createAndStartHistoricalDetector(
int numberOfEntities,
int trainTestSplit,
Expand Down Expand Up @@ -115,12 +127,13 @@ protected TrainResult createDetector(
List<JsonObject> data,
ImputationMethod imputation,
boolean hc,
long trainTimeMillis
long trainTimeMillis,
String name
) throws Exception {
Instant trainTime = Instant.ofEpochMilli(trainTimeMillis);

Duration windowDelay = getWindowDelay(trainTimeMillis);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis, name);

RestClient client = client();
String detectorId = createDetector(client, detector);
Expand All @@ -129,6 +142,17 @@ protected TrainResult createDetector(
return new TrainResult(detectorId, data, trainTestSplit * numberOfEntities, windowDelay, trainTime, "timestamp");
}

protected TrainResult createDetector(
int numberOfEntities,
int trainTestSplit,
List<JsonObject> data,
ImputationMethod imputation,
boolean hc,
long trainTimeMillis
) throws Exception {
return createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, "test");
}

protected Duration getWindowDelay(long trainTimeMillis) {
/*
* AD accepts windowDelay in the unit of minutes. Thus, we need to convert the delay in minutes. This will
Expand Down Expand Up @@ -156,7 +180,8 @@ protected abstract String genDetector(
long windowDelayMinutes,
boolean hc,
ImputationMethod imputation,
long trainTimeMillis
long trainTimeMillis,
String name
);

protected abstract AbstractSyntheticDataTest.GenData genData(
Expand Down
75 changes: 71 additions & 4 deletions src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,80 @@ public void testHCPrevious() throws Exception {
);
}

/**
* test we start two HC detector with zero imputation consecutively.
* We expect there is no out of order error from RCF.
* @throws Exception
*/
public void testDoubleHCZero() throws Exception {
lastSeen.clear();
int numberOfEntities = 2;

AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.NO_MISSING_DATA;
ImputationMethod method = ImputationMethod.ZERO;

AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode);

// only ingest train data to avoid validation error as we use latest data time as starting point.
// otherwise, we will have too many missing points.
ingestUniformSingleFeatureData(
trainTestSplit + numberOfEntities * 6, // we only need a few to verify and trigger train.
dataGenerated.data
);

TrainResult trainResult1 = createAndStartRealTimeDetector(
numberOfEntities,
trainTestSplit,
dataGenerated.data,
method,
true,
dataGenerated.testStartTime,
"test1"
);

TrainResult trainResult2 = createAndStartRealTimeDetector(
numberOfEntities,
trainTestSplit,
dataGenerated.data,
method,
true,
dataGenerated.testStartTime,
"test2"
);

runTest(
dataGenerated.testStartTime,
dataGenerated,
trainResult1.windowDelay,
trainResult1.detectorId,
numberOfEntities,
mode,
method,
3,
true
);

runTest(
dataGenerated.testStartTime,
dataGenerated,
trainResult2.windowDelay,
trainResult2.detectorId,
numberOfEntities,
mode,
method,
3,
true
);
}

@Override
protected String genDetector(
int trainTestSplit,
long windowDelayMinutes,
boolean hc,
ImputationMethod imputation,
long trainTimeMillis
long trainTimeMillis,
String name
) {
StringBuilder sb = new StringBuilder();

Expand Down Expand Up @@ -185,7 +252,7 @@ protected String genDetector(
// common part
sb
.append(
"{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\""
"{ \"name\": \"%s\", \"description\": \"test\", \"time_field\": \"timestamp\""
+ ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_id\": \"feature2\", \"feature_name\": \"feature 2\", \"feature_enabled\": "
+ "\"true\", \"aggregation_query\": { \"Feature2\": { \"avg\": { \"field\": \"data\" } } } },"
+ featureWithFilter
Expand Down Expand Up @@ -226,9 +293,9 @@ protected String genDetector(
sb.append("\"schema_version\": 0}");

if (hc) {
return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1, categoricalField);
return String.format(Locale.ROOT, sb.toString(), name, datasetName, intervalMinutes, trainTestSplit - 1, categoricalField);
} else {
return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1);
return String.format(Locale.ROOT, sb.toString(), name, datasetName, intervalMinutes, trainTestSplit - 1);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void testSingleStream() throws Exception {
);

Duration windowDelay = getWindowDelay(dataGenerated.testStartTime);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime, "test");

Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong());
Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong());
Expand Down Expand Up @@ -63,7 +63,7 @@ public void testHC() throws Exception {
);

Duration windowDelay = getWindowDelay(dataGenerated.testStartTime);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime);
String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime, "test");

Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong());
Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong());
Expand Down
Loading