Skip to content

Commit

Permalink
feat: binary constraints format (#514)
Browse files Browse the repository at this point in the history
* Add compile cli cmd

This adds a "compile" command to the CLI.  This is where compilation
will be controlled.  A notion of "legacy" versus "non-legacy" is patched
in.

This goal here is to retain the ability to read legacy bin files using
the original format (for now).  However, writing them is not really
supported at this time.

* support binary encoding / decoding

This adds support for binary encoding / decoding of HIR schema's via the
gob format.  To help ensure consistency between compiled source files
and encoded binary files, all tests are additionally run through a
filter which encodes and then decodes the file before running the test.

This also makes a breaking constraint by using an option to implement
the vanishing domain, instead of a pointer.  This seems to serialise
better.
  • Loading branch information
DavePearce authored Jan 9, 2025
1 parent a725493 commit 3a770dd
Show file tree
Hide file tree
Showing 50 changed files with 702 additions and 542 deletions.
6 changes: 3 additions & 3 deletions cmd/testgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ func traceToColumns(schema sc.Schema, trace tr.Trace) []tr.RawColumn {
for iter := schema.InputColumns(); iter.HasNext(); {
sc_col := iter.Next()
// Lookup the column data
tr_col := findColumn(sc_col.Context().Module(), sc_col.Name(), schema, trace)
tr_col := findColumn(sc_col.Context.Module(), sc_col.Name, schema, trace)
// Determine module name
mod := schema.Modules().Nth(sc_col.Context().Module())
mod := schema.Modules().Nth(sc_col.Context.Module())
// Assignt the raw colmn
cols[i] = tr.RawColumn{Module: mod.Name(), Name: sc_col.Name(), Data: tr_col.Data()}
cols[i] = tr.RawColumn{Module: mod.Name, Name: sc_col.Name, Data: tr_col.Data()}
//
i++
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func NewColumnAccess(column uint, shift int) *ColumnAccess {
// expression.
func (p *ColumnAccess) Context(schema sc.Schema) trace.Context {
col := schema.Columns().Nth(p.Column)
return col.Context()
return col.Context
}

// RequiredColumns returns the set of columns on which this term depends.
Expand Down
11 changes: 6 additions & 5 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/air"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/util"
)

// ApplyBinaryGadget adds a binarity constraint for a given column in the schema
Expand All @@ -15,15 +16,15 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) {
// Identify target column
column := schema.Columns().Nth(col)
// Determine column name
name := column.Name()
name := column.Name
// Construct X
X := air.NewColumnAccess(col, 0)
// Construct X-1
X_m1 := X.Sub(air.NewConst64(1))
// Construct X * (X-1)
X_X_m1 := X.Mul(X_m1)
// Done!
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Context(), nil, X_X_m1)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Context, util.None[int](), X_X_m1)
}

// ApplyBitwidthGadget ensures all values in a given column fit within a given
Expand All @@ -41,11 +42,11 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
n := nbits / 8
es := make([]air.Expr, n)
fr256 := fr.NewElement(256)
name := column.Name()
name := column.Name
coefficient := fr.NewElement(1)
// Add decomposition assignment
index := schema.AddAssignment(
assignment.NewByteDecomposition(name, column.Context(), col, n))
assignment.NewByteDecomposition(name, column.Context, col, n))
// Construct Columns
for i := uint(0); i < n; i++ {
// Create Column + Constraint
Expand All @@ -61,5 +62,5 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
X := air.NewColumnAccess(col, 0)
eq := X.Equate(sum)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Context(), nil, eq)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Context, util.None[int](), eq)
}
9 changes: 5 additions & 4 deletions pkg/air/gadgets/column_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/consensys/go-corset/pkg/air"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/util"
)

// ApplyColumnSortGadget adds sorting constraints for a column where the
Expand All @@ -24,7 +25,7 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem
// Identify target column
column := schema.Columns().Nth(col)
// Determine column name
name := column.Name()
name := column.Name
// Configure computation
Xk := air.NewColumnAccess(col, 0)
Xkm1 := air.NewColumnAccess(col, -1)
Expand All @@ -38,15 +39,15 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem
deltaName = fmt.Sprintf("-%s", name)
}
// Look up column
deltaIndex, ok := sc.ColumnIndexOf(schema, column.Context().Module(), deltaName)
deltaIndex, ok := sc.ColumnIndexOf(schema, column.Context.Module(), deltaName)
// Add new column (if it does not already exist)
if !ok {
deltaIndex = schema.AddAssignment(
assignment.NewComputedColumn(column.Context(), deltaName, Xdiff))
assignment.NewComputedColumn(column.Context, deltaName, Xdiff))
}
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
// Configure constraint: Delta[k] = X[k] - X[k-1]
Dk := air.NewColumnAccess(deltaIndex, 0)
schema.AddVanishingConstraint(deltaName, column.Context(), nil, Dk.Equate(Xdiff))
schema.AddVanishingConstraint(deltaName, column.Context, util.None[int](), Dk.Equate(Xdiff))
}
3 changes: 2 additions & 1 deletion pkg/air/gadgets/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

// Expand converts an arbitrary expression into a specific column index. In
Expand Down Expand Up @@ -34,7 +35,7 @@ func Expand(ctx trace.Context, e air.Expr, schema *air.Schema) uint {
// Construct 1 == e/e
eq_e_v := v.Equate(e)
// Ensure (e - v) == 0, where v is value of computed column.
schema.AddVanishingConstraint(name, ctx, nil, eq_e_v)
schema.AddVanishingConstraint(name, ctx, util.None[int](), eq_e_v)
}
//
return index
Expand Down
11 changes: 6 additions & 5 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

// ApplyLexicographicSortingGadget Add sorting constraints for a sequence of one
Expand Down Expand Up @@ -46,7 +47,7 @@ func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint
constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs)
// Add delta constraint
deltaName := fmt.Sprintf("%s:delta", prefix)
schema.AddVanishingConstraint(deltaName, ctx, nil, constraint)
schema.AddVanishingConstraint(deltaName, ctx, util.None[int](), constraint)
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
}
Expand All @@ -59,7 +60,7 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a
// Concatenate column names with their signs.
for i := 0; i < len(columns); i++ {
ith := schema.Columns().Nth(columns[i])
id.WriteString(ith.Name())
id.WriteString(ith.Name)

if signs[i] {
id.WriteString("+")
Expand Down Expand Up @@ -104,7 +105,7 @@ func addLexicographicSelectorBits(prefix string, context trace.Context,
pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1))
pName := fmt.Sprintf("%s:%d:a", prefix, i)
schema.AddVanishingConstraint(pName, context,
nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
util.None[int](), air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
// (∀j<i.Bj=0) ∧ Bi=1 ==> C[k]≠C[k-1]
qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema)
qName := fmt.Sprintf("%s:%d:b", prefix, i)
Expand All @@ -116,14 +117,14 @@ func addLexicographicSelectorBits(prefix string, context trace.Context,
constraint = air.NewConst64(1).Sub(&air.Add{Args: qterms}).Mul(constraint)
}

schema.AddVanishingConstraint(qName, context, nil, constraint)
schema.AddVanishingConstraint(qName, context, util.None[int](), constraint)
}

sum := &air.Add{Args: terms}
// (sum = 0) ∨ (sum = 1)
constraint := sum.Mul(sum.Equate(air.NewConst64(1)))
name := fmt.Sprintf("%s:xor", prefix)
schema.AddVanishingConstraint(name, context, nil, constraint)
schema.AddVanishingConstraint(name, context, util.None[int](), constraint)
}

// Construct the lexicographic delta constraint. This states that the delta
Expand Down
4 changes: 2 additions & 2 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
inv_e_implies_one_e_e := inv_e.Mul(one_e_e)
// Ensure (e != 0) ==> (1 == e/e)
l_name := fmt.Sprintf("%s <=", name)
schema.AddVanishingConstraint(l_name, ctx, nil, e_implies_one_e_e)
schema.AddVanishingConstraint(l_name, ctx, util.None[int](), e_implies_one_e_e)
// Ensure (e/e != 0) ==> (1 == e/e)
r_name := fmt.Sprintf("%s =>", name)
schema.AddVanishingConstraint(r_name, ctx, nil, inv_e_implies_one_e_e)
schema.AddVanishingConstraint(r_name, ctx, util.None[int](), inv_e_implies_one_e_e)
}
// Done
return air.NewColumnAccess(index, 0)
Expand Down
4 changes: 2 additions & 2 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (p *Schema) AddPropertyAssertion(handle string, context trace.Context, asse
}

// AddVanishingConstraint appends a new vanishing constraint.
func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, domain *int, expr Expr) {
func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, domain util.Option[int], expr Expr) {
if context.Module() >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", context.Module()))
}
Expand All @@ -156,7 +156,7 @@ func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, do
func (p *Schema) AddRangeConstraint(column uint, bound fr.Element) {
col := p.Columns().Nth(column)
handle := col.QualifiedName(p)
tc := constraint.NewRangeConstraint[*ColumnAccess](handle, col.Context(), NewColumnAccess(column, 0), bound)
tc := constraint.NewRangeConstraint[*ColumnAccess](handle, col.Context, NewColumnAccess(column, 0), bound)
p.constraints = append(p.constraints, tc)
}

Expand Down
18 changes: 9 additions & 9 deletions pkg/binfile/computation.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ func addSortedComputation(sorted *jsonSortedComputation, index uint,
dst_hnd := asHandle(dst_col.Handle)
src_col := schema.Columns().Nth(sources[i])
// Sanity check source column type
if src_col.Type().AsUint() == nil {
panic(fmt.Sprintf("source column %s has field type", src_col.Name()))
if src_col.DataType.AsUint() == nil {
panic(fmt.Sprintf("source column %s has field type", src_col.Name))
}

targets[i] = sc.NewColumn(ctx, dst_hnd.column, src_col.Type())
targets[i] = sc.NewColumn(ctx, dst_hnd.column, src_col.DataType)
// Update allocation information.
colmap[target_id] = index
index++
Expand All @@ -100,7 +100,7 @@ func addInterleavedComputation(c *jsonInterleavedComputation, index uint,
for i := range sources {
src_col := schema.Columns().Nth(sources[i])
// Update the column type
dst_type = sc.Join(dst_type, src_col.Type())
dst_type = sc.Join(dst_type, src_col.DataType)
}
// Update multiplier
ctx = ctx.Multiply(uint(len(sources)))
Expand All @@ -118,7 +118,7 @@ func sourceColumnsFromHandles(handles []string, columns []column,
handle := asHandle(columns[sourceIDs[0]].Handle)
// Resolve enclosing module
mid, ok := schema.Modules().Find(func(m sc.Module) bool {
return m.Name() == handle.module
return m.Name == handle.module
})
// Sanity check assumptions
if !ok {
Expand All @@ -139,16 +139,16 @@ func sourceColumnsFromHandles(handles []string, columns []column,
// Extract schema info about source column
src_col := schema.Columns().Nth(src_cid)
// Sanity check enclosing modules match
if src_col.Context().Module() != mid {
if src_col.Context.Module() != mid {
panic("inconsistent enclosing module for sorted permutation (source)")
}

ctx = ctx.Join(src_col.Context())
ctx = ctx.Join(src_col.Context)
// Sanity check we have a sensible type here.
if ctx.IsConflicted() {
panic(fmt.Sprintf("source column %s has conflicted evaluation context", src_col.Name()))
panic(fmt.Sprintf("source column %s has conflicted evaluation context", src_col.Name))
} else if ctx.IsVoid() {
panic(fmt.Sprintf("source column %s has void evaluation context", src_col.Name()))
panic(fmt.Sprintf("source column %s has void evaluation context", src_col.Name))
}

sources[i] = src_cid
Expand Down
7 changes: 4 additions & 3 deletions pkg/binfile/constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/util"
)

// JsonConstraint аn enumeration of constraint forms. Exactly one of these fields
Expand Down Expand Up @@ -93,13 +94,13 @@ func (e jsonConstraint) addToSchema(colmap map[uint]uint, schema *hir.Schema) {
}
}

func (e jsonDomain) toHir() *int {
func (e jsonDomain) toHir() util.Option[int] {
if len(e.Set) == 1 {
domain := e.Set[0]
return &domain
return util.Some(domain)
} else if e.Set != nil {
panic("Unknown domain")
}
// Default
return nil
return util.None[int]()
}
8 changes: 4 additions & 4 deletions pkg/binfile/constraint_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ func checkAllocation(cs *columnSet, colmap map[uint]uint, schema *hir.Schema) {
}

sc_col := schema.Columns().Nth(cid)
sc_mod := schema.Modules().Nth(sc_col.Context().Module())
sc_mod := schema.Modules().Nth(sc_col.Context.Module())
// Perform the check
if sc_mod.Name() != handle.module || sc_col.Name() != handle.column {
panic(fmt.Sprintf("invalid allocation %s.%s != %s.%s", handle.module, handle.column, sc_mod.Name(), sc_col.Name()))
if sc_mod.Name != handle.module || sc_col.Name != handle.column {
panic(fmt.Sprintf("invalid allocation %s.%s != %s.%s", handle.module, handle.column, sc_mod.Name, sc_col.Name))
}
}
}
Expand All @@ -239,7 +239,7 @@ func checkAllocation(cs *columnSet, colmap map[uint]uint, schema *hir.Schema) {
func registerModule(schema *hir.Schema, module string) uint {
// Attempt to find existing module with same name
mid, ok := schema.Modules().Find(func(m sc.Module) bool {
return m.Name() == module
return m.Name == module
})
// Check whether search successful, or not.
if ok {
Expand Down
13 changes: 7 additions & 6 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var checkCmd = &cobra.Command{
if GetFlag(cmd, "verbose") {
log.SetLevel(log.DebugLevel)
}
legacy := GetFlag(cmd, "legacy")
//
cfg.air = GetFlag(cmd, "air")
cfg.mir = GetFlag(cmd, "mir")
Expand All @@ -60,7 +61,7 @@ var checkCmd = &cobra.Command{
//
stats := util.NewPerfStats()
// Parse constraints
hirSchema = readSchema(cfg.stdlib, cfg.debug, args[1:])
hirSchema = readSchema(cfg.stdlib, cfg.debug, legacy, args[1:])
//
stats.Log("Reading constraints file")
// Parse trace file
Expand Down Expand Up @@ -196,9 +197,9 @@ func validationCheck(tr tr.Trace, schema sc.Schema) error {
// Extract schema for ith column
scCol := schemaCols.Next()
// Determine enclosing module
mod := schema.Modules().Nth(scCol.Context().Module())
mod := schema.Modules().Nth(scCol.Context.Module())
// Extract type for ith column
colType := scCol.Type()
colType := scCol.DataType
// Check elements
go func() {
// Send outcome back
Expand All @@ -221,7 +222,7 @@ func validateColumn(colType sc.Type, col tr.Column, mod sc.Module) error {
for j := 0; j < int(col.Data().Len()); j++ {
jth := col.Get(j)
if !colType.Accept(jth) {
qualColName := tr.QualifiedColumnName(mod.Name(), col.Name())
qualColName := tr.QualifiedColumnName(mod.Name, col.Name())
return fmt.Errorf("row %d of column %s is out-of-bounds (%s)", j, qualColName, jth.String())
}
}
Expand Down Expand Up @@ -249,10 +250,10 @@ func reportFailures(ir string, failures []sc.Failure, trace tr.Trace, cfg checkC
func reportFailure(failure sc.Failure, trace tr.Trace, cfg checkConfig) {
if f, ok := failure.(*constraint.VanishingFailure); ok {
cells := f.RequiredCells(trace)
reportConstraintFailure("constraint", f.Handle(), cells, trace, cfg)
reportConstraintFailure("constraint", f.Handle, cells, trace, cfg)
} else if f, ok := failure.(*sc.AssertionFailure); ok {
cells := f.RequiredCells(trace)
reportConstraintFailure("assertion", f.Handle(), cells, trace, cfg)
reportConstraintFailure("assertion", f.Handle, cells, trace, cfg)
}
}

Expand Down
37 changes: 37 additions & 0 deletions pkg/cmd/compile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package cmd

import (
"fmt"
"os"

"github.com/spf13/cobra"
)

var compileCmd = &cobra.Command{
Use: "compile [flags] constraint_file(s)",
Short: "compile constraints into a binary package.",
Long: `Compile a given set of constraint file(s) into a single binary package which can
be subsequently used without requiring a full compilation step.`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 1 {
fmt.Println(cmd.UsageString())
os.Exit(1)
}
stdlib := !GetFlag(cmd, "no-stdlib")
debug := GetFlag(cmd, "debug")
legacy := GetFlag(cmd, "legacy")
output := GetString(cmd, "output")
// Parse constraints
hirSchema := readSchema(stdlib, debug, legacy, args)
// Serialise as a gob file.
writeHirSchema(hirSchema, legacy, output)
},
}

//nolint:errcheck
func init() {
rootCmd.AddCommand(compileCmd)
compileCmd.Flags().Bool("debug", false, "enable debugging constraints")
compileCmd.Flags().StringP("output", "o", "a.bin", "specify output file.")
compileCmd.MarkFlagRequired("output")
}
Loading

0 comments on commit 3a770dd

Please sign in to comment.