From 0ed72204cf7cadf06ad97a40e289118573cb9cf7 Mon Sep 17 00:00:00 2001 From: DavePearce Date: Mon, 15 Jul 2024 17:32:05 +1200 Subject: [PATCH 1/2] Support `FieldArray` This provides a simple implementation of an array of field elements which is backed by a byte array. This is now working, and is consuming less memory. However, its taking roughly twice as long to execute for reasons unknown. --- pkg/cmd/trace.go | 14 +- pkg/schema/assignment/byte_decomposition.go | 13 +- pkg/schema/assignment/computed_column.go | 12 +- pkg/schema/assignment/interleave.go | 18 +-- pkg/schema/assignment/lexicographic_sort.go | 33 +++-- pkg/schema/assignment/sorted_permutation.go | 18 +-- pkg/schema/constraint/permutation.go | 12 +- pkg/schema/type.go | 22 +++ pkg/trace/array_trace.go | 104 ++++++++++++-- pkg/trace/builder.go | 13 +- pkg/trace/bytes_column.go | 140 ------------------ pkg/trace/field_column.go | 123 ---------------- pkg/trace/json/reader.go | 3 +- pkg/trace/lt/reader.go | 7 +- pkg/trace/lt/writer.go | 7 +- pkg/trace/trace.go | 19 +-- pkg/util/field_array.go | 148 ++++++++++++++++++++ pkg/util/fields.go | 16 --- pkg/util/permutation.go | 18 +-- 19 files changed, 349 insertions(+), 391 deletions(-) delete mode 100644 pkg/trace/bytes_column.go delete mode 100644 pkg/trace/field_column.go create mode 100644 pkg/util/field_array.go diff --git a/pkg/cmd/trace.go b/pkg/cmd/trace.go index 0194fca2..f0178c85 100644 --- a/pkg/cmd/trace.go +++ b/pkg/cmd/trace.go @@ -6,7 +6,6 @@ import ( "os" "strings" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" @@ -80,12 +79,7 @@ func filterColumns(tr trace.Trace, prefix string) trace.Trace { if strings.HasPrefix(qName, prefix) { ith := tr.Columns().Get(i) // Copy column data - data := make([]*fr.Element, ith.Height()) - // - for j := 0; j < int(ith.Height()); j++ { - data[j] = ith.Get(j) - } - + data := ith.Data().Clone() err := builder.Add(qName, ith.Padding(), data) // Sanity check if err != nil { @@ -102,9 +96,9 @@ func listColumns(tr trace.Trace) { tbl := util.NewTablePrinter(3, n) for i := uint(0); i < n; i++ { - ith := tr.Columns().Get(i) - elems := fmt.Sprintf("%d rows", ith.Height()) - bytes := fmt.Sprintf("%d bytes", ith.Width()*ith.Height()) + ith := tr.Columns().Get(i).Data() + elems := fmt.Sprintf("%d rows", ith.Len()) + bytes := fmt.Sprintf("%d bytes", ith.ByteWidth()*ith.Len()) tbl.SetRow(i, QualifiedColumnName(i, tr), elems, bytes) } diff --git a/pkg/schema/assignment/byte_decomposition.go b/pkg/schema/assignment/byte_decomposition.go index 6839d561..70cdefb1 100644 --- a/pkg/schema/assignment/byte_decomposition.go +++ b/pkg/schema/assignment/byte_decomposition.go @@ -69,16 +69,17 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error { // Identify source column source := columns.Get(p.source) // Construct byte column data - cols := make([][]*fr.Element, n) + cols := make([]*util.FieldArray, n) // Initialise columns for i := 0; i < n; i++ { - cols[i] = make([]*fr.Element, source.Height()) + // Construct a byte column for ith byte + cols[i] = util.NewFieldArray(source.Height(), 1) } // Decompose each row of each column - for i := 0; i < int(source.Height()); i = i + 1 { - ith := decomposeIntoBytes(source.Get(i), n) + for i := uint(0); i < source.Height(); i = i + 1 { + ith := decomposeIntoBytes(source.Get(int(i)), n) for j := 0; j < n; j++ { - cols[j][i] = ith[j] + cols[j].Set(i, ith[j]) } } // Determine padding values @@ -86,7 +87,7 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error { // Finally, add byte columns to trace for i := 0; i < n; i++ { ith := p.targets[i] - columns.Add(trace.NewFieldColumn(ith.Context(), ith.Name(), cols[i], padding[i])) + columns.Add(ith.Context(), ith.Name(), cols[i], padding[i]) } // Done return nil diff --git a/pkg/schema/assignment/computed_column.go b/pkg/schema/assignment/computed_column.go index 332223c5..a50bada5 100644 --- a/pkg/schema/assignment/computed_column.go +++ b/pkg/schema/assignment/computed_column.go @@ -87,15 +87,15 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error { // Determine multiplied height height := tr.Modules().Get(p.target.Context().Module()).Height() * multiplier // Make space for computed data - data := make([]*fr.Element, height) + data := util.NewFieldArray(height, 32) // Expand the trace - for i := 0; i < len(data); i++ { - val := p.expr.EvalAt(i, tr) + for i := uint(0); i < data.Len(); i++ { + val := p.expr.EvalAt(int(i), tr) if val != nil { - data[i] = val + data.Set(i, val) } else { zero := fr.NewElement(0) - data[i] = &zero + data.Set(i, &zero) } } // Determine padding value. A negative row index is used here to ensure @@ -103,7 +103,7 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error { // the padding value for *this* column. padding := p.expr.EvalAt(-1, tr) // Colunm needs to be expanded. - columns.Add(trace.NewFieldColumn(p.target.Context(), p.Name(), data, padding)) + columns.Add(p.target.Context(), p.Name(), data, padding) // Done return nil } diff --git a/pkg/schema/assignment/interleave.go b/pkg/schema/assignment/interleave.go index 50331756..58de5272 100644 --- a/pkg/schema/assignment/interleave.go +++ b/pkg/schema/assignment/interleave.go @@ -3,9 +3,7 @@ package assignment import ( "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/schema" - "github.com/consensys/go-corset/pkg/trace" tr "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -73,12 +71,16 @@ func (p *Interleaving) RequiredSpillage() uint { func (p *Interleaving) ExpandTrace(tr tr.Trace) error { columns := tr.Columns() ctx := p.target.Context() + // Byte width records the largest width of any column. + byte_width := uint(0) // Ensure target column doesn't exist for i := p.Columns(); i.HasNext(); { - name := i.Next().Name() + ith := i.Next() + // Update byte width + byte_width = max(byte_width, ith.Type().ByteWidth()) // Sanity check no column already exists with this name. - if _, ok := columns.IndexOf(ctx.Module(), name); ok { - return fmt.Errorf("interleaved column already exists ({%s})", name) + if _, ok := columns.IndexOf(ctx.Module(), ith.Name()); ok { + return fmt.Errorf("interleaved column already exists ({%s})", ith.Name()) } } // Determine interleaving width @@ -90,7 +92,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error { // the interleaved column) height := tr.Modules().Get(ctx.Module()).Height() * multiplier // Construct empty array - data := make([]*fr.Element, height*width) + data := util.NewFieldArray(height*width, uint8(byte_width)) // Offset just gives the column index offset := uint(0) // Copy interleaved data @@ -99,7 +101,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error { col := tr.Columns().Get(p.sources[i]) // Copy over for j := uint(0); j < height; j++ { - data[offset+(j*width)] = col.Get(int(j)) + data.Set(offset+(j*width), col.Get(int(j))) } offset++ @@ -108,7 +110,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error { // column in the interleaving. padding := columns.Get(0).Padding() // Colunm needs to be expanded. - columns.Add(trace.NewFieldColumn(ctx, p.target.Name(), data, padding)) + columns.Add(ctx, p.target.Name(), data, padding) // return nil } diff --git a/pkg/schema/assignment/lexicographic_sort.go b/pkg/schema/assignment/lexicographic_sort.go index 68bd139e..efb6403c 100644 --- a/pkg/schema/assignment/lexicographic_sort.go +++ b/pkg/schema/assignment/lexicographic_sort.go @@ -80,48 +80,55 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error { // Determine how many rows to be constrained. nrows := tr.Modules().Get(p.context.Module()).Height() * multiplier // Initialise new data columns - delta := make([]*fr.Element, nrows) - bit := make([][]*fr.Element, ncols) + bit := make([]*util.FieldArray, ncols) + // Byte width records the largest width of any column. + byte_width := uint(0) for i := 0; i < ncols; i++ { - bit[i] = make([]*fr.Element, nrows) + // TODO: following can be optimised to use a single bit per element, + // rather than an entire byte. + bit[i] = util.NewFieldArray(nrows, 1) + ith := columns.Get(p.sources[i]) + byte_width = max(byte_width, ith.Data().ByteWidth()) } - for i := 0; i < int(nrows); i++ { + delta := util.NewFieldArray(nrows, uint8(byte_width)) + + for i := uint(0); i < nrows; i++ { set := false // Initialise delta to zero - delta[i] = &zero + delta.Set(i, &zero) // Decide which row is the winner (if any) for j := 0; j < ncols; j++ { - prev := columns.Get(p.sources[j]).Get(i - 1) - curr := columns.Get(p.sources[j]).Get(i) + prev := columns.Get(p.sources[j]).Get(int(i - 1)) + curr := columns.Get(p.sources[j]).Get(int(i)) if !set && prev != nil && prev.Cmp(curr) != 0 { var diff fr.Element - bit[j][i] = &one + bit[j].Set(i, &one) // Compute curr - prev if p.signs[j] { diff.Set(curr) - delta[i] = diff.Sub(&diff, prev) + delta.Set(i, diff.Sub(&diff, prev)) } else { diff.Set(prev) - delta[i] = diff.Sub(&diff, curr) + delta.Set(i, diff.Sub(&diff, curr)) } set = true } else { - bit[j][i] = &zero + bit[j].Set(i, &zero) } } } // Add delta column data first := p.targets[0] - columns.Add(trace.NewFieldColumn(first.Context(), first.Name(), delta, &zero)) + columns.Add(first.Context(), first.Name(), delta, &zero) // Add bit column data for i := 0; i < ncols; i++ { ith := p.targets[1+i] - columns.Add(trace.NewFieldColumn(ith.Context(), ith.Name(), bit[i], &zero)) + columns.Add(ith.Context(), ith.Name(), bit[i], &zero) } // Done. return nil diff --git a/pkg/schema/assignment/sorted_permutation.go b/pkg/schema/assignment/sorted_permutation.go index 071bb3ae..721ac096 100644 --- a/pkg/schema/assignment/sorted_permutation.go +++ b/pkg/schema/assignment/sorted_permutation.go @@ -132,20 +132,14 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error { } } - cols := make([][]*fr.Element, len(p.sources)) + cols := make([]util.Array[*fr.Element], len(p.sources)) // Construct target columns for i := 0; i < len(p.sources); i++ { src := p.sources[i] - // Read column data to initialise permutation. - col := columns.Get(src) - // Copy column data to initialise permutation. - copy := make([]*fr.Element, col.Height()) - // - for j := 0; j < int(col.Height()); j++ { - copy[j] = col.Get(j) - } - // Copy over - cols[i] = copy + // Read column data + data := columns.Get(src).Data() + // Clone it to initialise permutation. + cols[i] = data.Clone() } // Sort target columns util.PermutationSort(cols, p.signs) @@ -156,7 +150,7 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error { ith := i.Next() dstColName := ith.Name() srcCol := tr.Columns().Get(p.sources[index]) - columns.Add(trace.NewFieldColumn(ith.Context(), dstColName, cols[index], srcCol.Padding())) + columns.Add(ith.Context(), dstColName, cols[index], srcCol.Padding()) } // return nil diff --git a/pkg/schema/constraint/permutation.go b/pkg/schema/constraint/permutation.go index e44cde14..1b0d1a68 100644 --- a/pkg/schema/constraint/permutation.go +++ b/pkg/schema/constraint/permutation.go @@ -72,20 +72,14 @@ func (p *PermutationConstraint) String() string { return fmt.Sprintf("(permutation (%s) (%s))", targets, sources) } -func sliceColumns(columns []uint, tr trace.Trace) [][]*fr.Element { +func sliceColumns(columns []uint, tr trace.Trace) []util.Array[*fr.Element] { // Allocate return array - cols := make([][]*fr.Element, len(columns)) + cols := make([]util.Array[*fr.Element], len(columns)) // Slice out the data for i, n := range columns { nth := tr.Columns().Get(n) - // Copy column data to initialise permutation. - copy := make([]*fr.Element, nth.Height()) - // - for j := 0; j < int(nth.Height()); j++ { - copy[j] = nth.Get(j) - } // Copy over - cols[i] = copy + cols[i] = nth.Data() } // Done return cols diff --git a/pkg/schema/type.go b/pkg/schema/type.go index fe7be6b0..446e5aaa 100644 --- a/pkg/schema/type.go +++ b/pkg/schema/type.go @@ -22,6 +22,9 @@ type Type interface { // Accept checks whether a specific value is accepted by this type Accept(*fr.Element) bool + // Return the number of bytes required represent any element of this type. + ByteWidth() uint + // Produce a string representation of this type. String() string } @@ -59,6 +62,19 @@ func (p *UintType) AsField() *FieldType { return nil } +// ByteWidth returns the number of bytes required represent any element of this +// type. +func (p *UintType) ByteWidth() uint { + m := p.nbits / 8 + n := p.nbits % 8 + // Check for even division + if n == 0 { + return m + } + // + return m + 1 +} + // Accept determines whether a given value is an element of this type. For // example, 123 is an element of the type u8 whilst 256 is not. func (p *UintType) Accept(val *fr.Element) bool { @@ -104,6 +120,12 @@ func (p *FieldType) AsField() *FieldType { return p } +// ByteWidth returns the number of bytes required represent any element of this +// type. +func (p *FieldType) ByteWidth() uint { + return 32 +} + // Accept determines whether a given value is an element of this type. In // fact, all field elements are members of this type. func (p *FieldType) Accept(val *fr.Element) bool { diff --git a/pkg/trace/array_trace.go b/pkg/trace/array_trace.go index 7874deaf..c59f060a 100644 --- a/pkg/trace/array_trace.go +++ b/pkg/trace/array_trace.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/util" ) @@ -13,7 +14,7 @@ type ArrayTrace struct { // Holds the complete set of columns in this trace. The index of each // column in this array uniquely identifies it, and is referred to as the // "column index". - columns []Column + columns []*ArrayTraceColumn // Holds the complete set of modules in this trace. The index of each // module in this array uniquely identifies it, and is referred to as the // "module index". @@ -29,7 +30,7 @@ func (p *ArrayTrace) Columns() ColumnSet { // Clone creates an identical clone of this trace. func (p *ArrayTrace) Clone() Trace { clone := new(ArrayTrace) - clone.columns = make([]Column, len(p.columns)) + clone.columns = make([]*ArrayTraceColumn, len(p.columns)) clone.modules = make([]Module, len(p.modules)) // Clone modules for i, m := range p.modules { @@ -37,7 +38,7 @@ func (p *ArrayTrace) Clone() Trace { } // Clone columns for i, c := range p.columns { - clone.columns[i] = c.Clone() + clone.columns[i] = NewArrayTraceColumn(c.context, c.name, c.data.Clone(), c.padding) } // done return clone @@ -102,17 +103,16 @@ type arrayTraceColumnSet struct { } // Add a new column to this column set. -func (p arrayTraceColumnSet) Add(column Column) uint { - ctx := column.Context() +func (p arrayTraceColumnSet) Add(ctx Context, name string, data util.Array[*fr.Element], padding *fr.Element) uint { m := &p.trace.modules[ctx.Module()] // Sanity check effective height - if column.Height() != (ctx.LengthMultiplier() * m.Height()) { - panic(fmt.Sprintf("invalid column height for %s: %d vs %d*%d", column.Name(), - column.Height(), m.Height(), ctx.LengthMultiplier())) + if data.Len() != (ctx.LengthMultiplier() * m.Height()) { + panic(fmt.Sprintf("invalid column height for %s: %d vs %d*%d", name, + data.Len(), m.Height(), ctx.LengthMultiplier())) } // Proceed index := uint(len(p.trace.columns)) - p.trace.columns = append(p.trace.columns, column) + p.trace.columns = append(p.trace.columns, NewArrayTraceColumn(ctx, name, data, padding)) // Register column with enclosing module m.registerColumn(index) // Done @@ -227,12 +227,94 @@ func (p arrayTraceModuleSet) Pad(index uint, n uint) { m.height += n // for _, c := range m.columns { - p.trace.columns[c].Pad(n) + p.trace.columns[c].pad(n) } } func (p arrayTraceModuleSet) reseatColumns(mid uint, columns []uint) { for _, c := range columns { - p.trace.columns[c].Reseat(mid) + p.trace.columns[c].reseat(mid) } } + +// ============================================================================ +// ArrayTraceColumn +// ============================================================================ + +// ArrayTraceColumn represents a column of data within a trace where each row is +// stored directly as a field element. This is the simplest form of column, +// which provides the fastest Get operation (i.e. because it just reads the +// field element out directly). However, at the same time, it can potentially +// use quite a lot of memory. In particular, when there are many different +// field elements which have smallish values then this requires excess data. +type ArrayTraceColumn struct { + // Evaluation context of this column + context Context + // Holds the name of this column + name string + // Holds the raw data making up this column + data util.Array[*fr.Element] + // Value to be used when padding this column + padding *fr.Element +} + +// NewArrayTraceColumn constructs a ArrayTraceColumn with the give name, data and padding. +func NewArrayTraceColumn(context Context, name string, data util.Array[*fr.Element], + padding *fr.Element) *ArrayTraceColumn { + // Sanity check data length + if data.Len()%context.LengthMultiplier() != 0 { + panic("data length has incorrect multiplier") + } + // Done + return &ArrayTraceColumn{context, name, data, padding} +} + +// Context returns the evaluation context this column provides. +func (p *ArrayTraceColumn) Context() Context { + return p.context +} + +// Name returns the name of the given column. +func (p *ArrayTraceColumn) Name() string { + return p.name +} + +// Height determines the height of this column. +func (p *ArrayTraceColumn) Height() uint { + return p.data.Len() +} + +// Padding returns the value which will be used for padding this column. +func (p *ArrayTraceColumn) Padding() *fr.Element { + return p.padding +} + +// Data provides access to the underlying data of this column +func (p *ArrayTraceColumn) Data() util.Array[*fr.Element] { + return p.data +} + +// Get the value at a given row in this column. If the row is +// out-of-bounds, then the column's padding value is returned instead. +// Thus, this function always succeeds. +func (p *ArrayTraceColumn) Get(row int) *fr.Element { + if row < 0 || uint(row) >= p.data.Len() { + // out-of-bounds access + return p.padding + } + // in-bounds access + return p.data.Get(uint(row)) +} + +func (p *ArrayTraceColumn) pad(n uint) { + // Apply the length multiplier + n = n * p.context.LengthMultiplier() + // Pad front of array + p.data = p.data.PadFront(n, p.padding) +} + +// Reseat updates the module index of this column (e.g. as a result of a +// realignment). +func (p *ArrayTraceColumn) reseat(mid uint) { + p.context = NewContext(mid, p.context.LengthMultiplier()) +} diff --git a/pkg/trace/builder.go b/pkg/trace/builder.go index 9add6455..77b6d4b7 100644 --- a/pkg/trace/builder.go +++ b/pkg/trace/builder.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/util" ) // Builder is a helper utility for constructing new traces. It simplifies @@ -18,7 +19,7 @@ type Builder struct { // Mapping from name to module index modmap map[string]uint // Set of known columns - columns []Column + columns []*ArrayTraceColumn } // NewBuilder constructs an empty builder which can then be used to build a new @@ -26,7 +27,7 @@ type Builder struct { func NewBuilder() *Builder { modules := make([]Module, 0) modmap := make(map[string]uint, 0) - columns := make([]Column, 0) + columns := make([]*ArrayTraceColumn, 0) // Initially empty environment return &Builder{modules, modmap, columns} } @@ -39,7 +40,7 @@ func (p *Builder) Build() Trace { // Add a new column to this trace based on a fully qualified column name. This // splits the qualified column name and (if necessary) registers a new module // with the given height. -func (p *Builder) Add(name string, padding *fr.Element, data []*fr.Element) error { +func (p *Builder) Add(name string, padding *fr.Element, data util.Array[*fr.Element]) error { var err error // Split qualified column name modname, colname := p.splitQualifiedColumnName(name) @@ -47,7 +48,7 @@ func (p *Builder) Add(name string, padding *fr.Element, data []*fr.Element) erro mid, ok := p.modmap[modname] // Register module (if not located) if !ok { - if mid, err = p.Register(modname, uint(len(data))); err != nil { + if mid, err = p.Register(modname, data.Len()); err != nil { // Should be unreachable. return err } @@ -57,7 +58,7 @@ func (p *Builder) Add(name string, padding *fr.Element, data []*fr.Element) erro // where we are importing expanded traces, then this might not be true. context := NewContext(mid, 1) // Register new column. - return p.registerColumn(NewFieldColumn(context, colname, data, padding)) + return p.registerColumn(NewArrayTraceColumn(context, colname, data, padding)) } // HasModule checks whether a given module has already been registered with this @@ -100,7 +101,7 @@ func (p *Builder) splitQualifiedColumnName(name string) (string, string) { // RegisterColumn registers a new column with this builder. An error can arise // if the column's module does not exist, or if the column's height does not // match that of its enclosing module. -func (p *Builder) registerColumn(col Column) error { +func (p *Builder) registerColumn(col *ArrayTraceColumn) error { mid := col.Context().Module() // Sanity check module exists if mid >= uint(len(p.modules)) { diff --git a/pkg/trace/bytes_column.go b/pkg/trace/bytes_column.go deleted file mode 100644 index 5e4ec81c..00000000 --- a/pkg/trace/bytes_column.go +++ /dev/null @@ -1,140 +0,0 @@ -package trace - -import ( - "io" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" -) - -// BytesColumn represents a column of data within a trace as a raw byte array, -// such that each element occupies a fixed number of bytes. Accessing elements -// in this column is potentially slower than for a FieldColumn, as the raw bytes -// must be converted into a field element. -type BytesColumn struct { - // Evaluation context of this column - context Context - // Holds the name of this column - name string - // Determines how many bytes each field element takes. For the BLS12-377 - // curve, this should be 32. In the future, when other curves are - // supported, this could be less. - width uint8 - // The number of data elements in this column. - length uint - // The data stored in this column (as bytes). - bytes []byte - // Value to be used when padding this column - padding *fr.Element -} - -// NewBytesColumn constructs a new BytesColumn from its constituent parts. -func NewBytesColumn(context Context, name string, width uint8, length uint, - bytes []byte, padding *fr.Element) *BytesColumn { - // Sanity check data length - if length%context.LengthMultiplier() != 0 { - panic("data length has incorrect multiplier") - } - - return &BytesColumn{context, name, width, length, bytes, padding} -} - -// Context returns the evaluation context this column provides. -func (p *BytesColumn) Context() Context { - return p.context -} - -// Name returns the name of this column -func (p *BytesColumn) Name() string { - return p.name -} - -// Width returns the number of bytes required for each element in this column. -func (p *BytesColumn) Width() uint { - return uint(p.width) -} - -// Height returns the number of rows in this column. -func (p *BytesColumn) Height() uint { - return p.length -} - -// Padding returns the value which will be used for padding this column. -func (p *BytesColumn) Padding() *fr.Element { - return p.padding -} - -// Get the ith row of this column as a field element. -func (p *BytesColumn) Get(i int) *fr.Element { - // TODO: error for out-of-bounds accesses!!!! - var elem fr.Element - // Determine starting offset within bytes slice - start := int(p.width) * i - end := start + int(p.width) - // Construct field element. - return elem.SetBytes(p.bytes[start:end]) -} - -// Clone an BytesColumn -func (p *BytesColumn) Clone() Column { - clone := new(BytesColumn) - clone.context = p.context - clone.name = p.name - clone.length = p.length - clone.width = p.width - clone.padding = p.padding - // NOTE: the following is as we never actually mutate the underlying bytes - // array. - clone.bytes = p.bytes - // Done - return clone -} - -// SetBytes sets the raw byte array underlying this column. Care must be taken -// when mutating a column which is already being used in a trace, as this could -// lead to unexpected behaviour. -func (p *BytesColumn) SetBytes(bytes []byte) { - p.bytes = bytes -} - -// Pad this column with n copies of the column's padding value. -func (p *BytesColumn) Pad(n uint) { - // Apply the length multiplier - n = n * p.context.LengthMultiplier() - // Computing padding length (in bytes) - padding_len := n * uint(p.width) - // Access bytes to use for padding - padding_bytes := p.padding.Bytes() - padded_bytes := make([]byte, padding_len+uint(len(p.bytes))) - // Append padding - offset := 0 - - for i := uint(0); i < n; i++ { - // Calculate starting position within the 32byte array, remembering that - // padding_bytes is stored in _big endian_ format meaning - // padding_bytes[0] is the _most significant_ byte. - start := 32 - p.width - // Copy over least significant bytes - for j := start; j < 32; j++ { - padded_bytes[offset] = padding_bytes[j] - offset++ - } - } - // Copy over original data - copy(padded_bytes[padding_len:], p.bytes) - // Done - p.bytes = padded_bytes - p.length += n -} - -// Reseat updates the module index of this column (e.g. as a result of a -// realignment). -func (p *BytesColumn) Reseat(mid uint) { - p.context = NewContext(mid, p.context.LengthMultiplier()) -} - -// Write the raw bytes of this column to a given writer, returning an error -// if this failed (for some reason). -func (p *BytesColumn) Write(w io.Writer) error { - _, err := w.Write(p.bytes) - return err -} diff --git a/pkg/trace/field_column.go b/pkg/trace/field_column.go deleted file mode 100644 index cf44e09f..00000000 --- a/pkg/trace/field_column.go +++ /dev/null @@ -1,123 +0,0 @@ -package trace - -import ( - "io" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" -) - -// FieldColumn represents a column of data within a trace where each row is -// stored directly as a field element. This is the simplest form of column, -// which provides the fastest Get operation (i.e. because it just reads the -// field element out directly). However, at the same time, it can potentially -// use quite a lot of memory. In particular, when there are many different -// field elements which have smallish values then this requires excess data. -type FieldColumn struct { - // Evaluation context of this column - context Context - // Holds the name of this column - name string - // Holds the raw data making up this column - data []*fr.Element - // Value to be used when padding this column - padding *fr.Element -} - -// NewFieldColumn constructs a FieldColumn with the give name, data and padding. -func NewFieldColumn(context Context, name string, data []*fr.Element, padding *fr.Element) *FieldColumn { - // Sanity check data length - if uint(len(data))%context.LengthMultiplier() != 0 { - panic("data length has incorrect multiplier") - } - // Done - return &FieldColumn{context, name, data, padding} -} - -// Context returns the evaluation context this column provides. -func (p *FieldColumn) Context() Context { - return p.context -} - -// Name returns the name of the given column. -func (p *FieldColumn) Name() string { - return p.name -} - -// Width determines the number of bytes per element for this column (which, in -// this case, is always 32). -func (p *FieldColumn) Width() uint { - return 32 -} - -// Height determines the height of this column. -func (p *FieldColumn) Height() uint { - return uint(len(p.data)) -} - -// Padding returns the value which will be used for padding this column. -func (p *FieldColumn) Padding() *fr.Element { - return p.padding -} - -// Get the value at a given row in this column. If the row is -// out-of-bounds, then the column's padding value is returned instead. -// Thus, this function always succeeds. -func (p *FieldColumn) Get(row int) *fr.Element { - if row < 0 || row >= len(p.data) { - // out-of-bounds access - return p.padding - } - // in-bounds access - return p.data[row] -} - -// Clone an FieldColumn -func (p *FieldColumn) Clone() Column { - clone := new(FieldColumn) - clone.context = p.context - clone.name = p.name - clone.padding = p.padding - // NOTE: the following is as we never actually mutate the underlying bytes - // array. - clone.data = p.data - - return clone -} - -// Pad this column with n copies of the column's padding value. -func (p *FieldColumn) Pad(n uint) { - // Apply the length multiplier - n = n * p.context.LengthMultiplier() - // Allocate sufficient memory - ndata := make([]*fr.Element, uint(len(p.data))+n) - // Copy over the data - copy(ndata[n:], p.data) - // Go padding! - for i := uint(0); i < n; i++ { - ndata[i] = p.padding - } - // Copy over - p.data = ndata -} - -// Reseat updates the module index of this column (e.g. as a result of a -// realignment). -func (p *FieldColumn) Reseat(mid uint) { - p.context = NewContext(mid, p.context.LengthMultiplier()) -} - -// Write the raw bytes of this column to a given writer, returning an error -// if this failed (for some reason). Observe that this always writes data in -// 32byte chunks. -func (p *FieldColumn) Write(w io.Writer) error { - for _, e := range p.data { - // Read exactly 32 bytes - bytes := e.Bytes() - // Write them out - if _, err := w.Write(bytes[:]); err != nil { - return err - } - } - // - return nil -} diff --git a/pkg/trace/json/reader.go b/pkg/trace/json/reader.go index e983e663..a1960034 100644 --- a/pkg/trace/json/reader.go +++ b/pkg/trace/json/reader.go @@ -26,7 +26,8 @@ func FromBytes(bytes []byte) (trace.Trace, error) { for name, rawInts := range rawData { // Translate raw bigints into raw field elements - rawElements := util.ToFieldElements(rawInts) + // TODO: support native field widths in column name. + rawElements := util.FieldArrayFromBigInts(32, rawInts) // Add column and sanity check for errors if err := builder.Add(name, &zero, rawElements); err != nil { return nil, err diff --git a/pkg/trace/lt/reader.go b/pkg/trace/lt/reader.go index bad030d0..e1c8668a 100644 --- a/pkg/trace/lt/reader.go +++ b/pkg/trace/lt/reader.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" ) // FromBytes parses a byte array representing a given LT trace file into an @@ -93,8 +94,8 @@ func readColumnHeader(buf *bytes.Reader) (columnHeader, error) { return header, nil } -func readColumnData(header columnHeader, bytes []byte) []*fr.Element { - data := make([]*fr.Element, header.length) +func readColumnData(header columnHeader, bytes []byte) *util.FieldArray { + data := util.NewFieldArray(header.length, uint8(header.width)) offset := uint(0) for i := uint(0); i < header.length; i++ { @@ -102,7 +103,7 @@ func readColumnData(header columnHeader, bytes []byte) []*fr.Element { // Calculate position of next element next := offset + header.width // Construct ith field element - data[i] = ith.SetBytes(bytes[offset:next]) + data.Set(i, ith.SetBytes(bytes[offset:next])) // Move offset to next element offset = next } diff --git a/pkg/trace/lt/writer.go b/pkg/trace/lt/writer.go index 41182448..7975ade4 100644 --- a/pkg/trace/lt/writer.go +++ b/pkg/trace/lt/writer.go @@ -42,6 +42,7 @@ func WriteBytes(tr trace.Trace, buf io.Writer) error { // Write header information for i := uint(0); i < ncols; i++ { col := columns.Get(i) + data := col.Data() mod := modules.Get(col.Context().Module()) name := col.Name() // Prepend module name (if applicable) @@ -61,18 +62,18 @@ func WriteBytes(tr trace.Trace, buf io.Writer) error { log.Fatal(err) } // Write bytes per element - if err := binary.Write(buf, binary.BigEndian, uint8(col.Width())); err != nil { + if err := binary.Write(buf, binary.BigEndian, uint8(data.ByteWidth())); err != nil { log.Fatal(err) } // Write Data length - if err := binary.Write(buf, binary.BigEndian, uint32(col.Height())); err != nil { + if err := binary.Write(buf, binary.BigEndian, uint32(data.Len())); err != nil { log.Fatal(err) } } // Write column data information for i := uint(0); i < ncols; i++ { col := columns.Get(i) - if err := col.Write(buf); err != nil { + if err := col.Data().Write(buf); err != nil { return err } } diff --git a/pkg/trace/trace.go b/pkg/trace/trace.go index f8755b2e..20eb535f 100644 --- a/pkg/trace/trace.go +++ b/pkg/trace/trace.go @@ -1,9 +1,8 @@ package trace import ( - "io" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/util" ) // Trace describes a set of named columns. Columns are not required to have the @@ -20,7 +19,7 @@ type Trace interface { // ColumnSet provides an interface to the declared columns within this trace. type ColumnSet interface { // Add a new column to this column set. - Add(column Column) uint + Add(ctx Context, name string, data util.Array[*fr.Element], padding *fr.Element) uint // Get the ith module in this set. Get(uint) Column // Determine index of given column, or return false if this fails. @@ -36,12 +35,12 @@ type ColumnSet interface { // Column describes an individual column of data within a trace table. type Column interface { - // Clone this column - Clone() Column // Get the value at a given row in this column. If the row is // out-of-bounds, then the column's padding value is returned instead. // Thus, this function always succeeds. Get(row int) *fr.Element + // Access the underlying data array for this column + Data() util.Array[*fr.Element] // Return the height (i.e. number of rows) of this column. Height() uint // Returns the evaluation context for this column. That identifies the @@ -54,16 +53,6 @@ type Column interface { Name() string // Return the value to use for padding this column. Padding() *fr.Element - // Pad this column with n copies of the column's padding value. - Pad(n uint) - // Reseat updates the module index of this column (e.g. as a result of a - // realignment). - Reseat(mid uint) - // Return the width (i.e. number of bytes per element) of this column. - Width() uint - // Write the raw bytes of this column to a given writer, returning an error - // if this failed (for some reason). - Write(io.Writer) error } // ModuleSet provides an interface to the declared moules within this trace. diff --git a/pkg/util/field_array.go b/pkg/util/field_array.go new file mode 100644 index 00000000..732c1b99 --- /dev/null +++ b/pkg/util/field_array.go @@ -0,0 +1,148 @@ +package util + +import ( + "io" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +// Array provides a generice interface to an array of elements. Typically, we +// are interested in arrays of field elements here. +type Array[T comparable] interface { + // Returns the number of elements in this array. + Len() uint + // Get returns the element at the given index in this array. + Get(uint) T + // Set the element at the given index in this array, overwriting the + // original value. + Set(uint, T) + // Clone makes clones of this array producing an otherwise identical copy. + Clone() Array[T] + // Return the number of bytes required to store an element of this array. + ByteWidth() uint + // Insert a given number of copies of T at start of array producing an + // updated array. + PadFront(uint, T) Array[T] + // Write out the contents of this array, assuming a minimal unit of 1 byte + // per element. + Write(w io.Writer) error +} + +// ---------------------------------------------------------------------------- + +// FieldArray implements an array of field elements using an underlying +// byte array. Each element occupies a fixed number of bytes, known as the +// width. This is space efficient when a known upper bound holds for the given +// elements. For example, when storing elements which always fit within 16bits, +// etc. +type FieldArray struct { + // The data stored in this column (as bytes). + bytes []byte + // The number of data elements in this column. + height uint + // Determines how many bytes each field element takes. For the BLS12-377 + // curve, this should be 32. In the future, when other curves are + // supported, this could be less. + width uint8 +} + +// NewFieldArray constructs a new field array with a given capacity. +func NewFieldArray(height uint, width uint8) *FieldArray { + bytes := make([]byte, height*uint(width)) + return &FieldArray{bytes, height, width} +} + +// FieldArrayFromBigInts converts an array of big integers into an array of +// field elements. +func FieldArrayFromBigInts(width uint8, ints []*big.Int) *FieldArray { + elements := NewFieldArray(uint(len(ints)), width) + // Convert each integer in turn. + for i, v := range ints { + element := new(fr.Element) + element.SetBigInt(v) + elements.Set(uint(i), element) + } + + // Done. + return elements +} + +// Len returns the number of elements in this field array. +func (p *FieldArray) Len() uint { + return p.height +} + +// ByteWidth returns the width of elements in this array. +func (p *FieldArray) ByteWidth() uint { + return uint(p.width) +} + +// Get returns the field element at the given index in this array. +func (p *FieldArray) Get(index uint) *fr.Element { + if index >= p.height { + panic("out-of-bounds access") + } + // Element which will hold value. + var elem fr.Element + // Determine starting offset within bytes slice + start := uint(p.width) * index + end := start + uint(p.width) + // Construct field element. + return elem.SetBytes(p.bytes[start:end]) +} + +// Set sets the field element at the given index in this array, overwriting the +// original value. +func (p *FieldArray) Set(index uint, element *fr.Element) { + bytes := element.Bytes() + // Determine starting offset within bytes slice + bytes_start := uint(p.width) * index + bytes_end := bytes_start + uint(p.width) + elem_start := 32 - p.width + // Copy data + copy(p.bytes[bytes_start:bytes_end], bytes[elem_start:]) +} + +// Clone makes clones of this array producing an otherwise identical copy. +func (p *FieldArray) Clone() Array[*fr.Element] { + n := len(p.bytes) + nbytes := make([]byte, n) + copy(nbytes, p.bytes) + // Done + return &FieldArray{nbytes, p.height, p.width} +} + +// PadFront (i.e. insert at the beginning) this array with n copies of the given padding value. +func (p *FieldArray) PadFront(n uint, padding *fr.Element) Array[*fr.Element] { + // Computing padding length (in bytes) + padding_len := n * uint(p.width) + // Access bytes to use for padding + padding_bytes := padding.Bytes() + padded_bytes := make([]byte, padding_len+uint(len(p.bytes))) + // Append padding + offset := 0 + + for i := uint(0); i < n; i++ { + // Calculate starting position within the 32byte array, remembering that + // padding_bytes is stored in _big endian_ format meaning + // padding_bytes[0] is the _most significant_ byte. + start := 32 - p.width + // Copy over least significant bytes + for j := start; j < 32; j++ { + padded_bytes[offset] = padding_bytes[j] + offset++ + } + } + // Copy over original data + copy(padded_bytes[padding_len:], p.bytes) + // Done + return &FieldArray{padded_bytes, p.height + n, p.width} +} + +// Write the raw bytes of this column to a given writer, returning an error +// if this failed (for some reason). +func (p *FieldArray) Write(w io.Writer) error { + _, err := w.Write(p.bytes) + return err +} diff --git a/pkg/util/fields.go b/pkg/util/fields.go index bfd66b42..06bfea43 100644 --- a/pkg/util/fields.go +++ b/pkg/util/fields.go @@ -1,25 +1,9 @@ package util import ( - "math/big" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" ) -// ToFieldElements converts an array of big integers into an array of field elements. -func ToFieldElements(ints []*big.Int) []*fr.Element { - elements := make([]*fr.Element, len(ints)) - // Convert each integer in turn. - for i, v := range ints { - element := new(fr.Element) - element.SetBigInt(v) - elements[i] = element - } - - // Done. - return elements -} - // Pow takes a given value to the power n. func Pow(val *fr.Element, n uint64) { if n == 0 { diff --git a/pkg/util/permutation.go b/pkg/util/permutation.go index 73474f27..b4500024 100644 --- a/pkg/util/permutation.go +++ b/pkg/util/permutation.go @@ -16,13 +16,13 @@ import ( // // This function operators by cloning the arrays, sorting them and checking they // are the same. -func ArePermutationOf(dst [][]*fr.Element, src [][]*fr.Element) bool { +func ArePermutationOf[T Array[*fr.Element]](dst []T, src []T) bool { if len(dst) != len(src) { return false } // Determine geometry ncols := len(dst) - nrows := len(dst[0]) + nrows := dst[0].Len() // Rotate input arrays dstCopy := rotate(dst, ncols, nrows) srcCopy := rotate(src, ncols, nrows) @@ -61,8 +61,8 @@ func permutationFunc(lhs []*fr.Element, rhs []*fr.Element) int { // NOTE: the current implementation is not intended to be particularly // efficient. In particular, would be better to do the sort directly // on the columns array without projecting into the row-wise form. -func PermutationSort(cols [][]*fr.Element, signs []bool) { - n := len(cols[0]) +func PermutationSort[T Array[*fr.Element]](cols []T, signs []bool) { + n := cols[0].Len() m := len(cols) // Rotate input matrix rows := rotate(cols, m, n) @@ -71,10 +71,10 @@ func PermutationSort(cols [][]*fr.Element, signs []bool) { return permutationSortFunc(l, r, signs) }) // Project back - for i := 0; i < n; i++ { + for i := uint(0); i < n; i++ { row := rows[i] for j := 0; j < m; j++ { - cols[j][i] = row[j] + cols[j].Set(i, row[j]) } } } @@ -123,14 +123,14 @@ func permutationSortFunc(lhs []*fr.Element, rhs []*fr.Element, signs []bool) int } // Clone and rotate a 2-dimensional array assuming a given geometry. -func rotate(src [][]*fr.Element, ncols int, nrows int) [][]*fr.Element { +func rotate[T Array[*fr.Element]](src []T, ncols int, nrows uint) [][]*fr.Element { // Copy outer arrays dst := make([][]*fr.Element, nrows) // Copy inner arrays - for i := 0; i < nrows; i++ { + for i := uint(0); i < nrows; i++ { row := make([]*fr.Element, ncols) for j := 0; j < ncols; j++ { - row[j] = src[j][i] + row[j] = src[j].Get(i) } dst[i] = row From 1a6fafc26e49fe02b22ca357332d68e3d720aa98 Mon Sep 17 00:00:00 2001 From: DavePearce Date: Tue, 16 Jul 2024 16:34:40 +1200 Subject: [PATCH 2/2] Support `FrElementArray` Time is best when using this. For short traces, it also seems like about 1.5s is spent reading the `bin` file and `1.7s` lowering. So, there are optimisations to be achieved there. --- pkg/cmd/check.go | 3 +- pkg/cmd/trace.go | 2 +- pkg/schema/assignment/byte_decomposition.go | 4 +- pkg/schema/assignment/computed_column.go | 2 +- pkg/schema/assignment/interleave.go | 2 +- pkg/schema/assignment/lexicographic_sort.go | 6 +- pkg/schema/assignment/sorted_permutation.go | 3 +- pkg/schema/constraint/permutation.go | 5 +- pkg/trace/array_trace.go | 8 +- pkg/trace/builder.go | 2 +- pkg/trace/json/reader.go | 2 +- pkg/trace/lt/reader.go | 4 +- pkg/trace/trace.go | 4 +- pkg/util/field_array.go | 151 ++++++++++++++++---- 14 files changed, 144 insertions(+), 54 deletions(-) diff --git a/pkg/cmd/check.go b/pkg/cmd/check.go index 39e1c50c..2714b453 100644 --- a/pkg/cmd/check.go +++ b/pkg/cmd/check.go @@ -162,6 +162,7 @@ func checkTraceWithLoweringDefault(tr trace.Trace, hirSchema *hir.Schema, cfg ch } func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, error) { + // if cfg.expand { // Clone to prevent interefence with subsequent checks tr = tr.Clone() @@ -186,12 +187,10 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, return tr, err } } - // Perform Alignment if err := performAlignment(false, tr, schema, cfg); err != nil { return tr, err } - // Apply padding (as necessary) for n := cfg.padding.Left; n <= cfg.padding.Right; n++ { if ptr, err := padAndCheckTrace(n, tr, schema); err != nil { diff --git a/pkg/cmd/trace.go b/pkg/cmd/trace.go index f0178c85..37565b54 100644 --- a/pkg/cmd/trace.go +++ b/pkg/cmd/trace.go @@ -98,7 +98,7 @@ func listColumns(tr trace.Trace) { for i := uint(0); i < n; i++ { ith := tr.Columns().Get(i).Data() elems := fmt.Sprintf("%d rows", ith.Len()) - bytes := fmt.Sprintf("%d bytes", ith.ByteWidth()*ith.Len()) + bytes := fmt.Sprintf("(%d*%d) = %d bytes", ith.Len(), ith.ByteWidth(), ith.ByteWidth()*ith.Len()) tbl.SetRow(i, QualifiedColumnName(i, tr), elems, bytes) } diff --git a/pkg/schema/assignment/byte_decomposition.go b/pkg/schema/assignment/byte_decomposition.go index 70cdefb1..b752f2eb 100644 --- a/pkg/schema/assignment/byte_decomposition.go +++ b/pkg/schema/assignment/byte_decomposition.go @@ -69,11 +69,11 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error { // Identify source column source := columns.Get(p.source) // Construct byte column data - cols := make([]*util.FieldArray, n) + cols := make([]util.FrArray, n) // Initialise columns for i := 0; i < n; i++ { // Construct a byte column for ith byte - cols[i] = util.NewFieldArray(source.Height(), 1) + cols[i] = util.NewFrArray(source.Height(), 1) } // Decompose each row of each column for i := uint(0); i < source.Height(); i = i + 1 { diff --git a/pkg/schema/assignment/computed_column.go b/pkg/schema/assignment/computed_column.go index a50bada5..051b11c4 100644 --- a/pkg/schema/assignment/computed_column.go +++ b/pkg/schema/assignment/computed_column.go @@ -87,7 +87,7 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error { // Determine multiplied height height := tr.Modules().Get(p.target.Context().Module()).Height() * multiplier // Make space for computed data - data := util.NewFieldArray(height, 32) + data := util.NewFrArray(height, 32) // Expand the trace for i := uint(0); i < data.Len(); i++ { val := p.expr.EvalAt(int(i), tr) diff --git a/pkg/schema/assignment/interleave.go b/pkg/schema/assignment/interleave.go index 58de5272..1058548a 100644 --- a/pkg/schema/assignment/interleave.go +++ b/pkg/schema/assignment/interleave.go @@ -92,7 +92,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error { // the interleaved column) height := tr.Modules().Get(ctx.Module()).Height() * multiplier // Construct empty array - data := util.NewFieldArray(height*width, uint8(byte_width)) + data := util.NewFrArray(height*width, byte_width) // Offset just gives the column index offset := uint(0) // Copy interleaved data diff --git a/pkg/schema/assignment/lexicographic_sort.go b/pkg/schema/assignment/lexicographic_sort.go index efb6403c..f0510909 100644 --- a/pkg/schema/assignment/lexicographic_sort.go +++ b/pkg/schema/assignment/lexicographic_sort.go @@ -80,19 +80,19 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error { // Determine how many rows to be constrained. nrows := tr.Modules().Get(p.context.Module()).Height() * multiplier // Initialise new data columns - bit := make([]*util.FieldArray, ncols) + bit := make([]util.FrArray, ncols) // Byte width records the largest width of any column. byte_width := uint(0) for i := 0; i < ncols; i++ { // TODO: following can be optimised to use a single bit per element, // rather than an entire byte. - bit[i] = util.NewFieldArray(nrows, 1) + bit[i] = util.NewFrArray(nrows, 1) ith := columns.Get(p.sources[i]) byte_width = max(byte_width, ith.Data().ByteWidth()) } - delta := util.NewFieldArray(nrows, uint8(byte_width)) + delta := util.NewFrArray(nrows, byte_width) for i := uint(0); i < nrows; i++ { set := false diff --git a/pkg/schema/assignment/sorted_permutation.go b/pkg/schema/assignment/sorted_permutation.go index 721ac096..3ecd2546 100644 --- a/pkg/schema/assignment/sorted_permutation.go +++ b/pkg/schema/assignment/sorted_permutation.go @@ -3,7 +3,6 @@ package assignment import ( "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/trace" tr "github.com/consensys/go-corset/pkg/trace" @@ -132,7 +131,7 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error { } } - cols := make([]util.Array[*fr.Element], len(p.sources)) + cols := make([]util.FrArray, len(p.sources)) // Construct target columns for i := 0; i < len(p.sources); i++ { src := p.sources[i] diff --git a/pkg/schema/constraint/permutation.go b/pkg/schema/constraint/permutation.go index 1b0d1a68..8a91b18c 100644 --- a/pkg/schema/constraint/permutation.go +++ b/pkg/schema/constraint/permutation.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -72,9 +71,9 @@ func (p *PermutationConstraint) String() string { return fmt.Sprintf("(permutation (%s) (%s))", targets, sources) } -func sliceColumns(columns []uint, tr trace.Trace) []util.Array[*fr.Element] { +func sliceColumns(columns []uint, tr trace.Trace) []util.FrArray { // Allocate return array - cols := make([]util.Array[*fr.Element], len(columns)) + cols := make([]util.FrArray, len(columns)) // Slice out the data for i, n := range columns { nth := tr.Columns().Get(n) diff --git a/pkg/trace/array_trace.go b/pkg/trace/array_trace.go index c59f060a..6c845d4b 100644 --- a/pkg/trace/array_trace.go +++ b/pkg/trace/array_trace.go @@ -103,7 +103,7 @@ type arrayTraceColumnSet struct { } // Add a new column to this column set. -func (p arrayTraceColumnSet) Add(ctx Context, name string, data util.Array[*fr.Element], padding *fr.Element) uint { +func (p arrayTraceColumnSet) Add(ctx Context, name string, data util.FrArray, padding *fr.Element) uint { m := &p.trace.modules[ctx.Module()] // Sanity check effective height if data.Len() != (ctx.LengthMultiplier() * m.Height()) { @@ -253,13 +253,13 @@ type ArrayTraceColumn struct { // Holds the name of this column name string // Holds the raw data making up this column - data util.Array[*fr.Element] + data util.FrArray // Value to be used when padding this column padding *fr.Element } // NewArrayTraceColumn constructs a ArrayTraceColumn with the give name, data and padding. -func NewArrayTraceColumn(context Context, name string, data util.Array[*fr.Element], +func NewArrayTraceColumn(context Context, name string, data util.FrArray, padding *fr.Element) *ArrayTraceColumn { // Sanity check data length if data.Len()%context.LengthMultiplier() != 0 { @@ -290,7 +290,7 @@ func (p *ArrayTraceColumn) Padding() *fr.Element { } // Data provides access to the underlying data of this column -func (p *ArrayTraceColumn) Data() util.Array[*fr.Element] { +func (p *ArrayTraceColumn) Data() util.FrArray { return p.data } diff --git a/pkg/trace/builder.go b/pkg/trace/builder.go index 77b6d4b7..bedbcd60 100644 --- a/pkg/trace/builder.go +++ b/pkg/trace/builder.go @@ -40,7 +40,7 @@ func (p *Builder) Build() Trace { // Add a new column to this trace based on a fully qualified column name. This // splits the qualified column name and (if necessary) registers a new module // with the given height. -func (p *Builder) Add(name string, padding *fr.Element, data util.Array[*fr.Element]) error { +func (p *Builder) Add(name string, padding *fr.Element, data util.FrArray) error { var err error // Split qualified column name modname, colname := p.splitQualifiedColumnName(name) diff --git a/pkg/trace/json/reader.go b/pkg/trace/json/reader.go index a1960034..6603d37c 100644 --- a/pkg/trace/json/reader.go +++ b/pkg/trace/json/reader.go @@ -27,7 +27,7 @@ func FromBytes(bytes []byte) (trace.Trace, error) { for name, rawInts := range rawData { // Translate raw bigints into raw field elements // TODO: support native field widths in column name. - rawElements := util.FieldArrayFromBigInts(32, rawInts) + rawElements := util.FrArrayFromBigInts(32, rawInts) // Add column and sanity check for errors if err := builder.Add(name, &zero, rawElements); err != nil { return nil, err diff --git a/pkg/trace/lt/reader.go b/pkg/trace/lt/reader.go index e1c8668a..126f9aa0 100644 --- a/pkg/trace/lt/reader.go +++ b/pkg/trace/lt/reader.go @@ -94,8 +94,8 @@ func readColumnHeader(buf *bytes.Reader) (columnHeader, error) { return header, nil } -func readColumnData(header columnHeader, bytes []byte) *util.FieldArray { - data := util.NewFieldArray(header.length, uint8(header.width)) +func readColumnData(header columnHeader, bytes []byte) util.FrArray { + data := util.NewFrArray(header.length, header.width) offset := uint(0) for i := uint(0); i < header.length; i++ { diff --git a/pkg/trace/trace.go b/pkg/trace/trace.go index 20eb535f..122c39bd 100644 --- a/pkg/trace/trace.go +++ b/pkg/trace/trace.go @@ -19,7 +19,7 @@ type Trace interface { // ColumnSet provides an interface to the declared columns within this trace. type ColumnSet interface { // Add a new column to this column set. - Add(ctx Context, name string, data util.Array[*fr.Element], padding *fr.Element) uint + Add(ctx Context, name string, data util.FrArray, padding *fr.Element) uint // Get the ith module in this set. Get(uint) Column // Determine index of given column, or return false if this fails. @@ -40,7 +40,7 @@ type Column interface { // Thus, this function always succeeds. Get(row int) *fr.Element // Access the underlying data array for this column - Data() util.Array[*fr.Element] + Data() util.FrArray // Return the height (i.e. number of rows) of this column. Height() uint // Returns the evaluation context for this column. That identifies the diff --git a/pkg/util/field_array.go b/pkg/util/field_array.go index 732c1b99..e85f89e2 100644 --- a/pkg/util/field_array.go +++ b/pkg/util/field_array.go @@ -31,12 +31,42 @@ type Array[T comparable] interface { // ---------------------------------------------------------------------------- -// FieldArray implements an array of field elements using an underlying +// FrArray represents an array of field elements. +type FrArray = Array[*fr.Element] + +// NewFrArray creates a new FrArray dynamically based on the given width. +func NewFrArray(height uint, width uint) FrArray { + switch width { + case 1, 2: + return NewFrByteArray(height, uint8(width)) + default: + return NewFrElementArray(height) + } +} + +// FrArrayFromBigInts converts an array of big integers into an array of +// field elements. +func FrArrayFromBigInts(width uint, ints []*big.Int) FrArray { + elements := NewFrArray(uint(len(ints)), width) + // Convert each integer in turn. + for i, v := range ints { + element := new(fr.Element) + element.SetBigInt(v) + elements.Set(uint(i), element) + } + + // Done. + return elements +} + +// ---------------------------------------------------------------------------- + +// FrByteArray implements an array of field elements using an underlying // byte array. Each element occupies a fixed number of bytes, known as the // width. This is space efficient when a known upper bound holds for the given // elements. For example, when storing elements which always fit within 16bits, // etc. -type FieldArray struct { +type FrByteArray struct { // The data stored in this column (as bytes). bytes []byte // The number of data elements in this column. @@ -47,39 +77,24 @@ type FieldArray struct { width uint8 } -// NewFieldArray constructs a new field array with a given capacity. -func NewFieldArray(height uint, width uint8) *FieldArray { +// NewFrByteArray constructs a new field array with a given capacity. +func NewFrByteArray(height uint, width uint8) *FrByteArray { bytes := make([]byte, height*uint(width)) - return &FieldArray{bytes, height, width} -} - -// FieldArrayFromBigInts converts an array of big integers into an array of -// field elements. -func FieldArrayFromBigInts(width uint8, ints []*big.Int) *FieldArray { - elements := NewFieldArray(uint(len(ints)), width) - // Convert each integer in turn. - for i, v := range ints { - element := new(fr.Element) - element.SetBigInt(v) - elements.Set(uint(i), element) - } - - // Done. - return elements + return &FrByteArray{bytes, height, width} } // Len returns the number of elements in this field array. -func (p *FieldArray) Len() uint { +func (p *FrByteArray) Len() uint { return p.height } // ByteWidth returns the width of elements in this array. -func (p *FieldArray) ByteWidth() uint { +func (p *FrByteArray) ByteWidth() uint { return uint(p.width) } // Get returns the field element at the given index in this array. -func (p *FieldArray) Get(index uint) *fr.Element { +func (p *FrByteArray) Get(index uint) *fr.Element { if index >= p.height { panic("out-of-bounds access") } @@ -94,7 +109,7 @@ func (p *FieldArray) Get(index uint) *fr.Element { // Set sets the field element at the given index in this array, overwriting the // original value. -func (p *FieldArray) Set(index uint, element *fr.Element) { +func (p *FrByteArray) Set(index uint, element *fr.Element) { bytes := element.Bytes() // Determine starting offset within bytes slice bytes_start := uint(p.width) * index @@ -105,16 +120,16 @@ func (p *FieldArray) Set(index uint, element *fr.Element) { } // Clone makes clones of this array producing an otherwise identical copy. -func (p *FieldArray) Clone() Array[*fr.Element] { +func (p *FrByteArray) Clone() Array[*fr.Element] { n := len(p.bytes) nbytes := make([]byte, n) copy(nbytes, p.bytes) // Done - return &FieldArray{nbytes, p.height, p.width} + return &FrByteArray{nbytes, p.height, p.width} } // PadFront (i.e. insert at the beginning) this array with n copies of the given padding value. -func (p *FieldArray) PadFront(n uint, padding *fr.Element) Array[*fr.Element] { +func (p *FrByteArray) PadFront(n uint, padding *fr.Element) Array[*fr.Element] { // Computing padding length (in bytes) padding_len := n * uint(p.width) // Access bytes to use for padding @@ -137,12 +152,90 @@ func (p *FieldArray) PadFront(n uint, padding *fr.Element) Array[*fr.Element] { // Copy over original data copy(padded_bytes[padding_len:], p.bytes) // Done - return &FieldArray{padded_bytes, p.height + n, p.width} + return &FrByteArray{padded_bytes, p.height + n, p.width} } // Write the raw bytes of this column to a given writer, returning an error // if this failed (for some reason). -func (p *FieldArray) Write(w io.Writer) error { +func (p *FrByteArray) Write(w io.Writer) error { _, err := w.Write(p.bytes) return err } + +// ---------------------------------------------------------------------------- + +// FrElementArray implements an array of field elements using an underlying +// byte array. Each element occupies a fixed number of bytes, known as the +// width. This is space efficient when a known upper bound holds for the given +// elements. For example, when storing elements which always fit within 16bits, +// etc. +type FrElementArray struct { + // The data stored in this column (as bytes). + elements []*fr.Element +} + +// NewFrElementArray constructs a new field array with a given capacity. +func NewFrElementArray(height uint) *FrElementArray { + elements := make([]*fr.Element, height) + return &FrElementArray{elements} +} + +// Len returns the number of elements in this field array. +func (p *FrElementArray) Len() uint { + return uint(len(p.elements)) +} + +// ByteWidth returns the width of elements in this array. +func (p *FrElementArray) ByteWidth() uint { + return 32 +} + +// Get returns the field element at the given index in this array. +func (p *FrElementArray) Get(index uint) *fr.Element { + return p.elements[index] +} + +// Set sets the field element at the given index in this array, overwriting the +// original value. +func (p *FrElementArray) Set(index uint, element *fr.Element) { + p.elements[index] = element +} + +// Clone makes clones of this array producing an otherwise identical copy. +func (p *FrElementArray) Clone() Array[*fr.Element] { + // Allocate sufficient memory + ndata := make([]*fr.Element, uint(len(p.elements))) + // Copy over the data + copy(ndata, p.elements) + // + return &FrElementArray{ndata} +} + +// PadFront (i.e. insert at the beginning) this array with n copies of the given padding value. +func (p *FrElementArray) PadFront(n uint, padding *fr.Element) Array[*fr.Element] { + // Allocate sufficient memory + ndata := make([]*fr.Element, uint(len(p.elements))+n) + // Copy over the data + copy(ndata[n:], p.elements) + // Go padding! + for i := uint(0); i < n; i++ { + ndata[i] = padding + } + // Copy over + return &FrElementArray{ndata} +} + +// Write the raw bytes of this column to a given writer, returning an error +// if this failed (for some reason). +func (p *FrElementArray) Write(w io.Writer) error { + for _, e := range p.elements { + // Read exactly 32 bytes + bytes := e.Bytes() + // Write them out + if _, err := w.Write(bytes[:]); err != nil { + return err + } + } + // + return nil +}