Skip to content

Commit

Permalink
[ISSUE-426] add prefetch event to state machine (TuGraph-family#427)
Browse files Browse the repository at this point in the history
* [ISSUE-426] add prefetch event to state machine

* fix ut
  • Loading branch information
xincai98 authored Dec 16, 2024
1 parent a5978f8 commit 956f6ff
Show file tree
Hide file tree
Showing 7 changed files with 437 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ public enum ScheduleStateType implements Serializable {
*/
START,

/**
* Shuffle prefetch state.
*/
PREFETCH,

/**
* Shuffle finish prefetch state.
*/
FINISH_PREFETCH,

/**
* Init state.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@

package com.antgroup.geaflow.runtime.core.scheduler;

import com.antgroup.geaflow.cluster.protocol.EventType;
import com.antgroup.geaflow.cluster.protocol.IEvent;
import com.antgroup.geaflow.cluster.resourcemanager.WorkerInfo;
import com.antgroup.geaflow.common.exception.GeaflowRuntimeException;
import com.antgroup.geaflow.common.tuple.Tuple;
import com.antgroup.geaflow.core.graph.ExecutionTask;
import com.antgroup.geaflow.runtime.core.protocol.FinishPrefetchEvent;
import com.antgroup.geaflow.runtime.core.protocol.PrefetchEvent;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
Expand All @@ -40,29 +37,6 @@ public Map<WorkerInfo, List<ExecutableEvent>> getEvents() {
}

public void markReady() {
for (Map.Entry<WorkerInfo, List<ExecutableEvent>> entry : this.worker2events.entrySet()) {
List<ExecutableEvent> events = entry.getValue();
List<ExecutableEvent> finishPrefetchEvents = new ArrayList<>();
for (ExecutableEvent executableEvent : events) {
IEvent event = executableEvent.getEvent();
if (event.getEventType() == EventType.PREFETCH) {
PrefetchEvent prefetchEvent = (PrefetchEvent) event;
FinishPrefetchEvent finishPrefetchEvent = new FinishPrefetchEvent(
prefetchEvent.getSchedulerId(),
prefetchEvent.getWorkerId(),
prefetchEvent.getCycleId(),
prefetchEvent.getIterationWindowId(),
executableEvent.getTask().getTaskId(),
executableEvent.getTask().getIndex(),
prefetchEvent.getPipelineId(),
prefetchEvent.getEdgeIds());
ExecutableEvent finishExecutableEvent = ExecutableEvent.build(
executableEvent.getWorker(), executableEvent.getTask(), finishPrefetchEvent);
finishPrefetchEvents.add(finishExecutableEvent);
}
}
events.addAll(finishPrefetchEvents);
}
this.workerIterator = this.worker2events.entrySet().iterator();
this.ready = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import com.antgroup.geaflow.cluster.protocol.IEvent;
import com.antgroup.geaflow.cluster.protocol.ScheduleStateType;
import com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys;
import com.antgroup.geaflow.common.exception.GeaflowRuntimeException;
import com.antgroup.geaflow.core.graph.ExecutionTask;
import com.antgroup.geaflow.core.graph.util.ExecutionTaskUtils;
Expand All @@ -28,6 +27,7 @@
import com.antgroup.geaflow.runtime.core.protocol.ExecuteComputeEvent;
import com.antgroup.geaflow.runtime.core.protocol.ExecuteFirstIterationEvent;
import com.antgroup.geaflow.runtime.core.protocol.FinishIterationEvent;
import com.antgroup.geaflow.runtime.core.protocol.FinishPrefetchEvent;
import com.antgroup.geaflow.runtime.core.protocol.InitCollectCycleEvent;
import com.antgroup.geaflow.runtime.core.protocol.InitCycleEvent;
import com.antgroup.geaflow.runtime.core.protocol.InitIterationEvent;
Expand All @@ -39,6 +39,7 @@
import com.antgroup.geaflow.runtime.core.protocol.PrefetchEvent;
import com.antgroup.geaflow.runtime.core.protocol.RollbackCycleEvent;
import com.antgroup.geaflow.runtime.core.protocol.StashWorkerEvent;
import com.antgroup.geaflow.runtime.core.scheduler.ExecutableEventIterator.ExecutableEvent;
import com.antgroup.geaflow.runtime.core.scheduler.context.ICycleSchedulerContext;
import com.antgroup.geaflow.runtime.core.scheduler.cycle.CollectExecutionNodeCycle;
import com.antgroup.geaflow.runtime.core.scheduler.cycle.ExecutionCycleType;
Expand All @@ -51,7 +52,10 @@
import com.antgroup.geaflow.shuffle.desc.OutputType;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

public class SchedulerEventBuilder {
Expand All @@ -63,7 +67,6 @@ public class SchedulerEventBuilder {
private final CycleResultManager resultManager;
private final boolean enableAffinity;
private final boolean isIteration;
private final boolean prefetch;
private final long schedulerId;

public SchedulerEventBuilder(ICycleSchedulerContext<ExecutionNodeCycle, ExecutionGraphCycle, ?> context,
Expand All @@ -75,12 +78,13 @@ public SchedulerEventBuilder(ICycleSchedulerContext<ExecutionNodeCycle, Executio
this.enableAffinity = context.getParentContext() != null
&& context.getParentContext().getCycle().getIterationCount() > 1;
this.isIteration = cycle.getVertexGroup().getCycleGroupMeta().isIterative();
this.prefetch = context.getConfig().getBoolean(ExecutionConfigKeys.SHUFFLE_PREFETCH);
this.schedulerId = schedulerId;
}

public ExecutableEventIterator build(ScheduleStateType state, long iterationId) {
switch (state) {
case PREFETCH:
return this.buildPrefetch();
case INIT:
return this.buildInitPipeline();
case ITERATION_INIT:
Expand All @@ -89,6 +93,8 @@ public ExecutableEventIterator build(ScheduleStateType state, long iterationId)
return buildExecute(iterationId);
case ITERATION_FINISH:
return this.finishIteration();
case FINISH_PREFETCH:
return this.buildFinishPrefetch();
case CLEAN_CYCLE:
return this.finishPipeline();
case ROLLBACK:
Expand All @@ -99,15 +105,44 @@ public ExecutableEventIterator build(ScheduleStateType state, long iterationId)

}

private ExecutableEventIterator buildPrefetch() {
ExecutableEventIterator iterator = this.buildChildrenPrefetchEvent();
return iterator;
}

private ExecutableEventIterator buildFinishPrefetch() {
ExecutableEventIterator events = new ExecutableEventIterator();
Map<Integer, ExecutableEvent> needFinishedPrefetchEvents =
this.context.getPrefetchEvents();
Iterator<Entry<Integer, ExecutableEvent>> iterator = needFinishedPrefetchEvents.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<Integer, ExecutableEvent> entry = iterator.next();
ExecutableEvent executableEvent = entry.getValue();
IEvent event = executableEvent.getEvent();
PrefetchEvent prefetchEvent = (PrefetchEvent) event;
FinishPrefetchEvent finishPrefetchEvent = new FinishPrefetchEvent(
prefetchEvent.getSchedulerId(),
prefetchEvent.getWorkerId(),
prefetchEvent.getCycleId(),
prefetchEvent.getIterationWindowId(),
executableEvent.getTask().getTaskId(),
executableEvent.getTask().getIndex(),
prefetchEvent.getPipelineId(),
prefetchEvent.getEdgeIds());
ExecutableEvent finishExecutableEvent = ExecutableEvent.build(
executableEvent.getWorker(), executableEvent.getTask(), finishPrefetchEvent);
events.addEvent(finishExecutableEvent);
iterator.remove();
}
return events;
}

private ExecutableEventIterator buildInitPipeline() {
ExecutableEventIterator iterator = new ExecutableEventIterator();
if (this.prefetch && !this.cycle.isIterative()) {
ExecutableEventIterator prefetchEvents = this.buildChildrenPrefetchEvent();
iterator.merge(prefetchEvents);
}
for (ExecutionTask task : this.cycle.getTasks()) {
IoDescriptor ioDescriptor =
IoDescriptorBuilder.buildPipelineIoDescriptor(task, this.cycle, this.resultManager, this.prefetch);
IoDescriptorBuilder.buildPipelineIoDescriptor(task, this.cycle,
this.resultManager, this.context.isPrefetch());
iterator.addEvent(task.getWorkerInfo(), task, buildInitOrPopEvent(task, ioDescriptor));
}
return iterator;
Expand All @@ -122,9 +157,14 @@ private ExecutableEventIterator buildChildrenPrefetchEvent() {
if (childCycle instanceof ExecutionNodeCycle) {
ExecutionNodeCycle childNodeCycle = (ExecutionNodeCycle) childCycle;
List<ExecutionTask> childHeadTasks = childNodeCycle.getCycleHeads();
Map<Integer, ExecutableEvent> needFinishedPrefetchEvents =
this.context.getPrefetchEvents();
for (ExecutionTask childHeadTask : childHeadTasks) {
PrefetchEvent prefetchEvent = this.buildPrefetchEvent(childNodeCycle, childHeadTask);
iterator.addEvent(childHeadTask.getWorkerInfo(), childHeadTask, prefetchEvent);
ExecutableEvent executableEvent = ExecutableEvent.build(childHeadTask.getWorkerInfo(),
childHeadTask, prefetchEvent);
iterator.addEvent(executableEvent);
needFinishedPrefetchEvents.put(childHeadTask.getTaskId(), executableEvent);
}
}
}
Expand Down Expand Up @@ -286,10 +326,6 @@ private ExecutableEventIterator finishPipeline() {

private ExecutableEventIterator finishIteration() {
ExecutableEventIterator iterator = new ExecutableEventIterator();
if (this.prefetch) {
ExecutableEventIterator prefetchEvents = this.buildChildrenPrefetchEvent();
iterator.merge(prefetchEvents);
}
for (ExecutionTask task : this.cycle.getTasks()) {
int workerId = task.getWorkerInfo().getWorkerIndex();
// Finish iteration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@

import com.antgroup.geaflow.cluster.resourcemanager.WorkerInfo;
import com.antgroup.geaflow.common.config.Configuration;
import com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys;
import com.antgroup.geaflow.ha.runtime.HighAvailableLevel;
import com.antgroup.geaflow.pipeline.callback.ICallbackFunction;
import com.antgroup.geaflow.runtime.core.scheduler.ExecutableEventIterator.ExecutableEvent;
import com.antgroup.geaflow.runtime.core.scheduler.cycle.IExecutionCycle;
import com.antgroup.geaflow.runtime.core.scheduler.io.CycleResultManager;
import com.antgroup.geaflow.runtime.core.scheduler.resource.IScheduledWorkerManager;
import com.antgroup.geaflow.runtime.core.scheduler.resource.ScheduledWorkerManagerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -51,6 +55,8 @@ public abstract class AbstractCycleSchedulerContext<
protected transient CycleResultManager cycleResultManager;
protected ICallbackFunction callbackFunction;
protected static ThreadLocal<Boolean> rollback = ThreadLocal.withInitial(() -> false);
protected transient Map<Integer, ExecutableEvent> prefetchEvents;
protected transient boolean prefetch;

public AbstractCycleSchedulerContext(C cycle, PCC parentContext) {
this.cycle = cycle;
Expand Down Expand Up @@ -94,6 +100,8 @@ public void init(long startIterationId) {
} else {
this.cycleResultManager = new CycleResultManager();
}
prefetch = cycle.getConfig().getBoolean(ExecutionConfigKeys.SHUFFLE_PREFETCH);
prefetchEvents = new HashMap<>();

LOGGER.info("{} init cycle context onTheFlyThreshold {}, currentIterationId {}, "
+ "iterationCount {}, finishIterationId {}, initialIterationId {}",
Expand Down Expand Up @@ -127,14 +135,21 @@ public void setCurrentIterationId(long iterationId) {
this.currentIterationId = iterationId;
}

@Override
public boolean isRecovered() {
return false;
}

@Override
public boolean isRollback() {
return rollback.get();
}

@Override
public boolean isPrefetch() {
return prefetch;
}

public void setRollback(boolean bool) {
rollback.set(bool);
}
Expand Down Expand Up @@ -200,6 +215,10 @@ public void setCallbackFunction(ICallbackFunction callbackFunction) {
this.callbackFunction = callbackFunction;
}

public Map<Integer, ExecutableEvent> getPrefetchEvents() {
return this.prefetchEvents;
}

@Override
public void finish(long windowId) {
if (callbackFunction != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

import com.antgroup.geaflow.cluster.resourcemanager.WorkerInfo;
import com.antgroup.geaflow.common.config.Configuration;
import com.antgroup.geaflow.runtime.core.scheduler.ExecutableEventIterator.ExecutableEvent;
import com.antgroup.geaflow.runtime.core.scheduler.cycle.IExecutionCycle;
import com.antgroup.geaflow.runtime.core.scheduler.io.CycleResultManager;
import com.antgroup.geaflow.runtime.core.scheduler.resource.IScheduledWorkerManager;
import java.io.Serializable;
import java.util.List;
import java.util.Map;

public interface ICycleSchedulerContext<
C extends IExecutionCycle,
Expand Down Expand Up @@ -54,6 +56,11 @@ public interface ICycleSchedulerContext<
*/
boolean isRollback();

/**
* Returns whether enable prefetch.
*/
boolean isPrefetch();

/**
* Returns current iteration id.
*/
Expand Down Expand Up @@ -124,6 +131,11 @@ public interface ICycleSchedulerContext<
*/
IScheduledWorkerManager<C> getSchedulerWorkerManager();

/**
* Returns prefetch events needed to be finished.
*/
Map<Integer, ExecutableEvent> getPrefetchEvents();

enum SchedulerState {
/**
* Init state.
Expand Down
Loading

0 comments on commit 956f6ff

Please sign in to comment.