Skip to content

Commit

Permalink
Support Array Type and Array Access
Browse files Browse the repository at this point in the history
This adds support for columns with an array type declaration, along with
array access expressions.
  • Loading branch information
DavePearce committed Dec 16, 2024
1 parent c0edb2b commit 44631ad
Show file tree
Hide file tree
Showing 13 changed files with 830 additions and 71 deletions.
4 changes: 2 additions & 2 deletions pkg/corset/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewGlobalEnvironment(scope *GlobalScope) GlobalEnvironment {
if binding, ok := b.(*ColumnBinding); ok && !binding.computed {
binding.AllocateId(columnId)
// Increase the column id
columnId++
columnId += binding.dataType.Width()
}
}
}
Expand All @@ -47,7 +47,7 @@ func NewGlobalEnvironment(scope *GlobalScope) GlobalEnvironment {
if binding, ok := b.(*ColumnBinding); ok && binding.computed {
binding.AllocateId(columnId)
// Increase the column id
columnId++
columnId += binding.dataType.Width()
}
}
}
Expand Down
100 changes: 100 additions & 0 deletions pkg/corset/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,106 @@ func (e *Add) Dependencies() []Symbol {
return DependenciesOfExpressions(e.Args)
}

// ============================================================================
// ArrayAccess
// ============================================================================

// ArrayAccess represents the a given value taken to a power.
type ArrayAccess struct {
name string
arg Expr
binding Binding
}

// IsQualified determines whether this symbol is qualfied or not (i.e. has an
// explicitly module specifier).
func (e *ArrayAccess) IsQualified() bool {
return false
}

// IsFunction indicates whether or not this symbol refers to a function (which
// of course it always does).
func (e *ArrayAccess) IsFunction() bool {
return false
}

// IsResolved checks whether this symbol has been resolved already, or not.
func (e *ArrayAccess) IsResolved() bool {
return e.binding != nil
}

// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *ArrayAccess) AsConstant() *big.Int {
return nil
}

// Multiplicity determines the number of values that evaluating this expression
// can generate.
func (e *ArrayAccess) Multiplicity() uint {
return determineMultiplicity([]Expr{e.arg})
}

// Module returns the module used to qualify this array access. At this time,
// however, array accesses are always unqualified.
func (e *ArrayAccess) Module() string {
panic("unqualified array access")
}

// Name returns the (unqualified) name of this symbol
func (e *ArrayAccess) Name() string {
return e.name
}

// Binding gets binding associated with this interface. This will panic if this
// symbol is not yet resolved.
func (e *ArrayAccess) Binding() Binding {
if e.binding == nil {
panic("variable access is unresolved")
}
//
return e.binding
}

// Context returns the context for this expression. Observe that the
// expression must have been resolved for this to be defined (i.e. it may
// panic if it has not been resolved yet).
func (e *ArrayAccess) Context() Context {
return e.arg.Context()
}

// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (e *ArrayAccess) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *ArrayAccess) Substitute(mapping map[uint]Expr) Expr {
return &ArrayAccess{e.name, e.arg.Substitute(mapping), e.binding}
}

// Resolve this symbol by associating it with the binding associated with
// the definition of the symbol to which this refers.
func (e *ArrayAccess) Resolve(binding Binding) bool {
if binding == nil {
panic("empty binding")
} else if e.binding != nil {
panic("already resolved")
}
//
e.binding = binding
//
return true
}

// Dependencies needed to signal declaration.
func (e *ArrayAccess) Dependencies() []Symbol {
deps := e.arg.Dependencies()
return append(deps, e)
}

// ============================================================================
// Constants
// ============================================================================
Expand Down
111 changes: 74 additions & 37 deletions pkg/corset/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,17 @@ func NewParser(srcfile *sexp.SourceFile, srcmap *sexp.SourceMap[sexp.SExp]) *Par
// Configure expression translator
p.AddSymbolRule(constantParserRule)
p.AddSymbolRule(varAccessParserRule)
p.AddRecursiveRule("+", addParserRule)
p.AddRecursiveRule("-", subParserRule)
p.AddRecursiveRule("*", mulParserRule)
p.AddRecursiveRule("~", normParserRule)
p.AddRecursiveRule("^", powParserRule)
p.AddRecursiveRule("begin", beginParserRule)
p.AddRecursiveListRule("+", addParserRule)
p.AddRecursiveListRule("-", subParserRule)
p.AddRecursiveListRule("*", mulParserRule)
p.AddRecursiveListRule("~", normParserRule)
p.AddRecursiveListRule("^", powParserRule)
p.AddRecursiveListRule("begin", beginParserRule)
p.AddListRule("for", forParserRule(parser))
p.AddRecursiveRule("if", ifParserRule)
p.AddRecursiveRule("shift", shiftParserRule)
p.AddDefaultRecursiveRule(invokeParserRule)
p.AddRecursiveListRule("if", ifParserRule)
p.AddRecursiveListRule("shift", shiftParserRule)
p.AddDefaultRecursiveListRule(invokeParserRule)
p.AddDefaultRecursiveArrayRule(arrayAccessParserRule)
//
return parser
}
Expand Down Expand Up @@ -363,29 +364,55 @@ func (p *Parser) parseColumnDeclarationAttributes(attrs []sexp.SExp) (Type, bool
var (
dataType Type = NewFieldType()
mustProve bool = false
array uint
err *SyntaxError
)

for _, attr := range attrs {
symbol := attr.AsSymbol()
for i := 0; i < len(attrs); i++ {
ith := attrs[i]
symbol := ith.AsSymbol()
// Sanity check
if symbol == nil {
return nil, false, p.translator.SyntaxError(attr, "unknown column attribute")
return nil, false, p.translator.SyntaxError(ith, "unknown column attribute")
}
//
switch symbol.Value {
case ":display", ":opcode":
// skip these for now, as they are only relevant to the inspector.
case ":array":
if array, err = p.parseArrayDimension(attrs[i+1]); err != nil {
return nil, false, err
}
// skip dimension
i++
default:
if dataType, mustProve, err = p.parseType(attr); err != nil {
if dataType, mustProve, err = p.parseType(ith); err != nil {
return nil, false, err
}
}
}
// Done
if array != 0 {
return NewArrayType(dataType, array), mustProve, nil
}
//
return dataType, mustProve, nil
}

func (p *Parser) parseArrayDimension(s sexp.SExp) (uint, *SyntaxError) {
dim := s.AsArray()
//
if dim == nil || dim.Len() != 1 || dim.Get(0).AsSymbol() == nil {
return 0, p.translator.SyntaxError(s, "invalid array dimension")
}
//
if num, ok := strconv.Atoi(dim.Get(0).AsSymbol().Value); ok == nil && num >= 0 {
return uint(num), nil
}
//
return 0, p.translator.SyntaxError(s, "invalid array dimension")
}

// Parse a constant declaration
func (p *Parser) parseDefConst(elements []sexp.SExp) (Declaration, []SyntaxError) {
var (
Expand Down Expand Up @@ -919,34 +946,26 @@ func forParserRule(p *Parser) sexp.ListRule[Expr] {
return func(list *sexp.List) (Expr, []SyntaxError) {
var (
errors []SyntaxError
n int = list.Len() - 1
rangeStr string
indexVar *sexp.Symbol
)
// Check we've got the expected number
if list.Len() != 4 {
msg := fmt.Sprintf("expected 3 arguments, found %d", list.Len())
return nil, p.translator.SyntaxErrors(list, msg)
}
// Extract index variable
if indexVar = list.Get(1).AsSymbol(); indexVar == nil {
err := p.translator.SyntaxError(list.Get(1), "invalid index variable")
errors = append(errors, *err)
}
// Extract range
for i := 2; i < n; i++ {
if ith := list.Get(i).AsSymbol(); ith != nil {
rangeStr = fmt.Sprintf("%s%s", rangeStr, ith.Value)
} else {
err := p.translator.SyntaxError(list.Get(i), "invalid range component")
errors = append(errors, *err)
}
}
// Parse range
start, end, ok := parseForRange(rangeStr)
start, end, errs := parseForRange(p, list.Get(2))
// Error Check
if !ok {
errors = append(errors, *p.translator.SyntaxError(list.Get(2), "malformed index range"))
}
errors = append(errors, errs...)
// Parse body
body, errs := p.translator.Translate(list.Get(n))
body, errs := p.translator.Translate(list.Get(3))
errors = append(errors, errs...)
//
// Error check
if len(errors) > 0 {
return nil, errors
}
Expand All @@ -957,24 +976,30 @@ func forParserRule(p *Parser) sexp.ListRule[Expr] {
}
}

func parseForRange(rangeStr string) (uint, uint, bool) {
// Parse a range which, represented as a string is "[s:e]".
func parseForRange(p *Parser, interval sexp.SExp) (uint, uint, []SyntaxError) {
var (
start int
end int
err1 error
err2 error
)
// This is a bit dirty. Essentially, we turn the sexp.Array back into a
// string and then parse it from there.
str := interval.String(false)
// Strip out any whitespace (which is permitted)
str = strings.ReplaceAll(str, " ", "")
// Check has form "[...]"
if !strings.HasPrefix(rangeStr, "[") || !strings.HasSuffix(rangeStr, "]") {
if !strings.HasPrefix(str, "[") || !strings.HasSuffix(str, "]") {
// error
return 0, 0, false
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
}
// Split out components
splits := strings.Split(rangeStr[1:len(rangeStr)-1], ":")
splits := strings.Split(str[1:len(str)-1], ":")
// Error check
if len(splits) == 0 || len(splits) > 2 {
// error
return 0, 0, false
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
} else if len(splits) == 1 {
start, err1 = strconv.Atoi(splits[0])
end = start
Expand All @@ -983,7 +1008,11 @@ func parseForRange(rangeStr string) (uint, uint, bool) {
end, err2 = strconv.Atoi(splits[1])
}
//
return uint(start), uint(end), err1 == nil && err2 == nil
if err1 != nil || err2 != nil {
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
}
// Success
return uint(start), uint(end), nil
}

func constantParserRule(symbol string) (Expr, bool, error) {
Expand Down Expand Up @@ -1016,7 +1045,7 @@ func constantParserRule(symbol string) (Expr, bool, error) {
func varAccessParserRule(col string) (Expr, bool, error) {
// Sanity check what we have
if !unicode.IsLetter(rune(col[0])) {
return nil, false, nil
return nil, false, errors.New("malformed column access")
}
// Handle qualified accesses (where permitted)
// Attempt to split column name into module / column pair.
Expand All @@ -1030,6 +1059,14 @@ func varAccessParserRule(col string) (Expr, bool, error) {
}
}

func arrayAccessParserRule(name string, args []Expr) (Expr, error) {
if len(args) != 1 {
return nil, errors.New("malformed array access")
}
//
return &ArrayAccess{name, args[0], nil}, nil
}

func addParserRule(_ string, args []Expr) (Expr, error) {
return &Add{args}, nil
}
Expand Down
26 changes: 24 additions & 2 deletions pkg/corset/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ func (r *resolver) finaliseExpressionsInModule(scope LocalScope, args []Expr) ([
//
//nolint:staticcheck
func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type, []SyntaxError) {
if v, ok := expr.(*Constant); ok {
if v, ok := expr.(*ArrayAccess); ok {
return r.finaliseArrayAccessInModule(scope, v)
} else if v, ok := expr.(*Constant); ok {
nbits := v.Val.BitLen()
return NewUintType(uint(nbits)), nil
} else if v, ok := expr.(*Add); ok {
Expand Down Expand Up @@ -542,10 +544,30 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type
} else if v, ok := expr.(*VariableAccess); ok {
return r.finaliseVariableInModule(scope, v)
} else {
return nil, r.srcmap.SyntaxErrors(expr, "unknown expression")
return nil, r.srcmap.SyntaxErrors(expr, "unknown expression encountered during resolution")
}
}

// Resolve a specific array access contained within some expression which, in
// turn, is contained within some module.
func (r *resolver) finaliseArrayAccessInModule(scope LocalScope, expr *ArrayAccess) (Type, []SyntaxError) {
// Resolve argument
if _, errors := r.finaliseExpressionInModule(scope, expr.arg); errors != nil {
return nil, errors
}
//
if !expr.IsResolved() && !scope.Bind(expr) {
return nil, r.srcmap.SyntaxErrors(expr, "unknown array column")
} else if binding, ok := expr.Binding().(*ColumnBinding); ok {
// Found column, now check it is an array.
if arr_t, ok := binding.dataType.(*ArrayType); ok {
return arr_t.element, nil
}
}
// Default
return nil, r.srcmap.SyntaxErrors(expr, "not an array column")
}

// Resolve an if condition contained within some expression which, in turn, is
// contained within some module. An important step occurrs here where, based on
// the semantics of the condition, this is inferred as an "if-zero" or an
Expand Down
Loading

0 comments on commit 44631ad

Please sign in to comment.