Skip to content

Commit

Permalink
Lift SyncFlowWorkflow into activity entirely (#2371)
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex authored Dec 24, 2024
1 parent c9fba56 commit f388a3f
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 294 deletions.
166 changes: 97 additions & 69 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"

"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5"
Expand All @@ -16,7 +14,7 @@ import (
"go.opentelemetry.io/otel/metric"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"
"go.temporal.io/sdk/temporal"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"

"github.com/PeerDB-io/peer-flow/alerting"
Expand All @@ -43,19 +41,17 @@ type NormalizeBatchRequest struct {
BatchID int64
}

type CdcCacheEntry struct {
connector connectors.CDCPullConnectorCore
syncDone chan struct{}
normalize chan NormalizeBatchRequest
normalizeDone chan struct{}
type CdcState struct {
connector connectors.CDCPullConnectorCore
syncDone chan struct{}
normalize chan NormalizeBatchRequest
errGroup *errgroup.Group
}

type FlowableActivity struct {
CatalogPool *pgxpool.Pool
Alerter *alerting.Alerter
CdcCache map[string]CdcCacheEntry
OtelManager *otel_metrics.OtelManager
CdcCacheRw sync.RWMutex
}

func (a *FlowableActivity) CheckConnection(
Expand Down Expand Up @@ -253,91 +249,125 @@ func (a *FlowableActivity) CreateNormalizedTable(
}, nil
}

func (a *FlowableActivity) MaintainPull(
func (a *FlowableActivity) maintainPull(
ctx context.Context,
config *protos.FlowConnectionConfigs,
sessionID string,
) error {
) (CdcState, context.Context, error) {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
srcConn, err := connectors.GetByNameAs[connectors.CDCPullConnector](ctx, config.Env, a.CatalogPool, config.SourceName)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
return CdcState{}, nil, err
}
defer connectors.CloseConnector(ctx, srcConn)

if err := srcConn.SetupReplConn(ctx); err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
connectors.CloseConnector(ctx, srcConn)
return CdcState{}, nil, err
}

normalizeBufferSize, err := peerdbenv.PeerDBNormalizeChannelBufferSize(ctx, config.Env)
if err != nil {
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return err
connectors.CloseConnector(ctx, srcConn)
return CdcState{}, nil, err
}

// syncDone will be closed by UnmaintainPull,
// whereas normalizeDone will be closed by the normalize goroutine
// syncDone will be closed by SyncFlow,
// whereas normalizeDone will be closed by normalizing goroutine
// Wait on normalizeDone at end to not interrupt final normalize
syncDone := make(chan struct{})
normalize := make(chan NormalizeBatchRequest, normalizeBufferSize)
normalizeDone := make(chan struct{})
a.CdcCacheRw.Lock()
a.CdcCache[sessionID] = CdcCacheEntry{
connector: srcConn,
syncDone: syncDone,
normalize: normalize,
normalizeDone: normalizeDone,
}
a.CdcCacheRw.Unlock()

ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()

go a.normalizeLoop(ctx, config, syncDone, normalize, normalizeDone)

for {
select {
case <-ticker.C:
activity.RecordHeartbeat(ctx, "keep session alive")
if err := srcConn.ReplPing(ctx); err != nil {
a.CdcCacheRw.Lock()
delete(a.CdcCache, sessionID)
a.CdcCacheRw.Unlock()
a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
return temporal.NewNonRetryableApplicationError("connection to source down", "disconnect", err)

group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
// returning error signals sync to stop, normalize can recover connections without interrupting sync, so never return error
a.normalizeLoop(groupCtx, config, syncDone, normalize)
return nil
})
group.Go(func() error {
defer connectors.CloseConnector(groupCtx, srcConn)
if err := a.maintainReplConn(groupCtx, config.FlowJobName, srcConn, syncDone); err != nil {
a.Alerter.LogFlowError(groupCtx, config.FlowJobName, err)
return err
}
return nil
})

return CdcState{
connector: srcConn,
syncDone: syncDone,
normalize: normalize,
errGroup: group,
}, groupCtx, nil
}

func (a *FlowableActivity) SyncFlow(
ctx context.Context,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
) error {
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := activity.GetLogger(ctx)

cdcState, groupCtx, err := a.maintainPull(ctx, config)
if err != nil {
logger.Error("MaintainPull failed", slog.Any("error", err))
return err
}

currentSyncFlowNum := int32(0)
totalRecordsSynced := int64(0)

for groupCtx.Err() == nil {
currentSyncFlowNum += 1
logger.Info("executing sync flow", slog.Int("count", int(currentSyncFlowNum)))

var numRecordsSynced int64
var syncErr error
if config.System == protos.TypeSystem_Q {
numRecordsSynced, syncErr = a.SyncRecords(groupCtx, config, options, cdcState)
} else {
numRecordsSynced, syncErr = a.SyncPg(groupCtx, config, options, cdcState)
}

if syncErr != nil {
if groupCtx.Err() != nil {
// need to return ctx.Err(), avoid returning syncErr that's wrapped context canceled
break
}
logger.Error("failed to sync records", slog.Any("error", syncErr))
close(cdcState.syncDone)
return errors.Join(syncErr, cdcState.errGroup.Wait())
} else {
totalRecordsSynced += numRecordsSynced
logger.Info("synced records",
slog.Int64("numRecordsSynced", numRecordsSynced), slog.Int64("totalRecordsSynced", totalRecordsSynced))

if options.NumberOfSyncs > 0 && currentSyncFlowNum >= options.NumberOfSyncs {
break
}
case <-syncDone:
return nil
case <-ctx.Done():
a.CdcCacheRw.Lock()
delete(a.CdcCache, sessionID)
a.CdcCacheRw.Unlock()
return nil
}
}
}

func (a *FlowableActivity) UnmaintainPull(ctx context.Context, sessionID string) error {
var normalizeDone chan struct{}
a.CdcCacheRw.Lock()
if entry, ok := a.CdcCache[sessionID]; ok {
close(entry.syncDone)
delete(a.CdcCache, sessionID)
normalizeDone = entry.normalizeDone
close(cdcState.syncDone)
waitErr := cdcState.errGroup.Wait()
if err := ctx.Err(); err != nil {
logger.Info("sync canceled", slog.Any("error", err))
return err
} else if waitErr != nil {
logger.Error("sync failed", slog.Any("error", waitErr))
return waitErr
}
a.CdcCacheRw.Unlock()
<-normalizeDone
return nil
}

func (a *FlowableActivity) SyncRecords(
ctx context.Context,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
sessionID string,
) (model.SyncRecordsResult, error) {
cdcState CdcState,
) (int64, error) {
var adaptStream func(stream *model.CDCStream[model.RecordItems]) (*model.CDCStream[model.RecordItems], error)
if config.Script != "" {
var onErr context.CancelCauseFunc
Expand Down Expand Up @@ -368,22 +398,20 @@ func (a *FlowableActivity) SyncRecords(
return stream, nil
}
}
numRecords, err := syncCore(ctx, a, config, options, sessionID, adaptStream,
return syncCore(ctx, a, config, options, cdcState, adaptStream,
connectors.CDCPullConnector.PullRecords,
connectors.CDCSyncConnector.SyncRecords)
return model.SyncRecordsResult{NumRecordsSynced: numRecords}, err
}

func (a *FlowableActivity) SyncPg(
ctx context.Context,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
sessionID string,
) (model.SyncRecordsResult, error) {
numRecords, err := syncCore(ctx, a, config, options, sessionID, nil,
cdcState CdcState,
) (int64, error) {
return syncCore(ctx, a, config, options, cdcState, nil,
connectors.CDCPullPgConnector.PullPg,
connectors.CDCSyncPgConnector.SyncPg)
return model.SyncRecordsResult{NumRecordsSynced: numRecords}, err
}

func (a *FlowableActivity) StartNormalize(
Expand Down
70 changes: 25 additions & 45 deletions flow/activities/flowable_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"log/slog"
"reflect"
"sync/atomic"
"time"

Expand Down Expand Up @@ -50,43 +49,6 @@ func heartbeatRoutine(
)
}

func waitForCdcCache[TPull connectors.CDCPullConnectorCore](
ctx context.Context, a *FlowableActivity, sessionID string,
) (TPull, chan NormalizeBatchRequest, error) {
var none TPull
logger := activity.GetLogger(ctx)
attempt := 0
waitInterval := time.Second
// try for 5 minutes, once per second
// after that, try indefinitely every minute
for {
a.CdcCacheRw.RLock()
entry, ok := a.CdcCache[sessionID]
a.CdcCacheRw.RUnlock()
if ok {
if conn, ok := entry.connector.(TPull); ok {
return conn, entry.normalize, nil
}
return none, nil, fmt.Errorf("expected %s, cache held %T", reflect.TypeFor[TPull]().Name(), entry.connector)
}
activity.RecordHeartbeat(ctx, fmt.Sprintf("wait %s for source connector", waitInterval))
attempt += 1
if attempt > 2 {
logger.Info("waiting on source connector setup",
slog.Int("attempt", attempt), slog.String("sessionID", sessionID))
}
if err := ctx.Err(); err != nil {
return none, nil, err
}
time.Sleep(waitInterval)
if attempt == 300 {
logger.Info("source connector not setup in time, transition to slow wait",
slog.String("sessionID", sessionID))
waitInterval = time.Minute
}
}
}

func (a *FlowableActivity) getTableNameSchemaMapping(ctx context.Context, flowName string) (map[string]*protos.TableSchema, error) {
rows, err := a.CatalogPool.Query(ctx, "select table_name, table_schema from table_schema_mapping where flow_name = $1", flowName)
if err != nil {
Expand Down Expand Up @@ -142,7 +104,7 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon
a *FlowableActivity,
config *protos.FlowConnectionConfigs,
options *protos.SyncFlowOptions,
sessionID string,
cdcState CdcState,
adaptStream func(*model.CDCStream[Items]) (*model.CDCStream[Items], error),
pull func(TPull, context.Context, *pgxpool.Pool, *otel_metrics.OtelManager, *model.PullRecordsRequest[Items]) error,
sync func(TSync, context.Context, *model.SyncRecordsRequest[Items]) (*model.SyncResponse, error),
Expand All @@ -160,10 +122,8 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon
tblNameMapping[v.SourceTableIdentifier] = model.NewNameAndExclude(v.DestinationTableIdentifier, v.Exclude)
}

srcConn, normChan, err := waitForCdcCache[TPull](ctx, a, sessionID)
if err != nil {
return 0, err
}
srcConn := cdcState.connector.(TPull)
normChan := cdcState.normalize
if err := srcConn.ConnectionActive(ctx); err != nil {
return 0, temporal.NewNonRetryableApplicationError("connection to source down", "disconnect", nil)
}
Expand Down Expand Up @@ -640,15 +600,35 @@ func replicateXminPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn
return currentSnapshotXmin, nil
}

func (a *FlowableActivity) maintainReplConn(
ctx context.Context, flowName string, srcConn connectors.CDCPullConnector, syncDone <-chan struct{},
) error {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
activity.RecordHeartbeat(ctx, "keep session alive")
if err := srcConn.ReplPing(ctx); err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
return fmt.Errorf("connection to source down: %w", err)
}
case <-syncDone:
return nil
case <-ctx.Done():
return nil
}
}
}

// Suitable to be run as goroutine
func (a *FlowableActivity) normalizeLoop(
ctx context.Context,
config *protos.FlowConnectionConfigs,
syncDone <-chan struct{},
normalize <-chan NormalizeBatchRequest,
normalizeDone chan struct{},
) {
defer close(normalizeDone)
logger := activity.GetLogger(ctx)

for {
Expand Down
1 change: 0 additions & 1 deletion flow/cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ func WorkerSetup(opts *WorkerSetupOptions) (*workerSetupResponse, error) {
w.RegisterActivity(&activities.FlowableActivity{
CatalogPool: conn,
Alerter: alerting.NewAlerter(context.Background(), conn),
CdcCache: make(map[string]activities.CdcCacheEntry),
OtelManager: otelManager,
})

Expand Down
2 changes: 1 addition & 1 deletion flow/e2e/postgres/peer_flow_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ func (s PeerFlowE2ETestSuitePG) Test_CustomSync() {
_, err = s.Conn().Exec(context.Background(), fmt.Sprintf(
"INSERT INTO %s(key, value) VALUES ('test_key', 'test_value')", srcTableName))
e2e.EnvNoError(s.t, env, err)
e2e.EnvWaitFor(s.t, env, 1*time.Minute, "paused workflow", func() bool {
e2e.EnvWaitFor(s.t, env, 3*time.Minute, "paused workflow", func() bool {
return e2e.EnvGetFlowStatus(s.t, env) == protos.FlowStatus_STATUS_PAUSED
})

Expand Down
4 changes: 0 additions & 4 deletions flow/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@ type SyncResponse struct {
CurrentSyncBatchID int64
}

type SyncRecordsResult struct {
NumRecordsSynced int64
}

type NormalizeResponse struct {
StartBatchID int64
EndBatchID int64
Expand Down
4 changes: 0 additions & 4 deletions flow/model/signals.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,3 @@ var FlowSignal = TypedSignal[CDCFlowSignal]{
var CDCDynamicPropertiesSignal = TypedSignal[*protos.CDCFlowConfigUpdate]{
Name: "cdc-dynamic-properties",
}

var SyncStopSignal = TypedSignal[struct{}]{
Name: "sync-stop",
}
Loading

0 comments on commit f388a3f

Please sign in to comment.