Skip to content

Commit

Permalink
feat: support reduce (#445)
Browse files Browse the repository at this point in the history
* Add test cases for reduce

* Support reduce operator

This adds support for the reduce operator, which performsing folding
over binary functions.

* Support intrinsics

This adds support for intrinsics which are needed for reductions
(amongst other things).  Specifically, they are required for reductions
which operate over built-in operators, such as "+", "-", etc.  At this
stage, only a few intrinsics have been added and I'm anticipating that
more will be required at some point.
  • Loading branch information
DavePearce authored Dec 17, 2024
1 parent 977366d commit 572dff1
Show file tree
Hide file tree
Showing 34 changed files with 616 additions and 150 deletions.
57 changes: 43 additions & 14 deletions pkg/corset/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,26 @@ type Binding interface {
IsFinalised() bool
}

// FunctionBinding is a special kind of binding which captures the essence of
// something which can be called. For example, this could be a user-defined
// function or an intrinsic.
type FunctionBinding interface {
Binding
// IsPure checks whether this function binding has side-effects or not.
IsPure() bool
// HasArity checks whether this binding supports a given number of
// parameters. For example, intrinsic functions are often nary --- meaning
// they can accept any number of arguments. In contrast, a user-defined
// function may only accept a specific number of arguments, etc.
HasArity(uint) bool
// Apply a set of concreate arguments to this function. This substitutes
// them through the body of the function producing a single expression.
Apply([]Expr) Expr
// Get the declared return type of this function, or nil if no return type
// was declared.
ReturnType() Type
}

// ============================================================================
// ColumnBinding
// ============================================================================
Expand Down Expand Up @@ -158,12 +178,12 @@ func (p *LocalVariableBinding) Finalise(index uint) {
}

// ============================================================================
// FunctionBinding
// DefunBinding
// ============================================================================

// FunctionBinding represents the binding of a function application to its
// physical definition.
type FunctionBinding struct {
// DefunBinding is a function binding arising from a user-defined function (as
// opposed, for example, to a function binding arising from an intrinsic).
type DefunBinding struct {
// Flag whether or not is pure function
pure bool
// Types of parameters (optional)
Expand All @@ -177,33 +197,42 @@ type FunctionBinding struct {
body Expr
}

// NewFunctionBinding constructs a new function binding.
func NewFunctionBinding(pure bool, paramTypes []Type, returnType Type, body Expr) FunctionBinding {
return FunctionBinding{pure, paramTypes, returnType, nil, body}
var _ FunctionBinding = &DefunBinding{}

// NewDefunBinding constructs a new function binding.
func NewDefunBinding(pure bool, paramTypes []Type, returnType Type, body Expr) DefunBinding {
return DefunBinding{pure, paramTypes, returnType, nil, body}
}

// IsPure checks whether this is a defpurefun or not
func (p *FunctionBinding) IsPure() bool {
func (p *DefunBinding) IsPure() bool {
return p.pure
}

// IsFinalised checks whether this binding has been finalised yet or not.
func (p *FunctionBinding) IsFinalised() bool {
func (p *DefunBinding) IsFinalised() bool {
return p.bodyType != nil
}

// Arity returns the number of parameters that this function accepts.
func (p *FunctionBinding) Arity() uint {
return uint(len(p.paramTypes))
// HasArity checks whether this function accepts a given number of arguments (or
// not).
func (p *DefunBinding) HasArity(arity uint) bool {
return arity == uint(len(p.paramTypes))
}

// ReturnType gets the declared return type of this function, or nil if no
// return type was declared.
func (p *DefunBinding) ReturnType() Type {
return p.returnType
}

// Finalise this binding by providing the necessary missing information.
func (p *FunctionBinding) Finalise(bodyType Type) {
func (p *DefunBinding) Finalise(bodyType Type) {
p.bodyType = bodyType
}

// Apply a given set of arguments to this function binding.
func (p *FunctionBinding) Apply(args []Expr) Expr {
func (p *DefunBinding) Apply(args []Expr) Expr {
mapping := make(map[uint]Expr)
// Setup the mapping
for i, e := range args {
Expand Down
2 changes: 1 addition & 1 deletion pkg/corset/declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ type DefFun struct {
// Parameters
parameters []*DefParameter
//
binding FunctionBinding
binding DefunBinding
}

// IsFunction is always true for a function definition!
Expand Down
155 changes: 77 additions & 78 deletions pkg/corset/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ func (e *ArrayAccess) Context() Context {
// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (e *ArrayAccess) Lisp() sexp.SExp {
panic("todo")
return sexp.NewArray([]sexp.SExp{
sexp.NewSymbol(e.name),
e.arg.Lisp(),
})
}

// Substitute all variables (such as for function parameters) arising in
Expand Down Expand Up @@ -461,65 +464,23 @@ func (e *If) Dependencies() []Symbol {

// Invoke represents an attempt to invoke a given function.
type Invoke struct {
module *string
name string
args []Expr
binding *FunctionBinding
fn *VariableAccess
args []Expr
}

// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *Invoke) AsConstant() *big.Int {
if e.binding == nil {
if e.fn.binding == nil {
panic("unresolved invocation")
} else if fn_binding, ok := e.fn.binding.(FunctionBinding); ok {
// Unroll body
body := fn_binding.Apply(e.args)
// Attempt to evaluate as constant
return body.AsConstant()
}
// Unroll body
body := e.binding.Apply(e.args)
// Attempt to evaluate as constant
return body.AsConstant()
}

// IsQualified determines whether this symbol is qualfied or not (i.e. has an
// explicitly module specifier).
func (e *Invoke) IsQualified() bool {
return e.module != nil
}

// IsFunction indicates whether or not this symbol refers to a function (which
// of course it always does).
func (e *Invoke) IsFunction() bool {
return true
}

// IsResolved checks whether this symbol has been resolved already, or not.
func (e *Invoke) IsResolved() bool {
return e.binding != nil
}

// Resolve this symbol by associating it with the binding associated with
// the definition of the symbol to which this refers.
func (e *Invoke) Resolve(binding Binding) bool {
if fb, ok := binding.(*FunctionBinding); ok {
e.binding = fb
return true
}
// Problem
return false
}

// Module returns the optional module qualification. This will panic if this
// invocation is unqualified.
func (e *Invoke) Module() string {
if e.module == nil {
panic("invocation has no module qualifier")
}

return *e.module
}

// Name of the function being invoked.
func (e *Invoke) Name() string {
return e.name
// Just fail
return nil
}

// Args returns the arguments provided by this invocation to the function being
Expand All @@ -528,24 +489,10 @@ func (e *Invoke) Args() []Expr {
return e.args
}

// Binding gets binding associated with this interface. This will panic if this
// symbol is not yet resolved.
func (e *Invoke) Binding() Binding {
if e.binding == nil {
panic("invocation not yet resolved")
}

return e.binding
}

// Context returns the context for this expression. Observe that the
// expression must have been resolved for this to be defined (i.e. it may
// panic if it has not been resolved yet).
func (e *Invoke) Context() Context {
if e.binding == nil {
panic("unresolved expressions encountered whilst resolving context")
}
// TODO: impure functions can have their own context.
return ContextOfExpressions(e.args)
}

Expand All @@ -559,28 +506,21 @@ func (e *Invoke) Multiplicity() uint {
// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (e *Invoke) Lisp() sexp.SExp {
var fn sexp.SExp
if e.module != nil {
fn = sexp.NewSymbol(fmt.Sprintf("%s.%s", *e.module, e.name))
} else {
fn = sexp.NewSymbol(e.name)
}

return ListOfExpressions(fn, e.args)
return ListOfExpressions(e.fn.Lisp(), e.args)
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Invoke) Substitute(mapping map[uint]Expr) Expr {
return &Invoke{e.module, e.name, SubstituteExpressions(e.args, mapping), e.binding}
return &Invoke{e.fn, SubstituteExpressions(e.args, mapping)}
}

// Dependencies needed to signal declaration.
func (e *Invoke) Dependencies() []Symbol {
deps := DependenciesOfExpressions(e.args)
// Include this expression as a symbol (which must be bound to the function
// being invoked)
return append(deps, e)
return append(deps, e.fn)
}

// ============================================================================
Expand Down Expand Up @@ -718,6 +658,60 @@ func (e *Normalise) Dependencies() []Symbol {
return e.Arg.Dependencies()
}

// ============================================================================
// Reduction
// ============================================================================

// Reduce reduces (i.e. folds) a list using a given binary function.
type Reduce struct {
fn *VariableAccess
arg Expr
}

// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *Reduce) AsConstant() *big.Int {
// TODO: potentially we can do better here.
return nil
}

// Multiplicity determines the number of values that evaluating this expression
// can generate.
func (e *Reduce) Multiplicity() uint {
return 1
}

// Context returns the context for this expression. Observe that the
// expression must have been resolved for this to be defined (i.e. it may
// panic if it has not been resolved yet).
func (e *Reduce) Context() Context {
return e.arg.Context()
}

// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (e *Reduce) Lisp() sexp.SExp {
return sexp.NewList([]sexp.SExp{
sexp.NewSymbol("reduce"),
sexp.NewSymbol(e.fn.name),
e.arg.Lisp()})
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Reduce) Substitute(mapping map[uint]Expr) Expr {
return &Reduce{
e.fn,
e.arg.Substitute(mapping),
}
}

// Dependencies needed to signal declaration.
func (e *Reduce) Dependencies() []Symbol {
deps := e.arg.Dependencies()
return append(deps, e.fn)
}

// ============================================================================
// Subtraction
// ============================================================================
Expand Down Expand Up @@ -828,6 +822,7 @@ func (e *Shift) Dependencies() []Symbol {
type VariableAccess struct {
module *string
name string
fn bool
binding Binding
}

Expand All @@ -850,7 +845,7 @@ func (e *VariableAccess) IsQualified() bool {
// IsFunction determines whether this symbol refers to a function (which, of
// course, variable accesses never do).
func (e *VariableAccess) IsFunction() bool {
return false
return e.fn
}

// IsResolved checks whether this symbol has been resolved already, or not.
Expand All @@ -865,6 +860,10 @@ func (e *VariableAccess) Resolve(binding Binding) bool {
panic("empty binding")
} else if e.binding != nil {
panic("already resolved")
} else if _, ok := binding.(FunctionBinding); ok && !e.fn {
return false
} else if _, ok := binding.(FunctionBinding); !ok && e.fn {
return false
}
//
e.binding = binding
Expand Down
Loading

0 comments on commit 572dff1

Please sign in to comment.