Skip to content

Commit

Permalink
Add reporter.TableCallback for direct retrieval of table observer o…
Browse files Browse the repository at this point in the history
…utput in Go code
  • Loading branch information
mlange-42 committed May 23, 2024
1 parent 06db52d commit 9b1b07c
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 11 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

## [[unpublished]](https://github.com/mlange-42/arche-model/compare/v0.8.0...main)

### Breaking changes

* Renames `reporter.Callback` to `reporter.RowCallback` (#66)

### Features

* Adds option `Final` to `reporter.Callback` to report data once on finalize instead of on ticks (#65)
* Adds option `Final` to `reporter.RowCallback` to report data once on finalize instead of on ticks (#65)
* Adds `reporter.TableCallback` for direct retrieval of table observer output in Go code (#66)

## [[v0.8.1]](https://github.com/mlange-42/arche-model/compare/v0.8.0...v0.8.1)

Expand All @@ -16,7 +21,7 @@

### Features

* Adds `reporter.Callback` for direct retrieval of observer output in Go code (#61)
* Adds `reporter.Callback` for direct retrieval of row observer output in Go code (#61)

### Bugfixes

Expand Down
53 changes: 48 additions & 5 deletions reporter/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"github.com/mlange-42/arche/ecs"
)

// Callback reporter calling a function on each update.
type Callback struct {
// RowCallback reporter calling a function on each update, using an [observer.Row].
type RowCallback struct {
Observer observer.Row // Observer to get data from.
UpdateInterval int // Update interval in model ticks.
HeaderCallback func(header []string) // Called with the header of the observer during initialization.
Expand All @@ -16,7 +16,7 @@ type Callback struct {
}

// Initialize the system
func (s *Callback) Initialize(w *ecs.World) {
func (s *RowCallback) Initialize(w *ecs.World) {
s.Observer.Initialize(w)
if s.UpdateInterval == 0 {
s.UpdateInterval = 1
Expand All @@ -28,7 +28,7 @@ func (s *Callback) Initialize(w *ecs.World) {
}

// Update the system
func (s *Callback) Update(w *ecs.World) {
func (s *RowCallback) Update(w *ecs.World) {
s.Observer.Update(w)

if !s.Final && s.step%int64(s.UpdateInterval) == 0 {
Expand All @@ -40,7 +40,50 @@ func (s *Callback) Update(w *ecs.World) {
}

// Finalize the system
func (s *Callback) Finalize(w *ecs.World) {
func (s *RowCallback) Finalize(w *ecs.World) {
if !s.Final {
return
}
values := s.Observer.Values(w)
s.Callback(int(s.step), values)
}

// RowCallback reporter calling a function on each update, using an [observer.Table].
type TableCallback struct {
Observer observer.Table // Observer to get data from.
UpdateInterval int // Update interval in model ticks.
HeaderCallback func(header []string) // Called with the header of the observer during initialization.
Callback func(step int, table [][]float64) // Called with step and data table on each update (subject to UpdateInterval).
Final bool // Whether Callback should be called on finalization only, instead of on every tick.
step int64
}

// Initialize the system
func (s *TableCallback) Initialize(w *ecs.World) {
s.Observer.Initialize(w)
if s.UpdateInterval == 0 {
s.UpdateInterval = 1
}
if s.HeaderCallback != nil {
s.HeaderCallback(s.Observer.Header())
}
s.step = 0
}

// Update the system
func (s *TableCallback) Update(w *ecs.World) {
s.Observer.Update(w)

if !s.Final && s.step%int64(s.UpdateInterval) == 0 {
values := s.Observer.Values(w)
s.Callback(int(s.step), values)
}

s.step++
}

// Finalize the system
func (s *TableCallback) Finalize(w *ecs.World) {
if !s.Final {
return
}
Expand Down
52 changes: 48 additions & 4 deletions reporter/callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import (
"github.com/stretchr/testify/assert"
)

func ExampleCallback() {
func ExampleRowCallback() {
// Create a new model.
m := model.New()

data := [][]float64{}

// Add a Print reporter with an Observer.
m.AddSystem(&reporter.Callback{
m.AddSystem(&reporter.RowCallback{
Observer: &ExampleObserver{},
Callback: func(step int, row []float64) {
data = append(data, row)
Expand All @@ -36,11 +36,37 @@ func ExampleCallback() {
// [[1 2 3] [1 2 3] [1 2 3]]
}

func TestCallbackFinal(t *testing.T) {
func ExampleTableCallback() {
// Create a new model.
m := model.New()

data := [][]float64{}

// Add a Print reporter with an Observer.
m.AddSystem(&reporter.TableCallback{
Observer: &ExampleSnapshotObserver{},
Callback: func(step int, table [][]float64) {
data = append(data, table...)
},
HeaderCallback: func(header []string) {},
})

// Add a termination system that ends the simulation.
m.AddSystem(&system.FixedTermination{Steps: 3})

// Run the simulation.
m.Run()

fmt.Println(data)
// Output:
// [[1 2 3] [1 2 3] [1 2 3] [1 2 3] [1 2 3] [1 2 3] [1 2 3] [1 2 3] [1 2 3]]
}

func TestRowCallbackFinal(t *testing.T) {
m := model.New()
counter := 0

m.AddSystem(&reporter.Callback{
m.AddSystem(&reporter.RowCallback{
Observer: &ExampleObserver{},
Callback: func(step int, row []float64) {
counter++
Expand All @@ -53,3 +79,21 @@ func TestCallbackFinal(t *testing.T) {

assert.Equal(t, 1, counter)
}

func TestTableCallbackFinal(t *testing.T) {
m := model.New()
counter := 0

m.AddSystem(&reporter.TableCallback{
Observer: &ExampleSnapshotObserver{},
Callback: func(step int, table [][]float64) {
counter++
},
HeaderCallback: func(header []string) {},
Final: true,
})
m.AddSystem(&system.FixedTermination{Steps: 3})
m.Run()

assert.Equal(t, 1, counter)
}

0 comments on commit 9b1b07c

Please sign in to comment.