diff --git a/pkg/air/gadgets/column_sort.go b/pkg/air/gadgets/column_sort.go index 1d11426..d50a479 100644 --- a/pkg/air/gadgets/column_sort.go +++ b/pkg/air/gadgets/column_sort.go @@ -43,7 +43,7 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem // Add new column (if it does not already exist) if !ok { deltaIndex = schema.AddAssignment( - assignment.NewComputedColumn(column.Context, deltaName, Xdiff)) + assignment.NewComputedColumn[air.Expr](column.Context, deltaName, Xdiff)) } // Add necessary bitwidth constraints ApplyBitwidthGadget(deltaIndex, bitwidth, schema) diff --git a/pkg/air/gadgets/expand.go b/pkg/air/gadgets/expand.go index e29da64..7f27a15 100644 --- a/pkg/air/gadgets/expand.go +++ b/pkg/air/gadgets/expand.go @@ -29,7 +29,7 @@ func Expand(ctx trace.Context, e air.Expr, schema *air.Schema) uint { // Add new column (if it does not already exist) if !ok { // Add computed column - index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, e)) + index = schema.AddAssignment(assignment.NewComputedColumn[air.Expr](ctx, name, e)) // Construct v == [e] v := air.NewColumnAccess(index, 0) // Construct 1 == e/e diff --git a/pkg/air/gadgets/normalisation.go b/pkg/air/gadgets/normalisation.go index 10d08ee..f781de1 100644 --- a/pkg/air/gadgets/normalisation.go +++ b/pkg/air/gadgets/normalisation.go @@ -44,7 +44,7 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr { // Add new column (if it does not already exist) if !ok { // Add computed column - index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, ie)) + index = schema.AddAssignment(assignment.NewComputedColumn[air.Expr](ctx, name, ie)) // Construct 1/e inv_e := air.NewColumnAccess(index, 0) // Construct e/e @@ -82,6 +82,23 @@ func (e *Inverse) EvalAt(k int, tbl tr.Trace) fr.Element { return inv } +// Add two expressions together, producing a third. +func (e *Inverse) Add(other air.Expr) air.Expr { panic("unreachable") } + +// Sub (subtract) one expression from another. +func (e *Inverse) Sub(other air.Expr) air.Expr { panic("unreachable") } + +// Mul (multiply) two expressions together, producing a third. +func (e *Inverse) Mul(other air.Expr) air.Expr { panic("unreachable") } + +// Equate one expression with another (equivalent to subtraction). +func (e *Inverse) Equate(other air.Expr) air.Expr { panic("unreachable") } + +// AsConstant determines whether or not this is a constant expression. If +// so, the constant is returned; otherwise, nil is returned. NOTE: this +// does not perform any form of simplification to determine this. +func (e *Inverse) AsConstant() *fr.Element { return nil } + // Bounds returns max shift in either the negative (left) or positive // direction (right). func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() } diff --git a/pkg/cmd/debug.go b/pkg/cmd/debug.go index d497078..b1ded80 100644 --- a/pkg/cmd/debug.go +++ b/pkg/cmd/debug.go @@ -4,11 +4,13 @@ import ( "fmt" "os" "reflect" - "strings" + "github.com/consensys/go-corset/pkg/air" "github.com/consensys/go-corset/pkg/hir" + "github.com/consensys/go-corset/pkg/mir" "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/schema/assignment" "github.com/consensys/go-corset/pkg/util" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -133,19 +135,20 @@ type schemaSummariser struct { var schemaSummarisers []schemaSummariser = []schemaSummariser{ // Constraints - constraintCounter("Constraints", "*constraint.VanishingConstraint"), - constraintCounter("Lookups", "*constraint.LookupConstraint"), - constraintCounter("Permutations", "*constraint.PermutationConstraint"), - constraintCounter("Types", "*constraint.TypeConstraint"), - constraintCounter("Range", "*constraint.RangeConstraint"), + constraintCounter("Constraints", vanishingConstraints...), + constraintCounter("Lookups", lookupConstraints...), + constraintCounter("Permutations", permutationConstraints...), + constraintCounter("Range", rangeConstraints...), // Assignments - assignmentCounter("Decompositions", "*assignment.ByteDecomposition"), - assignmentCounter("Computed Columns", "*assignment.ComputedColumn"), - assignmentCounter("Committed Columns", "*assignment.DataColumn"), - assignmentCounter("Interleavings", "*assignment.Interleaving"), - assignmentCounter("Lexicographic Orderings", "*assignment.LexicographicSort"), - assignmentCounter("Sorted Permutations", "*assignment.SortedPermutation"), - // Column Width + assignmentCounter("Decompositions", reflect.TypeOf((*assignment.ByteDecomposition)(nil))), + assignmentCounter("Committed Columns", reflect.TypeOf((*assignment.DataColumn)(nil))), + assignmentCounter("Computed Columns", computedColumns...), + assignmentCounter("Computation Columns", reflect.TypeOf((*assignment.Computation)(nil))), + assignmentCounter("Interleavings", reflect.TypeOf((*assignment.Interleaving)(nil))), + assignmentCounter("Lexicographic Orderings", reflect.TypeOf((*assignment.LexicographicSort)(nil))), + assignmentCounter("Sorted Permutations", reflect.TypeOf((*assignment.SortedPermutation)(nil))), + // Columns + columnCounter(), columnWidthSummariser(1, 1), columnWidthSummariser(2, 4), columnWidthSummariser(5, 8), @@ -156,30 +159,61 @@ var schemaSummarisers []schemaSummariser = []schemaSummariser{ columnWidthSummariser(129, 256), } -func constraintCounter(title string, prefix string) schemaSummariser { +var vanishingConstraints = []reflect.Type{ + reflect.TypeOf((hir.VanishingConstraint)(nil)), + reflect.TypeOf((mir.VanishingConstraint)(nil)), + reflect.TypeOf((air.VanishingConstraint)(nil))} + +var lookupConstraints = []reflect.Type{ + reflect.TypeOf((hir.LookupConstraint)(nil)), + reflect.TypeOf((mir.LookupConstraint)(nil)), + reflect.TypeOf((air.LookupConstraint)(nil))} + +var rangeConstraints = []reflect.Type{ + reflect.TypeOf((hir.RangeConstraint)(nil)), + reflect.TypeOf((mir.RangeConstraint)(nil)), + reflect.TypeOf((air.RangeConstraint)(nil))} + +var permutationConstraints = []reflect.Type{ + // permutation constraints only exist at AIR level + reflect.TypeOf((air.PermutationConstraint)(nil))} + +var computedColumns = []reflect.Type{ + // permutation constraints only exist at AIR level + reflect.TypeOf((*assignment.ComputedColumn[air.Expr])(nil))} + +func constraintCounter(title string, types ...reflect.Type) schemaSummariser { return schemaSummariser{ name: title, summary: func(schema sc.Schema) int { - return typeOfCounter(schema.Constraints(), prefix) + sum := 0 + for _, t := range types { + sum += typeOfCounter(schema.Constraints(), t) + } + return sum }, } } -func assignmentCounter(title string, prefix string) schemaSummariser { +func assignmentCounter(title string, types ...reflect.Type) schemaSummariser { return schemaSummariser{ name: title, summary: func(schema sc.Schema) int { - return typeOfCounter(schema.Declarations(), prefix) + sum := 0 + for _, t := range types { + sum += typeOfCounter(schema.Declarations(), t) + } + return sum }, } } -func typeOfCounter[T any](iter util.Iterator[T], prefix string) int { +func typeOfCounter[T any](iter util.Iterator[T], dyntype reflect.Type) int { count := 0 for iter.HasNext() { ith := iter.Next() - if isTypeOf(ith, prefix) { + if dyntype == reflect.TypeOf(ith) { count++ } } @@ -187,10 +221,18 @@ func typeOfCounter[T any](iter util.Iterator[T], prefix string) int { return count } -func isTypeOf(obj any, prefix string) bool { - dyntype := reflect.TypeOf(obj) - // Check whether dynamic type matches prefix - return strings.HasPrefix(dyntype.String(), prefix) +func columnCounter() schemaSummariser { + return schemaSummariser{ + name: "Columns (all)", + summary: func(sc schema.Schema) int { + count := 0 + for i := sc.Columns(); i.HasNext(); { + i.Next() + count++ + } + return count + }, + } } func columnWidthSummariser(lowWidth uint, highWidth uint) schemaSummariser { diff --git a/pkg/corset/compiler/translator.go b/pkg/corset/compiler/translator.go index 6f87e4f..e40dba1 100644 --- a/pkg/corset/compiler/translator.go +++ b/pkg/corset/compiler/translator.go @@ -291,10 +291,13 @@ func (t *translator) translateDefConstraint(decl *ast.DefConstraint, module util } // Apply guard (if applicable) if guard != nil { - constraint = &hir.Mul{Args: []hir.Expr{guard, constraint}} + constraint = &hir.IfZero{Condition: guard, TrueBranch: nil, FalseBranch: constraint} } // Apply perspective selector (if applicable) if selector != nil { + // NOTE: using an ifnot (as above) would be preferable here. However, + // this is currently done just to ensure constraints identical to the + // original are generated. constraint = &hir.Mul{Args: []hir.Expr{selector, constraint}} } // diff --git a/pkg/corset/stdlib.lisp b/pkg/corset/stdlib.lisp index 44e356e..f429b3c 100644 --- a/pkg/corset/stdlib.lisp +++ b/pkg/corset/stdlib.lisp @@ -21,12 +21,10 @@ (defpurefun ((not :binary@bool :force) (x :binary)) (- 1 x)) -(defpurefun ((eq! :binary@loob :force) (x :binary) (y :binary)) (^ (- x y) 2)) (defpurefun ((eq! :@loob) x y) (- x y)) (defpurefun ((neq! :binary@loob :force) x y) (not (~ (eq! x y)))) (defunalias = eq!) -(defpurefun ((eq :binary@bool :force) (x :binary) (y :binary)) (- 1 (^ (- x y) 2))) (defpurefun ((eq :binary@bool :force) x y) (- 1 (~ (eq! x y)))) (defpurefun ((neq :binary@bool :force) x y) (eq! x y)) diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 8fa56f8..74a5b69 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -299,7 +299,17 @@ func extractBodies(es []Expr, schema *mir.Schema) []mir.Expr { func expand(e Expr, schema sc.Schema) []Expr { if p, ok := e.(*Add); ok { return expandWithNaryConstructor(p.Args, func(nargs []Expr) Expr { - return &Add{Args: nargs} + var args []Expr + // Flatten nested sums + for _, e := range nargs { + if a, ok := e.(*Add); ok { + args = append(args, a.Args...) + } else { + args = append(args, e) + } + } + // Done + return &Add{Args: args} }, schema) } else if _, ok := e.(*Constant); ok { return []Expr{e} @@ -307,7 +317,17 @@ func expand(e Expr, schema sc.Schema) []Expr { return []Expr{e} } else if p, ok := e.(*Mul); ok { return expandWithNaryConstructor(p.Args, func(nargs []Expr) Expr { - return &Mul{Args: nargs} + var args []Expr + // Flatten nested products + for _, e := range nargs { + if a, ok := e.(*Mul); ok { + args = append(args, a.Args...) + } else { + args = append(args, e) + } + } + // Done + return &Mul{Args: args} }, schema) } else if p, ok := e.(*List); ok { ees := make([]Expr, 0) diff --git a/pkg/mir/const.go b/pkg/mir/const.go index b5854cd..f84407a 100644 --- a/pkg/mir/const.go +++ b/pkg/mir/const.go @@ -32,7 +32,7 @@ func applyConstantPropagation(e Expr, schema sc.Schema) Expr { func applyConstantPropagationAdd(es []Expr, schema sc.Schema) Expr { sum := fr.NewElement(0) - is_const := true + count := 0 rs := make([]Expr, len(es)) // for i, e := range es { @@ -40,16 +40,18 @@ func applyConstantPropagationAdd(es []Expr, schema sc.Schema) Expr { // Check for constant c, ok := rs[i].(*Constant) // Try to continue sum - if ok && is_const { + if ok { sum.Add(&sum, &c.Value) - } else { - is_const = false + // Increase count of constants + count++ } } // Check if constant - if is_const { + if count == len(es) { // Propagate constant return &Constant{sum} + } else if count > 1 { + rs = mergeConstants(sum, rs) } // Done return &Add{rs} @@ -85,10 +87,10 @@ func applyConstantPropagationSub(es []Expr, schema sc.Schema) Expr { func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr { one := fr.NewElement(1) - is_const := true prod := one rs := make([]Expr, len(es)) ones := 0 + consts := 0 // for i, e := range es { rs[i] = applyConstantPropagation(e, schema) @@ -100,16 +102,17 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr { return &Constant{c.Value} } else if ok && c.Value.IsOne() { ones++ + consts++ rs[i] = nil - } else if ok && is_const { + } else if ok { // Continue building constant prod.Mul(&prod, &c.Value) - } else { - is_const = false + // + consts++ } } // Check if constant - if is_const { + if consts == len(es) { return &Constant{prod} } else if ones > 0 { rs = util.RemoveMatching[Expr](rs, func(item Expr) bool { return item == nil }) @@ -117,6 +120,9 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr { // Sanity check what's left. if len(rs) == 1 { return rs[0] + } else if consts-ones > 1 { + // Combine constants + rs = mergeConstants(prod, rs) } // Done return &Mul{rs} @@ -155,3 +161,26 @@ func applyConstantPropagationNorm(arg Expr, schema sc.Schema) Expr { // return &Normalise{arg} } + +// Replace all constants within a given sequence of expressions with a single +// constant (whose value has been precomputed from those constants). The new +// value replaces the first constant in the list. +func mergeConstants(constant fr.Element, es []Expr) []Expr { + j := 0 + first := true + // + for i := range es { + // Check for constant + if _, ok := es[i].(*Constant); ok && first { + es[j] = &Constant{constant} + first = false + j++ + } else if !ok { + // Retain non-constant expression + es[j] = es[i] + j++ + } + } + // Return slice + return es[0:j] +} diff --git a/pkg/schema/assignment/computation.go b/pkg/schema/assignment/computation.go index 1346b51..014ecea 100644 --- a/pkg/schema/assignment/computation.go +++ b/pkg/schema/assignment/computation.go @@ -230,7 +230,6 @@ func mapIfNativeFunction(trace tr.Trace, sources []uint) []util.FrArray { rhs := fmt.Sprintf("%v=>%s", ith_row, val.String()) panic(fmt.Sprintf("conflicting values in source map (row %d): %s vs %s", i, lhs, rhs)) } else if !ok { - fmt.Printf("Inserting source key (row %d): %v\n", i, extractIthColumns(i, source_keys)) // Item not previously in map source_map.Insert(ith_key, ith_value) }