diff --git a/pkg/air/eval.go b/pkg/air/eval.go index 1f14a7f..71d0fa9 100644 --- a/pkg/air/eval.go +++ b/pkg/air/eval.go @@ -9,7 +9,7 @@ import ( // value at that row of the column in question or nil is that row is // out-of-bounds. func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element { - val := tbl.GetByName(e.Column, k+e.Shift) + val := tbl.ColumnByIndex(e.Column).Get(k + e.Shift) var clone fr.Element // Clone original value diff --git a/pkg/air/expr.go b/pkg/air/expr.go index 1f37aa7..47c87d2 100644 --- a/pkg/air/expr.go +++ b/pkg/air/expr.go @@ -143,14 +143,14 @@ func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND } // accesses the STAMP column at row 5, whilst CT(-1) accesses the CT column at // row 4. type ColumnAccess struct { - Column string + Column uint Shift int } // NewColumnAccess constructs an AIR expression representing the value of a given // column on the current row. -func NewColumnAccess(name string, shift int) Expr { - return &ColumnAccess{name, shift} +func NewColumnAccess(column uint, shift int) Expr { + return &ColumnAccess{column, shift} } // Add two expressions together, producing a third. diff --git a/pkg/air/gadgets/bits.go b/pkg/air/gadgets/bits.go index 8e8c131..27ac0d5 100644 --- a/pkg/air/gadgets/bits.go +++ b/pkg/air/gadgets/bits.go @@ -11,21 +11,23 @@ import ( // ApplyBinaryGadget adds a binarity constraint for a given column in the schema // which enforces that all values in the given column are either 0 or 1. For a // column X, this corresponds to the vanishing constraint X * (X-1) == 0. -func ApplyBinaryGadget(col string, schema *air.Schema) { +func ApplyBinaryGadget(column uint, schema *air.Schema) { + // Determine column name + name := schema.Column(column).Name() // Construct X - X := air.NewColumnAccess(col, 0) + X := air.NewColumnAccess(column, 0) // Construct X-1 X_m1 := X.Sub(air.NewConst64(1)) // Construct X * (X-1) X_X_m1 := X.Mul(X_m1) // Done! - schema.AddVanishingConstraint(col, nil, X_X_m1) + schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), nil, X_X_m1) } // ApplyBitwidthGadget ensures all values in a given column fit within a given // number of bits. This is implemented using a *byte decomposition* which adds // n columns and a vanishing constraint (where n*8 >= nbits). -func ApplyBitwidthGadget(col string, nbits uint, schema *air.Schema) { +func ApplyBitwidthGadget(column uint, nbits uint, schema *air.Schema) { if nbits%8 != 0 { panic("asymmetric bitwidth constraints not yet supported") } else if nbits == 0 { @@ -35,38 +37,41 @@ func ApplyBitwidthGadget(col string, nbits uint, schema *air.Schema) { n := nbits / 8 es := make([]air.Expr, n) fr256 := fr.NewElement(256) + name := schema.Column(column).Name() coefficient := fr.NewElement(1) // Construct Columns for i := uint(0); i < n; i++ { // Determine name for the ith byte column - colName := fmt.Sprintf("%s:%d", col, i) + colName := fmt.Sprintf("%s:%d", name, i) // Create Column + Constraint - schema.AddColumn(colName, true) - schema.AddRangeConstraint(colName, &fr256) - es[i] = air.NewColumnAccess(colName, 0).Mul(air.NewConstCopy(&coefficient)) + colIndex := schema.AddColumn(colName, true) + es[i] = air.NewColumnAccess(colIndex, 0).Mul(air.NewConstCopy(&coefficient)) + + schema.AddRangeConstraint(colIndex, &fr256) // Update coefficient coefficient.Mul(&coefficient, &fr256) } // Construct (X:0 * 1) + ... + (X:n * 2^n) sum := &air.Add{Args: es} // Construct X == (X:0 * 1) + ... + (X:n * 2^n) - X := air.NewColumnAccess(col, 0) + X := air.NewColumnAccess(column, 0) eq := X.Equate(sum) // Construct column name - schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", col, nbits), nil, eq) + schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), nil, eq) // Finally, add the necessary byte decomposition computation. - schema.AddComputation(table.NewByteDecomposition(col, nbits)) + schema.AddComputation(table.NewByteDecomposition(name, nbits)) } // AddBitArray adds an array of n bit columns using a given prefix, including // the necessary binarity constraints. -func AddBitArray(prefix string, count int, schema *air.Schema) []string { - bits := make([]string, count) +func AddBitArray(prefix string, count int, schema *air.Schema) []uint { + bits := make([]uint, count) + for i := 0; i < count; i++ { // Construct bit column name - bits[i] = fmt.Sprintf("%s:%d", prefix, i) + ith := fmt.Sprintf("%s:%d", prefix, i) // Add (synthetic) column - schema.AddColumn(bits[i], true) + bits[i] = schema.AddColumn(ith, true) // Add binarity constraints (i.e. to enfoce that this column is a bit). ApplyBinaryGadget(bits[i], schema) } diff --git a/pkg/air/gadgets/column_sort.go b/pkg/air/gadgets/column_sort.go index cf9002e..e614773 100644 --- a/pkg/air/gadgets/column_sort.go +++ b/pkg/air/gadgets/column_sort.go @@ -7,7 +7,7 @@ import ( "github.com/consensys/go-corset/pkg/table" ) -// ApplyColumnSortGadget Add sorting constraints for a column where the +// ApplyColumnSortGadget adds sorting constraints for a column where the // difference between any two rows (i.e. the delta) is constrained to fit within // a given bitwidth. The target column is assumed to have an appropriate // (enforced) bitwidth to ensure overflow cannot arise. The sorting constraint @@ -18,8 +18,10 @@ import ( // This gadget does not attempt to sort the column data during trace expansion, // and assumes the data either comes sorted or is sorted by some other // computation. -func ApplyColumnSortGadget(column string, sign bool, bitwidth uint, schema *air.Schema) { +func ApplyColumnSortGadget(column uint, sign bool, bitwidth uint, schema *air.Schema) { var deltaName string + // Determine column name + name := schema.Column(column).Name() // Configure computation Xk := air.NewColumnAccess(column, 0) Xkm1 := air.NewColumnAccess(column, -1) @@ -27,18 +29,18 @@ func ApplyColumnSortGadget(column string, sign bool, bitwidth uint, schema *air. var Xdiff air.Expr if sign { Xdiff = Xk.Sub(Xkm1) - deltaName = fmt.Sprintf("+%s", column) + deltaName = fmt.Sprintf("+%s", name) } else { Xdiff = Xkm1.Sub(Xk) - deltaName = fmt.Sprintf("-%s", column) + deltaName = fmt.Sprintf("-%s", name) } // Add delta column - schema.AddColumn(deltaName, true) + deltaIndex := schema.AddColumn(deltaName, true) // Add diff computation schema.AddComputation(table.NewComputedColumn(deltaName, Xdiff)) // Add necessary bitwidth constraints - ApplyBitwidthGadget(deltaName, bitwidth, schema) + ApplyBitwidthGadget(deltaIndex, bitwidth, schema) // Configure constraint: Delta[k] = X[k] - X[k-1] - Dk := air.NewColumnAccess(deltaName, 0) + Dk := air.NewColumnAccess(deltaIndex, 0) schema.AddVanishingConstraint(deltaName, nil, Dk.Equate(Xdiff)) } diff --git a/pkg/air/gadgets/lexicographic_sort.go b/pkg/air/gadgets/lexicographic_sort.go index 95c8b53..6ef255f 100644 --- a/pkg/air/gadgets/lexicographic_sort.go +++ b/pkg/air/gadgets/lexicographic_sort.go @@ -26,37 +26,38 @@ import ( // case (see above). The delta value captures the difference Ci[k]-Ci[k-1] to // ensure it is positive. The delta column is constrained to a given bitwidth, // with constraints added as necessary to ensure this. -func ApplyLexicographicSortingGadget(columns []string, signs []bool, bitwidth uint, schema *air.Schema) { +func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint, schema *air.Schema) { // Check preconditions ncols := len(columns) if ncols != len(signs) { panic("Inconsistent number of columns and signs for lexicographic sort.") } - // Add trace computation - schema.AddComputation(&lexicographicSortExpander{columns, signs, bitwidth}) // Construct a unique prefix for this sort. - prefix := constructLexicographicSortingPrefix(columns, signs) + prefix := constructLexicographicSortingPrefix(columns, signs, schema) + // Add trace computation + schema.AddComputation(&lexicographicSortExpander{prefix, columns, signs, bitwidth}) deltaName := fmt.Sprintf("%s:delta", prefix) // Construct selecto bits. bits := addLexicographicSelectorBits(prefix, columns, schema) // Add delta column - schema.AddColumn(deltaName, true) + deltaIndex := schema.AddColumn(deltaName, true) // Construct delta terms - constraint := constructLexicographicDeltaConstraint(deltaName, bits, columns, signs) + constraint := constructLexicographicDeltaConstraint(deltaIndex, bits, columns, signs) // Add delta constraint schema.AddVanishingConstraint(deltaName, nil, constraint) // Add necessary bitwidth constraints - ApplyBitwidthGadget(deltaName, bitwidth, schema) + ApplyBitwidthGadget(deltaIndex, bitwidth, schema) } // Construct a unique identifier for the given sort. This should not conflict // with the identifier for any other sort. -func constructLexicographicSortingPrefix(columns []string, signs []bool) string { +func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *air.Schema) string { // Use string builder to try and make this vaguely efficient. var id strings.Builder // Concatenate column names with their signs. for i := 0; i < len(columns); i++ { - id.WriteString(columns[i]) + ith := schema.Column(columns[i]) + id.WriteString(ith.Name()) if signs[i] { id.WriteString("+") @@ -75,7 +76,7 @@ func constructLexicographicSortingPrefix(columns []string, signs []bool) string // // NOTE: this implementation differs from the original corset which used an // additional "Eq" bit to help ensure at most one selector bit was enabled. -func addLexicographicSelectorBits(prefix string, columns []string, schema *air.Schema) []string { +func addLexicographicSelectorBits(prefix string, columns []uint, schema *air.Schema) []uint { ncols := len(columns) // Add bits and their binary constraints. bits := AddBitArray(prefix, ncols, schema) @@ -123,11 +124,11 @@ func addLexicographicSelectorBits(prefix string, columns []string, schema *air.S // appropriately for the sign) between the ith column whose multiplexor bit is // set. This is assumes that multiplexor bits are mutually exclusive (i.e. at // most is one). -func constructLexicographicDeltaConstraint(deltaName string, bits []string, columns []string, signs []bool) air.Expr { +func constructLexicographicDeltaConstraint(delta uint, bits []uint, columns []uint, signs []bool) air.Expr { ncols := len(columns) // Construct delta terms terms := make([]air.Expr, ncols) - Dk := air.NewColumnAccess(deltaName, 0) + Dk := air.NewColumnAccess(delta, 0) for i := 0; i < ncols; i++ { var Xdiff air.Expr @@ -150,7 +151,8 @@ func constructLexicographicDeltaConstraint(deltaName string, bits []string, colu } type lexicographicSortExpander struct { - columns []string + prefix string + columns []uint signs []bool bitwidth uint } @@ -163,15 +165,15 @@ func (p *lexicographicSortExpander) RequiredSpillage() uint { // Accepts checks whether a given trace has the necessary columns func (p *lexicographicSortExpander) Accepts(tr table.Trace) error { - prefix := constructLexicographicSortingPrefix(p.columns, p.signs) - deltaName := fmt.Sprintf("%s:delta", prefix) + //prefix := constructLexicographicSortingPrefix(p.columns, p.signs) + deltaName := fmt.Sprintf("%s:delta", p.prefix) // Check delta column exists if !tr.HasColumn(deltaName) { return fmt.Errorf("Trace missing lexicographic delta column ({%s})", deltaName) } // Check selector columns exist for i := range p.columns { - bitName := fmt.Sprintf("%s:%d", prefix, i) + bitName := fmt.Sprintf("%s:%d", p.prefix, i) if !tr.HasColumn(bitName) { return fmt.Errorf("Trace missing lexicographic selector column ({%s})", bitName) } @@ -190,8 +192,7 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error { // Determine how many rows to be constrained. nrows := tr.Height() // Construct a unique prefix for this sort. - prefix := constructLexicographicSortingPrefix(p.columns, p.signs) - deltaName := fmt.Sprintf("%s:delta", prefix) + deltaName := fmt.Sprintf("%s:delta", p.prefix) // Initialise new data columns delta := make([]*fr.Element, nrows) bit := make([][]*fr.Element, ncols) @@ -200,14 +201,14 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error { bit[i] = make([]*fr.Element, nrows) } - for i := uint(0); i < nrows; i++ { + for i := 0; i < int(nrows); i++ { set := false // Initialise delta to zero delta[i] = &zero // Decide which row is the winner (if any) for j := 0; j < ncols; j++ { - prev := tr.GetByName(p.columns[j], int(i-1)) - curr := tr.GetByName(p.columns[j], int(i)) + prev := tr.ColumnByIndex(p.columns[j]).Get(i - 1) + curr := tr.ColumnByIndex(p.columns[j]).Get(i) if !set && prev != nil && prev.Cmp(curr) != 0 { var diff fr.Element @@ -228,12 +229,11 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error { } } } - // Add delta column data tr.AddColumn(deltaName, delta, &zero) // Add bit column data for i := 0; i < ncols; i++ { - bitName := fmt.Sprintf("%s:%d", prefix, i) + bitName := fmt.Sprintf("%s:%d", p.prefix, i) tr.AddColumn(bitName, bit[i], &zero) } // Done. @@ -243,5 +243,5 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error { // String returns a string representation of this constraint. This is primarily // used for debugging. func (p *lexicographicSortExpander) String() string { - return fmt.Sprintf("(lexer (%s) (%v) :%d))", any(p.columns), p.signs, p.bitwidth) + return fmt.Sprintf("(lexer (%v) (%v) :%d))", any(p.columns), p.signs, p.bitwidth) } diff --git a/pkg/air/gadgets/normalisation.go b/pkg/air/gadgets/normalisation.go index 70e5e33..5a80f02 100644 --- a/pkg/air/gadgets/normalisation.go +++ b/pkg/air/gadgets/normalisation.go @@ -30,15 +30,17 @@ func ApplyPseudoInverseGadget(e air.Expr, tbl *air.Schema) air.Expr { ie := &Inverse{Expr: e} // Determine computed column name name := ie.String() + // Look up column + index, ok := tbl.ColumnIndex(name) // Add new column (if it does not already exist) - if !tbl.HasColumn(name) { + if !ok { // Add (synthetic) computed column - tbl.AddColumn(name, true) + index = tbl.AddColumn(name, true) tbl.AddComputation(table.NewComputedColumn(name, ie)) } // Construct 1/e - inv_e := air.NewColumnAccess(name, 0) + inv_e := air.NewColumnAccess(index, 0) // Construct e/e e_inv_e := e.Mul(inv_e) // Construct 1 == e/e @@ -54,7 +56,7 @@ func ApplyPseudoInverseGadget(e air.Expr, tbl *air.Schema) air.Expr { r_name := fmt.Sprintf("[%s =>]", ie.String()) tbl.AddVanishingConstraint(r_name, nil, inv_e_implies_one_e_e) // Done - return air.NewColumnAccess(name, 0) + return air.NewColumnAccess(index, 0) } // Inverse represents a computation which computes the multiplicative @@ -66,10 +68,6 @@ type Inverse struct{ Expr air.Expr } func (e *Inverse) EvalAt(k int, tbl table.Trace) *fr.Element { inv := new(fr.Element) val := e.Expr.EvalAt(k, tbl) - // Catch undefined case - if val == nil { - return nil - } // Go syntax huh? return inv.Inverse(val) } diff --git a/pkg/air/schema.go b/pkg/air/schema.go index 6ee0b0c..577c043 100644 --- a/pkg/air/schema.go +++ b/pkg/air/schema.go @@ -1,9 +1,6 @@ package air import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/table" "github.com/consensys/go-corset/pkg/util" @@ -20,7 +17,7 @@ type VanishingConstraint = *table.RowConstraint[table.ZeroTest[Expr]] // PropertyAssertion captures the notion of an arbitrary property which should // hold for all acceptable traces. However, such a property is not enforced by // the prover. -type PropertyAssertion = *table.PropertyAssertion[table.Evaluable] +type PropertyAssertion = *table.PropertyAssertion[table.ZeroTest[table.Evaluable]] // Permutation captures the notion of a simple column permutation at the AIR // level. @@ -60,6 +57,21 @@ func EmptySchema[C table.Evaluable]() *Schema { return p } +// Width returns the number of column groups in this schema. +func (p *Schema) Width() uint { + return uint(len(p.dataColumns)) +} + +// ColumnGroup returns information about the ith column group in this schema. +func (p *Schema) ColumnGroup(i uint) table.ColumnGroup { + return p.dataColumns[i] +} + +// Column returns information about the ith column in this schema. +func (p *Schema) Column(i uint) table.ColumnSchema { + return p.dataColumns[i] +} + // Size returns the number of declarations in this schema. func (p *Schema) Size() int { return len(p.dataColumns) + len(p.permutations) + len(p.vanishing) + @@ -81,7 +93,7 @@ func (p *Schema) Columns() []DataColumn { // HasColumn checks whether a given schema has a given column. func (p *Schema) HasColumn(name string) bool { for _, c := range p.dataColumns { - if c.Name == name { + if c.Name() == name { return true } } @@ -89,6 +101,18 @@ func (p *Schema) HasColumn(name string) bool { return false } +// ColumnIndex determines the column index for a given column in this schema, or +// returns false indicating an error. +func (p *Schema) ColumnIndex(name string) (uint, bool) { + for i, c := range p.dataColumns { + if c.Name() == name { + return uint(i), true + } + } + + return 0, false +} + // RequiredSpillage returns the minimum amount of spillage required to ensure // valid traces are accepted in the presence of arbitrary padding. Spillage can // only arise from computations as this is where values outside of the user's @@ -105,77 +129,18 @@ func (p *Schema) RequiredSpillage() uint { return mx } -// IsInputTrace determines whether a given input trace is a suitable -// input (i.e. non-expanded) trace for this schema. Specifically, the -// input trace must contain a matching column for each non-synthetic -// column in this trace. -func (p *Schema) IsInputTrace(tr table.Trace) error { - count := uint(0) - - for _, c := range p.dataColumns { - if !c.Synthetic && !tr.HasColumn(c.Name) { - msg := fmt.Sprintf("Trace missing input column ({%s})", c.Name) - return errors.New(msg) - } else if c.Synthetic && tr.HasColumn(c.Name) { - msg := fmt.Sprintf("Trace has synthetic column ({%s})", c.Name) - return errors.New(msg) - } else if !c.Synthetic { - count = count + 1 - } - } - // Check geometry - if tr.Width() != count { - // Determine the unknown columns for error reporting. - unknown := make([]string, 0) - - for i := uint(0); i < tr.Width(); i++ { - n := tr.ColumnName(int(i)) - if !p.HasColumn(n) { - unknown = append(unknown, n) - } - } - - msg := fmt.Sprintf("Trace has unknown columns {%s}", unknown) - - return errors.New(msg) - } - // Done - return nil -} - -// IsOutputTrace determines whether a given input trace is a suitable -// output (i.e. expanded) trace for this schema. Specifically, the -// output trace must contain a matching column for each column in this -// trace (synthetic or otherwise). -func (p *Schema) IsOutputTrace(tr table.Trace) error { - count := uint(0) - - for _, c := range p.dataColumns { - if !tr.HasColumn(c.Name) { - msg := fmt.Sprintf("Trace missing input column ({%s})", c.Name) - return errors.New(msg) - } - - count++ - } - // Check geometry - if tr.Width() != count { - return errors.New("Trace has unknown columns") - } - // Done - return nil -} - // AddColumn appends a new data column which is either synthetic or // not. A synthetic column is one which has been introduced by the // process of lowering from HIR / MIR to AIR. That is, it is not a // column which was original specified by the user. Columns also support a // "padding sign", which indicates whether padding should occur at the front // (positive sign) or the back (negative sign). -func (p *Schema) AddColumn(name string, synthetic bool) { +func (p *Schema) AddColumn(name string, synthetic bool) uint { // NOTE: the air level has no ability to enforce the type specified for a // given column. p.dataColumns = append(p.dataColumns, table.NewDataColumn(name, &table.FieldType{}, synthetic)) + // Calculate column index + return uint(len(p.dataColumns) - 1) } // AddComputation appends a new computation to be used during trace @@ -186,7 +151,7 @@ func (p *Schema) AddComputation(c table.TraceComputation) { // AddPermutationConstraint appends a new permutation constraint which // ensures that one column is a permutation of another. -func (p *Schema) AddPermutationConstraint(targets []string, sources []string) { +func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) { p.permutations = append(p.permutations, table.NewPermutation(targets, sources)) } @@ -196,7 +161,7 @@ func (p *Schema) AddVanishingConstraint(handle string, domain *int, expr Expr) { } // AddRangeConstraint appends a new range constraint. -func (p *Schema) AddRangeConstraint(column string, bound *fr.Element) { +func (p *Schema) AddRangeConstraint(column uint, bound *fr.Element) { p.ranges = append(p.ranges, table.NewRangeConstraint(column, bound)) } diff --git a/pkg/air/string.go b/pkg/air/string.go index d0883c5..bb68e18 100644 --- a/pkg/air/string.go +++ b/pkg/air/string.go @@ -6,10 +6,10 @@ import ( func (e *ColumnAccess) String() string { if e.Shift == 0 { - return e.Column + return fmt.Sprintf("#%d", e.Column) } - return fmt.Sprintf("(shift %s %d)", e.Column, e.Shift) + return fmt.Sprintf("(shift #%d %d)", e.Column, e.Shift) } func (e *Constant) String() string { diff --git a/pkg/cmd/check.go b/pkg/cmd/check.go index a834027..884800d 100644 --- a/pkg/cmd/check.go +++ b/pkg/cmd/check.go @@ -165,11 +165,19 @@ func checkTrace(tr *table.ArrayTrace, schema table.Schema, cfg checkConfig) (tab // Apply default inferred spillage tr.Pad(schema.RequiredSpillage()) } + // Perform Input Alignment + if err := tr.AlignInputWith(schema); err != nil { + return tr, err + } // Expand trace if err := schema.ExpandTrace(tr); err != nil { return tr, err } } + // Perform Alignment + if err := tr.AlignWith(schema); err != nil { + return tr, err + } // Check whether padding requested if cfg.padding.Left == 0 && cfg.padding.Right == 0 { // No padding requested. Therefore, we can avoid a clone in this case. diff --git a/pkg/cmd/compute.go b/pkg/cmd/compute.go index a8266e9..20be4f9 100644 --- a/pkg/cmd/compute.go +++ b/pkg/cmd/compute.go @@ -29,7 +29,7 @@ var computeCmd = &cobra.Command{ } // Print columns for _, c := range schema.Columns() { - fmt.Printf("column %s : %s\n", c.Name, c.Type) + fmt.Printf("column %s : %s\n", c.Name(), c.Type) } // Print constraints for _, c := range schema.Constraints() { diff --git a/pkg/hir/eval.go b/pkg/hir/eval.go index 55349c1..0eb6a43 100644 --- a/pkg/hir/eval.go +++ b/pkg/hir/eval.go @@ -9,7 +9,7 @@ import ( // value at that row of the column in question or nil is that row is // out-of-bounds. func (e *ColumnAccess) EvalAllAt(k int, tbl table.Trace) []*fr.Element { - val := tbl.GetByName(e.Column, k+e.Shift) + val := tbl.ColumnByName(e.Column).Get(k + e.Shift) var clone fr.Element // Clone original value diff --git a/pkg/hir/expr.go b/pkg/hir/expr.go index 7766be9..fdbd27a 100644 --- a/pkg/hir/expr.go +++ b/pkg/hir/expr.go @@ -21,7 +21,7 @@ type Expr interface { // Representation. Observe that a single expression at this // level can expand into *multiple* expressions at the MIR // level. - LowerTo() []mir.Expr + LowerTo(*mir.Schema) []mir.Expr // EvalAt evaluates this expression in a given tabular context. // Observe that if this expression is *undefined* within this // context then it returns "nil". An expression can be diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index c96f78b..0b7d37c 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -10,58 +10,58 @@ import ( // LowerTo lowers a sum expression to the MIR level. This requires expanding // the arguments, then lowering them. Furthermore, conditionals are "lifted" to // the top. -func (e *Add) LowerTo() []mir.Expr { - return lowerTo(e) +func (e *Add) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) } // LowerTo lowers a constant to the MIR level. This requires expanding the // arguments, then lowering them. Furthermore, conditionals are "lifted" to the // top. -func (e *Constant) LowerTo() []mir.Expr { - return lowerTo(e) +func (e *Constant) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) } // LowerTo lowers a column access to the MIR level. This requires expanding // the arguments, then lowering them. Furthermore, conditionals are "lifted" to // the top. -func (e *ColumnAccess) LowerTo() []mir.Expr { - return lowerTo(e) +func (e *ColumnAccess) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) } // LowerTo lowers a product expression to the MIR level. This requires expanding // the arguments, then lowering them. Furthermore, conditionals are "lifted" to // the top. -func (e *Mul) LowerTo() []mir.Expr { - return lowerTo(e) +func (e *Mul) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) } // LowerTo lowers a list expression to the MIR level by eliminating it // altogether. This still requires expanding the arguments, then lowering them. // Furthermore, conditionals are "lifted" to the top.. -func (e *List) LowerTo() []mir.Expr { - return lowerTo(e) +func (e *List) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) } // LowerTo lowers a normalise expression to the MIR level. This requires // expanding the arguments, then lowering them. Furthermore, conditionals are // "lifted" to the top.. -func (e *Normalise) LowerTo() []mir.Expr { - return lowerTo(e) +func (e *Normalise) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) } // LowerTo lowers an if expression to the MIR level by "compiling out" the // expression using normalisation at the MIR level. This also requires // expanding the arguments, then lowering them. Furthermore, conditionals are // "lifted" to the top. -func (e *IfZero) LowerTo() []mir.Expr { - return lowerTo(e) +func (e *IfZero) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) } // LowerTo lowers a subtract expression to the MIR level. This also requires // expanding the arguments, then lowering them. Furthermore, conditionals are // "lifted" to the top. -func (e *Sub) LowerTo() []mir.Expr { - return lowerTo(e) +func (e *Sub) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) } // ============================================================================ @@ -71,15 +71,15 @@ func (e *Sub) LowerTo() []mir.Expr { // Lowers a given expression to the MIR level. The expression is first expanded // into one or more target expressions. Furthermore, conditions must be "lifted" // to the root. -func lowerTo(e Expr) []mir.Expr { +func lowerTo(e Expr, schema *mir.Schema) []mir.Expr { // First expand expression es := expand(e) // Now lower each one (carefully) mes := make([]mir.Expr, len(es)) // for i, e := range es { - c := lowerCondition(e) - b := lowerBody(e) + c := lowerCondition(e, schema) + b := lowerBody(e, schema) mes[i] = mul2(c, b) } // Done @@ -89,30 +89,30 @@ func lowerTo(e Expr) []mir.Expr { // Lower the "condition" of an expression. Every expression can be view as a // conditional constraint of the form "if c then e", where "c" is the condition. // This is allowed to return nil if the body is unconditional. -func lowerCondition(e Expr) mir.Expr { +func lowerCondition(e Expr, schema *mir.Schema) mir.Expr { if p, ok := e.(*Add); ok { - return lowerConditions(p.Args) + return lowerConditions(p.Args, schema) } else if _, ok := e.(*Constant); ok { return nil } else if _, ok := e.(*ColumnAccess); ok { return nil } else if p, ok := e.(*Mul); ok { - return lowerConditions(p.Args) + return lowerConditions(p.Args, schema) } else if p, ok := e.(*Normalise); ok { - return lowerCondition(p.Arg) + return lowerCondition(p.Arg, schema) } else if p, ok := e.(*IfZero); ok { - return lowerIfZeroCondition(p) + return lowerIfZeroCondition(p, schema) } else if p, ok := e.(*Sub); ok { - return lowerConditions(p.Args) + return lowerConditions(p.Args, schema) } // Should be unreachable panic(fmt.Sprintf("unknown expression: %s", e.String())) } -func lowerConditions(es []Expr) mir.Expr { +func lowerConditions(es []Expr, schema *mir.Schema) mir.Expr { var r mir.Expr = nil for _, e := range es { - r = mul2(r, lowerCondition(e)) + r = mul2(r, lowerCondition(e, schema)) } return r @@ -120,11 +120,11 @@ func lowerConditions(es []Expr) mir.Expr { // Lowering conditional expressions is slightly more complex than others, so it // gets a case of its own. -func lowerIfZeroCondition(e *IfZero) mir.Expr { +func lowerIfZeroCondition(e *IfZero, schema *mir.Schema) mir.Expr { var bc mir.Expr // Lower condition - cc := lowerCondition(e.Condition) - cb := lowerBody(e.Condition) + cc := lowerCondition(e.Condition, schema) + cb := lowerBody(e.Condition, schema) // Add conditions arising if e.TrueBranch != nil && e.FalseBranch != nil { // Expansion should ensure this case does not exist. This is necessary @@ -145,10 +145,10 @@ func lowerIfZeroCondition(e *IfZero) mir.Expr { cb = oneMinusNormBody // Lower conditional's arising from body - bc = lowerCondition(e.TrueBranch) + bc = lowerCondition(e.TrueBranch, schema) } else { // Lower conditional's arising from body - bc = lowerCondition(e.FalseBranch) + bc = lowerCondition(e.FalseBranch, schema) } // return mul3(cc, cb, bc) @@ -157,39 +157,44 @@ func lowerIfZeroCondition(e *IfZero) mir.Expr { // Translate the "body" of an expression. Every expression can be view as a // conditional constraint of the form "if c then e", where "e" is the // constraint. -func lowerBody(e Expr) mir.Expr { +func lowerBody(e Expr, schema *mir.Schema) mir.Expr { if p, ok := e.(*Add); ok { - return &mir.Add{Args: lowerBodies(p.Args)} + return &mir.Add{Args: lowerBodies(p.Args, schema)} } else if p, ok := e.(*Constant); ok { return &mir.Constant{Value: p.Val} } else if p, ok := e.(*ColumnAccess); ok { - return &mir.ColumnAccess{Column: p.Column, Shift: p.Shift} + if index, ok := schema.ColumnIndex(p.Column); ok { + return &mir.ColumnAccess{Column: index, Shift: p.Shift} + } + // Should be unreachable as all columns should have been vetted earlier + // in the pipeline. + panic(fmt.Sprintf("invalid column access for %s", p.Column)) } else if p, ok := e.(*Mul); ok { - return &mir.Mul{Args: lowerBodies(p.Args)} + return &mir.Mul{Args: lowerBodies(p.Args, schema)} } else if p, ok := e.(*Normalise); ok { - return &mir.Normalise{Arg: lowerBody(p.Arg)} + return &mir.Normalise{Arg: lowerBody(p.Arg, schema)} } else if p, ok := e.(*IfZero); ok { if p.TrueBranch != nil && p.FalseBranch != nil { // Expansion should ensure this case does not exist. This is necessary // to ensure exactly one expression is generated from this expression. panic(fmt.Sprintf("unexpanded expression (%s)", e.String())) } else if p.TrueBranch != nil { - return lowerBody(p.TrueBranch) + return lowerBody(p.TrueBranch, schema) } // Done - return lowerBody(p.FalseBranch) + return lowerBody(p.FalseBranch, schema) } else if p, ok := e.(*Sub); ok { - return &mir.Sub{Args: lowerBodies(p.Args)} + return &mir.Sub{Args: lowerBodies(p.Args, schema)} } // Should be unreachable panic(fmt.Sprintf("unknown expression: %s", e.String())) } // Lower a vector of expanded expressions to the MIR level. -func lowerBodies(es []Expr) []mir.Expr { +func lowerBodies(es []Expr, schema *mir.Schema) []mir.Expr { rs := make([]mir.Expr, len(es)) for i, e := range es { - rs[i] = lowerBody(e) + rs[i] = lowerBody(e, schema) } return rs diff --git a/pkg/hir/parser.go b/pkg/hir/parser.go index 292d08d..bec6e98 100644 --- a/pkg/hir/parser.go +++ b/pkg/hir/parser.go @@ -175,10 +175,8 @@ func (p *hirParser) parseAssertionDeclaration(elements []sexp.SExp) error { if err != nil { return err } - // Add all assertions arising. - for _, e := range expr.LowerTo() { - p.schema.AddPropertyAssertion(handle, e) - } + // Add assertion. + p.schema.AddPropertyAssertion(handle, expr) return nil } diff --git a/pkg/hir/schema.go b/pkg/hir/schema.go index 3a2ac34..7d7c4bd 100644 --- a/pkg/hir/schema.go +++ b/pkg/hir/schema.go @@ -50,7 +50,7 @@ type VanishingConstraint = *table.RowConstraint[ZeroArrayTest] // PropertyAssertion captures the notion of an arbitrary property which should // hold for all acceptable traces. However, such a property is not enforced by // the prover. -type PropertyAssertion = mir.PropertyAssertion +type PropertyAssertion = *table.PropertyAssertion[ZeroArrayTest] // Permutation captures the notion of a (sorted) permutation at the HIR level. type Permutation = *table.SortedPermutation @@ -79,10 +79,50 @@ func EmptySchema() *Schema { return p } +// Column returns information about the ith column in this schema. +func (p *Schema) Column(i uint) table.ColumnSchema { + panic("todo") +} + +// Width returns the number of column groups in this schema. +func (p *Schema) Width() uint { + return uint(len(p.dataColumns) + len(p.permutations)) +} + +// ColumnGroup returns information about the ith column group in this schema. +func (p *Schema) ColumnGroup(i uint) table.ColumnGroup { + n := uint(len(p.dataColumns)) + if i < n { + return p.dataColumns[i] + } + + return p.permutations[i-n] +} + +// ColumnIndex determines the column index for a given column in this schema, or +// returns false indicating an error. +func (p *Schema) ColumnIndex(name string) (uint, bool) { + index := uint(0) + + for i := uint(0); i < p.Width(); i++ { + ith := p.ColumnGroup(i) + for j := uint(0); j < ith.Width(); j++ { + if ith.NameOf(j) == name { + // hit + return index, true + } + + index++ + } + } + // miss + return 0, false +} + // HasColumn checks whether a given schema has a given column. func (p *Schema) HasColumn(name string) bool { for _, c := range p.dataColumns { - if (*c).Name == name { + if (*c).Name() == name { return true } } @@ -141,8 +181,8 @@ func (p *Schema) AddVanishingConstraint(handle string, domain *int, expr Expr) { } // AddPropertyAssertion appends a new property assertion. -func (p *Schema) AddPropertyAssertion(handle string, expr mir.Expr) { - p.assertions = append(p.assertions, table.NewPropertyAssertion[mir.Expr](handle, expr)) +func (p *Schema) AddPropertyAssertion(handle string, property Expr) { + p.assertions = append(p.assertions, table.NewPropertyAssertion[ZeroArrayTest](handle, ZeroArrayTest{property})) } // Accepts determines whether this schema will accept a given trace. That @@ -190,7 +230,7 @@ func (p *Schema) LowerToMir() *mir.Schema { mirSchema := mir.EmptySchema() // First, lower columns for _, col := range p.dataColumns { - mirSchema.AddDataColumn(col.Name, col.Type) + mirSchema.AddDataColumn(col.Name(), col.Type) } // Second, lower permutations for _, col := range p.permutations { @@ -198,7 +238,7 @@ func (p *Schema) LowerToMir() *mir.Schema { } // Third, lower constraints for _, c := range p.vanishing { - mir_exprs := c.Constraint.Expr.LowerTo() + mir_exprs := c.Constraint.Expr.LowerTo(mirSchema) // Add individual constraints arising for _, mir_expr := range mir_exprs { mirSchema.AddVanishingConstraint(c.Handle, c.Domain, mir_expr) @@ -207,7 +247,10 @@ func (p *Schema) LowerToMir() *mir.Schema { // Fourth, copy property assertions. Observe, these do not require lowering // because they are already MIR-level expressions. for _, c := range p.assertions { - mirSchema.AddPropertyAssertion(c.Handle, c.Expr) + properties := c.Property.Expr.LowerTo(mirSchema) + for _, p := range properties { + mirSchema.AddPropertyAssertion(c.Handle, p) + } } // return mirSchema diff --git a/pkg/mir/eval.go b/pkg/mir/eval.go index fca0eb6..9ccbddb 100644 --- a/pkg/mir/eval.go +++ b/pkg/mir/eval.go @@ -9,7 +9,7 @@ import ( // value at that row of the column in question or nil is that row is // out-of-bounds. func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element { - val := tbl.GetByName(e.Column, k+e.Shift) + val := tbl.ColumnByIndex(e.Column).Get(k + e.Shift) var clone fr.Element // Clone original value diff --git a/pkg/mir/expr.go b/pkg/mir/expr.go index 620fb37..505e9f2 100644 --- a/pkg/mir/expr.go +++ b/pkg/mir/expr.go @@ -73,7 +73,7 @@ func (p *Normalise) Bounds() util.Bounds { return p.Arg.Bounds() } // accesses the STAMP column at row 5, whilst CT(-1) accesses the CT column at // row 4. type ColumnAccess struct { - Column string + Column uint Shift int } diff --git a/pkg/mir/lower.go b/pkg/mir/lower.go index fe1c086..805bdb5 100644 --- a/pkg/mir/lower.go +++ b/pkg/mir/lower.go @@ -6,49 +6,49 @@ import ( ) // LowerTo lowers a sum expression to the AIR level by lowering the arguments. -func (e *Add) LowerTo(tbl *air.Schema) air.Expr { - return &air.Add{Args: lowerExprs(e.Args, tbl)} +func (e *Add) LowerTo(schema *air.Schema) air.Expr { + return &air.Add{Args: lowerExprs(e.Args, schema)} } // LowerTo lowers a subtract expression to the AIR level by lowering the arguments. -func (e *Sub) LowerTo(tbl *air.Schema) air.Expr { - return &air.Sub{Args: lowerExprs(e.Args, tbl)} +func (e *Sub) LowerTo(schema *air.Schema) air.Expr { + return &air.Sub{Args: lowerExprs(e.Args, schema)} } // LowerTo lowers a product expression to the AIR level by lowering the arguments. -func (e *Mul) LowerTo(tbl *air.Schema) air.Expr { - return &air.Mul{Args: lowerExprs(e.Args, tbl)} +func (e *Mul) LowerTo(schema *air.Schema) air.Expr { + return &air.Mul{Args: lowerExprs(e.Args, schema)} } // LowerTo lowers a normalise expression to the AIR level by "compiling it out" // using a computed column. -func (p *Normalise) LowerTo(tbl *air.Schema) air.Expr { +func (p *Normalise) LowerTo(schema *air.Schema) air.Expr { // Lower the expression being normalised - e := p.Arg.LowerTo(tbl) + e := p.Arg.LowerTo(schema) // Construct an expression representing the normalised value of e. That is, // an expression which is 0 when e is 0, and 1 when e is non-zero. - return air_gadgets.Normalise(e, tbl) + return air_gadgets.Normalise(e, schema) } // LowerTo lowers a column access to the AIR level. This is straightforward as // it is already in the correct form. -func (e *ColumnAccess) LowerTo(tbl *air.Schema) air.Expr { +func (e *ColumnAccess) LowerTo(schema *air.Schema) air.Expr { return &air.ColumnAccess{Column: e.Column, Shift: e.Shift} } // LowerTo lowers a constant to the AIR level. This is straightforward as it is // already in the correct form. -func (e *Constant) LowerTo(tbl *air.Schema) air.Expr { +func (e *Constant) LowerTo(schema *air.Schema) air.Expr { return &air.Constant{Value: e.Value} } // Lower a set of zero or more MIR expressions. -func lowerExprs(exprs []Expr, tbl *air.Schema) []air.Expr { +func lowerExprs(exprs []Expr, schema *air.Schema) []air.Expr { n := len(exprs) nexprs := make([]air.Expr, n) for i := 0; i < n; i++ { - nexprs[i] = exprs[i].LowerTo(tbl) + nexprs[i] = exprs[i].LowerTo(schema) } return nexprs diff --git a/pkg/mir/schema.go b/pkg/mir/schema.go index 80b0900..1dacb13 100644 --- a/pkg/mir/schema.go +++ b/pkg/mir/schema.go @@ -20,7 +20,7 @@ type VanishingConstraint = *table.RowConstraint[table.ZeroTest[Expr]] // PropertyAssertion captures the notion of an arbitrary property which should // hold for all acceptable traces. However, such a property is not enforced by // the prover. -type PropertyAssertion = *table.PropertyAssertion[Expr] +type PropertyAssertion = *table.PropertyAssertion[table.ZeroTest[Expr]] // Permutation captures the notion of a (sorted) permutation at the MIR level. type Permutation = *table.SortedPermutation @@ -49,11 +49,51 @@ func EmptySchema() *Schema { return p } +// Width returns the number of column groups in this schema. +func (p *Schema) Width() uint { + return uint(len(p.dataColumns) + len(p.permutations)) +} + +// Column returns information about the ith column in this schema. +func (p *Schema) Column(i uint) table.ColumnSchema { + panic("todo") +} + +// ColumnGroup returns information about the ith column group in this schema. +func (p *Schema) ColumnGroup(i uint) table.ColumnGroup { + n := uint(len(p.dataColumns)) + if i < n { + return p.dataColumns[i] + } + + return p.permutations[i-n] +} + +// ColumnIndex determines the column index for a given column in this schema, or +// returns false indicating an error. +func (p *Schema) ColumnIndex(name string) (uint, bool) { + index := uint(0) + + for i := uint(0); i < p.Width(); i++ { + ith := p.ColumnGroup(i) + for j := uint(0); j < ith.Width(); j++ { + if ith.NameOf(j) == name { + // hit + return index, true + } + + index++ + } + } + // miss + return 0, false +} + // GetColumnByName gets a given data column based on its name. If no such // column exists, it panics. func (p *Schema) GetColumnByName(name string) DataColumn { for _, c := range p.dataColumns { - if c.Name == name { + if c.Name() == name { return c } } @@ -103,7 +143,8 @@ func (p *Schema) AddVanishingConstraint(handle string, domain *int, expr Expr) { // AddPropertyAssertion appends a new property assertion. func (p *Schema) AddPropertyAssertion(handle string, expr Expr) { - p.assertions = append(p.assertions, table.NewPropertyAssertion(handle, expr)) + test := table.ZeroTest[Expr]{Expr: expr} + p.assertions = append(p.assertions, table.NewPropertyAssertion(handle, test)) } // Accepts determines whether this schema will accept a given trace. That @@ -136,9 +177,29 @@ func (p *Schema) Accepts(trace table.Trace) error { // constraints as necessary to preserve the original semantics. func (p *Schema) LowerToAir() *air.Schema { airSchema := air.EmptySchema[Expr]() - // Lower data columns - for _, col := range p.dataColumns { - lowerColumnToAir(col, airSchema) + // Allocate data and permutation columns. This must be done first to ensure + // alignment is preserved across lowering. + index := uint(0) + + for i := uint(0); i < p.Width(); i++ { + ith := p.ColumnGroup(i) + for j := uint(0); j < ith.Width(); j++ { + col := ith.NameOf(j) + airSchema.AddColumn(col, ith.IsSynthetic()) + + index++ + } + } + // Add computations. Again this has to be done first for things to work. + // Essentially to reflect the fact that these columns have been added above + // before others. Realistically, the overall design of this process is a + // bit broken right now. + for _, perm := range p.permutations { + airSchema.AddComputation(perm) + } + // Lower checked data columns + for i, col := range p.dataColumns { + lowerColumnToAir(uint(i), col, airSchema) } // Lower permutations columns for _, perm := range p.permutations { @@ -160,25 +221,22 @@ func (p *Schema) LowerToAir() *air.Schema { // Lower a datacolumn to the AIR level. The main effect of this is that, for // columns with non-trivial types, we must add appropriate range constraints to // the enclosing schema. -func lowerColumnToAir(c *table.DataColumn[table.Type], schema *air.Schema) { +func lowerColumnToAir(index uint, c *table.DataColumn[table.Type], schema *air.Schema) { // Check whether a constraint is implied by the column's type if t := c.Type.AsUint(); t != nil && t.Checked() { // Yes, a constraint is implied. Now, decide whether to use a range // constraint or just a vanishing constraint. if t.HasBound(2) { // u1 => use vanishing constraint X * (X - 1) - air_gadgets.ApplyBinaryGadget(c.Name, schema) + air_gadgets.ApplyBinaryGadget(index, schema) } else if t.HasBound(256) { // u2..8 use range constraints - schema.AddRangeConstraint(c.Name, t.Bound()) + schema.AddRangeConstraint(index, t.Bound()) } else { // u9+ use byte decompositions. - air_gadgets.ApplyBitwidthGadget(c.Name, t.BitWidth(), schema) + air_gadgets.ApplyBitwidthGadget(index, t.BitWidth(), schema) } } - // Finally, add an (untyped) data column representing this - // data column. - schema.AddColumn(c.Name, false) } // Lower a permutation to the AIR level. This has quite a few @@ -189,14 +247,22 @@ func lowerColumnToAir(c *table.DataColumn[table.Type], schema *air.Schema) { // meet the requirements of a sorted permutation. func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Schema) { ncols := len(c.Targets) + // + targets := make([]uint, ncols) + sources := make([]uint, ncols) // Add individual permutation constraints for i := 0; i < ncols; i++ { - airSchema.AddColumn(c.Targets[i], true) + var ok1, ok2 bool + // TODO: REPLACE + sources[i], ok1 = airSchema.ColumnIndex(c.Sources[i]) + targets[i], ok2 = airSchema.ColumnIndex(c.Targets[i]) + + if !ok1 || !ok2 { + panic("missing column") + } } // - airSchema.AddPermutationConstraint(c.Targets, c.Sources) - // Add the trace computation. - airSchema.AddComputation(c) + airSchema.AddPermutationConstraint(targets, sources) // Add sorting constraints + synthetic columns as necessary. if ncols == 1 { // For a single column sort, its actually a bit easier because we don't @@ -206,7 +272,7 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche // also requires bitwidth constraints. bitwidth := mirSchema.GetColumnByName(c.Sources[0]).Type.AsUint().BitWidth() // Add column sorting constraints - air_gadgets.ApplyColumnSortGadget(c.Targets[0], c.Signs[0], bitwidth, airSchema) + air_gadgets.ApplyColumnSortGadget(targets[0], c.Signs[0], bitwidth, airSchema) } else { // For a multi column sort, its a bit harder as we need additional // logicl to ensure the target columns are lexicographally sorted. @@ -220,7 +286,7 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche } } // Add lexicographically sorted constraints - air_gadgets.ApplyLexicographicSortingGadget(c.Targets, c.Signs, bitwidth, airSchema) + air_gadgets.ApplyLexicographicSortingGadget(targets, c.Signs, bitwidth, airSchema) } } diff --git a/pkg/mir/string.go b/pkg/mir/string.go index bc2752d..cf904f3 100644 --- a/pkg/mir/string.go +++ b/pkg/mir/string.go @@ -6,10 +6,10 @@ import ( func (e *ColumnAccess) String() string { if e.Shift == 0 { - return e.Column + return fmt.Sprintf("#%d)", e.Column) } - return fmt.Sprintf("(shift %s %d)", e.Column, e.Shift) + return fmt.Sprintf("(shift #%d %d)", e.Column, e.Shift) } func (e *Constant) String() string { diff --git a/pkg/table/array_trace.go b/pkg/table/array_trace.go new file mode 100644 index 0000000..2701681 --- /dev/null +++ b/pkg/table/array_trace.go @@ -0,0 +1,305 @@ +package table + +import ( + "fmt" + "strings" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +// ArrayTrace provides an implementation of Trace which stores columns as an +// array. +type ArrayTrace struct { + // Holds the maximum height of any column in the trace + height uint + // Holds the name of each column + columns []*ArrayTraceColumn +} + +// EmptyArrayTrace constructs an empty array trace into which column data can be +// added. +func EmptyArrayTrace() *ArrayTrace { + p := new(ArrayTrace) + // Initially empty columns + p.columns = make([]*ArrayTraceColumn, 0) + // Initialise height as 0 + p.height = 0 + // done + return p +} + +// Width returns the number of columns in this trace. +func (p *ArrayTrace) Width() uint { + return uint(len(p.columns)) +} + +// ColumnName returns the name of the ith column in this trace. +func (p *ArrayTrace) ColumnName(index int) string { + return p.columns[index].Name() +} + +// IndexOf returns the index of the given name in this trace. +func (p *ArrayTrace) IndexOf(name string) (uint, bool) { + for i, c := range p.columns { + if c.name == name { + return uint(i), true + } + } + // Column does not exist + return 0, false +} + +// Clone creates an identical clone of this trace. +func (p *ArrayTrace) Clone() *ArrayTrace { + clone := new(ArrayTrace) + clone.columns = make([]*ArrayTraceColumn, len(p.columns)) + clone.height = p.height + // + for i, c := range p.columns { + // TODO: can this be avoided? + clone.columns[i] = c.Clone() + } + // done + return clone +} + +// HasColumn checks whether the trace has a given column or not. +func (p *ArrayTrace) HasColumn(name string) bool { + for _, c := range p.columns { + if c.name == name { + return true + } + } + + return false +} + +// AlignInputWith attempts to align this trace with the input columns of a +// given schema. This means ensuring the order of columns in this trace matches +// the order of input columns in the schema. Thus, column indexes used by +// constraints in the schema can directly access in this trace (i.e. without +// name lookup). Alignment can fail, however, if there is a mismatch between +// columns in the trace and those expected by the schema. +func (p *ArrayTrace) AlignInputWith(schema Schema) error { + return alignWith(false, p, schema) +} + +// AlignWith attempts to align this trace with a given schema. This means +// ensuring the order of columns in this trace matches the order in the schema. +// Thus, column indexes used by constraints in the schema can directly access in +// this trace (i.e. without name lookup). Alignment can fail, however, if there +// is a mismatch between columns in the trace and those expected by the schema. +func (p *ArrayTrace) AlignWith(schema Schema) error { + return alignWith(true, p, schema) +} + +// Alignment algorithm which operates either in unexpanded or expanded mode. In +// expanded mode, all columns must be accounted for and will be aligned. In +// unexpanded mode, the trace is only expected to contain input (i.e. +// non-synthetic) columns. Furthermore, in the schema these are expected to be +// allocated before synthetic columns. As such, alignment of these input +// columns is performed. +func alignWith(expand bool, p *ArrayTrace, schema Schema) error { + ncols := uint(len(p.columns)) + index := uint(0) + // Check each column described in this schema is present in the trace. + for i := uint(0); i < schema.Width(); i++ { + group := schema.ColumnGroup(i) + if expand || !group.IsSynthetic() { + for j := uint(0); j < group.Width(); j++ { + // Determine column name + schemaName := group.NameOf(j) + // Sanity check column exists + if index >= ncols { + return fmt.Errorf("trace missing column %s", schemaName) + } + + traceName := p.columns[index].name + // Check alignment + if traceName != schemaName { + // Not aligned --- so fix + k, ok := p.IndexOf(schemaName) + // check exists + if !ok { + return fmt.Errorf("trace missing column %s", schemaName) + } + // Swap columns + tmp := p.columns[index] + p.columns[index] = p.columns[k] + p.columns[k] = tmp + } + // Continue + index++ + } + } + } + // Check whether all columns matched + if index == ncols { + // Yes, alignment complete. + return nil + } + // Error Case. + unknowns := p.columns[index:] + // + return fmt.Errorf("trace contains unknown columns: %v", unknowns) +} + +// AddColumn adds a new column of data to this trace. +func (p *ArrayTrace) AddColumn(name string, data []*fr.Element, padding *fr.Element) { + // Sanity check the column does not already exist. + if p.HasColumn(name) { + panic("column already exists") + } + // Construct new column + column := ArrayTraceColumn{name, data, padding} + // Append it + p.columns = append(p.columns, &column) + // Update maximum height + if uint(len(data)) > p.height { + p.height = uint(len(data)) + } +} + +// Columns returns the set of columns in this trace. +func (p *ArrayTrace) Columns() []*ArrayTraceColumn { + return p.columns +} + +// ColumnByIndex looks up a column based on its index. +func (p *ArrayTrace) ColumnByIndex(index uint) Column { + return p.columns[index] +} + +// ColumnByName looks up a column based on its name. If the column doesn't +// exist, then nil is returned. +func (p *ArrayTrace) ColumnByName(name string) Column { + for _, c := range p.columns { + if name == c.name { + // Matched column + return c + } + } + + return nil +} + +// Height determines the maximum height of any column within this trace. +func (p *ArrayTrace) Height() uint { + return p.height +} + +// Pad each column in this trace with n items at the front. An iterator over +// the padding values to use for each column must be given. +func (p *ArrayTrace) Pad(n uint) { + for _, c := range p.columns { + c.Pad(n) + } + // Increment height + p.height += n +} + +func (p *ArrayTrace) String() string { + // Use string builder to try and make this vaguely efficient. + var id strings.Builder + + id.WriteString("{") + + for i := 0; i < len(p.columns); i++ { + if i != 0 { + id.WriteString(",") + } + + id.WriteString(p.columns[i].name) + id.WriteString("={") + + for j := 0; j < int(p.height); j++ { + jth := p.columns[i].Get(j) + + if j != 0 { + id.WriteString(",") + } + + if jth == nil { + id.WriteString("_") + } else { + id.WriteString(jth.String()) + } + } + id.WriteString("}") + } + id.WriteString("}") + // + return id.String() +} + +// =================================================================== +// Array Trace Column +// =================================================================== + +// ArrayTraceColumn represents a column of data within an array trace. +type ArrayTraceColumn struct { + // Holds the name of this column + name string + // Holds the raw data making up this column + data []*fr.Element + // Value to be used when padding this column + padding *fr.Element +} + +// Name returns the name of the given column. +func (p *ArrayTraceColumn) Name() string { + return p.name +} + +// Height determines the height of this column. +func (p *ArrayTraceColumn) Height() uint { + return uint(len(p.data)) +} + +// Padding returns the value which will be used for padding this column. +func (p *ArrayTraceColumn) Padding() *fr.Element { + return p.padding +} + +// Data returns the data for the given column. +func (p *ArrayTraceColumn) Data() []*fr.Element { + return p.data +} + +// Get the value at a given row in this column. If the row is +// out-of-bounds, then the column's padding value is returned instead. +// Thus, this function always succeeds. +func (p *ArrayTraceColumn) Get(row int) *fr.Element { + if row < 0 || row >= len(p.data) { + // out-of-bounds access + return p.padding + } + // in-bounds access + return p.data[row] +} + +// Clone an ArrayTraceColumn +func (p *ArrayTraceColumn) Clone() *ArrayTraceColumn { + clone := new(ArrayTraceColumn) + clone.name = p.name + clone.data = make([]*fr.Element, len(p.data)) + clone.padding = p.padding + copy(clone.data, p.data) + + return clone +} + +// Pad this column with n copies of a given value, either at the front +// (sign=true) or the back (sign=false). +func (p *ArrayTraceColumn) Pad(n uint) { + // Allocate sufficient memory + ndata := make([]*fr.Element, uint(len(p.data))+n) + // Copy over the data + copy(ndata[n:], p.data) + // Go padding! + for i := uint(0); i < n; i++ { + ndata[i] = p.padding + } + // Copy over + p.data = ndata +} diff --git a/pkg/table/column.go b/pkg/table/column.go index 58c5cd6..971dc85 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -10,7 +10,7 @@ import ( // DataColumn represents a column of user-provided values. type DataColumn[T Type] struct { - Name string + name string // Expected type of values held in this column. Observe that this type is // enforced only when checking is enabled. Unchecked typed columns can // still make sense when their values are implied by some other constraint. @@ -25,9 +25,29 @@ func NewDataColumn[T Type](name string, base T, synthetic bool) *DataColumn[T] { return &DataColumn[T]{name, base, synthetic} } -// Get the value of this column at a given row in a given trace. -func (c *DataColumn[T]) Get(row int, tr Trace) *fr.Element { - return tr.GetByName(c.Name, row) +// Name forms part of the ColumnSchema interface, and provides access to +// information about the ith column in a schema. +func (p *DataColumn[T]) Name() string { + return p.name +} + +// Width forms part of the ColumnGroup interface, and determines how many +// columns are in the group. Data columns already represent a group of size 1. +func (p *DataColumn[T]) Width() uint { + return 1 +} + +// NameOf forms part of the ColumnGroup interface, and provides access to the +// ith column in a group. Since a data column represents a group of size 1, +// there is only ever one name. +func (p *DataColumn[T]) NameOf(index uint) string { + return p.name +} + +// IsSynthetic forms part of the ColumnGroup interface, and determines whether or +// not the group (as a whole) is synthetic. +func (p *DataColumn[T]) IsSynthetic() bool { + return p.Synthetic } // Accepts determines whether or not this column accepts the given trace. For a @@ -35,20 +55,22 @@ func (c *DataColumn[T]) Get(row int, tr Trace) *fr.Element { // type. // //nolint:revive -func (c *DataColumn[T]) Accepts(tr Trace) error { - // Check column in trace! - if !tr.HasColumn(c.Name) { - return fmt.Errorf("Trace missing data column ({%s})", c.Name) - } - // Check constraints accepted - for i := uint(0); i < tr.Height(); i++ { - val := tr.GetByName(c.Name, int(i)) - - if !c.Type.Accept(val) { - // Construct useful error message - msg := fmt.Sprintf("column %s value out-of-bounds (row %d, %s)", c.Name, i, val) - // Evaluation failure - return errors.New(msg) +func (p *DataColumn[T]) Accepts(tr Trace) error { + // Only check for non-field types. This is simply because a column with the + // field type always accepts everything. + if p.Type.AsField() == nil { + // Access corresponding column in trace + col := tr.ColumnByName(p.name) + // Check constraints accepted + for i := 0; i < int(tr.Height()); i++ { + val := col.Get(i) + + if !p.Type.Accept(val) { + // Construct useful error message + msg := fmt.Sprintf("column %s value out-of-bounds (row %d, %s)", p.Name(), i, val) + // Evaluation failure + return errors.New(msg) + } } } // All good @@ -58,10 +80,10 @@ func (c *DataColumn[T]) Accepts(tr Trace) error { //nolint:revive func (c *DataColumn[T]) String() string { if c.Type.AsField() != nil { - return fmt.Sprintf("(column %s)", c.Name) + return fmt.Sprintf("(column %s)", c.Name()) } - return fmt.Sprintf("(column %s :%s)", c.Name, c.Type) + return fmt.Sprintf("(column %s :%s)", c.Name(), c.Type) } // ComputedColumn describes a column whose values are computed on-demand, rather @@ -117,8 +139,7 @@ func (c *ComputedColumn[E]) Accepts(tr Trace) error { // then an error is flagged. func (c *ComputedColumn[E]) ExpandTrace(tr Trace) error { if tr.HasColumn(c.Name) { - msg := fmt.Sprintf("Computed column already exists ({%s})", c.Name) - return errors.New(msg) + return fmt.Errorf("Computed column already exists ({%s})", c.Name) } data := make([]*fr.Element, tr.Height()) @@ -151,77 +172,6 @@ func (c *ComputedColumn[E]) String() string { // Sorted Permutations // =================================================================== -// Permutation declares a constraint that one column is a permutation -// of another. -type Permutation struct { - // The target columns - Targets []string - // The so columns - Sources []string -} - -// NewPermutation creates a new permutation -func NewPermutation(targets []string, sources []string) *Permutation { - if len(targets) != len(sources) { - panic("differeng number of target / source permutation columns") - } - - return &Permutation{targets, sources} -} - -// RequiredSpillage returns the minimum amount of spillage required to ensure -// valid traces are accepted in the presence of arbitrary padding. -func (p *Permutation) RequiredSpillage() uint { - return uint(0) -} - -// Accepts checks whether a permutation holds between the source and -// target columns. -func (p *Permutation) Accepts(tr Trace) error { - // Sanity check columns well formed. - if err := validPermutationColumns(p.Targets, p.Sources, tr); err != nil { - return err - } - // Slice out data - src := sliceMatchingColumns(p.Sources, tr) - dst := sliceMatchingColumns(p.Targets, tr) - // Sanity check whether column exists - if !util.ArePermutationOf(dst, src) { - msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})", - p.Targets, p.Sources) - return errors.New(msg) - } - // Success - return nil -} - -func (p *Permutation) String() string { - targets := "" - sources := "" - - for i, s := range p.Targets { - if i != 0 { - targets += " " - } - - targets += s - } - - for i, s := range p.Sources { - if i != 0 { - sources += " " - } - - sources += s - } - - return fmt.Sprintf("(permutation (%s) (%s))", targets, sources) -} - -// =================================================================== -// Sorted Permutations -// =================================================================== - // SortedPermutation declares one or more columns as sorted permutations of // existing columns. type SortedPermutation struct { @@ -242,6 +192,27 @@ func NewSortedPermutation(targets []string, signs []bool, sources []string) *Sor return &SortedPermutation{targets, signs, sources} } +// Width forms part of the ColumnGroup interface, and provides access to the +// ith column in a group. Sorted permutations have define one or more new +// columns. +func (p *SortedPermutation) Width() uint { + return uint(len(p.Targets)) +} + +// NameOf forms part of the ColumnGroup interface, and provides access to the +// ith column in a group. Since a data column represents a group of size 1, +// there is only ever one name. +func (p *SortedPermutation) NameOf(index uint) string { + return p.Targets[index] +} + +// IsSynthetic forms part of the ColumnGroup interface which determines whether +// or not the group (as a whole) is synthetic. Sorted permutation columns are +// always synthetic. +func (p *SortedPermutation) IsSynthetic() bool { + return true +} + // RequiredSpillage returns the minimum amount of spillage required to ensure // valid traces are accepted in the presence of arbitrary padding. func (p *SortedPermutation) RequiredSpillage() uint { diff --git a/pkg/table/constraints.go b/pkg/table/constraints.go index d8bcb9e..eddc407 100644 --- a/pkg/table/constraints.go +++ b/pkg/table/constraints.go @@ -52,10 +52,10 @@ type ZeroTest[E Evaluable] struct { } // TestAt determines whether or not a given expression evaluates to zero. -// Observe that if the expression is undefined, then it is assumed to hold. +// Observe that if the expression is undefined, then it is assumed not to hold. func (p ZeroTest[E]) TestAt(row int, tr Trace) bool { val := p.Expr.EvalAt(row, tr) - return val == nil || val.IsZero() + return val != nil && val.IsZero() } // Bounds determines the bounds for this zero test. @@ -172,9 +172,8 @@ func (p *RowConstraint[T]) String() string { // 256 (i.e. to ensuring bytes). This restriction is somewhat // arbitrary and is determined by the underlying prover. type RangeConstraint struct { - // A unique identifier for this constraint. This is primarily - // useful for debugging. - Handle string + // Column index to be constrained. + Column uint // The actual constraint itself, namely an expression which // should evaluate to zero. NOTE: an fr.Element is used here // to store the bound simply to make the necessary comparison @@ -183,7 +182,7 @@ type RangeConstraint struct { } // NewRangeConstraint constructs a new Range constraint! -func NewRangeConstraint(column string, bound *fr.Element) *RangeConstraint { +func NewRangeConstraint(column uint, bound *fr.Element) *RangeConstraint { var n fr.Element = fr.NewElement(256) if bound.Cmp(&n) > 0 { panic("Range constraint for bitwidth above 8 not supported") @@ -192,24 +191,21 @@ func NewRangeConstraint(column string, bound *fr.Element) *RangeConstraint { return &RangeConstraint{column, bound} } -// GetHandle returns the handle associated with this constraint. -func (p *RangeConstraint) GetHandle() string { - return p.Handle -} - // IsAir is a marker that indicates this is an AIR column. func (p *RangeConstraint) IsAir() bool { return true } // Accepts checks whether a range constraint evaluates to zero on // every row of a table. If so, return nil otherwise return an error. func (p *RangeConstraint) Accepts(tr Trace) error { - for k := uint(0); k < tr.Height(); k++ { + column := tr.ColumnByIndex(p.Column) + for k := 0; k < int(tr.Height()); k++ { // Get the value on the kth row - kth := tr.GetByName(p.Handle, int(k)) + kth := column.Get(k) // Perform the bounds check if kth != nil && kth.Cmp(p.Bound) >= 0 { + name := column.Name() // Construct useful error message - msg := fmt.Sprintf("value out-of-bounds (row %d, %s)", kth, p.Handle) + msg := fmt.Sprintf("value out-of-bounds (row %d, %s)", kth, name) // Evaluation failure return errors.New(msg) } @@ -219,7 +215,86 @@ func (p *RangeConstraint) Accepts(tr Trace) error { } func (p *RangeConstraint) String() string { - return fmt.Sprintf("(range %s %s)", p.Handle, p.Bound) + return fmt.Sprintf("(range #%d %s)", p.Column, p.Bound) +} + +// =================================================================== +// Permutation +// =================================================================== + +// Permutation declares a constraint that one column is a permutation +// of another. +type Permutation struct { + // The target columns + Targets []uint + // The source columns + Sources []uint +} + +// NewPermutation creates a new permutation +func NewPermutation(targets []uint, sources []uint) *Permutation { + if len(targets) != len(sources) { + panic("differeng number of target / source permutation columns") + } + + return &Permutation{targets, sources} +} + +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *Permutation) RequiredSpillage() uint { + return uint(0) +} + +// Accepts checks whether a permutation holds between the source and +// target columns. +func (p *Permutation) Accepts(tr Trace) error { + // Slice out data + src := sliceColumns(p.Sources, tr) + dst := sliceColumns(p.Targets, tr) + // Sanity check whether column exists + if !util.ArePermutationOf(dst, src) { + msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})", + p.Targets, p.Sources) + return errors.New(msg) + } + // Success + return nil +} + +func (p *Permutation) String() string { + targets := "" + sources := "" + + for i, s := range p.Targets { + if i != 0 { + targets += " " + } + + targets += fmt.Sprintf("%d", s) + } + + for i, s := range p.Sources { + if i != 0 { + sources += " " + } + + sources += fmt.Sprintf("%d", s) + } + + return fmt.Sprintf("(permutation (%s) (%s))", targets, sources) +} + +func sliceColumns(columns []uint, tr Trace) [][]*fr.Element { + // Allocate return array + cols := make([][]*fr.Element, len(columns)) + // Slice out the data + for i, n := range columns { + nth := tr.ColumnByIndex(n) + cols[i] = nth.Data() + } + // Done + return cols } // =================================================================== @@ -233,7 +308,7 @@ func (p *RangeConstraint) String() string { // That is, they should be implied by the actual constraints. Thus, whilst the // prover cannot enforce such properties, external tools (such as for formal // verification) can attempt to ensure they do indeed always hold. -type PropertyAssertion[E Evaluable] struct { +type PropertyAssertion[T Testable] struct { // A unique identifier for this constraint. This is primarily // useful for debugging. Handle string @@ -242,33 +317,31 @@ type PropertyAssertion[E Evaluable] struct { // Observe that this can be any function which is computable // on a given trace --- we are not restricted to expressions // which can be arithmetised. - Expr E + Property T +} + +// NewPropertyAssertion constructs a new property assertion! +func NewPropertyAssertion[T Testable](handle string, property T) *PropertyAssertion[T] { + return &PropertyAssertion[T]{handle, property} } // GetHandle returns the handle associated with this constraint. // //nolint:revive -func (p *PropertyAssertion[E]) GetHandle() string { +func (p *PropertyAssertion[T]) GetHandle() string { return p.Handle } -// NewPropertyAssertion constructs a new property assertion! -func NewPropertyAssertion[E Evaluable](handle string, expr E) *PropertyAssertion[E] { - return &PropertyAssertion[E]{handle, expr} -} - // Accepts checks whether a vanishing constraint evaluates to zero on every row // of a table. If so, return nil otherwise return an error. // //nolint:revive -func (p *PropertyAssertion[E]) Accepts(tr Trace) error { +func (p *PropertyAssertion[T]) Accepts(tr Trace) error { for k := uint(0); k < tr.Height(); k++ { - // Determine kth evaluation point - kth := p.Expr.EvalAt(int(k), tr) - // Check whether it vanished (or was undefined) - if kth != nil && !kth.IsZero() { + // Check whether property holds (or was undefined) + if !p.Property.TestAt(int(k), tr) { // Construct useful error message - msg := fmt.Sprintf("property assertion %s does not hold (row %d, %s)", p.Handle, k, kth) + msg := fmt.Sprintf("property assertion %s does not hold (row %d)", p.Handle, k) // Evaluation failure return errors.New(msg) } diff --git a/pkg/table/printer.go b/pkg/table/printer.go index 22775d8..0ab3c1e 100644 --- a/pkg/table/printer.go +++ b/pkg/table/printer.go @@ -26,18 +26,19 @@ func PrintTrace(tr Trace) { func traceColumnData(tr Trace, col uint) []string { n := tr.Height() - data := make([]string, n+1) - data[0] = tr.ColumnName(int(col)) + data := make([]string, n+2) + data[0] = fmt.Sprintf("#%d", col) + data[1] = tr.ColumnByIndex(col).Name() - for row := uint(0); row < n; row++ { - data[row+1] = tr.GetByIndex(int(col), int(row)).String() + for row := 0; row < int(n); row++ { + data[row+2] = tr.ColumnByIndex(col).Get(row).String() } return data } func traceRowWidths(height uint, rows [][]string) []int { - widths := make([]int, height+1) + widths := make([]int, height+2) for _, row := range rows { for i, col := range row { diff --git a/pkg/table/schema.go b/pkg/table/schema.go index 634d824..b744a21 100644 --- a/pkg/table/schema.go +++ b/pkg/table/schema.go @@ -21,6 +21,39 @@ type Schema interface { // ensure valid traces are accepted in the presence of arbitrary padding. // Note: this is calculated on demand. RequiredSpillage() uint + + // Determine the number of column groups in this schema. + Width() uint + + // Determine the index of a named column in this schema, or return false if + // no matching column exists. + ColumnIndex(string) (uint, bool) + + // Access information about the ith column group in this schema. + ColumnGroup(uint) ColumnGroup + + // Access information about the ith column in this schema. + Column(uint) ColumnSchema +} + +// ColumnGroup represents a group of related columns in the schema. For +// example, a single data column is (for now) always a column group of size 1. +// Likewise, an array of size n is a column group of size n, etc. +type ColumnGroup interface { + // Return the number of columns in this group. + Width() uint + + // Returns the name of the ith column in this group. + NameOf(uint) string + + // Determines whether or not this column group is synthetic. + IsSynthetic() bool +} + +// ColumnSchema provides information about a specific column in the schema. +type ColumnSchema interface { + // Returns the name of this column + Name() string } // Declaration represents a declared element of a schema. For example, a column diff --git a/pkg/table/trace.go b/pkg/table/trace.go index 6d2e888..3922681 100644 --- a/pkg/table/trace.go +++ b/pkg/table/trace.go @@ -2,16 +2,14 @@ package table import ( "encoding/json" - "fmt" "math/big" - "strings" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/util" ) // Acceptable represents an element which can "accept" a trace, or either reject -// with an error or report a warning. +// with an error (or eventually perhaps report a warning). type Acceptable interface { Accepts(Trace) error } @@ -22,19 +20,28 @@ type Column interface { Name() string // Return the height (i.e. number of rows) of this column. Height() uint - // Return the data stored in this column. + // Return the raw data stored in this column. Data() []*fr.Element // Return the value to use for padding this column. Padding() *fr.Element + // Get the value at a given row in this column. If the row is + // out-of-bounds, then the column's padding value is returned instead. + // Thus, this function always succeeds. + Get(row int) *fr.Element } // Trace describes a set of named columns. Columns are not required to have the // same height and can be either "data" columns or "computed" columns. type Trace interface { + // Attempt to align this trace with a given schema. This means ensuring the + // order of columns in this trace matches the order in the schema. Thus, + // column indexes used by constraints in the schema can directly access in + // this trace (i.e. without name lookup). Alignment can fail, however, if + // there is a mismatch between columns in the trace and those expected by + // the schema. + AlignWith(schema Schema) error // Add a new column of data AddColumn(name string, data []*fr.Element, padding *fr.Element) - // Get the name of the ith column in this trace. - ColumnName(int) string // ColumnByIndex returns the ith column in this trace. ColumnByIndex(uint) Column // ColumnByName returns the data of a given column in order that it can be @@ -42,19 +49,6 @@ type Trace interface { ColumnByName(name string) Column // Check whether this trace contains data for the given column. HasColumn(name string) bool - // Get the value of a given column by its name. If the column - // does not exist or if the index is out-of-bounds then an - // error is returned. - // - // NOTE: this operation is expected to be slower than - // GetByindex as, depending on the underlying data format, - // this may first resolve the name into a physical column - // index. - GetByName(name string, row int) *fr.Element - // Get the value of a given column by its index. If the column - // does not exist or if the index is out-of-bounds then an - // error is returned. - GetByIndex(col int, row int) *fr.Element // Pad each column in this trace with n items at the front. An iterator over // the padding values to use for each column must be given. Pad(n uint) @@ -79,260 +73,6 @@ func ConstraintsAcceptTrace[T Acceptable](trace Trace, constraints []T) error { return nil } -// =================================================================== -// Array Trace -// =================================================================== - -// ArrayTrace provides an implementation of Trace which stores columns as an -// array. -type ArrayTrace struct { - // Holds the maximum height of any column in the trace - height uint - // Holds the name of each column - columns []*ArrayTraceColumn -} - -// EmptyArrayTrace constructs an empty array trace into which column data can be -// added. -func EmptyArrayTrace() *ArrayTrace { - p := new(ArrayTrace) - // Initially empty columns - p.columns = make([]*ArrayTraceColumn, 0) - // Initialise height as 0 - p.height = 0 - // done - return p -} - -// Width returns the number of columns in this trace. -func (p *ArrayTrace) Width() uint { - return uint(len(p.columns)) -} - -// ColumnName returns the name of the ith column in this trace. -func (p *ArrayTrace) ColumnName(index int) string { - return p.columns[index].Name() -} - -// Clone creates an identical clone of this trace. -func (p *ArrayTrace) Clone() *ArrayTrace { - clone := new(ArrayTrace) - clone.columns = make([]*ArrayTraceColumn, len(p.columns)) - clone.height = p.height - // - for i, c := range p.columns { - // TODO: can this be avoided? - clone.columns[i] = c.Clone() - } - // done - return clone -} - -// HasColumn checks whether the trace has a given column or not. -func (p *ArrayTrace) HasColumn(name string) bool { - for _, c := range p.columns { - if c.name == name { - return true - } - } - - return false -} - -// AddColumn adds a new column of data to this trace. -func (p *ArrayTrace) AddColumn(name string, data []*fr.Element, padding *fr.Element) { - // Sanity check the column does not already exist. - if p.HasColumn(name) { - panic("column already exists") - } - // Construct new column - column := ArrayTraceColumn{name, data, padding} - // Append it - p.columns = append(p.columns, &column) - // Update maximum height - if uint(len(data)) > p.height { - p.height = uint(len(data)) - } -} - -// Columns returns the set of columns in this trace. -func (p *ArrayTrace) Columns() []*ArrayTraceColumn { - return p.columns -} - -// GetByName gets the value of a given column (as identified by its name) at a -// given row. If the column does not exist, an error is returned. -func (p *ArrayTrace) GetByName(name string, row int) *fr.Element { - // NOTE: Could improve performance here if names were kept in - // sorted order. - c := p.getColumnByName(name) - if c != nil { - // Matched column - return c.Get(row) - } - // Precondition failure - panic(fmt.Sprintf("Invalid column: {%s}", name)) -} - -// ColumnByIndex looks up a column based on its index. -func (p *ArrayTrace) ColumnByIndex(index uint) Column { - return p.columns[index] -} - -// ColumnByName looks up a column based on its name. If the column doesn't -// exist, then nil is returned. -func (p *ArrayTrace) ColumnByName(name string) Column { - for _, c := range p.columns { - if name == c.name { - // Matched column - return c - } - } - - return nil -} - -// GetByIndex returns the value of a given column (as identifier by its index or -// register) at a given row. If the column is out-of-bounds an error is -// returned. -func (p *ArrayTrace) GetByIndex(col int, row int) *fr.Element { - if col < 0 || col >= len(p.columns) { - // Precondition failure - panic(fmt.Sprintf("Invalid column: {%d}", col)) - } - - return p.columns[col].Get(row) -} - -func (p *ArrayTrace) getColumnByName(name string) *ArrayTraceColumn { - for _, c := range p.columns { - if name == c.name { - // Matched column - return c - } - } - - return nil -} - -// Height determines the maximum height of any column within this trace. -func (p *ArrayTrace) Height() uint { - return p.height -} - -// Pad each column in this trace with n items at the front. An iterator over -// the padding values to use for each column must be given. -func (p *ArrayTrace) Pad(n uint) { - for _, c := range p.columns { - c.Pad(n) - } - // Increment height - p.height += n -} - -func (p *ArrayTrace) String() string { - // Use string builder to try and make this vaguely efficient. - var id strings.Builder - - id.WriteString("{") - - for i := 0; i < len(p.columns); i++ { - if i != 0 { - id.WriteString(",") - } - - id.WriteString(p.columns[i].name) - id.WriteString("={") - - for j := uint(0); j < p.height; j++ { - jth := p.GetByIndex(i, int(j)) - - if j != 0 { - id.WriteString(",") - } - - if jth == nil { - id.WriteString("_") - } else { - id.WriteString(jth.String()) - } - } - id.WriteString("}") - } - id.WriteString("}") - // - return id.String() -} - -// =================================================================== -// Array Trace Column -// =================================================================== - -// ArrayTraceColumn represents a column of data within an array trace. -type ArrayTraceColumn struct { - // Holds the name of this column - name string - // Holds the raw data making up this column - data []*fr.Element - // Value to be used when padding this column - padding *fr.Element -} - -// Name returns the name of the given column. -func (p *ArrayTraceColumn) Name() string { - return p.name -} - -// Height determines the height of this column. -func (p *ArrayTraceColumn) Height() uint { - return uint(len(p.data)) -} - -// Padding returns the value which will be used for padding this column. -func (p *ArrayTraceColumn) Padding() *fr.Element { - return p.padding -} - -// Data returns the data for the given column. -func (p *ArrayTraceColumn) Data() []*fr.Element { - return p.data -} - -// Clone an ArrayTraceColumn -func (p *ArrayTraceColumn) Clone() *ArrayTraceColumn { - clone := new(ArrayTraceColumn) - clone.name = p.name - clone.data = make([]*fr.Element, len(p.data)) - clone.padding = p.padding - copy(clone.data, p.data) - - return clone -} - -// Pad this column with n copies of a given value, either at the front -// (sign=true) or the back (sign=false). -func (p *ArrayTraceColumn) Pad(n uint) { - // Allocate sufficient memory - ndata := make([]*fr.Element, uint(len(p.data))+n) - // Copy over the data - copy(ndata[n:], p.data) - // Go padding! - for i := uint(0); i < n; i++ { - ndata[i] = p.padding - } - // Copy over - p.data = ndata -} - -// Get the value at the given row of this column. -func (p *ArrayTraceColumn) Get(row int) *fr.Element { - if row >= 0 && row < len(p.data) { - return p.data[row] - } - // For out-of-bounds access, return padding value. - return p.padding -} - // =================================================================== // JSON Parser // =================================================================== @@ -358,7 +98,6 @@ func ParseJsonTrace(bytes []byte) (*ArrayTrace, error) { // Add new column to the trace trace.AddColumn(name, rawElements, &zero) } - // Done. return trace, nil } diff --git a/pkg/test/ir_test.go b/pkg/test/ir_test.go index 79b87a0..82570b6 100644 --- a/pkg/test/ir_test.go +++ b/pkg/test/ir_test.go @@ -371,26 +371,26 @@ func CheckTraces(t *testing.T, test string, expected bool, traces []*table.Array hirID := traceId{"HIR", test, expected, i + 1, padding, hirSchema.RequiredSpillage()} mirID := traceId{"MIR", test, expected, i + 1, padding, mirSchema.RequiredSpillage()} airID := traceId{"AIR", test, expected, i + 1, padding, airSchema.RequiredSpillage()} - // Check HIR/MIR trace (if applicable) - if airSchema.IsInputTrace(tr) == nil { - // This is an unexpanded input trace. + // Check whether trace is input compatible with schema + if err := tr.AlignInputWith(hirSchema); err != nil { + // Alignment failed. So, attempt alignment as expanded + // trace instead. + if err := tr.AlignWith(airSchema); err != nil { + // Still failed, hence trace must be malformed in some way + if expected { + t.Errorf("Trace malformed (%s.accepts, line %d): [%s]", test, i+1, err) + } else { + t.Errorf("Trace malformed (%s.rejects, line %d): [%s]", test, i+1, err) + } + } else { + // Aligned as expanded trace + checkExpandedTrace(t, tr, airID, airSchema) + } + } else { + // Aligned as unexpanded trace. checkInputTrace(t, tr, hirID, hirSchema) checkInputTrace(t, tr, mirID, mirSchema) checkInputTrace(t, tr, airID, airSchema) - } else if airSchema.IsOutputTrace(tr) == nil { - // This is an already expanded input trace. Therefore, no need - // to perform expansion. - checkExpandedTrace(t, tr, airID, airSchema) - } else { - // Trace appears to be malformed. - err1 := airSchema.IsInputTrace(tr) - err2 := airSchema.IsOutputTrace(tr) - - if expected { - t.Errorf("Trace malformed (%s.accepts, line %d): [%s][%s]", test, i+1, err1, err2) - } else { - t.Errorf("Trace malformed (%s.rejects, line %d): [%s][%s]", test, i+1, err1, err2) - } } } } @@ -407,6 +407,9 @@ func checkInputTrace(t *testing.T, tr *table.ArrayTrace, id traceId, schema tabl // Check if err != nil { t.Error(err) + } else if err := etr.AlignWith(schema); err != nil { + // Alignment problem + t.Error(err) } else { checkExpandedTrace(t, etr, id, schema) } @@ -421,7 +424,7 @@ func checkExpandedTrace(t *testing.T, tr table.Trace, id traceId, schema table.S accepted := (err == nil) // Process what happened versus what was supposed to happen. if !accepted && id.expected { - //printTrace(tr) + //table.PrintTrace(tr) msg := fmt.Sprintf("Trace rejected incorrectly (%s, %s.accepts, line %d with spillage %d / padding %d): %s", id.ir, id.test, id.line, id.spillage, id.padding, err) t.Errorf(msg) diff --git a/testdata/norm_05.rejects b/testdata/norm_05.rejects index edd3c57..b9d1de5 100644 --- a/testdata/norm_05.rejects +++ b/testdata/norm_05.rejects @@ -1,2 +1,2 @@ -{ "A": [1], "(inv A)": [0] } -{ "A": [0], "(inv A)": [1] } +{ "A": [1], "(inv #0)": [0] } +{ "A": [0], "(inv #0)": [1] } diff --git a/testdata/norm_06.rejects b/testdata/norm_06.rejects index 22e5c07..b872c85 100644 --- a/testdata/norm_06.rejects +++ b/testdata/norm_06.rejects @@ -4,7 +4,7 @@ { "A": [1], "B": [0] } { "A": [0], "B": [1] } { "A": [1], "B": [1] } -{ "A": [0], "B": [1], "(inv (+ A B))": [0] } -{ "A": [1], "B": [0], "(inv (+ A B))": [0] } -{ "A": [1], "B": [1], "(inv (+ A B))": [0] } -{ "A": [-2], "B": [1], "(inv (+ A B))": [0] } +{ "A": [0], "B": [1], "(inv (+ #0 #1))": [0] } +{ "A": [1], "B": [0], "(inv (+ #0 #1))": [0] } +{ "A": [1], "B": [1], "(inv (+ #0 #1))": [0] } +{ "A": [-2], "B": [1], "(inv (+ #0 #1))": [0] } diff --git a/testdata/norm_07.rejects b/testdata/norm_07.rejects index 01c18cc..16db16b 100644 --- a/testdata/norm_07.rejects +++ b/testdata/norm_07.rejects @@ -3,8 +3,8 @@ { "A": [0] } { "A": [1] } { "A": [2] } -{ "A": [1], "(inv A)": [1] } -{ "A": [2], "(inv A)": [1] } -{ "A": [1], "(inv A)": [2] } -{ "A": [-2], "(inv A)": [-1] } -{ "A": [-1], "(inv A)": [1] } +{ "A": [1], "(inv #0)": [1] } +{ "A": [2], "(inv #0)": [1] } +{ "A": [1], "(inv #0)": [2] } +{ "A": [-2], "(inv #0)": [-1] } +{ "A": [-1], "(inv #0)": [1] }