Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support corset arrays #444

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/corset/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewGlobalEnvironment(scope *GlobalScope) GlobalEnvironment {
if binding, ok := b.(*ColumnBinding); ok && !binding.computed {
binding.AllocateId(columnId)
// Increase the column id
columnId++
columnId += binding.dataType.Width()
}
}
}
Expand All @@ -47,7 +47,7 @@ func NewGlobalEnvironment(scope *GlobalScope) GlobalEnvironment {
if binding, ok := b.(*ColumnBinding); ok && binding.computed {
binding.AllocateId(columnId)
// Increase the column id
columnId++
columnId += binding.dataType.Width()
}
}
}
Expand Down
100 changes: 100 additions & 0 deletions pkg/corset/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,106 @@ func (e *Add) Dependencies() []Symbol {
return DependenciesOfExpressions(e.Args)
}

// ============================================================================
// ArrayAccess
// ============================================================================

// ArrayAccess represents the a given value taken to a power.
type ArrayAccess struct {
name string
arg Expr
binding Binding
}

// IsQualified determines whether this symbol is qualfied or not (i.e. has an
// explicitly module specifier).
func (e *ArrayAccess) IsQualified() bool {
return false
}

// IsFunction indicates whether or not this symbol refers to a function (which
// of course it always does).
func (e *ArrayAccess) IsFunction() bool {
return false
}

// IsResolved checks whether this symbol has been resolved already, or not.
func (e *ArrayAccess) IsResolved() bool {
return e.binding != nil
}

// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *ArrayAccess) AsConstant() *big.Int {
return nil
}

// Multiplicity determines the number of values that evaluating this expression
// can generate.
func (e *ArrayAccess) Multiplicity() uint {
return determineMultiplicity([]Expr{e.arg})
}

// Module returns the module used to qualify this array access. At this time,
// however, array accesses are always unqualified.
func (e *ArrayAccess) Module() string {
panic("unqualified array access")
}

// Name returns the (unqualified) name of this symbol
func (e *ArrayAccess) Name() string {
return e.name
}

// Binding gets binding associated with this interface. This will panic if this
// symbol is not yet resolved.
func (e *ArrayAccess) Binding() Binding {
if e.binding == nil {
panic("variable access is unresolved")
}
//
return e.binding
}

// Context returns the context for this expression. Observe that the
// expression must have been resolved for this to be defined (i.e. it may
// panic if it has not been resolved yet).
func (e *ArrayAccess) Context() Context {
return e.arg.Context()
}

// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (e *ArrayAccess) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *ArrayAccess) Substitute(mapping map[uint]Expr) Expr {
return &ArrayAccess{e.name, e.arg.Substitute(mapping), e.binding}
}

// Resolve this symbol by associating it with the binding associated with
// the definition of the symbol to which this refers.
func (e *ArrayAccess) Resolve(binding Binding) bool {
if binding == nil {
panic("empty binding")
} else if e.binding != nil {
panic("already resolved")
}
//
e.binding = binding
//
return true
}

// Dependencies needed to signal declaration.
func (e *ArrayAccess) Dependencies() []Symbol {
deps := e.arg.Dependencies()
return append(deps, e)
}

// ============================================================================
// Constants
// ============================================================================
Expand Down
111 changes: 74 additions & 37 deletions pkg/corset/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,17 @@ func NewParser(srcfile *sexp.SourceFile, srcmap *sexp.SourceMap[sexp.SExp]) *Par
// Configure expression translator
p.AddSymbolRule(constantParserRule)
p.AddSymbolRule(varAccessParserRule)
p.AddRecursiveRule("+", addParserRule)
p.AddRecursiveRule("-", subParserRule)
p.AddRecursiveRule("*", mulParserRule)
p.AddRecursiveRule("~", normParserRule)
p.AddRecursiveRule("^", powParserRule)
p.AddRecursiveRule("begin", beginParserRule)
p.AddRecursiveListRule("+", addParserRule)
p.AddRecursiveListRule("-", subParserRule)
p.AddRecursiveListRule("*", mulParserRule)
p.AddRecursiveListRule("~", normParserRule)
p.AddRecursiveListRule("^", powParserRule)
p.AddRecursiveListRule("begin", beginParserRule)
p.AddListRule("for", forParserRule(parser))
p.AddRecursiveRule("if", ifParserRule)
p.AddRecursiveRule("shift", shiftParserRule)
p.AddDefaultRecursiveRule(invokeParserRule)
p.AddRecursiveListRule("if", ifParserRule)
p.AddRecursiveListRule("shift", shiftParserRule)
p.AddDefaultRecursiveListRule(invokeParserRule)
p.AddDefaultRecursiveArrayRule(arrayAccessParserRule)
//
return parser
}
Expand Down Expand Up @@ -363,29 +364,55 @@ func (p *Parser) parseColumnDeclarationAttributes(attrs []sexp.SExp) (Type, bool
var (
dataType Type = NewFieldType()
mustProve bool = false
array uint
err *SyntaxError
)

for _, attr := range attrs {
symbol := attr.AsSymbol()
for i := 0; i < len(attrs); i++ {
ith := attrs[i]
symbol := ith.AsSymbol()
// Sanity check
if symbol == nil {
return nil, false, p.translator.SyntaxError(attr, "unknown column attribute")
return nil, false, p.translator.SyntaxError(ith, "unknown column attribute")
}
//
switch symbol.Value {
case ":display", ":opcode":
// skip these for now, as they are only relevant to the inspector.
case ":array":
if array, err = p.parseArrayDimension(attrs[i+1]); err != nil {
return nil, false, err
}
// skip dimension
i++
default:
if dataType, mustProve, err = p.parseType(attr); err != nil {
if dataType, mustProve, err = p.parseType(ith); err != nil {
return nil, false, err
}
}
}
// Done
if array != 0 {
return NewArrayType(dataType, array), mustProve, nil
}
//
return dataType, mustProve, nil
}

func (p *Parser) parseArrayDimension(s sexp.SExp) (uint, *SyntaxError) {
dim := s.AsArray()
//
if dim == nil || dim.Len() != 1 || dim.Get(0).AsSymbol() == nil {
return 0, p.translator.SyntaxError(s, "invalid array dimension")
}
//
if num, ok := strconv.Atoi(dim.Get(0).AsSymbol().Value); ok == nil && num >= 0 {
return uint(num), nil
}
//
return 0, p.translator.SyntaxError(s, "invalid array dimension")
}

// Parse a constant declaration
func (p *Parser) parseDefConst(elements []sexp.SExp) (Declaration, []SyntaxError) {
var (
Expand Down Expand Up @@ -919,34 +946,26 @@ func forParserRule(p *Parser) sexp.ListRule[Expr] {
return func(list *sexp.List) (Expr, []SyntaxError) {
var (
errors []SyntaxError
n int = list.Len() - 1
rangeStr string
indexVar *sexp.Symbol
)
// Check we've got the expected number
if list.Len() != 4 {
msg := fmt.Sprintf("expected 3 arguments, found %d", list.Len())
return nil, p.translator.SyntaxErrors(list, msg)
}
// Extract index variable
if indexVar = list.Get(1).AsSymbol(); indexVar == nil {
err := p.translator.SyntaxError(list.Get(1), "invalid index variable")
errors = append(errors, *err)
}
// Extract range
for i := 2; i < n; i++ {
if ith := list.Get(i).AsSymbol(); ith != nil {
rangeStr = fmt.Sprintf("%s%s", rangeStr, ith.Value)
} else {
err := p.translator.SyntaxError(list.Get(i), "invalid range component")
errors = append(errors, *err)
}
}
// Parse range
start, end, ok := parseForRange(rangeStr)
start, end, errs := parseForRange(p, list.Get(2))
// Error Check
if !ok {
errors = append(errors, *p.translator.SyntaxError(list.Get(2), "malformed index range"))
}
errors = append(errors, errs...)
// Parse body
body, errs := p.translator.Translate(list.Get(n))
body, errs := p.translator.Translate(list.Get(3))
errors = append(errors, errs...)
//
// Error check
if len(errors) > 0 {
return nil, errors
}
Expand All @@ -957,24 +976,30 @@ func forParserRule(p *Parser) sexp.ListRule[Expr] {
}
}

func parseForRange(rangeStr string) (uint, uint, bool) {
// Parse a range which, represented as a string is "[s:e]".
func parseForRange(p *Parser, interval sexp.SExp) (uint, uint, []SyntaxError) {
var (
start int
end int
err1 error
err2 error
)
// This is a bit dirty. Essentially, we turn the sexp.Array back into a
// string and then parse it from there.
str := interval.String(false)
// Strip out any whitespace (which is permitted)
str = strings.ReplaceAll(str, " ", "")
// Check has form "[...]"
if !strings.HasPrefix(rangeStr, "[") || !strings.HasSuffix(rangeStr, "]") {
if !strings.HasPrefix(str, "[") || !strings.HasSuffix(str, "]") {
// error
return 0, 0, false
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
}
// Split out components
splits := strings.Split(rangeStr[1:len(rangeStr)-1], ":")
splits := strings.Split(str[1:len(str)-1], ":")
// Error check
if len(splits) == 0 || len(splits) > 2 {
// error
return 0, 0, false
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
} else if len(splits) == 1 {
start, err1 = strconv.Atoi(splits[0])
end = start
Expand All @@ -983,7 +1008,11 @@ func parseForRange(rangeStr string) (uint, uint, bool) {
end, err2 = strconv.Atoi(splits[1])
}
//
return uint(start), uint(end), err1 == nil && err2 == nil
if err1 != nil || err2 != nil {
return 0, 0, p.translator.SyntaxErrors(interval, "invalid interval")
}
// Success
return uint(start), uint(end), nil
}

func constantParserRule(symbol string) (Expr, bool, error) {
Expand Down Expand Up @@ -1016,7 +1045,7 @@ func constantParserRule(symbol string) (Expr, bool, error) {
func varAccessParserRule(col string) (Expr, bool, error) {
// Sanity check what we have
if !unicode.IsLetter(rune(col[0])) {
return nil, false, nil
return nil, false, errors.New("malformed column access")
}
// Handle qualified accesses (where permitted)
// Attempt to split column name into module / column pair.
Expand All @@ -1030,6 +1059,14 @@ func varAccessParserRule(col string) (Expr, bool, error) {
}
}

func arrayAccessParserRule(name string, args []Expr) (Expr, error) {
if len(args) != 1 {
return nil, errors.New("malformed array access")
}
//
return &ArrayAccess{name, args[0], nil}, nil
}

func addParserRule(_ string, args []Expr) (Expr, error) {
return &Add{args}, nil
}
Expand Down
30 changes: 28 additions & 2 deletions pkg/corset/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ func (r *resolver) finaliseExpressionsInModule(scope LocalScope, args []Expr) ([
//
//nolint:staticcheck
func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type, []SyntaxError) {
if v, ok := expr.(*Constant); ok {
if v, ok := expr.(*ArrayAccess); ok {
return r.finaliseArrayAccessInModule(scope, v)
} else if v, ok := expr.(*Constant); ok {
nbits := v.Val.BitLen()
return NewUintType(uint(nbits)), nil
} else if v, ok := expr.(*Add); ok {
Expand Down Expand Up @@ -542,7 +544,31 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type
} else if v, ok := expr.(*VariableAccess); ok {
return r.finaliseVariableInModule(scope, v)
} else {
return nil, r.srcmap.SyntaxErrors(expr, "unknown expression")
return nil, r.srcmap.SyntaxErrors(expr, "unknown expression encountered during resolution")
}
}

// Resolve a specific array access contained within some expression which, in
// turn, is contained within some module.
func (r *resolver) finaliseArrayAccessInModule(scope LocalScope, expr *ArrayAccess) (Type, []SyntaxError) {
// Resolve argument
if _, errors := r.finaliseExpressionInModule(scope, expr.arg); errors != nil {
return nil, errors
}
//
if !expr.IsResolved() && !scope.Bind(expr) {
return nil, r.srcmap.SyntaxErrors(expr, "unknown array column")
} else if binding, ok := expr.Binding().(*ColumnBinding); !ok {
return nil, r.srcmap.SyntaxErrors(expr, "unknown array column")
} else if arr_t, ok := binding.dataType.(*ArrayType); !ok {
return nil, r.srcmap.SyntaxErrors(expr, "expected array column")
} else if c := expr.arg.AsConstant(); c == nil {
return nil, r.srcmap.SyntaxErrors(expr, "expected constant array index")
} else if i := uint(c.Uint64()); i == 0 || i > arr_t.Width() {
return nil, r.srcmap.SyntaxErrors(expr, "array access out-of-bounds")
} else {
// All good
return arr_t.element, nil
}
}

Expand Down
Loading
Loading