diff --git a/pkg/corset/ast.go b/pkg/corset/ast.go index 4ce3e9ff..2027f2f0 100644 --- a/pkg/corset/ast.go +++ b/pkg/corset/ast.go @@ -4,7 +4,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/sexp" - "github.com/consensys/go-corset/pkg/trace" tr "github.com/consensys/go-corset/pkg/trace" ) @@ -175,7 +174,7 @@ type DefInterleaved struct { // The target column being defined Target string // The source columns used to define the interleaved target column. - Sources []*DefSourceColumn + Sources []*DefName } // CanFinalise checks whether or not this interleaving is ready to be finalised. @@ -200,19 +199,6 @@ func (p *DefInterleaved) Lisp() sexp.SExp { panic("got here") } -// DefSourceColumn provides information about a column being permuted by a -// sorted permutation. -type DefSourceColumn struct { - // Name of the column to be permuted - Name string -} - -// Lisp converts this node into its lisp representation. This is primarily used -// for debugging purposes. -func (p *DefSourceColumn) Lisp() sexp.SExp { - panic("got here") -} - // DefLookup represents a lookup constraint between a set N of source // expressions and a set of N target expressions. The source expressions must // have a single context (i.e. all be in the same module) and likewise for the @@ -325,6 +311,52 @@ func (p *DefProperty) Lisp() sexp.SExp { // parameters). In contrast, an impure function can access those columns // defined within its enclosing context. type DefFun struct { + Name *DefName + // Flag whether or not is pure function + Pure bool + // Return type + Return sc.Type + // Parameters + Parameters []*DefParameter + // Body + Body Expr +} + +// IsDeclaration needed to signal declaration. +func (p *DefFun) IsDeclaration() {} + +// Lisp converts this node into its lisp representation. This is primarily used +// for debugging purposes. +func (p *DefFun) Lisp() sexp.SExp { + panic("got here") +} + +// DefParameter packages together those piece relevant to declaring an individual +// parameter, such its name and type. +type DefParameter struct { + // Column name + Name string + // The datatype which all values in this parameter should inhabit. + DataType sc.Type +} + +// Lisp converts this node into its lisp representation. This is primarily used +// for debugging purposes. +func (p *DefParameter) Lisp() sexp.SExp { + panic("got here") +} + +// DefName is simply a wrapper around a string which can be associated with +// source information for producing syntax errors. +type DefName struct { + // Name of the column to be permuted + Name string +} + +// Lisp converts this node into its lisp representation. This is primarily used +// for debugging purposes. +func (p *DefName) Lisp() sexp.SExp { + panic("got here") } // Expr represents an arbitrary expression over the columns of a given context @@ -347,6 +379,10 @@ type Expr interface { // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). Context() tr.Context + + // Substitute all variables (such as for function parameters) arising in + // this expression. + Substitute(args []Expr) Expr } // ============================================================================ @@ -375,6 +411,12 @@ func (e *Add) Lisp() sexp.SExp { panic("todo") } +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *Add) Substitute(args []Expr) Expr { + return &Add{SubstituteExpressions(e.Args, args)} +} + // ============================================================================ // Constants // ============================================================================ @@ -401,6 +443,12 @@ func (e *Constant) Lisp() sexp.SExp { return sexp.NewSymbol(e.Val.String()) } +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *Constant) Substitute(args []Expr) Expr { + return e +} + // ============================================================================ // Exponentiation // ============================================================================ @@ -430,6 +478,12 @@ func (e *Exp) Lisp() sexp.SExp { panic("todo") } +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *Exp) Substitute(args []Expr) Expr { + return &Exp{e.Arg.Substitute(args), e.Pow} +} + // ============================================================================ // IfZero // ============================================================================ @@ -464,6 +518,15 @@ func (e *IfZero) Lisp() sexp.SExp { panic("todo") } +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *IfZero) Substitute(args []Expr) Expr { + return &IfZero{e.Condition.Substitute(args), + SubstituteOptionalExpression(e.TrueBranch, args), + SubstituteOptionalExpression(e.FalseBranch, args), + } +} + // ============================================================================ // List // ============================================================================ @@ -490,6 +553,12 @@ func (e *List) Lisp() sexp.SExp { panic("todo") } +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *List) Substitute(args []Expr) Expr { + return &List{SubstituteExpressions(e.Args, args)} +} + // ============================================================================ // Multiplication // ============================================================================ @@ -516,6 +585,12 @@ func (e *Mul) Lisp() sexp.SExp { panic("todo") } +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *Mul) Substitute(args []Expr) Expr { + return &Mul{SubstituteExpressions(e.Args, args)} +} + // ============================================================================ // Normalise // ============================================================================ @@ -543,6 +618,12 @@ func (e *Normalise) Lisp() sexp.SExp { panic("todo") } +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *Normalise) Substitute(args []Expr) Expr { + return &Normalise{e.Arg.Substitute(args)} +} + // ============================================================================ // Subtraction // ============================================================================ @@ -569,6 +650,54 @@ func (e *Sub) Lisp() sexp.SExp { panic("todo") } +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *Sub) Substitute(args []Expr) Expr { + return &Sub{SubstituteExpressions(e.Args, args)} +} + +// ============================================================================ +// VariableAccess +// ============================================================================ + +// Invoke represents an attempt to invoke a given function. +type Invoke struct { + Module *string + Name string + Args []Expr + Binding *FunctionBinding +} + +// Context returns the context for this expression. Observe that the +// expression must have been resolved for this to be defined (i.e. it may +// panic if it has not been resolved yet). +func (e *Invoke) Context() tr.Context { + if e.Binding == nil { + panic("unresolved expressions encountered whilst resolving context") + } + // TODO: impure functions can have their own context. + return ContextOfExpressions(e.Args) +} + +// Multiplicity determines the number of values that evaluating this expression +// can generate. +func (e *Invoke) Multiplicity() uint { + // FIXME: is this always correct? + return 1 +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *Invoke) Lisp() sexp.SExp { + panic("todo") +} + +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *Invoke) Substitute(args []Expr) Expr { + return &Invoke{e.Module, e.Name, SubstituteExpressions(e.Args, args), e.Binding} +} + // ============================================================================ // VariableAccess // ============================================================================ @@ -579,13 +708,12 @@ type VariableAccess struct { Module *string Name string Shift int - Binding *Binder + Binding Binding } // Multiplicity determines the number of values that evaluating this expression // can generate. func (e *VariableAccess) Multiplicity() uint { - // NOTE: this might not be true for invocations. return 1 } @@ -597,26 +725,28 @@ func (e *VariableAccess) Context() tr.Context { panic("unresolved expressions encountered whilst resolving context") } // Extract saved context - return e.Binding.Context + return e.Binding.Context() } // Lisp converts this schema element into a simple S-Expression, for example -// so it can be printed. +// so it can be printed.a func (e *VariableAccess) Lisp() sexp.SExp { panic("todo") } -// Binder provides additional information determined during the resolution -// phase. Specifically, it clarifies the meaning of a given variable name used -// within an expression (i.e. is it a column access, a local variable access, -// etc). -type Binder struct { - // Identifies whether this is a column access, or a variable access. - Column bool - // For a column access, this identifies the enclosing context. - Context trace.Context - // Identifies the variable or column index (as appropriate). - Index uint +// Substitute all variables (such as for function parameters) arising in +// this expression. +func (e *VariableAccess) Substitute(args []Expr) Expr { + if b, ok := e.Binding.(*ParameterBinding); ok { + // This is a variable to be substituted. + if e.Shift != 0 { + panic("support variable shifts") + } + // + return args[b.index] + } + // Nothing to do here + return e } // ============================================================================ @@ -638,6 +768,28 @@ func ContextOfExpressions(exprs []Expr) tr.Context { return context } +// SubstituteExpressions substitutes all variables found in a given set of +// expressions. +func SubstituteExpressions(exprs []Expr, vars []Expr) []Expr { + nexprs := make([]Expr, len(exprs)) + // + for i := 0; i < len(nexprs); i++ { + nexprs[i] = exprs[i].Substitute(vars) + } + // + return nexprs +} + +// SubstituteOptionalExpression substitutes through an expression which is +// optional (i.e. might be nil). In such case, nil is returned. +func SubstituteOptionalExpression(expr Expr, vars []Expr) Expr { + if expr != nil { + expr = expr.Substitute(vars) + } + // + return expr +} + func determineMultiplicity(exprs []Expr) uint { width := uint(1) // diff --git a/pkg/corset/binding.go b/pkg/corset/binding.go new file mode 100644 index 00000000..fc427e5c --- /dev/null +++ b/pkg/corset/binding.go @@ -0,0 +1,62 @@ +package corset + +import ( + tr "github.com/consensys/go-corset/pkg/trace" +) + +// Binding represents an association between a name, as found in a source file, +// and concrete item (e.g. a column, function, etc). +type Binding interface { + // Returns the context associated with this binding. + Context() tr.Context +} + +// ColumnBinding represents something bound to a given column. +type ColumnBinding struct { + // For a column access, this identifies the enclosing context. + context tr.Context + // Identifies the variable or column index (as appropriate). + index uint +} + +// Context returns the enclosing context for this column access. +func (p *ColumnBinding) Context() tr.Context { + return p.context +} + +// ColumnID returns the column identifier that this column access refers to. +func (p *ColumnBinding) ColumnID() uint { + return p.index +} + +// ParameterBinding represents something bound to a given column. +type ParameterBinding struct { + // Identifies the variable or column index (as appropriate). + index uint +} + +// Context for a parameter is always void, as it does not correspond to a column +// in given module. +func (p *ParameterBinding) Context() tr.Context { + return tr.VoidContext() +} + +// FunctionBinding represents the binding of a function application to its +// physical definition. +type FunctionBinding struct { + // arity determines the number of arguments this function takes. + arity uint + // body of the function in question. + body Expr +} + +// Context for a parameter is always void, as it does not correspond to a column +// in given module. +func (p *FunctionBinding) Context() tr.Context { + return tr.VoidContext() +} + +// Apply a given set of arguments to this function binding. +func (p *FunctionBinding) Apply(args []Expr) Expr { + return p.body.Substitute(args) +} diff --git a/pkg/corset/compiler.go b/pkg/corset/compiler.go index 717b33a5..487f88ee 100644 --- a/pkg/corset/compiler.go +++ b/pkg/corset/compiler.go @@ -1,6 +1,8 @@ package corset import ( + "fmt" + "github.com/consensys/go-corset/pkg/hir" "github.com/consensys/go-corset/pkg/sexp" ) @@ -70,7 +72,7 @@ func (p *Compiler) Compile() (*hir.Schema, []SyntaxError) { } // Check constraint contexts (e.g. for constraints, lookups, etc) // Type check constraints - + fmt.Println("Translating Circuit...") // Finally, translate everything and add it to the schema. return TranslateCircuit(env, p.srcmap, &p.circuit) } diff --git a/pkg/corset/environment.go b/pkg/corset/environment.go index 42d6ac7f..7906649c 100644 --- a/pkg/corset/environment.go +++ b/pkg/corset/environment.go @@ -54,6 +54,12 @@ func EmptyEnvironment() *Environment { return &Environment{modules, columns} } +// NewModuleScope creates a new evaluation scope. +func (p *Environment) NewModuleScope(module string) *ModuleScope { + mid := p.Module(module) + return &ModuleScope{mid, p, make(map[string]FunctionBinding)} +} + // RegisterModule registers a new module within this environment. Observe that // this will panic if the module already exists. Furthermore, the module // identifier is always determined as the next available identifier. diff --git a/pkg/corset/parser.go b/pkg/corset/parser.go index 0996d963..71e42fb6 100644 --- a/pkg/corset/parser.go +++ b/pkg/corset/parser.go @@ -27,7 +27,7 @@ import ( func ParseSourceFiles(files []*sexp.SourceFile) (Circuit, *sexp.SourceMaps[Node], []SyntaxError) { var circuit Circuit // (for now) at most one error per source file is supported. - var errors []SyntaxError = make([]SyntaxError, len(files)) + var errors []SyntaxError // Construct an initially empty source map srcmaps := sexp.NewSourceMaps[Node]() // num_errs counts the number of errors reported @@ -37,13 +37,13 @@ func ParseSourceFiles(files []*sexp.SourceFile) (Circuit, *sexp.SourceMaps[Node] // Names identifies the names of each unique module. names := make([]string, 0) // - for i, file := range files { - c, srcmap, err := ParseSourceFile(file) + for _, file := range files { + c, srcmap, errs := ParseSourceFile(file) // Handle errors - if err != nil { - num_errs++ + if len(errs) > 0 { + num_errs += uint(len(errs)) // Report any errors encountered - errors[i] = *err + errors = append(errors, errs...) } else { // Combine source maps srcmaps.Join(srcmap) @@ -82,20 +82,23 @@ func ParseSourceFiles(files []*sexp.SourceFile) (Circuit, *sexp.SourceMaps[Node] // ParseSourceFile parses the contents of a single lisp file into one or more // modules. Observe that every lisp file starts in the "prelude" or "root" // module, and may declare items for additional modules as necessary. -func ParseSourceFile(srcfile *sexp.SourceFile) (Circuit, *sexp.SourceMap[Node], *SyntaxError) { - var circuit Circuit +func ParseSourceFile(srcfile *sexp.SourceFile) (Circuit, *sexp.SourceMap[Node], []SyntaxError) { + var ( + circuit Circuit + errors []SyntaxError + ) // Parse bytes into an S-Expression terms, srcmap, err := srcfile.ParseAll() // Check test file parsed ok if err != nil { - return circuit, nil, err + return circuit, nil, []SyntaxError{*err} } // Construct parser for corset syntax p := NewParser(srcfile, srcmap) // Parse whatever is declared at the beginning of the file before the first // module declaration. These declarations form part of the "prelude". - if circuit.Declarations, terms, err = p.parseModuleContents(terms); err != nil { - return circuit, nil, err + if circuit.Declarations, terms, errors = p.parseModuleContents(terms); len(errors) > 0 { + return circuit, nil, errors } // Continue parsing string until nothing remains. for len(terms) != 0 { @@ -104,12 +107,12 @@ func ParseSourceFile(srcfile *sexp.SourceFile) (Circuit, *sexp.SourceMap[Node], decls []Declaration ) // Extract module name - if name, err = p.parseModuleStart(terms[0]); err != nil { - return circuit, nil, err + if name, errors = p.parseModuleStart(terms[0]); len(errors) > 0 { + return circuit, nil, errors } // Parse module contents - if decls, terms, err = p.parseModuleContents(terms[1:]); err != nil { - return circuit, nil, err + if decls, terms, errors = p.parseModuleContents(terms[1:]); len(errors) > 0 { + return circuit, nil, errors } else if len(decls) != 0 { circuit.Modules = append(circuit.Modules, Module{name, decls}) } @@ -151,6 +154,7 @@ func NewParser(srcfile *sexp.SourceFile, srcmap *sexp.SourceMap[sexp.SExp]) *Par p.AddRecursiveRule("^", powParserRule) p.AddRecursiveRule("if", ifParserRule) p.AddRecursiveRule("begin", beginParserRule) + p.AddDefaultRecursiveRule(invokeParserRule) // return parser } @@ -173,7 +177,8 @@ func (p *Parser) mapSourceNode(from sexp.SExp, to Node) { } // Extract all declarations associated with a given module and package them up. -func (p *Parser) parseModuleContents(terms []sexp.SExp) ([]Declaration, []sexp.SExp, *SyntaxError) { +func (p *Parser) parseModuleContents(terms []sexp.SExp) ([]Declaration, []sexp.SExp, []SyntaxError) { + var errors []SyntaxError // decls := make([]Declaration, 0) // @@ -181,35 +186,38 @@ func (p *Parser) parseModuleContents(terms []sexp.SExp) ([]Declaration, []sexp.S e, ok := s.(*sexp.List) // Check for error if !ok { - return nil, nil, p.translator.SyntaxError(s, "unexpected or malformed declaration") - } - // Check for end-of-module - if e.MatchSymbols(2, "module") { + err := p.translator.SyntaxError(s, "unexpected or malformed declaration") + errors = append(errors, *err) + } else if e.MatchSymbols(2, "module") { return decls, terms[i:], nil - } - // Parse the declaration - if decl, err := p.parseDeclaration(e); err != nil { - return nil, nil, err + } else if decl, errs := p.parseDeclaration(e); errs != nil { + errors = append(errors, errs...) } else { // Continue accumulating declarations for this module. decls = append(decls, decl) } } + // Sanity check errors + if len(errors) > 0 { + return nil, nil, errors + } // End-of-file signals end-of-module. return decls, make([]sexp.SExp, 0), nil } // Parse a module declaration of the form "(module m1)" which indicates the // start of module m1. -func (p *Parser) parseModuleStart(s sexp.SExp) (string, *SyntaxError) { +func (p *Parser) parseModuleStart(s sexp.SExp) (string, []SyntaxError) { l, ok := s.(*sexp.List) // Check for error if !ok { - return "", p.translator.SyntaxError(s, "unexpected or malformed declaration") + err := p.translator.SyntaxError(s, "unexpected or malformed declaration") + return "", []SyntaxError{*err} } // Sanity check declaration if len(l.Elements) > 2 { - return "", p.translator.SyntaxError(l, "malformed module declaration") + err := p.translator.SyntaxError(l, "malformed module declaration") + return "", []SyntaxError{*err} } // Extract column name name := l.Elements[1].AsSymbol().Value @@ -217,54 +225,68 @@ func (p *Parser) parseModuleStart(s sexp.SExp) (string, *SyntaxError) { return name, nil } -func (p *Parser) parseDeclaration(s *sexp.List) (Declaration, *SyntaxError) { +func (p *Parser) parseDeclaration(s *sexp.List) (Declaration, []SyntaxError) { var ( - decl Declaration - error *SyntaxError + decl Declaration + errors []SyntaxError + err *SyntaxError ) // if s.MatchSymbols(1, "defcolumns") { - decl, error = p.parseDefColumns(s) + decl, errors = p.parseDefColumns(s) } else if s.Len() == 4 && s.MatchSymbols(2, "defconstraint") { - decl, error = p.parseDefConstraint(s.Elements) + decl, errors = p.parseDefConstraint(s.Elements) + } else if s.Len() == 3 && s.MatchSymbols(1, "defpurefun") { + decl, errors = p.parseDefPureFun(s.Elements) } else if s.Len() == 3 && s.MatchSymbols(1, "definrange") { - decl, error = p.parseDefInRange(s.Elements) + decl, err = p.parseDefInRange(s.Elements) } else if s.Len() == 3 && s.MatchSymbols(1, "definterleaved") { - decl, error = p.parseDefInterleaved(s.Elements) + decl, err = p.parseDefInterleaved(s.Elements) } else if s.Len() == 4 && s.MatchSymbols(1, "deflookup") { - decl, error = p.parseDefLookup(s.Elements) + decl, err = p.parseDefLookup(s.Elements) } else if s.Len() == 3 && s.MatchSymbols(2, "defpermutation") { - decl, error = p.parseDefPermutation(s.Elements) + decl, err = p.parseDefPermutation(s.Elements) } else if s.Len() == 3 && s.MatchSymbols(2, "defproperty") { - decl, error = p.parseDefProperty(s.Elements) + decl, err = p.parseDefProperty(s.Elements) } else { - error = p.translator.SyntaxError(s, "malformed declaration") + err = p.translator.SyntaxError(s, "malformed declaration") + } + // Handle unit error case + if err != nil { + errors = append(errors, *err) } // Register node if appropriate if decl != nil { p.mapSourceNode(s, decl) } // done - return decl, error + return decl, errors } // Parse a column declaration -func (p *Parser) parseDefColumns(l *sexp.List) (*DefColumns, *SyntaxError) { +func (p *Parser) parseDefColumns(l *sexp.List) (*DefColumns, []SyntaxError) { columns := make([]*DefColumn, l.Len()-1) // Sanity check declaration if len(l.Elements) == 1 { - return nil, p.translator.SyntaxError(l, "malformed column declaration") + err := p.translator.SyntaxError(l, "malformed column declaration") + return nil, []SyntaxError{*err} } + // + var errors []SyntaxError // Process column declarations one by one. for i := 1; i < len(l.Elements); i++ { decl, err := p.parseColumnDeclaration(l.Elements[i]) // Extract column name if err != nil { - return nil, err + errors = append(errors, *err) } // Assign the declaration columns[i-1] = decl } + // Sanity check errors + if len(errors) > 0 { + return nil, errors + } // Done return &DefColumns{columns}, nil } @@ -301,10 +323,12 @@ func (p *Parser) parseColumnDeclaration(e sexp.SExp) (*DefColumn, *SyntaxError) } // Parse a vanishing declaration -func (p *Parser) parseDefConstraint(elements []sexp.SExp) (*DefConstraint, *SyntaxError) { +func (p *Parser) parseDefConstraint(elements []sexp.SExp) (*DefConstraint, []SyntaxError) { + var errors []SyntaxError // Initial sanity checks if elements[1].AsSymbol() == nil { - return nil, p.translator.SyntaxError(elements[1], "expected constraint handle") + err := p.translator.SyntaxError(elements[1], "expected constraint handle") + return nil, []SyntaxError{*err} } // handle := elements[1].AsSymbol().Value @@ -313,12 +337,16 @@ func (p *Parser) parseDefConstraint(elements []sexp.SExp) (*DefConstraint, *Synt domain, guard, err := p.parseConstraintAttributes(elements[2]) // Check for error if err != nil { - return nil, err + errors = append(errors, *err) } // Translate expression expr, err := p.translator.Translate(elements[3]) if err != nil { - return nil, err + errors = append(errors, *err) + } + // + if len(errors) > 0 { + return nil, errors } // Done return &DefConstraint{handle, domain, guard, expr}, nil @@ -335,7 +363,7 @@ func (p *Parser) parseDefInterleaved(elements []sexp.SExp) (*DefInterleaved, *Sy // Extract target and source columns target := elements[1].AsSymbol().Value sexpSources := elements[2].AsList() - sources := make([]*DefSourceColumn, sexpSources.Len()) + sources := make([]*DefName, sexpSources.Len()) // for i := 0; i != sexpSources.Len(); i++ { ith := sexpSources.Get(i) @@ -343,7 +371,7 @@ func (p *Parser) parseDefInterleaved(elements []sexp.SExp) (*DefInterleaved, *Sy return nil, p.translator.SyntaxError(ith, "malformed source column") } // Extract column name - sources[i] = &DefSourceColumn{ith.AsSymbol().Value} + sources[i] = &DefName{ith.AsSymbol().Value} } // Done return &DefInterleaved{target, sources}, nil @@ -479,6 +507,73 @@ func (p *Parser) parseDefProperty(elements []sexp.SExp) (*DefProperty, *SyntaxEr return &DefProperty{handle, expr}, nil } +// Parse a permutation declaration +func (p *Parser) parseDefPureFun(elements []sexp.SExp) (*DefFun, []SyntaxError) { + var ( + name *DefName + ret sc.Type + params []*DefParameter + errors []SyntaxError + signature *sexp.List = elements[1].AsList() + ) + // Parse signature + if signature == nil || signature.Len() == 0 { + err := p.translator.SyntaxError(elements[1], "malformed function signature") + errors = append(errors, *err) + } else { + name, ret, params, errors = p.parseFunctionSignature(signature.Elements) + } + // Translate expression + body, err := p.translator.Translate(elements[2]) + if err != nil { + errors = append(errors, *err) + } + // Check for errors + if len(errors) > 0 { + return nil, errors + } + // + return &DefFun{name, true, ret, params, body}, nil +} + +func (p *Parser) parseFunctionSignature(elements []sexp.SExp) (*DefName, sc.Type, []*DefParameter, []SyntaxError) { + var ( + name *sexp.Symbol = elements[0].AsSymbol() + params []*DefParameter = make([]*DefParameter, len(elements)-1) + ret sc.Type = &sc.FieldType{} + errors []SyntaxError + ) + // Parse name + if name == nil { + err := p.translator.SyntaxError(elements[1], "expected function name") + errors = append(errors, *err) + } + // Parse parameters + for i := 0; i < len(params); i = i + 1 { + var errs []SyntaxError + + if params[i], errs = p.parseFunctionParameter(elements[i+1]); len(errs) > 0 { + errors = append(errors, errs...) + } + } + // Check for any errors arising + if len(errors) > 0 { + return nil, nil, nil, errors + } + // + return &DefName{name.Value}, ret, params, nil +} + +func (p *Parser) parseFunctionParameter(element sexp.SExp) (*DefParameter, []SyntaxError) { + if symbol := element.AsSymbol(); symbol != nil { + return &DefParameter{symbol.Value, &sc.FieldType{}}, nil + } + // Construct error message (for now) + err := p.translator.SyntaxError(element, "malformed parameter declaration") + // + return nil, []SyntaxError{*err} +} + // Parse a range declaration func (p *Parser) parseDefInRange(elements []sexp.SExp) (*DefInRange, *SyntaxError) { var bound fr.Element @@ -651,6 +746,23 @@ func ifParserRule(_ string, args []Expr) (Expr, error) { return nil, errors.New("incorrect number of arguments") } +func invokeParserRule(name string, args []Expr) (Expr, error) { + // Sanity check what we have + if !unicode.IsLetter(rune(name[0])) { + return nil, nil + } + // Handle qualified accesses (where permitted) + // Attempt to split column name into module / column pair. + split := strings.Split(name, ".") + if len(split) == 2 { + return &Invoke{&split[0], split[1], args, nil}, nil + } else if len(split) > 2 { + return nil, errors.New("malformed function invocation") + } + // Done + return &Invoke{nil, name, args, nil}, nil +} + func shiftParserRule(col string, amt string) (Expr, error) { n, err := strconv.Atoi(amt) diff --git a/pkg/corset/resolver.go b/pkg/corset/resolver.go index 9a58236a..e06ccc29 100644 --- a/pkg/corset/resolver.go +++ b/pkg/corset/resolver.go @@ -337,11 +337,13 @@ func (r *resolver) finalisePermutationAssignment(module uint, decl *DefPermutati // declared; secondly, to determine what each variable represents (i.e. column // access, a constant, etc). func (r *resolver) resolveConstraints(circuit *Circuit) []SyntaxError { - errs := r.resolveConstraintsInModule("", circuit.Declarations) + root := r.buildModuleScope("", circuit.Declarations) + errs := r.resolveConstraintsInModule(root, circuit.Declarations) // for _, m := range circuit.Modules { + module := r.buildModuleScope(m.Name, circuit.Declarations) // Process all declarations in the module - merrs := r.resolveConstraintsInModule(m.Name, m.Declarations) + merrs := r.resolveConstraintsInModule(module, m.Declarations) // Package up all errors errs = append(errs, merrs...) } @@ -349,31 +351,49 @@ func (r *resolver) resolveConstraints(circuit *Circuit) []SyntaxError { return errs } +func (r *resolver) buildModuleScope(name string, decls []Declaration) Scope { + var ( + scope *ModuleScope = r.env.NewModuleScope(name) + ) + // + for _, d := range decls { + // Look for defcolumns decalarations only + if c, ok := d.(*DefFun); ok { + // TODO: sanity check if function already declared. + scope.DeclareFunction(c.Name.Name, uint(len(c.Parameters)), c.Body) + } + } + // + return scope +} + // Helper for resolve constraints which considers those constraints declared in // a particular module. -func (r *resolver) resolveConstraintsInModule(module string, decls []Declaration) []SyntaxError { +func (r *resolver) resolveConstraintsInModule(enclosing Scope, decls []Declaration) []SyntaxError { var errors []SyntaxError - + // for _, d := range decls { // Look for defcolumns decalarations only if _, ok := d.(*DefColumns); ok { // Safe to ignore. } else if c, ok := d.(*DefConstraint); ok { - errors = append(errors, r.resolveDefConstraintInModule(module, c)...) + errors = append(errors, r.resolveDefConstraintInModule(enclosing, c)...) } else if c, ok := d.(*DefInRange); ok { - errors = append(errors, r.resolveDefInRangeInModule(module, c)...) + errors = append(errors, r.resolveDefInRangeInModule(enclosing, c)...) } else if _, ok := d.(*DefInterleaved); ok { // Nothing to do here, since this assignment form contains no // expressions to be resolved. } else if c, ok := d.(*DefLookup); ok { - errors = append(errors, r.resolveDefLookupInModule(module, c)...) + errors = append(errors, r.resolveDefLookupInModule(enclosing, c)...) } else if _, ok := d.(*DefPermutation); ok { // Nothing to do here, since this assignment form contains no // expressions to be resolved. + } else if c, ok := d.(*DefFun); ok { + errors = append(errors, r.resolveDefFunInModule(enclosing, c)...) } else if c, ok := d.(*DefProperty); ok { - errors = append(errors, r.resolveDefPropertyInModule(module, c)...) + errors = append(errors, r.resolveDefPropertyInModule(enclosing, c)...) } else { - errors = append(errors, *r.srcmap.SyntaxError(d, fmt.Sprintf("unknown declaration in module %s", module))) + errors = append(errors, *r.srcmap.SyntaxError(d, "unknown declaration")) } } // @@ -381,70 +401,86 @@ func (r *resolver) resolveConstraintsInModule(module string, decls []Declaration } // Resolve those variables appearing in either the guard or the body of this constraint. -func (r *resolver) resolveDefConstraintInModule(module string, decl *DefConstraint) []SyntaxError { +func (r *resolver) resolveDefConstraintInModule(enclosing Scope, decl *DefConstraint) []SyntaxError { var ( - errors []SyntaxError - context = tr.VoidContext() + errors []SyntaxError + scope = NewLocalScope(enclosing, false) ) // Resolve guard if decl.Guard != nil { - errors = r.resolveExpressionInModule(module, false, &context, decl.Guard) + errors = r.resolveExpressionInModule(scope, decl.Guard) } // Resolve constraint body - errors = append(errors, r.resolveExpressionInModule(module, false, &context, decl.Constraint)...) + errors = append(errors, r.resolveExpressionInModule(scope, decl.Constraint)...) // Done return errors } // Resolve those variables appearing in the body of this range constraint. -func (r *resolver) resolveDefInRangeInModule(module string, decl *DefInRange) []SyntaxError { +func (r *resolver) resolveDefInRangeInModule(enclosing Scope, decl *DefInRange) []SyntaxError { var ( - errors []SyntaxError - context = tr.VoidContext() + errors []SyntaxError + scope = NewLocalScope(enclosing, false) ) // Resolve property body - errors = append(errors, r.resolveExpressionInModule(module, false, &context, decl.Expr)...) + errors = append(errors, r.resolveExpressionInModule(scope, decl.Expr)...) + // Done + return errors +} + +// Resolve those variables appearing in the body of this function. +func (r *resolver) resolveDefFunInModule(enclosing Scope, decl *DefFun) []SyntaxError { + var ( + errors []SyntaxError + scope = NewLocalScope(enclosing, false) + ) + // Declare parameters in local scope + for _, p := range decl.Parameters { + scope.DeclareLocal(p.Name) + } + // Resolve property body + errors = append(errors, r.resolveExpressionInModule(scope, decl.Body)...) + // Remove parameters from enclosing environment // Done return errors } // Resolve those variables appearing in the body of this lookup constraint. -func (r *resolver) resolveDefLookupInModule(module string, decl *DefLookup) []SyntaxError { +func (r *resolver) resolveDefLookupInModule(enclosing Scope, decl *DefLookup) []SyntaxError { var ( - errors []SyntaxError - sourceContext = tr.VoidContext() - targetContext = tr.VoidContext() + errors []SyntaxError + sourceScope = NewLocalScope(enclosing, true) + targetScope = NewLocalScope(enclosing, true) ) // Resolve source expressions - errors = append(errors, r.resolveExpressionsInModule(module, true, &sourceContext, decl.Sources)...) + errors = append(errors, r.resolveExpressionsInModule(sourceScope, decl.Sources)...) // Resolve target expressions - errors = append(errors, r.resolveExpressionsInModule(module, true, &targetContext, decl.Targets)...) + errors = append(errors, r.resolveExpressionsInModule(targetScope, decl.Targets)...) // Done return errors } // Resolve those variables appearing in the body of this property assertion. -func (r *resolver) resolveDefPropertyInModule(module string, decl *DefProperty) []SyntaxError { +func (r *resolver) resolveDefPropertyInModule(enclosing Scope, decl *DefProperty) []SyntaxError { var ( - errors []SyntaxError - context = tr.VoidContext() + errors []SyntaxError + scope = NewLocalScope(enclosing, false) ) // Resolve property body - errors = append(errors, r.resolveExpressionInModule(module, false, &context, decl.Assertion)...) + errors = append(errors, r.resolveExpressionInModule(scope, decl.Assertion)...) // Done return errors } // Resolve a sequence of zero or more expressions within a given module. This // simply resolves each of the arguments in turn, collecting any errors arising. -func (r *resolver) resolveExpressionsInModule(module string, global bool, - context *tr.Context, args []Expr) []SyntaxError { +func (r *resolver) resolveExpressionsInModule(scope LocalScope, args []Expr) []SyntaxError { var errors []SyntaxError // Visit each argument for _, arg := range args { if arg != nil { - errs := r.resolveExpressionInModule(module, global, context, arg) + errs := r.resolveExpressionInModule(scope, arg) errors = append(errors, errs...) } } @@ -457,57 +493,77 @@ func (r *resolver) resolveExpressionsInModule(module string, global bool, // variable accesses. As above, the goal is ensure variable refers to something // that was declared and, more specifically, what kind of access it is (e.g. // column access, constant access, etc). -func (r *resolver) resolveExpressionInModule(module string, global bool, context *tr.Context, expr Expr) []SyntaxError { +func (r *resolver) resolveExpressionInModule(scope LocalScope, expr Expr) []SyntaxError { if _, ok := expr.(*Constant); ok { return nil } else if v, ok := expr.(*Add); ok { - return r.resolveExpressionsInModule(module, global, context, v.Args) + return r.resolveExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*Exp); ok { - return r.resolveExpressionInModule(module, global, context, v.Arg) + return r.resolveExpressionInModule(scope, v.Arg) } else if v, ok := expr.(*IfZero); ok { - return r.resolveExpressionsInModule(module, global, context, []Expr{v.Condition, v.TrueBranch, v.FalseBranch}) + return r.resolveExpressionsInModule(scope, []Expr{v.Condition, v.TrueBranch, v.FalseBranch}) + } else if v, ok := expr.(*Invoke); ok { + return r.resolveInvokeInModule(scope, v) } else if v, ok := expr.(*List); ok { - return r.resolveExpressionsInModule(module, global, context, v.Args) + return r.resolveExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*Mul); ok { - return r.resolveExpressionsInModule(module, global, context, v.Args) + return r.resolveExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*Normalise); ok { - return r.resolveExpressionInModule(module, global, context, v.Arg) + return r.resolveExpressionInModule(scope, v.Arg) } else if v, ok := expr.(*Sub); ok { - return r.resolveExpressionsInModule(module, global, context, v.Args) + return r.resolveExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*VariableAccess); ok { - return r.resolveVariableInModule(module, global, context, v) + return r.resolveVariableInModule(scope, v) } else { return r.srcmap.SyntaxErrors(expr, "unknown expression") } } +// Resolve a specific invocation contained within some expression which, in +// turn, is contained within some module. Note, qualified accesses are only +// permitted in a global context. +func (r *resolver) resolveInvokeInModule(scope LocalScope, expr *Invoke) []SyntaxError { + // Resolve arguments + if errors := r.resolveExpressionsInModule(scope, expr.Args); errors != nil { + return errors + } + // Lookup the corresponding function definition. + binding := scope.Bind(nil, expr.Name, true) + // Check what we got + if fnBinding, ok := binding.(*FunctionBinding); ok { + expr.Binding = fnBinding + return nil + } + // + return r.srcmap.SyntaxErrors(expr, "unknown function") +} + // Resolve a specific variable access contained within some expression which, in // turn, is contained within some module. Note, qualified accesses are only // permitted in a global context. -func (r *resolver) resolveVariableInModule(module string, global bool, context *tr.Context, +func (r *resolver) resolveVariableInModule(scope LocalScope, expr *VariableAccess) []SyntaxError { + // Will identify module of variable + //var module string = scope.EnclosingModule() + var mid *uint // Check whether this is a qualified access, or not. - if !global && expr.Module != nil { + if !scope.IsGlobal() && expr.Module != nil { return r.srcmap.SyntaxErrors(expr, "qualified access not permitted here") - } else if expr.Module != nil && !r.env.HasModule(*expr.Module) { + } else if expr.Module != nil && !scope.HasModule(*expr.Module) { return r.srcmap.SyntaxErrors(expr, fmt.Sprintf("unknown module %s", *expr.Module)) } else if expr.Module != nil { - module = *expr.Module + tmp := scope.Module(*expr.Module) + mid = &tmp } - // - mid := r.env.Module(module) // Attempt resolve as a column access in enclosing module - if cinfo, ok := r.env.LookupColumn(mid, expr.Name); ok { - ctx := tr.NewContext(mid, cinfo.multiplier) + if expr.Binding = scope.Bind(mid, expr.Name, false); expr.Binding != nil { // Update context - if *context = context.Join(ctx); context.IsConflicted() { + if !scope.FixContext(expr.Binding.Context()) { return r.srcmap.SyntaxErrors(expr, "conflicting context") } - // Register the binding to complete resolution. - expr.Binding = &Binder{true, ctx, cinfo.cid} // Done return nil } // Unable to resolve variable - return r.srcmap.SyntaxErrors(expr, fmt.Sprintf("unknown symbol in module %s", module)) + return r.srcmap.SyntaxErrors(expr, "unknown symbol") } diff --git a/pkg/corset/scope.go b/pkg/corset/scope.go new file mode 100644 index 00000000..79344907 --- /dev/null +++ b/pkg/corset/scope.go @@ -0,0 +1,180 @@ +package corset + +import ( + "fmt" + + tr "github.com/consensys/go-corset/pkg/trace" +) + +// Scope represents a region of code in which an expression can be evaluated. +// The purpose of a scope is to assist with determining what, exactly, a given +// variable used within an expression refers to. For example, a variable can +// refer to a column, or a parameter, etc. +type Scope interface { + // Get the name of the enclosing module. This is generally useful for + // reporting errors. + EnclosingModule() uint + // HasModule checks whether a given module exists, or not. + HasModule(string) bool + // Lookup the identifier for a given module. This assumes that the module + // exists, and will panic otherwise. + Module(string) uint + // Lookup a given variable being referenced with an optional module + // specifier. This variable could correspond to a column, a function, a + // parameter, or a local variable. Furthermore, the returned binding will + // be nil if this variable does not exist. + Bind(*uint, string, bool) Binding +} + +// ============================================================================= +// Module Scope +// ============================================================================= + +// ModuleScope represents the scope characterised by a module. +type ModuleScope struct { + // Module ID + module uint + // Provides access to global environment + environment *Environment + // Maps function names to their contents. + functions map[string]FunctionBinding +} + +// EnclosingModule returns the name of the enclosing module. This is generally +// useful for reporting errors. +func (p *ModuleScope) EnclosingModule() uint { + return p.module +} + +// HasModule checks whether a given module exists, or not. +func (p *ModuleScope) HasModule(module string) bool { + return p.environment.HasModule(module) +} + +// Module determines the module index for a given module. This assumes the +// module exists, and will panic otherwise. +func (p *ModuleScope) Module(module string) uint { + return p.environment.Module(module) +} + +// Bind looks up a given variable being referenced within a given module. For a +// root context, this is either a column, an alias or a function declaration. +func (p *ModuleScope) Bind(module *uint, name string, fn bool) Binding { + var mid uint + // Determine module for this lookup. + if module != nil { + mid = *module + } else { + mid = p.module + } + // Lookup function + if binding, ok := p.functions[name]; ok && module == nil { + return &binding + } else if info, ok := p.environment.LookupColumn(mid, name); ok && !fn { + ctx := tr.NewContext(mid, info.multiplier) + return &ColumnBinding{ctx, info.cid} + } + // error + return nil +} + +// DeclareFunction declares a given function within this module scope. +func (p *ModuleScope) DeclareFunction(name string, arity uint, body Expr) { + if _, ok := p.functions[name]; ok { + panic(fmt.Sprintf("attempt to redeclared function \"%s\"/%d", name, arity)) + } + // + p.functions[name] = FunctionBinding{arity, body} +} + +// ============================================================================= +// Local Scope +// ============================================================================= + +// LocalScope represents a simple implementation of scope in which local +// variables can be declared. A local scope must have a single context +// associated with it, and this will be inferred by resolving those expressions +// which must be evaluated within. +type LocalScope struct { + global bool + // Represents the enclosing scope + enclosing Scope + // Context for this scope + context *tr.Context + // Maps inputs parameters to the declaration index. + locals map[string]uint +} + +// NewLocalScope constructs a new local scope within a given enclosing scope. A +// local scope can have local variables declared within it. A local scope can +// also be "global" in the sense that accessing symbols from other modules is +// permitted. +func NewLocalScope(enclosing Scope, global bool) LocalScope { + context := tr.VoidContext() + locals := make(map[string]uint) + // + return LocalScope{global, enclosing, &context, locals} +} + +// NestedScope creates a nested scope within this local scope. +func (p LocalScope) NestedScope() LocalScope { + nlocals := make(map[string]uint) + // Clone allocated variables + for k, v := range p.locals { + nlocals[k] = v + } + // Done + return LocalScope{p.global, p.enclosing, p.context, nlocals} +} + +// IsGlobal determines whether symbols can be accessed in modules other than the +// enclosing module. +func (p LocalScope) IsGlobal() bool { + return p.global +} + +// EnclosingModule returns the name of the enclosing module. This is generally +// useful for reporting errors. +func (p LocalScope) EnclosingModule() uint { + return p.enclosing.EnclosingModule() +} + +// FixContext fixes the context for this scope. Since every scope requires +// exactly one context, this fails if we fix it to incompatible contexts. +func (p LocalScope) FixContext(context tr.Context) bool { + // Join contexts together + *p.context = p.context.Join(context) + // Check they were compatible + return !p.context.IsConflicted() +} + +// HasModule checks whether a given module exists, or not. +func (p LocalScope) HasModule(module string) bool { + return p.enclosing.HasModule(module) +} + +// Module determines the module index for a given module. This assumes the +// module exists, and will panic otherwise. +func (p LocalScope) Module(module string) uint { + return p.enclosing.Module(module) +} + +// Bind looks up a given variable or function being referenced either within the +// enclosing scope (module==nil) or within a specified module. +func (p LocalScope) Bind(module *uint, name string, fn bool) Binding { + // Check whether this is a local variable access. + if id, ok := p.locals[name]; ok && !fn && module == nil { + // Yes, this is a local variable access. + return &ParameterBinding{id} + } + // No, this is not a local variable access. + return p.enclosing.Bind(module, name, fn) +} + +// DeclareLocal registers a new local variable (e.g. a parameter). +func (p LocalScope) DeclareLocal(name string) uint { + index := uint(len(p.locals)) + p.locals[name] = index + // Return variable index + return index +} diff --git a/pkg/corset/translator.go b/pkg/corset/translator.go index fd62800a..e0493c5b 100644 --- a/pkg/corset/translator.go +++ b/pkg/corset/translator.go @@ -129,7 +129,7 @@ func (t *translator) translateAssignmentsAndConstraintsInModule(module string, d context := t.env.Module(module) // for _, d := range decls { - errs := t.translateAssignmentOrConstraint(d, context) + errs := t.translateDeclaration(d, context) errors = append(errors, errs...) } // Done @@ -138,13 +138,16 @@ func (t *translator) translateAssignmentsAndConstraintsInModule(module string, d // Translate an assignment or constraint declarartion which occurs within a // given module. -func (t *translator) translateAssignmentOrConstraint(decl Declaration, module uint) []SyntaxError { +func (t *translator) translateDeclaration(decl Declaration, module uint) []SyntaxError { var errors []SyntaxError // if _, ok := decl.(*DefColumns); ok { // Not an assignment or a constraint, hence ignore. } else if d, ok := decl.(*DefConstraint); ok { errors = t.translateDefConstraint(d, module) + } else if _, ok := decl.(*DefFun); ok { + // For now, functions are always compiled out when going down to HIR. + // In the future, this might change if we add support for macros to HIR. } else if d, ok := decl.(*DefInRange); ok { errors = t.translateDefInRange(d, module) } else if d, Ok := decl.(*DefInterleaved); Ok { @@ -352,6 +355,16 @@ func (t *translator) translateExpressionInModule(expr Expr, module uint) (hir.Ex } else if v, ok := expr.(*IfZero); ok { args, errs := t.translateExpressionsInModule([]Expr{v.Condition, v.TrueBranch, v.FalseBranch}, module) return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: args[2]}, errs + } else if e, ok := expr.(*Invoke); ok { + if e.Binding != nil && e.Binding.arity == uint(len(e.Args)) { + body := e.Binding.Apply(e.Args) + return t.translateExpressionInModule(body, module) + } else if e.Binding != nil { + msg := fmt.Sprintf("incorrect number of arguments (expected %d, found %d)", e.Binding.arity, len(e.Args)) + return nil, t.srcmap.SyntaxErrors(expr, msg) + } + // + return nil, t.srcmap.SyntaxErrors(expr, "unbound function") } else if v, ok := expr.(*List); ok { args, errs := t.translateExpressionsInModule(v.Args, module) return &hir.List{Args: args}, errs @@ -365,7 +378,11 @@ func (t *translator) translateExpressionInModule(expr Expr, module uint) (hir.Ex args, errs := t.translateExpressionsInModule(v.Args, module) return &hir.Sub{Args: args}, errs } else if e, ok := expr.(*VariableAccess); ok { - return &hir.ColumnAccess{Column: e.Binding.Index, Shift: e.Shift}, nil + if binding, ok := e.Binding.(*ColumnBinding); ok { + return &hir.ColumnAccess{Column: binding.ColumnID(), Shift: e.Shift}, nil + } + // error + return nil, t.srcmap.SyntaxErrors(expr, "unbound variable") } else { return nil, t.srcmap.SyntaxErrors(expr, "unknown expression") } diff --git a/pkg/hir/eval.go b/pkg/hir/eval.go index ec59ff4d..4012f646 100644 --- a/pkg/hir/eval.go +++ b/pkg/hir/eval.go @@ -15,12 +15,6 @@ func (e *ColumnAccess) EvalAllAt(k int, tr trace.Trace) []fr.Element { return []fr.Element{val} } -// EvalAllAt attempts to evaluate a variable access at a given row in a trace. -// However, at this time, that does not make sense. -func (e *VariableAccess) EvalAllAt(k int, tr trace.Trace) []fr.Element { - panic("unsupported operation") -} - // EvalAllAt evaluates a constant at a given row in a trace, which simply returns // that constant. func (e *Constant) EvalAllAt(k int, tr trace.Trace) []fr.Element { diff --git a/pkg/hir/expr.go b/pkg/hir/expr.go index 7a5a02ac..279ac932 100644 --- a/pkg/hir/expr.go +++ b/pkg/hir/expr.go @@ -415,47 +415,6 @@ func (p *Normalise) RequiredCells(row int, tr trace.Trace) *util.AnySortedSet[tr // does not perform any form of simplification to determine this. func (p *Normalise) AsConstant() *fr.Element { return nil } -// ============================================================================ -// VariableAccess -// ============================================================================ - -// VariableAccess represents reading the value of a given local variable (such -// as a function parameter). -type VariableAccess struct { - Name string - Shift int -} - -// Bounds returns max shift in either the negative (left) or positive -// direction (right). -func (p *VariableAccess) Bounds() util.Bounds { - panic("variable accesses do not have bounds") -} - -// Context determines the evaluation context (i.e. enclosing module) for this -// expression. -func (p *VariableAccess) Context(schema sc.Schema) trace.Context { - panic("variable accesses do not have a context") -} - -// RequiredColumns returns the set of columns on which this term depends. -// That is, columns whose values may be accessed when evaluating this term -// on a given trace. -func (p *VariableAccess) RequiredColumns() *util.SortedSet[uint] { - panic("unsupported operation") -} - -// RequiredCells returns the set of trace cells on which this term depends. -// In this case, that is the empty set. -func (p *VariableAccess) RequiredCells(row int, tr trace.Trace) *util.AnySortedSet[trace.CellRef] { - panic("unsupported operation") -} - -// AsConstant determines whether or not this is a constant expression. If -// so, the constant is returned; otherwise, nil is returned. NOTE: this -// does not perform any form of simplification to determine this. -func (p *VariableAccess) AsConstant() *fr.Element { return nil } - // ============================================================================ // ColumnAccess // ============================================================================ diff --git a/pkg/hir/lisp.go b/pkg/hir/lisp.go index cb557455..23756c9b 100644 --- a/pkg/hir/lisp.go +++ b/pkg/hir/lisp.go @@ -23,21 +23,6 @@ func (e *ColumnAccess) Lisp(schema sc.Schema) sexp.SExp { return sexp.NewList([]sexp.SExp{sexp.NewSymbol("shift"), access, shift}) } -// Lisp converts this schema element into a simple S-Expression, for example -// so it can be printed. -func (e *VariableAccess) Lisp(schema sc.Schema) sexp.SExp { - access := sexp.NewSymbol(e.Name) - // Check whether shifted (or not) - if e.Shift == 0 { - // Not shifted - return access - } - // Shifted - shift := sexp.NewSymbol(fmt.Sprintf("%d", e.Shift)) - - return sexp.NewList([]sexp.SExp{sexp.NewSymbol("shift"), access, shift}) -} - // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Constant) Lisp(schema sc.Schema) sexp.SExp { diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 949baf97..8df3efa9 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -113,13 +113,6 @@ func (e *ColumnAccess) LowerTo(schema *mir.Schema) []mir.Expr { return lowerTo(e, schema) } -// LowerTo lowers a variable access to the MIR level. This requires expanding -// the arguments, then lowering them. Furthermore, conditionals are "lifted" to -// the top. -func (e *VariableAccess) LowerTo(schema *mir.Schema) []mir.Expr { - return lowerTo(e, schema) -} - // LowerTo lowers an exponent expression to the MIR level. This requires expanding // the argument andn lowering it. Furthermore, conditionals are "lifted" to // the top. diff --git a/pkg/hir/macro.go b/pkg/hir/macro.go deleted file mode 100644 index f7ad5e6f..00000000 --- a/pkg/hir/macro.go +++ /dev/null @@ -1,18 +0,0 @@ -package hir - -// MacroDefinition represents something which can be called, and that will be -// inlined at the point of call. -type MacroDefinition struct { - // Enclosing module - module uint - // Name of the macro - name string - // Parameters of the macro - params []string - // Body of the macro - body Expr - // Indicates whether or not this macro is "pure". More specifically, pure - // macros can only refer to parameters (i.e. cannot access enclosing columns - // directly). - pure bool -} diff --git a/pkg/hir/schema.go b/pkg/hir/schema.go index 3373288a..0d8d5a32 100644 --- a/pkg/hir/schema.go +++ b/pkg/hir/schema.go @@ -51,9 +51,6 @@ type Schema struct { assertions []PropertyAssertion // Cache list of columns declared in inputs and assignments. column_cache []sc.Column - // Macros determines the set of macros which can be called within - // expressions, etc. - macros []*MacroDefinition } // EmptySchema is used to construct a fresh schema onto which new columns and @@ -146,15 +143,6 @@ func (p *Schema) AddPropertyAssertion(handle string, context trace.Context, prop p.assertions = append(p.assertions, sc.NewPropertyAssertion[ZeroArrayTest](handle, context, ZeroArrayTest{property})) } -// AddMacroDefinition adds a definition for a macro (either pure or impure). -func (p *Schema) AddMacroDefinition(module uint, name string, params []string, body Expr, pure bool) uint { - index := p.Macros().Count() - macro := &MacroDefinition{module, name, params, body, pure} - p.macros = append(p.macros, macro) - - return index -} - // ============================================================================ // Schema Interface // ============================================================================ @@ -204,11 +192,6 @@ func (p *Schema) Declarations() util.Iterator[sc.Declaration] { return inputs.Append(ps) } -// Macros returns an array over the macro definitions available in this schema. -func (p *Schema) Macros() util.Iterator[*MacroDefinition] { - return util.NewArrayIterator(p.macros) -} - // Modules returns an iterator over the declared set of modules within this // schema. func (p *Schema) Modules() util.Iterator[sc.Module] { diff --git a/pkg/test/invalid_corset_test.go b/pkg/test/invalid_corset_test.go index 9622bd22..9615de48 100644 --- a/pkg/test/invalid_corset_test.go +++ b/pkg/test/invalid_corset_test.go @@ -243,6 +243,37 @@ func Test_Invalid_Interleave_09(t *testing.T) { CheckInvalid(t, "interleave_invalid_09") } +// =================================================================== +// Functions +// =================================================================== + +func Test_Invalid_PureFun_01(t *testing.T) { + CheckInvalid(t, "purefun_invalid_01") +} + +func Test_Invalid_PureFun_02(t *testing.T) { + CheckInvalid(t, "purefun_invalid_02") +} + +func Test_Invalid_PureFun_03(t *testing.T) { + CheckInvalid(t, "purefun_invalid_03") +} + +/* + func Test_Invalid_PureFun_04(t *testing.T) { + CheckInvalid(t, "purefun_invalid_04") + } +*/ +func Test_Invalid_PureFun_05(t *testing.T) { + CheckInvalid(t, "purefun_invalid_05") +} + +/* + func Test_Invalid_PureFun_06(t *testing.T) { + CheckInvalid(t, "purefun_invalid_06") + } +*/ + // =================================================================== // Test Helpers // =================================================================== diff --git a/pkg/test/valid_corset_test.go b/pkg/test/valid_corset_test.go index 7feb1ced..3532ff76 100644 --- a/pkg/test/valid_corset_test.go +++ b/pkg/test/valid_corset_test.go @@ -482,14 +482,20 @@ func Test_Interleave_04(t *testing.T) { // Functions // =================================================================== -/* func Test_PureFun_01(t *testing.T) { +func Test_PureFun_01(t *testing.T) { Check(t, "purefun_01") } func Test_PureFun_02(t *testing.T) { Check(t, "purefun_02") } + +/* + func Test_PureFun_03(t *testing.T) { + Check(t, "purefun_03") + } */ + // =================================================================== // Complex Tests // =================================================================== diff --git a/testdata/purefun_03.lisp b/testdata/purefun_03.lisp new file mode 100644 index 00000000..b4083c79 --- /dev/null +++ b/testdata/purefun_03.lisp @@ -0,0 +1,3 @@ +(defcolumns A) +(defpurefun ((vanishes! :@loob) e) e) +(defconstraint test () (vanishes! A)) diff --git a/testdata/purefun_invalid_01.lisp b/testdata/purefun_invalid_01.lisp new file mode 100644 index 00000000..b33fe145 --- /dev/null +++ b/testdata/purefun_invalid_01.lisp @@ -0,0 +1,3 @@ +(defcolumns A) +;;(defpurefun (id x) x) +(defconstraint test () (id A)) diff --git a/testdata/purefun_invalid_02.lisp b/testdata/purefun_invalid_02.lisp new file mode 100644 index 00000000..c6a92174 --- /dev/null +++ b/testdata/purefun_invalid_02.lisp @@ -0,0 +1,3 @@ +(defcolumns A) +(defpurefun (id x) x) +(defconstraint test () (+ id A)) diff --git a/testdata/purefun_invalid_03.lisp b/testdata/purefun_invalid_03.lisp new file mode 100644 index 00000000..742d9e3f --- /dev/null +++ b/testdata/purefun_invalid_03.lisp @@ -0,0 +1,3 @@ +(defcolumns A) +(defpurefun (id x) x) +(defconstraint test () (id A A)) diff --git a/testdata/purefun_invalid_04.lisp b/testdata/purefun_invalid_04.lisp new file mode 100644 index 00000000..b7a1d3b6 --- /dev/null +++ b/testdata/purefun_invalid_04.lisp @@ -0,0 +1,4 @@ +(defcolumns A) +;; not pure! +(defpurefun (id x) (+ x A)) +(defconstraint test () (id 1)) diff --git a/testdata/purefun_invalid_05.lisp b/testdata/purefun_invalid_05.lisp new file mode 100644 index 00000000..8624a7dd --- /dev/null +++ b/testdata/purefun_invalid_05.lisp @@ -0,0 +1 @@ +(defpurefun (id x) (+ x y)) diff --git a/testdata/purefun_invalid_06.lisp b/testdata/purefun_invalid_06.lisp new file mode 100644 index 00000000..a1ba38a5 --- /dev/null +++ b/testdata/purefun_invalid_06.lisp @@ -0,0 +1,5 @@ +(defcolumns X) +;; recursive :) +(defpurefun (id x) (+ x (id x))) +;; infinite loop? +(defconstraint c1 () (id X))