Skip to content

Commit

Permalink
Simplified the RunContinuityPlugin to no longer rely on simulation-ba…
Browse files Browse the repository at this point in the history
…sed recording of plans into the simulation state. Fixed a small bug in the simulation where plans were being accounted as active when they have no consumer. (#197)
  • Loading branch information
shawnhatch authored Feb 3, 2024
1 parent c1ccc3e commit f5a4bc5
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 185 deletions.
10 changes: 7 additions & 3 deletions gcm/src/main/java/gov/hhs/aspr/ms/gcm/nucleus/Simulation.java
Original file line number Diff line number Diff line change
Expand Up @@ -883,10 +883,12 @@ private void loadExistingPlans() {
planRec.planner = planQueueData.getPlanner();
planRec.time = planQueueData.getTime();

if (planRec.isActive) {
activePlanCount++;
}

if (planRec.plan.getCallbackConsumer() != null) {
if (planRec.isActive) {
activePlanCount++;
}

planningQueue.add(planRec);
Map<Object, PlanRec> map;
if (planRec.key != null) {
Expand Down Expand Up @@ -1041,10 +1043,12 @@ public void execute() {
// initialize the actors by flushing the actor queue
executeActorQueue();


loadExistingPlans();

planningQueueMode = PlanningQueueMode.RUNNING;


while (activePlanCount > 0) {
if (forcedHaltPresent) {
if (planningQueue.peek().time > simulationHaltTime) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,72 +1,59 @@
package gov.hhs.aspr.ms.gcm.nucleus.testsupport.runcontinuityplugin;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.IntStream;

import org.apache.commons.math3.util.Pair;

import gov.hhs.aspr.ms.gcm.nucleus.ActorContext;
import gov.hhs.aspr.ms.gcm.nucleus.Plan;
import util.wrappers.MutableInteger;

public class RunContinuityActor implements Consumer<ActorContext> {
private MutableInteger completionCount = new MutableInteger();
private ActorContext actorContext;
private final RunContinuityPluginData runContinuityPluginData;
private Map<Integer, Pair<Double, Consumer<ActorContext>>> planMap = new LinkedHashMap<>();

public RunContinuityActor(RunContinuityPluginData runContinuityPluginData) {
this.runContinuityPluginData = runContinuityPluginData;
}

public void accept(ActorContext actorContext) {
this.actorContext = actorContext;
actorContext.setPlanDataConverter(RunContinuityPlanData.class, this::getConsumerFromPlanData);

completionCount.setValue(runContinuityPluginData.getCompletionCount());

if (!runContinuityPluginData.plansAreScheduled()) {
List<Pair<Double, Consumer<ActorContext>>> consumers = runContinuityPluginData.getConsumers();
for (int i = 0; i < consumers.size(); i++) {
Pair<Double, Consumer<ActorContext>> pair = consumers.get(i);
double time = pair.getFirst();
Consumer<ActorContext> consumer = pair.getSecond();
List<Pair<Double, Consumer<ActorContext>>> consumers = runContinuityPluginData.getConsumers();
IntStream.range(0,consumers.size()).forEach((i)->{

Pair<Double, Consumer<ActorContext>> pair = consumers.get(i);
planMap.put(i, pair);
double time = pair.getFirst();
Consumer<ActorContext> consumer = pair.getSecond();

RunContinuityPlanData continuityPluginData = new RunContinuityPlanData(i);
RunContinuityPlanData continuityPluginData = new RunContinuityPlanData(i);

Plan<ActorContext> plan = Plan.builder(ActorContext.class)//
.setTime(time)//
.setCallbackConsumer((c) -> executePlan(consumer))//
.setPlanData(continuityPluginData)//
.build();
Plan<ActorContext> plan = Plan.builder(ActorContext.class)//
.setTime(time)//
.setCallbackConsumer((c) -> {
planMap.remove(i);
consumer.accept(actorContext);
})//
.setPlanData(continuityPluginData)//
.build();

actorContext.addPlan(plan);
}
}
actorContext.subscribeToSimulationClose(this::recordState);
}
actorContext.addPlan(plan);

private Consumer<ActorContext> getConsumerFromPlanData(RunContinuityPlanData runContinuityPlanData) {
Consumer<ActorContext> consumer = runContinuityPluginData.getConsumers().get(runContinuityPlanData.getId())
.getSecond();
return (c) -> executePlan(consumer);
}
});

private void executePlan(Consumer<ActorContext> consumer) {
completionCount.increment();
consumer.accept(actorContext);
actorContext.subscribeToSimulationClose(this::recordState);
}

private void recordState(ActorContext actorContext) {
RunContinuityPluginData.Builder builder = RunContinuityPluginData.builder();

builder.setCompletionCount(completionCount.getValue());
List<Pair<Double, Consumer<ActorContext>>> consumers = runContinuityPluginData.getConsumers();
for (Pair<Double, Consumer<ActorContext>> pair : consumers) {
RunContinuityPluginData.Builder builder = RunContinuityPluginData.builder();
for (Pair<Double, Consumer<ActorContext>> pair : planMap.values()) {
double time = pair.getFirst();
Consumer<ActorContext> consumer = pair.getSecond();
builder.addContextConsumer(time, consumer);
}
builder.setPlansAreScheduled(true);
}
actorContext.releaseOutput(builder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,14 @@ public class RunContinuityPluginData implements PluginData {

private static class Data {

private boolean plansAreScheduled;

private int completionCount;

private List<Pair<Double, Consumer<ActorContext>>> consumers = new ArrayList<>();

private boolean locked;

private Data() {
}

private Data(Data data) {
completionCount = data.completionCount;
private Data(Data data) {
consumers.addAll(data.consumers);
locked = data.locked;
}
Expand Down Expand Up @@ -80,15 +75,6 @@ public RunContinuityPluginData build() {
return new RunContinuityPluginData(data);
}

/**
* Sets the plan scheduling state. Defaults to false.
*/
public Builder setPlansAreScheduled(boolean plansAreScheduled) {
ensureDataMutability();
data.plansAreScheduled = plansAreScheduled;
return this;
}

/**
* Schedules a context consumer
*/
Expand All @@ -98,14 +84,6 @@ public Builder addContextConsumer(final double time, final Consumer<ActorContext
return this;
}

/**
* Sets the completion count
*/
public Builder setCompletionCount(final int completionCount) {
ensureDataMutability();
data.completionCount = completionCount;
return this;
}

private void validateData() {
// do nothing
Expand All @@ -119,13 +97,6 @@ private RunContinuityPluginData(Data data) {
this.data = data;
}

/**
* Returns the completion count
*/
public int getCompletionCount() {
return data.completionCount;
}

/**
* Returns the list scheduled consumers
*/
Expand All @@ -138,20 +109,12 @@ public PluginDataBuilder getCloneBuilder() {
return new Builder(data);
}

/**
* Returns true if plans have been scheduled from the consumers
*/
public boolean plansAreScheduled() {
return data.plansAreScheduled;
}

/**
* Returns true if the completion count is greater than or equal to the number
* of contained consumers.
*/
public boolean allPlansComplete() {
return data.completionCount >= data.consumers.size();

return data.consumers.isEmpty();
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gov.hhs.aspr.ms.gcm.nucleus.testsupport.runcontinuityplugin;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.LinkedHashSet;
import java.util.Set;
Expand Down Expand Up @@ -54,7 +55,7 @@ public void testAccept() {
.execute();

runContinuityPluginData = outputConsumer.getOutputItem(RunContinuityPluginData.class).get();
assertEquals(expectedOutput.size(), runContinuityPluginData.getCompletionCount());
assertTrue(runContinuityPluginData.allPlansComplete());


Set<Double> actualOutput = new LinkedHashSet<>(outputConsumer.getOutputItems(Double.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,6 @@ public void testBuilder() {
assertNotNull(RunContinuityPluginData.builder());
}

@Test
@UnitTestMethod(target = RunContinuityPluginData.class, name = "getCompletionCount", args = {})
public void testGetCompletionCount() {
for (int i = 0; i < 10; i++) {
RunContinuityPluginData runContinuityPluginData = RunContinuityPluginData.builder().setCompletionCount(i)
.build();
assertEquals(i, runContinuityPluginData.getCompletionCount());
}

}

@Test
@UnitTestMethod(target = RunContinuityPluginData.class, name = "getConsumers", args = {})
public void testGetConsumers() {
Expand Down Expand Up @@ -67,10 +56,9 @@ public void testGetCloneBuilder() {

for (int i = 0; i < 10; i++) {

RunContinuityPluginData.Builder builder = RunContinuityPluginData.builder()//
.setCompletionCount(randomGenerator.nextInt(3))//
.setPlansAreScheduled(randomGenerator.nextBoolean());//
for (int j = 0; j < 3; j++) {
RunContinuityPluginData.Builder builder = RunContinuityPluginData.builder();//

for (int j = 0; j < i; j++) {
builder.addContextConsumer(randomGenerator.nextDouble(), (c) -> {
});
}
Expand All @@ -79,93 +67,29 @@ public void testGetCloneBuilder() {
RunContinuityPluginData cloneRunContinuityPluginData = //
(RunContinuityPluginData) runContinuityPluginData.getCloneBuilder().build();

assertEquals(runContinuityPluginData.getCompletionCount(),
cloneRunContinuityPluginData.getCompletionCount());

assertEquals(runContinuityPluginData.getConsumers(), cloneRunContinuityPluginData.getConsumers());

}
}

@Test
@UnitTestMethod(target = RunContinuityPluginData.class, name = "plansAreScheduled", args = {})
public void testPlansAreScheduled() {
RunContinuityPluginData runContinuityPluginData = RunContinuityPluginData.builder().setPlansAreScheduled(true)
.build();
assertTrue(runContinuityPluginData.plansAreScheduled());

runContinuityPluginData = RunContinuityPluginData.builder().setPlansAreScheduled(false).build();
assertFalse(runContinuityPluginData.plansAreScheduled());

}

@Test
@UnitTestMethod(target = RunContinuityPluginData.class, name = "allPlansComplete", args = {})
public void testAllPlansComplete() {
assertTrue(RunContinuityPluginData.builder()//
.setCompletionCount(0)//
.setPlansAreScheduled(true)//
.build().allPlansComplete());

assertTrue(RunContinuityPluginData.builder()//
.setCompletionCount(3)//
.setPlansAreScheduled(true)//
.build().allPlansComplete());


assertFalse(RunContinuityPluginData.builder()//
.setCompletionCount(0)//
.setPlansAreScheduled(true)//
.addContextConsumer(1.0, (c) -> {
})//
.build().allPlansComplete());

assertTrue(RunContinuityPluginData.builder()//
.setCompletionCount(1)//
.setPlansAreScheduled(true)//
.addContextConsumer(1.0, (c) -> {
})//
.build().allPlansComplete());


assertFalse(RunContinuityPluginData.builder()//
.setCompletionCount(1)//
.setPlansAreScheduled(true)//
.addContextConsumer(1.0, (c) -> {
})//
.addContextConsumer(1.0, (c) -> {
})//
.build().allPlansComplete());

assertTrue(RunContinuityPluginData.builder()//
.setCompletionCount(2)//
.setPlansAreScheduled(true)//
.addContextConsumer(1.0, (c) -> {
})//
.addContextConsumer(1.0, (c) -> {
})//
.build().allPlansComplete());



}

@Test
@UnitTestMethod(target = RunContinuityPluginData.Builder.class, name = "build", args = {}, tags = {UnitTag.LOCAL_PROXY})
@UnitTestMethod(target = RunContinuityPluginData.Builder.class, name = "build", args = {}, tags = {
UnitTag.LOCAL_PROXY })
public void testBuild() {
//covered by other tests
}

@Test
@UnitTestMethod(target = RunContinuityPluginData.Builder.class, name = "setPlansAreScheduled", args = {
boolean.class })
public void testSetPlansAreScheduled() {
RunContinuityPluginData runContinuityPluginData = RunContinuityPluginData.builder().setPlansAreScheduled(true)
.build();
assertTrue(runContinuityPluginData.plansAreScheduled());

runContinuityPluginData = RunContinuityPluginData.builder().setPlansAreScheduled(false).build();
assertFalse(runContinuityPluginData.plansAreScheduled());
// covered by other tests
}

@Test
Expand All @@ -191,14 +115,4 @@ public void testAddContextConsumer() {
assertEquals(expectedPairs, runContinuityPluginData.getConsumers());
}

@Test
@UnitTestMethod(target = RunContinuityPluginData.Builder.class, name = "setCompletionCount", args = { int.class })
public void testSetCompletionCount() {
for (int i = 0; i < 10; i++) {
RunContinuityPluginData runContinuityPluginData = RunContinuityPluginData.builder().setCompletionCount(i)
.build();
assertEquals(i, runContinuityPluginData.getCompletionCount());
}
}

}
Loading

0 comments on commit f5a4bc5

Please sign in to comment.