diff --git a/cmd/root.go b/cmd/root.go index c479c73..2ce8216 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -21,6 +21,7 @@ import ( "github.com/asmaloney/gactar/util/executil" "github.com/asmaloney/gactar/util/filesystem" "github.com/asmaloney/gactar/util/frameworkutil" + "github.com/asmaloney/gactar/util/runoptions" "github.com/asmaloney/gactar/util/version" ) @@ -43,7 +44,10 @@ var ( flagVersion = false // options for the default command line mode - defaultModeOptions defaultmode.Options + defaultModeRunAfterGeneration bool + defaultModeLogLevel string + defaultModeTraceActivations bool + defaultModeRandomSeed uint32 ) type errRequiresSubcommand struct { @@ -97,9 +101,30 @@ var rootCmd = &cobra.Command{ return err } - defaultModeOptions.FileList = args + options := defaultmode.CommandLineOptions{ + FileList: args, + RunAfterGeneration: defaultModeRunAfterGeneration, + } + + // validate & override options + if cmd.Flags().Changed("trace") { + options.TraceActivations = &defaultModeTraceActivations + } + + if cmd.Flags().Changed("logging") { + if !runoptions.ValidLogLevel(defaultModeLogLevel) { + return runoptions.ErrInvalidLogLevel{Level: defaultModeLogLevel} + } + + logLevel := runoptions.ACTRLogLevel(defaultModeLogLevel) + options.LogLevel = &logLevel + } + + if cmd.Flags().Changed("rand") { + options.RandomSeed = &defaultModeRandomSeed + } - s, err := defaultmode.Initialize(settings, defaultModeOptions) + s, err := defaultmode.Initialize(settings, options) if err != nil { return err } @@ -152,7 +177,10 @@ func init() { // Local flags - only run when this action is called directly. rootCmd.Flags().BoolVarP(&flagVersion, "version", "v", false, "output the version and quit") // Run options for default command line mode. - rootCmd.Flags().BoolVarP(&defaultModeOptions.RunAfterGeneration, "run", "r", false, "run the models after generating the code") + rootCmd.Flags().BoolVarP(&defaultModeRunAfterGeneration, "run", "r", false, "run the models after generating the code") + rootCmd.Flags().StringVarP(&defaultModeLogLevel, "logging", "l", defaultModeLogLevel, fmt.Sprintf("logging level - valid options: %s", strings.Join(runoptions.ACTRLoggingLevels, ", "))) + rootCmd.Flags().BoolVarP(&defaultModeTraceActivations, "trace", "t", false, "output trace activations") + rootCmd.Flags().Uint32VarP(&defaultModeRandomSeed, "seed", "s", 0, "set the random number seed") rootCmd.MarkFlagsMutuallyExclusive("run", "version") rootCmd.SetGlobalNormalizationFunc(normalizeAliasFlagsFunc) diff --git a/modes/defaultmode/defaultmode.go b/modes/defaultmode/defaultmode.go index 4e78702..6b4be5b 100644 --- a/modes/defaultmode/defaultmode.go +++ b/modes/defaultmode/defaultmode.go @@ -13,6 +13,7 @@ import ( "github.com/asmaloney/gactar/util/chalk" "github.com/asmaloney/gactar/util/cli" "github.com/asmaloney/gactar/util/filesystem" + "github.com/asmaloney/gactar/util/runoptions" "github.com/asmaloney/gactar/util/validate" ) @@ -22,24 +23,28 @@ var ( ErrNoValidModels = errors.New("no valid models to run") ) -type Options struct { +// CommandLineOptions come from the command line. +type CommandLineOptions struct { FileList []string RunAfterGeneration bool + + // these override any options from the model + runoptions.Options } type DefaultMode struct { settings *cli.Settings - runOptions Options + commandLineOptions CommandLineOptions } -func Initialize(settings *cli.Settings, options Options) (d *DefaultMode, err error) { +func Initialize(settings *cli.Settings, options CommandLineOptions) (d *DefaultMode, err error) { // Check if files exist first if len(options.FileList) == 0 { return nil, ErrNoInputFiles } - existingFiles := make([]string, len(options.FileList)) + var existingFiles []string for _, file := range options.FileList { if _, fileErr := os.Stat(file); errors.Is(fileErr, os.ErrNotExist) { @@ -57,11 +62,11 @@ func Initialize(settings *cli.Settings, options Options) (d *DefaultMode, err er } d = &DefaultMode{ - settings: settings, - runOptions: options, + settings: settings, + commandLineOptions: options, } - d.runOptions.FileList = existingFiles + d.commandLineOptions.FileList = existingFiles return } @@ -69,21 +74,21 @@ func Initialize(settings *cli.Settings, options Options) (d *DefaultMode, err er func (d *DefaultMode) Start() (err error) { fmt.Printf("Intermediate file path: %q\n", d.settings.TempPath) - err = generateCode(d.settings.ActiveFrameworks, d.runOptions.FileList, d.settings.TempPath) + err = d.generateCode() if err != nil { return err } - if d.runOptions.RunAfterGeneration { - runCode(d.settings.ActiveFrameworks) + if d.commandLineOptions.RunAfterGeneration { + d.runCode(d.settings.ActiveFrameworks) } return } -func generateCode(frameworks framework.List, files []string, outputDir string) (err error) { +func (d *DefaultMode) generateCode() (err error) { modelMap := map[string]*actr.Model{} - for _, file := range files { + for _, file := range d.commandLineOptions.FileList { fmt.Printf("Generating model for %s\n", file) model, log, modelErr := amod.GenerateModelFromFile(file) if modelErr != nil { @@ -103,7 +108,7 @@ func generateCode(frameworks framework.List, files []string, outputDir string) ( return ErrNoValidModels } - for _, f := range frameworks { + for _, f := range d.settings.ActiveFrameworks { fmt.Printf(" %s\n", f.Info().Name) for file, model := range modelMap { fmt.Printf("\t- generating code for %s\n", file) @@ -120,7 +125,9 @@ func generateCode(frameworks framework.List, files []string, outputDir string) ( continue } - fileName, err := f.WriteModel(outputDir, &model.DefaultParams) + options := overrideRunOptions(&model.DefaultParams, &d.commandLineOptions.Options) + + fileName, err := f.WriteModel(d.settings.TempPath, options) if err != nil { fmt.Println(err.Error()) continue @@ -132,10 +139,13 @@ func generateCode(frameworks framework.List, files []string, outputDir string) ( return } -func runCode(frameworks framework.List) { +func (d *DefaultMode) runCode(frameworks framework.List) { for _, f := range frameworks { model := f.Model() - result, err := f.Run(&model.DefaultParams) + + options := overrideRunOptions(&model.DefaultParams, &d.commandLineOptions.Options) + + result, err := f.Run(options) if err != nil { fmt.Println(err.Error()) continue @@ -146,3 +156,22 @@ func runCode(frameworks framework.List) { fmt.Println() } } + +// overrideRunOptions overrides options set in the model with any set on the command line. +func overrideRunOptions(modelOptions, cliOptions *runoptions.Options) *runoptions.Options { + options := *modelOptions + + if cliOptions.LogLevel != nil { + options.LogLevel = cliOptions.LogLevel + } + + if cliOptions.TraceActivations != nil { + options.TraceActivations = cliOptions.TraceActivations + } + + if cliOptions.RandomSeed != nil { + options.RandomSeed = cliOptions.RandomSeed + } + + return &options +}