diff --git a/fs/serve.go b/fs/serve.go index 79c77557..d16e4993 100644 --- a/fs/serve.go +++ b/fs/serve.go @@ -481,10 +481,11 @@ type Server struct { wg sync.WaitGroup } -// Serve serves the FUSE connection by making calls to the methods -// of fs and the Nodes and Handles it makes available. It returns only -// when the connection has been closed or an unexpected error occurs. -func (s *Server) Serve(fs FS) error { +// ServeContext serves the FUSE connection by making calls to the methods +// of fs and the Nodes and Handles it makes available. It returns when the +// context is closed, when the connection has been closed, or an unexpected +// error occurs. +func (s *Server) ServeContext(ctx context.Context, fs FS) error { defer s.wg.Wait() // Wait for worker goroutines to complete before return s.fs = fs @@ -507,22 +508,48 @@ func (s *Server) Serve(fs FS) error { }) s.handle = append(s.handle, nil) - for { - req, err := s.conn.ReadRequest() - if err != nil { - if err == io.EOF { - break + retErr := make(chan error) + + go func() { + for { + req, err := s.conn.ReadRequest() + if err != nil { + if err == io.EOF { + break + } + retErr <- err + return } - return err + + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.serve(ctx, req) + }() } + retErr <- nil + }() - s.wg.Add(1) - go func() { - defer s.wg.Done() - s.serve(req) - }() + select { + case err := <-retErr: + return err + case <-ctx.Done(): + return nil } - return nil +} + +// Serve serves the FUSE connection by making calls to the methods +// of fs and the Nodes and Handles it makes available. It returns only +// when the connection has been closed or an unexpected error occurs. +func (s *Server) Serve(fs FS) error { + return s.ServeContext(context.Background(), fs) +} + +// ServeContext serves a FUSE connection with the default settings. See +// Server.ServeContext. +func ServeContext(ctx context.Context, c *fuse.Conn, fs FS) error { + server := New(c, nil) + return server.ServeContext(ctx, fs) } // Serve serves a FUSE connection with the default settings. See @@ -915,8 +942,8 @@ func (m *logDuplicateRequestID) String() string { return fmt.Sprintf("Duplicate request: new %v, old %v", m.New, m.Old) } -func (c *Server) serve(r fuse.Request) { - ctx, cancel := context.WithCancel(context.Background()) +func (c *Server) serve(originContext context.Context, r fuse.Request) { + ctx, cancel := context.WithCancel(originContext) defer cancel() parentCtx := ctx if c.context != nil {