Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FrArray #243

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ func checkTraceWithLoweringDefault(tr trace.Trace, hirSchema *hir.Schema, cfg ch
}

func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, error) {
//
if cfg.expand {
// Clone to prevent interefence with subsequent checks
tr = tr.Clone()
Expand All @@ -186,12 +187,10 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace,
return tr, err
}
}

// Perform Alignment
if err := performAlignment(false, tr, schema, cfg); err != nil {
return tr, err
}

// Apply padding (as necessary)
for n := cfg.padding.Left; n <= cfg.padding.Right; n++ {
if ptr, err := padAndCheckTrace(n, tr, schema); err != nil {
Expand Down
14 changes: 4 additions & 10 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"os"
"strings"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"

Expand Down Expand Up @@ -80,12 +79,7 @@ func filterColumns(tr trace.Trace, prefix string) trace.Trace {
if strings.HasPrefix(qName, prefix) {
ith := tr.Columns().Get(i)
// Copy column data
data := make([]*fr.Element, ith.Height())
//
for j := 0; j < int(ith.Height()); j++ {
data[j] = ith.Get(j)
}

data := ith.Data().Clone()
err := builder.Add(qName, ith.Padding(), data)
// Sanity check
if err != nil {
Expand All @@ -102,9 +96,9 @@ func listColumns(tr trace.Trace) {
tbl := util.NewTablePrinter(3, n)

for i := uint(0); i < n; i++ {
ith := tr.Columns().Get(i)
elems := fmt.Sprintf("%d rows", ith.Height())
bytes := fmt.Sprintf("%d bytes", ith.Width()*ith.Height())
ith := tr.Columns().Get(i).Data()
elems := fmt.Sprintf("%d rows", ith.Len())
bytes := fmt.Sprintf("(%d*%d) = %d bytes", ith.Len(), ith.ByteWidth(), ith.ByteWidth()*ith.Len())
tbl.SetRow(i, QualifiedColumnName(i, tr), elems, bytes)
}

Expand Down
13 changes: 7 additions & 6 deletions pkg/schema/assignment/byte_decomposition.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,25 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error {
// Identify source column
source := columns.Get(p.source)
// Construct byte column data
cols := make([][]*fr.Element, n)
cols := make([]util.FrArray, n)
// Initialise columns
for i := 0; i < n; i++ {
cols[i] = make([]*fr.Element, source.Height())
// Construct a byte column for ith byte
cols[i] = util.NewFrArray(source.Height(), 1)
}
// Decompose each row of each column
for i := 0; i < int(source.Height()); i = i + 1 {
ith := decomposeIntoBytes(source.Get(i), n)
for i := uint(0); i < source.Height(); i = i + 1 {
ith := decomposeIntoBytes(source.Get(int(i)), n)
for j := 0; j < n; j++ {
cols[j][i] = ith[j]
cols[j].Set(i, ith[j])
}
}
// Determine padding values
padding := decomposeIntoBytes(source.Padding(), n)
// Finally, add byte columns to trace
for i := 0; i < n; i++ {
ith := p.targets[i]
columns.Add(trace.NewFieldColumn(ith.Context(), ith.Name(), cols[i], padding[i]))
columns.Add(ith.Context(), ith.Name(), cols[i], padding[i])
}
// Done
return nil
Expand Down
12 changes: 6 additions & 6 deletions pkg/schema/assignment/computed_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,23 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
// Determine multiplied height
height := tr.Modules().Get(p.target.Context().Module()).Height() * multiplier
// Make space for computed data
data := make([]*fr.Element, height)
data := util.NewFrArray(height, 32)
// Expand the trace
for i := 0; i < len(data); i++ {
val := p.expr.EvalAt(i, tr)
for i := uint(0); i < data.Len(); i++ {
val := p.expr.EvalAt(int(i), tr)
if val != nil {
data[i] = val
data.Set(i, val)
} else {
zero := fr.NewElement(0)
data[i] = &zero
data.Set(i, &zero)
}
}
// Determine padding value. A negative row index is used here to ensure
// that all columns return their padding value which is then used to compute
// the padding value for *this* column.
padding := p.expr.EvalAt(-1, tr)
// Colunm needs to be expanded.
columns.Add(trace.NewFieldColumn(p.target.Context(), p.Name(), data, padding))
columns.Add(p.target.Context(), p.Name(), data, padding)
// Done
return nil
}
18 changes: 10 additions & 8 deletions pkg/schema/assignment/interleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ package assignment
import (
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -73,12 +71,16 @@ func (p *Interleaving) RequiredSpillage() uint {
func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
columns := tr.Columns()
ctx := p.target.Context()
// Byte width records the largest width of any column.
byte_width := uint(0)
// Ensure target column doesn't exist
for i := p.Columns(); i.HasNext(); {
name := i.Next().Name()
ith := i.Next()
// Update byte width
byte_width = max(byte_width, ith.Type().ByteWidth())
// Sanity check no column already exists with this name.
if _, ok := columns.IndexOf(ctx.Module(), name); ok {
return fmt.Errorf("interleaved column already exists ({%s})", name)
if _, ok := columns.IndexOf(ctx.Module(), ith.Name()); ok {
return fmt.Errorf("interleaved column already exists ({%s})", ith.Name())
}
}
// Determine interleaving width
Expand All @@ -90,7 +92,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
// the interleaved column)
height := tr.Modules().Get(ctx.Module()).Height() * multiplier
// Construct empty array
data := make([]*fr.Element, height*width)
data := util.NewFrArray(height*width, byte_width)
// Offset just gives the column index
offset := uint(0)
// Copy interleaved data
Expand All @@ -99,7 +101,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
col := tr.Columns().Get(p.sources[i])
// Copy over
for j := uint(0); j < height; j++ {
data[offset+(j*width)] = col.Get(int(j))
data.Set(offset+(j*width), col.Get(int(j)))
}

offset++
Expand All @@ -108,7 +110,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
// column in the interleaving.
padding := columns.Get(0).Padding()
// Colunm needs to be expanded.
columns.Add(trace.NewFieldColumn(ctx, p.target.Name(), data, padding))
columns.Add(ctx, p.target.Name(), data, padding)
//
return nil
}
33 changes: 20 additions & 13 deletions pkg/schema/assignment/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,48 +80,55 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error {
// Determine how many rows to be constrained.
nrows := tr.Modules().Get(p.context.Module()).Height() * multiplier
// Initialise new data columns
delta := make([]*fr.Element, nrows)
bit := make([][]*fr.Element, ncols)
bit := make([]util.FrArray, ncols)
// Byte width records the largest width of any column.
byte_width := uint(0)

for i := 0; i < ncols; i++ {
bit[i] = make([]*fr.Element, nrows)
// TODO: following can be optimised to use a single bit per element,
// rather than an entire byte.
bit[i] = util.NewFrArray(nrows, 1)
ith := columns.Get(p.sources[i])
byte_width = max(byte_width, ith.Data().ByteWidth())
}

for i := 0; i < int(nrows); i++ {
delta := util.NewFrArray(nrows, byte_width)

for i := uint(0); i < nrows; i++ {
set := false
// Initialise delta to zero
delta[i] = &zero
delta.Set(i, &zero)
// Decide which row is the winner (if any)
for j := 0; j < ncols; j++ {
prev := columns.Get(p.sources[j]).Get(i - 1)
curr := columns.Get(p.sources[j]).Get(i)
prev := columns.Get(p.sources[j]).Get(int(i - 1))
curr := columns.Get(p.sources[j]).Get(int(i))

if !set && prev != nil && prev.Cmp(curr) != 0 {
var diff fr.Element

bit[j][i] = &one
bit[j].Set(i, &one)
// Compute curr - prev
if p.signs[j] {
diff.Set(curr)
delta[i] = diff.Sub(&diff, prev)
delta.Set(i, diff.Sub(&diff, prev))
} else {
diff.Set(prev)
delta[i] = diff.Sub(&diff, curr)
delta.Set(i, diff.Sub(&diff, curr))
}

set = true
} else {
bit[j][i] = &zero
bit[j].Set(i, &zero)
}
}
}
// Add delta column data
first := p.targets[0]
columns.Add(trace.NewFieldColumn(first.Context(), first.Name(), delta, &zero))
columns.Add(first.Context(), first.Name(), delta, &zero)
// Add bit column data
for i := 0; i < ncols; i++ {
ith := p.targets[1+i]
columns.Add(trace.NewFieldColumn(ith.Context(), ith.Name(), bit[i], &zero))
columns.Add(ith.Context(), ith.Name(), bit[i], &zero)
}
// Done.
return nil
Expand Down
19 changes: 6 additions & 13 deletions pkg/schema/assignment/sorted_permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package assignment
import (
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
Expand Down Expand Up @@ -132,20 +131,14 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error {
}
}

cols := make([][]*fr.Element, len(p.sources))
cols := make([]util.FrArray, len(p.sources))
// Construct target columns
for i := 0; i < len(p.sources); i++ {
src := p.sources[i]
// Read column data to initialise permutation.
col := columns.Get(src)
// Copy column data to initialise permutation.
copy := make([]*fr.Element, col.Height())
//
for j := 0; j < int(col.Height()); j++ {
copy[j] = col.Get(j)
}
// Copy over
cols[i] = copy
// Read column data
data := columns.Get(src).Data()
// Clone it to initialise permutation.
cols[i] = data.Clone()
}
// Sort target columns
util.PermutationSort(cols, p.signs)
Expand All @@ -156,7 +149,7 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error {
ith := i.Next()
dstColName := ith.Name()
srcCol := tr.Columns().Get(p.sources[index])
columns.Add(trace.NewFieldColumn(ith.Context(), dstColName, cols[index], srcCol.Padding()))
columns.Add(ith.Context(), dstColName, cols[index], srcCol.Padding())
}
//
return nil
Expand Down
13 changes: 3 additions & 10 deletions pkg/schema/constraint/permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -72,20 +71,14 @@ func (p *PermutationConstraint) String() string {
return fmt.Sprintf("(permutation (%s) (%s))", targets, sources)
}

func sliceColumns(columns []uint, tr trace.Trace) [][]*fr.Element {
func sliceColumns(columns []uint, tr trace.Trace) []util.FrArray {
// Allocate return array
cols := make([][]*fr.Element, len(columns))
cols := make([]util.FrArray, len(columns))
// Slice out the data
for i, n := range columns {
nth := tr.Columns().Get(n)
// Copy column data to initialise permutation.
copy := make([]*fr.Element, nth.Height())
//
for j := 0; j < int(nth.Height()); j++ {
copy[j] = nth.Get(j)
}
// Copy over
cols[i] = copy
cols[i] = nth.Data()
}
// Done
return cols
Expand Down
22 changes: 22 additions & 0 deletions pkg/schema/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type Type interface {
// Accept checks whether a specific value is accepted by this type
Accept(*fr.Element) bool

// Return the number of bytes required represent any element of this type.
ByteWidth() uint

// Produce a string representation of this type.
String() string
}
Expand Down Expand Up @@ -59,6 +62,19 @@ func (p *UintType) AsField() *FieldType {
return nil
}

// ByteWidth returns the number of bytes required represent any element of this
// type.
func (p *UintType) ByteWidth() uint {
m := p.nbits / 8
n := p.nbits % 8
// Check for even division
if n == 0 {
return m
}
//
return m + 1
}

// Accept determines whether a given value is an element of this type. For
// example, 123 is an element of the type u8 whilst 256 is not.
func (p *UintType) Accept(val *fr.Element) bool {
Expand Down Expand Up @@ -104,6 +120,12 @@ func (p *FieldType) AsField() *FieldType {
return p
}

// ByteWidth returns the number of bytes required represent any element of this
// type.
func (p *FieldType) ByteWidth() uint {
return 32
}

// Accept determines whether a given value is an element of this type. In
// fact, all field elements are members of this type.
func (p *FieldType) Accept(val *fr.Element) bool {
Expand Down
Loading