Skip to content

Commit

Permalink
feat: support corset arrays (#444)
Browse files Browse the repository at this point in the history
* Support Array Type and Array Access

This adds support for columns with an array type declaration, along with
array access expressions.

* Add additional array tests

This adds a number of additional tests for arrays, and fixes a few bugs
uncovered by them.
  • Loading branch information
DavePearce authored Dec 16, 2024
1 parent c0edb2b commit 977366d
Show file tree
Hide file tree
Showing 25 changed files with 1,203 additions and 71 deletions.
4 changes: 2 additions & 2 deletions pkg/corset/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewGlobalEnvironment(scope *GlobalScope) GlobalEnvironment {
if binding, ok := b.(*ColumnBinding); ok && !binding.computed {
binding.AllocateId(columnId)
// Increase the column id
columnId++
columnId += binding.dataType.Width()
}
}
}
Expand All @@ -47,7 +47,7 @@ func NewGlobalEnvironment(scope *GlobalScope) GlobalEnvironment {
if binding, ok := b.(*ColumnBinding); ok && binding.computed {
binding.AllocateId(columnId)
// Increase the column id
columnId++
columnId += binding.dataType.Width()
}
}
}
Expand Down
100 changes: 100 additions & 0 deletions pkg/corset/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,106 @@ func (e *Add) Dependencies() []Symbol {
return DependenciesOfExpressions(e.Args)
}

// ============================================================================
// ArrayAccess
// ============================================================================

// ArrayAccess represents the a given value taken to a power.
type ArrayAccess struct {
name string
arg Expr
binding Binding
}

// IsQualified determines whether this symbol is qualfied or not (i.e. has an
// explicitly module specifier).
func (e *ArrayAccess) IsQualified() bool {
return false
}

// IsFunction indicates whether or not this symbol refers to a function (which
// of course it always does).
func (e *ArrayAccess) IsFunction() bool {
return false
}

// IsResolved checks whether this symbol has been resolved already, or not.
func (e *ArrayAccess) IsResolved() bool {
return e.binding != nil
}

// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *ArrayAccess) AsConstant() *big.Int {
return nil
}

// Multiplicity determines the number of values that evaluating this expression
// can generate.
func (e *ArrayAccess) Multiplicity() uint {
return determineMultiplicity([]Expr{e.arg})
}

// Module returns the module used to qualify this array access. At this time,
// however, array accesses are always unqualified.
func (e *ArrayAccess) Module() string {
panic("unqualified array access")
}

// Name returns the (unqualified) name of this symbol
func (e *ArrayAccess) Name() string {
return e.name
}

// Binding gets binding associated with this interface. This will panic if this
// symbol is not yet resolved.
func (e *ArrayAccess) Binding() Binding {
if e.binding == nil {
panic("variable access is unresolved")
}
//
return e.binding
}

// Context returns the context for this expression. Observe that the
// expression must have been resolved for this to be defined (i.e. it may
// panic if it has not been resolved yet).
func (e *ArrayAccess) Context() Context {
return e.arg.Context()
}

// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (e *ArrayAccess) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *ArrayAccess) Substitute(mapping map[uint]Expr) Expr {
return &ArrayAccess{e.name, e.arg.Substitute(mapping), e.binding}
}

// Resolve this symbol by associating it with the binding associated with
// the definition of the symbol to which this refers.
func (e *ArrayAccess) Resolve(binding Binding) bool {
if binding == nil {
panic("empty binding")
} else if e.binding != nil {
panic("already resolved")
}
//
e.binding = binding
//
return true
}

// Dependencies needed to signal declaration.
func (e *ArrayAccess) Dependencies() []Symbol {
deps := e.arg.Dependencies()
return append(deps, e)
}

// ============================================================================
// Constants
// ============================================================================
Expand Down
111 changes: 74 additions & 37 deletions pkg/corset/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,17 @@ func NewParser(srcfile *sexp.SourceFile, srcmap *sexp.SourceMap[sexp.SExp]) *Par
// Configure expression translator
p.AddSymbolRule(constantParserRule)
p.AddSymbolRule(varAccessParserRule)
p.AddRecursiveRule("+", addParserRule)
p.AddRecursiveRule("-", subParserRule)
p.AddRecursiveRule("*", mulParserRule)
p.AddRecursiveRule("~", normParserRule)
p.AddRecursiveRule("^", powParserRule)
p.AddRecursiveRule("begin", beginParserRule)
p.AddRecursiveListRule("+", addParserRule)
p.AddRecursiveListRule("-", subParserRule)
p.AddRecursiveListRule("*", mulParserRule)
p.AddRecursiveListRule("~", normParserRule)
p.AddRecursiveListRule("^", powParserRule)
p.AddRecursiveListRule("begin", beginParserRule)
p.AddListRule("for", forParserRule(parser))
p.AddRecursiveRule("if", ifParserRule)
p.AddRecursiveRule("shift", shiftParserRule)
p.AddDefaultRecursiveRule(invokeParserRule)
p.AddRecursiveListRule("if", ifParserRule)
p.AddRecursiveListRule("shift", shiftParserRule)
p.AddDefaultRecursiveListRule(invokeParserRule)
p.AddDefaultRecursiveArrayRule(arrayAccessParserRule)
//
return parser
}
Expand Down Expand Up @@ -363,29 +364,55 @@ func (p *Parser) parseColumnDeclarationAttributes(attrs []sexp.SExp) (Type, bool
var (
dataType Type = NewFieldType()
mustProve bool = false
array uint
err *SyntaxError
)

for _, attr := range attrs {
symbol := attr.AsSymbol()
for i := 0; i < len(attrs); i++ {
ith := attrs[i]
symbol := ith.AsSymbol()
// Sanity check
if symbol == nil {
return nil, false, p.translator.SyntaxError(attr, "unknown column attribute")
return nil, false, p.translator.SyntaxError(ith, "unknown column attribute")
}
//
switch symbol.Value {
case ":display", ":opcode":
// skip these for now, as they are only relevant to the inspector.
case ":array":
if array, err = p.parseArrayDimension(attrs[i+1]); err != nil {
return nil, false, err
}
// skip dimension
i++
default:
if dataType, mustProve, err = p.parseType(attr); err != nil {
if dataType, mustProve, err = p.parseType(ith); err != nil {
return nil, false, err
}
}
}
// Done
if array != 0 {
return NewArrayType(dataType, array), mustProve, nil
}
//
return dataType, mustProve, nil
}

func (p *Parser) parseArrayDimension(s sexp.SExp) (uint, *SyntaxError) {
dim := s.AsArray()
//
if dim == nil || dim.Len() != 1 || dim.Get(0).AsSymbol() == nil {
return 0, p.translator.SyntaxError(s, "invalid array dimension")
}
//
if num, ok := strconv.Atoi(dim.Get(0).AsSymbol().Value); ok == nil && num >= 0 {
return uint(num), nil
}
//
return 0, p.translator.SyntaxError(s, "invalid array dimension")
}

// Parse a constant declaration
func (p *Parser) parseDefConst(elements []sexp.SExp) (Declaration, []SyntaxError) {
var (
Expand Down Expand Up @@ -919,34 +946,26 @@ func forParserRule(p *Parser) sexp.ListRule[Expr] {
return func(list *sexp.List) (Expr, []SyntaxError) {
var (
errors []SyntaxError
n int = list.Len() - 1
rangeStr string
indexVar *sexp.Symbol
)
// Check we've got the expected number
if list.Len() != 4 {
msg := fmt.Sprintf("expected 3 arguments, found %d", list.Len())
return nil, p.translator.SyntaxErrors(list, msg)
}
// Extract index variable
if indexVar = list.Get(1).AsSymbol(); indexVar == nil {
err := p.translator.SyntaxError(list.Get(1), "invalid index variable")
errors = append(errors, *err)
}
// Extract range
for i := 2; i < n; i++ {
if ith := list.Get(i).AsSymbol(); ith != nil {
rangeStr = fmt.Sprintf("%s%s", rangeStr, ith.Value)
} else {
err := p.translator.SyntaxError(list.Get(i), "invalid range component")
errors = append(errors, *err)
}
}
// Parse range
start, end, ok := parseForRange(rangeStr)
start, end, errs := parseForRange(p, list.Get(2))
// Error Check
if !ok {
errors = append(errors, *p.translator.SyntaxError(list.Get(2), "malformed index range"))
}
errors = append(errors, errs...)
// Parse body
body, errs := p.translator.Translate(list.Get(n))
body, errs := p.translator.Translate(list.Get(3))
errors = append(errors, errs...)
//
// Error check
if len(errors) > 0 {
return nil, errors
}
Expand All @@ -957,24 +976,30 @@ func forParserRule(p *Parser) sexp.ListRule[Expr] {
}
}

func parseForRange(rangeStr string) (uint, uint, bool) {
// Parse a range which, represented as a string is "[s:e]".
func parseForRange(p *Parser, interval sexp.SExp) (uint, uint, []SyntaxError) {
var (
start int
end int
err1 error
err2 error
)
// This is a bit dirty. Essentially, we turn the sexp.Array back into a
// string and then parse it from there.
str := interval.String(false)
// Strip out any whitespace (which is permitted)
str = strings.ReplaceAll(str, " ", "")
// Check has form "[...]"
if !strings.HasPrefix(rangeStr, "[") || !strings.HasSuffix(rangeStr, "]") {
if !strings.HasPrefix(str, "[") || !strings.HasSuffix(str, "]") {
// error
return 0, 0, false
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
}
// Split out components
splits := strings.Split(rangeStr[1:len(rangeStr)-1], ":")
splits := strings.Split(str[1:len(str)-1], ":")
// Error check
if len(splits) == 0 || len(splits) > 2 {
// error
return 0, 0, false
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
} else if len(splits) == 1 {
start, err1 = strconv.Atoi(splits[0])
end = start
Expand All @@ -983,7 +1008,11 @@ func parseForRange(rangeStr string) (uint, uint, bool) {
end, err2 = strconv.Atoi(splits[1])
}
//
return uint(start), uint(end), err1 == nil && err2 == nil
if err1 != nil || err2 != nil {
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
}
// Success
return uint(start), uint(end), nil
}

func constantParserRule(symbol string) (Expr, bool, error) {
Expand Down Expand Up @@ -1016,7 +1045,7 @@ func constantParserRule(symbol string) (Expr, bool, error) {
func varAccessParserRule(col string) (Expr, bool, error) {
// Sanity check what we have
if !unicode.IsLetter(rune(col[0])) {
return nil, false, nil
return nil, false, errors.New("malformed column access")
}
// Handle qualified accesses (where permitted)
// Attempt to split column name into module / column pair.
Expand All @@ -1030,6 +1059,14 @@ func varAccessParserRule(col string) (Expr, bool, error) {
}
}

func arrayAccessParserRule(name string, args []Expr) (Expr, error) {
if len(args) != 1 {
return nil, errors.New("malformed array access")
}
//
return &ArrayAccess{name, args[0], nil}, nil
}

func addParserRule(_ string, args []Expr) (Expr, error) {
return &Add{args}, nil
}
Expand Down
30 changes: 28 additions & 2 deletions pkg/corset/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ func (r *resolver) finaliseExpressionsInModule(scope LocalScope, args []Expr) ([
//
//nolint:staticcheck
func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type, []SyntaxError) {
if v, ok := expr.(*Constant); ok {
if v, ok := expr.(*ArrayAccess); ok {
return r.finaliseArrayAccessInModule(scope, v)
} else if v, ok := expr.(*Constant); ok {
nbits := v.Val.BitLen()
return NewUintType(uint(nbits)), nil
} else if v, ok := expr.(*Add); ok {
Expand Down Expand Up @@ -542,7 +544,31 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type
} else if v, ok := expr.(*VariableAccess); ok {
return r.finaliseVariableInModule(scope, v)
} else {
return nil, r.srcmap.SyntaxErrors(expr, "unknown expression")
return nil, r.srcmap.SyntaxErrors(expr, "unknown expression encountered during resolution")
}
}

// Resolve a specific array access contained within some expression which, in
// turn, is contained within some module.
func (r *resolver) finaliseArrayAccessInModule(scope LocalScope, expr *ArrayAccess) (Type, []SyntaxError) {
// Resolve argument
if _, errors := r.finaliseExpressionInModule(scope, expr.arg); errors != nil {
return nil, errors
}
//
if !expr.IsResolved() && !scope.Bind(expr) {
return nil, r.srcmap.SyntaxErrors(expr, "unknown array column")
} else if binding, ok := expr.Binding().(*ColumnBinding); !ok {
return nil, r.srcmap.SyntaxErrors(expr, "unknown array column")
} else if arr_t, ok := binding.dataType.(*ArrayType); !ok {
return nil, r.srcmap.SyntaxErrors(expr, "expected array column")
} else if c := expr.arg.AsConstant(); c == nil {
return nil, r.srcmap.SyntaxErrors(expr, "expected constant array index")
} else if i := uint(c.Uint64()); i == 0 || i > arr_t.Width() {
return nil, r.srcmap.SyntaxErrors(expr, "array access out-of-bounds")
} else {
// All good
return arr_t.element, nil
}
}

Expand Down
Loading

0 comments on commit 977366d

Please sign in to comment.