diff --git a/pkg/air/expr.go b/pkg/air/expr.go index 8675f60..f6c737c 100644 --- a/pkg/air/expr.go +++ b/pkg/air/expr.go @@ -3,6 +3,7 @@ package air import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/util" ) // Expr represents an expression in the Arithmetic Intermediate Representation @@ -33,6 +34,10 @@ type Expr interface { // Equate one expression with another Equate(Expr) Expr + + // Determine the maximum shift in this expression in either the negative + // (left) or positive direction (right). + MaxShift() util.Pair[uint, uint] } // Add represents the sum over zero or more expressions. @@ -50,6 +55,10 @@ func (p *Add) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} } // Equate one expression with another (equivalent to subtraction). func (p *Add) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} } +// MaxShift returns max shift in either the negative (left) or positive +// direction (right). +func (p *Add) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) } + // Sub represents the subtraction over zero or more expressions. type Sub struct{ Args []Expr } @@ -65,6 +74,10 @@ func (p *Sub) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} } // Equate one expression with another (equivalent to subtraction). func (p *Sub) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} } +// MaxShift returns max shift in either the negative (left) or positive +// direction (right). +func (p *Sub) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) } + // Mul represents the product over zero or more expressions. type Mul struct{ Args []Expr } @@ -80,6 +93,10 @@ func (p *Mul) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} } // Equate one expression with another (equivalent to subtraction). func (p *Mul) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} } +// MaxShift returns max shift in either the negative (left) or positive +// direction (right). +func (p *Mul) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) } + // Constant represents a constant value within an expression. type Constant struct{ Value *fr.Element } @@ -118,6 +135,10 @@ func (p *Constant) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} } // Equate one expression with another (equivalent to subtraction). func (p *Constant) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} } +// MaxShift returns max shift in either the negative (left) or positive +// direction (right). A constant has zero shift. +func (p *Constant) MaxShift() util.Pair[uint, uint] { return util.NewPair[uint, uint](0, 0) } + // ColumnAccess represents reading the value held at a given column in the // tabular context. Furthermore, the current row maybe shifted up (or down) by // a given amount. Suppose we are evaluating a constraint on row k=5 which @@ -146,3 +167,31 @@ func (p *ColumnAccess) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other} // Equate one expression with another (equivalent to subtraction). func (p *ColumnAccess) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} } + +// MaxShift returns max shift in either the negative (left) or positive +// direction (right). +func (p *ColumnAccess) MaxShift() util.Pair[uint, uint] { + if p.Shift >= 0 { + // Positive shift + return util.NewPair[uint, uint](0, uint(p.Shift)) + } + // Negative shift + return util.NewPair[uint, uint](uint(-p.Shift), 0) +} + +// ========================================================================== +// Helpers +// ========================================================================== + +func maxShiftOfArray(args []Expr) util.Pair[uint, uint] { + neg := uint(0) + pos := uint(0) + + for _, e := range args { + mx := e.MaxShift() + neg = max(neg, mx.Left) + pos = max(pos, mx.Right) + } + // Done + return util.NewPair(neg, pos) +} diff --git a/pkg/air/gadgets/bits.go b/pkg/air/gadgets/bits.go index 0b961c0..8e8c131 100644 --- a/pkg/air/gadgets/bits.go +++ b/pkg/air/gadgets/bits.go @@ -52,7 +52,8 @@ func ApplyBitwidthGadget(col string, nbits uint, schema *air.Schema) { // Construct X == (X:0 * 1) + ... + (X:n * 2^n) X := air.NewColumnAccess(col, 0) eq := X.Equate(sum) - schema.AddVanishingConstraint(col, nil, eq) + // Construct column name + schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", col, nbits), nil, eq) // Finally, add the necessary byte decomposition computation. schema.AddComputation(table.NewByteDecomposition(col, nbits)) } diff --git a/pkg/air/gadgets/lexicographic_sort.go b/pkg/air/gadgets/lexicographic_sort.go index 50a3e3f..eef2040 100644 --- a/pkg/air/gadgets/lexicographic_sort.go +++ b/pkg/air/gadgets/lexicographic_sort.go @@ -155,6 +155,12 @@ type lexicographicSortExpander struct { bitwidth uint } +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *lexicographicSortExpander) RequiredSpillage() uint { + return uint(0) +} + // Accepts checks whether a given trace has the necessary columns func (p *lexicographicSortExpander) Accepts(tr table.Trace) error { prefix := constructLexicographicSortingPrefix(p.columns, p.signs) @@ -194,14 +200,14 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error { bit[i] = make([]*fr.Element, nrows) } - for i := 0; i < nrows; i++ { + for i := uint(0); i < nrows; i++ { set := false // Initialise delta to zero delta[i] = &zero // Decide which row is the winner (if any) for j := 0; j < ncols; j++ { - prev := tr.GetByName(p.columns[j], i-1) - curr := tr.GetByName(p.columns[j], i) + prev := tr.GetByName(p.columns[j], int(i-1)) + curr := tr.GetByName(p.columns[j], int(i)) if !set && prev != nil && prev.Cmp(curr) != 0 { var diff fr.Element diff --git a/pkg/air/gadgets/normalisation.go b/pkg/air/gadgets/normalisation.go index 038cdbe..1a3875e 100644 --- a/pkg/air/gadgets/normalisation.go +++ b/pkg/air/gadgets/normalisation.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/air" "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/util" ) // Normalise constructs an expression representing the normalised value of e. @@ -73,6 +74,10 @@ func (e *Inverse) EvalAt(k int, tbl table.Trace) *fr.Element { return inv.Inverse(val) } +// MaxShift returns max shift in either the negative (left) or positive +// direction (right). +func (e *Inverse) MaxShift() util.Pair[uint, uint] { return e.Expr.MaxShift() } + func (e *Inverse) String() string { return fmt.Sprintf("(inv %s)", e.Expr) } diff --git a/pkg/air/schema.go b/pkg/air/schema.go index a2523ae..4c2a4c1 100644 --- a/pkg/air/schema.go +++ b/pkg/air/schema.go @@ -89,12 +89,38 @@ func (p *Schema) HasColumn(name string) bool { return false } +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. Spillage can +// only arise from computations as this is where values outside of the user's +// control are determined. +func (p *Schema) RequiredSpillage() uint { + // Ensures always at least one row of spillage (referred to as the "initial + // padding row") + mx := uint(1) + // Determine if any more spillage required + for _, c := range p.computations { + mx = max(mx, c.RequiredSpillage()) + } + + return mx +} + +// ApplyPadding adds n items of padding to each column of the trace. +// Padding values are placed either at the front or the back of a given +// column, depending on their interpretation. +func (p *Schema) ApplyPadding(n uint, tr table.Trace) { + tr.Pad(n, func(j int) *fr.Element { + // Extract front value to use for padding. + return tr.GetByIndex(j, 0) + }) +} + // IsInputTrace determines whether a given input trace is a suitable // input (i.e. non-expanded) trace for this schema. Specifically, the // input trace must contain a matching column for each non-synthetic // column in this trace. func (p *Schema) IsInputTrace(tr table.Trace) error { - count := 0 + count := uint(0) for _, c := range p.dataColumns { if !c.Synthetic && !tr.HasColumn(c.Name) { @@ -112,8 +138,8 @@ func (p *Schema) IsInputTrace(tr table.Trace) error { // Determine the unknown columns for error reporting. unknown := make([]string, 0) - for i := 0; i < tr.Width(); i++ { - n := tr.ColumnName(i) + for i := uint(0); i < tr.Width(); i++ { + n := tr.ColumnName(int(i)) if !p.HasColumn(n) { unknown = append(unknown, n) } @@ -132,7 +158,7 @@ func (p *Schema) IsInputTrace(tr table.Trace) error { // output trace must contain a matching column for each column in this // trace (synthetic or otherwise). func (p *Schema) IsOutputTrace(tr table.Trace) error { - count := 0 + count := uint(0) for _, c := range p.dataColumns { if !tr.HasColumn(c.Name) { @@ -153,7 +179,9 @@ func (p *Schema) IsOutputTrace(tr table.Trace) error { // AddColumn appends a new data column which is either synthetic or // not. A synthetic column is one which has been introduced by the // process of lowering from HIR / MIR to AIR. That is, it is not a -// column which was original specified by the user. +// column which was original specified by the user. Columns also support a +// "padding sign", which indicates whether padding should occur at the front +// (positive sign) or the back (negative sign). func (p *Schema) AddColumn(name string, synthetic bool) { // NOTE: the air level has no ability to enforce the type specified for a // given column. @@ -219,8 +247,6 @@ func (p *Schema) Accepts(trace table.Trace) error { // columns. Observe that computed columns have to be computed in the correct // order. func (p *Schema) ExpandTrace(tr table.Trace) error { - // Insert initial padding row - table.PadTrace(1, tr) // Execute all computations for _, c := range p.computations { err := c.ExpandTrace(tr) diff --git a/pkg/cmd/check.go b/pkg/cmd/check.go index 8a8e07a..f0ba6dc 100644 --- a/pkg/cmd/check.go +++ b/pkg/cmd/check.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/go-corset/pkg/hir" "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/util" "github.com/spf13/cobra" ) @@ -19,32 +20,121 @@ var checkCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { var trace *table.ArrayTrace var hirSchema *hir.Schema + var cfg checkConfig if len(args) != 2 { fmt.Println(cmd.UsageString()) os.Exit(1) } - raw := getFlag(cmd, "raw") + cfg.air = getFlag(cmd, "air") + cfg.mir = getFlag(cmd, "mir") + cfg.hir = getFlag(cmd, "hir") + cfg.expand = !getFlag(cmd, "raw") + cfg.report = getFlag(cmd, "report") + cfg.spillage = getInt(cmd, "spillage") + cfg.padding.Right = getUint(cmd, "padding") + // TODO: support true ranges + cfg.padding.Left = cfg.padding.Right // Parse trace trace = readTraceFile(args[0]) // Parse constraints hirSchema = readSchemaFile(args[1]) // Go! - checkTraceWithLowering(trace, hirSchema, raw) + checkTraceWithLowering(trace, hirSchema, cfg) }, } +// check config encapsulates certain parameters to be used when +// checking traces. +type checkConfig struct { + // Performing checking at HIR level + hir bool + // Performing checking at MIR level + mir bool + // Performing checking at AIR level + air bool + // Determines how much spillage to account for. This gives the user the + // ability to override the inferred default. A negative value indicates + // this default should be used. + spillage int + // Determines how much padding to use + padding util.Pair[uint, uint] + // Specifies whether or not to perform trace expansion. Trace expansion is + // not required when a "raw" trace is given which already includes all + // implied columns. + expand bool + // Specifies whether or not to report details of the failure (e.g. for + // debugging purposes). + report bool +} + // Check a given trace is consistently accepted (or rejected) at the different // IR levels. -func checkTraceWithLowering(tr *table.ArrayTrace, hirSchema *hir.Schema, raw bool) { +func checkTraceWithLowering(tr *table.ArrayTrace, schema *hir.Schema, cfg checkConfig) { + if !cfg.hir && !cfg.mir && !cfg.air { + // Process together + checkTraceWithLoweringDefault(tr, schema, cfg) + } else { + // Process individually + if cfg.hir { + checkTraceWithLoweringHir(tr, schema, cfg) + } + + if cfg.mir { + checkTraceWithLoweringMir(tr, schema, cfg) + } + + if cfg.air { + checkTraceWithLoweringAir(tr, schema, cfg) + } + } +} + +func checkTraceWithLoweringHir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { + trHIR, errHIR := checkTrace(tr, hirSchema, cfg) + // + if errHIR != nil { + reportError("HIR", trHIR, errHIR, cfg) + os.Exit(1) + } +} + +func checkTraceWithLoweringMir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { + // Lower HIR => MIR + mirSchema := hirSchema.LowerToMir() + // Check trace + trMIR, errMIR := checkTrace(tr, mirSchema, cfg) + // + if errMIR != nil { + reportError("MIR", trMIR, errMIR, cfg) + os.Exit(1) + } +} + +func checkTraceWithLoweringAir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { // Lower HIR => MIR mirSchema := hirSchema.LowerToMir() // Lower MIR => AIR airSchema := mirSchema.LowerToAir() + trAIR, errAIR := checkTrace(tr, airSchema, cfg) // - errHIR := checkTrace(tr, hirSchema, raw) - errMIR := checkTrace(tr, mirSchema, raw) - errAIR := checkTrace(tr, airSchema, raw) + if errAIR != nil { + reportError("AIR", trAIR, errAIR, cfg) + os.Exit(1) + } +} + +// The default check allows one to compare all levels against each other and +// look for any discrepenacies. +func checkTraceWithLoweringDefault(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { + // Lower HIR => MIR + mirSchema := hirSchema.LowerToMir() + // Lower MIR => AIR + airSchema := mirSchema.LowerToAir() + // + trHIR, errHIR := checkTrace(tr, hirSchema, cfg) + trMIR, errMIR := checkTrace(tr, mirSchema, cfg) + trAIR, errAIR := checkTrace(tr, airSchema, cfg) // if errHIR != nil || errMIR != nil || errAIR != nil { strHIR := toErrorString(errHIR) @@ -54,26 +144,50 @@ func checkTraceWithLowering(tr *table.ArrayTrace, hirSchema *hir.Schema, raw boo if strHIR == strMIR && strMIR == strAIR { fmt.Println(errHIR) } else { - reportError(errHIR, "HIR") - reportError(errMIR, "MIR") - reportError(errAIR, "AIR") + reportError("HIR", trHIR, errHIR, cfg) + reportError("MIR", trMIR, errMIR, cfg) + reportError("AIR", trAIR, errAIR, cfg) } os.Exit(1) } } -func checkTrace(tr *table.ArrayTrace, schema table.Schema, raw bool) error { - if !raw { +func checkTrace(tr *table.ArrayTrace, schema table.Schema, cfg checkConfig) (table.Trace, error) { + if cfg.expand { // Clone to prevent interefence with subsequent checks tr = tr.Clone() + // Apply spillage + if cfg.spillage >= 0 { + // Apply user-specified spillage + table.FrontPadWithZeros(uint(cfg.spillage), tr) + } else { + // Apply default inferred spillage + table.FrontPadWithZeros(schema.RequiredSpillage(), tr) + } // Expand trace if err := schema.ExpandTrace(tr); err != nil { - return err + return tr, err } } - // Check whether accepted or not. - return schema.Accepts(tr) + // Check whether padding requested + if cfg.padding.Left == 0 && cfg.padding.Right == 0 { + // No padding requested. Therefore, we can avoid a clone in this case. + return tr, schema.Accepts(tr) + } + // Apply padding + for n := cfg.padding.Left; n <= cfg.padding.Right; n++ { + // Prevent interference + ptr := tr.Clone() + // Apply padding + schema.ApplyPadding(n, ptr) + // Check whether accepted or not. + if err := schema.Accepts(ptr); err != nil { + return ptr, err + } + } + // Done + return nil, nil } func toErrorString(err error) string { @@ -84,7 +198,11 @@ func toErrorString(err error) string { return err.Error() } -func reportError(err error, ir string) { +func reportError(ir string, tr table.Trace, err error, cfg checkConfig) { + if cfg.report { + table.PrintTrace(tr) + } + if err != nil { fmt.Printf("%s: %s\n", ir, err) } else { @@ -94,5 +212,12 @@ func reportError(err error, ir string) { func init() { rootCmd.AddCommand(checkCmd) + checkCmd.Flags().Bool("report", false, "report details of failure for debugging") checkCmd.Flags().Bool("raw", false, "assume input trace already expanded") + checkCmd.Flags().Bool("hir", false, "check at HIR level") + checkCmd.Flags().Bool("mir", false, "check at MIR level") + checkCmd.Flags().Bool("air", false, "check at AIR level") + checkCmd.Flags().Uint("padding", 0, "specify amount of (front) padding to apply") + checkCmd.Flags().Int("spillage", -1, + "specify amount of splillage to account for (where -1 indicates this should be inferred)") } diff --git a/pkg/cmd/util.go b/pkg/cmd/util.go index e164f4a..6f30b62 100644 --- a/pkg/cmd/util.go +++ b/pkg/cmd/util.go @@ -24,16 +24,41 @@ func getFlag(cmd *cobra.Command, flag string) bool { return r } +// Get an expectedsigned integer, or panic if an error arises. +func getInt(cmd *cobra.Command, flag string) int { + r, err := cmd.Flags().GetInt(flag) + if err != nil { + fmt.Println(err) + os.Exit(3) + } + + return r +} + +// Get an expected unsigned integer, or panic if an error arises. +func getUint(cmd *cobra.Command, flag string) uint { + r, err := cmd.Flags().GetUint(flag) + if err != nil { + fmt.Println(err) + os.Exit(4) + } + + return r +} + // Parse a trace file using a parser based on the extension of the filename. func readTraceFile(filename string) *table.ArrayTrace { + var trace *table.ArrayTrace + // Read data file bytes, err := os.ReadFile(filename) + // Check success if err == nil { // Check file extension ext := path.Ext(filename) // switch ext { case ".json": - trace, err := table.ParseJsonTrace(bytes) + trace, err = table.ParseJsonTrace(bytes) if err == nil { return trace } diff --git a/pkg/hir/parser.go b/pkg/hir/parser.go index f7dd869..292d08d 100644 --- a/pkg/hir/parser.go +++ b/pkg/hir/parser.go @@ -150,6 +150,10 @@ func (p *hirParser) parseSortedPermutationDeclaration(elements []sexp.SExp) erro if strings.HasPrefix(sortName, "+") { signs[i] = true } else if strings.HasPrefix(sortName, "-") { + if i == 0 { + return p.translator.SyntaxError(source, "sorted permutation requires ascending first column") + } + signs[i] = false } else { return p.translator.SyntaxError(source, "malformed sort direction") diff --git a/pkg/hir/schema.go b/pkg/hir/schema.go index d8d47ce..5a177c6 100644 --- a/pkg/hir/schema.go +++ b/pkg/hir/schema.go @@ -1,6 +1,7 @@ package hir import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/mir" "github.com/consensys/go-corset/pkg/table" "github.com/consensys/go-corset/pkg/util" @@ -100,6 +101,23 @@ func (p *Schema) Size() int { return len(p.dataColumns) + len(p.permutations) + len(p.vanishing) + len(p.assertions) } +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *Schema) RequiredSpillage() uint { + // Ensures always at least one row of spillage (referred to as the "initial + // padding row") + return uint(1) +} + +// ApplyPadding adds n items of padding to each column of the trace. +// Padding values are placed either at the front or the back of a given +// column, depending on their interpretation. +func (p *Schema) ApplyPadding(n uint, tr table.Trace) { + tr.Pad(n, func(j int) *fr.Element { + return tr.GetByIndex(j, 0) + }) +} + // GetDeclaration returns the ith declaration in this schema. func (p *Schema) GetDeclaration(index int) table.Declaration { ith := util.FlatArrayIndexOf_4(index, p.dataColumns, p.permutations, p.vanishing, p.assertions) @@ -159,8 +177,6 @@ func (p *Schema) Accepts(trace table.Trace) error { // ExpandTrace expands a given trace according to this schema. func (p *Schema) ExpandTrace(tr table.Trace) error { - // Insert initial padding row - table.PadTrace(1, tr) // Expand all the permutation columns for _, perm := range p.permutations { err := perm.ExpandTrace(tr) diff --git a/pkg/mir/schema.go b/pkg/mir/schema.go index 83ce36a..53cf321 100644 --- a/pkg/mir/schema.go +++ b/pkg/mir/schema.go @@ -3,6 +3,7 @@ package mir import ( "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/air" air_gadgets "github.com/consensys/go-corset/pkg/air/gadgets" "github.com/consensys/go-corset/pkg/table" @@ -67,6 +68,23 @@ func (p *Schema) Size() int { return len(p.dataColumns) + len(p.permutations) + len(p.vanishing) + len(p.assertions) } +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *Schema) RequiredSpillage() uint { + // Ensures always at least one row of spillage (referred to as the "initial + // padding row") + return uint(1) +} + +// ApplyPadding adds n items of padding to each column of the trace. +// Padding values are placed either at the front or the back of a given +// column, depending on their interpretation. +func (p *Schema) ApplyPadding(n uint, tr table.Trace) { + tr.Pad(n, func(j int) *fr.Element { + return tr.GetByIndex(j, 0) + }) +} + // GetDeclaration returns the ith declaration in this schema. func (p *Schema) GetDeclaration(index int) table.Declaration { ith := util.FlatArrayIndexOf_4(index, p.dataColumns, p.permutations, p.vanishing, p.assertions) @@ -217,8 +235,6 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche // ExpandTrace expands a given trace according to this schema. func (p *Schema) ExpandTrace(tr table.Trace) error { - // Insert initial padding row - table.PadTrace(1, tr) // Expand all the permutation columns for _, perm := range p.permutations { err := perm.ExpandTrace(tr) diff --git a/pkg/table/column.go b/pkg/table/column.go index 03d86c8..7c02db0 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -41,8 +41,8 @@ func (c *DataColumn[T]) Accepts(tr Trace) error { return fmt.Errorf("Trace missing data column ({%s})", c.Name) } // Check constraints accepted - for i := 0; i < tr.Height(); i++ { - val := tr.GetByName(c.Name, i) + for i := uint(0); i < tr.Height(); i++ { + val := tr.GetByName(c.Name, int(i)) if !c.Type.Accept(val) { // Construct useful error message @@ -70,29 +70,40 @@ func (c *DataColumn[T]) String() string { // expectation that this computation is acyclic. Furthermore, computed columns // give rise to "trace expansion". That is where the initial trace provided by // the user is expanded by determining the value of all computed columns. -type ComputedColumn struct { +type ComputedColumn[E Computable] struct { Name string // The computation which accepts a given trace and computes // the value of this column at a given row. - Expr Evaluable + Expr E } // NewComputedColumn constructs a new computed column with a given name and // determining expression. More specifically, that expression is used to // compute the values for this column during trace expansion. -func NewComputedColumn(name string, expr Evaluable) *ComputedColumn { - return &ComputedColumn{ +func NewComputedColumn[E Computable](name string, expr E) *ComputedColumn[E] { + return &ComputedColumn[E]{ Name: name, Expr: expr, } } +// RequiredSpillage returns the minimum amount of spillage required to ensure +// this column can be correctly computed in the presence of arbitrary (front) +// padding. +func (c *ComputedColumn[E]) RequiredSpillage() uint { + // NOTE: Spillage is only currently considered to be necessary at the front + // (i.e. start) of a trace. This is because padding is always inserted at + // the front, never the back. As such, it is the maximum positive shift + // which determines how much spillage is required for this comptuation. + return c.Expr.MaxShift().Right +} + // Accepts determines whether or not this column accepts the given trace. For a // data column, this means ensuring that all elements are value for the columns // type. // //nolint:revive -func (c *ComputedColumn) Accepts(tr Trace) error { +func (c *ComputedColumn[E]) Accepts(tr Trace) error { // Check column in trace! if !tr.HasColumn(c.Name) { return fmt.Errorf("Trace missing computed column ({%s})", c.Name) @@ -104,7 +115,7 @@ func (c *ComputedColumn) Accepts(tr Trace) error { // ExpandTrace attempts to a new column to the trace which contains the result // of evaluating a given expression on each row. If the column already exists, // then an error is flagged. -func (c *ComputedColumn) ExpandTrace(tr Trace) error { +func (c *ComputedColumn[E]) ExpandTrace(tr Trace) error { if tr.HasColumn(c.Name) { msg := fmt.Sprintf("Computed column already exists ({%s})", c.Name) return errors.New(msg) @@ -127,8 +138,9 @@ func (c *ComputedColumn) ExpandTrace(tr Trace) error { return nil } -func (c *ComputedColumn) String() string { - return fmt.Sprintf("(compute %s %s)", c.Name, c.Expr) +// nolint:revive +func (c *ComputedColumn[E]) String() string { + return fmt.Sprintf("(compute %s %s)", c.Name, any(c.Expr)) } // =================================================================== @@ -149,6 +161,12 @@ func NewPermutation(target string, source string) *Permutation { return &Permutation{target, source} } +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *Permutation) RequiredSpillage() uint { + return uint(0) +} + // Accepts checks whether a permutation holds between the source and // target columns. func (p *Permutation) Accepts(tr Trace) error { @@ -190,6 +208,12 @@ func NewSortedPermutation(targets []string, signs []bool, sources []string) *Sor return &SortedPermutation{targets, signs, sources} } +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *SortedPermutation) RequiredSpillage() uint { + return uint(0) +} + // Accepts checks whether a sorted permutation holds between the // source and target columns. func (p *SortedPermutation) Accepts(tr Trace) error { @@ -218,7 +242,7 @@ func (p *SortedPermutation) Accepts(tr Trace) error { return err } - cols[i] = tr.ColumnByName(dstName) + cols[i] = tr.ColumnByName(dstName).Data() } // Check that target columns are sorted lexicographically. @@ -247,7 +271,7 @@ func (p *SortedPermutation) ExpandTrace(tr Trace) error { for i := 0; i < len(p.Targets); i++ { src := p.Sources[i] // Read column data to initialise permutation. - data := tr.ColumnByName(src) + data := tr.ColumnByName(src).Data() // Copy column data to initialise permutation. cols[i] = make([]*fr.Element, len(data)) copy(cols[i], data) @@ -296,8 +320,8 @@ func (p *SortedPermutation) String() string { // of another in given trace. The order in which columns are given is // not important. func IsPermutationOf(target string, source string, tr Trace) error { - dst := tr.ColumnByName(target) - src := tr.ColumnByName(source) + dst := tr.ColumnByName(target).Data() + src := tr.ColumnByName(source).Data() // Sanity check whether column exists if dst == nil { msg := fmt.Sprintf("Invalid target column for permutation ({%s})", target) diff --git a/pkg/table/computation.go b/pkg/table/computation.go index 42a1bdd..1954b24 100644 --- a/pkg/table/computation.go +++ b/pkg/table/computation.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/util" ) // TraceComputation represents a computation which is applied to a @@ -17,6 +18,22 @@ type TraceComputation interface { // original trace, but are added during trace expansion to // form the final trace. ExpandTrace(Trace) error + // RequiredSpillage returns the minimum amount of spillage required to ensure + // valid traces are accepted in the presence of arbitrary padding. Note, + // spillage is currently assumed to be required only at the front of a + // trace. + RequiredSpillage() uint +} + +// Computable is an extension of the Evaluable interface which additionally +// allows one to determine specifics about the computation needed to ensure it +// can be correctly computed on a given trace. +type Computable interface { + Evaluable + + // Determine the maximum shift in this expression in either the negative + // (left) or positive direction (right). + MaxShift() util.Pair[uint, uint] } // ByteDecomposition is part of a range constraint for wide columns (e.g. u32) @@ -60,7 +77,7 @@ func (p *ByteDecomposition) ExpandTrace(tr Trace) error { // Calculate how many bytes required. n := int(p.BitWidth / 8) // Extract column data to decompose - data := tr.ColumnByName(p.Target) + data := tr.ColumnByName(p.Target).Data() // Construct byte column data cols := make([][]*fr.Element, n) // Initialise columns @@ -87,6 +104,12 @@ func (p *ByteDecomposition) String() string { return fmt.Sprintf("(decomposition %s %d)", p.Target, p.BitWidth) } +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *ByteDecomposition) RequiredSpillage() uint { + return uint(0) +} + // Decompose a given element into n bytes in little endian form. For example, // decomposing 41b into 2 bytes gives [0x1b,0x04]. func decomposeIntoBytes(val *fr.Element, n int) []*fr.Element { diff --git a/pkg/table/constraints.go b/pkg/table/constraints.go index 1fdde33..6d661fb 100644 --- a/pkg/table/constraints.go +++ b/pkg/table/constraints.go @@ -106,8 +106,8 @@ func (p *RowConstraint[T]) Accepts(tr Trace) error { // HoldsGlobally checks whether a given expression vanishes (i.e. evaluates to // zero) for all rows of a trace. If not, report an appropriate error. func HoldsGlobally[T Testable](handle string, constraint T, tr Trace) error { - for k := 0; k < tr.Height(); k++ { - if err := HoldsLocally(k, handle, constraint, tr); err != nil { + for k := uint(0); k < tr.Height(); k++ { + if err := HoldsLocally(int(k), handle, constraint, tr); err != nil { return err } } @@ -120,7 +120,7 @@ func HoldsGlobally[T Testable](handle string, constraint T, tr Trace) error { func HoldsLocally[T Testable](k int, handle string, constraint T, tr Trace) error { // Negative rows calculated from end of trace. if k < 0 { - k += tr.Height() + k += int(tr.Height()) } // Check whether it holds or not if !constraint.TestAt(k, tr) { @@ -188,9 +188,9 @@ func (p *RangeConstraint) IsAir() bool { return true } // Accepts checks whether a range constraint evaluates to zero on // every row of a table. If so, return nil otherwise return an error. func (p *RangeConstraint) Accepts(tr Trace) error { - for k := 0; k < tr.Height(); k++ { + for k := uint(0); k < tr.Height(); k++ { // Get the value on the kth row - kth := tr.GetByName(p.Handle, k) + kth := tr.GetByName(p.Handle, int(k)) // Perform the bounds check if kth != nil && kth.Cmp(p.Bound) >= 0 { // Construct useful error message @@ -247,9 +247,9 @@ func NewPropertyAssertion[E Evaluable](handle string, expr E) *PropertyAssertion // //nolint:revive func (p *PropertyAssertion[E]) Accepts(tr Trace) error { - for k := 0; k < tr.Height(); k++ { + for k := uint(0); k < tr.Height(); k++ { // Determine kth evaluation point - kth := p.Expr.EvalAt(k, tr) + kth := p.Expr.EvalAt(int(k), tr) // Check whether it vanished (or was undefined) if kth != nil && !kth.IsZero() { // Construct useful error message diff --git a/pkg/table/printer.go b/pkg/table/printer.go new file mode 100644 index 0000000..22775d8 --- /dev/null +++ b/pkg/table/printer.go @@ -0,0 +1,71 @@ +package table + +import ( + "fmt" + "unicode/utf8" +) + +// PrintTrace prints a trace in a more human-friendly fashion. +func PrintTrace(tr Trace) { + n := tr.Width() + // + rows := make([][]string, n) + for i := uint(0); i < n; i++ { + rows[i] = traceColumnData(tr, i) + } + // + widths := traceRowWidths(tr.Height(), rows) + // + printHorizontalRule(widths) + // + for _, r := range rows { + printTraceRow(r, widths) + printHorizontalRule(widths) + } +} + +func traceColumnData(tr Trace, col uint) []string { + n := tr.Height() + data := make([]string, n+1) + data[0] = tr.ColumnName(int(col)) + + for row := uint(0); row < n; row++ { + data[row+1] = tr.GetByIndex(int(col), int(row)).String() + } + + return data +} + +func traceRowWidths(height uint, rows [][]string) []int { + widths := make([]int, height+1) + + for _, row := range rows { + for i, col := range row { + w := utf8.RuneCountInString(col) + widths[i] = max(w, widths[i]) + } + } + + return widths +} + +func printTraceRow(row []string, widths []int) { + for i, col := range row { + fmt.Printf(" %*s |", widths[i], col) + } + + fmt.Println() +} + +func printHorizontalRule(widths []int) { + for _, w := range widths { + fmt.Print("-") + + for i := 0; i < w; i++ { + fmt.Print("-") + } + fmt.Print("-+") + } + + fmt.Println() +} diff --git a/pkg/table/schema.go b/pkg/table/schema.go index e069a7a..da04db8 100644 --- a/pkg/table/schema.go +++ b/pkg/table/schema.go @@ -16,6 +16,16 @@ type Schema interface { // GetDeclaration returns the ith declaration in this schema. GetDeclaration(int) Declaration + + // RequiredSpillage returns the minimum amount of spillage required to + // ensure valid traces are accepted in the presence of arbitrary padding. + // Note: this is calculated on demand. + RequiredSpillage() uint + + // ApplyPadding adds n items of padding to each column of the trace. + // Padding values are placed either at the front or the back of a given + // column, depending on their interpretation. + ApplyPadding(uint, Trace) } // Declaration represents a declared element of a schema. For example, a column diff --git a/pkg/table/trace.go b/pkg/table/trace.go index acbdac9..eed6020 100644 --- a/pkg/table/trace.go +++ b/pkg/table/trace.go @@ -16,6 +16,16 @@ type Acceptable interface { Accepts(Trace) error } +// Column describes an individual column of data within a trace table. +type Column interface { + // Get the name of this column + Name() string + // Return the height (i.e. number of rows) of this column. + Height() uint + // Return the data stored in this column. + Data() []*fr.Element +} + // Trace describes a set of named columns. Columns are not required to have the // same height and can be either "data" columns or "computed" columns. type Trace interface { @@ -23,14 +33,11 @@ type Trace interface { AddColumn(name string, data []*fr.Element) // Get the name of the ith column in this trace. ColumnName(int) string - // Duplicate the first row of this trace n times, whilst placing the - // duplicates at the beginning of the trace. This can be used, for example, - // to apply padding to an existing trace. Note it is an error to call this - // on an empty trace. - DuplicateFront(n int) + // ColumnByIndex returns the ith column in this trace. + ColumnByIndex(uint) Column // ColumnByName returns the data of a given column in order that it can be // inspected. If the given column does not exist, then nil is returned. - ColumnByName(name string) []*fr.Element + ColumnByName(name string) Column // Check whether this trace contains data for the given column. HasColumn(name string) bool // Get the value of a given column by its name. If the column @@ -46,14 +53,14 @@ type Trace interface { // does not exist or if the index is out-of-bounds then an // error is returned. GetByIndex(col int, row int) *fr.Element - // Insert n copies of a given row at the front of this trace using a given - // mapping function to initialise each column. Note, n cannot be negative. - InsertFront(n int, mapping func(int) *fr.Element) + // Pad each column in this trace with n items at the front. An iterator over + // the padding values to use for each column must be given. + Pad(n uint, signs func(int) *fr.Element) // Determine the height of this table, which is defined as the // height of the largest column. - Height() int + Height() uint // Get the number of columns in this trace. - Width() int + Width() uint } // ConstraintsAcceptTrace determines whether or not one or more groups of @@ -70,13 +77,11 @@ func ConstraintsAcceptTrace[T Acceptable](trace Trace, constraints []T) error { return nil } -// PadTrace adds n rows of padding to the given trace by duplicating the first -// row n times. This requires that a first row exists. Furthermore, we cannot -// pad a negative number of rows (i.e. when n < 0). -func PadTrace(n int, tr Trace) { +// FrontPadWithZeros adds n rows of zeros to the given trace. +func FrontPadWithZeros(n uint, tr Trace) { var zero fr.Element = fr.NewElement((0)) // Insert initial padding row - tr.InsertFront(n, func(index int) *fr.Element { return &zero }) + tr.Pad(n, func(index int) *fr.Element { return &zero }) } // =================================================================== @@ -87,7 +92,7 @@ func PadTrace(n int, tr Trace) { // array. type ArrayTrace struct { // Holds the maximum height of any column in the trace - height int + height uint // Holds the name of each column columns []*ArrayTraceColumn } @@ -105,8 +110,8 @@ func EmptyArrayTrace() *ArrayTrace { } // Width returns the number of columns in this trace. -func (p *ArrayTrace) Width() int { - return len(p.columns) +func (p *ArrayTrace) Width() uint { + return uint(len(p.columns)) } // ColumnName returns the name of the ith column in this trace. @@ -150,8 +155,8 @@ func (p *ArrayTrace) AddColumn(name string, data []*fr.Element) { // Append it p.columns = append(p.columns, &column) // Update maximum height - if len(data) > p.height { - p.height = len(data) + if uint(len(data)) > p.height { + p.height = uint(len(data)) } } @@ -160,16 +165,6 @@ func (p *ArrayTrace) Columns() []*ArrayTraceColumn { return p.columns } -// DuplicateFront inserts n duplicates of the first row at the front of this -// trace. -func (p *ArrayTrace) DuplicateFront(n int) { - for _, c := range p.columns { - c.DuplicateFront(n) - } - // Increment height - p.height += n -} - // GetByName gets the value of a given column (as identified by its name) at a // given row. If the column does not exist, an error is returned. func (p *ArrayTrace) GetByName(name string, row int) *fr.Element { @@ -184,13 +179,18 @@ func (p *ArrayTrace) GetByName(name string, row int) *fr.Element { panic(fmt.Sprintf("Invalid column: {%s}", name)) } +// ColumnByIndex looks up a column based on its index. +func (p *ArrayTrace) ColumnByIndex(index uint) Column { + return p.columns[index] +} + // ColumnByName looks up a column based on its name. If the column doesn't // exist, then nil is returned. -func (p *ArrayTrace) ColumnByName(name string) []*fr.Element { +func (p *ArrayTrace) ColumnByName(name string) Column { for _, c := range p.columns { if name == c.name { // Matched column - return c.data + return c } } @@ -221,15 +221,15 @@ func (p *ArrayTrace) getColumnByName(name string) *ArrayTraceColumn { } // Height determines the maximum height of any column within this trace. -func (p *ArrayTrace) Height() int { +func (p *ArrayTrace) Height() uint { return p.height } -// InsertFront inserts n duplicates of a given row at the beginning of this -// trace. -func (p *ArrayTrace) InsertFront(n int, mapping func(int) *fr.Element) { +// Pad each column in this trace with n items at the front. An iterator over +// the padding values to use for each column must be given. +func (p *ArrayTrace) Pad(n uint, padding func(int) *fr.Element) { for i, c := range p.columns { - c.InsertFront(n, mapping(i)) + c.Pad(n, padding(i)) } // Increment height p.height += n @@ -249,8 +249,8 @@ func (p *ArrayTrace) String() string { id.WriteString(p.columns[i].name) id.WriteString("={") - for j := 0; j < p.height; j++ { - jth := p.GetByIndex(i, j) + for j := uint(0); j < p.height; j++ { + jth := p.GetByIndex(i, int(j)) if j != 0 { id.WriteString(",") @@ -286,6 +286,16 @@ func (p *ArrayTraceColumn) Name() string { return p.name } +// Height determines the height of this column. +func (p *ArrayTraceColumn) Height() uint { + return uint(len(p.data)) +} + +// Data returns the data for the given column. +func (p *ArrayTraceColumn) Data() []*fr.Element { + return p.data +} + // Clone an ArrayTraceColumn func (p *ArrayTraceColumn) Clone() *ArrayTraceColumn { clone := new(ArrayTraceColumn) @@ -296,16 +306,16 @@ func (p *ArrayTraceColumn) Clone() *ArrayTraceColumn { return clone } -// DuplicateFront the first row of this column n times. -func (p *ArrayTraceColumn) DuplicateFront(n int) { - ndata := make([]*fr.Element, len(p.data)+n) - // Copy items from existing data over +// Pad this column with n copies of a given value, either at the front +// (sign=true) or the back (sign=false). +func (p *ArrayTraceColumn) Pad(n uint, value *fr.Element) { + // Allocate sufficient memory + ndata := make([]*fr.Element, uint(len(p.data))+n) + // Copy over the data copy(ndata[n:], p.data) - // Copy front - front := p.data[0] - // Duplicate front - for i := 0; i < n; i++ { - ndata[i] = front + // Go padding! + for i := uint(0); i < n; i++ { + ndata[i] = value } // Copy over p.data = ndata @@ -320,19 +330,6 @@ func (p *ArrayTraceColumn) Get(row int) *fr.Element { return p.data[row] } -// InsertFront inserts a given item at the front of this column. -func (p *ArrayTraceColumn) InsertFront(n int, item *fr.Element) { - ndata := make([]*fr.Element, len(p.data)+n) - // Copy items from existing data over - copy(ndata[n:], p.data) - // Insert new items - for i := 0; i < n; i++ { - ndata[i] = item - } - // Copy over - p.data = ndata -} - // =================================================================== // JSON Parser // =================================================================== diff --git a/pkg/test/ir_test.go b/pkg/test/ir_test.go index 503c925..a9961c4 100644 --- a/pkg/test/ir_test.go +++ b/pkg/test/ir_test.go @@ -7,7 +7,6 @@ import ( "os" "strings" "testing" - "unicode/utf8" "github.com/consensys/go-corset/pkg/hir" "github.com/consensys/go-corset/pkg/table" @@ -126,6 +125,10 @@ func Test_Shift_06(t *testing.T) { Check(t, "shift_06") } +func Test_Shift_07(t *testing.T) { + Check(t, "shift_07") +} + // =================================================================== // Normalisation Tests // =================================================================== @@ -308,7 +311,7 @@ func TestSlow_Mxp(t *testing.T) { // Determines the maximum amount of padding to use when testing. Specifically, // every trace is tested with varying amounts of padding upto this value. -const MAX_PADDING int = 0 +const MAX_PADDING uint = 0 // For a given set of constraints, check that all traces which we // expect to be accepted are accepted, and all traces that we expect @@ -339,21 +342,25 @@ func Check(t *testing.T, test string) { func CheckTraces(t *testing.T, test string, expected bool, traces []*table.ArrayTrace, hirSchema *hir.Schema) { for i, tr := range traces { if tr != nil { - for padding := 0; padding <= MAX_PADDING; padding++ { + for padding := uint(0); padding <= MAX_PADDING; padding++ { // Lower HIR => MIR mirSchema := hirSchema.LowerToMir() // Lower MIR => AIR airSchema := mirSchema.LowerToAir() + // Construct trace identifiers + hirID := traceId{"HIR", test, expected, i + 1, padding, hirSchema.RequiredSpillage()} + mirID := traceId{"MIR", test, expected, i + 1, padding, mirSchema.RequiredSpillage()} + airID := traceId{"AIR", test, expected, i + 1, padding, airSchema.RequiredSpillage()} // Check HIR/MIR trace (if applicable) if airSchema.IsInputTrace(tr) == nil { // This is an unexpanded input trace. - checkInputTrace(t, tr, traceId{"HIR", test, expected, i + 1, padding}, hirSchema) - checkInputTrace(t, tr, traceId{"MIR", test, expected, i + 1, padding}, mirSchema) - checkInputTrace(t, tr, traceId{"AIR", test, expected, i + 1, padding}, airSchema) + checkInputTrace(t, tr, hirID, hirSchema) + checkInputTrace(t, tr, mirID, mirSchema) + checkInputTrace(t, tr, airID, airSchema) } else if airSchema.IsOutputTrace(tr) == nil { // This is an already expanded input trace. Therefore, no need // to perform expansion. - checkExpandedTrace(t, tr, traceId{"AIR", test, expected, i + 1, 0}, airSchema) + checkExpandedTrace(t, tr, airID, airSchema) } else { // Trace appears to be malformed. err1 := airSchema.IsInputTrace(tr) @@ -373,6 +380,8 @@ func CheckTraces(t *testing.T, test string, expected bool, traces []*table.Array func checkInputTrace(t *testing.T, tr *table.ArrayTrace, id traceId, schema table.Schema) { // Clone trace (to ensure expansion does not affect subsequent tests) etr := tr.Clone() + // Apply spillage + table.FrontPadWithZeros(schema.RequiredSpillage(), etr) // Expand trace err := schema.ExpandTrace(etr) // Check @@ -385,7 +394,7 @@ func checkInputTrace(t *testing.T, tr *table.ArrayTrace, id traceId, schema tabl func checkExpandedTrace(t *testing.T, tr table.Trace, id traceId, schema table.Schema) { // Apply padding - table.PadTrace(id.padding, tr) + schema.ApplyPadding(id.padding, tr) // Check err := schema.Accepts(tr) // Determine whether trace accepted or not. @@ -393,14 +402,13 @@ func checkExpandedTrace(t *testing.T, tr table.Trace, id traceId, schema table.S // Process what happened versus what was supposed to happen. if !accepted && id.expected { //printTrace(tr) - msg := fmt.Sprintf("Trace rejected incorrectly (%s, %s.accepts, %d padding, line %d): %s", - id.ir, id.test, id.padding, id.line, err) + msg := fmt.Sprintf("Trace rejected incorrectly (%s, %s.accepts, line %d with spillage %d / padding %d): %s", + id.ir, id.test, id.line, id.spillage, id.padding, err) t.Errorf(msg) } else if accepted && !id.expected { - printTrace(tr) - - msg := fmt.Sprintf("Trace accepted incorrectly (%s, %s.rejects, %d padding, line %d)", - id.ir, id.test, id.padding, id.line) + //printTrace(tr) + msg := fmt.Sprintf("Trace accepted incorrectly (%s, %s.rejects, line %d with spillage %d / padding %d)", + id.ir, id.test, id.line, id.spillage, id.padding) t.Errorf(msg) } } @@ -420,8 +428,11 @@ type traceId struct { // Identifies the line number within the test file that the failing trace // original. line int - // Identifies how much padding has been added to the original trace. - padding int + // Identifies how much padding has been added to the expanded trace. + padding uint + // Determines how much spillage was added to the original trace (prior to + // expansion). + spillage uint } // ReadTracesFile reads a file containing zero or more traces expressed as JSON, where @@ -499,68 +510,3 @@ func readLine(reader *bufio.Reader) *string { // Done return &str } - -// Prints a trace in a more human-friendly fashion. -func printTrace(tr table.Trace) { - n := tr.Width() - // - rows := make([][]string, n) - for i := 0; i < n; i++ { - rows[i] = traceColumnData(tr, i) - } - // - widths := traceRowWidths(tr.Height(), rows) - // - printHorizontalRule(widths) - // - for _, r := range rows { - printTraceRow(r, widths) - printHorizontalRule(widths) - } -} - -func traceColumnData(tr table.Trace, col int) []string { - n := tr.Height() - data := make([]string, n+1) - data[0] = tr.ColumnName(col) - - for row := 0; row < n; row++ { - data[row+1] = tr.GetByIndex(col, row).String() - } - - return data -} - -func traceRowWidths(height int, rows [][]string) []int { - widths := make([]int, height+1) - - for _, row := range rows { - for i, col := range row { - w := utf8.RuneCountInString(col) - widths[i] = max(w, widths[i]) - } - } - - return widths -} - -func printTraceRow(row []string, widths []int) { - for i, col := range row { - fmt.Printf(" %*s |", widths[i], col) - } - - fmt.Println() -} - -func printHorizontalRule(widths []int) { - for _, w := range widths { - fmt.Print("-") - - for i := 0; i < w; i++ { - fmt.Print("-") - } - fmt.Print("-+") - } - - fmt.Println() -} diff --git a/pkg/util/pair.go b/pkg/util/pair.go new file mode 100644 index 0000000..90347f1 --- /dev/null +++ b/pkg/util/pair.go @@ -0,0 +1,18 @@ +package util + +// Pair provides a simple encapsulation of two items paired together. +type Pair[S any, T any] struct { + Left S + Right T +} + +// NewPair returns a new instance of Pair by value. +func NewPair[S any, T any](left S, right T) Pair[S, T] { + return Pair[S, T]{left, right} +} + +// NewPairRef returns a reference to a new instance of Pair. +func NewPairRef[S any, T any](left S, right T) *Pair[S, T] { + var p Pair[S, T] = NewPair(left, right) + return &p +} diff --git a/pkg/util/permutation.go b/pkg/util/permutation.go index 67228fe..c5f5f27 100644 --- a/pkg/util/permutation.go +++ b/pkg/util/permutation.go @@ -2,6 +2,7 @@ package util import ( //"fmt" + "slices" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" diff --git a/testdata/memory.lisp b/testdata/memory.lisp index b3c2911..1d96da7 100644 --- a/testdata/memory.lisp +++ b/testdata/memory.lisp @@ -24,7 +24,7 @@ ;; PC[0]=0 (vanish:first heartbeat_1 PC) ;; PC[k]=PC[k-1]+1 -(vanish heartbeat_2 (- PC (+ 1 (shift PC -1)))) +(vanish heartbeat_2 (* PC (- PC (+ 1 (shift PC -1))))) ;; ADDR'[k] != ADDR'[k-1] ==> (RW'[k]=1 || VAL'[k]=0) (vanish first_read_1 (ifnot (- ADDR' (shift ADDR' -1)) (* (- 1 RW') VAL'))) diff --git a/testdata/memory.rejects b/testdata/memory.rejects index f9129f8..9efe170 100644 --- a/testdata/memory.rejects +++ b/testdata/memory.rejects @@ -1,5 +1,5 @@ -{ "PC": [0], "RW": [0], "ADDR": [0], "VAL": [0] } { "PC": [2], "RW": [0], "ADDR": [0], "VAL": [0] } +{ "PC": [3], "RW": [0], "ADDR": [0], "VAL": [0] } { "PC": [1,3], "RW": [0,0], "ADDR": [0,0], "VAL": [0,0] } ;; n=1 { "PC": [1,2,3,4], "RW": [0,0,0,0], "ADDR": [0,0,0,0], "VAL": [1,0,0,0] } diff --git a/testdata/permute_03.accepts b/testdata/permute_03.accepts index 5cca95c..5ac5ede 100644 --- a/testdata/permute_03.accepts +++ b/testdata/permute_03.accepts @@ -1,12 +1,12 @@ -{"X": []} -{"X": [1]} -{"X": [2,1]} -{"X": [1,2]} -{"X": [1,2,3]} -{"X": [2,1,3]} -{"X": [1,3,2]} -{"X": [3,2,1]} -{"X": [3,2,1,4]} -{"X": [3,4,2,1]} -{"X": [3,2,1,4]} -{"X": [3,2,4,1]} +{"ST": [], "X": []} +{"ST": [1], "X": [1]} +{"ST": [1,1], "X": [2,1]} +{"ST": [1,1], "X": [1,2]} +{"ST": [1,1,1], "X": [1,2,3]} +{"ST": [1,1,1], "X": [2,1,3]} +{"ST": [1,1,1], "X": [1,3,2]} +{"ST": [1,1,1], "X": [3,2,1]} +{"ST": [1,1,1,1], "X": [3,2,1,4]} +{"ST": [1,1,1,1], "X": [3,4,2,1]} +{"ST": [1,1,1,1], "X": [3,2,1,4]} +{"ST": [1,1,1,1], "X": [3,2,4,1]} diff --git a/testdata/permute_03.lisp b/testdata/permute_03.lisp index 468165a..1edc92a 100644 --- a/testdata/permute_03.lisp +++ b/testdata/permute_03.lisp @@ -1,4 +1,5 @@ +(column ST) (column X :u16) (permute (Y) (+X)) ;; Ensure sorted column increments by 1 -(vanish increment (- (shift Y 1) (+ 1 Y))) +(vanish increment (* ST (- (shift Y 1) (+ 1 Y)))) diff --git a/testdata/permute_03.rejects b/testdata/permute_03.rejects index 5eeb955..56e1525 100644 --- a/testdata/permute_03.rejects +++ b/testdata/permute_03.rejects @@ -1,11 +1,11 @@ -{"X": [0,0]} -{"X": [1,1]} -{"X": [0,2]} -{"X": [2,0]} -{"X": [5,9]} -{"X": [6,4]} -{"X": [0,0,0]} -{"X": [1,2,4]} -{"X": [2,0,3]} -{"X": [1,2,1]} -{"X": [3,2,0]} +{"ST": [1,1], "X": [0,0]} +{"ST": [1,1], "X": [1,1]} +{"ST": [1,1], "X": [0,2]} +{"ST": [1,1], "X": [2,0]} +{"ST": [1,1], "X": [5,9]} +{"ST": [1,1], "X": [6,4]} +{"ST": [1,1,1], "X": [0,0,0]} +{"ST": [1,1,1], "X": [1,2,4]} +{"ST": [1,1,1], "X": [2,0,3]} +{"ST": [1,1,1], "X": [1,2,1]} +{"ST": [1,1,1], "X": [3,2,0]} diff --git a/testdata/permute_04.accepts b/testdata/permute_04.accepts index 9306be6..ce04e1e 100644 --- a/testdata/permute_04.accepts +++ b/testdata/permute_04.accepts @@ -1,12 +1,12 @@ -{"X": [5]} -{"X": [0,5]} -{"X": [5,0]} -{"X": [0,1,5]} -{"X": [0,5,1]} -{"X": [5,1,0]} -{"X": [1,0,5]} -{"X": [1,2,3,4,5]} -{"X": [1,2,3,5,4]} -{"X": [1,2,5,4,3]} -{"X": [1,5,3,4,2]} -{"X": [5,2,3,4,1]} +{"ST": [1], "X": [5]} +{"ST": [1,1], "X": [6,5]} +{"ST": [1,1], "X": [5,6]} +{"ST": [1,1,1], "X": [6,7,5]} +{"ST": [1,1,1], "X": [6,5,7]} +{"ST": [1,1,1], "X": [5,7,6]} +{"ST": [1,1,1], "X": [7,6,5]} +{"ST": [1,1,1,1,1], "X": [6,7,8,9,5]} +{"ST": [1,1,1,1,1], "X": [6,7,8,5,9]} +{"ST": [1,1,1,1,1], "X": [6,7,5,9,8]} +{"ST": [1,1,1,1,1], "X": [6,5,8,9,7]} +{"ST": [1,1,1,1,1], "X": [5,7,8,9,6]} diff --git a/testdata/permute_04.lisp b/testdata/permute_04.lisp index 35f92dd..f84c48c 100644 --- a/testdata/permute_04.lisp +++ b/testdata/permute_04.lisp @@ -1,3 +1,4 @@ +(column ST :u16) (column X :u16) -(permute (Y) (-X)) -(vanish:first last-row (- Y 5)) +(permute (ST' Y) (+ST -X)) +(vanish:last first-row (- Y 5)) diff --git a/testdata/permute_04.rejects b/testdata/permute_04.rejects index e5cc1eb..74fda2a 100644 --- a/testdata/permute_04.rejects +++ b/testdata/permute_04.rejects @@ -1,5 +1,5 @@ -{"X": [-1]} -{"X": [255]} -{"X": [1234987]} -{"X": [1,2,3]} -{"X": [3,44,235,623,1,35]} +{"ST": [1], "X": [-1]} +{"ST": [1], "X": [255]} +{"ST": [1], "X": [1234987]} +{"ST": [1,1,1], "X": [1,2,3]} +{"ST": [1,1,1,1,1,1], "X": [3,44,235,623,1,35]} diff --git a/testdata/permute_07.accepts b/testdata/permute_07.accepts index 9b9d829..9474bf1 100644 --- a/testdata/permute_07.accepts +++ b/testdata/permute_07.accepts @@ -1,16 +1,16 @@ -{"X": [], "Y": []} -{"X": [0], "Y": [0]} -{"X": [1], "Y": [0]} -{"X": [0], "Y": [1]} -{"X": [2], "Y": [0]} -;; n=2 -{"X": [0,0], "Y": [0,0]} -{"X": [1,0], "Y": [0,1]} -{"X": [2,1], "Y": [1,0]} -{"X": [2,0], "Y": [0,1]} +{"ST": [], "X": [], "Y": []} +{"ST": [1], "X": [1], "Y": [1]} +{"ST": [1], "X": [2], "Y": [1]} +{"ST": [1], "X": [1], "Y": [2]} +{"ST": [1], "X": [3], "Y": [1]} ;; n=3 -{"X": [0,0,0], "Y": [0,0,0]} -{"X": [3,1,4], "Y": [1,0,3]} -{"X": [3,0,1], "Y": [1,1,0]} -{"X": [3,0,1], "Y": [1,1,0]} -{"X": [5,2,3], "Y": [3,0,2]} +{"ST": [1,1], "X": [1,1], "Y": [1,1]} +{"ST": [1,1], "X": [2,1], "Y": [1,2]} +{"ST": [1,1], "X": [3,2], "Y": [2,1]} +{"ST": [1,1], "X": [3,1], "Y": [1,2]} +;; n=3 +{"ST": [1,1,1], "X": [1,1,1], "Y": [1,1,1]} +{"ST": [1,1,1], "X": [4,2,5], "Y": [2,1,4]} +{"ST": [1,1,1], "X": [4,1,2], "Y": [2,2,1]} +{"ST": [1,1,1], "X": [4,1,2], "Y": [2,2,1]} +{"ST": [1,1,1], "X": [6,3,4], "Y": [4,1,3]} diff --git a/testdata/permute_07.lisp b/testdata/permute_07.lisp index 02e693c..cd8c231 100644 --- a/testdata/permute_07.lisp +++ b/testdata/permute_07.lisp @@ -1,4 +1,5 @@ +(column ST :u16) (column X :u16) (column Y :u16) -(permute (A B) (-X +Y)) -(vanish diag_ab (- (shift A 1) B)) +(permute (ST' A B) (+ST -X +Y)) +(vanish diag_ab (* ST' (- (shift A 1) B))) diff --git a/testdata/permute_07.rejects b/testdata/permute_07.rejects index 42eb25a..b1b72c5 100644 --- a/testdata/permute_07.rejects +++ b/testdata/permute_07.rejects @@ -1,7 +1,7 @@ -{"X": [0,0], "Y": [1,1]} -{"X": [1,2], "Y": [1,0]} -{"X": [1,1], "Y": [1,0]} -{"X": [0,3,2], "Y": [0,0,0]} -{"X": [0,3,2], "Y": [3,0,2]} -{"X": [0,3,2], "Y": [3,1,2]} -{"X": [0,3,2], "Y": [3,2,2]} +{"ST": [1,1], "X": [0,0], "Y": [1,1]} +{"ST": [1,1], "X": [1,2], "Y": [1,0]} +{"ST": [1,1], "X": [1,1], "Y": [1,0]} +{"ST": [1,1,1], "X": [0,3,2], "Y": [0,0,0]} +{"ST": [1,1,1], "X": [0,3,2], "Y": [3,0,2]} +{"ST": [1,1,1], "X": [0,3,2], "Y": [3,1,2]} +{"ST": [1,1,1], "X": [0,3,2], "Y": [3,2,2]} diff --git a/testdata/permute_09.accepts b/testdata/permute_09.accepts index 5b3ae2b..394cd0a 100644 --- a/testdata/permute_09.accepts +++ b/testdata/permute_09.accepts @@ -1,17 +1,16 @@ -{"X": [], "Y": []} -{"X": [0], "Y": [0]} -{"X": [1], "Y": [0]} +{"ST": [], "X": [], "Y": []} +{"ST": [1], "X": [0], "Y": [0]} +{"ST": [1], "X": [1], "Y": [0]} ;; n=2 -{"X": [0,0], "Y": [0,0]} -{"X": [2,1], "Y": [1,0]} -{"X": [3,2], "Y": [2,0]} -{"X": [2,3], "Y": [0,2]} -{"X": [2,2], "Y": [0,2]} -{"X": [2,1], "Y": [1,0]} -{"X": [2,2], "Y": [2,0]} +{"ST": [1,1], "X": [0,0], "Y": [0,0]} +{"ST": [1,1], "X": [2,1], "Y": [1,0]} +{"ST": [1,1], "X": [3,2], "Y": [2,0]} +{"ST": [1,1], "X": [2,3], "Y": [0,2]} +{"ST": [1,1], "X": [2,2], "Y": [0,2]} +{"ST": [1,1], "X": [2,1], "Y": [1,0]} +{"ST": [1,1], "X": [2,2], "Y": [2,0]} ;; n=3 -{"X": [0,0,0], "Y": [0,0,0]} -;; -{"X": [3,1,2], "Y": [2,0,1]} -{"X": [3,1,2], "Y": [2,0,1]} -{"X": [2,2,1], "Y": [1,2,0]} +{"ST": [1,1,1], "X": [0,0,0], "Y": [0,0,0]} +{"ST": [1,1,1], "X": [3,1,2], "Y": [2,0,1]} +{"ST": [1,1,1], "X": [3,1,2], "Y": [2,0,1]} +{"ST": [1,1,1], "X": [2,2,1], "Y": [1,2,0]} diff --git a/testdata/permute_09.lisp b/testdata/permute_09.lisp index 0b4df1d..330a4f5 100644 --- a/testdata/permute_09.lisp +++ b/testdata/permute_09.lisp @@ -1,4 +1,5 @@ +(column ST :u16) (column X :u16) (column Y :u16) -(permute (A B) (-X -Y)) -(vanish diag_ab (- (shift A 1) B)) +(permute (ST' A B) (+ST -X -Y)) +(vanish diag_ab (* ST' (- (shift A 1) B))) diff --git a/testdata/permute_09.rejects b/testdata/permute_09.rejects index 9ff18f5..adb6824 100644 --- a/testdata/permute_09.rejects +++ b/testdata/permute_09.rejects @@ -1,6 +1,6 @@ -{"X": [0,0], "Y": [1,1]} -{"X": [1,2], "Y": [1,0]} -{"X": [0,3,2], "Y": [0,0,0]} -{"X": [0,3,2], "Y": [3,0,2]} -{"X": [0,3,2], "Y": [3,1,2]} -{"X": [0,3,2], "Y": [3,2,2]} +{"ST": [1,1], "X": [0,0], "Y": [1,1]} +{"ST": [1,1], "X": [1,2], "Y": [1,0]} +{"ST": [1,1,1], "X": [0,3,2], "Y": [0,0,0]} +{"ST": [1,1,1], "X": [0,3,2], "Y": [3,0,2]} +{"ST": [1,1,1], "X": [0,3,2], "Y": [3,1,2]} +{"ST": [1,1,1], "X": [0,3,2], "Y": [3,2,2]} diff --git a/testdata/shift_07.accepts b/testdata/shift_07.accepts new file mode 100644 index 0000000..2a86f0c --- /dev/null +++ b/testdata/shift_07.accepts @@ -0,0 +1,5 @@ +{"BIT_1": [], "ARG": []} +{"BIT_1": [0], "ARG": [0]} +{"BIT_1": [0], "ARG": [1]} +{"BIT_1": [0], "ARG": [2]} +{"BIT_1": [0,1], "ARG": [2,0]} diff --git a/testdata/shift_07.lisp b/testdata/shift_07.lisp new file mode 100644 index 0000000..6e4bf8d --- /dev/null +++ b/testdata/shift_07.lisp @@ -0,0 +1,8 @@ +(column BIT_1 :u1) +(column ARG) + +(vanish pivot + ;; If BIT_1[k-1]=0 and BIT_1[k]=1 + (if (+ (shift BIT_1 -1) (- 1 BIT_1)) + ;; Then ARG = 0 + ARG)) diff --git a/testdata/shift_07.rejects b/testdata/shift_07.rejects new file mode 100644 index 0000000..4e7d29f --- /dev/null +++ b/testdata/shift_07.rejects @@ -0,0 +1,2 @@ +{"BIT_1": [1], "ARG": [1]} +{"BIT_1": [0,1], "ARG": [2,1]} diff --git a/testdata/spillage_01.accepts b/testdata/spillage_01.accepts new file mode 100644 index 0000000..8a3e0a7 --- /dev/null +++ b/testdata/spillage_01.accepts @@ -0,0 +1 @@ +{ "A": {} } diff --git a/testdata/spillage_01.lisp b/testdata/spillage_01.lisp new file mode 100644 index 0000000..73469a7 --- /dev/null +++ b/testdata/spillage_01.lisp @@ -0,0 +1,2 @@ +(column A) +(vanish spills (* A (~ (shift A 1)))) diff --git a/testdata/spillage_01.rejects b/testdata/spillage_01.rejects new file mode 100644 index 0000000..e69de29