diff --git a/pkg/corset/ast.go b/pkg/corset/ast.go index 7f6d838..3bc44b0 100644 --- a/pkg/corset/ast.go +++ b/pkg/corset/ast.go @@ -754,8 +754,6 @@ func (p *DefProperty) Lisp() sexp.SExp { // defined within its enclosing context. type DefFun struct { name string - // Specify whether is pure (or not) - pure bool // Parameters parameters []*DefParameter // diff --git a/pkg/corset/binding.go b/pkg/corset/binding.go index 89d4972..e7af4f1 100644 --- a/pkg/corset/binding.go +++ b/pkg/corset/binding.go @@ -134,6 +134,11 @@ func NewFunctionBinding(pure bool, paramTypes []sc.Type, returnType sc.Type, bod return FunctionBinding{pure, paramTypes, returnType, body} } +// IsPure checks whether this is a defpurefun or not +func (p *FunctionBinding) IsPure() bool { + return p.pure +} + // IsFinalised checks whether this binding has been finalised yet or not. func (p *FunctionBinding) IsFinalised() bool { return p.returnType != nil diff --git a/pkg/corset/parser.go b/pkg/corset/parser.go index a8b4063..6e89c60 100644 --- a/pkg/corset/parser.go +++ b/pkg/corset/parser.go @@ -628,9 +628,9 @@ func (p *Parser) parseDefFun(pure bool, elements []sexp.SExp) (Declaration, []Sy paramTypes[i] = p.DataType } // Construct binding - binding := NewFunctionBinding(true, paramTypes, ret, body) + binding := NewFunctionBinding(pure, paramTypes, ret, body) // - return &DefFun{name, pure, params, binding}, nil + return &DefFun{name, params, binding}, nil } func (p *Parser) parseFunctionSignature(elements []sexp.SExp) (string, sc.Type, []*DefParameter, []SyntaxError) { diff --git a/pkg/corset/resolver.go b/pkg/corset/resolver.go index b0b8096..67b75a7 100644 --- a/pkg/corset/resolver.go +++ b/pkg/corset/resolver.go @@ -250,7 +250,7 @@ func (r *resolver) declarationDependenciesAreFinalised(scope *ModuleScope, // Finalise a declaration. func (r *resolver) finaliseDeclaration(scope *ModuleScope, decl Declaration) []SyntaxError { if d, ok := decl.(*DefConst); ok { - return r.finaliseDefConstInModule(d) + return r.finaliseDefConstInModule(scope, d) } else if d, ok := decl.(*DefConstraint); ok { return r.finaliseDefConstraintInModule(scope, d) } else if d, ok := decl.(*DefFun); ok { @@ -273,10 +273,14 @@ func (r *resolver) finaliseDeclaration(scope *ModuleScope, decl Declaration) []S // Finalise one or more constant definitions within a given module. // Specifically, we need to check that the constant values provided are indeed // constants. -func (r *resolver) finaliseDefConstInModule(decl *DefConst) []SyntaxError { +func (r *resolver) finaliseDefConstInModule(enclosing Scope, decl *DefConst) []SyntaxError { var errors []SyntaxError // for _, c := range decl.constants { + scope := NewLocalScope(enclosing, false, true) + // Resolve constant body + errors = append(errors, r.finaliseExpressionInModule(scope, c.binding.value)...) + // Check it is indeed constant! if constant := c.binding.value.AsConstant(); constant == nil { err := r.srcmap.SyntaxError(c, "definition not constant") errors = append(errors, *err) @@ -292,7 +296,7 @@ func (r *resolver) finaliseDefConstInModule(decl *DefConst) []SyntaxError { func (r *resolver) finaliseDefConstraintInModule(enclosing Scope, decl *DefConstraint) []SyntaxError { var ( errors []SyntaxError - scope = NewLocalScope(enclosing, false) + scope = NewLocalScope(enclosing, false, false) ) // Resolve guard if decl.Guard != nil { @@ -385,7 +389,7 @@ func (r *resolver) finaliseDefPermutationInModule(decl *DefPermutation) []Syntax func (r *resolver) finaliseDefInRangeInModule(enclosing Scope, decl *DefInRange) []SyntaxError { var ( errors []SyntaxError - scope = NewLocalScope(enclosing, false) + scope = NewLocalScope(enclosing, false, false) ) // Resolve property body errors = append(errors, r.finaliseExpressionInModule(scope, decl.Expr)...) @@ -401,7 +405,7 @@ func (r *resolver) finaliseDefInRangeInModule(enclosing Scope, decl *DefInRange) func (r *resolver) finaliseDefFunInModule(enclosing Scope, decl *DefFun) []SyntaxError { var ( errors []SyntaxError - scope = NewLocalScope(enclosing, false) + scope = NewLocalScope(enclosing, false, decl.IsPure()) ) // Declare parameters in local scope for _, p := range decl.Parameters() { @@ -417,8 +421,8 @@ func (r *resolver) finaliseDefFunInModule(enclosing Scope, decl *DefFun) []Synta func (r *resolver) finaliseDefLookupInModule(enclosing Scope, decl *DefLookup) []SyntaxError { var ( errors []SyntaxError - sourceScope = NewLocalScope(enclosing, true) - targetScope = NewLocalScope(enclosing, true) + sourceScope = NewLocalScope(enclosing, true, false) + targetScope = NewLocalScope(enclosing, true, false) ) // Resolve source expressions errors = append(errors, r.finaliseExpressionsInModule(sourceScope, decl.Sources)...) @@ -432,7 +436,7 @@ func (r *resolver) finaliseDefLookupInModule(enclosing Scope, decl *DefLookup) [ func (r *resolver) finaliseDefPropertyInModule(enclosing Scope, decl *DefProperty) []SyntaxError { var ( errors []SyntaxError - scope = NewLocalScope(enclosing, false) + scope = NewLocalScope(enclosing, false, false) ) // Resolve property body errors = append(errors, r.finaliseExpressionInModule(scope, decl.Assertion)...) @@ -466,7 +470,11 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) []Syn } else if v, ok := expr.(*Add); ok { return r.finaliseExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*Exp); ok { - return r.finaliseExpressionsInModule(scope, []Expr{v.Arg, v.Pow}) + purescope := scope.NestedPureScope() + arg_errs := r.finaliseExpressionInModule(scope, v.Arg) + pow_errs := r.finaliseExpressionInModule(purescope, v.Pow) + // combine errors + return append(arg_errs, pow_errs...) } else if v, ok := expr.(*IfZero); ok { return r.finaliseExpressionsInModule(scope, []Expr{v.Condition, v.TrueBranch, v.FalseBranch}) } else if v, ok := expr.(*Invoke); ok { @@ -478,7 +486,11 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) []Syn } else if v, ok := expr.(*Normalise); ok { return r.finaliseExpressionInModule(scope, v.Arg) } else if v, ok := expr.(*Shift); ok { - return r.finaliseExpressionsInModule(scope, []Expr{v.Arg, v.Shift}) + purescope := scope.NestedPureScope() + arg_errs := r.finaliseExpressionInModule(scope, v.Arg) + shf_errs := r.finaliseExpressionInModule(purescope, v.Shift) + // combine errors + return append(arg_errs, shf_errs...) } else if v, ok := expr.(*Sub); ok { return r.finaliseExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*VariableAccess); ok { @@ -499,6 +511,8 @@ func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) []Synt // Lookup the corresponding function definition. if !scope.Bind(expr) { return r.srcmap.SyntaxErrors(expr, "unknown function") + } else if scope.IsPure() && !expr.binding.IsPure() { + return r.srcmap.SyntaxErrors(expr, "not permitted in pure context") } // Success return nil @@ -522,6 +536,8 @@ func (r *resolver) finaliseVariableInModule(scope LocalScope, if binding, ok := expr.Binding().(*ColumnBinding); ok { if !scope.FixContext(binding.Context()) { return r.srcmap.SyntaxErrors(expr, "conflicting context") + } else if scope.IsPure() { + return r.srcmap.SyntaxErrors(expr, "not permitted in pure context") } } else if _, ok := expr.Binding().(*ConstantBinding); !ok { // Unable to resolve variable diff --git a/pkg/corset/scope.go b/pkg/corset/scope.go index 4495bbc..bc589d0 100644 --- a/pkg/corset/scope.go +++ b/pkg/corset/scope.go @@ -208,6 +208,9 @@ func (p *ModuleScope) Alias(alias string, symbol Symbol) bool { // which must be evaluated within. type LocalScope struct { global bool + // Determines whether or not this scope is "pure" (i.e. whether or not + // columns can be accessed, etc). + pure bool // Represents the enclosing scope enclosing Scope // Context for this scope @@ -220,11 +223,11 @@ type LocalScope struct { // local scope can have local variables declared within it. A local scope can // also be "global" in the sense that accessing symbols from other modules is // permitted. -func NewLocalScope(enclosing Scope, global bool) LocalScope { +func NewLocalScope(enclosing Scope, global bool, pure bool) LocalScope { context := tr.VoidContext[string]() locals := make(map[string]uint) // - return LocalScope{global, enclosing, &context, locals} + return LocalScope{global, pure, enclosing, &context, locals} } // NestedScope creates a nested scope within this local scope. @@ -235,7 +238,19 @@ func (p LocalScope) NestedScope() LocalScope { nlocals[k] = v } // Done - return LocalScope{p.global, p.enclosing, p.context, nlocals} + return LocalScope{p.global, p.pure, p, p.context, nlocals} +} + +// NestedPureScope creates a nested scope within this local scope which, in +// addition, is always pure. +func (p LocalScope) NestedPureScope() LocalScope { + nlocals := make(map[string]uint) + // Clone allocated variables + for k, v := range p.locals { + nlocals[k] = v + } + // Done + return LocalScope{p.global, true, p, p.context, nlocals} } // IsGlobal determines whether symbols can be accessed in modules other than the @@ -244,6 +259,13 @@ func (p LocalScope) IsGlobal() bool { return p.global } +// IsPure determines whether or not this scope is pure. That is, whether or not +// expressions in this scope are permitted to access columns (either directly, +// or indirectly via impure invocations). +func (p LocalScope) IsPure() bool { + return p.pure +} + // FixContext fixes the context for this scope. Since every scope requires // exactly one context, this fails if we fix it to incompatible contexts. func (p LocalScope) FixContext(context Context) bool { diff --git a/pkg/test/invalid_corset_test.go b/pkg/test/invalid_corset_test.go index 7a30c77..4880f9f 100644 --- a/pkg/test/invalid_corset_test.go +++ b/pkg/test/invalid_corset_test.go @@ -128,6 +128,18 @@ func Test_Invalid_Constant_14(t *testing.T) { CheckInvalid(t, "constant_invalid_14") } +func Test_Invalid_Constant_15(t *testing.T) { + CheckInvalid(t, "constant_invalid_15") +} + +func Test_Invalid_Constant_16(t *testing.T) { + CheckInvalid(t, "constant_invalid_16") +} + +func Test_Invalid_Constant_17(t *testing.T) { + CheckInvalid(t, "constant_invalid_17") +} + // =================================================================== // Alias Tests // =================================================================== @@ -373,11 +385,10 @@ func Test_Invalid_PureFun_03(t *testing.T) { CheckInvalid(t, "purefun_invalid_03") } -/* - func Test_Invalid_PureFun_04(t *testing.T) { - CheckInvalid(t, "purefun_invalid_04") - } -*/ +func Test_Invalid_PureFun_04(t *testing.T) { + CheckInvalid(t, "purefun_invalid_04") +} + func Test_Invalid_PureFun_05(t *testing.T) { CheckInvalid(t, "purefun_invalid_05") } @@ -388,6 +399,10 @@ func Test_Invalid_PureFun_05(t *testing.T) { } */ +func Test_Invalid_PureFun_07(t *testing.T) { + CheckInvalid(t, "purefun_invalid_07") +} + // =================================================================== // Test Helpers // =================================================================== diff --git a/testdata/constant_invalid_15.lisp b/testdata/constant_invalid_15.lisp new file mode 100644 index 0000000..1b56aef --- /dev/null +++ b/testdata/constant_invalid_15.lisp @@ -0,0 +1,2 @@ +(defun (ONE) 1) +(defconst X (ONE)) diff --git a/testdata/constant_invalid_16.lisp b/testdata/constant_invalid_16.lisp new file mode 100644 index 0000000..4d19344 --- /dev/null +++ b/testdata/constant_invalid_16.lisp @@ -0,0 +1,3 @@ +(defcolumns X) +(defun (ONE) 1) +(defconstraint c1 () (* X (^ 2 (ONE)))) diff --git a/testdata/constant_invalid_17.lisp b/testdata/constant_invalid_17.lisp new file mode 100644 index 0000000..68148bb --- /dev/null +++ b/testdata/constant_invalid_17.lisp @@ -0,0 +1,3 @@ +(defcolumns X) +(defun (ONE) 1) +(defconstraint c1 () (shift X (ONE))) diff --git a/testdata/purefun_invalid_07.lisp b/testdata/purefun_invalid_07.lisp new file mode 100644 index 0000000..73cbc05 --- /dev/null +++ b/testdata/purefun_invalid_07.lisp @@ -0,0 +1,5 @@ +(defcolumns A) +(defun (getA) A) +;; not pure! +(defpurefun (id x) (+ x (getA))) +(defconstraint test () (id 1))