Skip to content

Commit

Permalink
Support arrays with inteval dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
DavePearce committed Dec 21, 2024
1 parent 33e467b commit ec0b5a4
Show file tree
Hide file tree
Showing 10 changed files with 621 additions and 23 deletions.
34 changes: 22 additions & 12 deletions pkg/corset/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ func (p *Parser) parseColumnDeclarationAttributes(attrs []sexp.SExp) (Type, bool
var (
dataType Type = NewFieldType()
mustProve bool = false
array uint
array_min uint
array_max uint
err *SyntaxError
)

Expand Down Expand Up @@ -396,7 +397,7 @@ func (p *Parser) parseColumnDeclarationAttributes(attrs []sexp.SExp) (Type, bool
return nil, false, p.translator.SyntaxError(ith, "unknown display definition")
}
case ":array":
if array, err = p.parseArrayDimension(attrs[i+1]); err != nil {
if array_min, array_max, err = p.parseArrayDimension(attrs[i+1]); err != nil {
return nil, false, err
}
// skip dimension
Expand All @@ -408,25 +409,34 @@ func (p *Parser) parseColumnDeclarationAttributes(attrs []sexp.SExp) (Type, bool
}
}
// Done
if array != 0 {
return NewArrayType(dataType, array), mustProve, nil
if array_max != 0 {
return NewArrayType(dataType, array_min, array_max), mustProve, nil
}
//
return dataType, mustProve, nil
}

func (p *Parser) parseArrayDimension(s sexp.SExp) (uint, *SyntaxError) {
func (p *Parser) parseArrayDimension(s sexp.SExp) (uint, uint, *SyntaxError) {
dim := s.AsArray()
//
if dim == nil || dim.Len() != 1 || dim.Get(0).AsSymbol() == nil {
return 0, p.translator.SyntaxError(s, "invalid array dimension")
}
//
if num, ok := strconv.Atoi(dim.Get(0).AsSymbol().Value); ok == nil && num >= 0 {
return uint(num), nil
if dim == nil || dim.Get(0).AsSymbol() == nil || dim.Len() != 1 {
return 0, 0, p.translator.SyntaxError(s, "invalid array dimension")
} else {
// Check for interval dimensions
split := strings.Split(dim.Get(0).AsSymbol().Value, ":")
//
if len(split) == 0 || len(split) > 2 {
return 0, 0, p.translator.SyntaxError(s, "invalid array dimension")
} else if m, ok_m := strconv.Atoi(split[0]); len(split) == 1 && m >= 0 && ok_m == nil {
return uint(1), uint(m), nil
} else if ok_m != nil || m < 0 {
//unlikely scenarios
} else if n, ok_n := strconv.Atoi(split[1]); len(split) == 2 && n >= 0 && ok_n == nil {
return uint(m), uint(n), nil
}
}
//
return 0, p.translator.SyntaxError(s, "invalid array dimension")
return 0, 0, p.translator.SyntaxError(s, "invalid array dimension")
}

// Parse a constant declaration
Expand Down
18 changes: 13 additions & 5 deletions pkg/corset/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package corset

import (
"fmt"
"math"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/hir"
Expand Down Expand Up @@ -109,7 +110,7 @@ func (t *translator) translateDefColumn(decl *DefColumn, module string) []Syntax
if arr_t, ok := decl.DataType().(*ArrayType); ok {
var errors []SyntaxError
// Handle array types
for i := uint(1); i <= arr_t.size; i++ {
for i := arr_t.min; i <= arr_t.max; i++ {
name := fmt.Sprintf("%s_%d", decl.name, i)
errs := t.translateRawColumn(decl, module, name, arr_t.element.AsUnderlying(), columnId)
errors = append(errors, errs...)
Expand Down Expand Up @@ -455,19 +456,26 @@ func (t *translator) translateExpressionInModule(expr Expr, module string, shift
}

func (t *translator) translateArrayAccessInModule(expr *ArrayAccess, shift int) (hir.Expr, []SyntaxError) {
var errors []SyntaxError
var (
errors []SyntaxError
min uint = 0
max uint = math.MaxUint
)
// Lookup the column
binding, ok := expr.Binding().(*ColumnBinding)
// Did we find it?
if !ok {
errors = append(errors, *t.srcmap.SyntaxError(expr.arg, "invalid array index encountered during translation"))
} else if arr_t, ok := binding.dataType.(*ArrayType); ok {
min = arr_t.min
max = arr_t.max
}
// Array index should be statically known
index := expr.arg.AsConstant()
//
if index == nil {
errors = append(errors, *t.srcmap.SyntaxError(expr.arg, "expected constant array index"))
} else if i := uint(index.Uint64()); i == 0 || (binding != nil && i > binding.dataType.Width()) {
} else if i := uint(index.Uint64()); i < min || i > max {
errors = append(errors, *t.srcmap.SyntaxError(expr.arg, "array index out-of-bounds"))
}
// Error check
Expand All @@ -476,8 +484,8 @@ func (t *translator) translateArrayAccessInModule(expr *ArrayAccess, shift int)
}
// Lookup underlying column info
info := t.env.Column(binding.module, expr.Name())
// Update column id (remember indices start from 1)
columnId := info.cid + uint(index.Uint64()) - 1
// Update column id
columnId := info.cid + uint(index.Uint64()) - min
// Done
return &hir.ColumnAccess{Column: columnId, Shift: shift}, nil
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/corset/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,15 @@ func (p *NativeType) String() string {
type ArrayType struct {
// element type
element Type
// array size
size uint
// min index
min uint
// max index
max uint
}

// NewArrayType constructs a new array type of a given (fixed) size.
func NewArrayType(element Type, size uint) *ArrayType {
return &ArrayType{element, size}
func NewArrayType(element Type, min uint, max uint) *ArrayType {
return &ArrayType{element, min, max}
}

// HasLoobeanSemantics indicates whether or not this type supports "loobean"
Expand Down Expand Up @@ -247,7 +249,7 @@ func (p *ArrayType) WithBooleanSemantics() Type {
// Width returns the number of underlying columns represented by this column.
// For example, an array of size n will expand into n underlying columns.
func (p *ArrayType) Width() uint {
return p.size
return p.max - p.min + 1
}

// AsUnderlying attempts to convert this type into an underlying type. If this
Expand All @@ -266,5 +268,5 @@ func (p *ArrayType) SubtypeOf(other Type) bool {
}

func (p *ArrayType) String() string {
return fmt.Sprintf("(%s)[%d]", p.element.String(), p.size)
return fmt.Sprintf("(%s)[%d:%d]", p.element.String(), p.min, p.max)
}
8 changes: 8 additions & 0 deletions pkg/test/valid_corset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,14 @@ func Test_Array_03(t *testing.T) {
Check(t, false, "array_03")
}

func Test_Array_04(t *testing.T) {
Check(t, false, "array_04")
}

func Test_Array_05(t *testing.T) {
Check(t, false, "array_05")
}

// ===================================================================
// Reduce
// ===================================================================
Expand Down
18 changes: 18 additions & 0 deletions testdata/array_04.accepts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{ "ARG": [], "BIT_4": [], "BIT_3": [], "BIT_2": [], "BIT_1": [] }
;;
{ "ARG": [0], "BIT_4": [0], "BIT_3": [0], "BIT_2": [0], "BIT_1": [0] }
{ "ARG": [1], "BIT_4": [0], "BIT_3": [0], "BIT_2": [0], "BIT_1": [1] }
{ "ARG": [2], "BIT_4": [0], "BIT_3": [0], "BIT_2": [1], "BIT_1": [0] }
{ "ARG": [3], "BIT_4": [0], "BIT_3": [0], "BIT_2": [1], "BIT_1": [1] }
{ "ARG": [4], "BIT_4": [0], "BIT_3": [1], "BIT_2": [0], "BIT_1": [0] }
{ "ARG": [5], "BIT_4": [0], "BIT_3": [1], "BIT_2": [0], "BIT_1": [1] }
{ "ARG": [6], "BIT_4": [0], "BIT_3": [1], "BIT_2": [1], "BIT_1": [0] }
{ "ARG": [7], "BIT_4": [0], "BIT_3": [1], "BIT_2": [1], "BIT_1": [1] }
{ "ARG": [8], "BIT_4": [1], "BIT_3": [0], "BIT_2": [0], "BIT_1": [0] }
{ "ARG": [9], "BIT_4": [1], "BIT_3": [0], "BIT_2": [0], "BIT_1": [1] }
{ "ARG": [10], "BIT_4": [1], "BIT_3": [0], "BIT_2": [1], "BIT_1": [0] }
{ "ARG": [11], "BIT_4": [1], "BIT_3": [0], "BIT_2": [1], "BIT_1": [1] }
{ "ARG": [12], "BIT_4": [1], "BIT_3": [1], "BIT_2": [0], "BIT_1": [0] }
{ "ARG": [13], "BIT_4": [1], "BIT_3": [1], "BIT_2": [0], "BIT_1": [1] }
{ "ARG": [14], "BIT_4": [1], "BIT_3": [1], "BIT_2": [1], "BIT_1": [0] }
{ "ARG": [15], "BIT_4": [1], "BIT_3": [1], "BIT_2": [1], "BIT_1": [1] }
11 changes: 11 additions & 0 deletions testdata/array_04.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(defcolumns
(BIT :binary@prove :array [1:4])
(ARG :i16@loob))

(defconstraint bits ()
(- ARG
(+
(* 1 [BIT 1])
(* 2 [BIT 2])
(* 4 [BIT 3])
(* 8 [BIT 4]))))
Loading

0 comments on commit ec0b5a4

Please sign in to comment.