Skip to content

Commit

Permalink
Merge pull request #1026 from herbie-fp/zane-true-error
Browse files Browse the repository at this point in the history
Concrete Values for Odyssey
  • Loading branch information
zaneenders authored Nov 14, 2024
2 parents 6cd3051 + 6ed1a1d commit 1a83c9f
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 118 deletions.
109 changes: 105 additions & 4 deletions infra/testApi.mjs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { strict as assert } from 'node:assert'; // use strict equality everywhere

// Future TODO: before this API becomes set in stone/offered publicly, we should change the results of these methods to be just the output data rather than duplicating input values.

// Reusable testing data
const SAMPLE_SIZE = 8000
const FPCoreFormula = '(FPCore (x) (- (sqrt (+ x 1)) (sqrt x)))'
const FPCoreFormula2 = '(FPCore (x) (- (sqrt (+ x 1))))'
const FPCoreFormula3 = '(FPCore (x) (if (<= (- (sqrt (+ x 1.0)) (sqrt x)) 0.05) (* 0.5 (sqrt (/ 1.0 x))) (fma (fma (- 0.125) x 0.5) x (- 1.0 (sqrt x)))))'
const eval_sample = [[[1], -1.4142135623730951]]

// improve endpoint
Expand Down Expand Up @@ -90,7 +90,7 @@ assert.deepEqual(sample.points[1], sample2.points[1])
const explainBody = {
method: 'POST',
body: JSON.stringify({
formula: FPCoreFormula, sample: sample2.points
formula: FPCoreFormula, sample: sample.points
})
}
const explain = await (await fetch(makeEndpoint("/api/explanations"), explainBody)).json()
Expand Down Expand Up @@ -144,7 +144,7 @@ assert.deepEqual(calculateAsyncResult.points, [[[1], -1.4142135623730951]])
// Local error endpoint
const localErrorBody = {
method: 'POST', body: JSON.stringify({
formula: FPCoreFormula, sample: sample2.points
formula: FPCoreFormula, sample: sample.points
})
}
const localError = await (await fetch(makeEndpoint("/api/localerror"), localErrorBody)).json()
Expand All @@ -166,6 +166,107 @@ const localError2 = await (await fetch(makeEndpoint("/api/localerror"), {
})).json()
// Test that different sample points produce different job ids ensuring that different results are served for these inputs.
assert.notEqual(localError1.job, localError2.job)
// Assert local error works for default example.
const ignoredValue = 1e+308
'(FPCore (1e-100) (- (sqrt (+ x 1)) (sqrt x)))'
const localError5 = await (await fetch(makeEndpoint("/api/localerror"), {
method: 'POST', body: JSON.stringify({
formula: FPCoreFormula, sample: [[[1e-100], ignoredValue]], seed: 5
})
})).json()

// avg_error, actual_value, exact_value, absolute_difference, ulps_error
// root node
checkLocalErrorNode(localError5.tree, [],
'-', '0.0', '1.0', '1.0', '1e-50', '0.0')
// left sqrt
checkLocalErrorNode(localError5.tree, [0],
'sqrt', '0.0', '1.0', '1.0', '5e-101', '0.0')
// right sqrt
checkLocalErrorNode(localError5.tree, [1],
'sqrt', '0.0', '1e-50', '1e-50', '2.379726195519099e-68', '0.0')
// plus
checkLocalErrorNode(localError5.tree, [0, 0],
'+', '0.0', '1.0', '1.0', '1e-100', '0.0')
// var x
checkLocalErrorNode(localError5.tree, [0, 0, 0],
'x', '0.0', '1e-100', '1e-100', 'equal', '0.0')
// literal 1
checkLocalErrorNode(localError5.tree, [0, 0, 1],
'1.0', '0.0', '1.0', '1.0', 'equal', '0.0')

// '(FPCore (1e100) (- (sqrt (+ x 1)) (sqrt x)))'
const localError6 = await (await fetch(makeEndpoint("/api/localerror"), {
method: 'POST', body: JSON.stringify({
formula: FPCoreFormula, sample: [[[1e100], ignoredValue]], seed: 5
})
})).json()
// avg_error, actual_value, exact_value, absolute_error, ulps_error
// root node
checkLocalErrorNode(localError6.tree, [],
'-', '61.7', '0.0', '5e-51', '5e-51', '61.74124908607812')
// left sqrt
checkLocalErrorNode(localError6.tree, [0],
'sqrt', '0.0', '1e+50', '1e+50', '6.834625285603891e+33', '0.0')
// right sqrt
checkLocalErrorNode(localError6.tree, [1],
'sqrt', '0.0', '1e+50', '1e+50', '6.834625285603891e+33', '0.0')
// plus
checkLocalErrorNode(localError6.tree, [0, 0],
'+', '0.0', '1e+100', '1e+100', '1.0', '0.0')
// var x
checkLocalErrorNode(localError6.tree, [0, 0, 0],
'x', '0.0', '1e+100', '1e+100', 'equal', '0.0')
// literal 1
checkLocalErrorNode(localError6.tree, [0, 0, 1],
'1.0', '0.0', '1.0', '1.0', 'equal', '0.0')

// Test a large number `2e269` to trigger NaNs in local error
const localError7 = await (await fetch(makeEndpoint("/api/localerror"), {
method: 'POST', body: JSON.stringify({
formula: FPCoreFormula3, sample: [[[2e269], ignoredValue]], seed: 5
})
})).json()
// Test against conditionals expressions
checkLocalErrorNode(localError7.tree, [0],
'<=', '0.0', 'true', 'true', 'invalid', '0.0')
// TODO a bug in Rival
// checkLocalErrorNode(localError7.tree, [0, 0],
// '-', '61.2', '0.0', '1.1180339887498948e-135', '1.1180339887498948e-135', '61.16647760559045')
checkLocalErrorNode(localError7.tree, [0, 1],
'0.05', '0.0', '0.05', '0.05', 'invalid', '0.0')
checkLocalErrorNode(localError7.tree, [2],
'fma', '0.0', '-inf.0', '-inf.0', 'invalid', '0.0')

/// root: The root node of the local error tree.
/// path: the path to get to the node you want to test.
/// name: Name of the node you are testing
/// avg_error: Average Error
/// actual_value: Value of the node
/// exact_value: The correct evaluation of the expression
/// absolute_difference: The ABS of the error at the node |approx - exact|
/// ulps_error: ulps of error at this node.
function checkLocalErrorNode(root, path, name,
avg_error, actual_value, exact_value, absolute_difference, ulps_error) {
const node = getNodeFromPath(root, path)
// console.log(node) // Helpful for seeing which node is failing a test
assert.equal(node['e'], name)
assert.equal(node['avg-error'], avg_error)
assert.equal(node['actual-value'], actual_value)
assert.equal(node['exact-value'], exact_value)
assert.equal(node['abs-error-difference'], absolute_difference)
assert.equal(node['ulps-error'], ulps_error)
}

function getNodeFromPath(node, path) {
if (path.length > 0) {
const index = path.shift()
const child = node['children'][index]
return getNodeFromPath(child, path)
} else {
return node
}
}

// Alternatives endpoint
const altBody = {
Expand Down Expand Up @@ -257,4 +358,4 @@ async function callAsyncAndWaitJSONResult(endpoint, body) {
}
const result = await fetch(makeEndpoint(`/api/result/${jobJSON.job}`), { method: 'GET' })
return await result.json()
}
}
2 changes: 1 addition & 1 deletion src/api/sandbox.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@

(define-values (train-pcontext test-pcontext) (partition-pcontext pcontext))
(*pcontext* test-pcontext)
(local-error-as-tree test (*context*)))
(local-error-as-tree (test-input test) (*context*)))

(define (get-explanations test pcontext)
(unless pcontext
Expand Down
12 changes: 3 additions & 9 deletions src/core/explain.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,10 @@
[else #t]))

(define (actual-errors expr pcontext)

(define errs
(match-define (cons subexprs pt-errorss)
(parameterize ([*pcontext* pcontext])
(first (compute-local-errors (list (all-subexpressions expr)) (*context*)))))

(define pruned (make-hash))
(for ([(k v) (in-hash errs)])
(hash-set! pruned k (hash-ref v 'errs)))
(define idk (flip-lists (hash->list pruned)))
(match-define (cons subexprs pt-errorss) idk)
(flip-lists (hash->list (first (compute-local-errors (list (all-subexpressions expr))
(*context*)))))))

(define pt-worst-subexpr
(append* (reap [sow]
Expand Down
Loading

0 comments on commit 1a83c9f

Please sign in to comment.