Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
DavePearce authored Jan 24, 2025
2 parents 90ed0b9 + dd86f82 commit 6ff4a5a
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 42 deletions.
2 changes: 1 addition & 1 deletion pkg/air/gadgets/column_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem
// Add new column (if it does not already exist)
if !ok {
deltaIndex = schema.AddAssignment(
assignment.NewComputedColumn(column.Context, deltaName, Xdiff))
assignment.NewComputedColumn[air.Expr](column.Context, deltaName, Xdiff))
}
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
Expand Down
2 changes: 1 addition & 1 deletion pkg/air/gadgets/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func Expand(ctx trace.Context, e air.Expr, schema *air.Schema) uint {
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, e))
index = schema.AddAssignment(assignment.NewComputedColumn[air.Expr](ctx, name, e))
// Construct v == [e]
v := air.NewColumnAccess(index, 0)
// Construct 1 == e/e
Expand Down
19 changes: 18 additions & 1 deletion pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, ie))
index = schema.AddAssignment(assignment.NewComputedColumn[air.Expr](ctx, name, ie))
// Construct 1/e
inv_e := air.NewColumnAccess(index, 0)
// Construct e/e
Expand Down Expand Up @@ -82,6 +82,23 @@ func (e *Inverse) EvalAt(k int, tbl tr.Trace) fr.Element {
return inv
}

// Add two expressions together, producing a third.
func (e *Inverse) Add(other air.Expr) air.Expr { panic("unreachable") }

// Sub (subtract) one expression from another.
func (e *Inverse) Sub(other air.Expr) air.Expr { panic("unreachable") }

// Mul (multiply) two expressions together, producing a third.
func (e *Inverse) Mul(other air.Expr) air.Expr { panic("unreachable") }

// Equate one expression with another (equivalent to subtraction).
func (e *Inverse) Equate(other air.Expr) air.Expr { panic("unreachable") }

// AsConstant determines whether or not this is a constant expression. If
// so, the constant is returned; otherwise, nil is returned. NOTE: this
// does not perform any form of simplification to determine this.
func (e *Inverse) AsConstant() *fr.Element { return nil }

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() }
Expand Down
88 changes: 65 additions & 23 deletions pkg/cmd/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"fmt"
"os"
"reflect"
"strings"

"github.com/consensys/go-corset/pkg/air"
"github.com/consensys/go-corset/pkg/hir"
"github.com/consensys/go-corset/pkg/mir"
"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -133,19 +135,20 @@ type schemaSummariser struct {

var schemaSummarisers []schemaSummariser = []schemaSummariser{
// Constraints
constraintCounter("Constraints", "*constraint.VanishingConstraint"),
constraintCounter("Lookups", "*constraint.LookupConstraint"),
constraintCounter("Permutations", "*constraint.PermutationConstraint"),
constraintCounter("Types", "*constraint.TypeConstraint"),
constraintCounter("Range", "*constraint.RangeConstraint"),
constraintCounter("Constraints", vanishingConstraints...),
constraintCounter("Lookups", lookupConstraints...),
constraintCounter("Permutations", permutationConstraints...),
constraintCounter("Range", rangeConstraints...),
// Assignments
assignmentCounter("Decompositions", "*assignment.ByteDecomposition"),
assignmentCounter("Computed Columns", "*assignment.ComputedColumn"),
assignmentCounter("Committed Columns", "*assignment.DataColumn"),
assignmentCounter("Interleavings", "*assignment.Interleaving"),
assignmentCounter("Lexicographic Orderings", "*assignment.LexicographicSort"),
assignmentCounter("Sorted Permutations", "*assignment.SortedPermutation"),
// Column Width
assignmentCounter("Decompositions", reflect.TypeOf((*assignment.ByteDecomposition)(nil))),
assignmentCounter("Committed Columns", reflect.TypeOf((*assignment.DataColumn)(nil))),
assignmentCounter("Computed Columns", computedColumns...),
assignmentCounter("Computation Columns", reflect.TypeOf((*assignment.Computation)(nil))),
assignmentCounter("Interleavings", reflect.TypeOf((*assignment.Interleaving)(nil))),
assignmentCounter("Lexicographic Orderings", reflect.TypeOf((*assignment.LexicographicSort)(nil))),
assignmentCounter("Sorted Permutations", reflect.TypeOf((*assignment.SortedPermutation)(nil))),
// Columns
columnCounter(),
columnWidthSummariser(1, 1),
columnWidthSummariser(2, 4),
columnWidthSummariser(5, 8),
Expand All @@ -156,41 +159,80 @@ var schemaSummarisers []schemaSummariser = []schemaSummariser{
columnWidthSummariser(129, 256),
}

func constraintCounter(title string, prefix string) schemaSummariser {
var vanishingConstraints = []reflect.Type{
reflect.TypeOf((hir.VanishingConstraint)(nil)),
reflect.TypeOf((mir.VanishingConstraint)(nil)),
reflect.TypeOf((air.VanishingConstraint)(nil))}

var lookupConstraints = []reflect.Type{
reflect.TypeOf((hir.LookupConstraint)(nil)),
reflect.TypeOf((mir.LookupConstraint)(nil)),
reflect.TypeOf((air.LookupConstraint)(nil))}

var rangeConstraints = []reflect.Type{
reflect.TypeOf((hir.RangeConstraint)(nil)),
reflect.TypeOf((mir.RangeConstraint)(nil)),
reflect.TypeOf((air.RangeConstraint)(nil))}

var permutationConstraints = []reflect.Type{
// permutation constraints only exist at AIR level
reflect.TypeOf((air.PermutationConstraint)(nil))}

var computedColumns = []reflect.Type{
// permutation constraints only exist at AIR level
reflect.TypeOf((*assignment.ComputedColumn[air.Expr])(nil))}

func constraintCounter(title string, types ...reflect.Type) schemaSummariser {
return schemaSummariser{
name: title,
summary: func(schema sc.Schema) int {
return typeOfCounter(schema.Constraints(), prefix)
sum := 0
for _, t := range types {
sum += typeOfCounter(schema.Constraints(), t)
}
return sum
},
}
}

func assignmentCounter(title string, prefix string) schemaSummariser {
func assignmentCounter(title string, types ...reflect.Type) schemaSummariser {
return schemaSummariser{
name: title,
summary: func(schema sc.Schema) int {
return typeOfCounter(schema.Declarations(), prefix)
sum := 0
for _, t := range types {
sum += typeOfCounter(schema.Declarations(), t)
}
return sum
},
}
}

func typeOfCounter[T any](iter util.Iterator[T], prefix string) int {
func typeOfCounter[T any](iter util.Iterator[T], dyntype reflect.Type) int {
count := 0

for iter.HasNext() {
ith := iter.Next()
if isTypeOf(ith, prefix) {
if dyntype == reflect.TypeOf(ith) {
count++
}
}

return count
}

func isTypeOf(obj any, prefix string) bool {
dyntype := reflect.TypeOf(obj)
// Check whether dynamic type matches prefix
return strings.HasPrefix(dyntype.String(), prefix)
func columnCounter() schemaSummariser {
return schemaSummariser{
name: "Columns (all)",
summary: func(sc schema.Schema) int {
count := 0
for i := sc.Columns(); i.HasNext(); {
i.Next()
count++
}
return count
},
}
}

func columnWidthSummariser(lowWidth uint, highWidth uint) schemaSummariser {
Expand Down
5 changes: 4 additions & 1 deletion pkg/corset/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,13 @@ func (t *translator) translateDefConstraint(decl *ast.DefConstraint, module util
}
// Apply guard (if applicable)
if guard != nil {
constraint = &hir.Mul{Args: []hir.Expr{guard, constraint}}
constraint = &hir.IfZero{Condition: guard, TrueBranch: nil, FalseBranch: constraint}
}
// Apply perspective selector (if applicable)
if selector != nil {
// NOTE: using an ifnot (as above) would be preferable here. However,
// this is currently done just to ensure constraints identical to the
// original are generated.
constraint = &hir.Mul{Args: []hir.Expr{selector, constraint}}
}
//
Expand Down
2 changes: 0 additions & 2 deletions pkg/corset/stdlib.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@

(defpurefun ((not :binary@bool :force) (x :binary)) (- 1 x))

(defpurefun ((eq! :binary@loob :force) (x :binary) (y :binary)) (^ (- x y) 2))
(defpurefun ((eq! :@loob) x y) (- x y))
(defpurefun ((neq! :binary@loob :force) x y) (not (~ (eq! x y))))
(defunalias = eq!)

(defpurefun ((eq :binary@bool :force) (x :binary) (y :binary)) (- 1 (^ (- x y) 2)))
(defpurefun ((eq :binary@bool :force) x y) (- 1 (~ (eq! x y))))
(defpurefun ((neq :binary@bool :force) x y) (eq! x y))

Expand Down
24 changes: 22 additions & 2 deletions pkg/hir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,35 @@ func extractBodies(es []Expr, schema *mir.Schema) []mir.Expr {
func expand(e Expr, schema sc.Schema) []Expr {
if p, ok := e.(*Add); ok {
return expandWithNaryConstructor(p.Args, func(nargs []Expr) Expr {
return &Add{Args: nargs}
var args []Expr
// Flatten nested sums
for _, e := range nargs {
if a, ok := e.(*Add); ok {
args = append(args, a.Args...)
} else {
args = append(args, e)
}
}
// Done
return &Add{Args: args}
}, schema)
} else if _, ok := e.(*Constant); ok {
return []Expr{e}
} else if _, ok := e.(*ColumnAccess); ok {
return []Expr{e}
} else if p, ok := e.(*Mul); ok {
return expandWithNaryConstructor(p.Args, func(nargs []Expr) Expr {
return &Mul{Args: nargs}
var args []Expr
// Flatten nested products
for _, e := range nargs {
if a, ok := e.(*Mul); ok {
args = append(args, a.Args...)
} else {
args = append(args, e)
}
}
// Done
return &Mul{Args: args}
}, schema)
} else if p, ok := e.(*List); ok {
ees := make([]Expr, 0)
Expand Down
49 changes: 39 additions & 10 deletions pkg/mir/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,26 @@ func applyConstantPropagation(e Expr, schema sc.Schema) Expr {

func applyConstantPropagationAdd(es []Expr, schema sc.Schema) Expr {
sum := fr.NewElement(0)
is_const := true
count := 0
rs := make([]Expr, len(es))
//
for i, e := range es {
rs[i] = applyConstantPropagation(e, schema)
// Check for constant
c, ok := rs[i].(*Constant)
// Try to continue sum
if ok && is_const {
if ok {
sum.Add(&sum, &c.Value)
} else {
is_const = false
// Increase count of constants
count++
}
}
// Check if constant
if is_const {
if count == len(es) {
// Propagate constant
return &Constant{sum}
} else if count > 1 {
rs = mergeConstants(sum, rs)
}
// Done
return &Add{rs}
Expand Down Expand Up @@ -85,10 +87,10 @@ func applyConstantPropagationSub(es []Expr, schema sc.Schema) Expr {

func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr {
one := fr.NewElement(1)
is_const := true
prod := one
rs := make([]Expr, len(es))
ones := 0
consts := 0
//
for i, e := range es {
rs[i] = applyConstantPropagation(e, schema)
Expand All @@ -100,23 +102,27 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr {
return &Constant{c.Value}
} else if ok && c.Value.IsOne() {
ones++
consts++
rs[i] = nil
} else if ok && is_const {
} else if ok {
// Continue building constant
prod.Mul(&prod, &c.Value)
} else {
is_const = false
//
consts++
}
}
// Check if constant
if is_const {
if consts == len(es) {
return &Constant{prod}
} else if ones > 0 {
rs = util.RemoveMatching[Expr](rs, func(item Expr) bool { return item == nil })
}
// Sanity check what's left.
if len(rs) == 1 {
return rs[0]
} else if consts-ones > 1 {
// Combine constants
rs = mergeConstants(prod, rs)
}
// Done
return &Mul{rs}
Expand Down Expand Up @@ -155,3 +161,26 @@ func applyConstantPropagationNorm(arg Expr, schema sc.Schema) Expr {
//
return &Normalise{arg}
}

// Replace all constants within a given sequence of expressions with a single
// constant (whose value has been precomputed from those constants). The new
// value replaces the first constant in the list.
func mergeConstants(constant fr.Element, es []Expr) []Expr {
j := 0
first := true
//
for i := range es {
// Check for constant
if _, ok := es[i].(*Constant); ok && first {
es[j] = &Constant{constant}
first = false
j++
} else if !ok {
// Retain non-constant expression
es[j] = es[i]
j++
}
}
// Return slice
return es[0:j]
}
1 change: 0 additions & 1 deletion pkg/schema/assignment/computation.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ func mapIfNativeFunction(trace tr.Trace, sources []uint) []util.FrArray {
rhs := fmt.Sprintf("%v=>%s", ith_row, val.String())
panic(fmt.Sprintf("conflicting values in source map (row %d): %s vs %s", i, lhs, rhs))
} else if !ok {
fmt.Printf("Inserting source key (row %d): %v\n", i, extractIthColumns(i, source_keys))
// Item not previously in map
source_map.Insert(ith_key, ith_value)
}
Expand Down

0 comments on commit 6ff4a5a

Please sign in to comment.