diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a00d5f..3938593 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,14 @@ ## [[unpublished]](https://github.com/mlange-42/arche-model/compare/v0.7.0...main) +### Features + +* Adds `reporter.Callback` for direct retrieval of observer output in Go code (#61) + ### Bugfixes * Fix typo in error message when adding UI system as normal system (#60) +* Fix reporters did not work with unspecified `UpdateInterval` (#61) ## [[v0.7.0]](https://github.com/mlange-42/arche-model/compare/v0.6.0...v0.7.0) diff --git a/reporter/callback.go b/reporter/callback.go new file mode 100644 index 0000000..cd64637 --- /dev/null +++ b/reporter/callback.go @@ -0,0 +1,40 @@ +package reporter + +import ( + "github.com/mlange-42/arche-model/observer" + "github.com/mlange-42/arche/ecs" +) + +// Callback reporter calling a function on each update. +type Callback struct { + Observer observer.Row // Observer to get data from. + UpdateInterval int // Update/print interval in model ticks. + HeaderCallback func(header []string) + Callback func(step int, row []float64) + step int64 +} + +// Initialize the system +func (s *Callback) Initialize(w *ecs.World) { + s.Observer.Initialize(w) + if s.UpdateInterval == 0 { + s.UpdateInterval = 1 + } + if s.HeaderCallback != nil { + s.HeaderCallback(s.Observer.Header()) + } + s.step = 0 +} + +// Update the system +func (s *Callback) Update(w *ecs.World) { + s.Observer.Update(w) + if s.step%int64(s.UpdateInterval) == 0 { + values := s.Observer.Values(w) + s.Callback(int(s.step), values) + } + s.step++ +} + +// Finalize the system +func (s *Callback) Finalize(w *ecs.World) {} diff --git a/reporter/callback_test.go b/reporter/callback_test.go new file mode 100644 index 0000000..728949d --- /dev/null +++ b/reporter/callback_test.go @@ -0,0 +1,35 @@ +package reporter_test + +import ( + "fmt" + + "github.com/mlange-42/arche-model/model" + "github.com/mlange-42/arche-model/reporter" + "github.com/mlange-42/arche-model/system" +) + +func ExampleCallback() { + // Create a new model. + m := model.New() + + data := [][]float64{} + + // Add a Print reporter with an Observer. + m.AddSystem(&reporter.Callback{ + Observer: &ExampleObserver{}, + Callback: func(step int, row []float64) { + data = append(data, row) + }, + HeaderCallback: func(header []string) {}, + }) + + // Add a termination system that ends the simulation. + m.AddSystem(&system.FixedTermination{Steps: 3}) + + // Run the simulation. + m.Run() + + fmt.Println(data) + // Output: + // [[1 2 3] [1 2 3] [1 2 3]] +} diff --git a/reporter/csv.go b/reporter/csv.go index d5481df..77d71b6 100644 --- a/reporter/csv.go +++ b/reporter/csv.go @@ -28,6 +28,9 @@ type CSV struct { func (s *CSV) Initialize(w *ecs.World) { s.Observer.Initialize(w) s.header = s.Observer.Header() + if s.UpdateInterval == 0 { + s.UpdateInterval = 1 + } if s.Sep == "" { s.Sep = "," } diff --git a/reporter/csv_example_test.go b/reporter/csv_example_test.go index ef7f0ce..8bb822a 100644 --- a/reporter/csv_example_test.go +++ b/reporter/csv_example_test.go @@ -13,10 +13,9 @@ func ExampleCSV() { // Add a CSV reporter with an Observer. m.AddSystem(&reporter.CSV{ - Observer: &ExampleObserver{}, - File: "../out/test.csv", - Sep: ";", - UpdateInterval: 10, + Observer: &ExampleObserver{}, + File: "../out/test.csv", + Sep: ";", }) // Add a termination system that ends the simulation. diff --git a/reporter/csv_snaphot.go b/reporter/csv_snaphot.go index ac23181..a2c0487 100644 --- a/reporter/csv_snaphot.go +++ b/reporter/csv_snaphot.go @@ -27,6 +27,9 @@ type SnapshotCSV struct { func (s *SnapshotCSV) Initialize(w *ecs.World) { s.Observer.Initialize(w) s.header = s.Observer.Header() + if s.UpdateInterval == 0 { + s.UpdateInterval = 1 + } if s.Sep == "" { s.Sep = "," diff --git a/reporter/csv_snapshot_example_test.go b/reporter/csv_snapshot_example_test.go index 25159d4..bd0be56 100644 --- a/reporter/csv_snapshot_example_test.go +++ b/reporter/csv_snapshot_example_test.go @@ -13,10 +13,9 @@ func ExampleSnapshotCSV() { // Add a SnapshotCSV reporter with an Observer. m.AddSystem(&reporter.SnapshotCSV{ - Observer: &ExampleSnapshotObserver{}, - FilePattern: "../out/test-%06d.csv", - Sep: ";", - UpdateInterval: 10, + Observer: &ExampleSnapshotObserver{}, + FilePattern: "../out/test-%06d.csv", + Sep: ";", }) // Add a termination system that ends the simulation. diff --git a/reporter/csv_snapshot_test.go b/reporter/csv_snapshot_test.go index 786e38f..91ed4ef 100644 --- a/reporter/csv_snapshot_test.go +++ b/reporter/csv_snapshot_test.go @@ -14,9 +14,8 @@ func TestSnapshotCSV(t *testing.T) { m := model.New() m.AddSystem(&reporter.SnapshotCSV{ - Observer: &ExampleSnapshotObserver{}, - FilePattern: "../out/test-%06d.csv", - UpdateInterval: 10, + Observer: &ExampleSnapshotObserver{}, + FilePattern: "../out/test-%06d.csv", }) m.AddSystem(&system.FixedTermination{Steps: 100}) diff --git a/reporter/csv_test.go b/reporter/csv_test.go index d675b72..5e6245d 100644 --- a/reporter/csv_test.go +++ b/reporter/csv_test.go @@ -14,9 +14,8 @@ func TestCSV(t *testing.T) { m := model.New() m.AddSystem(&reporter.CSV{ - Observer: &ExampleObserver{}, - File: "../out/test.csv", - UpdateInterval: 10, + Observer: &ExampleObserver{}, + File: "../out/test.csv", }) m.AddSystem(&system.FixedTermination{Steps: 100}) diff --git a/reporter/print.go b/reporter/print.go index 4229a46..5a5ee34 100644 --- a/reporter/print.go +++ b/reporter/print.go @@ -19,6 +19,9 @@ type Print struct { func (s *Print) Initialize(w *ecs.World) { s.Observer.Initialize(w) s.header = s.Observer.Header() + if s.UpdateInterval == 0 { + s.UpdateInterval = 1 + } s.step = 0 } diff --git a/reporter/print_test.go b/reporter/print_test.go index 67d2f63..53e9ef1 100644 --- a/reporter/print_test.go +++ b/reporter/print_test.go @@ -12,12 +12,11 @@ func ExamplePrint() { // Add a Print reporter with an Observer. m.AddSystem(&reporter.Print{ - Observer: &ExampleObserver{}, - UpdateInterval: 10, + Observer: &ExampleObserver{}, }) // Add a termination system that ends the simulation. - m.AddSystem(&system.FixedTermination{Steps: 20}) + m.AddSystem(&system.FixedTermination{Steps: 3}) // Run the simulation. m.Run() @@ -26,4 +25,6 @@ func ExamplePrint() { // [1 2 3] // [A B C] // [1 2 3] + // [A B C] + // [1 2 3] }