Skip to content

Commit

Permalink
Separate run options from model
Browse files Browse the repository at this point in the history
Instead of storing them directly on the model, pass them as a param.

The model now stores the default which come from the amod file. These may overridden by the command line or the web API.
  • Loading branch information
asmaloney committed Mar 1, 2024
1 parent ddcc20b commit 5ad204c
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 76 deletions.
12 changes: 7 additions & 5 deletions actr/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ type Model struct {

Productions []*Production

runoptions.Options
// These defaults come from the amod file and may be overridden on the command line
// or by web requests.
DefaultParams runoptions.Options

// Used to validate our parameters
parameters param.ParametersInterface
Expand Down Expand Up @@ -80,7 +82,7 @@ func (model *Model) Initialize() {
model.Procedural = modules.NewProcedural()
model.Modules = append(model.Modules, model.Procedural)

model.LogLevel = "info"
model.DefaultParams = runoptions.New()

// Declare our parameters
loggingParam := param.NewStr(
Expand Down Expand Up @@ -293,16 +295,16 @@ func (model *Model) SetParam(kv *keyvalue.KeyValue) (err error) {

switch kv.Key {
case "log_level":
model.LogLevel = runoptions.ACTRLogLevel(*value.Str)
model.DefaultParams.LogLevel = runoptions.ACTRLogLevel(*value.Str)

case "trace_activations":
boolVal, _ := value.AsBool() // already validated
model.TraceActivations = boolVal
model.DefaultParams.TraceActivations = boolVal

case "random_seed":
seed := uint32(*value.Number)

model.RandomSeed = &seed
model.DefaultParams.RandomSeed = &seed

default:
return param.ErrUnrecognizedOption{Option: kv.Key}
Expand Down
37 changes: 19 additions & 18 deletions framework/ccm_pyactr/ccm_pyactr.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/asmaloney/gactar/util/filesystem"
"github.com/asmaloney/gactar/util/issues"
"github.com/asmaloney/gactar/util/numbers"
"github.com/asmaloney/gactar/util/runoptions"
)

//go:embed ccm_print.py
Expand Down Expand Up @@ -115,8 +116,8 @@ func (c CCMPyACTR) Model() (model *actr.Model) {

// Run generates the python code from the amod file, writes it to disk, creates a "run" file
// to actually run the model, and returns the output (stdout and stderr combined).
func (c *CCMPyACTR) Run(initialBuffers framework.InitialBuffers) (result *framework.RunResult, err error) {
runFile, err := c.WriteModel(c.tmpPath, initialBuffers)
func (c *CCMPyACTR) Run(options *runoptions.Options, initialBuffers framework.InitialBuffers) (result *framework.RunResult, err error) {
runFile, err := c.WriteModel(c.tmpPath, options, initialBuffers)
if err != nil {
return
}
Expand All @@ -137,7 +138,7 @@ func (c *CCMPyACTR) Run(initialBuffers framework.InitialBuffers) (result *framew
}

// WriteModel converts the internal actr.Model to Python and writes it to a file.
func (c *CCMPyACTR) WriteModel(path string, initialBuffers framework.InitialBuffers) (outputFileName string, err error) {
func (c *CCMPyACTR) WriteModel(path string, options *runoptions.Options, initialBuffers framework.InitialBuffers) (outputFileName string, err error) {
// If our model has a print statement, then write out our support file
if c.model.HasPrintStatement() {
err = framework.WriteSupportFile(path, ccmPrintFileName, ccmPrintPython)
Expand All @@ -147,7 +148,7 @@ func (c *CCMPyACTR) WriteModel(path string, initialBuffers framework.InitialBuff
}

// If our model is tracing activations, then write out our support file
if c.model.TraceActivations {
if options.TraceActivations {
err = framework.WriteSupportFile(path, gactarActivateTraceFileName, gactarActivateTraceFile)
if err != nil {
return
Expand All @@ -164,7 +165,7 @@ func (c *CCMPyACTR) WriteModel(path string, initialBuffers framework.InitialBuff
return "", err
}

_, err = c.GenerateCode(initialBuffers)
_, err = c.GenerateCode(options, initialBuffers)
if err != nil {
return
}
Expand All @@ -178,7 +179,7 @@ func (c *CCMPyACTR) WriteModel(path string, initialBuffers framework.InitialBuff
}

// GenerateCode converts the internal actr.Model to Python code.
func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []byte, err error) {
func (c *CCMPyACTR) GenerateCode(options *runoptions.Options, initialBuffers framework.InitialBuffers) (code []byte, err error) {
patterns, err := framework.ParseInitialBuffers(c.model, initialBuffers)
if err != nil {
return
Expand All @@ -195,13 +196,13 @@ func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code

memory := c.model.Memory

c.writeImports()
c.writeImports(options)

c.Write("\n\n")

// random
if c.model.RandomSeed != nil {
c.Writeln("random.seed(%d)", *c.model.RandomSeed)
if options.RandomSeed != nil {
c.Writeln("random.seed(%d)", *options.RandomSeed)
c.Write("\n\n")
}

Expand Down Expand Up @@ -237,7 +238,7 @@ func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code
c.Writeln(" %s = Memory(%s)", memory.ModuleName(), memory.BufferName())
}

if c.model.TraceActivations {
if options.TraceActivations {
c.Writeln(" trace = ActivateTrace(%s)", memory.ModuleName())
}

Expand Down Expand Up @@ -286,7 +287,7 @@ func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code
c.Writeln("")
}

if c.model.LogLevel == "info" {
if options.LogLevel == "info" {
// this turns on some logging at the high level
c.Writeln(" def __init__(self):")
c.Writeln(" super().__init__(log=True)")
Expand All @@ -308,7 +309,7 @@ func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code

c.Writeln("")

c.writeMain()
c.writeMain(options)

code = c.GetContents()
return
Expand Down Expand Up @@ -346,8 +347,8 @@ func (c CCMPyACTR) writeAuthors() {
c.Writeln("")
}

func (c CCMPyACTR) writeImports() {
if c.model.RandomSeed != nil {
func (c CCMPyACTR) writeImports(runOptions *runoptions.Options) {
if runOptions.RandomSeed != nil {
c.Writeln("import random")
}

Expand Down Expand Up @@ -379,7 +380,7 @@ func (c CCMPyACTR) writeImports() {
c.Write("from python_actr import %s\n", strings.Join(additionalImports, ", "))
}

if c.model.LogLevel == "detail" {
if runOptions.LogLevel == "detail" {
c.Writeln("from python_actr import log, log_everything")
}

Expand All @@ -388,7 +389,7 @@ func (c CCMPyACTR) writeImports() {
c.Writeln(fmt.Sprintf("from %s import CCMPrint", ccmPrintImportName))
}

if c.model.TraceActivations {
if runOptions.TraceActivations {
c.Writeln("")
c.Writeln(fmt.Sprintf("from %s import ActivateTrace", gactarActivateTraceImportName))
}
Expand Down Expand Up @@ -489,11 +490,11 @@ func (c CCMPyACTR) writeProductions() {
}
}

func (c CCMPyACTR) writeMain() {
func (c CCMPyACTR) writeMain(runOptions *runoptions.Options) {
c.Writeln("if __name__ == \"__main__\":")
c.Writeln(fmt.Sprintf(" model = %s()", c.className))

if c.model.LogLevel == "detail" {
if runOptions.LogLevel == "detail" {
c.Writeln(" log(summary=1)")
c.Writeln(" log_everything(model)")
}
Expand Down
7 changes: 4 additions & 3 deletions framework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/asmaloney/gactar/actr"

"github.com/asmaloney/gactar/util/issues"
"github.com/asmaloney/gactar/util/runoptions"
"github.com/asmaloney/gactar/util/version"
)

Expand Down Expand Up @@ -50,9 +51,9 @@ type Framework interface {
SetModel(model *actr.Model) (err error)
Model() (model *actr.Model)

Run(initialBuffers InitialBuffers) (result *RunResult, err error)
WriteModel(path string, initialBuffers InitialBuffers) (outputFileName string, err error)
GenerateCode(initialBuffers InitialBuffers) (code []byte, err error)
Run(options *runoptions.Options, initialBuffers InitialBuffers) (result *RunResult, err error)
WriteModel(path string, options *runoptions.Options, initialBuffers InitialBuffers) (outputFileName string, err error)
GenerateCode(options *runoptions.Options, initialBuffers InitialBuffers) (code []byte, err error)
}

type List map[string]Framework
Expand Down
31 changes: 16 additions & 15 deletions framework/pyactr/pyactr.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/asmaloney/gactar/util/filesystem"
"github.com/asmaloney/gactar/util/issues"
"github.com/asmaloney/gactar/util/numbers"
"github.com/asmaloney/gactar/util/runoptions"
)

//go:embed pyactr_print.py
Expand Down Expand Up @@ -132,8 +133,8 @@ func (p PyACTR) Model() (model *actr.Model) {
return p.model
}

func (p *PyACTR) Run(initialBuffers framework.InitialBuffers) (result *framework.RunResult, err error) {
runFile, err := p.WriteModel(p.tmpPath, initialBuffers)
func (p *PyACTR) Run(options *runoptions.Options, initialBuffers framework.InitialBuffers) (result *framework.RunResult, err error) {
runFile, err := p.WriteModel(p.tmpPath, options, initialBuffers)
if err != nil {
return
}
Expand All @@ -156,7 +157,7 @@ func (p *PyACTR) Run(initialBuffers framework.InitialBuffers) (result *framework
}

// WriteModel converts the internal actr.Model to Python and writes it to a file.
func (p *PyACTR) WriteModel(path string, initialBuffers framework.InitialBuffers) (outputFileName string, err error) {
func (p *PyACTR) WriteModel(path string, options *runoptions.Options, initialBuffers framework.InitialBuffers) (outputFileName string, err error) {
// If our model has a print statement, then write out our support file
if p.model.HasPrintStatement() {
err = framework.WriteSupportFile(path, pyactrPrintFileName, pyactrPrintPython)
Expand All @@ -175,7 +176,7 @@ func (p *PyACTR) WriteModel(path string, initialBuffers framework.InitialBuffers
return "", err
}

_, err = p.GenerateCode(initialBuffers)
_, err = p.GenerateCode(options, initialBuffers)
if err != nil {
return
}
Expand All @@ -189,7 +190,7 @@ func (p *PyACTR) WriteModel(path string, initialBuffers framework.InitialBuffers
}

// GenerateCode converts the internal actr.Model to Python code.
func (p *PyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []byte, err error) {
func (p *PyACTR) GenerateCode(options *runoptions.Options, initialBuffers framework.InitialBuffers) (code []byte, err error) {
patterns, err := framework.ParseInitialBuffers(p.model, initialBuffers)
if err != nil {
return
Expand All @@ -204,13 +205,13 @@ func (p *PyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []b

p.writeHeader()

p.writeImports()
p.writeImports(options)

p.Writeln("")

// random
if p.model.RandomSeed != nil {
p.Writeln("numpy.random.seed(%d)\n", *p.model.RandomSeed)
if options.RandomSeed != nil {
p.Writeln("numpy.random.seed(%d)\n", *options.RandomSeed)
}

memory := p.model.Memory
Expand Down Expand Up @@ -253,7 +254,7 @@ func (p *PyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []b
p.Writeln(" rule_firing=%s,", numbers.Float64Str(*procedural.DefaultActionTime))
}

if p.model.TraceActivations {
if options.TraceActivations {
p.Writeln(" activation_trace=True,")
}

Expand Down Expand Up @@ -329,7 +330,7 @@ func (p *PyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []b
p.Writeln("")

// ...add our code to run
p.writeMain()
p.writeMain(options)

code = p.GetContents()
return
Expand Down Expand Up @@ -367,8 +368,8 @@ func (p PyACTR) writeAuthors() {
p.Writeln("")
}

func (p PyACTR) writeImports() {
if p.model.RandomSeed != nil {
func (p PyACTR) writeImports(runOptions *runoptions.Options) {
if runOptions.RandomSeed != nil {
p.Writeln("import numpy")
}

Expand Down Expand Up @@ -491,20 +492,20 @@ func (p PyACTR) writeProductions() {
}
}

func (p PyACTR) writeMain() {
func (p PyACTR) writeMain(runOptions *runoptions.Options) {
p.Writeln("# Main")
p.Writeln("if __name__ == '__main__':")

options := []string{"gui=False"}

if p.model.LogLevel == "min" {
if runOptions.LogLevel == "min" {
options = append(options, "trace=False")
}

p.Writeln(" sim = %s.simulation( %s )", p.className, strings.Join(options, ", "))
p.Writeln(" sim.run()")

if p.model.LogLevel != "min" {
if runOptions.LogLevel != "min" {
p.Writeln(" if goal.test_buffer('full'):")
p.Writeln(" print('chunk left in goal: ' + str(goal.pop()))")
p.Writeln(" if %s.retrieval.test_buffer('full'):", p.className)
Expand Down
2 changes: 1 addition & 1 deletion framework/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func GenerateCodeFromFile(fw Framework, inputFile string, initialBuffers Initial
return
}

code, err = fw.GenerateCode(initialBuffers)
code, err = fw.GenerateCode(&model.DefaultParams, initialBuffers)
if err != nil {
return
}
Expand Down
Loading

0 comments on commit 5ad204c

Please sign in to comment.