diff --git a/ci/release/changelogs/next.md b/ci/release/changelogs/next.md index c5c1be154b..59c8f3b634 100644 --- a/ci/release/changelogs/next.md +++ b/ci/release/changelogs/next.md @@ -12,6 +12,7 @@ - Encoding API switches to standard zlib encoding so that decoding doesn't depend on source. [#1709](https://github.com/terrastruct/d2/pull/1709) - `currentcolor` is accepted as a color option to inherit parent colors. (ty @hboomsma) [#1700](https://github.com/terrastruct/d2/pull/1700) - grid containers can now be sized with `width`/`height` even when using a layout plugin without that feature. [#1731](https://github.com/terrastruct/d2/pull/1731) +- Watch mode watches for changes in both the input file and imported files [#1720](https://github.com/terrastruct/d2/pull/1720) #### Bugfixes ⛑️ diff --git a/d2cli/main.go b/d2cli/main.go index 4aae3c4943..793c048396 100644 --- a/d2cli/main.go +++ b/d2cli/main.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "io/fs" "os" "os/exec" "os/user" @@ -332,7 +333,7 @@ func Run(ctx context.Context, ms *xmain.State) (err error) { ctx, cancel := timelib.WithTimeout(ctx, time.Minute*2) defer cancel() - _, written, err := compile(ctx, ms, plugins, layoutFlag, renderOpts, fontFamily, *animateIntervalFlag, inputPath, outputPath, "", *bundleFlag, *forceAppendixFlag, pw.Page) + _, written, err := compile(ctx, ms, plugins, nil, layoutFlag, renderOpts, fontFamily, *animateIntervalFlag, inputPath, outputPath, "", *bundleFlag, *forceAppendixFlag, pw.Page) if err != nil { if written { return fmt.Errorf("failed to fully compile (partial render written) %s: %w", ms.HumanPath(inputPath), err) @@ -367,7 +368,7 @@ func LayoutResolver(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plu } } -func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, layout *string, renderOpts d2svg.RenderOpts, fontFamily *d2fonts.FontFamily, animateInterval int64, inputPath, outputPath, boardPath string, bundle, forceAppendix bool, page playwright.Page) (_ []byte, written bool, _ error) { +func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, fs fs.FS, layout *string, renderOpts d2svg.RenderOpts, fontFamily *d2fonts.FontFamily, animateInterval int64, inputPath, outputPath, boardPath string, bundle, forceAppendix bool, page playwright.Page) (_ []byte, written bool, _ error) { start := time.Now() input, err := ms.ReadPath(inputPath) if err != nil { @@ -385,6 +386,7 @@ func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, la InputPath: inputPath, LayoutResolver: LayoutResolver(ctx, ms, plugins), Layout: layout, + FS: fs, } cancel := background.Repeat(func() { diff --git a/d2cli/watch.go b/d2cli/watch.go index 866a480b3f..b721e7726f 100644 --- a/d2cli/watch.go +++ b/d2cli/watch.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "runtime" + "sort" "strings" "sync" "time" @@ -73,6 +74,7 @@ type watcher struct { l net.Listener staticFileServer http.Handler + boardpathMu sync.Mutex wsclientsMu sync.Mutex closing bool wsclientsWG sync.WaitGroup @@ -218,10 +220,13 @@ func (w *watcher) goFunc(fn func(context.Context) error) { * TODO: Abstract out file system and fsnotify to test this with 100% coverage. See comment in main_test.go */ func (w *watcher) watchLoop(ctx context.Context) error { - lastModified, err := w.ensureAddWatch(ctx) + lastModified := make(map[string]time.Time) + + mt, err := w.ensureAddWatch(ctx, w.inputPath) if err != nil { return err } + lastModified[w.inputPath] = mt w.ms.Log.Info.Printf("compiling %v...", w.ms.HumanPath(w.inputPath)) w.requestCompile() @@ -230,6 +235,8 @@ func (w *watcher) watchLoop(ctx context.Context) error { pollTicker := time.NewTicker(time.Second * 10) defer pollTicker.Stop() + changed := make(map[string]struct{}) + for { select { case <-pollTicker.C: @@ -237,13 +244,18 @@ func (w *watcher) watchLoop(ctx context.Context) error { // getting any more events. // File notification APIs are notoriously unreliable. I've personally experienced // many quirks and so feel this check is justified even if excessive. - mt, err := w.ensureAddWatch(ctx) - if err != nil { - return err + missedChanges := false + for _, watched := range w.fw.WatchList() { + mt, err := w.ensureAddWatch(ctx, watched) + if err != nil { + return err + } + if mt2, ok := lastModified[watched]; !ok || !mt.Equal(mt2) { + missedChanges = true + lastModified[watched] = mt + } } - if !mt.Equal(lastModified) { - // We missed changes. - lastModified = mt + if missedChanges { w.requestCompile() } case ev, ok := <-w.fw.Events: @@ -251,19 +263,20 @@ func (w *watcher) watchLoop(ctx context.Context) error { return errors.New("fsnotify watcher closed") } w.ms.Log.Debug.Printf("received file system event %v", ev) - mt, err := w.ensureAddWatch(ctx) + mt, err := w.ensureAddWatch(ctx, ev.Name) if err != nil { return err } if ev.Op == fsnotify.Chmod { - if mt.Equal(lastModified) { + if mt.Equal(lastModified[ev.Name]) { // Benign Chmod. // See https://github.com/fsnotify/fsnotify/issues/15 continue } // We missed changes. - lastModified = mt + lastModified[ev.Name] = mt } + changed[ev.Name] = struct{}{} // The purpose of eatBurstTimer is to wait at least 16 milliseconds after a sequence of // events to ensure that whomever is editing the file is now done. // @@ -276,8 +289,18 @@ func (w *watcher) watchLoop(ctx context.Context) error { // misleading error. eatBurstTimer.Reset(time.Millisecond * 16) case <-eatBurstTimer.C: - w.ms.Log.Info.Printf("detected change in %v: recompiling...", w.ms.HumanPath(w.inputPath)) + var changedList []string + for k := range changed { + changedList = append(changedList, k) + } + sort.Strings(changedList) + changedStr := w.ms.HumanPath(changedList[0]) + for i := 1; i < len(changed); i++ { + changedStr += fmt.Sprintf(", %s", w.ms.HumanPath(changedList[i])) + } + w.ms.Log.Info.Printf("detected change in %s: recompiling...", changedStr) w.requestCompile() + changed = make(map[string]struct{}) case err, ok := <-w.fw.Errors: if !ok { return errors.New("fsnotify watcher closed") @@ -296,17 +319,17 @@ func (w *watcher) requestCompile() { } } -func (w *watcher) ensureAddWatch(ctx context.Context) (time.Time, error) { +func (w *watcher) ensureAddWatch(ctx context.Context, path string) (time.Time, error) { interval := time.Millisecond * 16 tc := time.NewTimer(0) <-tc.C for { - mt, err := w.addWatch(ctx) + mt, err := w.addWatch(ctx, path) if err == nil { return mt, nil } if interval >= time.Second { - w.ms.Log.Error.Printf("failed to watch inputPath %q: %v (retrying in %v)", w.ms.HumanPath(w.inputPath), err, interval) + w.ms.Log.Error.Printf("failed to watch %q: %v (retrying in %v)", w.ms.HumanPath(path), err, interval) } tc.Reset(interval) @@ -324,19 +347,56 @@ func (w *watcher) ensureAddWatch(ctx context.Context) (time.Time, error) { } } -func (w *watcher) addWatch(ctx context.Context) (time.Time, error) { - err := w.fw.Add(w.inputPath) +func (w *watcher) addWatch(ctx context.Context, path string) (time.Time, error) { + err := w.fw.Add(path) if err != nil { return time.Time{}, err } var d os.FileInfo - d, err = os.Stat(w.inputPath) + d, err = os.Stat(path) if err != nil { return time.Time{}, err } return d.ModTime(), nil } +func (w *watcher) replaceWatchList(ctx context.Context, paths []string) error { + // First remove the files no longer being watched + for _, watched := range w.fw.WatchList() { + if watched == w.inputPath { + continue + } + found := false + for _, p := range paths { + if watched == p { + found = true + break + } + } + if !found { + // Don't mind errors here + w.fw.Remove(watched) + } + } + // Then add the files newly being watched + for _, p := range paths { + found := false + for _, watched := range w.fw.WatchList() { + if watched == p { + found = true + break + } + } + if !found { + _, err := w.ensureAddWatch(ctx, p) + if err != nil { + return err + } + } + } + return nil +} + func (w *watcher) compileLoop(ctx context.Context) error { firstCompile := true for { @@ -364,7 +424,10 @@ func (w *watcher) compileLoop(ctx context.Context) error { w.pw = newPW } - svg, _, err := compile(ctx, w.ms, w.plugins, w.layout, w.renderOpts, w.fontFamily, w.animateInterval, w.inputPath, w.outputPath, w.boardPath, w.bundle, w.forceAppendix, w.pw.Page) + fs := trackedFS{} + w.boardpathMu.Lock() + svg, _, err := compile(ctx, w.ms, w.plugins, &fs, w.layout, w.renderOpts, w.fontFamily, w.animateInterval, w.inputPath, w.outputPath, w.boardPath, w.bundle, w.forceAppendix, w.pw.Page) + w.boardpathMu.Unlock() errs := "" if err != nil { if len(svg) > 0 { @@ -375,6 +438,11 @@ func (w *watcher) compileLoop(ctx context.Context) error { errs = err.Error() w.ms.Log.Error.Print(errs) } + err = w.replaceWatchList(ctx, fs.opened) + if err != nil { + return err + } + w.broadcast(&compileResult{ SVG: string(svg), Scale: w.renderOpts.Scale, @@ -442,13 +510,19 @@ func (w *watcher) handleRoot(hw http.ResponseWriter, r *http.Request) { `, filepath.Base(w.outputPath), w.devMode) + w.boardpathMu.Lock() // if path is "/x.svg", we just want "x" boardPath := strings.TrimPrefix(r.URL.Path, "/") if idx := strings.LastIndexByte(boardPath, '.'); idx != -1 { boardPath = boardPath[:idx] } + recompile := false if boardPath != w.boardPath { w.boardPath = boardPath + recompile = true + } + w.boardpathMu.Unlock() + if recompile { w.requestCompile() } } @@ -574,3 +648,16 @@ func wsHeartbeat(ctx context.Context, c *websocket.Conn) { } } } + +// trackedFS is OS's FS with the addition that it tracks which files are opened successfully +type trackedFS struct { + opened []string +} + +func (tfs *trackedFS) Open(name string) (fs.File, error) { + f, err := os.Open(name) + if err == nil { + tfs.opened = append(tfs.opened, name) + } + return f, err +} diff --git a/e2etests-cli/main_test.go b/e2etests-cli/main_test.go index 272457eb3b..f851c9833a 100644 --- a/e2etests-cli/main_test.go +++ b/e2etests-cli/main_test.go @@ -575,11 +575,8 @@ layers: { // Wait for watch server to spin up and listen urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`) - watchURL := waitLogs(ctx, stderr, urlRE) - - if watchURL == "" { - t.Error(errors.New(stderr.String())) - } + watchURL, err := waitLogs(ctx, stderr, urlRE) + assert.Success(t, err) stderr.Reset() // Start a client @@ -599,8 +596,8 @@ layers: { assert.Success(t, err) successRE := regexp.MustCompile(`broadcasting update to 1 client`) - line := waitLogs(ctx, stderr, successRE) - assert.NotEqual(t, "", line) + _, err = waitLogs(ctx, stderr, successRE) + assert.Success(t, err) }, }, { @@ -631,11 +628,9 @@ layers: { // Wait for watch server to spin up and listen urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`) - watchURL := waitLogs(ctx, stderr, urlRE) + watchURL, err := waitLogs(ctx, stderr, urlRE) + assert.Success(t, err) - if watchURL == "" { - t.Error(errors.New(stderr.String())) - } stderr.Reset() // Start a client @@ -655,8 +650,8 @@ layers: { assert.Success(t, err) successRE := regexp.MustCompile(`broadcasting update to 1 client`) - line := waitLogs(ctx, stderr, successRE) - assert.NotEqual(t, "", line) + _, err = waitLogs(ctx, stderr, successRE) + assert.Success(t, err) }, }, { @@ -685,11 +680,8 @@ layers: { // Wait for watch server to spin up and listen urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`) - watchURL := waitLogs(ctx, stderr, urlRE) - - if watchURL == "" { - t.Error(errors.New(stderr.String())) - } + watchURL, err := waitLogs(ctx, stderr, urlRE) + assert.Success(t, err) stderr.Reset() // Start a client @@ -709,8 +701,82 @@ layers: { assert.Success(t, err) successRE := regexp.MustCompile(`broadcasting update to 1 client`) - line := waitLogs(ctx, stderr, successRE) - assert.NotEqual(t, "", line) + _, err = waitLogs(ctx, stderr, successRE) + assert.Success(t, err) + }, + }, + { + name: "watch-imported-file", + run: func(t *testing.T, ctx context.Context, dir string, env *xos.Env) { + writeFile(t, dir, "a.d2", ` +...@b +`) + writeFile(t, dir, "b.d2", ` +x +`) + stderr := &bytes.Buffer{} + tms := testMain(dir, env, "--watch", "--browser=0", "a.d2") + tms.Stderr = stderr + + tms.Start(t, ctx) + defer func() { + err := tms.Signal(ctx, os.Interrupt) + assert.Success(t, err) + }() + + // Wait for first compilation to finish + doneRE := regexp.MustCompile(`successfully compiled a.d2`) + _, err := waitLogs(ctx, stderr, doneRE) + assert.Success(t, err) + stderr.Reset() + + // Test that writing an imported file will cause recompilation + writeFile(t, dir, "b.d2", ` +x -> y +`) + bRE := regexp.MustCompile(`detected change in b.d2`) + _, err = waitLogs(ctx, stderr, bRE) + assert.Success(t, err) + stderr.Reset() + + // Test burst of both files changing + writeFile(t, dir, "a.d2", ` +...@b +hey +`) + writeFile(t, dir, "b.d2", ` +x +hi +`) + bothRE := regexp.MustCompile(`detected change in a.d2, b.d2`) + _, err = waitLogs(ctx, stderr, bothRE) + assert.Success(t, err) + + // Wait for that compilation to fully finish + _, err = waitLogs(ctx, stderr, doneRE) + assert.Success(t, err) + stderr.Reset() + + // Update the main file to no longer have that dependency + writeFile(t, dir, "a.d2", ` +a +`) + _, err = waitLogs(ctx, stderr, doneRE) + assert.Success(t, err) + stderr.Reset() + + // Change b + writeFile(t, dir, "b.d2", ` +y +`) + // Change a to retrigger compilation + // The test works by seeing that the report only says "a" changed, otherwise testing for omission of compilation from "b" would require waiting + writeFile(t, dir, "a.d2", ` +c +`) + + _, err = waitLogs(ctx, stderr, doneRE) + assert.Success(t, err) }, }, } @@ -810,7 +876,9 @@ func getNumBoards(svg string) int { return strings.Count(svg, `class="d2`) } -func waitLogs(ctx context.Context, buf *bytes.Buffer, pattern *regexp.Regexp) string { +var errRE = regexp.MustCompile(`err:`) + +func waitLogs(ctx context.Context, buf *bytes.Buffer, pattern *regexp.Regexp) (string, error) { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() var match string @@ -819,13 +887,20 @@ func waitLogs(ctx context.Context, buf *bytes.Buffer, pattern *regexp.Regexp) st case <-ticker.C: out := buf.String() match = pattern.FindString(out) + errMatch := errRE.FindString(out) + if errMatch != "" { + return "", errors.New(buf.String()) + } case <-ctx.Done(): ticker.Stop() - return "" + return "", fmt.Errorf("could not match pattern in log. logs: %s", buf.String()) } } + if match == "" { + return "", errors.New(buf.String()) + } - return match + return match, nil } func getWatchPage(ctx context.Context, t *testing.T, page string) error {