Skip to content

Commit

Permalink
close progress with stdin
Browse files Browse the repository at this point in the history
  • Loading branch information
mostlikelee committed Dec 2, 2024
1 parent 806580b commit 0a3001d
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 112 deletions.
2 changes: 1 addition & 1 deletion orbit/pkg/dialog/dialog.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Dialog interface {
ShowInfo(ctx context.Context, opts InfoOptions) error
// Progress displays a dialog that shows progress. It waits until the
// context is cancelled.
ShowProgress(ctx context.Context, opts ProgressOptions) error
ShowProgress(opts ProgressOptions) (cancelFunc func() error, err error)
}

// EntryOptions represents options for a dialog that accepts end user input.
Expand Down
10 changes: 10 additions & 0 deletions orbit/pkg/execuser/execuser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// SYSTEM service on Windows) as the current login user.
package execuser

import "io"

type eopts struct {
env [][2]string
args [][2]string
Expand Down Expand Up @@ -49,3 +51,11 @@ func RunWithOutput(path string, opts ...Option) (output []byte, exitCode int, er
}
return runWithOutput(path, o)
}

func RunWithStdin(path string, opts ...Option) (io.WriteCloser, error) {
var o eopts
for _, fn := range opts {
fn(&o)
}
return runWithStdin(path, o)
}
4 changes: 4 additions & 0 deletions orbit/pkg/execuser/execuser_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,7 @@ func run(path string, opts eopts) (lastLogs string, err error) {
func runWithOutput(path string, opts eopts) (output []byte, exitCode int, err error) {
return nil, 0, errors.New("not implemented")
}

func runWithStdin(path string, opts eopts) (io.WriteCloser, error) {
return nil, errors.New("not implemented")
}
29 changes: 29 additions & 0 deletions orbit/pkg/execuser/execuser_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,35 @@ func runWithOutput(path string, opts eopts) (output []byte, exitCode int, err er
return output, exitCode, nil
}

func runWithStdin(path string, opts eopts) (io.WriteCloser, error) {
args, err := getUserAndDisplayArgs(path, opts)
if err != nil {
return nil, fmt.Errorf("get args: %w", err)
}

args = append(args, path)

if len(opts.args) > 0 {
for _, arg := range opts.args {
args = append(args, arg[0], arg[1])
}
}

cmd := exec.Command("sudo", args...)
log.Printf("cmd=%s", cmd.String())

stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("stdin pipe: %w", err)
}

if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("open path %q: %w", path, err)
}

return stdin, nil
}

func getUserAndDisplayArgs(path string, opts eopts) ([]string, error) {
user, err := getLoginUID()
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions orbit/pkg/execuser/execuser_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package execuser
import (
"errors"
"fmt"
"io"
"os"
"unsafe"

Expand Down Expand Up @@ -121,6 +122,10 @@ func runWithOutput(path string, opts eopts) (output []byte, exitCode int, err er
return nil, 0, errors.New("not implemented")
}

func runWithStdin(path string, opts eopts) (io.WriteCloser, error) {
return nil, errors.New("not implemented")
}

// getCurrentUserSessionId will attempt to resolve
// the session ID of the user currently active on
// the system.
Expand Down
18 changes: 7 additions & 11 deletions orbit/pkg/luks/luks_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (lr *LuksRunner) Run(oc *fleet.OrbitConfig) error {
if err := removeKeySlot(ctx, devicePath, *keyslot); err != nil {
log.Error().Err(err).Msgf("failed to remove key slot %d", *keyslot)
}
return fmt.Errorf("Failed to get salt for key slot: %w", err)
response.Err = fmt.Sprintf("Failed to get salt for key slot: %s", err)
}
response.Salt = salt
}
Expand Down Expand Up @@ -118,13 +118,14 @@ func (lr *LuksRunner) getEscrowKey(ctx context.Context, devicePath string) ([]by
return nil, nil, nil
}

err = lr.notifier.ShowProgress(ctx, dialog.ProgressOptions{
cancelProgress, err := lr.notifier.ShowProgress(dialog.ProgressOptions{
Title: infoTitle,
Text: "Validating passphrase...",
})
if err != nil {
log.Error().Err(err).Msg("failed to show progress dialog")
}
defer cancelProgress()

Check failure on line 128 in orbit/pkg/luks/luks_linux.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

Error return value is not checked (errcheck)

// Validate the passphrase
for {
Expand All @@ -147,22 +148,17 @@ func (lr *LuksRunner) getEscrowKey(ctx context.Context, devicePath string) ([]by
return nil, nil, nil
}

err = lr.notifier.ShowProgress(ctx, dialog.ProgressOptions{
Title: infoTitle,
Text: "Validating passphrase...",
})
if err != nil {
log.Error().Err(err).Msg("failed to show progress dialog after retry")
}
}

err = lr.notifier.ShowProgress(ctx, dialog.ProgressOptions{
cancelProgress()

Check failure on line 153 in orbit/pkg/luks/luks_linux.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

Error return value is not checked (errcheck)
cancelProgress, err = lr.notifier.ShowProgress(dialog.ProgressOptions{
Title: infoTitle,
Text: "Key escrow in progress...",
Text: "Escrowing key...",
})
if err != nil {
log.Error().Err(err).Msg("failed to show progress dialog")
}
defer cancelProgress()

Check failure on line 161 in orbit/pkg/luks/luks_linux.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

Error return value is not checked (errcheck)

escrowPassphrase, err := generateRandomPassphrase()
if err != nil {
Expand Down
52 changes: 17 additions & 35 deletions orbit/pkg/zenity/zenity.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (

"github.com/fleetdm/fleet/v4/orbit/pkg/dialog"
"github.com/fleetdm/fleet/v4/orbit/pkg/execuser"
"github.com/fleetdm/fleet/v4/orbit/pkg/platform"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/rs/zerolog/log"
)

const zenityProcessName = "zenity"
Expand All @@ -18,26 +16,21 @@ type Zenity struct {
// cmdWithOutput can be set in tests to mock execution of the dialog.
cmdWithOutput func(ctx context.Context, args ...string) ([]byte, int, error)
// cmdWithWait can be set in tests to mock execution of the dialog.
cmdWithWait func(ctx context.Context, args ...string) error
// killZenityFunc can be set in tests to mock killing the zenity process.
killZenityFunc func()
cmdWithCancel func(args ...string) (func() error, error)
}

// New creates a new Zenity dialog instance for zenity v4 on Linux.
// Zenity implements the Dialog interface.
func New() *Zenity {
return &Zenity{
cmdWithOutput: execCmdWithOutput,
cmdWithWait: execCmdWithWait,
killZenityFunc: killZenityProcesses,
cmdWithOutput: execCmdWithOutput,
cmdWithCancel: execCmdWithCancel,
}
}

// ShowEntry displays an dialog that accepts end user input. It returns the entered
// text or errors ErrCanceled, ErrTimeout, or ErrUnknown.
func (z *Zenity) ShowEntry(ctx context.Context, opts dialog.EntryOptions) ([]byte, error) {
z.killZenityFunc()

args := []string{"--entry"}
if opts.Title != "" {
args = append(args, fmt.Sprintf("--title=%s", opts.Title))
Expand Down Expand Up @@ -69,8 +62,6 @@ func (z *Zenity) ShowEntry(ctx context.Context, opts dialog.EntryOptions) ([]byt

// ShowInfo displays an information dialog. It returns errors ErrTimeout or ErrUnknown.
func (z *Zenity) ShowInfo(ctx context.Context, opts dialog.InfoOptions) error {
z.killZenityFunc()

args := []string{"--info"}
if opts.Title != "" {
args = append(args, fmt.Sprintf("--title=%s", opts.Title))
Expand All @@ -95,18 +86,9 @@ func (z *Zenity) ShowInfo(ctx context.Context, opts dialog.InfoOptions) error {
return nil
}

// ShowProgress starts a Zenity progress dialog with the given options.
// This function is designed to block until the provided context is canceled.
// It is intended to be used within a separate goroutine to avoid blocking
// the main execution flow.
//
// If the context is already canceled, the function will return immediately.
//
// Use this function for cases where a progress dialog is needed to run
// alongside other operations, with explicit cancellation or termination.
func (z *Zenity) ShowProgress(ctx context.Context, opts dialog.ProgressOptions) error {
z.killZenityFunc()

// ShowProgress starts a Zenity pulsating progress dialog with the given options.
// It returns a cancel function that can be used to cancel the dialog.
func (z *Zenity) ShowProgress(opts dialog.ProgressOptions) (func() error, error) {
args := []string{"--progress"}
if opts.Title != "" {
args = append(args, fmt.Sprintf("--title=%s", opts.Title))
Expand All @@ -121,12 +103,15 @@ func (z *Zenity) ShowProgress(ctx context.Context, opts dialog.ProgressOptions)
// --no-cancel disables the cancel button
args = append(args, "--no-cancel")

err := z.cmdWithWait(ctx, args...)
// --auto-close automatically closes the dialog when stdin is closed
args = append(args, "--auto-close")

cancel, err := z.cmdWithCancel(args...)
if err != nil {
return ctxerr.Wrap(ctx, dialog.ErrUnknown, err.Error())
return nil, fmt.Errorf("failed to start progress dialog: %w", err)
}

return nil
return cancel, nil
}

func execCmdWithOutput(ctx context.Context, args ...string) ([]byte, int, error) {
Expand All @@ -143,19 +128,16 @@ func execCmdWithOutput(ctx context.Context, args ...string) ([]byte, int, error)
return output, exitCode, err
}

func execCmdWithWait(ctx context.Context, args ...string) error {
func execCmdWithCancel(args ...string) (func() error, error) {
var opts []execuser.Option
for _, arg := range args {
opts = append(opts, execuser.WithArg(arg, "")) // Using empty value for positional args
}

_, err := execuser.Run(zenityProcessName, opts...)
return err
}

func killZenityProcesses() {
_, err := platform.KillAllProcessByName(zenityProcessName)
stdin, err := execuser.RunWithStdin(zenityProcessName, opts...)
if err != nil {
log.Warn().Err(err).Msg("failed to kill zenity process")
return nil, err
}

return stdin.Close, err
}
62 changes: 10 additions & 52 deletions orbit/pkg/zenity/zenity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type mockExecCmd struct {
output []byte
exitCode int
capturedArgs []string
waitDuration time.Duration
}

// MockCommandContext simulates exec.CommandContext and captures arguments
Expand All @@ -29,17 +28,10 @@ func (m *mockExecCmd) runWithOutput(ctx context.Context, args ...string) ([]byte
return m.output, m.exitCode, nil
}

func (m *mockExecCmd) runWithWait(ctx context.Context, args ...string) error {
func (m *mockExecCmd) runWithStdin(args ...string) (func() error, error) {
m.capturedArgs = append(m.capturedArgs, args...)

select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(m.waitDuration):

}

return nil
return nil, nil
}

func TestShowEntryArgs(t *testing.T) {
Expand Down Expand Up @@ -76,8 +68,7 @@ func TestShowEntryArgs(t *testing.T) {
output: []byte("some output"),
}
z := &Zenity{
cmdWithOutput: mock.runWithOutput,
killZenityFunc: func() {},
cmdWithOutput: mock.runWithOutput,
}
output, err := z.ShowEntry(ctx, tt.opts)
assert.NoError(t, err)
Expand Down Expand Up @@ -118,8 +109,7 @@ func TestShowEntryError(t *testing.T) {
exitCode: tt.exitCode,
}
z := &Zenity{
cmdWithOutput: mock.runWithOutput,
killZenityFunc: func() {},
cmdWithOutput: mock.runWithOutput,
}
output, err := z.ShowEntry(ctx, dialog.EntryOptions{})
require.ErrorIs(t, err, tt.expectedErr)
Expand All @@ -135,8 +125,7 @@ func TestShowEntrySuccess(t *testing.T) {
output: []byte("some output"),
}
z := &Zenity{
cmdWithOutput: mock.runWithOutput,
killZenityFunc: func() {},
cmdWithOutput: mock.runWithOutput,
}
output, err := z.ShowEntry(ctx, dialog.EntryOptions{})
assert.NoError(t, err)
Expand Down Expand Up @@ -171,8 +160,7 @@ func TestShowInfoArgs(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
mock := &mockExecCmd{}
z := &Zenity{
cmdWithOutput: mock.runWithOutput,
killZenityFunc: func() {},
cmdWithOutput: mock.runWithOutput,
}
err := z.ShowInfo(ctx, tt.opts)
assert.NoError(t, err)
Expand Down Expand Up @@ -207,8 +195,7 @@ func TestShowInfoError(t *testing.T) {
exitCode: tt.exitCode,
}
z := &Zenity{
cmdWithOutput: mock.runWithOutput,
killZenityFunc: func() {},
cmdWithOutput: mock.runWithOutput,
}
err := z.ShowInfo(ctx, dialog.InfoOptions{})
require.ErrorIs(t, err, tt.expectedErr)
Expand All @@ -217,8 +204,6 @@ func TestShowInfoError(t *testing.T) {
}

func TestProgressArgs(t *testing.T) {
ctx := context.Background()

testCases := []struct {
name string
opts dialog.ProgressOptions
Expand All @@ -230,46 +215,19 @@ func TestProgressArgs(t *testing.T) {
Title: "A Title",
Text: "Some text",
},
expectedArgs: []string{"--progress", "--title=A Title", "--text=Some text", "--pulsate", "--no-cancel"},
expectedArgs: []string{"--progress", "--title=A Title", "--text=Some text", "--pulsate", "--no-cancel", "--auto-close"},
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
mock := &mockExecCmd{}
z := &Zenity{
cmdWithWait: mock.runWithWait,
killZenityFunc: func() {},
cmdWithCancel: mock.runWithStdin,
}
err := z.ShowProgress(ctx, tt.opts)
_, err := z.ShowProgress(tt.opts)
assert.NoError(t, err)
assert.Equal(t, tt.expectedArgs, mock.capturedArgs)
})
}
}

func TestProgressKillOnCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

mock := &mockExecCmd{
waitDuration: 5 * time.Second,
}
z := &Zenity{
cmdWithWait: mock.runWithWait,
killZenityFunc: func() {},
}

done := make(chan struct{})
start := time.Now()

go func() {
_ = z.ShowProgress(ctx, dialog.ProgressOptions{})
close(done)
}()

time.Sleep(100 * time.Millisecond)
cancel()
<-done

assert.True(t, time.Since(start) < 5*time.Second)
}
Loading

0 comments on commit 0a3001d

Please sign in to comment.