Skip to content

Commit

Permalink
implemented cut
Browse files Browse the repository at this point in the history
  • Loading branch information
billhails committed Oct 14, 2023
1 parent 7f11ea5 commit bea9e68
Show file tree
Hide file tree
Showing 20 changed files with 1,329 additions and 38 deletions.
168 changes: 168 additions & 0 deletions docs/lambda-conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,172 @@ Note the equality check has to be part of the DFA, because we don't yet
know which branch we're on so can't start binding variables that will
differ between branches.

## New Approach

None of the above pans out, there are issues with variable binding confusing the NFA to DFA converter. You can see how far I got in prototyping this in [prototyping/DFATree.py](../prototyping/DFATree.py).

The new approach is to do what Haskell perportedly does, which is to check each pattern in turn. This is certainly simpler, and we can use a beefed-up variant of `amb` to help out by backtracking when a branch doesn't match.

The basic skeleton of the generated code should be something like:

```scheme
(amb (<check branch 1> <body 1>)
(amb (<check branch 2> <body 2>)
(amb (<check branch 3> <body 3>)
(error "patterns exhausted"))))
```

Each branch can be generated completely independantly of the others, and need only do a `(back)` if the args fail to match.

However there is a little problem, if such a function is backtracked through
for other reasons (normal use of `amb` by other code) then a subsequent branch
would be attempted in error.

How Prolog addresses this problem, where a branch once determined should be committed to, is to use a mechanism called a "green cut". This ensures that when a function body is backtracked out of, the entire function is backtracked out of.

We can achieve the same by having a new version of `amb` which I'm calling `escape`, and a new version of `back` which I'm calling `cut`.

`escape` is just like `amb` except it takes only one argument expression, and when backtracked to from that expression, it itself backtracks.

`cut` is a little different from a normal `back` though, when invoked, instead of just restoring the previous failure continuation, it repeatedly restores the previous continuation unlil it encounters an `escape` continuation.

The skeleton now becomes:

```scheme
(escape
(amb (<check branch 1> (amb <body 1> (cut))
(amb (<check branch 2> (amb <body 2> (cut))
(amb (<check branch 3> (amb <body 3> (cut))
(error "patterns exhausted")))))
```

So if for example downstream code attempts to backtrack through `<body 1>`, the innermost `amb` will catch the failure and invoke `cut`, which will jump over the outer `amb` to the `escape` which will then continue to backtrack out of the entire compound function.

Let's look at some concrete examples converting familiar functions to lambdas,
first here's `map`

```scheme
; fn map {
; (_, []) { [] }
; (f, h @ t) { f(h) @ map(f, t) }
; }
(define map
(lambda ($1 $2)
(escape
(amb (match $2
(nil (amb nil (cut)))
(pair (back))
(amb (let (f $1)
(match $2
(nil (let (h (field $2 0))
(let (t (field $2 1))
(amb (pair (f h) (map f t)) (cut)))))
(pair (back))))
(error "patterns exhausted in function map")))))))
```

Arguments to the lambda are bound to generated symbols, which shouldn't be
lexically symbols so they can't conflict (dollar-prefix should do it).

Then the entire compound is wrapped in an escape, and a nest of `amb`s check
each branch.

`match` is a simple exhaustive case statement for types.

`field` extracts a zero-indexed field from a compound structure like `pair`.

Nested `let`s bind variables appropriately, and the body of the function is constructed within those `let` bindings, wrapped in an `amb` with a trailing `cut`.

next let's look at `member`.

```scheme
; fn member {
; (_, []) { false }
; (x, x @ _) { true }
; (x, _ @ t) { member(x, t) }
; }
(define member
(lambda ($1 $2)
(escape
(amb (match $2
(0 (amb false (cut)))
(1 (back)))
(amb (let (x $1)
(match $2
(1 (if (eq x (field $2 0))
(amb true (cut))
(back)))
(0 (back))))
(amb (let (x $1)
(match $2
(1 (let (t (field $2 1))
(amb (member x t) (cut))))
(0 (back))))
(error "patterns exhausted in function member")))))))
```

Much the same process, The additional wrinkle is the comparison of the second
binding of `x` in the true branch, rather than just binding `x`.

There's rather an accumulation of failure continuations using this approach,
a function call not otherwise using `amb` now costs 2 failure continuations
that are likely never invoked if the application makes no use of `amb`.

Maybe there's a less costly way.

## Refinement

Leave `escape` as is, but change the behaviour of `cut`. Have `cut` now take an expression to evaluate, and *before* evaluating it, peel away all the failure
continuations up to and including the `escape`.

So instead of:

```scheme
(amb (member x t) (cut))
```

We just need:

```scheme
(cut (member x t))
```

In fact we can gain a bit more efficiency still by having `escape` merely
tag the current failure continuation, then `cut` peels back to leave that
continuation, un-tagging it instead of removing it. Any downstream backtracking
from the argument to `cut` will hit that continuation.


In fact, we probably don't even need `escape`. If the use of `cut` is restricted
to this specific situation, there will only ever be one failure continuation
installed for pattern matching, and `cut` merely restores the previous one.

## Changes to "The Math"

Complex expressions now include `cut`

$$
\begin{array}{rcl}
\mathtt{cexp} &::=& \mathtt{(aexp_0\\ aexp_1\dots aexp_n)}
\\
&|& \mathtt{(if\\ aexp\\ exp\\ exp)}
\\
&|& \mathtt{(call/cc\\ aexp)}
\\
&|& \mathtt{(letrec\\ ((var_1\\ aexp_1)\dots(var_n\\ aexp_n))\\ exp)}
\\
&|& \mathtt{(amb\\ exp\\ exp)}
\\
&|& \mathtt{(cut\\ exp)}
\\
&|& \mathtt{(back)}
\end{array}
$$

`cut` pops the topmost failure continuation and arranges for its argument to be evaluated. It would be an error if `cut` was invoked without a failure continuation in place:

$$
step(\mathtt{(cut\ exp)}, \rho, \kappa, \mathbf{backtrack}(\mathtt{exp'}, \rho', \kappa', f) = (\mathtt{exp}, \rho, \kappa, f))
$$

That's it. We won't expose `cut` as a language feature because its use is purely internal to the implementation.
110 changes: 76 additions & 34 deletions prototyping/DFATree.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def notePosition(self, i):


class NumericFarg(Farg):
"""
literal number
"""
def __init__(self, n):
self.n = n

Expand All @@ -39,6 +42,9 @@ def __str__(self):


class VecFarg(Farg):
"""
vector type like pair or maybe
"""
def __init__(self, label, *fields):
self.label = label
self.fields = fields
Expand All @@ -56,6 +62,9 @@ def __str__(self):


class VarFarg(Farg):
"""
variable
"""
def __init__(self, name):
self.name = name

Expand All @@ -66,18 +75,29 @@ def __str__(self):
return self.name


class ComparisonFarg(Farg):
def __init__(self, name):
class AssignmentFarg(Farg):
"""
name = value
"""
def __init__(self, name, value):
self.name = name
self.value = value

def accept(self, visitor, state):
return visitor.visitComparisonFarg(self, state)
return visitor.visitAssignmentFarg(self, state)

def __str__(self):
return self.name
return self.name + '=' + str(self.value)

def notePosition(self, i):
super().notePosition(i)
self.value.notePosition(i)


class WildcardFarg(Farg):
"""
Wildcard _
"""
def accept(self, visitor, state):
return visitor.visitWildcardFarg(self, state)

Expand Down Expand Up @@ -553,7 +573,7 @@ def __eq__(self, other):
return isinstance(other, DfaVarTransition) and self.var == other.var

def __str__(self):
return 'bind ' + self.var
return 'unify ' + self.var

def isWildcard(self):
return True
Expand All @@ -562,15 +582,19 @@ def isConditional(self):
return True


class DfaWildcardTransition(DfaTransition):
class DfaAssignmentTransition(DfaTransition):
def __init__(self, var):
super().__init__()
self.var = var

def __hash__(self):
return hash(('var', '*'))
return hash(('var', self.var))

def __eq__(self, other):
return isinstance(other, DfaWildcardTransition)
return isinstance(other, DfaAssignmentTransition) and self.var == other.var

def __str__(self):
return 'else'
return 'unify ' + self.var

def isWildcard(self):
return True
Expand All @@ -579,19 +603,18 @@ def isConditional(self):
return True


class DfaComparisonTransition(DfaTransition):
def __init__(self, var):
super().__init__()
self.var = var

class DfaWildcardTransition(DfaTransition):
def __hash__(self):
return hash(('cmp', self.var))
return hash(('var', '*'))

def __eq__(self, other):
return isinstance(other, DfaComparisonTransition) and self.var == other.var
return isinstance(other, DfaWildcardTransition)

def __str__(self):
return f'=={str(self.var)}'
return 'else'

def isWildcard(self):
return True

def isConditional(self):
return True
Expand Down Expand Up @@ -716,36 +739,36 @@ def key(self):
return DfaArgTransition(self.index)


class NfaVarTransition(NfaTransition):
class NfaAssignmentTransition(NfaTransition):
def __init__(self, name, to):
super().__init__(to)
self.name = name

def __str__(self):
return 'bind ' + self.name + ' ' + str(self.to)
return 'unify ' + self.name + ' ' + str(self.to)

def key(self):
return DfaVarTransition(self.name)
return DfaAssignmentTransition(self.name)


class NfaWildcardTransition(NfaTransition):
class NfaVarTransition(NfaTransition):
def __init__(self, name, to):
super().__init__(to)
self.name = name

def __str__(self):
return '(=*) ' + str(self.to)
return 'unify ' + self.name + ' ' + str(self.to)

def key(self):
return DfaWildcardTransition()
return DfaVarTransition(self.name)


class NfaComparisonTransition(NfaTransition):
def __init__(self, name, to):
super().__init__(to)
self.name = name

class NfaWildcardTransition(NfaTransition):
def __str__(self):
return f'(={self.name}) ' + str(self.to)
return '(=*) ' + str(self.to)

def key(self):
return DfaComparisonTransition(self.name)
return DfaWildcardTransition()


class FargToNfaVisitor:
Expand Down Expand Up @@ -780,12 +803,12 @@ def recursivelyVisitVec(self, fields, count, finalState):
def visitVarFarg(self, var, state):
return NfaState([NfaVarTransition(var.name, state)])

def visitAssignmentFarg(self, assignment, state):
return NfaState([NfaAssignmentTransition(assignment.name, assignment.value.accept(self, state))])

def visitWildcardFarg(self, wildcard, state):
return NfaState([NfaWildcardTransition(state)])

def visitComparisonFarg(self, comparison, state):
return NfaState([NfaComparisonTransition(comparison.name, state)])


def makeMermaid(args):
print('```plaintext')
Expand All @@ -809,7 +832,7 @@ def makeMermaid(args):

memberArgs = Compound(
Fargs('false', VarFarg('x'), VecFarg('nil')),
Fargs('true', VarFarg('x'), VecFarg('cons', ComparisonFarg('x'), WildcardFarg())),
Fargs('true', VarFarg('x'), VecFarg('cons', VarFarg('x'), WildcardFarg())),
Fargs('continue', VarFarg('x'), VecFarg('cons', WildcardFarg(), VarFarg('t')))
)
makeMermaid(memberArgs)
Expand All @@ -823,3 +846,22 @@ def makeMermaid(args):
)
makeMermaid(testArgs)

"""
Problems
--------
Because we're matching the args for all functions at once, we can't pay attention to variable names,
we have to assume the variables are bound by a previous (or subsequent) process.
That in turn means we can't handle assignment args (x=[1, 2] etc.) because x won't be bound.
And there are problems with matching common values too, like in member:
fn member {
(_, []) { false }
(x, x @ y) { true }
(y, _ @ t) { member(y, t) }
}
The parallel matching would have to bind y as well as x to the first argument, but y shouldn't be bound
in the true case because types are different and it would become a comparison with the second appearence
of y in the true branch.
"""
Loading

0 comments on commit bea9e68

Please sign in to comment.