Skip to content

Commit

Permalink
agent: extract logic related to starting/stopping the agent from `App…
Browse files Browse the repository at this point in the history
…lyDisruption`
  • Loading branch information
nadiamoe committed Sep 13, 2023
1 parent 5d8eabc commit c8e9d6b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 32 deletions.
9 changes: 7 additions & 2 deletions cmd/agent/commands/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ func BuildGrpcCmd(env runtime.Environment, config *agent.Config) *cobra.Command
return fmt.Errorf("upstream host cannot be localhost when running in transparent mode")
}

agent, err := agent.BuildAndStart(env, config)
if err != nil {
return fmt.Errorf("initializing agent: %w", err)
}

defer agent.Stop()

listenAddress := net.JoinHostPort("", fmt.Sprint(port))
upstreamAddress := net.JoinHostPort(upstreamHost, fmt.Sprint(targetPort))

Expand Down Expand Up @@ -80,8 +87,6 @@ func BuildGrpcCmd(env runtime.Environment, config *agent.Config) *cobra.Command
return err
}

agent := agent.BuildAgent(env, config)

return agent.ApplyDisruption(cmd.Context(), disruptor, duration)
},
}
Expand Down
9 changes: 7 additions & 2 deletions cmd/agent/commands/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ func BuildHTTPCmd(env runtime.Environment, config *agent.Config) *cobra.Command
return fmt.Errorf("upstream host cannot be localhost when running in transparent mode")
}

agent, err := agent.BuildAndStart(env, config)
if err != nil {
return fmt.Errorf("initializing agent: %w", err)
}

defer agent.Stop()

listenAddress := net.JoinHostPort("", fmt.Sprint(port))
upstreamAddress := "http://" + net.JoinHostPort(upstreamHost, fmt.Sprint(targetPort))

Expand Down Expand Up @@ -79,8 +86,6 @@ func BuildHTTPCmd(env runtime.Environment, config *agent.Config) *cobra.Command
return err
}

agent := agent.BuildAgent(env, config)

return agent.ApplyDisruption(cmd.Context(), disruptor, duration)
},
}
Expand Down
63 changes: 39 additions & 24 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package agent
import (
"context"
"fmt"
"io"
"os"
"syscall"
"time"

Expand All @@ -19,48 +21,51 @@ type Config struct {

// Agent maintains the state required for executing an agent command
type Agent struct {
env runtime.Environment
config *Config
env runtime.Environment
sc <-chan os.Signal
profileCloser io.Closer
}

// BuildAgent builds a instance of an agent
func BuildAgent(env runtime.Environment, config *Config) *Agent {
return &Agent{
env: env,
config: config,
// BuildAndStart creates and starts a new instance of an agent.
// Returned agent is guaranteed to be unique in the environment it is running, and will handle signals sent to the
// process.
// Callers must Stop the returned agent at the end of its lifecycle.
func BuildAndStart(env runtime.Environment, config *Config) (*Agent, error) {
a := &Agent{
env: env,
}

if err := a.start(config); err != nil {
a.Stop() // Stop any initialized component if initialization failed.
return nil, err
}

return a, nil
}

// ApplyDisruption applies a disruption to the target
func (r *Agent) ApplyDisruption(ctx context.Context, disruptor protocol.Disruptor, duration time.Duration) error {
sc := r.env.Signal().Notify(syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
defer func() {
r.env.Signal().Reset()
}()
func (a *Agent) start(config *Config) error {
a.sc = a.env.Signal().Notify(syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)

acquired, err := r.env.Lock().Acquire()
acquired, err := a.env.Lock().Acquire()
if err != nil {
return fmt.Errorf("could not acquire process lock: %w", err)
}

if !acquired {
return fmt.Errorf("another instance of the agent is already running")
}

defer func() {
_ = r.env.Lock().Release()
}()

// start profiler
profiler, err := r.env.Profiler().Start(ctx, *r.config.Profiler)
a.profileCloser, err = a.env.Profiler().Start(*config.Profiler)
if err != nil {
return fmt.Errorf("could not create profiler %w", err)
}

// ensure the profiler is closed even if there's an error executing the command
defer func() {
_ = profiler.Close()
}()
return nil
}

// ApplyDisruption applies a disruption to the target
func (a *Agent) ApplyDisruption(ctx context.Context, disruptor protocol.Disruptor, duration time.Duration) error {
// set context for command
ctx, cancel := context.WithCancel(ctx)

Expand All @@ -83,7 +88,17 @@ func (r *Agent) ApplyDisruption(ctx context.Context, disruptor protocol.Disrupto
return ctx.Err()
case err := <-cc:
return err
case s := <-sc:
case s := <-a.sc:
return fmt.Errorf("received signal %q", s)
}
}

// Stop stops a running agent: It releases
func (a *Agent) Stop() {
a.env.Signal().Reset()
_ = a.env.Lock().Release()

if a.profileCloser != nil {
_ = a.profileCloser.Close()
}
}
18 changes: 14 additions & 4 deletions pkg/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ func Test_CancelContext(t *testing.T) {
t.Parallel()
env := runtime.NewFakeRuntime(tc.args, tc.vars)

agent := BuildAgent(env, tc.config)
agent, err := BuildAndStart(env, tc.config)
if err != nil {
t.Fatalf("starting agent: %v", err)
}

defer agent.Stop()

ctx, cancel := context.WithCancel(context.Background())
go func() {
Expand All @@ -72,7 +77,7 @@ func Test_CancelContext(t *testing.T) {
}()

disruptor := &FakeProtocolDisruptor{}
err := agent.ApplyDisruption(ctx, disruptor, tc.delay)
err = agent.ApplyDisruption(ctx, disruptor, tc.delay)
if !errors.Is(err, tc.expected) {
t.Errorf("expected %v got %v", tc.err, err)
}
Expand Down Expand Up @@ -126,7 +131,12 @@ func Test_Signals(t *testing.T) {
t.Parallel()
env := runtime.NewFakeRuntime(tc.args, tc.vars)

agent := BuildAgent(env, tc.config)
agent, err := BuildAndStart(env, tc.config)
if err != nil {
t.Fatalf("starting agent: %v", err)
}

defer agent.Stop()

go func() {
time.Sleep(1 * time.Second)
Expand All @@ -136,7 +146,7 @@ func Test_Signals(t *testing.T) {
}()

disruptor := &FakeProtocolDisruptor{}
err := agent.ApplyDisruption(context.TODO(), disruptor, tc.delay)
err = agent.ApplyDisruption(context.TODO(), disruptor, tc.delay)
if tc.expectErr && err == nil {
t.Errorf("should had failed")
return
Expand Down

0 comments on commit c8e9d6b

Please sign in to comment.