diff --git a/pkg/air/eval.go b/pkg/air/eval.go index 71d0fa9..b274767 100644 --- a/pkg/air/eval.go +++ b/pkg/air/eval.go @@ -2,14 +2,14 @@ package air import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/trace" ) // EvalAt evaluates a column access at a given row in a trace, which returns the // value at that row of the column in question or nil is that row is // out-of-bounds. -func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element { - val := tbl.ColumnByIndex(e.Column).Get(k + e.Shift) +func (e *ColumnAccess) EvalAt(k int, tr trace.Trace) *fr.Element { + val := tr.ColumnByIndex(e.Column).Get(k + e.Shift) var clone fr.Element // Clone original value @@ -18,7 +18,7 @@ func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element { // EvalAt evaluates a constant at a given row in a trace, which simply returns // that constant. -func (e *Constant) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Constant) EvalAt(k int, tr trace.Trace) *fr.Element { var clone fr.Element // Clone original value return clone.Set(e.Value) @@ -26,37 +26,37 @@ func (e *Constant) EvalAt(k int, tbl table.Trace) *fr.Element { // EvalAt evaluates a sum at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Add) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Add) EvalAt(k int, tr trace.Trace) *fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Add(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // EvalAt evaluates a product at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Mul) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Mul) EvalAt(k int, tr trace.Trace) *fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Mul(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // EvalAt evaluates a subtraction at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Sub) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Sub) EvalAt(k int, tr trace.Trace) *fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Sub(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // EvalExprsAt evaluates all expressions in a given slice at a given row on the // table, and fold their results together using a combinator. -func evalExprsAt(k int, tbl table.Trace, exprs []Expr, fn func(*fr.Element, *fr.Element)) *fr.Element { +func evalExprsAt(k int, tr trace.Trace, exprs []Expr, fn func(*fr.Element, *fr.Element)) *fr.Element { // Evaluate first argument - val := exprs[0].EvalAt(k, tbl) + val := exprs[0].EvalAt(k, tr) if val == nil { return nil } // Continue evaluating the rest for i := 1; i < len(exprs); i++ { - ith := exprs[i].EvalAt(k, tbl) + ith := exprs[i].EvalAt(k, tr) fn(val, ith) } diff --git a/pkg/air/expr.go b/pkg/air/expr.go index 47c87d2..593d92f 100644 --- a/pkg/air/expr.go +++ b/pkg/air/expr.go @@ -2,7 +2,7 @@ package air import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -19,7 +19,7 @@ type Expr interface { // "nil". An expression can be undefined for several reasons: firstly, if // it accesses a row which does not exist (e.g. at index -1); secondly, if // it accesses a column which does not exist. - EvalAt(int, table.Trace) *fr.Element + EvalAt(int, trace.Trace) *fr.Element // String produces a string representing this as an S-Expression. String() string diff --git a/pkg/air/gadgets/bits.go b/pkg/air/gadgets/bits.go index 27ac0d5..0fb18d1 100644 --- a/pkg/air/gadgets/bits.go +++ b/pkg/air/gadgets/bits.go @@ -5,7 +5,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/air" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema/assignment" ) // ApplyBinaryGadget adds a binarity constraint for a given column in the schema @@ -13,7 +13,7 @@ import ( // column X, this corresponds to the vanishing constraint X * (X-1) == 0. func ApplyBinaryGadget(column uint, schema *air.Schema) { // Determine column name - name := schema.Column(column).Name() + name := schema.Columns().Nth(column).Name() // Construct X X := air.NewColumnAccess(column, 0) // Construct X-1 @@ -27,7 +27,7 @@ func ApplyBinaryGadget(column uint, schema *air.Schema) { // ApplyBitwidthGadget ensures all values in a given column fit within a given // number of bits. This is implemented using a *byte decomposition* which adds // n columns and a vanishing constraint (where n*8 >= nbits). -func ApplyBitwidthGadget(column uint, nbits uint, schema *air.Schema) { +func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) { if nbits%8 != 0 { panic("asymmetric bitwidth constraints not yet supported") } else if nbits == 0 { @@ -37,44 +37,24 @@ func ApplyBitwidthGadget(column uint, nbits uint, schema *air.Schema) { n := nbits / 8 es := make([]air.Expr, n) fr256 := fr.NewElement(256) - name := schema.Column(column).Name() + name := schema.Columns().Nth(col).Name() coefficient := fr.NewElement(1) + // Add decomposition assignment + index := schema.AddAssignment(assignment.NewByteDecomposition(name, n)) // Construct Columns for i := uint(0); i < n; i++ { - // Determine name for the ith byte column - colName := fmt.Sprintf("%s:%d", name, i) // Create Column + Constraint - colIndex := schema.AddColumn(colName, true) - es[i] = air.NewColumnAccess(colIndex, 0).Mul(air.NewConstCopy(&coefficient)) + es[i] = air.NewColumnAccess(index+i, 0).Mul(air.NewConstCopy(&coefficient)) - schema.AddRangeConstraint(colIndex, &fr256) + schema.AddRangeConstraint(index+i, &fr256) // Update coefficient coefficient.Mul(&coefficient, &fr256) } // Construct (X:0 * 1) + ... + (X:n * 2^n) sum := &air.Add{Args: es} // Construct X == (X:0 * 1) + ... + (X:n * 2^n) - X := air.NewColumnAccess(column, 0) + X := air.NewColumnAccess(col, 0) eq := X.Equate(sum) // Construct column name schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), nil, eq) - // Finally, add the necessary byte decomposition computation. - schema.AddComputation(table.NewByteDecomposition(name, nbits)) -} - -// AddBitArray adds an array of n bit columns using a given prefix, including -// the necessary binarity constraints. -func AddBitArray(prefix string, count int, schema *air.Schema) []uint { - bits := make([]uint, count) - - for i := 0; i < count; i++ { - // Construct bit column name - ith := fmt.Sprintf("%s:%d", prefix, i) - // Add (synthetic) column - bits[i] = schema.AddColumn(ith, true) - // Add binarity constraints (i.e. to enfoce that this column is a bit). - ApplyBinaryGadget(bits[i], schema) - } - // - return bits } diff --git a/pkg/air/gadgets/column_sort.go b/pkg/air/gadgets/column_sort.go index e614773..46ad729 100644 --- a/pkg/air/gadgets/column_sort.go +++ b/pkg/air/gadgets/column_sort.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/consensys/go-corset/pkg/air" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema/assignment" ) // ApplyColumnSortGadget adds sorting constraints for a column where the @@ -18,13 +18,13 @@ import ( // This gadget does not attempt to sort the column data during trace expansion, // and assumes the data either comes sorted or is sorted by some other // computation. -func ApplyColumnSortGadget(column uint, sign bool, bitwidth uint, schema *air.Schema) { +func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schema) { var deltaName string // Determine column name - name := schema.Column(column).Name() + name := schema.Columns().Nth(col).Name() // Configure computation - Xk := air.NewColumnAccess(column, 0) - Xkm1 := air.NewColumnAccess(column, -1) + Xk := air.NewColumnAccess(col, 0) + Xkm1 := air.NewColumnAccess(col, -1) // Account for sign var Xdiff air.Expr if sign { @@ -34,10 +34,8 @@ func ApplyColumnSortGadget(column uint, sign bool, bitwidth uint, schema *air.Sc Xdiff = Xkm1.Sub(Xk) deltaName = fmt.Sprintf("-%s", name) } - // Add delta column - deltaIndex := schema.AddColumn(deltaName, true) - // Add diff computation - schema.AddComputation(table.NewComputedColumn(deltaName, Xdiff)) + // Add delta assignment + deltaIndex := schema.AddAssignment(assignment.NewComputedColumn(deltaName, Xdiff)) // Add necessary bitwidth constraints ApplyBitwidthGadget(deltaIndex, bitwidth, schema) // Configure constraint: Delta[k] = X[k] - X[k-1] diff --git a/pkg/air/gadgets/lexicographic_sort.go b/pkg/air/gadgets/lexicographic_sort.go index 6ef255f..70fc342 100644 --- a/pkg/air/gadgets/lexicographic_sort.go +++ b/pkg/air/gadgets/lexicographic_sort.go @@ -4,9 +4,8 @@ import ( "fmt" "strings" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/air" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema/assignment" ) // ApplyLexicographicSortingGadget Add sorting constraints for a sequence of one @@ -35,15 +34,13 @@ func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint // Construct a unique prefix for this sort. prefix := constructLexicographicSortingPrefix(columns, signs, schema) // Add trace computation - schema.AddComputation(&lexicographicSortExpander{prefix, columns, signs, bitwidth}) - deltaName := fmt.Sprintf("%s:delta", prefix) + deltaIndex := schema.AddAssignment(assignment.NewLexicographicSort(prefix, columns, signs, bitwidth)) // Construct selecto bits. - bits := addLexicographicSelectorBits(prefix, columns, schema) - // Add delta column - deltaIndex := schema.AddColumn(deltaName, true) + addLexicographicSelectorBits(prefix, deltaIndex, columns, schema) // Construct delta terms - constraint := constructLexicographicDeltaConstraint(deltaIndex, bits, columns, signs) + constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs) // Add delta constraint + deltaName := fmt.Sprintf("%s:delta", prefix) schema.AddVanishingConstraint(deltaName, nil, constraint) // Add necessary bitwidth constraints ApplyBitwidthGadget(deltaIndex, bitwidth, schema) @@ -56,7 +53,7 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a var id strings.Builder // Concatenate column names with their signs. for i := 0; i < len(columns); i++ { - ith := schema.Column(columns[i]) + ith := schema.Columns().Nth(columns[i]) id.WriteString(ith.Name()) if signs[i] { @@ -76,23 +73,28 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a // // NOTE: this implementation differs from the original corset which used an // additional "Eq" bit to help ensure at most one selector bit was enabled. -func addLexicographicSelectorBits(prefix string, columns []uint, schema *air.Schema) []uint { - ncols := len(columns) - // Add bits and their binary constraints. - bits := AddBitArray(prefix, ncols, schema) +func addLexicographicSelectorBits(prefix string, deltaIndex uint, columns []uint, schema *air.Schema) { + ncols := uint(len(columns)) + // Calculate column index of first selector bit + bitIndex := deltaIndex + 1 + // Add binary constraints for selector bits + for i := uint(0); i < ncols; i++ { + // Add binarity constraints (i.e. to enfoce that this column is a bit). + ApplyBinaryGadget(bitIndex+i, schema) + } // Apply constraints to ensure at most one is set. terms := make([]air.Expr, ncols) - for i := 0; i < ncols; i++ { - terms[i] = air.NewColumnAccess(bits[i], 0) + for i := uint(0); i < ncols; i++ { + terms[i] = air.NewColumnAccess(bitIndex+i, 0) pterms := make([]air.Expr, i+1) qterms := make([]air.Expr, i) - for j := 0; j < i; j++ { - pterms[j] = air.NewColumnAccess(bits[j], 0) - qterms[j] = air.NewColumnAccess(bits[j], 0) + for j := uint(0); j < i; j++ { + pterms[j] = air.NewColumnAccess(bitIndex+j, 0) + qterms[j] = air.NewColumnAccess(bitIndex+j, 0) } // (∀j<=i.Bj=0) ==> C[k]=C[k-1] - pterms[i] = air.NewColumnAccess(bits[i], 0) + pterms[i] = air.NewColumnAccess(bitIndex+i, 0) pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)) pName := fmt.Sprintf("%s:%d:a", prefix, i) schema.AddVanishingConstraint(pName, nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff)) @@ -100,7 +102,7 @@ func addLexicographicSelectorBits(prefix string, columns []uint, schema *air.Sch qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema) qName := fmt.Sprintf("%s:%d:b", prefix, i) // bi = 0 || C[k]≠C[k-1] - constraint := air.NewColumnAccess(bits[i], 0).Mul(air.NewConst64(1).Sub(qDiff)) + constraint := air.NewColumnAccess(bitIndex+i, 0).Mul(air.NewConst64(1).Sub(qDiff)) if i != 0 { // (∃j (1 == e/e) l_name := fmt.Sprintf("[%s <=]", ie.String()) - tbl.AddVanishingConstraint(l_name, nil, e_implies_one_e_e) + schema.AddVanishingConstraint(l_name, nil, e_implies_one_e_e) // Ensure (e/e != 0) ==> (1 == e/e) r_name := fmt.Sprintf("[%s =>]", ie.String()) - tbl.AddVanishingConstraint(r_name, nil, inv_e_implies_one_e_e) + schema.AddVanishingConstraint(r_name, nil, inv_e_implies_one_e_e) // Done return air.NewColumnAccess(index, 0) } @@ -65,7 +66,7 @@ type Inverse struct{ Expr air.Expr } // EvalAt computes the multiplicative inverse of a given expression at a given // row in the table. -func (e *Inverse) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Inverse) EvalAt(k int, tbl tr.Trace) *fr.Element { inv := new(fr.Element) val := e.Expr.EvalAt(k, tbl) // Go syntax huh? diff --git a/pkg/air/schema.go b/pkg/air/schema.go index 577c043..73fada3 100644 --- a/pkg/air/schema.go +++ b/pkg/air/schema.go @@ -2,213 +2,122 @@ package air import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/schema/assignment" + "github.com/consensys/go-corset/pkg/schema/constraint" "github.com/consensys/go-corset/pkg/util" ) // DataColumn captures the essence of a data column at AIR level. -type DataColumn = *table.DataColumn[*table.FieldType] - -// VanishingConstraint captures the essence of a vanishing constraint at the HIR -// level. A vanishing constraint is a row constraint which must evaluate to -// zero. -type VanishingConstraint = *table.RowConstraint[table.ZeroTest[Expr]] +type DataColumn = *assignment.DataColumn // PropertyAssertion captures the notion of an arbitrary property which should // hold for all acceptable traces. However, such a property is not enforced by // the prover. -type PropertyAssertion = *table.PropertyAssertion[table.ZeroTest[table.Evaluable]] - -// Permutation captures the notion of a simple column permutation at the AIR -// level. -type Permutation = *table.Permutation +type PropertyAssertion = *schema.PropertyAssertion[constraint.ZeroTest[schema.Evaluable]] // Schema for AIR traces which is parameterised on a notion of computation as // permissible in computed columns. type Schema struct { - // The data columns of this schema. - dataColumns []DataColumn - // The permutation columns of this schema. - permutations []Permutation - // The vanishing constraints of this schema. - vanishing []VanishingConstraint - // The range constraints of this schema. - ranges []*table.RangeConstraint + // The set of data columns corresponding to the inputs of this schema. + inputs []schema.Declaration + // Assignments defines the set of column declarations whose trace values are + // computed from the inputs. + assignments []schema.Assignment + // The constraints of this schema. A constraint is either a vanishing + // constraint, a permutation constraint, a lookup constraint or a range + // constraint. + constraints []schema.Constraint // Property assertions. assertions []PropertyAssertion - // The computations used to construct traces which adhere to - // this schema. Such computations are not expressible at the - // prover level and, hence, can only be used to pre-process - // traces prior to prove generation. - computations []table.TraceComputation } // EmptySchema is used to construct a fresh schema onto which new columns and // constraints will be added. -func EmptySchema[C table.Evaluable]() *Schema { +func EmptySchema[C schema.Evaluable]() *Schema { p := new(Schema) - p.dataColumns = make([]DataColumn, 0) - p.permutations = make([]Permutation, 0) - p.vanishing = make([]VanishingConstraint, 0) - p.ranges = make([]*table.RangeConstraint, 0) + p.inputs = make([]schema.Declaration, 0) + p.assignments = make([]schema.Assignment, 0) + p.constraints = make([]schema.Constraint, 0) p.assertions = make([]PropertyAssertion, 0) - p.computations = make([]table.TraceComputation, 0) // Done return p } -// Width returns the number of column groups in this schema. -func (p *Schema) Width() uint { - return uint(len(p.dataColumns)) -} - -// ColumnGroup returns information about the ith column group in this schema. -func (p *Schema) ColumnGroup(i uint) table.ColumnGroup { - return p.dataColumns[i] -} - -// Column returns information about the ith column in this schema. -func (p *Schema) Column(i uint) table.ColumnSchema { - return p.dataColumns[i] -} - -// Size returns the number of declarations in this schema. -func (p *Schema) Size() int { - return len(p.dataColumns) + len(p.permutations) + len(p.vanishing) + - len(p.ranges) + len(p.assertions) + len(p.computations) -} - -// GetDeclaration returns the ith declaration in this schema. -func (p *Schema) GetDeclaration(index int) table.Declaration { - ith := util.FlatArrayIndexOf_6(index, p.dataColumns, p.permutations, - p.vanishing, p.ranges, p.assertions, p.computations) - return ith.(table.Declaration) -} - -// Columns returns the set of data columns. -func (p *Schema) Columns() []DataColumn { - return p.dataColumns -} - -// HasColumn checks whether a given schema has a given column. -func (p *Schema) HasColumn(name string) bool { - for _, c := range p.dataColumns { - if c.Name() == name { - return true - } - } - - return false -} - -// ColumnIndex determines the column index for a given column in this schema, or -// returns false indicating an error. -func (p *Schema) ColumnIndex(name string) (uint, bool) { - for i, c := range p.dataColumns { - if c.Name() == name { - return uint(i), true - } - } - - return 0, false -} - -// RequiredSpillage returns the minimum amount of spillage required to ensure -// valid traces are accepted in the presence of arbitrary padding. Spillage can -// only arise from computations as this is where values outside of the user's -// control are determined. -func (p *Schema) RequiredSpillage() uint { - // Ensures always at least one row of spillage (referred to as the "initial - // padding row") - mx := uint(1) - // Determine if any more spillage required - for _, c := range p.computations { - mx = max(mx, c.RequiredSpillage()) - } - - return mx -} - -// AddColumn appends a new data column which is either synthetic or -// not. A synthetic column is one which has been introduced by the -// process of lowering from HIR / MIR to AIR. That is, it is not a -// column which was original specified by the user. Columns also support a -// "padding sign", which indicates whether padding should occur at the front -// (positive sign) or the back (negative sign). -func (p *Schema) AddColumn(name string, synthetic bool) uint { +// AddColumn appends a new data column whose values must be provided by the +// user. +func (p *Schema) AddColumn(name string, datatype schema.Type) uint { // NOTE: the air level has no ability to enforce the type specified for a // given column. - p.dataColumns = append(p.dataColumns, table.NewDataColumn(name, &table.FieldType{}, synthetic)) + p.inputs = append(p.inputs, assignment.NewDataColumn(name, datatype)) // Calculate column index - return uint(len(p.dataColumns) - 1) + return uint(len(p.inputs) - 1) } -// AddComputation appends a new computation to be used during trace -// expansion for this schema. -func (p *Schema) AddComputation(c table.TraceComputation) { - p.computations = append(p.computations, c) +// AddAssignment appends a new assignment (i.e. set of computed columns) to be +// used during trace expansion for this schema. Computed columns are introduced +// by the process of lowering from HIR / MIR to AIR. +func (p *Schema) AddAssignment(c schema.Assignment) uint { + index := p.Columns().Count() + p.assignments = append(p.assignments, c) + + return index } // AddPermutationConstraint appends a new permutation constraint which // ensures that one column is a permutation of another. func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) { - p.permutations = append(p.permutations, table.NewPermutation(targets, sources)) + p.constraints = append(p.constraints, constraint.NewPermutationConstraint(targets, sources)) } // AddVanishingConstraint appends a new vanishing constraint. func (p *Schema) AddVanishingConstraint(handle string, domain *int, expr Expr) { - p.vanishing = append(p.vanishing, table.NewRowConstraint(handle, domain, table.ZeroTest[Expr]{Expr: expr})) + p.constraints = append(p.constraints, + constraint.NewVanishingConstraint(handle, domain, constraint.ZeroTest[Expr]{Expr: expr})) } // AddRangeConstraint appends a new range constraint. func (p *Schema) AddRangeConstraint(column uint, bound *fr.Element) { - p.ranges = append(p.ranges, table.NewRangeConstraint(column, bound)) + p.constraints = append(p.constraints, constraint.NewRangeConstraint(column, bound)) } -// Accepts determines whether this schema will accept a given trace. That -// is, whether or not the given trace adheres to the schema. A trace can fail -// to adhere to the schema for a variety of reasons, such as having a constraint -// which does not hold. -func (p *Schema) Accepts(trace table.Trace) error { - // Check vanishing constraints - err := table.ConstraintsAcceptTrace(trace, p.vanishing) - if err != nil { - return err - } - // Check permutation constraints - err = table.ConstraintsAcceptTrace(trace, p.permutations) - if err != nil { - return err - } - // Check range constraints - err = table.ConstraintsAcceptTrace(trace, p.ranges) - if err != nil { - return err - } - // Check computations - err = table.ConstraintsAcceptTrace(trace, p.computations) - if err != nil { - return err - } - // TODO: handle assertions. These cannot be checked in the same way as for - // other constraints at the AIR level because the prover does not support - // them. - - return nil +// ============================================================================ +// Schema Interface +// ============================================================================ + +// Inputs returns an array over the input declarations of this schema. That is, +// the subset of declarations whose trace values must be provided by the user. +func (p *Schema) Inputs() util.Iterator[schema.Declaration] { + return util.NewArrayIterator(p.inputs) } -// ExpandTrace expands a given trace according to this schema. More -// specifically, that means computing the actual values for any computed -// columns. Observe that computed columns have to be computed in the correct -// order. -func (p *Schema) ExpandTrace(tr table.Trace) error { - // Execute all computations - for _, c := range p.computations { - err := c.ExpandTrace(tr) - if err != nil { - return err - } - } - // Done - return nil +// Assignments returns an array over the assignments of this schema. That +// is, the subset of declarations whose trace values can be computed from +// the inputs. +func (p *Schema) Assignments() util.Iterator[schema.Assignment] { + return util.NewArrayIterator(p.assignments) +} + +// Columns returns an array over the underlying columns of this schema. +// Specifically, the index of a column in this array is its column index. +func (p *Schema) Columns() util.Iterator[schema.Column] { + is := util.NewFlattenIterator[schema.Declaration, schema.Column](p.Inputs(), + func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() }) + ps := util.NewFlattenIterator[schema.Assignment, schema.Column](p.Assignments(), + func(d schema.Assignment) util.Iterator[schema.Column] { return d.Columns() }) + // + return is.Append(ps) +} + +// Constraints returns an array over the underlying constraints of this +// schema. +func (p *Schema) Constraints() util.Iterator[schema.Constraint] { + return util.NewArrayIterator(p.constraints) +} + +// Declarations returns an array over the column declarations of this +// schema. +func (p *Schema) Declarations() util.Iterator[schema.Declaration] { + ps := util.NewCastIterator[schema.Assignment, schema.Declaration](p.Assignments()) + return p.Inputs().Append(ps) } diff --git a/pkg/binfile/computation.go b/pkg/binfile/computation.go index 857a8f1..b7bf130 100644 --- a/pkg/binfile/computation.go +++ b/pkg/binfile/computation.go @@ -2,6 +2,7 @@ package binfile import ( "github.com/consensys/go-corset/pkg/hir" + sc "github.com/consensys/go-corset/pkg/schema" ) type jsonComputationSet struct { @@ -25,8 +26,17 @@ type jsonSortedComputation struct { func (e jsonComputationSet) addToSchema(schema *hir.Schema) { for _, c := range e.Computations { if c.Sorted != nil { - targets := asColumnRefs(c.Sorted.Tos) + refs := asColumnRefs(c.Sorted.Tos) sources := asColumnRefs(c.Sorted.Froms) + // Convert target refs into columns + targets := make([]sc.Column, len(refs)) + + for i, r := range refs { + // TODO: correctly determine type + ith := &sc.FieldType{} + targets[i] = sc.NewColumn(r, ith) + } + // Finally, add the permutation column schema.AddPermutationColumns(targets, c.Sorted.Signs, sources) } } diff --git a/pkg/binfile/constraint_set.go b/pkg/binfile/constraint_set.go index 50fdad2..910b621 100644 --- a/pkg/binfile/constraint_set.go +++ b/pkg/binfile/constraint_set.go @@ -106,7 +106,12 @@ func HirSchemaFromJson(bytes []byte) (schema *hir.Schema, err error) { fmt.Printf("COLUMN: %s\n", c.Handle) panic("invalid JSON column configuration") } else { - schema.AddDataColumn(c.Handle, c.Type.toHir(c.MustProve)) + t := c.Type.toHir() + schema.AddDataColumn(c.Handle, t) + // Check whether a type constraint required or not. + if c.MustProve { + schema.AddTypeConstraint(c.Handle, t) + } } } // Add constraints diff --git a/pkg/binfile/type.go b/pkg/binfile/type.go index 35a5dcb..899a13a 100644 --- a/pkg/binfile/type.go +++ b/pkg/binfile/type.go @@ -3,7 +3,7 @@ package binfile import ( "fmt" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema" ) type jsonType struct { @@ -21,16 +21,16 @@ type jsonType struct { // Translation // ============================================================================= -func (e *jsonType) toHir(checked bool) table.Type { +func (e *jsonType) toHir() schema.Type { // Check whether magma is string if str, ok := e.Magma.(string); ok { switch str { case "Native": - return &table.FieldType{} + return &schema.FieldType{} case "Byte": - return table.NewUintType(8, checked) + return schema.NewUintType(8) case "Binary": - return table.NewUintType(1, checked) + return schema.NewUintType(1) default: panic(fmt.Sprintf("Unknown JSON type encountered: %s:%s", e.Magma, e.Conditioning)) } @@ -39,7 +39,7 @@ func (e *jsonType) toHir(checked bool) table.Type { if intMap, ok := e.Magma.(map[string]any); ok { if val, isInt := intMap["Integer"]; isInt { nbits := uint(val.(float64)) - return table.NewUintType(nbits, checked) + return schema.NewUintType(nbits) } } // Fail diff --git a/pkg/cmd/check.go b/pkg/cmd/check.go index 884800d..e3faa81 100644 --- a/pkg/cmd/check.go +++ b/pkg/cmd/check.go @@ -5,7 +5,8 @@ import ( "os" "github.com/consensys/go-corset/pkg/hir" - "github.com/consensys/go-corset/pkg/table" + sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" "github.com/spf13/cobra" ) @@ -18,7 +19,7 @@ var checkCmd = &cobra.Command{ Traces can be given either as JSON or binary lt files. Constraints can be given either as lisp or bin files.`, Run: func(cmd *cobra.Command, args []string) { - var trace *table.ArrayTrace + var trace *trace.ArrayTrace var hirSchema *hir.Schema var cfg checkConfig @@ -70,7 +71,7 @@ type checkConfig struct { // Check a given trace is consistently accepted (or rejected) at the different // IR levels. -func checkTraceWithLowering(tr *table.ArrayTrace, schema *hir.Schema, cfg checkConfig) { +func checkTraceWithLowering(tr *trace.ArrayTrace, schema *hir.Schema, cfg checkConfig) { if !cfg.hir && !cfg.mir && !cfg.air { // Process together checkTraceWithLoweringDefault(tr, schema, cfg) @@ -90,7 +91,7 @@ func checkTraceWithLowering(tr *table.ArrayTrace, schema *hir.Schema, cfg checkC } } -func checkTraceWithLoweringHir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { +func checkTraceWithLoweringHir(tr *trace.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { trHIR, errHIR := checkTrace(tr, hirSchema, cfg) // if errHIR != nil { @@ -99,7 +100,7 @@ func checkTraceWithLoweringHir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg } } -func checkTraceWithLoweringMir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { +func checkTraceWithLoweringMir(tr *trace.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { // Lower HIR => MIR mirSchema := hirSchema.LowerToMir() // Check trace @@ -111,7 +112,7 @@ func checkTraceWithLoweringMir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg } } -func checkTraceWithLoweringAir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { +func checkTraceWithLoweringAir(tr *trace.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { // Lower HIR => MIR mirSchema := hirSchema.LowerToMir() // Lower MIR => AIR @@ -126,7 +127,7 @@ func checkTraceWithLoweringAir(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg // The default check allows one to compare all levels against each other and // look for any discrepenacies. -func checkTraceWithLoweringDefault(tr *table.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { +func checkTraceWithLoweringDefault(tr *trace.ArrayTrace, hirSchema *hir.Schema, cfg checkConfig) { // Lower HIR => MIR mirSchema := hirSchema.LowerToMir() // Lower MIR => AIR @@ -153,7 +154,7 @@ func checkTraceWithLoweringDefault(tr *table.ArrayTrace, hirSchema *hir.Schema, } } -func checkTrace(tr *table.ArrayTrace, schema table.Schema, cfg checkConfig) (table.Trace, error) { +func checkTrace(tr *trace.ArrayTrace, schema sc.Schema, cfg checkConfig) (trace.Trace, error) { if cfg.expand { // Clone to prevent interefence with subsequent checks tr = tr.Clone() @@ -163,25 +164,25 @@ func checkTrace(tr *table.ArrayTrace, schema table.Schema, cfg checkConfig) (tab tr.Pad(uint(cfg.spillage)) } else { // Apply default inferred spillage - tr.Pad(schema.RequiredSpillage()) + tr.Pad(sc.RequiredSpillage(schema)) } // Perform Input Alignment - if err := tr.AlignInputWith(schema); err != nil { + if err := sc.AlignInputs(tr, schema); err != nil { return tr, err } // Expand trace - if err := schema.ExpandTrace(tr); err != nil { + if err := sc.ExpandTrace(schema, tr); err != nil { return tr, err } } // Perform Alignment - if err := tr.AlignWith(schema); err != nil { + if err := sc.Align(tr, schema); err != nil { return tr, err } // Check whether padding requested if cfg.padding.Left == 0 && cfg.padding.Right == 0 { // No padding requested. Therefore, we can avoid a clone in this case. - return tr, schema.Accepts(tr) + return tr, sc.Accepts(schema, tr) } // Apply padding for n := cfg.padding.Left; n <= cfg.padding.Right; n++ { @@ -190,7 +191,7 @@ func checkTrace(tr *table.ArrayTrace, schema table.Schema, cfg checkConfig) (tab // Apply padding ptr.Pad(n) // Check whether accepted or not. - if err := schema.Accepts(ptr); err != nil { + if err := sc.Accepts(schema, ptr); err != nil { return ptr, err } } @@ -206,9 +207,9 @@ func toErrorString(err error) string { return err.Error() } -func reportError(ir string, tr table.Trace, err error, cfg checkConfig) { +func reportError(ir string, tr trace.Trace, err error, cfg checkConfig) { if cfg.report { - table.PrintTrace(tr) + trace.PrintTrace(tr) } if err != nil { diff --git a/pkg/cmd/compute.go b/pkg/cmd/compute.go index 20be4f9..fec92c5 100644 --- a/pkg/cmd/compute.go +++ b/pkg/cmd/compute.go @@ -28,12 +28,13 @@ var computeCmd = &cobra.Command{ panic(err) } // Print columns - for _, c := range schema.Columns() { - fmt.Printf("column %s : %s\n", c.Name(), c.Type) + for i := schema.Columns(); i.HasNext(); { + ith := i.Next() + fmt.Printf("column %s : %s\n", ith.Name(), ith.Type()) } // Print constraints - for _, c := range schema.Constraints() { - fmt.Println(c) + for i := schema.Constraints(); i.HasNext(); { + fmt.Println(i.Next()) } } }, diff --git a/pkg/cmd/debug.go b/pkg/cmd/debug.go index 81a6dcd..8e529f7 100644 --- a/pkg/cmd/debug.go +++ b/pkg/cmd/debug.go @@ -4,10 +4,7 @@ import ( "fmt" "os" - "github.com/consensys/go-corset/pkg/air" - "github.com/consensys/go-corset/pkg/hir" - "github.com/consensys/go-corset/pkg/mir" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema" "github.com/spf13/cobra" ) @@ -44,85 +41,12 @@ var debugCmd = &cobra.Command{ } // Print out all declarations included in a given -func printSchema(schema table.Schema, stats bool) { - dataColumns := 0 - permutations := 0 - vanishing := 0 - ranges := 0 - assertions := 0 - computations := 0 - // Print declarations in order of appearance. - for i := 0; i < schema.Size(); i++ { - ith := schema.GetDeclaration(i) - fmt.Println(ith.String()) - // Count stats - if isDataColumn(ith) { - dataColumns++ - } else if isPermutation(ith) { - permutations++ - } else if isVanishing(ith) { - vanishing++ - } else if isRange(ith) { - ranges++ - } else { - computations++ - } - } - // - if stats { - fmt.Println("--") - fmt.Printf("%d column(s), %d permutation(s), %d constraint(s), %d range(s), %d assertion(s) and %d computation(s).\n", - dataColumns, permutations, vanishing, ranges, assertions, computations) - } -} - -func isDataColumn(d table.Declaration) bool { - if _, ok := d.(air.DataColumn); ok { - return true - } else if _, ok := d.(mir.DataColumn); ok { - return true - } else if _, ok := d.(hir.DataColumn); ok { - return true - } - - return false -} - -func isPermutation(d table.Declaration) bool { - if _, ok := d.(air.Permutation); ok { - return true - } else if _, ok := d.(mir.Permutation); ok { - return true - } else if _, ok := d.(hir.Permutation); ok { - return true - } - - return false -} - -func isVanishing(d table.Declaration) bool { - if _, ok := d.(air.VanishingConstraint); ok { - return true - } else if _, ok := d.(mir.VanishingConstraint); ok { - return true - } else if _, ok := d.(hir.VanishingConstraint); ok { - return true - } - - return false -} - -func isRange(d table.Declaration) bool { - if _, ok := d.(*table.RangeConstraint); ok { - return true - } - - return false +func printSchema(schema schema.Schema, stats bool) { + panic("todo") } func init() { rootCmd.AddCommand(debugCmd) - debugCmd.Flags().BoolP("stats", "s", false, "Report statistics") debugCmd.Flags().Bool("hir", false, "Print constraints at HIR level") debugCmd.Flags().Bool("mir", false, "Print constraints at MIR level") debugCmd.Flags().Bool("air", false, "Print constraints at AIR level") diff --git a/pkg/cmd/util.go b/pkg/cmd/util.go index 6f30b62..719255b 100644 --- a/pkg/cmd/util.go +++ b/pkg/cmd/util.go @@ -9,7 +9,7 @@ import ( "github.com/consensys/go-corset/pkg/binfile" "github.com/consensys/go-corset/pkg/hir" "github.com/consensys/go-corset/pkg/sexp" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/trace" "github.com/spf13/cobra" ) @@ -47,8 +47,8 @@ func getUint(cmd *cobra.Command, flag string) uint { } // Parse a trace file using a parser based on the extension of the filename. -func readTraceFile(filename string) *table.ArrayTrace { - var trace *table.ArrayTrace +func readTraceFile(filename string) *trace.ArrayTrace { + var tr *trace.ArrayTrace // Read data file bytes, err := os.ReadFile(filename) // Check success @@ -58,9 +58,9 @@ func readTraceFile(filename string) *table.ArrayTrace { // switch ext { case ".json": - trace, err = table.ParseJsonTrace(bytes) + tr, err = trace.ParseJsonTrace(bytes) if err == nil { - return trace + return tr } case ".lt": panic("Support for lt trace files not implemented (yet).") diff --git a/pkg/hir/eval.go b/pkg/hir/eval.go index 0eb6a43..98dd4f7 100644 --- a/pkg/hir/eval.go +++ b/pkg/hir/eval.go @@ -2,14 +2,14 @@ package hir import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/trace" ) // EvalAllAt evaluates a column access at a given row in a trace, which returns the // value at that row of the column in question or nil is that row is // out-of-bounds. -func (e *ColumnAccess) EvalAllAt(k int, tbl table.Trace) []*fr.Element { - val := tbl.ColumnByName(e.Column).Get(k + e.Shift) +func (e *ColumnAccess) EvalAllAt(k int, tr trace.Trace) []*fr.Element { + val := tr.ColumnByName(e.Column).Get(k + e.Shift) var clone fr.Element // Clone original value @@ -18,7 +18,7 @@ func (e *ColumnAccess) EvalAllAt(k int, tbl table.Trace) []*fr.Element { // EvalAllAt evaluates a constant at a given row in a trace, which simply returns // that constant. -func (e *Constant) EvalAllAt(k int, tbl table.Trace) []*fr.Element { +func (e *Constant) EvalAllAt(k int, tr trace.Trace) []*fr.Element { var clone fr.Element // Clone original value return []*fr.Element{clone.Set(e.Val)} @@ -26,16 +26,16 @@ func (e *Constant) EvalAllAt(k int, tbl table.Trace) []*fr.Element { // EvalAllAt evaluates a sum at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Add) EvalAllAt(k int, tbl table.Trace) []*fr.Element { +func (e *Add) EvalAllAt(k int, tr trace.Trace) []*fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Add(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // EvalAllAt evaluates a product at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Mul) EvalAllAt(k int, tbl table.Trace) []*fr.Element { +func (e *Mul) EvalAllAt(k int, tr trace.Trace) []*fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Mul(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // EvalAllAt evaluates a conditional at a given row in a trace by first evaluating @@ -43,16 +43,16 @@ func (e *Mul) EvalAllAt(k int, tbl table.Trace) []*fr.Element { // (if applicable) is evaluated; otherwise if the condition is non-zero then // false branch (if applicable) is evaluated). If the branch to be evaluated is // missing (i.e. nil), then nil is returned. -func (e *IfZero) EvalAllAt(k int, tbl table.Trace) []*fr.Element { +func (e *IfZero) EvalAllAt(k int, tr trace.Trace) []*fr.Element { vals := make([]*fr.Element, 0) // Evaluate condition - conditions := e.Condition.EvalAllAt(k, tbl) + conditions := e.Condition.EvalAllAt(k, tr) // Check all results for _, cond := range conditions { if cond.IsZero() && e.TrueBranch != nil { - vals = append(vals, e.TrueBranch.EvalAllAt(k, tbl)...) + vals = append(vals, e.TrueBranch.EvalAllAt(k, tr)...) } else if !cond.IsZero() && e.FalseBranch != nil { - vals = append(vals, e.FalseBranch.EvalAllAt(k, tbl)...) + vals = append(vals, e.FalseBranch.EvalAllAt(k, tr)...) } } @@ -61,11 +61,11 @@ func (e *IfZero) EvalAllAt(k int, tbl table.Trace) []*fr.Element { // EvalAllAt evaluates a list at a given row in a trace by evaluating each of its // arguments at that row. -func (e *List) EvalAllAt(k int, tbl table.Trace) []*fr.Element { +func (e *List) EvalAllAt(k int, tr trace.Trace) []*fr.Element { vals := make([]*fr.Element, 0) for _, e := range e.Args { - vs := e.EvalAllAt(k, tbl) + vs := e.EvalAllAt(k, tr) vals = append(vals, vs...) } @@ -75,9 +75,9 @@ func (e *List) EvalAllAt(k int, tbl table.Trace) []*fr.Element { // EvalAllAt evaluates the normalisation of some expression by first evaluating // that expression. Then, zero is returned if the result is zero; otherwise one // is returned. -func (e *Normalise) EvalAllAt(k int, tbl table.Trace) []*fr.Element { +func (e *Normalise) EvalAllAt(k int, tr trace.Trace) []*fr.Element { // Check whether argument evaluates to zero or not. - vals := e.Arg.EvalAllAt(k, tbl) + vals := e.Arg.EvalAllAt(k, tr) // Normalise values (as necessary) for _, e := range vals { if !e.IsZero() { @@ -90,20 +90,20 @@ func (e *Normalise) EvalAllAt(k int, tbl table.Trace) []*fr.Element { // EvalAllAt evaluates a subtraction at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Sub) EvalAllAt(k int, tbl table.Trace) []*fr.Element { +func (e *Sub) EvalAllAt(k int, tr trace.Trace) []*fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Sub(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // EvalExprsAt evaluates all expressions in a given slice at a given row on the // table, and fold their results together using a combinator. -func evalExprsAt(k int, tbl table.Trace, exprs []Expr, fn func(*fr.Element, *fr.Element)) []*fr.Element { +func evalExprsAt(k int, tr trace.Trace, exprs []Expr, fn func(*fr.Element, *fr.Element)) []*fr.Element { // Evaluate first argument. - vals := exprs[0].EvalAllAt(k, tbl) + vals := exprs[0].EvalAllAt(k, tr) // Continue evaluating the rest. for i := 1; i < len(exprs); i++ { - vs := exprs[i].EvalAllAt(k, tbl) + vs := exprs[i].EvalAllAt(k, tr) vals = evalExprsAtApply(vals, vs, fn) } diff --git a/pkg/hir/expr.go b/pkg/hir/expr.go index fdbd27a..a596546 100644 --- a/pkg/hir/expr.go +++ b/pkg/hir/expr.go @@ -3,7 +3,7 @@ package hir import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/mir" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -28,7 +28,7 @@ type Expr interface { // undefined for several reasons: firstly, if it accesses a // row which does not exist (e.g. at index -1); secondly, if // it accesses a column which does not exist. - EvalAllAt(int, table.Trace) []*fr.Element + EvalAllAt(int, trace.Trace) []*fr.Element // String produces a string representing this as an S-Expression. String() string } diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 0b7d37c..214b854 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -5,8 +5,58 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/mir" + sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/schema/constraint" ) +// LowerToMir lowers (or refines) an HIR table into an MIR schema. That means +// lowering all the columns and constraints, whilst adding additional columns / +// constraints as necessary to preserve the original semantics. +func (p *Schema) LowerToMir() *mir.Schema { + mirSchema := mir.EmptySchema() + // First, lower columns + for _, input := range p.inputs { + col := input.(DataColumn) + mirSchema.AddDataColumn(col.Name(), col.Type()) + } + // Second, lower permutations + for _, asn := range p.assignments { + col := asn.(Permutation) + mirSchema.AddPermutationColumns(col.Targets(), col.Signs, col.Sources) + } + // Third, lower constraints + for _, c := range p.constraints { + lowerConstraintToMir(c, mirSchema) + } + // Fourth, copy property assertions. Observe, these do not require lowering + // because they are already MIR-level expressions. + for _, c := range p.assertions { + properties := c.Property.Expr.LowerTo(mirSchema) + for _, p := range properties { + mirSchema.AddPropertyAssertion(c.Handle, p) + } + } + // + return mirSchema +} + +func lowerConstraintToMir(c sc.Constraint, schema *mir.Schema) { + // Check what kind of constraint we have + if v, ok := c.(VanishingConstraint); ok { + mir_exprs := v.Constraint.Expr.LowerTo(schema) + // Add individual constraints arising + for _, mir_expr := range mir_exprs { + schema.AddVanishingConstraint(v.Handle, v.Domain, mir_expr) + } + } else if v, ok := c.(*constraint.TypeConstraint); ok { + schema.AddTypeConstraint(v.Target(), v.Type()) + } else { + // Should be unreachable as no other constraint types can be added to a + // schema. + panic("unreachable") + } +} + // LowerTo lowers a sum expression to the MIR level. This requires expanding // the arguments, then lowering them. Furthermore, conditionals are "lifted" to // the top. @@ -163,7 +213,7 @@ func lowerBody(e Expr, schema *mir.Schema) mir.Expr { } else if p, ok := e.(*Constant); ok { return &mir.Constant{Value: p.Val} } else if p, ok := e.(*ColumnAccess); ok { - if index, ok := schema.ColumnIndex(p.Column); ok { + if index, ok := sc.ColumnIndexOf(schema, p.Column); ok { return &mir.ColumnAccess{Column: index, Shift: p.Shift} } // Should be unreachable as all columns should have been vetted earlier diff --git a/pkg/hir/parser.go b/pkg/hir/parser.go index bec6e98..52b128a 100644 --- a/pkg/hir/parser.go +++ b/pkg/hir/parser.go @@ -6,8 +6,9 @@ import ( "strings" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/schema" + sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/sexp" - "github.com/consensys/go-corset/pkg/table" ) // =================================================================== @@ -82,7 +83,7 @@ func (p *hirParser) parseDeclaration(s sexp.SExp) error { } else if e.Len() == 3 && e.MatchSymbols(2, "assert") { return p.parseAssertionDeclaration(e.Elements) } else if e.Len() == 3 && e.MatchSymbols(1, "permute") { - return p.parseSortedPermutationDeclaration(e.Elements) + return p.parseSortedPermutationDeclaration(e) } } // Error @@ -98,11 +99,11 @@ func (p *hirParser) parseColumnDeclaration(l *sexp.List) error { // Extract column name columnName := l.Elements[1].String() // Sanity check doesn't already exist - if p.schema.HasColumn(columnName) { + if sc.HasColumn(p.schema, columnName) { return p.translator.SyntaxError(l, "duplicate column declaration") } // Default to field type - var columnType table.Type = &table.FieldType{} + var columnType sc.Type = &sc.FieldType{} // Parse type (if applicable) if len(l.Elements) == 3 { var err error @@ -114,36 +115,34 @@ func (p *hirParser) parseColumnDeclaration(l *sexp.List) error { } // Register column in Schema p.schema.AddDataColumn(columnName, columnType) + p.schema.AddTypeConstraint(columnName, columnType) return nil } // Parse a sorted permutation declaration -func (p *hirParser) parseSortedPermutationDeclaration(elements []sexp.SExp) error { +func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error { // Target columns are (sorted) permutations of source columns. - sexpTargets := elements[1].AsList() + sexpTargets := l.Elements[1].AsList() // Source columns. - sexpSources := elements[2].AsList() + sexpSources := l.Elements[2].AsList() // Convert into appropriate form. - targets := make([]string, sexpTargets.Len()) + targets := make([]schema.Column, sexpTargets.Len()) sources := make([]string, sexpSources.Len()) signs := make([]bool, sexpSources.Len()) // - for i := 0; i < sexpTargets.Len(); i++ { - target := sexpTargets.Get(i).AsSymbol() - // Sanity check syntax as expected - if target == nil { - return p.translator.SyntaxError(sexpTargets.Get(i), "malformed column") - } - // Copy over - targets[i] = target.String() + if sexpTargets.Len() != sexpSources.Len() { + return p.translator.SyntaxError(l, "sorted permutation requires matching number of source and target columns") } // for i := 0; i < sexpSources.Len(); i++ { source := sexpSources.Get(i).AsSymbol() + target := sexpTargets.Get(i).AsSymbol() // Sanity check syntax as expected if source == nil { return p.translator.SyntaxError(sexpSources.Get(i), "malformed column") + } else if target == nil { + return p.translator.SyntaxError(sexpTargets.Get(i), "malformed column") } // Determine source column sign (i.e. sort direction) sortName := source.String() @@ -160,6 +159,8 @@ func (p *hirParser) parseSortedPermutationDeclaration(elements []sexp.SExp) erro } // Copy over column name sources[i] = sortName[1:] + // FIXME: determine source column type + targets[i] = schema.NewColumn(target.String(), &schema.FieldType{}) } // p.schema.AddPermutationColumns(targets, signs, sources) @@ -195,7 +196,7 @@ func (p *hirParser) parseVanishingDeclaration(elements []sexp.SExp, domain *int) return nil } -func (p *hirParser) parseType(term sexp.SExp) (table.Type, error) { +func (p *hirParser) parseType(term sexp.SExp) (sc.Type, error) { symbol := term.AsSymbol() if symbol == nil { return nil, p.translator.SyntaxError(term, "malformed column") @@ -207,8 +208,8 @@ func (p *hirParser) parseType(term sexp.SExp) (table.Type, error) { if err != nil { return nil, err } - // TODO: support @prove - return table.NewUintType(uint(n), true), nil + // Done + return sc.NewUintType(uint(n)), nil } // Error return nil, p.translator.SyntaxError(symbol, "unknown type") diff --git a/pkg/hir/schema.go b/pkg/hir/schema.go index 7d7c4bd..72e1a3f 100644 --- a/pkg/hir/schema.go +++ b/pkg/hir/schema.go @@ -1,8 +1,10 @@ package hir import ( - "github.com/consensys/go-corset/pkg/mir" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/schema/assignment" + "github.com/consensys/go-corset/pkg/schema/constraint" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -16,7 +18,7 @@ type ZeroArrayTest struct { // TestAt determines whether or not every element from a given array of // expressions evaluates to zero. Observe that any expressions which are // undefined are assumed to hold. -func (p ZeroArrayTest) TestAt(row int, tr table.Trace) bool { +func (p ZeroArrayTest) TestAt(row int, tr trace.Trace) bool { // Evalues expression yielding zero or more values. vals := p.Expr.EvalAllAt(row, tr) // Check each value in turn against zero. @@ -40,29 +42,30 @@ func (p ZeroArrayTest) Bounds() util.Bounds { } // DataColumn captures the essence of a data column at AIR level. -type DataColumn = *table.DataColumn[table.Type] +type DataColumn = *assignment.DataColumn // VanishingConstraint captures the essence of a vanishing constraint at the HIR // level. A vanishing constraint is a row constraint which must evaluate to // zero. -type VanishingConstraint = *table.RowConstraint[ZeroArrayTest] +type VanishingConstraint = *constraint.VanishingConstraint[ZeroArrayTest] // PropertyAssertion captures the notion of an arbitrary property which should // hold for all acceptable traces. However, such a property is not enforced by // the prover. -type PropertyAssertion = *table.PropertyAssertion[ZeroArrayTest] +type PropertyAssertion = *schema.PropertyAssertion[ZeroArrayTest] // Permutation captures the notion of a (sorted) permutation at the HIR level. -type Permutation = *table.SortedPermutation +type Permutation = *assignment.SortedPermutation // Schema for HIR constraints and columns. type Schema struct { // The data columns of this schema. - dataColumns []DataColumn + inputs []schema.Declaration // The sorted permutations of this schema. - permutations []Permutation - // The vanishing constraints of this schema. - vanishing []VanishingConstraint + assignments []schema.Assignment + // Constraints of this schema, which are either vanishing, lookup or type + // constraints. + constraints []schema.Constraint // The property assertions for this schema. assertions []PropertyAssertion } @@ -71,98 +74,18 @@ type Schema struct { // constraints will be added. func EmptySchema() *Schema { p := new(Schema) - p.dataColumns = make([]DataColumn, 0) - p.permutations = make([]Permutation, 0) - p.vanishing = make([]VanishingConstraint, 0) + p.inputs = make([]schema.Declaration, 0) + p.assignments = make([]schema.Assignment, 0) + p.constraints = make([]schema.Constraint, 0) p.assertions = make([]PropertyAssertion, 0) // Done return p } -// Column returns information about the ith column in this schema. -func (p *Schema) Column(i uint) table.ColumnSchema { - panic("todo") -} - -// Width returns the number of column groups in this schema. -func (p *Schema) Width() uint { - return uint(len(p.dataColumns) + len(p.permutations)) -} - -// ColumnGroup returns information about the ith column group in this schema. -func (p *Schema) ColumnGroup(i uint) table.ColumnGroup { - n := uint(len(p.dataColumns)) - if i < n { - return p.dataColumns[i] - } - - return p.permutations[i-n] -} - -// ColumnIndex determines the column index for a given column in this schema, or -// returns false indicating an error. -func (p *Schema) ColumnIndex(name string) (uint, bool) { - index := uint(0) - - for i := uint(0); i < p.Width(); i++ { - ith := p.ColumnGroup(i) - for j := uint(0); j < ith.Width(); j++ { - if ith.NameOf(j) == name { - // hit - return index, true - } - - index++ - } - } - // miss - return 0, false -} - -// HasColumn checks whether a given schema has a given column. -func (p *Schema) HasColumn(name string) bool { - for _, c := range p.dataColumns { - if (*c).Name() == name { - return true - } - } - - return false -} - -// Columns returns the set of (data) columns declared within this schema. -func (p *Schema) Columns() []*table.DataColumn[table.Type] { - return p.dataColumns -} - -// Constraints returns the set of (vanishing) constraints declared within this schema. -func (p *Schema) Constraints() []VanishingConstraint { - return p.vanishing -} - -// Size returns the number of declarations in this schema. -func (p *Schema) Size() int { - return len(p.dataColumns) + len(p.permutations) + len(p.vanishing) + len(p.assertions) -} - -// RequiredSpillage returns the minimum amount of spillage required to ensure -// valid traces are accepted in the presence of arbitrary padding. -func (p *Schema) RequiredSpillage() uint { - // Ensures always at least one row of spillage (referred to as the "initial - // padding row") - return uint(1) -} - -// GetDeclaration returns the ith declaration in this schema. -func (p *Schema) GetDeclaration(index int) table.Declaration { - ith := util.FlatArrayIndexOf_4(index, p.dataColumns, p.permutations, p.vanishing, p.assertions) - return ith.(table.Declaration) -} - // AddDataColumn appends a new data column with a given type. Furthermore, the // type is enforced by the system when checking is enabled. -func (p *Schema) AddDataColumn(name string, base table.Type) { - p.dataColumns = append(p.dataColumns, table.NewDataColumn(name, base, false)) +func (p *Schema) AddDataColumn(name string, base schema.Type) { + p.inputs = append(p.inputs, assignment.NewDataColumn(name, base)) } // AddPermutationColumns introduces a permutation of one or more @@ -171,87 +94,65 @@ func (p *Schema) AddDataColumn(name string, base table.Type) { // source columns. Each source column is associated with a "sign" // which indicates the direction of sorting (i.e. ascending versus // descending). -func (p *Schema) AddPermutationColumns(targets []string, signs []bool, sources []string) { - p.permutations = append(p.permutations, table.NewSortedPermutation(targets, signs, sources)) +func (p *Schema) AddPermutationColumns(targets []schema.Column, signs []bool, sources []string) { + p.assignments = append(p.assignments, assignment.NewSortedPermutation(targets, signs, sources)) } // AddVanishingConstraint appends a new vanishing constraint. func (p *Schema) AddVanishingConstraint(handle string, domain *int, expr Expr) { - p.vanishing = append(p.vanishing, table.NewRowConstraint(handle, domain, ZeroArrayTest{expr})) + p.constraints = append(p.constraints, constraint.NewVanishingConstraint(handle, domain, ZeroArrayTest{expr})) +} + +// AddTypeConstraint appends a new range constraint. +func (p *Schema) AddTypeConstraint(target string, t schema.Type) { + // Check whether is a field type, as these can actually be ignored. + if t.AsField() == nil { + p.constraints = append(p.constraints, constraint.NewTypeConstraint(target, t)) + } } // AddPropertyAssertion appends a new property assertion. func (p *Schema) AddPropertyAssertion(handle string, property Expr) { - p.assertions = append(p.assertions, table.NewPropertyAssertion[ZeroArrayTest](handle, ZeroArrayTest{property})) + p.assertions = append(p.assertions, schema.NewPropertyAssertion[ZeroArrayTest](handle, ZeroArrayTest{property})) } -// Accepts determines whether this schema will accept a given trace. That -// is, whether or not the given trace adheres to the schema. A trace can fail -// to adhere to the schema for a variety of reasons, such as having a constraint -// which does not hold. -func (p *Schema) Accepts(trace table.Trace) error { - // Check (typed) data columns - if err := table.ConstraintsAcceptTrace(trace, p.dataColumns); err != nil { - return err - } - // Check permutations - if err := table.ConstraintsAcceptTrace(trace, p.permutations); err != nil { - return err - } - // Check vanishing constraints - if err := table.ConstraintsAcceptTrace(trace, p.vanishing); err != nil { - return err - } - // Check properties - if err := table.ConstraintsAcceptTrace(trace, p.assertions); err != nil { - return err - } - // Done - return nil -} +// ============================================================================ +// Schema Interface +// ============================================================================ -// ExpandTrace expands a given trace according to this schema. -func (p *Schema) ExpandTrace(tr table.Trace) error { - // Expand all the permutation columns - for _, perm := range p.permutations { - err := perm.ExpandTrace(tr) - if err != nil { - return err - } - } +// Inputs returns an array over the input declarations of this schema. That is, +// the subset of declarations whose trace values must be provided by the user. +func (p *Schema) Inputs() util.Iterator[schema.Declaration] { + return util.NewArrayIterator(p.inputs) +} - return nil +// Assignments returns an array over the assignments of this schema. That +// is, the subset of declarations whose trace values can be computed from +// the inputs. +func (p *Schema) Assignments() util.Iterator[schema.Assignment] { + return util.NewArrayIterator(p.assignments) } -// LowerToMir lowers (or refines) an HIR table into an MIR table. That means -// lowering all the columns and constraints, whilst adding additional columns / -// constraints as necessary to preserve the original semantics. -func (p *Schema) LowerToMir() *mir.Schema { - mirSchema := mir.EmptySchema() - // First, lower columns - for _, col := range p.dataColumns { - mirSchema.AddDataColumn(col.Name(), col.Type) - } - // Second, lower permutations - for _, col := range p.permutations { - mirSchema.AddPermutationColumns(col.Targets, col.Signs, col.Sources) - } - // Third, lower constraints - for _, c := range p.vanishing { - mir_exprs := c.Constraint.Expr.LowerTo(mirSchema) - // Add individual constraints arising - for _, mir_expr := range mir_exprs { - mirSchema.AddVanishingConstraint(c.Handle, c.Domain, mir_expr) - } - } - // Fourth, copy property assertions. Observe, these do not require lowering - // because they are already MIR-level expressions. - for _, c := range p.assertions { - properties := c.Property.Expr.LowerTo(mirSchema) - for _, p := range properties { - mirSchema.AddPropertyAssertion(c.Handle, p) - } - } +// Columns returns an array over the underlying columns of this schema. +// Specifically, the index of a column in this array is its column index. +func (p *Schema) Columns() util.Iterator[schema.Column] { + is := util.NewFlattenIterator[schema.Declaration, schema.Column](p.Inputs(), + func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() }) + ps := util.NewFlattenIterator[schema.Assignment, schema.Column](p.Assignments(), + func(d schema.Assignment) util.Iterator[schema.Column] { return d.Columns() }) // - return mirSchema + return is.Append(ps) +} + +// Constraints returns an array over the underlying constraints of this +// schema. +func (p *Schema) Constraints() util.Iterator[schema.Constraint] { + return util.NewArrayIterator(p.constraints) +} + +// Declarations returns an array over the column declarations of this +// schema. +func (p *Schema) Declarations() util.Iterator[schema.Declaration] { + ps := util.NewCastIterator[schema.Assignment, schema.Declaration](p.Assignments()) + return p.Inputs().Append(ps) } diff --git a/pkg/mir/eval.go b/pkg/mir/eval.go index 9ccbddb..10e602c 100644 --- a/pkg/mir/eval.go +++ b/pkg/mir/eval.go @@ -2,14 +2,14 @@ package mir import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/trace" ) // EvalAt evaluates a column access at a given row in a trace, which returns the // value at that row of the column in question or nil is that row is // out-of-bounds. -func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element { - val := tbl.ColumnByIndex(e.Column).Get(k + e.Shift) +func (e *ColumnAccess) EvalAt(k int, tr trace.Trace) *fr.Element { + val := tr.ColumnByIndex(e.Column).Get(k + e.Shift) var clone fr.Element // Clone original value @@ -18,7 +18,7 @@ func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element { // EvalAt evaluates a constant at a given row in a trace, which simply returns // that constant. -func (e *Constant) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Constant) EvalAt(k int, tr trace.Trace) *fr.Element { var clone fr.Element // Clone original value return clone.Set(e.Value) @@ -26,24 +26,24 @@ func (e *Constant) EvalAt(k int, tbl table.Trace) *fr.Element { // EvalAt evaluates a sum at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Add) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Add) EvalAt(k int, tr trace.Trace) *fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Add(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // EvalAt evaluates a product at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Mul) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Mul) EvalAt(k int, tr trace.Trace) *fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Mul(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // EvalAt evaluates the normalisation of some expression by first evaluating // that expression. Then, zero is returned if the result is zero; otherwise one // is returned. -func (e *Normalise) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Normalise) EvalAt(k int, tr trace.Trace) *fr.Element { // Check whether argument evaluates to zero or not. - val := e.Arg.EvalAt(k, tbl) + val := e.Arg.EvalAt(k, tr) // Normalise value (if necessary) if !val.IsZero() { val.SetOne() @@ -54,19 +54,19 @@ func (e *Normalise) EvalAt(k int, tbl table.Trace) *fr.Element { // EvalAt evaluates a subtraction at a given row in a trace by first evaluating all of // its arguments at that row. -func (e *Sub) EvalAt(k int, tbl table.Trace) *fr.Element { +func (e *Sub) EvalAt(k int, tr trace.Trace) *fr.Element { fn := func(l *fr.Element, r *fr.Element) { l.Sub(l, r) } - return evalExprsAt(k, tbl, e.Args, fn) + return evalExprsAt(k, tr, e.Args, fn) } // Evaluate all expressions in a given slice at a given row on the // table, and fold their results together using a combinator. -func evalExprsAt(k int, tbl table.Trace, exprs []Expr, fn func(*fr.Element, *fr.Element)) *fr.Element { +func evalExprsAt(k int, tr trace.Trace, exprs []Expr, fn func(*fr.Element, *fr.Element)) *fr.Element { // Evaluate first argument - val := exprs[0].EvalAt(k, tbl) + val := exprs[0].EvalAt(k, tr) // Continue evaluating the rest for i := 1; i < len(exprs); i++ { - ith := exprs[i].EvalAt(k, tbl) + ith := exprs[i].EvalAt(k, tr) fn(val, ith) } diff --git a/pkg/mir/expr.go b/pkg/mir/expr.go index 505e9f2..0699c64 100644 --- a/pkg/mir/expr.go +++ b/pkg/mir/expr.go @@ -3,7 +3,7 @@ package mir import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/air" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -25,7 +25,7 @@ type Expr interface { // undefined for several reasons: firstly, if it accesses a // row which does not exist (e.g. at index -1); secondly, if // it accesses a column which does not exist. - EvalAt(int, table.Trace) *fr.Element + EvalAt(int, trace.Trace) *fr.Element // String produces a string representing this as an S-Expression. String() string } diff --git a/pkg/mir/lower.go b/pkg/mir/lower.go index 805bdb5..46a660d 100644 --- a/pkg/mir/lower.go +++ b/pkg/mir/lower.go @@ -3,8 +3,125 @@ package mir import ( "github.com/consensys/go-corset/pkg/air" air_gadgets "github.com/consensys/go-corset/pkg/air/gadgets" + "github.com/consensys/go-corset/pkg/schema" + sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/schema/constraint" ) +// LowerToAir lowers (or refines) an MIR table into an AIR schema. That means +// lowering all the columns and constraints, whilst adding additional columns / +// constraints as necessary to preserve the original semantics. +func (p *Schema) LowerToAir() *air.Schema { + airSchema := air.EmptySchema[Expr]() + // Add data columns. + for _, c := range p.inputs { + col := c.(DataColumn) + airSchema.AddColumn(col.Name(), col.Type()) + } + // Add Assignments. Again this has to be done first for things to work. + // Essentially to reflect the fact that these columns have been added above + // before others. Realistically, the overall design of this process is a + // bit broken right now. + for _, perm := range p.assignments { + airSchema.AddAssignment(perm.(Permutation)) + } + // Lower permutations columns + for _, perm := range p.assignments { + lowerPermutationToAir(perm.(Permutation), p, airSchema) + } + // Lower vanishing constraints + for _, c := range p.constraints { + lowerConstraintToAir(c, airSchema) + } + // Done + return airSchema +} + +// Lower a constraint to the AIR level. +func lowerConstraintToAir(c sc.Constraint, schema *air.Schema) { + // Check what kind of constraint we have + if v, ok := c.(VanishingConstraint); ok { + air_expr := v.Constraint.Expr.LowerTo(schema) + schema.AddVanishingConstraint(v.Handle, v.Domain, air_expr) + } else if v, ok := c.(*constraint.TypeConstraint); ok { + if t := v.Type().AsUint(); t != nil { + index, ok := sc.ColumnIndexOf(schema, v.Target()) + // Sanity check + if !ok { + panic("Cannot find column") + } + // Yes, a constraint is implied. Now, decide whether to use a range + // constraint or just a vanishing constraint. + if t.HasBound(2) { + // u1 => use vanishing constraint X * (X - 1) + air_gadgets.ApplyBinaryGadget(index, schema) + } else if t.HasBound(256) { + // u2..8 use range constraints + schema.AddRangeConstraint(index, t.Bound()) + } else { + // u9+ use byte decompositions. + air_gadgets.ApplyBitwidthGadget(index, t.BitWidth(), schema) + } + } + } else { + // Should be unreachable as no other constraint types can be added to a + // schema. + panic("unreachable") + } +} + +// Lower a permutation to the AIR level. This has quite a few +// effects. Firstly, permutation constraints are added for all of the +// new columns. Secondly, sorting constraints (and their associated +// computed columns) must also be added. Finally, a trace +// computation is required to ensure traces are correctly expanded to +// meet the requirements of a sorted permutation. +func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Schema) { + c_targets := c.Targets() + ncols := len(c_targets) + // + targets := make([]uint, ncols) + sources := make([]uint, ncols) + // Add individual permutation constraints + for i := 0; i < ncols; i++ { + var ok1, ok2 bool + // TODO: REPLACE + sources[i], ok1 = sc.ColumnIndexOf(airSchema, c.Sources[i]) + targets[i], ok2 = sc.ColumnIndexOf(airSchema, c_targets[i].Name()) + + if !ok1 || !ok2 { + panic("missing column") + } + } + // + airSchema.AddPermutationConstraint(targets, sources) + // Add sorting constraints + computed columns as necessary. + if ncols == 1 { + // For a single column sort, its actually a bit easier because we don't + // need to implement a multiplexor (i.e. to determine which column is + // differs, etc). Instead, we just need a delta column which ensures + // there is a non-negative difference between consecutive rows. This + // also requires bitwidth constraints. + bitwidth := schema.ColumnByName(mirSchema, c.Sources[0]).Type().AsUint().BitWidth() + // Add column sorting constraints + air_gadgets.ApplyColumnSortGadget(targets[0], c.Signs[0], bitwidth, airSchema) + } else { + // For a multi column sort, its a bit harder as we need additional + // logicl to ensure the target columns are lexicographally sorted. + bitwidth := uint(0) + + for i := 0; i < ncols; i++ { + // Extract bitwidth of ith column + ith := schema.ColumnByName(mirSchema, c.Sources[i]).Type().AsUint().BitWidth() + if ith > bitwidth { + bitwidth = ith + } + } + // Add lexicographically sorted constraints + air_gadgets.ApplyLexicographicSortingGadget(targets, c.Signs, bitwidth, airSchema) + } +} + // LowerTo lowers a sum expression to the AIR level by lowering the arguments. func (e *Add) LowerTo(schema *air.Schema) air.Expr { return &air.Add{Args: lowerExprs(e.Args, schema)} diff --git a/pkg/mir/schema.go b/pkg/mir/schema.go index 1dacb13..3272f5c 100644 --- a/pkg/mir/schema.go +++ b/pkg/mir/schema.go @@ -1,38 +1,37 @@ package mir import ( - "fmt" - - "github.com/consensys/go-corset/pkg/air" - air_gadgets "github.com/consensys/go-corset/pkg/air/gadgets" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/schema/assignment" + "github.com/consensys/go-corset/pkg/schema/constraint" "github.com/consensys/go-corset/pkg/util" ) // DataColumn captures the essence of a data column at the MIR level. -type DataColumn = *table.DataColumn[table.Type] +type DataColumn = *assignment.DataColumn // VanishingConstraint captures the essence of a vanishing constraint at the MIR // level. A vanishing constraint is a row constraint which must evaluate to // zero. -type VanishingConstraint = *table.RowConstraint[table.ZeroTest[Expr]] +type VanishingConstraint = *constraint.VanishingConstraint[constraint.ZeroTest[Expr]] // PropertyAssertion captures the notion of an arbitrary property which should // hold for all acceptable traces. However, such a property is not enforced by // the prover. -type PropertyAssertion = *table.PropertyAssertion[table.ZeroTest[Expr]] +type PropertyAssertion = *schema.PropertyAssertion[constraint.ZeroTest[Expr]] // Permutation captures the notion of a (sorted) permutation at the MIR level. -type Permutation = *table.SortedPermutation +type Permutation = *assignment.SortedPermutation // Schema for MIR traces type Schema struct { // The data columns of this schema. - dataColumns []DataColumn + inputs []schema.Declaration // The sorted permutations of this schema. - permutations []Permutation - // The vanishing constraints of this schema. - vanishing []VanishingConstraint + assignments []schema.Assignment + // The constraints of this schema, which are either vanishing constraints, + // type constraints or lookup constraints. + constraints []schema.Constraint // The property assertions for this schema. assertions []PropertyAssertion } @@ -41,89 +40,17 @@ type Schema struct { // constraints will be added. func EmptySchema() *Schema { p := new(Schema) - p.dataColumns = make([]DataColumn, 0) - p.permutations = make([]Permutation, 0) - p.vanishing = make([]VanishingConstraint, 0) + p.inputs = make([]schema.Declaration, 0) + p.assignments = make([]schema.Assignment, 0) + p.constraints = make([]schema.Constraint, 0) p.assertions = make([]PropertyAssertion, 0) // Done return p } -// Width returns the number of column groups in this schema. -func (p *Schema) Width() uint { - return uint(len(p.dataColumns) + len(p.permutations)) -} - -// Column returns information about the ith column in this schema. -func (p *Schema) Column(i uint) table.ColumnSchema { - panic("todo") -} - -// ColumnGroup returns information about the ith column group in this schema. -func (p *Schema) ColumnGroup(i uint) table.ColumnGroup { - n := uint(len(p.dataColumns)) - if i < n { - return p.dataColumns[i] - } - - return p.permutations[i-n] -} - -// ColumnIndex determines the column index for a given column in this schema, or -// returns false indicating an error. -func (p *Schema) ColumnIndex(name string) (uint, bool) { - index := uint(0) - - for i := uint(0); i < p.Width(); i++ { - ith := p.ColumnGroup(i) - for j := uint(0); j < ith.Width(); j++ { - if ith.NameOf(j) == name { - // hit - return index, true - } - - index++ - } - } - // miss - return 0, false -} - -// GetColumnByName gets a given data column based on its name. If no such -// column exists, it panics. -func (p *Schema) GetColumnByName(name string) DataColumn { - for _, c := range p.dataColumns { - if c.Name() == name { - return c - } - } - - msg := fmt.Sprintf("unknown column encountered (%s)", name) - panic(msg) -} - -// Size returns the number of declarations in this schema. -func (p *Schema) Size() int { - return len(p.dataColumns) + len(p.permutations) + len(p.vanishing) + len(p.assertions) -} - -// RequiredSpillage returns the minimum amount of spillage required to ensure -// valid traces are accepted in the presence of arbitrary padding. -func (p *Schema) RequiredSpillage() uint { - // Ensures always at least one row of spillage (referred to as the "initial - // padding row") - return uint(1) -} - -// GetDeclaration returns the ith declaration in this schema. -func (p *Schema) GetDeclaration(index int) table.Declaration { - ith := util.FlatArrayIndexOf_4(index, p.dataColumns, p.permutations, p.vanishing, p.assertions) - return ith.(table.Declaration) -} - // AddDataColumn appends a new data column. -func (p *Schema) AddDataColumn(name string, base table.Type) { - p.dataColumns = append(p.dataColumns, table.NewDataColumn(name, base, false)) +func (p *Schema) AddDataColumn(name string, base schema.Type) { + p.inputs = append(p.inputs, assignment.NewDataColumn(name, base)) } // AddPermutationColumns introduces a permutation of one or more @@ -132,173 +59,67 @@ func (p *Schema) AddDataColumn(name string, base table.Type) { // source columns. Each source column is associated with a "sign" // which indicates the direction of sorting (i.e. ascending versus // descending). -func (p *Schema) AddPermutationColumns(targets []string, signs []bool, sources []string) { - p.permutations = append(p.permutations, table.NewSortedPermutation(targets, signs, sources)) +func (p *Schema) AddPermutationColumns(targets []schema.Column, signs []bool, sources []string) { + p.assignments = append(p.assignments, assignment.NewSortedPermutation(targets, signs, sources)) } // AddVanishingConstraint appends a new vanishing constraint. func (p *Schema) AddVanishingConstraint(handle string, domain *int, expr Expr) { - p.vanishing = append(p.vanishing, table.NewRowConstraint(handle, domain, table.ZeroTest[Expr]{Expr: expr})) + p.constraints = append(p.constraints, + constraint.NewVanishingConstraint(handle, domain, constraint.ZeroTest[Expr]{Expr: expr})) } -// AddPropertyAssertion appends a new property assertion. -func (p *Schema) AddPropertyAssertion(handle string, expr Expr) { - test := table.ZeroTest[Expr]{Expr: expr} - p.assertions = append(p.assertions, table.NewPropertyAssertion(handle, test)) -} - -// Accepts determines whether this schema will accept a given trace. That -// is, whether or not the given trace adheres to the schema. A trace can fail -// to adhere to the schema for a variety of reasons, such as having a constraint -// which does not hold. -func (p *Schema) Accepts(trace table.Trace) error { - // Check (typed) data columns - if err := table.ConstraintsAcceptTrace(trace, p.dataColumns); err != nil { - return err - } - // Check permutations - if err := table.ConstraintsAcceptTrace(trace, p.permutations); err != nil { - return err - } - // Check vanishing constraints - if err := table.ConstraintsAcceptTrace(trace, p.vanishing); err != nil { - return err - } - // Check property assertions - if err := table.ConstraintsAcceptTrace(trace, p.assertions); err != nil { - return err +// AddTypeConstraint appends a new range constraint. +func (p *Schema) AddTypeConstraint(target string, t schema.Type) { + // Check whether is a field type, as these can actually be ignored. + if t.AsField() == nil { + p.constraints = append(p.constraints, constraint.NewTypeConstraint(target, t)) } - // Done - return nil } -// LowerToAir lowers (or refines) an MIR table into an AIR table. That means -// lowering all the columns and constraints, whilst adding additional columns / -// constraints as necessary to preserve the original semantics. -func (p *Schema) LowerToAir() *air.Schema { - airSchema := air.EmptySchema[Expr]() - // Allocate data and permutation columns. This must be done first to ensure - // alignment is preserved across lowering. - index := uint(0) +// AddPropertyAssertion appends a new property assertion. +func (p *Schema) AddPropertyAssertion(handle string, expr Expr) { + test := constraint.ZeroTest[Expr]{Expr: expr} + p.assertions = append(p.assertions, schema.NewPropertyAssertion(handle, test)) +} - for i := uint(0); i < p.Width(); i++ { - ith := p.ColumnGroup(i) - for j := uint(0); j < ith.Width(); j++ { - col := ith.NameOf(j) - airSchema.AddColumn(col, ith.IsSynthetic()) +// ============================================================================ +// Schema Interface +// ============================================================================ - index++ - } - } - // Add computations. Again this has to be done first for things to work. - // Essentially to reflect the fact that these columns have been added above - // before others. Realistically, the overall design of this process is a - // bit broken right now. - for _, perm := range p.permutations { - airSchema.AddComputation(perm) - } - // Lower checked data columns - for i, col := range p.dataColumns { - lowerColumnToAir(uint(i), col, airSchema) - } - // Lower permutations columns - for _, perm := range p.permutations { - lowerPermutationToAir(perm, p, airSchema) - } - // Lower vanishing constraints - for _, c := range p.vanishing { - // FIXME: this is broken because its currently - // assuming that an AirConstraint is always a - // VanishingConstraint. Eventually this will not be - // true. - air_expr := c.Constraint.Expr.LowerTo(airSchema) - airSchema.AddVanishingConstraint(c.Handle, c.Domain, air_expr) - } - // Done - return airSchema +// Inputs returns an array over the input declarations of this schema. That is, +// the subset of declarations whose trace values must be provided by the user. +func (p *Schema) Inputs() util.Iterator[schema.Declaration] { + return util.NewArrayIterator(p.inputs) } -// Lower a datacolumn to the AIR level. The main effect of this is that, for -// columns with non-trivial types, we must add appropriate range constraints to -// the enclosing schema. -func lowerColumnToAir(index uint, c *table.DataColumn[table.Type], schema *air.Schema) { - // Check whether a constraint is implied by the column's type - if t := c.Type.AsUint(); t != nil && t.Checked() { - // Yes, a constraint is implied. Now, decide whether to use a range - // constraint or just a vanishing constraint. - if t.HasBound(2) { - // u1 => use vanishing constraint X * (X - 1) - air_gadgets.ApplyBinaryGadget(index, schema) - } else if t.HasBound(256) { - // u2..8 use range constraints - schema.AddRangeConstraint(index, t.Bound()) - } else { - // u9+ use byte decompositions. - air_gadgets.ApplyBitwidthGadget(index, t.BitWidth(), schema) - } - } +// Assignments returns an array over the assignments of this schema. That +// is, the subset of declarations whose trace values can be computed from +// the inputs. +func (p *Schema) Assignments() util.Iterator[schema.Assignment] { + return util.NewArrayIterator(p.assignments) } -// Lower a permutation to the AIR level. This has quite a few -// effects. Firstly, permutation constraints are added for all of the -// new columns. Secondly, sorting constraints (and their associated -// synthetic columns) must also be added. Finally, a trace -// computation is required to ensure traces are correctly expanded to -// meet the requirements of a sorted permutation. -func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Schema) { - ncols := len(c.Targets) - // - targets := make([]uint, ncols) - sources := make([]uint, ncols) - // Add individual permutation constraints - for i := 0; i < ncols; i++ { - var ok1, ok2 bool - // TODO: REPLACE - sources[i], ok1 = airSchema.ColumnIndex(c.Sources[i]) - targets[i], ok2 = airSchema.ColumnIndex(c.Targets[i]) - - if !ok1 || !ok2 { - panic("missing column") - } - } +// Columns returns an array over the underlying columns of this schema. +// Specifically, the index of a column in this array is its column index. +func (p *Schema) Columns() util.Iterator[schema.Column] { + is := util.NewFlattenIterator[schema.Declaration, schema.Column](p.Inputs(), + func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() }) + ps := util.NewFlattenIterator[schema.Assignment, schema.Column](p.Assignments(), + func(d schema.Assignment) util.Iterator[schema.Column] { return d.Columns() }) // - airSchema.AddPermutationConstraint(targets, sources) - // Add sorting constraints + synthetic columns as necessary. - if ncols == 1 { - // For a single column sort, its actually a bit easier because we don't - // need to implement a multiplexor (i.e. to determine which column is - // differs, etc). Instead, we just need a delta column which ensures - // there is a non-negative difference between consecutive rows. This - // also requires bitwidth constraints. - bitwidth := mirSchema.GetColumnByName(c.Sources[0]).Type.AsUint().BitWidth() - // Add column sorting constraints - air_gadgets.ApplyColumnSortGadget(targets[0], c.Signs[0], bitwidth, airSchema) - } else { - // For a multi column sort, its a bit harder as we need additional - // logicl to ensure the target columns are lexicographally sorted. - bitwidth := uint(0) - - for i := 0; i < ncols; i++ { - // Extract bitwidth of ith column - ith := mirSchema.GetColumnByName(c.Sources[i]).Type.AsUint().BitWidth() - if ith > bitwidth { - bitwidth = ith - } - } - // Add lexicographically sorted constraints - air_gadgets.ApplyLexicographicSortingGadget(targets, c.Signs, bitwidth, airSchema) - } + return is.Append(ps) } -// ExpandTrace expands a given trace according to this schema. -func (p *Schema) ExpandTrace(tr table.Trace) error { - // Expand all the permutation columns - for _, perm := range p.permutations { - err := perm.ExpandTrace(tr) - if err != nil { - return err - } - } +// Constraints returns an array over the underlying constraints of this +// schema. +func (p *Schema) Constraints() util.Iterator[schema.Constraint] { + return util.NewArrayIterator(p.constraints) +} - return nil +// Declarations returns an array over the column declarations of this +// schema. +func (p *Schema) Declarations() util.Iterator[schema.Declaration] { + ps := util.NewCastIterator[schema.Assignment, schema.Declaration](p.Assignments()) + return p.Inputs().Append(ps) } diff --git a/pkg/schema/alignment.go b/pkg/schema/alignment.go new file mode 100644 index 0000000..72ceb50 --- /dev/null +++ b/pkg/schema/alignment.go @@ -0,0 +1,81 @@ +package schema + +import ( + "fmt" + + tr "github.com/consensys/go-corset/pkg/trace" +) + +// AlignInputs attempts to align this trace with the input columns of a given +// schema. This means ensuring the order of columns in this trace matches the +// order of input columns in the schema. Thus, column indexes used by +// constraints in the schema can directly access in this trace (i.e. without +// name lookup). Alignment can fail, however, if there is a mismatch between +// columns in the trace and those expected by the schema. +func AlignInputs(p tr.Trace, schema Schema) error { + return alignWith(false, p, schema) +} + +// Align attempts to align this trace with a given schema. This means ensuring +// the order of columns in this trace matches the order in the schema. Thus, +// column indexes used by constraints in the schema can directly access in this +// trace (i.e. without name lookup). Alignment can fail, however, if there is a +// mismatch between columns in the trace and those expected by the schema. +func Align(p tr.Trace, schema Schema) error { + return alignWith(true, p, schema) +} + +// Alignment algorithm which operates either in unexpanded or expanded mode. In +// expanded mode, all columns must be accounted for and will be aligned. In +// unexpanded mode, the trace is only expected to contain input (i.e. +// non-computed) columns. Furthermore, in the schema these are expected to be +// allocated before computed columns. As such, alignment of these input +// columns is performed. +func alignWith(expand bool, p tr.Trace, schema Schema) error { + ncols := p.Width() + index := uint(0) + // Check each column described in this schema is present in the trace. + for i := schema.Declarations(); i.HasNext(); { + ith := i.Next() + if expand || !ith.IsComputed() { + for j := ith.Columns(); j.HasNext(); { + jth := j.Next() + // Determine column name + schemaName := jth.Name() + // Sanity check column exists + if index >= ncols { + return fmt.Errorf("trace missing column %s", schemaName) + } + + traceName := p.ColumnByIndex(index).Name() + // Check alignment + if traceName != schemaName { + // Not aligned --- so fix + k, ok := p.ColumnIndex(schemaName) + // check exists + if !ok { + return fmt.Errorf("trace missing column %s", schemaName) + } + // Swap columns + p.Swap(index, k) + } + // Continue + index++ + } + } + } + // Check whether all columns matched + if index == ncols { + // Yes, alignment complete. + return nil + } + // Error Case. + n := ncols - index + unknowns := make([]string, n) + // Determine names of unknown columns. + for i := index; i < ncols; i++ { + unknowns[i-index] = p.ColumnByIndex(i).Name() + } + // + return fmt.Errorf("trace contains unknown columns: %v", unknowns) +} diff --git a/pkg/schema/assertion.go b/pkg/schema/assertion.go new file mode 100644 index 0000000..f76cc29 --- /dev/null +++ b/pkg/schema/assertion.go @@ -0,0 +1,57 @@ +package schema + +import ( + "errors" + "fmt" + + tr "github.com/consensys/go-corset/pkg/trace" +) + +// PropertyAssertion is similar to a vanishing constraint but is used only for +// debugging / testing / verification. Unlike vanishing constraints, property +// assertions do not represent something that the prover can enforce. Rather, +// they represent properties which are expected to hold for every valid trace. +// That is, they should be implied by the actual constraints. Thus, whilst the +// prover cannot enforce such properties, external tools (such as for formal +// verification) can attempt to ensure they do indeed always hold. +type PropertyAssertion[T Testable] struct { + // A unique identifier for this constraint. This is primarily + // useful for debugging. + Handle string + // The actual assertion itself, namely an expression which + // should hold (i.e. vanish) for every row of a trace. + // Observe that this can be any function which is computable + // on a given trace --- we are not restricted to expressions + // which can be arithmetised. + Property T +} + +// NewPropertyAssertion constructs a new property assertion! +func NewPropertyAssertion[T Testable](handle string, property T) *PropertyAssertion[T] { + return &PropertyAssertion[T]{handle, property} +} + +// GetHandle returns the handle associated with this constraint. +// +//nolint:revive +func (p *PropertyAssertion[T]) GetHandle() string { + return p.Handle +} + +// Accepts checks whether a vanishing constraint evaluates to zero on every row +// of a table. If so, return nil otherwise return an error. +// +//nolint:revive +func (p *PropertyAssertion[T]) Accepts(tr tr.Trace) error { + for k := uint(0); k < tr.Height(); k++ { + // Check whether property holds (or was undefined) + if !p.Property.TestAt(int(k), tr) { + // Construct useful error message + msg := fmt.Sprintf("property assertion %s does not hold (row %d)", p.Handle, k) + // Evaluation failure + return errors.New(msg) + } + } + // All good + return nil +} diff --git a/pkg/table/computation.go b/pkg/schema/assignment/byte_decomposition.go similarity index 57% rename from pkg/table/computation.go rename to pkg/schema/assignment/byte_decomposition.go index 9267406..aaef0cd 100644 --- a/pkg/table/computation.go +++ b/pkg/schema/assignment/byte_decomposition.go @@ -1,73 +1,74 @@ -package table +package assignment import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/schema" + tr "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" ) -// TraceComputation represents a computation which is applied to a -// high-level trace in order to expand it to a low-level trace. This -// typically involves adding columns, evaluating compute-only -// expressions, sorting columns, etc. -type TraceComputation interface { - Acceptable - // ExpandTrace expands a given trace to include "computed - // columns". These are columns which do not exist in the - // original trace, but are added during trace expansion to - // form the final trace. - ExpandTrace(Trace) error - // RequiredSpillage returns the minimum amount of spillage required to ensure - // valid traces are accepted in the presence of arbitrary padding. Note, - // spillage is currently assumed to be required only at the front of a - // trace. - RequiredSpillage() uint -} - // ByteDecomposition is part of a range constraint for wide columns (e.g. u32) // implemented using a byte decomposition. type ByteDecomposition struct { - // The target column being decomposed - Target string - // The bitwidth of the target column - BitWidth uint + // The source column being decomposed + source string + // Target columns needed for decomposition + targets []schema.Column } // NewByteDecomposition creates a new sorted permutation -func NewByteDecomposition(target string, width uint) *ByteDecomposition { - if width%8 != 0 { - panic("asymetric byte decomposition not yet supported") - } else if width == 0 { +func NewByteDecomposition(source string, width uint) *ByteDecomposition { + if width == 0 { panic("zero byte decomposition encountered") } + // Define type of bytes + U8 := schema.NewUintType(8) + // Construct target names + targets := make([]schema.Column, width) - return &ByteDecomposition{target, width} -} - -// Accepts checks whether a given trace has the necessary columns -func (p *ByteDecomposition) Accepts(tr Trace) error { - n := int(p.BitWidth / 8) - // - for i := 0; i < n; i++ { - colName := fmt.Sprintf("%s:%d", p.Target, i) - if !tr.HasColumn(colName) { - return fmt.Errorf("Trace missing byte decomposition column ({%s})", colName) - } + for i := uint(0); i < width; i++ { + name := fmt.Sprintf("%s:%d", source, i) + targets[i] = schema.NewColumn(name, U8) } // Done - return nil + return &ByteDecomposition{source, targets} +} + +func (p *ByteDecomposition) String() string { + return fmt.Sprintf("(decomposition %s %d)", p.source, len(p.targets)) } +// ============================================================================ +// Declaration Interface +// ============================================================================ + +// Columns returns the columns declared by this byte decomposition (in the order +// of declaration). +func (p *ByteDecomposition) Columns() util.Iterator[schema.Column] { + return util.NewArrayIterator[schema.Column](p.targets) +} + +// IsComputed Determines whether or not this declaration is computed. +func (p *ByteDecomposition) IsComputed() bool { + return true +} + +// ============================================================================ +// Assignment Interface +// ============================================================================ + // ExpandTrace expands a given trace to include the columns specified by a given // ByteDecomposition. This requires computing the value of each byte column in // the decomposition. -func (p *ByteDecomposition) ExpandTrace(tr Trace) error { +func (p *ByteDecomposition) ExpandTrace(tr tr.Trace) error { // Calculate how many bytes required. - n := int(p.BitWidth / 8) + n := len(p.targets) // Identify target column - target := tr.ColumnByName(p.Target) + target := tr.ColumnByName(p.source) // Extract column data to decompose - data := tr.ColumnByName(p.Target).Data() + data := tr.ColumnByName(p.source).Data() // Construct byte column data cols := make([][]*fr.Element, n) // Initialise columns @@ -85,17 +86,13 @@ func (p *ByteDecomposition) ExpandTrace(tr Trace) error { padding := decomposeIntoBytes(target.Padding(), n) // Finally, add byte columns to trace for i := 0; i < n; i++ { - col := fmt.Sprintf("%s:%d", p.Target, i) + col := fmt.Sprintf("%s:%d", p.source, i) tr.AddColumn(col, cols[i], padding[i]) } // Done return nil } -func (p *ByteDecomposition) String() string { - return fmt.Sprintf("(decomposition %s %d)", p.Target, p.BitWidth) -} - // RequiredSpillage returns the minimum amount of spillage required to ensure // valid traces are accepted in the presence of arbitrary padding. func (p *ByteDecomposition) RequiredSpillage() uint { diff --git a/pkg/schema/assignment/computed.go b/pkg/schema/assignment/computed.go new file mode 100644 index 0000000..26dbc0c --- /dev/null +++ b/pkg/schema/assignment/computed.go @@ -0,0 +1,102 @@ +package assignment + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/schema" + sc "github.com/consensys/go-corset/pkg/schema" + tr "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" +) + +// ComputedColumn describes a column whose values are computed on-demand, rather +// than being stored in a data array. Typically computed columns read values +// from other columns in a trace in order to calculate their value. There is an +// expectation that this computation is acyclic. Furthermore, computed columns +// give rise to "trace expansion". That is where the initial trace provided by +// the user is expanded by determining the value of all computed columns. +type ComputedColumn[E sc.Evaluable] struct { + name string + // The computation which accepts a given trace and computes + // the value of this column at a given row. + expr E +} + +// NewComputedColumn constructs a new computed column with a given name and +// determining expression. More specifically, that expression is used to +// compute the values for this column during trace expansion. +func NewComputedColumn[E sc.Evaluable](name string, expr E) *ComputedColumn[E] { + return &ComputedColumn[E]{name, expr} +} + +// nolint:revive +func (p *ComputedColumn[E]) String() string { + return fmt.Sprintf("(compute %s %s)", p.name, any(p.expr)) +} + +// Name returns the name of this computed column. +func (p *ComputedColumn[E]) Name() string { + return p.name +} + +// ============================================================================ +// Declaration Interface +// ============================================================================ + +// Columns returns the columns declared by this computed column. +func (p *ComputedColumn[E]) Columns() util.Iterator[schema.Column] { + // TODO: figure out appropriate type for computed column + column := schema.NewColumn(p.name, &schema.FieldType{}) + return util.NewUnitIterator[schema.Column](column) +} + +// IsComputed Determines whether or not this declaration is computed (which it +// is). +func (p *ComputedColumn[E]) IsComputed() bool { + return true +} + +// ============================================================================ +// Assignment Interface +// ============================================================================ + +// RequiredSpillage returns the minimum amount of spillage required to ensure +// this column can be correctly computed in the presence of arbitrary (front) +// padding. +func (p *ComputedColumn[E]) RequiredSpillage() uint { + // NOTE: Spillage is only currently considered to be necessary at the front + // (i.e. start) of a trace. This is because padding is always inserted at + // the front, never the back. As such, it is the maximum positive shift + // which determines how much spillage is required for this comptuation. + return p.expr.Bounds().End +} + +// ExpandTrace attempts to a new column to the trace which contains the result +// of evaluating a given expression on each row. If the column already exists, +// then an error is flagged. +func (p *ComputedColumn[E]) ExpandTrace(tr tr.Trace) error { + if tr.HasColumn(p.name) { + return fmt.Errorf("Computed column already exists ({%s})", p.name) + } + + data := make([]*fr.Element, tr.Height()) + // Expand the trace + for i := 0; i < len(data); i++ { + val := p.expr.EvalAt(i, tr) + if val != nil { + data[i] = val + } else { + zero := fr.NewElement(0) + data[i] = &zero + } + } + // Determine padding value. A negative row index is used here to ensure + // that all columns return their padding value which is then used to compute + // the padding value for *this* column. + padding := p.expr.EvalAt(-1, tr) + // Colunm needs to be expanded. + tr.AddColumn(p.name, data, padding) + // Done + return nil +} diff --git a/pkg/schema/assignment/data_column.go b/pkg/schema/assignment/data_column.go new file mode 100644 index 0000000..118db5c --- /dev/null +++ b/pkg/schema/assignment/data_column.go @@ -0,0 +1,58 @@ +package assignment + +import ( + "fmt" + + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/util" +) + +// DataColumn represents a column of user-provided values. +type DataColumn struct { + name string + // Expected type of values held in this column. Observe that this should be + // true for the input columns for any valid trace and, furthermore, every + // computed column should have values of this type. + datatype schema.Type +} + +// NewDataColumn constructs a new data column with a given name. +func NewDataColumn(name string, base schema.Type) *DataColumn { + return &DataColumn{name, base} +} + +// Name forms part of the ColumnSchema interface, and provides access to +// information about the ith column in a schema. +func (p *DataColumn) Name() string { + return p.name +} + +// Type Returns the expected type of data in this column +func (p *DataColumn) Type() schema.Type { + return p.datatype +} + +//nolint:revive +func (c *DataColumn) String() string { + if c.datatype.AsField() != nil { + return fmt.Sprintf("(column %s)", c.Name()) + } + + return fmt.Sprintf("(column %s :%s)", c.Name(), c.datatype) +} + +// ============================================================================ +// Declaration Interface +// ============================================================================ + +// Columns returns the columns declared by this computed column. +func (p *DataColumn) Columns() util.Iterator[schema.Column] { + column := schema.NewColumn(p.name, p.datatype) + return util.NewUnitIterator[schema.Column](column) +} + +// IsComputed Determines whether or not this declaration is computed (which data +// columns never are). +func (p *DataColumn) IsComputed() bool { + return false +} diff --git a/pkg/schema/assignment/lexicographic_sort.go b/pkg/schema/assignment/lexicographic_sort.go new file mode 100644 index 0000000..218df22 --- /dev/null +++ b/pkg/schema/assignment/lexicographic_sort.go @@ -0,0 +1,125 @@ +package assignment + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" +) + +// LexicographicSort provides the necessary computation for filling out columns +// added to enforce lexicographic sorting constraints between one or more source +// columns. Specifically, a delta column is required along with one selector +// column (binary) for each source column. +type LexicographicSort struct { + // The target columns to be filled. The first entry is for the delta + // column, and the remaining n entries are for the selector columns. + targets []schema.Column + // Source columns being sorted + sources []uint + signs []bool + bitwidth uint +} + +// NewLexicographicSort constructs a new LexicographicSorting assignment. +func NewLexicographicSort(prefix string, sources []uint, signs []bool, bitwidth uint) *LexicographicSort { + targets := make([]schema.Column, len(sources)+1) + // Create delta column + targets[0] = schema.NewColumn(fmt.Sprintf("%s:delta", prefix), schema.NewUintType(bitwidth)) + // Create selector columns + for i := range sources { + ithName := fmt.Sprintf("%s:%d", prefix, i) + targets[1+i] = schema.NewColumn(ithName, schema.NewUintType(1)) + } + // Done + return &LexicographicSort{targets, sources, signs, bitwidth} +} + +// ============================================================================ +// Declaration Interface +// ============================================================================ + +// Columns returns the columns declared by this assignment. +func (p *LexicographicSort) Columns() util.Iterator[schema.Column] { + return util.NewArrayIterator(p.targets) +} + +// IsComputed Determines whether or not this declaration is computed (which it +// is). +func (p *LexicographicSort) IsComputed() bool { + return true +} + +// ============================================================================ +// Assignment Interface +// ============================================================================ + +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *LexicographicSort) RequiredSpillage() uint { + return uint(0) +} + +// ExpandTrace adds columns as needed to support the LexicographicSortingGadget. +// That includes the delta column, and the bit selectors. +func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error { + zero := fr.NewElement(0) + one := fr.NewElement(1) + // Exact number of columns involved in the sort + ncols := len(p.sources) + // Determine how many rows to be constrained. + nrows := tr.Height() + // Initialise new data columns + delta := make([]*fr.Element, nrows) + bit := make([][]*fr.Element, ncols) + + for i := 0; i < ncols; i++ { + bit[i] = make([]*fr.Element, nrows) + } + + for i := 0; i < int(nrows); i++ { + set := false + // Initialise delta to zero + delta[i] = &zero + // Decide which row is the winner (if any) + for j := 0; j < ncols; j++ { + prev := tr.ColumnByIndex(p.sources[j]).Get(i - 1) + curr := tr.ColumnByIndex(p.sources[j]).Get(i) + + if !set && prev != nil && prev.Cmp(curr) != 0 { + var diff fr.Element + + bit[j][i] = &one + // Compute curr - prev + if p.signs[j] { + diff.Set(curr) + delta[i] = diff.Sub(&diff, prev) + } else { + diff.Set(prev) + delta[i] = diff.Sub(&diff, curr) + } + + set = true + } else { + bit[j][i] = &zero + } + } + } + // Add delta column data + tr.AddColumn(p.targets[0].Name(), delta, &zero) + // Add bit column data + for i := 0; i < ncols; i++ { + bitName := p.targets[1+i].Name() + tr.AddColumn(bitName, bit[i], &zero) + } + // Done. + return nil +} + +// String returns a string representation of this constraint. This is primarily +// used for debugging. +func (p *LexicographicSort) String() string { + return fmt.Sprintf("(lexer (%v) (%v) :%d))", any(p.targets), p.signs, p.bitwidth) +} diff --git a/pkg/schema/assignment/sorted_permutation.go b/pkg/schema/assignment/sorted_permutation.go new file mode 100644 index 0000000..e973f6f --- /dev/null +++ b/pkg/schema/assignment/sorted_permutation.go @@ -0,0 +1,130 @@ +package assignment + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/schema" + tr "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" +) + +// SortedPermutation declares one or more columns as sorted permutations of +// existing columns. +type SortedPermutation struct { + // The new (sorted) columns + targets []schema.Column + // The sorting criteria + Signs []bool + // The existing columns + Sources []string +} + +// NewSortedPermutation creates a new sorted permutation +func NewSortedPermutation(targets []schema.Column, signs []bool, sources []string) *SortedPermutation { + if len(targets) != len(signs) || len(signs) != len(sources) { + panic("target and source column widths must match") + } + + return &SortedPermutation{targets, signs, sources} +} + +// Targets returns the columns declared by this sorted permutation (in the order +// of declaration). This is the same as Columns(), except that it avoids using +// an iterator. +func (p *SortedPermutation) Targets() []schema.Column { + return p.targets +} + +// String returns a string representation of this constraint. This is primarily +// used for debugging. +func (p *SortedPermutation) String() string { + targets := "" + sources := "" + + index := 0 + for i := 0; i != len(p.targets); i++ { + if index != 0 { + targets += " " + } + + targets += p.targets[i].Name() + index++ + } + + for i, s := range p.Sources { + if i != 0 { + sources += " " + } + + if p.Signs[i] { + sources += fmt.Sprintf("+%s", s) + } else { + sources += fmt.Sprintf("-%s", s) + } + } + + return fmt.Sprintf("(permute (%s) (%s))", targets, sources) +} + +// ============================================================================ +// Declaration Interface +// ============================================================================ + +// Columns returns the columns declared by this sorted permutation (in the order +// of declaration). +func (p *SortedPermutation) Columns() util.Iterator[schema.Column] { + return util.NewArrayIterator(p.targets) +} + +// IsComputed Determines whether or not this declaration is computed. +func (p *SortedPermutation) IsComputed() bool { + return true +} + +// ============================================================================ +// Assignment Interface +// ============================================================================ + +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *SortedPermutation) RequiredSpillage() uint { + return uint(0) +} + +// ExpandTrace expands a given trace to include the columns specified by a given +// SortedPermutation. This requires copying the data in the source columns, and +// sorting that data according to the permutation criteria. +func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error { + // Ensure target columns don't exist + for i := p.Columns(); i.HasNext(); { + if tr.HasColumn(i.Next().Name()) { + panic("target column already exists") + } + } + + cols := make([][]*fr.Element, len(p.Sources)) + // Construct target columns + for i := 0; i < len(p.Sources); i++ { + src := p.Sources[i] + // Read column data to initialise permutation. + data := tr.ColumnByName(src).Data() + // Copy column data to initialise permutation. + cols[i] = make([]*fr.Element, len(data)) + copy(cols[i], data) + } + // Sort target columns + util.PermutationSort(cols, p.Signs) + // Physically add the columns + index := 0 + + for i := p.Columns(); i.HasNext(); { + dstColName := i.Next().Name() + srcCol := tr.ColumnByName(p.Sources[index]) + tr.AddColumn(dstColName, cols[index], srcCol.Padding()) + + index++ + } + // + return nil +} diff --git a/pkg/schema/constraint/permutation.go b/pkg/schema/constraint/permutation.go new file mode 100644 index 0000000..b60b990 --- /dev/null +++ b/pkg/schema/constraint/permutation.go @@ -0,0 +1,85 @@ +package constraint + +import ( + "errors" + "fmt" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" +) + +// PermutationConstraint declares a constraint that one column is a permutation +// of another. +type PermutationConstraint struct { + // The target columns + Targets []uint + // The source columns + Sources []uint +} + +// NewPermutationConstraint creates a new permutation +func NewPermutationConstraint(targets []uint, sources []uint) *PermutationConstraint { + if len(targets) != len(sources) { + panic("differeng number of target / source permutation columns") + } + + return &PermutationConstraint{targets, sources} +} + +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. +func (p *PermutationConstraint) RequiredSpillage() uint { + return uint(0) +} + +// Accepts checks whether a permutation holds between the source and +// target columns. +func (p *PermutationConstraint) Accepts(tr trace.Trace) error { + // Slice out data + src := sliceColumns(p.Sources, tr) + dst := sliceColumns(p.Targets, tr) + // Sanity check whether column exists + if !util.ArePermutationOf(dst, src) { + msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})", + p.Targets, p.Sources) + return errors.New(msg) + } + // Success + return nil +} + +func (p *PermutationConstraint) String() string { + targets := "" + sources := "" + + for i, s := range p.Targets { + if i != 0 { + targets += " " + } + + targets += fmt.Sprintf("%d", s) + } + + for i, s := range p.Sources { + if i != 0 { + sources += " " + } + + sources += fmt.Sprintf("%d", s) + } + + return fmt.Sprintf("(permutation (%s) (%s))", targets, sources) +} + +func sliceColumns(columns []uint, tr trace.Trace) [][]*fr.Element { + // Allocate return array + cols := make([][]*fr.Element, len(columns)) + // Slice out the data + for i, n := range columns { + nth := tr.ColumnByIndex(n) + cols[i] = nth.Data() + } + // Done + return cols +} diff --git a/pkg/schema/constraint/range.go b/pkg/schema/constraint/range.go new file mode 100644 index 0000000..e5d8455 --- /dev/null +++ b/pkg/schema/constraint/range.go @@ -0,0 +1,59 @@ +package constraint + +import ( + "errors" + "fmt" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/trace" +) + +// RangeConstraint restricts all values in a given column to be within +// a range [0..n) for some bound n. For example, a bound of 256 would +// restrict all values to be bytes. At this time, range constraints +// are explicitly limited at the arithmetic level to bounds of at most +// 256 (i.e. to ensuring bytes). This restriction is somewhat +// arbitrary and is determined by the underlying prover. +type RangeConstraint struct { + // Column index to be constrained. + Column uint + // The actual constraint itself, namely an expression which + // should evaluate to zero. NOTE: an fr.Element is used here + // to store the bound simply to make the necessary comparison + // against table data more direct. + Bound *fr.Element +} + +// NewRangeConstraint constructs a new Range constraint! +func NewRangeConstraint(column uint, bound *fr.Element) *RangeConstraint { + var n fr.Element = fr.NewElement(256) + if bound.Cmp(&n) > 0 { + panic("Range constraint for bitwidth above 8 not supported") + } + + return &RangeConstraint{column, bound} +} + +// Accepts checks whether a range constraint evaluates to zero on +// every row of a table. If so, return nil otherwise return an error. +func (p *RangeConstraint) Accepts(tr trace.Trace) error { + column := tr.ColumnByIndex(p.Column) + for k := 0; k < int(tr.Height()); k++ { + // Get the value on the kth row + kth := column.Get(k) + // Perform the bounds check + if kth != nil && kth.Cmp(p.Bound) >= 0 { + name := column.Name() + // Construct useful error message + msg := fmt.Sprintf("value out-of-bounds (row %d, %s)", kth, name) + // Evaluation failure + return errors.New(msg) + } + } + // All good + return nil +} + +func (p *RangeConstraint) String() string { + return fmt.Sprintf("(range #%d %s)", p.Column, p.Bound) +} diff --git a/pkg/schema/constraint/type.go b/pkg/schema/constraint/type.go new file mode 100644 index 0000000..857fe36 --- /dev/null +++ b/pkg/schema/constraint/type.go @@ -0,0 +1,61 @@ +package constraint + +import ( + "errors" + "fmt" + + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" +) + +// TypeConstraint restricts all values in a given column to be within +// a range [0..n) for some bound n. Any bound is supported, and the system will +// choose the best underlying implementation as needed. +type TypeConstraint struct { + // Column to be constrained. + column string + // The actual constraint itself, namely an expression which + // should evaluate to zero. NOTE: an fr.Element is used here + // to store the bound simply to make the necessary comparison + // against table data more direct. + expected schema.Type +} + +// NewTypeConstraint constructs a new Range constraint! +func NewTypeConstraint(column string, expected schema.Type) *TypeConstraint { + return &TypeConstraint{column, expected} +} + +// Target returns the target column for this constraint. +func (p *TypeConstraint) Target() string { + return p.column +} + +// Type returns the expected for all values in the target column. +func (p *TypeConstraint) Type() schema.Type { + return p.expected +} + +// Accepts checks whether a range constraint evaluates to zero on +// every row of a table. If so, return nil otherwise return an error. +func (p *TypeConstraint) Accepts(tr trace.Trace) error { + column := tr.ColumnByName(p.column) + for k := 0; k < int(tr.Height()); k++ { + // Get the value on the kth row + kth := column.Get(k) + // Perform the type check + if kth != nil && !p.expected.Accept(kth) { + name := column.Name() + // Construct useful error message + msg := fmt.Sprintf("value out-of-bounds (row %d, %s)", kth, name) + // Evaluation failure + return errors.New(msg) + } + } + // All good + return nil +} + +func (p *TypeConstraint) String() string { + return fmt.Sprintf("(type %s %s)", p.column, p.expected.String()) +} diff --git a/pkg/schema/constraint/vanishing.go b/pkg/schema/constraint/vanishing.go new file mode 100644 index 0000000..ec7be35 --- /dev/null +++ b/pkg/schema/constraint/vanishing.go @@ -0,0 +1,127 @@ +package constraint + +import ( + "errors" + "fmt" + + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" +) + +// ZeroTest is a wrapper which converts an Evaluable expression into a Testable +// constraint. Specifically, by checking whether or not the given expression +// vanishes (i.e. evaluates to zero). +type ZeroTest[E schema.Evaluable] struct { + Expr E +} + +// TestAt determines whether or not a given expression evaluates to zero. +// Observe that if the expression is undefined, then it is assumed not to hold. +func (p ZeroTest[E]) TestAt(row int, tr trace.Trace) bool { + val := p.Expr.EvalAt(row, tr) + return val != nil && val.IsZero() +} + +// Bounds determines the bounds for this zero test. +func (p ZeroTest[E]) Bounds() util.Bounds { + return p.Expr.Bounds() +} + +// String generates a human-readble string. +// +//nolint:revive +func (p ZeroTest[E]) String() string { + return fmt.Sprintf("%s", any(p.Expr)) +} + +// VanishingConstraint specifies a constraint which should hold on every row of the +// table. The only exception is when the constraint is undefined (e.g. because +// it references a non-existent table cell). In such case, the constraint is +// ignored. This is parameterised by the type of the constraint expression. +// Thus, we can reuse this definition across the various intermediate +// representations (e.g. Mid-Level IR, Arithmetic IR, etc). +type VanishingConstraint[T schema.Testable] struct { + // A unique identifier for this constraint. This is primarily + // useful for debugging. + Handle string + // Indicates (when nil) a global constraint that applies to all rows. + // Otherwise, indicates a local constraint which applies to the specific row + // given here. + Domain *int + // The actual constraint itself (e.g. an expression which + // should evaluate to zero, etc) + Constraint T +} + +// NewVanishingConstraint constructs a new vanishing constraint! +func NewVanishingConstraint[T schema.Testable](handle string, domain *int, constraint T) *VanishingConstraint[T] { + return &VanishingConstraint[T]{handle, domain, constraint} +} + +// GetHandle returns the handle associated with this constraint. +func (p *VanishingConstraint[T]) GetHandle() string { + return p.Handle +} + +// Accepts checks whether a vanishing constraint evaluates to zero on every row +// of a table. If so, return nil otherwise return an error. +// +//nolint:revive +func (p *VanishingConstraint[T]) Accepts(tr trace.Trace) error { + if p.Domain == nil { + // Global Constraint + return HoldsGlobally(p.Handle, p.Constraint, tr) + } + // Check specific row + return HoldsLocally(*p.Domain, p.Handle, p.Constraint, tr) +} + +// HoldsGlobally checks whether a given expression vanishes (i.e. evaluates to +// zero) for all rows of a trace. If not, report an appropriate error. +func HoldsGlobally[T schema.Testable](handle string, constraint T, tr trace.Trace) error { + // Determine well-definedness bounds for this constraint + bounds := constraint.Bounds() + // Sanity check enough rows + if bounds.End < tr.Height() { + // Check all in-bounds values + for k := bounds.Start; k < (tr.Height() - bounds.End); k++ { + if err := HoldsLocally(int(k), handle, constraint, tr); err != nil { + return err + } + } + } + // Success + return nil +} + +// HoldsLocally checks whether a given constraint holds (e.g. vanishes) on a +// specific row of a trace. If not, report an appropriate error. +func HoldsLocally[T schema.Testable](k int, handle string, constraint T, tr trace.Trace) error { + // Negative rows calculated from end of trace. + if k < 0 { + k += int(tr.Height()) + } + // Check whether it holds or not + if !constraint.TestAt(k, tr) { + // Construct useful error message + msg := fmt.Sprintf("constraint \"%s\" does not hold (row %d)", handle, k) + // Evaluation failure + return errors.New(msg) + } + // Success + return nil +} + +// String generates a human-readble string. +// +//nolint:revive +func (p *VanishingConstraint[T]) String() string { + if p.Domain == nil { + return fmt.Sprintf("(vanish %s %s)", p.Handle, any(p.Constraint)) + } else if *p.Domain == 0 { + return fmt.Sprintf("(vanish:first %s %s)", p.Handle, any(p.Constraint)) + } + // + return fmt.Sprintf("(vanish:last %s %s)", p.Handle, any(p.Constraint)) +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go new file mode 100644 index 0000000..12fa7ca --- /dev/null +++ b/pkg/schema/schema.go @@ -0,0 +1,129 @@ +package schema + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + tr "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" +) + +// Schema represents a schema which can be used to manipulate a trace. +type Schema interface { + // Assignments returns an array over the assignments of this schema. That + // is, the subset of declarations whose trace values can be computed from + // the inputs. + Assignments() util.Iterator[Assignment] + + // Columns returns an array over the underlying columns of this schema. + // Specifically, the index of a column in this array is its column index. + Columns() util.Iterator[Column] + + // Constraints returns an array over the underlying constraints of this + // schema. + Constraints() util.Iterator[Constraint] + + // Declarations returns an array over the column declarations of this + // schema. + Declarations() util.Iterator[Declaration] +} + +// Assignment represents a schema element which declares one or more columns +// whose values are "assigned" from the results of a computation. An assignment +// is a column group which, additionally, can provide information about the +// computation (e.g. which columns it depends upon, etc). +type Assignment interface { + Declaration + + // ExpandTrace expands a given trace to include "computed + // columns". These are columns which do not exist in the + // original trace, but are added during trace expansion to + // form the final trace. + ExpandTrace(tr.Trace) error + // RequiredSpillage returns the minimum amount of spillage required to ensure + // valid traces are accepted in the presence of arbitrary padding. Note, + // spillage is currently assumed to be required only at the front of a + // trace. + RequiredSpillage() uint +} + +// Constraint represents an element which can "accept" a trace, or either reject +// with an error (or eventually perhaps report a warning). +type Constraint interface { + Accepts(tr.Trace) error +} + +// Declaration represents an element which declares one (or more) columns in the +// schema. For example, a single data column is (for now) always a column group +// of size 1. Likewise, an array of size n is a column group of size n, etc. +type Declaration interface { + // Return the declared columns (in the order of declaration). + Columns() util.Iterator[Column] + + // Determines whether or not this declaration is computed. + IsComputed() bool +} + +// Evaluable captures something which can be evaluated on a given table row to +// produce an evaluation point. For example, expressions in the +// Mid-Level or Arithmetic-Level IR can all be evaluated at rows of a +// table. +type Evaluable interface { + util.Boundable + // EvalAt evaluates this expression in a given tabular context. + // Observe that if this expression is *undefined* within this + // context then it returns "nil". An expression can be + // undefined for several reasons: firstly, if it accesses a + // row which does not exist (e.g. at index -1); secondly, if + // it accesses a column which does not exist. + EvalAt(int, tr.Trace) *fr.Element +} + +// Testable captures the notion of a constraint which can be tested on a given +// row of a given trace. It is very similar to Evaluable, except that it only +// indicates success or failure. The reason for using this interface over +// Evaluable is that, for historical reasons, constraints at the HIR cannot be +// Evaluable (i.e. because they return multiple values, rather than a single +// value). However, constraints at the HIR level remain testable. +type Testable interface { + util.Boundable + + // TestAt evaluates this expression in a given tabular context and checks it + // against zero. Observe that if this expression is *undefined* within this + // context then it returns "nil". An expression can be undefined for + // several reasons: firstly, if it accesses a row which does not exist (e.g. + // at index -1); secondly, if it accesses a column which does not exist. + TestAt(int, tr.Trace) bool +} + +// ============================================================================ +// Column +// ============================================================================ + +// Column represents a specific column in the schema that, ultimately, will +// correspond 1:1 with a column in the trace. +type Column struct { + // Returns the name of this column + name string + // Returns the expected type of data in this column + datatype Type +} + +// NewColumn constructs a new column +func NewColumn(name string, datatype Type) Column { + return Column{name, datatype} +} + +// Name returns the name of this column +func (p Column) Name() string { + return p.name +} + +// Type returns the expected type of data in this column +func (p Column) Type() Type { + return p.datatype +} + +func (p Column) String() string { + return fmt.Sprintf("%s:%s", p.name, p.datatype.String()) +} diff --git a/pkg/schema/schemas.go b/pkg/schema/schemas.go new file mode 100644 index 0000000..8c3c95c --- /dev/null +++ b/pkg/schema/schemas.go @@ -0,0 +1,94 @@ +package schema + +import ( + "fmt" + + tr "github.com/consensys/go-corset/pkg/trace" +) + +// RequiredSpillage returns the minimum amount of spillage required to ensure +// valid traces are accepted in the presence of arbitrary padding. Spillage can +// only arise from computations as this is where values outside of the user's +// control are determined. +func RequiredSpillage(schema Schema) uint { + // Ensures always at least one row of spillage (referred to as the "initial + // padding row") + mx := uint(1) + // Determine if any more spillage required + for i := schema.Assignments(); i.HasNext(); { + // Get ith assignment + ith := i.Next() + // Incorporate its spillage requirements + mx = max(mx, ith.RequiredSpillage()) + } + + return mx +} + +// ExpandTrace expands a given trace according to this schema. More +// specifically, that means computing the actual values for any assignments. +// Observe that assignments have to be computed in the correct order. +func ExpandTrace(schema Schema, trace tr.Trace) error { + // Compute each assignment in turn + for i := schema.Assignments(); i.HasNext(); { + // Get ith assignment + ith := i.Next() + // Compute ith assignment(s) + if err := ith.ExpandTrace(trace); err != nil { + return err + } + } + // Done + return nil +} + +// Accepts determines whether this schema will accept a given trace. That +// is, whether or not the given trace adheres to the schema. A trace can fail +// to adhere to the schema for a variety of reasons, such as having a constraint +// which does not hold. +// +//nolint:revive +func Accepts(schema Schema, trace tr.Trace) error { + // Check each constraint in turn + for i := schema.Constraints(); i.HasNext(); { + // Get ith constraint + ith := i.Next() + // Check it holds (or report an error) + if err := ith.Accepts(trace); err != nil { + return err + } + } + // Success + return nil +} + +// ColumnIndexOf returns the column index of the column with the given name, or +// returns false if no matching column exists. +func ColumnIndexOf(schema Schema, name string) (uint, bool) { + return schema.Columns().Find(func(c Column) bool { + return c.Name() == name + }) +} + +// ColumnByName returns the column with the matching name, or panics if no such +// column exists. +func ColumnByName(schema Schema, name string) Column { + var col Column + // Attempt to determine the index of this column + _, ok := schema.Columns().Find(func(c Column) bool { + col = c + return c.Name() == name + }) + // If we found it, then done. + if ok { + return col + } + // Otherwise panic. + panic(fmt.Sprintf("unknown column %s", name)) +} + +// HasColumn checks whether a column of the given name is declared within the schema. +func HasColumn(schema Schema, name string) bool { + _, ok := ColumnIndexOf(schema, name) + return ok +} diff --git a/pkg/table/type.go b/pkg/schema/type.go similarity index 89% rename from pkg/table/type.go rename to pkg/schema/type.go index 09ced76..fe7be6b 100644 --- a/pkg/table/type.go +++ b/pkg/schema/type.go @@ -1,4 +1,4 @@ -package table +package schema import ( "fmt" @@ -33,12 +33,10 @@ type UintType struct { nbits uint // The numeric bound of all values in this type (e.g. 2^8 for u8, etc). bound *fr.Element - // Indicates whether or not this type should be enforced (or not). - checked bool } // NewUintType constructs a new integer type for a given bit width. -func NewUintType(nbits uint, checked bool) *UintType { +func NewUintType(nbits uint) *UintType { var maxBigInt big.Int // Compute 2^n maxBigInt.Exp(big.NewInt(2), big.NewInt(int64(nbits)), nil) @@ -46,7 +44,7 @@ func NewUintType(nbits uint, checked bool) *UintType { bound := new(fr.Element) bound.SetBigInt(&maxBigInt) - return &UintType{nbits, bound, checked} + return &UintType{nbits, bound} } // AsUint accesses this type assuming it is a Uint. Since this is the case, @@ -61,12 +59,6 @@ func (p *UintType) AsField() *FieldType { return nil } -// Checked identifies whether the type of this column must be enforced using one -// more constraints and/or columns. -func (p *UintType) Checked() bool { - return p.checked -} - // Accept determines whether a given value is an element of this type. For // example, 123 is an element of the type u8 whilst 256 is not. func (p *UintType) Accept(val *fr.Element) bool { diff --git a/pkg/table/column.go b/pkg/table/column.go deleted file mode 100644 index 971dc85..0000000 --- a/pkg/table/column.go +++ /dev/null @@ -1,338 +0,0 @@ -package table - -import ( - "errors" - "fmt" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/util" -) - -// DataColumn represents a column of user-provided values. -type DataColumn[T Type] struct { - name string - // Expected type of values held in this column. Observe that this type is - // enforced only when checking is enabled. Unchecked typed columns can - // still make sense when their values are implied by some other constraint. - Type T - // Indicates whether or not this column was created by the compiler (i.e. is - // synthetic), or was specified by the user (i.e. is natural). - Synthetic bool -} - -// NewDataColumn constructs a new data column with a given name. -func NewDataColumn[T Type](name string, base T, synthetic bool) *DataColumn[T] { - return &DataColumn[T]{name, base, synthetic} -} - -// Name forms part of the ColumnSchema interface, and provides access to -// information about the ith column in a schema. -func (p *DataColumn[T]) Name() string { - return p.name -} - -// Width forms part of the ColumnGroup interface, and determines how many -// columns are in the group. Data columns already represent a group of size 1. -func (p *DataColumn[T]) Width() uint { - return 1 -} - -// NameOf forms part of the ColumnGroup interface, and provides access to the -// ith column in a group. Since a data column represents a group of size 1, -// there is only ever one name. -func (p *DataColumn[T]) NameOf(index uint) string { - return p.name -} - -// IsSynthetic forms part of the ColumnGroup interface, and determines whether or -// not the group (as a whole) is synthetic. -func (p *DataColumn[T]) IsSynthetic() bool { - return p.Synthetic -} - -// Accepts determines whether or not this column accepts the given trace. For a -// data column, this means ensuring that all elements are value for the columns -// type. -// -//nolint:revive -func (p *DataColumn[T]) Accepts(tr Trace) error { - // Only check for non-field types. This is simply because a column with the - // field type always accepts everything. - if p.Type.AsField() == nil { - // Access corresponding column in trace - col := tr.ColumnByName(p.name) - // Check constraints accepted - for i := 0; i < int(tr.Height()); i++ { - val := col.Get(i) - - if !p.Type.Accept(val) { - // Construct useful error message - msg := fmt.Sprintf("column %s value out-of-bounds (row %d, %s)", p.Name(), i, val) - // Evaluation failure - return errors.New(msg) - } - } - } - // All good - return nil -} - -//nolint:revive -func (c *DataColumn[T]) String() string { - if c.Type.AsField() != nil { - return fmt.Sprintf("(column %s)", c.Name()) - } - - return fmt.Sprintf("(column %s :%s)", c.Name(), c.Type) -} - -// ComputedColumn describes a column whose values are computed on-demand, rather -// than being stored in a data array. Typically computed columns read values -// from other columns in a trace in order to calculate their value. There is an -// expectation that this computation is acyclic. Furthermore, computed columns -// give rise to "trace expansion". That is where the initial trace provided by -// the user is expanded by determining the value of all computed columns. -type ComputedColumn[E Evaluable] struct { - Name string - // The computation which accepts a given trace and computes - // the value of this column at a given row. - Expr E -} - -// NewComputedColumn constructs a new computed column with a given name and -// determining expression. More specifically, that expression is used to -// compute the values for this column during trace expansion. -func NewComputedColumn[E Evaluable](name string, expr E) *ComputedColumn[E] { - return &ComputedColumn[E]{ - Name: name, - Expr: expr, - } -} - -// RequiredSpillage returns the minimum amount of spillage required to ensure -// this column can be correctly computed in the presence of arbitrary (front) -// padding. -func (c *ComputedColumn[E]) RequiredSpillage() uint { - // NOTE: Spillage is only currently considered to be necessary at the front - // (i.e. start) of a trace. This is because padding is always inserted at - // the front, never the back. As such, it is the maximum positive shift - // which determines how much spillage is required for this comptuation. - return c.Expr.Bounds().End -} - -// Accepts determines whether or not this column accepts the given trace. For a -// data column, this means ensuring that all elements are value for the columns -// type. -// -//nolint:revive -func (c *ComputedColumn[E]) Accepts(tr Trace) error { - // Check column in trace! - if !tr.HasColumn(c.Name) { - return fmt.Errorf("Trace missing computed column ({%s})", c.Name) - } - - return nil -} - -// ExpandTrace attempts to a new column to the trace which contains the result -// of evaluating a given expression on each row. If the column already exists, -// then an error is flagged. -func (c *ComputedColumn[E]) ExpandTrace(tr Trace) error { - if tr.HasColumn(c.Name) { - return fmt.Errorf("Computed column already exists ({%s})", c.Name) - } - - data := make([]*fr.Element, tr.Height()) - // Expand the trace - for i := 0; i < len(data); i++ { - val := c.Expr.EvalAt(i, tr) - if val != nil { - data[i] = val - } else { - zero := fr.NewElement(0) - data[i] = &zero - } - } - // Determine padding value. A negative row index is used here to ensure - // that all columns return their padding value which is then used to compute - // the padding value for *this* column. - padding := c.Expr.EvalAt(-1, tr) - // Colunm needs to be expanded. - tr.AddColumn(c.Name, data, padding) - // Done - return nil -} - -// nolint:revive -func (c *ComputedColumn[E]) String() string { - return fmt.Sprintf("(compute %s %s)", c.Name, any(c.Expr)) -} - -// =================================================================== -// Sorted Permutations -// =================================================================== - -// SortedPermutation declares one or more columns as sorted permutations of -// existing columns. -type SortedPermutation struct { - // The new (sorted) columns - Targets []string - // The sorting criteria - Signs []bool - // The existing columns - Sources []string -} - -// NewSortedPermutation creates a new sorted permutation -func NewSortedPermutation(targets []string, signs []bool, sources []string) *SortedPermutation { - if len(targets) != len(signs) || len(signs) != len(sources) { - panic("target and source column widths must match") - } - - return &SortedPermutation{targets, signs, sources} -} - -// Width forms part of the ColumnGroup interface, and provides access to the -// ith column in a group. Sorted permutations have define one or more new -// columns. -func (p *SortedPermutation) Width() uint { - return uint(len(p.Targets)) -} - -// NameOf forms part of the ColumnGroup interface, and provides access to the -// ith column in a group. Since a data column represents a group of size 1, -// there is only ever one name. -func (p *SortedPermutation) NameOf(index uint) string { - return p.Targets[index] -} - -// IsSynthetic forms part of the ColumnGroup interface which determines whether -// or not the group (as a whole) is synthetic. Sorted permutation columns are -// always synthetic. -func (p *SortedPermutation) IsSynthetic() bool { - return true -} - -// RequiredSpillage returns the minimum amount of spillage required to ensure -// valid traces are accepted in the presence of arbitrary padding. -func (p *SortedPermutation) RequiredSpillage() uint { - return uint(0) -} - -// Accepts checks whether a sorted permutation holds between the -// source and target columns. -func (p *SortedPermutation) Accepts(tr Trace) error { - // Sanity check columns well formed. - if err := validPermutationColumns(p.Targets, p.Sources, tr); err != nil { - return err - } - // Slice out data - src := sliceMatchingColumns(p.Sources, tr) - dst := sliceMatchingColumns(p.Targets, tr) - // Sanity check whether column exists - if !util.ArePermutationOf(dst, src) { - msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})", - p.Targets, p.Sources) - return errors.New(msg) - } - // Check that target columns are sorted lexicographically. - if util.AreLexicographicallySorted(dst, p.Signs) { - return nil - } - // Error case - msg := fmt.Sprintf("Permutation columns not lexicographically sorted ({%s})", p.Targets) - // Done - return errors.New(msg) -} - -// ExpandTrace expands a given trace to include the columns specified by a given -// SortedPermutation. This requires copying the data in the source columns, and -// sorting that data according to the permutation criteria. -func (p *SortedPermutation) ExpandTrace(tr Trace) error { - // Ensure target columns don't exist - for _, col := range p.Targets { - if tr.HasColumn(col) { - panic("target column already exists") - } - } - - cols := make([][]*fr.Element, len(p.Sources)) - // Construct target columns - for i := 0; i < len(p.Targets); i++ { - src := p.Sources[i] - // Read column data to initialise permutation. - data := tr.ColumnByName(src).Data() - // Copy column data to initialise permutation. - cols[i] = make([]*fr.Element, len(data)) - copy(cols[i], data) - } - // Sort target columns - util.PermutationSort(cols, p.Signs) - // Physically add the columns - for i := 0; i < len(p.Targets); i++ { - dstColName := p.Targets[i] - srcCol := tr.ColumnByName(p.Sources[i]) - tr.AddColumn(dstColName, cols[i], srcCol.Padding()) - } - // - return nil -} - -// String returns a string representation of this constraint. This is primarily -// used for debugging. -func (p *SortedPermutation) String() string { - targets := "" - sources := "" - - for i, s := range p.Targets { - if i != 0 { - targets += " " - } - - targets += s - } - - for i, s := range p.Sources { - if i != 0 { - sources += " " - } - - if p.Signs[i] { - sources += fmt.Sprintf("+%s", s) - } else { - sources += fmt.Sprintf("-%s", s) - } - } - - return fmt.Sprintf("(permute (%s) (%s))", targets, sources) -} - -func validPermutationColumns(targets []string, sources []string, tr Trace) error { - ncols := len(targets) - // Sanity check matching length - if len(sources) != ncols { - return fmt.Errorf("Number of source and target columns differs") - } - // Check required columns in trace - for i := 0; i < ncols; i++ { - if !tr.HasColumn(targets[i]) { - return fmt.Errorf("Trace missing permutation target column ({%s})", targets[i]) - } else if !tr.HasColumn(sources[i]) { - return fmt.Errorf("Trace missing permutation source ({%s})", sources[i]) - } - } - // - return nil -} - -func sliceMatchingColumns(names []string, tr Trace) [][]*fr.Element { - // Allocate return array - cols := make([][]*fr.Element, len(names)) - // Slice out the data - for i, n := range names { - nth := tr.ColumnByName(n) - cols[i] = nth.Data() - } - // Done - return cols -} diff --git a/pkg/table/constraints.go b/pkg/table/constraints.go deleted file mode 100644 index eddc407..0000000 --- a/pkg/table/constraints.go +++ /dev/null @@ -1,351 +0,0 @@ -package table - -import ( - "errors" - "fmt" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/util" -) - -// Evaluable captures something which can be evaluated on a given table row to -// produce an evaluation point. For example, expressions in the -// Mid-Level or Arithmetic-Level IR can all be evaluated at rows of a -// table. -type Evaluable interface { - util.Boundable - // EvalAt evaluates this expression in a given tabular context. - // Observe that if this expression is *undefined* within this - // context then it returns "nil". An expression can be - // undefined for several reasons: firstly, if it accesses a - // row which does not exist (e.g. at index -1); secondly, if - // it accesses a column which does not exist. - EvalAt(int, Trace) *fr.Element -} - -// Testable captures the notion of a constraint which can be tested on a given -// row of a given trace. It is very similar to Evaluable, except that it only -// indicates success or failure. The reason for using this interface over -// Evaluable is that, for historical reasons, constraints at the HIR cannot be -// Evaluable (i.e. because they return multiple values, rather than a single -// value). However, constraints at the HIR level remain testable. -type Testable interface { - util.Boundable - - // TestAt evaluates this expression in a given tabular context and checks it - // against zero. Observe that if this expression is *undefined* within this - // context then it returns "nil". An expression can be undefined for - // several reasons: firstly, if it accesses a row which does not exist (e.g. - // at index -1); secondly, if it accesses a column which does not exist. - TestAt(int, Trace) bool -} - -// =================================================================== -// Vanishing Constraints -// =================================================================== - -// ZeroTest is a wrapper which converts an Evaluable expression into a Testable -// constraint. Specifically, by checking whether or not the given expression -// vanishes (i.e. evaluates to zero). -type ZeroTest[E Evaluable] struct { - Expr E -} - -// TestAt determines whether or not a given expression evaluates to zero. -// Observe that if the expression is undefined, then it is assumed not to hold. -func (p ZeroTest[E]) TestAt(row int, tr Trace) bool { - val := p.Expr.EvalAt(row, tr) - return val != nil && val.IsZero() -} - -// Bounds determines the bounds for this zero test. -func (p ZeroTest[E]) Bounds() util.Bounds { - return p.Expr.Bounds() -} - -// String generates a human-readble string. -// -//nolint:revive -func (p ZeroTest[E]) String() string { - return fmt.Sprintf("%s", any(p.Expr)) -} - -// RowConstraint specifies a constraint which should hold on every row of the -// table. The only exception is when the constraint is undefined (e.g. because -// it references a non-existent table cell). In such case, the constraint is -// ignored. This is parameterised by the type of the constraint expression. -// Thus, we can reuse this definition across the various intermediate -// representations (e.g. Mid-Level IR, Arithmetic IR, etc). -type RowConstraint[T Testable] struct { - // A unique identifier for this constraint. This is primarily - // useful for debugging. - Handle string - // Indicates (when nil) a global constraint that applies to all rows. - // Otherwise, indicates a local constraint which applies to the specific row - // given here. - Domain *int - // The actual constraint itself (e.g. an expression which - // should evaluate to zero, etc) - Constraint T -} - -// NewRowConstraint constructs a new vanishing constraint! -func NewRowConstraint[T Testable](handle string, domain *int, constraint T) *RowConstraint[T] { - return &RowConstraint[T]{handle, domain, constraint} -} - -// GetHandle returns the handle associated with this constraint. -func (p *RowConstraint[T]) GetHandle() string { - return p.Handle -} - -// Accepts checks whether a vanishing constraint evaluates to zero on every row -// of a table. If so, return nil otherwise return an error. -// -//nolint:revive -func (p *RowConstraint[T]) Accepts(tr Trace) error { - if p.Domain == nil { - // Global Constraint - return HoldsGlobally(p.Handle, p.Constraint, tr) - } - // Check specific row - return HoldsLocally(*p.Domain, p.Handle, p.Constraint, tr) -} - -// HoldsGlobally checks whether a given expression vanishes (i.e. evaluates to -// zero) for all rows of a trace. If not, report an appropriate error. -func HoldsGlobally[T Testable](handle string, constraint T, tr Trace) error { - // Determine well-definedness bounds for this constraint - bounds := constraint.Bounds() - // Sanity check enough rows - if bounds.End < tr.Height() { - // Check all in-bounds values - for k := bounds.Start; k < (tr.Height() - bounds.End); k++ { - if err := HoldsLocally(int(k), handle, constraint, tr); err != nil { - return err - } - } - } - // Success - return nil -} - -// HoldsLocally checks whether a given constraint holds (e.g. vanishes) on a -// specific row of a trace. If not, report an appropriate error. -func HoldsLocally[T Testable](k int, handle string, constraint T, tr Trace) error { - // Negative rows calculated from end of trace. - if k < 0 { - k += int(tr.Height()) - } - // Check whether it holds or not - if !constraint.TestAt(k, tr) { - // Construct useful error message - msg := fmt.Sprintf("constraint \"%s\" does not hold (row %d)", handle, k) - // Evaluation failure - return errors.New(msg) - } - // Success - return nil -} - -// String generates a human-readble string. -// -//nolint:revive -func (p *RowConstraint[T]) String() string { - if p.Domain == nil { - return fmt.Sprintf("(vanish %s %s)", p.Handle, any(p.Constraint)) - } else if *p.Domain == 0 { - return fmt.Sprintf("(vanish:first %s %s)", p.Handle, any(p.Constraint)) - } - // - return fmt.Sprintf("(vanish:last %s %s)", p.Handle, any(p.Constraint)) -} - -// =================================================================== -// Range Constraint -// =================================================================== - -// RangeConstraint restricts all values in a given column to be within -// a range [0..n) for some bound n. For example, a bound of 256 would -// restrict all values to be bytes. At this time, range constraints -// are explicitly limited at the arithmetic level to bounds of at most -// 256 (i.e. to ensuring bytes). This restriction is somewhat -// arbitrary and is determined by the underlying prover. -type RangeConstraint struct { - // Column index to be constrained. - Column uint - // The actual constraint itself, namely an expression which - // should evaluate to zero. NOTE: an fr.Element is used here - // to store the bound simply to make the necessary comparison - // against table data more direct. - Bound *fr.Element -} - -// NewRangeConstraint constructs a new Range constraint! -func NewRangeConstraint(column uint, bound *fr.Element) *RangeConstraint { - var n fr.Element = fr.NewElement(256) - if bound.Cmp(&n) > 0 { - panic("Range constraint for bitwidth above 8 not supported") - } - - return &RangeConstraint{column, bound} -} - -// IsAir is a marker that indicates this is an AIR column. -func (p *RangeConstraint) IsAir() bool { return true } - -// Accepts checks whether a range constraint evaluates to zero on -// every row of a table. If so, return nil otherwise return an error. -func (p *RangeConstraint) Accepts(tr Trace) error { - column := tr.ColumnByIndex(p.Column) - for k := 0; k < int(tr.Height()); k++ { - // Get the value on the kth row - kth := column.Get(k) - // Perform the bounds check - if kth != nil && kth.Cmp(p.Bound) >= 0 { - name := column.Name() - // Construct useful error message - msg := fmt.Sprintf("value out-of-bounds (row %d, %s)", kth, name) - // Evaluation failure - return errors.New(msg) - } - } - // All good - return nil -} - -func (p *RangeConstraint) String() string { - return fmt.Sprintf("(range #%d %s)", p.Column, p.Bound) -} - -// =================================================================== -// Permutation -// =================================================================== - -// Permutation declares a constraint that one column is a permutation -// of another. -type Permutation struct { - // The target columns - Targets []uint - // The source columns - Sources []uint -} - -// NewPermutation creates a new permutation -func NewPermutation(targets []uint, sources []uint) *Permutation { - if len(targets) != len(sources) { - panic("differeng number of target / source permutation columns") - } - - return &Permutation{targets, sources} -} - -// RequiredSpillage returns the minimum amount of spillage required to ensure -// valid traces are accepted in the presence of arbitrary padding. -func (p *Permutation) RequiredSpillage() uint { - return uint(0) -} - -// Accepts checks whether a permutation holds between the source and -// target columns. -func (p *Permutation) Accepts(tr Trace) error { - // Slice out data - src := sliceColumns(p.Sources, tr) - dst := sliceColumns(p.Targets, tr) - // Sanity check whether column exists - if !util.ArePermutationOf(dst, src) { - msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})", - p.Targets, p.Sources) - return errors.New(msg) - } - // Success - return nil -} - -func (p *Permutation) String() string { - targets := "" - sources := "" - - for i, s := range p.Targets { - if i != 0 { - targets += " " - } - - targets += fmt.Sprintf("%d", s) - } - - for i, s := range p.Sources { - if i != 0 { - sources += " " - } - - sources += fmt.Sprintf("%d", s) - } - - return fmt.Sprintf("(permutation (%s) (%s))", targets, sources) -} - -func sliceColumns(columns []uint, tr Trace) [][]*fr.Element { - // Allocate return array - cols := make([][]*fr.Element, len(columns)) - // Slice out the data - for i, n := range columns { - nth := tr.ColumnByIndex(n) - cols[i] = nth.Data() - } - // Done - return cols -} - -// =================================================================== -// Property Assertion -// =================================================================== - -// PropertyAssertion is similar to a vanishing constraint but is used only for -// debugging / testing / verification. Unlike vanishing constraints, property -// assertions do not represent something that the prover can enforce. Rather, -// they represent properties which are expected to hold for every valid trace. -// That is, they should be implied by the actual constraints. Thus, whilst the -// prover cannot enforce such properties, external tools (such as for formal -// verification) can attempt to ensure they do indeed always hold. -type PropertyAssertion[T Testable] struct { - // A unique identifier for this constraint. This is primarily - // useful for debugging. - Handle string - // The actual assertion itself, namely an expression which - // should hold (i.e. vanish) for every row of a trace. - // Observe that this can be any function which is computable - // on a given trace --- we are not restricted to expressions - // which can be arithmetised. - Property T -} - -// NewPropertyAssertion constructs a new property assertion! -func NewPropertyAssertion[T Testable](handle string, property T) *PropertyAssertion[T] { - return &PropertyAssertion[T]{handle, property} -} - -// GetHandle returns the handle associated with this constraint. -// -//nolint:revive -func (p *PropertyAssertion[T]) GetHandle() string { - return p.Handle -} - -// Accepts checks whether a vanishing constraint evaluates to zero on every row -// of a table. If so, return nil otherwise return an error. -// -//nolint:revive -func (p *PropertyAssertion[T]) Accepts(tr Trace) error { - for k := uint(0); k < tr.Height(); k++ { - // Check whether property holds (or was undefined) - if !p.Property.TestAt(int(k), tr) { - // Construct useful error message - msg := fmt.Sprintf("property assertion %s does not hold (row %d)", p.Handle, k) - // Evaluation failure - return errors.New(msg) - } - } - // All good - return nil -} diff --git a/pkg/table/schema.go b/pkg/table/schema.go deleted file mode 100644 index b744a21..0000000 --- a/pkg/table/schema.go +++ /dev/null @@ -1,66 +0,0 @@ -package table - -// Schema represents a schema which can be used to manipulate a trace. -// Specifically, a schema can determine whether or not a trace is accepted; -// likewise, a schema can expand a trace according to its internal computation. -type Schema interface { - Accepts(Trace) error - // ExpandTrace expands a given trace to include "computed - // columns". These are columns which do not exist in the - // original trace, but are added during trace expansion to - // form the final trace. - ExpandTrace(Trace) error - - // Size returns the number of declarations in this schema. - Size() int - - // GetDeclaration returns the ith declaration in this schema. - GetDeclaration(int) Declaration - - // RequiredSpillage returns the minimum amount of spillage required to - // ensure valid traces are accepted in the presence of arbitrary padding. - // Note: this is calculated on demand. - RequiredSpillage() uint - - // Determine the number of column groups in this schema. - Width() uint - - // Determine the index of a named column in this schema, or return false if - // no matching column exists. - ColumnIndex(string) (uint, bool) - - // Access information about the ith column group in this schema. - ColumnGroup(uint) ColumnGroup - - // Access information about the ith column in this schema. - Column(uint) ColumnSchema -} - -// ColumnGroup represents a group of related columns in the schema. For -// example, a single data column is (for now) always a column group of size 1. -// Likewise, an array of size n is a column group of size n, etc. -type ColumnGroup interface { - // Return the number of columns in this group. - Width() uint - - // Returns the name of the ith column in this group. - NameOf(uint) string - - // Determines whether or not this column group is synthetic. - IsSynthetic() bool -} - -// ColumnSchema provides information about a specific column in the schema. -type ColumnSchema interface { - // Returns the name of this column - Name() string -} - -// Declaration represents a declared element of a schema. For example, a column -// declaration or a vanishing constraint declaration. The purpose of this -// interface is to provide some generic interactions that are available -// regardless of the IR level. -type Declaration interface { - // Return a human-readable string for this declaration. - String() string -} diff --git a/pkg/test/ir_test.go b/pkg/test/ir_test.go index 82570b6..cd01504 100644 --- a/pkg/test/ir_test.go +++ b/pkg/test/ir_test.go @@ -9,7 +9,9 @@ import ( "testing" "github.com/consensys/go-corset/pkg/hir" - "github.com/consensys/go-corset/pkg/table" + "github.com/consensys/go-corset/pkg/schema" + sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" ) // Determines the (relative) location of the test directory. That is @@ -359,7 +361,7 @@ func Check(t *testing.T, test string) { // Check a given set of tests have an expected outcome (i.e. are // either accepted or rejected) by a given set of constraints. -func CheckTraces(t *testing.T, test string, expected bool, traces []*table.ArrayTrace, hirSchema *hir.Schema) { +func CheckTraces(t *testing.T, test string, expected bool, traces []*trace.ArrayTrace, hirSchema *hir.Schema) { for i, tr := range traces { if tr != nil { for padding := uint(0); padding <= MAX_PADDING; padding++ { @@ -368,14 +370,14 @@ func CheckTraces(t *testing.T, test string, expected bool, traces []*table.Array // Lower MIR => AIR airSchema := mirSchema.LowerToAir() // Construct trace identifiers - hirID := traceId{"HIR", test, expected, i + 1, padding, hirSchema.RequiredSpillage()} - mirID := traceId{"MIR", test, expected, i + 1, padding, mirSchema.RequiredSpillage()} - airID := traceId{"AIR", test, expected, i + 1, padding, airSchema.RequiredSpillage()} + hirID := traceId{"HIR", test, expected, i + 1, padding, schema.RequiredSpillage(hirSchema)} + mirID := traceId{"MIR", test, expected, i + 1, padding, schema.RequiredSpillage(mirSchema)} + airID := traceId{"AIR", test, expected, i + 1, padding, schema.RequiredSpillage(airSchema)} // Check whether trace is input compatible with schema - if err := tr.AlignInputWith(hirSchema); err != nil { + if err := sc.AlignInputs(tr, hirSchema); err != nil { // Alignment failed. So, attempt alignment as expanded // trace instead. - if err := tr.AlignWith(airSchema); err != nil { + if err := sc.Align(tr, airSchema); err != nil { // Still failed, hence trace must be malformed in some way if expected { t.Errorf("Trace malformed (%s.accepts, line %d): [%s]", test, i+1, err) @@ -397,17 +399,17 @@ func CheckTraces(t *testing.T, test string, expected bool, traces []*table.Array } } -func checkInputTrace(t *testing.T, tr *table.ArrayTrace, id traceId, schema table.Schema) { +func checkInputTrace(t *testing.T, tr *trace.ArrayTrace, id traceId, schema sc.Schema) { // Clone trace (to ensure expansion does not affect subsequent tests) etr := tr.Clone() // Apply spillage - etr.Pad(schema.RequiredSpillage()) + etr.Pad(id.spillage) // Expand trace - err := schema.ExpandTrace(etr) + err := sc.ExpandTrace(schema, etr) // Check if err != nil { t.Error(err) - } else if err := etr.AlignWith(schema); err != nil { + } else if err := sc.Align(etr, schema); err != nil { // Alignment problem t.Error(err) } else { @@ -415,11 +417,11 @@ func checkInputTrace(t *testing.T, tr *table.ArrayTrace, id traceId, schema tabl } } -func checkExpandedTrace(t *testing.T, tr table.Trace, id traceId, schema table.Schema) { +func checkExpandedTrace(t *testing.T, tr trace.Trace, id traceId, schema sc.Schema) { // Apply padding tr.Pad(id.padding) // Check - err := schema.Accepts(tr) + err := sc.Accepts(schema, tr) // Determine whether trace accepted or not. accepted := (err == nil) // Process what happened versus what was supposed to happen. @@ -460,14 +462,14 @@ type traceId struct { // ReadTracesFile reads a file containing zero or more traces expressed as JSON, where // each trace is on a separate line. -func ReadTracesFile(name string, ext string) []*table.ArrayTrace { +func ReadTracesFile(name string, ext string) []*trace.ArrayTrace { lines := ReadInputFile(name, ext) - traces := make([]*table.ArrayTrace, len(lines)) + traces := make([]*trace.ArrayTrace, len(lines)) // Read constraints line by line for i, line := range lines { // Parse input line as JSON if line != "" && !strings.HasPrefix(line, ";;") { - tr, err := table.ParseJsonTrace([]byte(line)) + tr, err := trace.ParseJsonTrace([]byte(line)) if err != nil { msg := fmt.Sprintf("%s.%s:%d: %s", name, ext, i+1, err) panic(msg) diff --git a/pkg/table/array_trace.go b/pkg/trace/array_trace.go similarity index 66% rename from pkg/table/array_trace.go rename to pkg/trace/array_trace.go index 2701681..e0fac77 100644 --- a/pkg/table/array_trace.go +++ b/pkg/trace/array_trace.go @@ -1,7 +1,6 @@ -package table +package trace import ( - "fmt" "strings" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -38,8 +37,9 @@ func (p *ArrayTrace) ColumnName(index int) string { return p.columns[index].Name() } -// IndexOf returns the index of the given name in this trace. -func (p *ArrayTrace) IndexOf(name string) (uint, bool) { +// ColumnIndex returns the column index of the column with the given name in +// this trace, or returns false if no such column exists. +func (p *ArrayTrace) ColumnIndex(name string) (uint, bool) { for i, c := range p.columns { if c.name == name { return uint(i), true @@ -74,76 +74,6 @@ func (p *ArrayTrace) HasColumn(name string) bool { return false } -// AlignInputWith attempts to align this trace with the input columns of a -// given schema. This means ensuring the order of columns in this trace matches -// the order of input columns in the schema. Thus, column indexes used by -// constraints in the schema can directly access in this trace (i.e. without -// name lookup). Alignment can fail, however, if there is a mismatch between -// columns in the trace and those expected by the schema. -func (p *ArrayTrace) AlignInputWith(schema Schema) error { - return alignWith(false, p, schema) -} - -// AlignWith attempts to align this trace with a given schema. This means -// ensuring the order of columns in this trace matches the order in the schema. -// Thus, column indexes used by constraints in the schema can directly access in -// this trace (i.e. without name lookup). Alignment can fail, however, if there -// is a mismatch between columns in the trace and those expected by the schema. -func (p *ArrayTrace) AlignWith(schema Schema) error { - return alignWith(true, p, schema) -} - -// Alignment algorithm which operates either in unexpanded or expanded mode. In -// expanded mode, all columns must be accounted for and will be aligned. In -// unexpanded mode, the trace is only expected to contain input (i.e. -// non-synthetic) columns. Furthermore, in the schema these are expected to be -// allocated before synthetic columns. As such, alignment of these input -// columns is performed. -func alignWith(expand bool, p *ArrayTrace, schema Schema) error { - ncols := uint(len(p.columns)) - index := uint(0) - // Check each column described in this schema is present in the trace. - for i := uint(0); i < schema.Width(); i++ { - group := schema.ColumnGroup(i) - if expand || !group.IsSynthetic() { - for j := uint(0); j < group.Width(); j++ { - // Determine column name - schemaName := group.NameOf(j) - // Sanity check column exists - if index >= ncols { - return fmt.Errorf("trace missing column %s", schemaName) - } - - traceName := p.columns[index].name - // Check alignment - if traceName != schemaName { - // Not aligned --- so fix - k, ok := p.IndexOf(schemaName) - // check exists - if !ok { - return fmt.Errorf("trace missing column %s", schemaName) - } - // Swap columns - tmp := p.columns[index] - p.columns[index] = p.columns[k] - p.columns[k] = tmp - } - // Continue - index++ - } - } - } - // Check whether all columns matched - if index == ncols { - // Yes, alignment complete. - return nil - } - // Error Case. - unknowns := p.columns[index:] - // - return fmt.Errorf("trace contains unknown columns: %v", unknowns) -} - // AddColumn adds a new column of data to this trace. func (p *ArrayTrace) AddColumn(name string, data []*fr.Element, padding *fr.Element) { // Sanity check the column does not already exist. @@ -198,6 +128,14 @@ func (p *ArrayTrace) Pad(n uint) { p.height += n } +// Swap the order of two columns in this trace. This is needed, in +// particular, for alignment. +func (p *ArrayTrace) Swap(l uint, r uint) { + tmp := p.columns[l] + p.columns[l] = p.columns[r] + p.columns[r] = tmp +} + func (p *ArrayTrace) String() string { // Use string builder to try and make this vaguely efficient. var id strings.Builder diff --git a/pkg/table/printer.go b/pkg/trace/printer.go similarity index 98% rename from pkg/table/printer.go rename to pkg/trace/printer.go index 0ab3c1e..d3a09df 100644 --- a/pkg/table/printer.go +++ b/pkg/trace/printer.go @@ -1,4 +1,4 @@ -package table +package trace import ( "fmt" diff --git a/pkg/table/trace.go b/pkg/trace/trace.go similarity index 71% rename from pkg/table/trace.go rename to pkg/trace/trace.go index 3922681..f988176 100644 --- a/pkg/table/trace.go +++ b/pkg/trace/trace.go @@ -1,4 +1,4 @@ -package table +package trace import ( "encoding/json" @@ -8,12 +8,6 @@ import ( "github.com/consensys/go-corset/pkg/util" ) -// Acceptable represents an element which can "accept" a trace, or either reject -// with an error (or eventually perhaps report a warning). -type Acceptable interface { - Accepts(Trace) error -} - // Column describes an individual column of data within a trace table. type Column interface { // Get the name of this column @@ -33,13 +27,6 @@ type Column interface { // Trace describes a set of named columns. Columns are not required to have the // same height and can be either "data" columns or "computed" columns. type Trace interface { - // Attempt to align this trace with a given schema. This means ensuring the - // order of columns in this trace matches the order in the schema. Thus, - // column indexes used by constraints in the schema can directly access in - // this trace (i.e. without name lookup). Alignment can fail, however, if - // there is a mismatch between columns in the trace and those expected by - // the schema. - AlignWith(schema Schema) error // Add a new column of data AddColumn(name string, data []*fr.Element, padding *fr.Element) // ColumnByIndex returns the ith column in this trace. @@ -47,6 +34,9 @@ type Trace interface { // ColumnByName returns the data of a given column in order that it can be // inspected. If the given column does not exist, then nil is returned. ColumnByName(name string) Column + // Determine the index of a particular column in this trace, or return false + // if no such column exists. + ColumnIndex(name string) (uint, bool) // Check whether this trace contains data for the given column. HasColumn(name string) bool // Pad each column in this trace with n items at the front. An iterator over @@ -55,24 +45,13 @@ type Trace interface { // Determine the height of this table, which is defined as the // height of the largest column. Height() uint + // Swap the order of two columns in this trace. This is needed, in + // particular, for alignment. + Swap(uint, uint) // Get the number of columns in this trace. Width() uint } -// ConstraintsAcceptTrace determines whether or not one or more groups of -// constraints accept a given trace. It returns the first error or warning -// encountered. -func ConstraintsAcceptTrace[T Acceptable](trace Trace, constraints []T) error { - for _, c := range constraints { - err := c.Accepts(trace) - if err != nil { - return err - } - } - // - return nil -} - // =================================================================== // JSON Parser // =================================================================== diff --git a/pkg/util/iterator.go b/pkg/util/iterator.go new file mode 100644 index 0000000..bb40dbb --- /dev/null +++ b/pkg/util/iterator.go @@ -0,0 +1,527 @@ +package util + +// Predicate abstracts the notion of a function which identifies something. +type Predicate[T any] func(T) bool + +// Iterator is an adapter which sits on top of a BaseIterator and provides +// various useful and reusable functions. +type Iterator[T any] interface { + BaseIterator[T] + + // Append another iterator onto the end of this iterator. Thus, when all + // items are visited in this iterator, iteration continues into the other. + Append(Iterator[T]) Iterator[T] + + // Clone creates a copy of this iterator at the given cursor position. + // Modifying the clone (i.e. by calling Next) iterator will not modify the + // original. + Clone() Iterator[T] + + // Collect allocates a new array containing all items of this iterator. + // This drains the iterator. + Collect() []T + + // Find returns the index of the first match for a given predicate, or + // return false if no match is found. This will mutate the iterator. + Find(Predicate[T]) (uint, bool) + + // Count the number of items left. Note, this does not modify the iterator. + Count() uint + + // Get the nth item in this iterator. This will mutate the iterator. + Nth(uint) T +} + +// =============================================================== +// Append Iterator +// =============================================================== + +type appendIterator[T comparable] struct { + left Iterator[T] + right Iterator[T] +} + +// NewAppendIterator construct an iterator over an array of items. +func NewAppendIterator[T comparable](left Iterator[T], right Iterator[T]) Iterator[T] { + return &appendIterator[T]{left, right} +} + +// HasNext checks whether or not there are any items remaining to visit. +// +//nolint:revive +func (p *appendIterator[T]) HasNext() bool { + return p.left.HasNext() || p.right.HasNext() +} + +// Next returns the next item, and advance the iterator. +// +//nolint:revive +func (p *appendIterator[T]) Next() T { + if p.left.HasNext() { + return p.left.Next() + } + + return p.right.Next() +} + +// Append another iterator onto the end of this iterator. Thus, when all +// items are visited in this iterator, iteration continues into the other. +// +//nolint:revive +func (p *appendIterator[T]) Append(iter Iterator[T]) Iterator[T] { + return NewAppendIterator(p, iter) +} + +// Clone creates a copy of this iterator at the given cursor position. +// Modifying the clone (i.e. by calling Next) iterator will not modify the +// original. +func (p *appendIterator[T]) Clone() Iterator[T] { + return NewAppendIterator[T](p.left.Clone(), p.right.Clone()) +} + +// Collect allocates a new array containing all items of this iterator. +// This drains the iterator. +// +//nolint:revive +func (p *appendIterator[T]) Collect() []T { + lhs := p.left.Collect() + rhs := p.right.Collect() + + return append(lhs, rhs...) +} + +// Count returns the number of items left in the iterator +// +//nolint:revive +func (p *appendIterator[T]) Count() uint { + return p.left.Count() + p.right.Count() +} + +// Find returns the index of the first match for a given predicate, or +// return false if no match is found. +// +//nolint:revive +func (p *appendIterator[T]) Find(predicate Predicate[T]) (uint, bool) { + return baseFind(p, predicate) +} + +// Nth returns the nth item in this iterator +// +//nolint:revive +func (p *appendIterator[T]) Nth(n uint) T { + // TODO: improve performance. + return baseNth(p, n) +} + +// =============================================================== +// Array Iterator +// =============================================================== + +// ArrayIterator provides an iterator implementation for an Array. +type arrayIterator[T comparable] struct { + items []T + index uint +} + +// NewArrayIterator construct an iterator over an array of items. +func NewArrayIterator[T comparable](items []T) Iterator[T] { + return &arrayIterator[T]{items, 0} +} + +// HasNext checks whether or not there are any items remaining to visit. +// +//nolint:revive +func (p *arrayIterator[T]) HasNext() bool { + return p.index < uint(len(p.items)) +} + +// Next returns the next item, and advance the iterator. +// +//nolint:revive +func (p *arrayIterator[T]) Next() T { + next := p.items[p.index] + p.index++ + + return next +} + +// Append another iterator onto the end of this iterator. Thus, when all +// items are visited in this iterator, iteration continues into the other. +// +//nolint:revive +func (p *arrayIterator[T]) Append(iter Iterator[T]) Iterator[T] { + return NewAppendIterator(p, iter) +} + +// Clone creates a copy of this iterator at the given cursor position. +// Modifying the clone (i.e. by calling Next) iterator will not modify the +// original. +// +//nolint:revive +func (p *arrayIterator[T]) Clone() Iterator[T] { + return &arrayIterator[T]{p.items, p.index} +} + +// Collect allocates a new array containing all items of this iterator. +// This drains the iterator. +// +//nolint:revive +func (p *arrayIterator[T]) Collect() []T { + items := make([]T, len(p.items)) + copy(items, p.items) + + return items +} + +// Count returns the number of items left in the iterator +// +//nolint:revive +func (p *arrayIterator[T]) Count() uint { + return uint(len(p.items)) - p.index +} + +// Find returns the index of the first match for a given predicate, or +// return false if no match is found. +// +//nolint:revive +func (p *arrayIterator[T]) Find(predicate Predicate[T]) (uint, bool) { + return baseFind(p, predicate) +} + +// Nth returns the nth item in this iterator +// +//nolint:revive +func (p *arrayIterator[T]) Nth(n uint) T { + p.index = n + return p.items[n] +} + +// =============================================================== +// Cast Iterator +// =============================================================== +type castIterator[S, T comparable] struct { + iter Iterator[S] +} + +// NewCastIterator construct an iterator over an array of items. +func NewCastIterator[S, T comparable](iter Iterator[S]) Iterator[T] { + return &castIterator[S, T]{iter} +} + +// HasNext checks whether or not there are any items remaining to visit. +// +//nolint:revive +func (p *castIterator[S, T]) HasNext() bool { + return p.iter.HasNext() +} + +// Next returns the next item, and advance the iterator. +// +//nolint:revive +func (p *castIterator[S, T]) Next() T { + n := any(p.iter.Next()) + return n.(T) +} + +// Append another iterator onto the end of this iterator. Thus, when all +// items are visited in this iterator, iteration continues into the other. +// +//nolint:revive +func (p *castIterator[S, T]) Append(iter Iterator[T]) Iterator[T] { + return NewAppendIterator(p, iter) +} + +// Clone creates a copy of this iterator at the given cursor position. +// Modifying the clone (i.e. by calling Next) iterator will not modify the +// original. +// +//nolint:revive +func (p *castIterator[S, T]) Clone() Iterator[T] { + return NewCastIterator[S, T](p.iter.Clone()) +} + +// Collect allocates a new array containing all items of this iterator. This drains the iterator. +// +//nolint:revive +func (p *castIterator[S, T]) Collect() []T { + items := make([]T, p.iter.Count()) + index := 0 + + for i := p.iter; i.HasNext(); { + n := any(i.Next()) + items[index] = n.(T) + index++ + } + + return items +} + +// Count returns the number of items left in the iterator +// +//nolint:revive +func (p *castIterator[S, T]) Count() uint { + return p.iter.Count() +} + +// Find returns the index of the first match for a given predicate, or +// return false if no match is found. +// +//nolint:revive +func (p *castIterator[S, T]) Find(predicate Predicate[T]) (uint, bool) { + return p.iter.Find(func(item S) bool { + tmp := any(item) + return predicate(tmp.(T)) + }) +} + +// Nth returns the nth item in this iterator +// +//nolint:revive +func (p *castIterator[S, T]) Nth(n uint) T { + v := any(p.iter.Nth(n)) + return v.(T) +} + +// =============================================================== +// Flatten Iterator +// =============================================================== + +// FlattenIterator provides an iterator implementation for an Array. +type flattenIterator[S, T comparable] struct { + // Outermost iterator + iter Iterator[S] + // Innermost iterator + curr Iterator[T] + // Mapping function + fn func(S) Iterator[T] +} + +// NewFlattenIterator adapts a sequence of items S which themselves can be +// iterated as items T, into a flat sequence of items T. +func NewFlattenIterator[S, T comparable](iter Iterator[S], fn func(S) Iterator[T]) Iterator[T] { + return &flattenIterator[S, T]{iter, nil, fn} +} + +// HasNext checks whether or not there are any items remaining to visit. +// +//nolint:revive +func (p *flattenIterator[S, T]) HasNext() bool { + if p.curr != nil && p.curr.HasNext() { + return true + } + // Find next hit + for p.iter.HasNext() { + p.curr = p.fn(p.iter.Next()) + if p.curr.HasNext() { + return true + } + } + // Failed + return false +} + +// Next returns the next item, and advance the iterator. +// +//nolint:revive +func (p *flattenIterator[S, T]) Next() T { + // Can assume HasNext called, otherwise this is undefined anyway :) + return p.curr.Next() +} + +// Append another iterator onto the end of this iterator. Thus, when all +// items are visited in this iterator, iteration continues into the other. +// +//nolint:revive +func (p *flattenIterator[S, T]) Append(iter Iterator[T]) Iterator[T] { + return NewAppendIterator[T](p, iter) +} + +// Clone creates a copy of this iterator at the given cursor position. +// Modifying the clone (i.e. by calling Next) iterator will not modify the +// original. +// +//nolint:revive +func (p *flattenIterator[S, T]) Clone() Iterator[T] { + var curr Iterator[T] + if p.curr != nil { + curr = p.curr.Clone() + } + + return &flattenIterator[S, T]{p.iter.Clone(), curr, p.fn} +} + +// Collect allocates a new array containing all items of this iterator. +// +//nolint:revive +func (p *flattenIterator[S, T]) Collect() []T { + items := make([]T, 0) + if p.curr != nil { + items = p.curr.Collect() + } + // Flatten each group in turn + for p.iter.HasNext() { + ith_items := p.fn(p.iter.Next()).Collect() + items = append(items, ith_items...) + } + // Done + return items +} + +// Count returns the number of items left in the iterator +// +//nolint:revive +func (p *flattenIterator[S, T]) Count() uint { + count := uint(0) + + for i := p.Clone(); i.HasNext(); { + i.Next() + + count++ + } + + return count +} + +// Find returns the index of the first match for a given predicate, or +// return false if no match is found. +// +//nolint:revive +func (p *flattenIterator[S, T]) Find(predicate Predicate[T]) (uint, bool) { + return baseFind(p, predicate) +} + +// Nth returns the nth item in this iterator +// +//nolint:revive +func (p *flattenIterator[S, T]) Nth(n uint) T { + panic("todo") +} + +// =============================================================== +// Unit Iterator +// =============================================================== + +type unitIterator[T comparable] struct { + item T + index uint +} + +// NewUnitIterator construct an iterator over an array of items. +func NewUnitIterator[T comparable](item T) *unitIterator[T] { + return &unitIterator[T]{item, 0} +} + +// HasNext checks whether or not there are any items remaining to visit. +// +//nolint:revive +func (p *unitIterator[T]) HasNext() bool { + return p.index < 1 +} + +// Next returns the next item, and advance the iterator. +// +//nolint:revive +func (p *unitIterator[T]) Next() T { + p.index++ + return p.item +} + +// Append another iterator onto the end of this iterator. Thus, when all +// items are visited in this iterator, iteration continues into the other. +// +//nolint:revive +func (p *unitIterator[T]) Append(iter Iterator[T]) Iterator[T] { + return NewAppendIterator(p, iter) +} + +// Clone creates a copy of this iterator at the given cursor position. Modifying +// the clone (i.e. by calling Next) iterator will not modify the original. +// +//nolint:revive +func (p *unitIterator[T]) Clone() Iterator[T] { + return &unitIterator[T]{p.item, p.index} +} + +// Collect allocates a new array containing all items of this iterator. +// This drains the iterator. +// +//nolint:revive +func (p *unitIterator[T]) Collect() []T { + items := make([]T, 1) + items[0] = p.item + + return items +} + +// Count returns the number of items left in the iterator +// +//nolint:revive +func (p *unitIterator[T]) Count() uint { + if p.index == 0 { + return 1 + } + // nothing left + return 0 +} + +// Find returns the index of the first match for a given predicate, or +// return false if no match is found. +// +//nolint:revive +func (p *unitIterator[T]) Find(predicate Predicate[T]) (uint, bool) { + if predicate(p.item) { + // Success + return 0, true + } + // Failed + return 0, false +} + +// Nth returns the nth item in this iterator +// +//nolint:revive +func (p *unitIterator[T]) Nth(n uint) T { + return p.item +} + +// =============================================================== +// Base Iterator +// =============================================================== + +// BaseIterator abstracts the process of iterating over a sequence of elements. +type BaseIterator[T any] interface { + // Check whether or not there are any items remaining to visit. + HasNext() bool + + // Get the next item, and advanced the iterator. + Next() T +} + +func baseFind[T comparable, S BaseIterator[T]](iter S, predicate Predicate[T]) (uint, bool) { + index := uint(0) + + for i := iter; i.HasNext(); { + if predicate(i.Next()) { + return index, true + } + + index++ + } + // Failed to find it + return 0, false +} + +func baseNth[T comparable, S BaseIterator[T]](iter S, n uint) T { + index := uint(0) + + for i := iter; i.HasNext(); { + ith := i.Next() + if index == n { + return ith + } + + index++ + } + // Issue! + panic("iterator out-of-bounds") +}