From 49696fe1ca3a8084fa659678a7d0a404b390e5b3 Mon Sep 17 00:00:00 2001 From: DavePearce Date: Thu, 25 Jul 2024 20:27:30 +1200 Subject: [PATCH] Fix for Field Array Gremlin This puts through a relatively straightforward fix relating to permutations. The problem was a pointer into an array which was being sorted. --- pkg/schema/constraint/permutation.go | 18 ++-- pkg/trace/util.go | 18 ++++ pkg/util/arrays.go | 4 +- pkg/util/field_array.go | 144 ++++++++++++++++++++++++++- pkg/util/permutation.go | 24 ++--- 5 files changed, 185 insertions(+), 23 deletions(-) diff --git a/pkg/schema/constraint/permutation.go b/pkg/schema/constraint/permutation.go index 8a91b18c..1d71533d 100644 --- a/pkg/schema/constraint/permutation.go +++ b/pkg/schema/constraint/permutation.go @@ -39,13 +39,19 @@ func (p *PermutationConstraint) Accepts(tr trace.Trace) error { src := sliceColumns(p.sources, tr) dst := sliceColumns(p.targets, tr) // Sanity check whether column exists - if !util.ArePermutationOf(dst, src) { - msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})", - p.targets, p.sources) - return errors.New(msg) + if util.ArePermutationOf(dst, src) { + // Success + return nil } - // Success - return nil + // Prepare suitable error message + src_names := trace.QualifiedColumnNamesToCommaSeparatedString(p.sources, tr) + dst_names := trace.QualifiedColumnNamesToCommaSeparatedString(p.targets, tr) + // + msg := fmt.Sprintf("Target columns (%s) not permutation of source columns (%s)", + dst_names, src_names) + // Done + return errors.New(msg) + } func (p *PermutationConstraint) String() string { diff --git a/pkg/trace/util.go b/pkg/trace/util.go index 99411c11..905eaf71 100644 --- a/pkg/trace/util.go +++ b/pkg/trace/util.go @@ -1,5 +1,7 @@ package trace +import "strings" + // PadColumns pads every column in a given trace with a given amount of padding. func PadColumns(tr Trace, padding uint) { modules := tr.Modules() @@ -21,3 +23,19 @@ func MaxHeight(tr Trace) uint { // Done return h } + +// QualifiedColumnNamesToCommaSeparatedString produces a suitable string for use +// in error messages from a list of one or more column identifies. +func QualifiedColumnNamesToCommaSeparatedString(columns []uint, trace Trace) string { + var names strings.Builder + + for i, c := range columns { + if i != 0 { + names.WriteString(",") + } + + names.WriteString(trace.Columns().Get(c).Name()) + } + // Done + return names.String() +} diff --git a/pkg/util/arrays.go b/pkg/util/arrays.go index 47557d89..0557c865 100644 --- a/pkg/util/arrays.go +++ b/pkg/util/arrays.go @@ -38,7 +38,7 @@ func Equals(lhs []*fr.Element, rhs []*fr.Element) bool { } // Equals2d returns true if two 2D arrays are equal. -func Equals2d(lhs [][]*fr.Element, rhs [][]*fr.Element) bool { +func Equals2d(lhs [][]fr.Element, rhs [][]fr.Element) bool { if len(lhs) != len(rhs) { return false } @@ -52,7 +52,7 @@ func Equals2d(lhs [][]*fr.Element, rhs [][]*fr.Element) bool { } // Check elements match for j := 0; j < len(lhs_i); j++ { - if lhs_i[j].Cmp(rhs_i[j]) != 0 { + if lhs_i[j].Cmp(&rhs_i[j]) != 0 { return false } } diff --git a/pkg/util/field_array.go b/pkg/util/field_array.go index 9ffec864..b8ccd701 100644 --- a/pkg/util/field_array.go +++ b/pkg/util/field_array.go @@ -3,6 +3,7 @@ package util import ( "io" "math/big" + "strings" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" ) @@ -47,9 +48,10 @@ func NewFrArray(height uint, bitWidth uint) FrArray { var pool FrIndexPool[uint16] = NewFrIndexPool[uint16]() return NewFrPoolArray[uint16](height, bitWidth, pool) default: - // return NewFrElementArray(height, bitWidth) - var pool FrMapPool = NewFrMapPool(bitWidth) - return NewFrPoolArray[uint32](height, bitWidth, pool) + return NewFrElementArray(height, bitWidth) + //return NewFrPtrElementArray(height, bitWidth) + // var pool FrMapPool = NewFrMapPool(bitWidth) + // return NewFrPoolArray[uint32](height, bitWidth, pool) } } @@ -148,6 +150,122 @@ func (p *FrElementArray) Write(w io.Writer) error { return nil } +func (p *FrElementArray) String() string { + var sb strings.Builder + + sb.WriteString("[") + + for i := 0; i < len(p.elements); i++ { + if i != 0 { + sb.WriteString(",") + } + + sb.WriteString(p.elements[i].String()) + } + + sb.WriteString("]") + + return sb.String() +} + +// ---------------------------------------------------------------------------- + +// FrPtrElementArray implements an array of field elements using an underlying +// byte array. Each element occupies a fixed number of bytes, known as the +// width. This is space efficient when a known upper bound holds for the given +// elements. For example, when storing elements which always fit within 16bits, +// etc. +type FrPtrElementArray struct { + // The data stored in this column (as bytes). + elements []*fr.Element + // Maximum number of bits required to store an element of this array. + bitwidth uint +} + +// NewFrPtrElementArray constructs a new field array with a given capacity. +func NewFrPtrElementArray(height uint, bitwidth uint) *FrPtrElementArray { + elements := make([]*fr.Element, height) + return &FrPtrElementArray{elements, bitwidth} +} + +// Len returns the number of elements in this field array. +func (p *FrPtrElementArray) Len() uint { + return uint(len(p.elements)) +} + +// BitWidth returns the width (in bits) of elements in this array. +func (p *FrPtrElementArray) BitWidth() uint { + return p.bitwidth +} + +// Get returns the field element at the given index in this array. +func (p *FrPtrElementArray) Get(index uint) *fr.Element { + return p.elements[index] +} + +// Set sets the field element at the given index in this array, overwriting the +// original value. +func (p *FrPtrElementArray) Set(index uint, element *fr.Element) { + p.elements[index] = element +} + +// Clone makes clones of this array producing an otherwise identical copy. +func (p *FrPtrElementArray) Clone() Array[*fr.Element] { + // Allocate sufficient memory + ndata := make([]*fr.Element, uint(len(p.elements))) + // Copy over the data + copy(ndata, p.elements) + // + return &FrPtrElementArray{ndata, p.bitwidth} +} + +// PadFront (i.e. insert at the beginning) this array with n copies of the given padding value. +func (p *FrPtrElementArray) PadFront(n uint, padding *fr.Element) Array[*fr.Element] { + // Allocate sufficient memory + ndata := make([]*fr.Element, uint(len(p.elements))+n) + // Copy over the data + copy(ndata[n:], p.elements) + // Go padding! + for i := uint(0); i < n; i++ { + ndata[i] = padding + } + // Copy over + return &FrPtrElementArray{ndata, p.bitwidth} +} + +// Write the raw bytes of this column to a given writer, returning an error +// if this failed (for some reason). +func (p *FrPtrElementArray) Write(w io.Writer) error { + for _, e := range p.elements { + // Read exactly 32 bytes + bytes := e.Bytes() + // Write them out + if _, err := w.Write(bytes[:]); err != nil { + return err + } + } + // + return nil +} + +func (p *FrPtrElementArray) String() string { + var sb strings.Builder + + sb.WriteString("[") + + for i := 0; i < len(p.elements); i++ { + if i != 0 { + sb.WriteString(",") + } + + sb.WriteString(p.elements[i].String()) + } + + sb.WriteString("]") + + return sb.String() +} + // ---------------------------------------------------------------------------- // FrPoolArray implements an array of field elements using an index to pool @@ -231,3 +349,23 @@ func (p *FrPoolArray[K, P]) Write(w io.Writer) error { // return nil } + +//nolint:rev +func (p *FrPoolArray[K, P]) String() string { + var sb strings.Builder + + sb.WriteString("[") + + for i := 0; i < len(p.elements); i++ { + if i != 0 { + sb.WriteString(",") + } + + index := p.elements[i] + sb.WriteString(p.pool.Get(index).String()) + } + + sb.WriteString("[") + + return sb.String() +} diff --git a/pkg/util/permutation.go b/pkg/util/permutation.go index b4500024..55520cd3 100644 --- a/pkg/util/permutation.go +++ b/pkg/util/permutation.go @@ -33,10 +33,10 @@ func ArePermutationOf[T Array[*fr.Element]](dst []T, src []T) bool { return Equals2d(dstCopy, srcCopy) } -func permutationFunc(lhs []*fr.Element, rhs []*fr.Element) int { +func permutationFunc(lhs []fr.Element, rhs []fr.Element) int { for i := 0; i < len(lhs); i++ { // Compare ith elements - c := lhs[i].Cmp(rhs[i]) + c := lhs[i].Cmp(&rhs[i]) // Check whether same if c != 0 { // Positive @@ -67,14 +67,14 @@ func PermutationSort[T Array[*fr.Element]](cols []T, signs []bool) { // Rotate input matrix rows := rotate(cols, m, n) // Perform the permutation sort - slices.SortFunc(rows, func(l []*fr.Element, r []*fr.Element) int { + slices.SortFunc(rows, func(l []fr.Element, r []fr.Element) int { return permutationSortFunc(l, r, signs) }) // Project back for i := uint(0); i < n; i++ { row := rows[i] for j := 0; j < m; j++ { - cols[j].Set(i, row[j]) + cols[j].Set(i, &row[j]) } } } @@ -82,14 +82,14 @@ func PermutationSort[T Array[*fr.Element]](cols []T, signs []bool) { // AreLexicographicallySorted checks whether one or more columns are // lexicographically sorted according to the given signs. This operation does // not modify or clone either array. -func AreLexicographicallySorted(cols [][]*fr.Element, signs []bool) bool { +func AreLexicographicallySorted(cols [][]fr.Element, signs []bool) bool { ncols := len(cols) nrows := len(cols[0]) for i := 1; i < nrows; i++ { for j := 0; j < ncols; j++ { // Compare ith elements - c := cols[j][i].Cmp(cols[j][i-1]) + c := cols[j][i].Cmp(&cols[j][i-1]) // Check whether same if signs[j] && c < 0 { return false @@ -104,10 +104,10 @@ func AreLexicographicallySorted(cols [][]*fr.Element, signs []bool) bool { return true } -func permutationSortFunc(lhs []*fr.Element, rhs []*fr.Element, signs []bool) int { +func permutationSortFunc(lhs []fr.Element, rhs []fr.Element, signs []bool) int { for i := 0; i < len(lhs); i++ { // Compare ith elements - c := lhs[i].Cmp(rhs[i]) + c := lhs[i].Cmp(&rhs[i]) // Check whether same if c != 0 { if signs[i] { @@ -123,14 +123,14 @@ func permutationSortFunc(lhs []*fr.Element, rhs []*fr.Element, signs []bool) int } // Clone and rotate a 2-dimensional array assuming a given geometry. -func rotate[T Array[*fr.Element]](src []T, ncols int, nrows uint) [][]*fr.Element { +func rotate[T Array[*fr.Element]](src []T, ncols int, nrows uint) [][]fr.Element { // Copy outer arrays - dst := make([][]*fr.Element, nrows) + dst := make([][]fr.Element, nrows) // Copy inner arrays for i := uint(0); i < nrows; i++ { - row := make([]*fr.Element, ncols) + row := make([]fr.Element, ncols) for j := 0; j < ncols; j++ { - row[j] = src[j].Get(i) + row[j] = *src[j].Get(i) } dst[i] = row