diff --git a/egg-herbie/main.rkt b/egg-herbie/main.rkt index 2076ddbe4..f39ab5a2a 100644 --- a/egg-herbie/main.rkt +++ b/egg-herbie/main.rkt @@ -2,27 +2,29 @@ (require ffi/unsafe ffi/unsafe/define + ffi/vector racket/runtime-path) (provide egraph_create egraph_destroy egraph_add_expr + egraph_add_node egraph_run egraph_copy egraph_get_stop_reason - egraph_serialize egraph_find + egraph_serialize + egraph_get_eclasses + egraph_get_eclass egraph_get_simplest egraph_get_variants - _EGraphIter - destroy_egraphiters egraph_get_cost egraph_is_unsound_detected egraph_get_times_applied egraph_get_proof - destroy_string - (struct-out EGraphIter) - (struct-out FFIRule)) + (struct-out iteration-data) + (struct-out FFIRule) + make-ffi-rule) (define-runtime-path libeggmath-path (build-path "target/release" @@ -60,6 +62,95 @@ ; Checks for a condition on MacOS if x86 Racket is being used on an ARM mac. (define-ffi-definer define-eggmath (ffi-lib libeggmath-path #:fail handle-eggmath-import-failure)) +;; Frees a Rust-allocated C-string +(define-eggmath destroy_string (_fun _pointer -> _void)) + +;; Gets the length of a Rust-allocated C-string in bytes, +;; excluding the nul terminator. +(define-eggmath string_length (_fun _pointer -> _uint32)) + +;; Converts a Racket string to a C-style string. +(define (string->_rust/string s #:raw? [raw? #f]) + (define bstr (string->bytes/utf-8 s)) + (define n (bytes-length bstr)) + (define p (malloc (if raw? 'raw 'atomic) (add1 n))) + (memcpy p bstr n) + (ptr-set! p _byte n 0) + p) + +;; Converts a non-NULL, Rust-allocated C-string to a Racket string, +;; freeing the Rust string. +(define (_rust/string->string p) + (define len (string_length p)) + (define bstr (make-bytes len)) + (memcpy bstr p len) + (destroy_string p) + (bytes->string/utf-8 bstr)) + +;; Converts a non-NULL, Rust-allocated C-string to a Racket datum +;; by repeatedly reading the string. The underlying Rust string +;; is automatically freed. +(define (_rust/string->data p) + (define len (string_length p)) + (define bstr (make-bytes len)) + (memcpy bstr p len) + (destroy_string p) + (for/list ([datum (in-port read (open-input-bytes bstr))]) + datum)) + +;; FFI type that converts Rust-allocated C-style strings +;; to Racket strings, automatically freeing the Rust-side allocation. +(define _rust/string + (make-ctype _pointer + (lambda (x) (and x (string->_rust/string x))) + (lambda (x) (and x (_rust/string->string x))))) + +;; FFI type that converts Rust-allocated C-style strings +;; to a Racket datum via `read`, automatically freeing the Rust-side allocation. +(define _rust/datum + (make-ctype _pointer + (lambda (x) (and x (string->_rust/string (~a x)))) + (lambda (x) (and x (first (_rust/string->data x)))))) + +;; FFI type that converts Rust-allocated C-style strings +;; to multiple Racket datum via reapeated use of `read`, +;; automatically freeing the Rust-side allocation. +(define _rust/data + (make-ctype _pointer + (lambda (_) (error '_rust/data "cannot be used as an input type")) + (lambda (x) (and x (_rust/string->data x))))) + +; Egraph iteration data +; Not managed by Racket GC. +; Must call `destroy_egraphiters` to free. +(define-cstruct _EGraphIter ([numnodes _uint] [numeclasses _uint] [time _double]) #:malloc-mode 'raw) + +;; Frees an array of _EgraphIter structs +(define-eggmath destroy_egraphiters (_fun _pointer -> _void)) + +;; Racket representation of `_EGraphIter` +(struct iteration-data (num-nodes num-eclasses time)) + +;; Rewrite rule that can be passed over the FFI boundary. +;; Must be manually freed. +(define-cstruct _FFIRule ([name _pointer] [left _pointer] [right _pointer]) #:malloc-mode 'raw) + +;; Constructs for `_FFIRule` struct. +(define (make-ffi-rule name lhs rhs) + (define name* (string->_rust/string (~a name) #:raw? #t)) + (define lhs* (string->_rust/string (~a lhs) #:raw? #t)) + (define rhs* (string->_rust/string (~a rhs) #:raw? #t)) + (define p (make-FFIRule name* lhs* rhs*)) + (register-finalizer p free-ffi-rule) + p) + +;; Frees a `_FFIRule` struct. +(define (free-ffi-rule rule) + (free (FFIRule-name rule)) + (free (FFIRule-left rule)) + (free (FFIRule-right rule)) + (free rule)) + ; GC'able egraph ; If Racket GC can prove unreachable, `egraph_destroy` will be called (define _egraph-pointer @@ -70,30 +161,32 @@ (register-finalizer p egraph_destroy) p))) -; Egraph iteration data -; Not managed by Racket GC. -; Must call `destroy_egraphiters` to free. -(define-cstruct _EGraphIter ([numnodes _uint] [numeclasses _uint] [time _double]) #:malloc-mode 'raw) - -; Rewrite rule -; Not managed by Racket GC. -; Must call `free` on struct and fields -(define-cstruct _FFIRule ([name _pointer] [left _pointer] [right _pointer]) #:malloc-mode 'raw) - -;; -> a pointer to an egraph +;; Constructs an e-graph instances. (define-eggmath egraph_create (_fun -> _egraph-pointer)) +;; Frees an e-graph instance. (define-eggmath egraph_destroy (_fun _egraph-pointer -> _void)) -(define-eggmath destroy_string (_fun _pointer -> _void)) +;; Copies an e-graph instance. +(define-eggmath egraph_copy (_fun _egraph-pointer -> _egraph-pointer)) -;; egraph pointer, s-expr string -> node number -(define-eggmath egraph_add_expr (_fun _egraph-pointer _string/utf-8 -> _uint)) +;; Adds an expression to the e-graph. +;; egraph -> expr -> id +(define-eggmath egraph_add_expr (_fun _egraph-pointer _rust/datum -> _uint)) -(define-eggmath destroy_egraphiters (_fun _pointer -> _void)) +; egraph -> string -> ids -> bool -> id +(define-eggmath egraph_add_node + (_fun [p : _egraph-pointer] ; egraph + [f : _rust/string] ; enode op + [v : _u32vector] ; id vector + [_uint = (u32vector-length v)] ; id vector length + [is_root : _stdbool] ; root node? + -> + _uint)) (define-eggmath egraph_is_unsound_detected (_fun _egraph-pointer -> _stdbool)) +;; Runs the egraph with a set of rules, returning the statistics of the run. (define-eggmath egraph_run (_fun _egraph-pointer ;; egraph (ffi-rules : (_list i _FFIRule-pointer)) ;; ffi rules @@ -106,39 +199,93 @@ _stdbool ;; simple scheduler? _stdbool ;; constant folding enabled? -> - (iterations : _EGraphIter-pointer) + (iterations : _EGraphIter-pointer) ;; array of _EgraphIter structs -> - (values iterations iterations-length iterations-ptr))) - -;; creates a fresh runner from an existing egraph -(define-eggmath egraph_copy (_fun _egraph-pointer -> _egraph-pointer)) + (begin + (define iter-data + (for/list ([i (in-range iterations-length)]) + (define ptr (ptr-add iterations i _EGraphIter)) + (iteration-data (EGraphIter-numnodes ptr) + (EGraphIter-numeclasses ptr) + (EGraphIter-time ptr)))) + (destroy_egraphiters iterations-ptr) + iter-data))) ;; gets the stop reason as an integer (define-eggmath egraph_get_stop_reason (_fun _egraph-pointer -> _uint)) ;; egraph -> string -(define-eggmath egraph_serialize (_fun _egraph-pointer -> _string)) +(define-eggmath egraph_serialize (_fun _egraph-pointer -> _rust/datum)) + +;; egraph -> uint +(define-eggmath egraph_size (_fun _egraph-pointer -> _uint)) + +;; egraph -> id -> uint +(define-eggmath egraph_eclass_size (_fun _egraph-pointer _uint -> _uint)) + +;; egraph -> id -> idx -> uint +(define-eggmath egraph_enode_size (_fun _egraph-pointer _uint _uint -> _uint)) + +;; egraph -> u32vector +(define-eggmath + egraph_get_eclasses + (_fun [e : _egraph-pointer] [v : _u32vector = (make-u32vector (egraph_size e))] -> _void -> v)) + +;; egraph -> id -> u32 -> (or symbol? number? (cons symbol u32vector)) +;; UNSAFE: `v` must be large enough to contain the child ids +(define-eggmath egraph_get_node + (_fun [e : _egraph-pointer] + [id : _uint32] + [idx : _uint32] + [v : _u32vector] + -> + [f : _rust/string] + -> + (if (zero? (u32vector-length v)) + (or (string->number f) (string->symbol f)) + (cons (string->symbol f) v)))) +; u32vector +(define empty-u32vec (make-u32vector 0)) + +; egraph -> id -> (vectorof (or symbol? number? (cons symbol u32vector))) +(define (egraph_get_eclass egg-ptr id) + (define n (egraph_eclass_size egg-ptr id)) + (for/vector #:length n + ([i (in-range n)]) + (define node-size (egraph_enode_size egg-ptr id i)) + (if (zero? node-size) + (egraph_get_node egg-ptr id i empty-u32vec) + (egraph_get_node egg-ptr id i (make-u32vector node-size))))) ;; egraph -> id -> id (define-eggmath egraph_find (_fun _egraph-pointer _uint -> _uint)) -;; node number -> s-expr string +;; egraph -> id -> (listof expr) (define-eggmath egraph_get_simplest (_fun _egraph-pointer _uint ;; node id _uint ;; iteration -> - _pointer)) - -(define-eggmath egraph_get_proof (_fun _egraph-pointer _string/utf-8 _string/utf-8 -> _pointer)) + _rust/datum)) ;; expr -;; node number -> (s-expr string) string +;; egraph -> id -> string -> (listof expr) (define-eggmath egraph_get_variants (_fun _egraph-pointer _uint ;; node id - _string/utf-8 ;; original expr + _rust/datum ;; original expr + -> + _rust/data)) ;; listof expr + +;; egraph -> string -> string -> string +;; TODO: in Herbie, we bail on converting the proof +;; if the string is too big. It would be more efficient +;; to bail here instead. +(define-eggmath egraph_get_proof + (_fun _egraph-pointer ;; egraph + _rust/datum ;; expr1 + _rust/datum ;; expr2 -> - _pointer)) ;; string pointer + _rust/string)) ;; string (define-eggmath egraph_get_cost (_fun _egraph-pointer diff --git a/egg-herbie/src/lib.rs b/egg-herbie/src/lib.rs index ecadda54d..cba0d6124 100644 --- a/egg-herbie/src/lib.rs +++ b/egg-herbie/src/lib.rs @@ -2,9 +2,9 @@ pub mod math; -use egg::{BackoffScheduler, Extractor, Id, Language, SimpleScheduler, StopReason, Symbol}; +use egg::{BackoffScheduler, Extractor, FromOp, Id, Language, SimpleScheduler, StopReason, Symbol}; use indexmap::IndexMap; -use libc::c_void; +use libc::{c_void, strlen}; use math::*; use std::cmp::min; @@ -47,6 +47,11 @@ pub unsafe extern "C" fn destroy_string(ptr: *mut c_char) { drop(CString::from_raw(ptr)) } +#[no_mangle] +pub unsafe extern "C" fn string_length(ptr: *const c_char) -> u32 { + strlen(ptr) as u32 +} + #[repr(C)] pub struct EGraphIter { numnodes: u32, @@ -69,7 +74,6 @@ pub unsafe extern "C" fn egraph_add_expr(ptr: *mut Context, expr: *const c_char) let mut context = Box::from_raw(ptr); assert_eq!(context.iteration, 0); - let rec_expr = CStr::from_ptr(expr).to_str().unwrap().parse().unwrap(); context.runner = context.runner.with_expr(&rec_expr); let id = usize::from(*context.runner.roots.last().unwrap()) @@ -81,6 +85,31 @@ pub unsafe extern "C" fn egraph_add_expr(ptr: *mut Context, expr: *const c_char) id } +#[no_mangle] +pub unsafe extern "C" fn egraph_add_node( + ptr: *mut Context, + f: *const c_char, + ids_ptr: *const u32, + num_ids: u32, + is_root: bool, +) -> u32 { + let _ = env_logger::try_init(); + // Safety: `ptr` was box allocated by `egraph_create` + let mut context = ManuallyDrop::new(Box::from_raw(ptr)); + + let f = CStr::from_ptr(f).to_str().unwrap(); + let len = num_ids as usize; + let ids: &[u32] = slice::from_raw_parts(ids_ptr, len); + let ids = ids.iter().map(|id| Id::from(*id as usize)).collect(); + let node = Math::from_op(f, ids).unwrap(); + let id = context.runner.egraph.add(node); + if is_root { + context.runner.roots.push(id); + } + + usize::from(id) as u32 +} + #[no_mangle] pub unsafe extern "C" fn egraph_copy(ptr: *mut Context) -> *mut Context { // Safety: `ptr` was box allocated by `egraph_create` @@ -239,10 +268,14 @@ pub unsafe extern "C" fn egraph_find(ptr: *mut Context, id: usize) -> u32 { pub unsafe extern "C" fn egraph_serialize(ptr: *mut Context) -> *const c_char { // Safety: `ptr` was box allocated by `egraph_create` let context = ManuallyDrop::new(Box::from_raw(ptr)); + let mut ids: Vec = context.runner.egraph.classes().map(|c| c.id).collect(); + ids.sort(); + // Iterate through the eclasses and print each eclass let mut s = String::from("("); - for c in context.runner.egraph.classes() { - s.push_str(&format!("({}", c.id)); + for id in ids { + let c = &context.runner.egraph[id]; + s.push_str(&format!("({}", id)); for node in &c.nodes { if matches!(node, Math::Symbol(_) | Math::Constant(_)) { s.push_str(&format!(" {}", node)); @@ -263,6 +296,63 @@ pub unsafe extern "C" fn egraph_serialize(ptr: *mut Context) -> *const c_char { c_string.as_ptr() } +#[no_mangle] +pub unsafe extern "C" fn egraph_size(ptr: *mut Context) -> u32 { + let context = ManuallyDrop::new(Box::from_raw(ptr)); + context.runner.egraph.number_of_classes() as u32 +} + +#[no_mangle] +pub unsafe extern "C" fn egraph_eclass_size(ptr: *mut Context, id: u32) -> u32 { + let context = ManuallyDrop::new(Box::from_raw(ptr)); + let id = Id::from(id as usize); + context.runner.egraph[id].nodes.len() as u32 +} + +#[no_mangle] +pub unsafe extern "C" fn egraph_enode_size(ptr: *mut Context, id: u32, idx: u32) -> u32 { + let context = ManuallyDrop::new(Box::from_raw(ptr)); + let id = Id::from(id as usize); + let idx = idx as usize; + context.runner.egraph[id].nodes[idx].len() as u32 +} + +#[no_mangle] +pub unsafe extern "C" fn egraph_get_eclasses(ptr: *mut Context, ids_ptr: *mut u32) { + let context = ManuallyDrop::new(Box::from_raw(ptr)); + let mut ids: Vec = context + .runner + .egraph + .classes() + .map(|c| usize::from(c.id) as u32) + .collect(); + ids.sort(); + + for (i, id) in ids.iter().enumerate() { + std::ptr::write(ids_ptr.offset(i as isize), *id); + } +} + +#[no_mangle] +pub unsafe extern "C" fn egraph_get_node( + ptr: *mut Context, + id: u32, + idx: u32, + ids: *mut u32, +) -> *const c_char { + let context = ManuallyDrop::new(Box::from_raw(ptr)); + let id = Id::from(id as usize); + let idx = idx as usize; + + let node = &context.runner.egraph[id].nodes[idx]; + for (i, id) in node.children().iter().enumerate() { + std::ptr::write(ids.offset(i as isize), usize::from(*id) as u32); + } + + let c_string = ManuallyDrop::new(CString::new(node.to_string()).unwrap()); + c_string.as_ptr() +} + #[no_mangle] pub unsafe extern "C" fn egraph_get_simplest( ptr: *mut Context, diff --git a/infra/convert-demo.rkt b/infra/convert-demo.rkt index f25e2a3e2..123fa02e7 100644 --- a/infra/convert-demo.rkt +++ b/infra/convert-demo.rkt @@ -34,7 +34,10 @@ (define exprs-unfiltered (for/list ([test tests]) (read-expr (hash-ref test 'input) is-version-10))) - (define exprs (for/set ([expr exprs-unfiltered] #:when (not (set-member? existing-set expr))) expr)) + (define exprs + (for/set ([expr exprs-unfiltered] + #:when (not (set-member? existing-set expr))) + expr)) (for ([expr (in-set exprs)]) (fprintf output-file "~a\n" (make-fpcore expr))) exprs) diff --git a/infra/testApi.mjs b/infra/testApi.mjs index 0b26c3e5a..1a81be1e0 100644 --- a/infra/testApi.mjs +++ b/infra/testApi.mjs @@ -40,17 +40,6 @@ const testResult = (startResponse.status == 201 || startResponse.status == 202) assert.equal(testResult, true) const path = startResponse.headers.get("location") -// Check status endpoint -const checkStatus = await callHerbie(path, { method: 'GET' }) -// Test result depends on how fast Server responds -if (checkStatus.status == 202) { - assert.equal(checkStatus.statusText, 'Job in progress') -} else if (checkStatus.status == 201) { - assert.equal(checkStatus.statusText, 'Job complete') -} else { - assert.fail() -} - // up endpoint const up = await callHerbie("/up", { method: 'GET' }) assert.equal('Up', up.statusText) @@ -189,6 +178,20 @@ for (const e in expectedExpressions) { assert.equal(translatedExpr.result, expectedExpressions[e]) } +let counter = 0 +let cap = 100 +// Check status endpoint +let checkStatus = await callHerbie(path, { method: 'GET' }) +/* +This is testing if the /improve-start test at the beginning has been completed. The cap and counter is a sort of timeout for the test. Ends up being 10 seconds max. +*/ +while (checkStatus.status != 201 && counter < cap) { + counter += 1 + checkStatus = await callHerbie(path, { method: 'GET' }) + await new Promise(r => setTimeout(r, 100)); // ms +} +assert.equal(checkStatus.statusText, 'Job complete') + // Results.json endpoint const jsonResults = await callHerbie("/results.json", { method: 'GET' }) diff --git a/src/api/demo.rkt b/src/api/demo.rkt index e00c39bf7..d7f967b07 100644 --- a/src/api/demo.rkt +++ b/src/api/demo.rkt @@ -11,7 +11,8 @@ web-server/dispatch/extend web-server/http/bindings web-server/configuration/responders - web-server/managers/none) + web-server/managers/none + web-server/safety-limits) (require "../utils/common.rkt" "../config.rkt" @@ -47,7 +48,7 @@ (and (not (and (*demo-output*) ; If we've already saved to disk, skip this job (directory-exists? (build-path (*demo-output*) x)))) (let ([m (regexp-match #rx"^([0-9a-f]+)\\.[0-9a-f.]+" x)]) - (and m (completed-job? (second m)))))) + (and m (get-results-for (second m)))))) (λ (x) (let ([m (regexp-match #rx"^([0-9a-f]+)\\.[0-9a-f.]+" x)]) (get-results-for (if m (second m) x))))) @@ -73,19 +74,20 @@ [((hash-arg) (string-arg)) generate-page] [("results.json") generate-report])) -(define (generate-page req result page) +(define (generate-page req result-hash page) (define path (first (string-split (url->string (request-uri req)) "/"))) (cond - [(set-member? (all-pages result) page) + [(set-member? (all-pages result-hash) page) ;; Write page contents to disk (when (*demo-output*) (make-directory (build-path (*demo-output*) path)) - (for ([page (all-pages result)]) - (call-with-output-file (build-path (*demo-output*) path page) - (λ (out) - (with-handlers ([exn:fail? (page-error-handler result page out)]) - (make-page page out result (*demo-output*) #f))))) - (update-report result + (for ([page (all-pages result-hash)]) + (call-with-output-file + (build-path (*demo-output*) path page) + (λ (out) + (with-handlers ([exn:fail? (page-error-handler result-hash page out)]) + (make-page page out result-hash (*demo-output*) #f))))) + (update-report result-hash path (get-seed) (build-path (*demo-output*) "results.json") @@ -96,8 +98,8 @@ #"text" (list (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count))))) (λ (out) - (with-handlers ([exn:fail? (page-error-handler result page out)]) - (make-page page out result (*demo-output*) #f))))] + (with-handlers ([exn:fail? (page-error-handler result-hash page out)]) + (make-page page out result-hash (*demo-output*) #f))))] [else (next-dispatcher)])) (define (generate-report req) @@ -106,7 +108,7 @@ (next-dispatcher)] [else (define info - (make-report-info (get-improve-job-data) + (make-report-info (get-improve-table-data) #:seed (get-seed) #:note (if (*demo?*) "Web demo results" "Herbie results"))) (response 200 @@ -232,9 +234,9 @@ (a ([href "./index.html"]) " See what formulas other users submitted."))] [else `("all formulas submitted here are " (a ([href "./index.html"]) "logged") ".")]))))) -(define (update-report result dir seed data-file html-file) +(define (update-report result-hash dir seed data-file html-file) (define link (path-element->string (last (explode-path dir)))) - (define data (get-table-data result link)) + (define data (get-table-data-from-hash result-hash link)) (define info (if (file-exists? data-file) (let ([info (read-datafile data-file)]) @@ -337,7 +339,16 @@ (url main))) (define (check-status req job-id) - (match (is-job-finished job-id) + (define r (get-results-for job-id)) + ;; TODO return the current status from the jobs timeline + (match r + [#f + (response 202 + #"Job in progress" + (current-seconds) + #"text/plain" + (list (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count))))) + (λ (out) (display "Not done!" out)))] [(? box? timeline) (response 202 #"Job in progress" @@ -349,7 +360,7 @@ (for/list ([entry (reverse (unbox timeline))]) (format "Doing ~a\n" (hash-ref entry 'type)))) out)))] - [#f + [(? hash? result-hash) (response/full 201 #"Job complete" (current-seconds) @@ -398,7 +409,7 @@ (list (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count)))) (header #"X-Herbie-Job-ID" (string->bytes/utf-8 job-id)) (header #"Access-Control-Allow-Origin" (string->bytes/utf-8 "*"))) - (λ (out) (write-json (job-result-timeline job-result) out)))])) + (λ (out) (write-json (hash-ref job-result 'timeline) out)))])) ; /api/sample endpoint: test in console on demo page: ;; (await fetch('/api/sample', {method: 'POST', body: JSON.stringify({formula: "(FPCore (x) (- (sqrt (+ x 1))))", seed: 5})})).json() @@ -412,10 +423,7 @@ (define command (create-job 'sample test #:seed seed* #:pcontext #f #:profile? #f #:timeline-disabled? #t)) (define id (start-job command)) - (define result (wait-for-job id)) - (define pctx (job-result-backend result)) - (define repr (context-repr (test-context test))) - (hasheq 'points (pcontext->json pctx repr) 'job id 'path (make-path id))))) + (wait-for-job id)))) (define explanations-endpoint (post-with-json-response (lambda (post-data) @@ -423,10 +431,7 @@ (define formula (read-syntax 'web (open-input-string formula-str))) (define sample (hash-ref post-data 'sample)) (define seed (hash-ref post-data 'seed #f)) - (eprintf "Explanations job started on ~a...\n" formula-str) - (define test (parse-test formula)) - (define expr (prog->fpcore (test-input test))) (define pcontext (json->pcontext sample (test-context test))) (define command (create-job 'explanations @@ -436,11 +441,7 @@ #:profile? #f #:timeline-disabled? #t)) (define id (start-job command)) - (define result (wait-for-job id)) - (define explanations (job-result-backend result)) - - (eprintf " complete\n") - (hasheq 'explanation explanations 'job id 'path (make-path id))))) + (wait-for-job id)))) (define analyze-endpoint (post-with-json-response (lambda (post-data) @@ -458,33 +459,26 @@ #:profile? #f #:timeline-disabled? #t)) (define id (start-job command)) - (define result (wait-for-job id)) - (define errs - (for/list ([pt&err (job-result-backend result)]) - (define pt (first pt&err)) - (define err (second pt&err)) - (list pt (format-bits (ulps->bits err))))) - (hasheq 'points errs 'job id 'path (make-path id))))) + (wait-for-job id)))) ;; (await fetch('/api/exacts', {method: 'POST', body: JSON.stringify({formula: "(FPCore (x) (- (sqrt (+ x 1))))", points: [[1, 1]]})})).json() (define exacts-endpoint - (post-with-json-response - (lambda (post-data) - (define formula (read-syntax 'web (open-input-string (hash-ref post-data 'formula)))) - (define sample (hash-ref post-data 'sample)) - (define seed (hash-ref post-data 'seed #f)) - (define test (parse-test formula)) - (define pcontext (json->pcontext sample (test-context test))) - (define command - (create-job 'exacts - test - #:seed seed - #:pcontext pcontext - #:profile? #f - #:timeline-disabled? #t)) - (define id (start-job command)) - (define result (wait-for-job id)) - (hasheq 'points (job-result-backend result) 'job id 'path (make-path id))))) + (post-with-json-response (lambda (post-data) + (define formula + (read-syntax 'web (open-input-string (hash-ref post-data 'formula)))) + (define sample (hash-ref post-data 'sample)) + (define seed (hash-ref post-data 'seed #f)) + (define test (parse-test formula)) + (define pcontext (json->pcontext sample (test-context test))) + (define command + (create-job 'exacts + test + #:seed seed + #:pcontext pcontext + #:profile? #f + #:timeline-disabled? #t)) + (define id (start-job command)) + (wait-for-job id)))) (define calculate-endpoint (post-with-json-response (lambda (post-data) @@ -502,9 +496,7 @@ #:profile? #f #:timeline-disabled? #t)) (define id (start-job command)) - (define result (wait-for-job id)) - (define approx (job-result-backend result)) - (hasheq 'points approx 'job id 'path (make-path id))))) + (wait-for-job id)))) (define local-error-endpoint (post-with-json-response (lambda (post-data) @@ -523,83 +515,25 @@ #:profile? #f #:timeline-disabled? #t)) (define id (start-job command)) - (define result (wait-for-job id)) - (define local-error (job-result-backend result)) - ;; TODO: potentially unsafe if resugaring changes the AST - (define tree - (let loop ([expr expr] - [err local-error]) - (match expr - [(list op args ...) - ;; err => (List (listof Integer) List ...) - (hasheq 'e - (~a op) - 'avg-error - (format-bits (errors-score (first err))) - 'children - (map loop args (rest err)))] - [_ - ;; err => (List (listof Integer)) - (hasheq 'e - (~a expr) - 'avg-error - (format-bits (errors-score (first err))) - 'children - '())]))) - (hasheq 'tree tree 'job id 'path (make-path id))))) + (wait-for-job id)))) (define alternatives-endpoint - (post-with-json-response - (lambda (post-data) - (define formula (read-syntax 'web (open-input-string (hash-ref post-data 'formula)))) - (define sample (hash-ref post-data 'sample)) - (define seed (hash-ref post-data 'seed #f)) - (define test (parse-test formula)) - (define vars (test-vars test)) - (define repr (test-output-repr test)) - (define pcontext (json->pcontext sample (test-context test))) - (define command - (create-job 'alternatives - test - #:seed seed - #:pcontext pcontext - #:profile? #f - #:timeline-disabled? #f)) - (define id (start-job command)) - (define result (wait-for-job id)) - (match-define (list altns test-pcontext processed-pcontext) (job-result-backend result)) - (define splitpoints - (for/list ([alt altns]) - (splitpoints->json vars alt repr))) - - (define fpcores - (for/list ([altn altns]) - (~a (program->fpcore (alt-expr altn) (test-context test))))) - - (define histories - (for/list ([altn altns]) - (let ([os (open-output-string)]) - (parameterize ([current-output-port os]) - (write-xexpr - `(div ([id "history"]) - (ol ,@ - (render-history altn processed-pcontext test-pcontext (test-context test))))) - (get-output-string os))))) - (define derivations - (for/list ([altn altns]) - (render-json altn processed-pcontext test-pcontext (test-context test)))) - (hasheq 'alternatives - fpcores - 'histories - histories - 'derivations - derivations - 'splitpoints - splitpoints - 'job - id - 'path - (make-path id))))) + (post-with-json-response (lambda (post-data) + (define formula + (read-syntax 'web (open-input-string (hash-ref post-data 'formula)))) + (define sample (hash-ref post-data 'sample)) + (define seed (hash-ref post-data 'seed #f)) + (define test (parse-test formula)) + (define pcontext (json->pcontext sample (test-context test))) + (define command + (create-job 'alternatives + test + #:seed seed + #:pcontext pcontext + #:profile? #f + #:timeline-disabled? #t)) + (define id (start-job command)) + (wait-for-job id)))) (define ->mathjs-endpoint (post-with-json-response (lambda (post-data) @@ -618,9 +552,7 @@ (define command (create-job 'cost test #:seed #f #:pcontext #f #:profile? #f #:timeline-disabled? #f)) (define id (start-job command)) - (define result (wait-for-job id)) - (define cost (job-result-backend result)) - (hasheq 'cost cost 'job id 'path (make-path id))))) + (wait-for-job id)))) (define translate-endpoint (post-with-json-response (lambda (post-data) @@ -649,6 +581,8 @@ (hasheq 'result converted 'language target-lang)))) (define (run-demo #:quiet [quiet? #f] + #:threads [threads #f] + #:browser [browser? #t] #:output output #:demo? demo? #:prefix prefix @@ -659,25 +593,9 @@ (*demo-output* output) (*demo-prefix* prefix) (*demo-log* log) - - (define config - `(init rand - ,(get-seed) - flags - ,(*flags*) - num-iters - ,(*num-iterations*) - points - ,(*num-points*) - timeout - ,(*timeout*) - output-dir - ,(*demo-output*) - reeval - ,(*reeval-pts*) - demo? - ,(*demo?*))) - (start-job-server config *demo?* *demo-output*) + (unless threads + (set! threads (processor-count))) + (start-job-server threads) (unless quiet? (eprintf "Herbie ~a with seed ~a\n" *herbie-version* (get-seed)) @@ -686,10 +604,13 @@ (serve/servlet dispatch #:listen-ip (if public #f "127.0.0.1") #:port port + #:safety-limits + (make-safety-limits #:max-request-body-length + (* 5 1024 1024)) ; 5 mb body size for det44 bench mark. #:servlet-current-directory (current-directory) #:manager (create-none-manager #f) #:command-line? true - #:launch-browser? (not quiet?) + #:launch-browser? (and browser? (not quiet?)) #:banner? (not quiet?) #:servlets-root (web-resource) #:server-root-path (web-resource) diff --git a/src/api/sandbox.rkt b/src/api/sandbox.rkt index 7eb7f38fa..6027fff74 100644 --- a/src/api/sandbox.rkt +++ b/src/api/sandbox.rkt @@ -25,8 +25,10 @@ (submod "../utils/timeline.rkt" debug)) (provide run-herbie - get-table-data unparse-result + get-table-data + partition-pcontext + get-table-data-from-hash *reeval-pts* *timeout* (struct-out job-result) @@ -35,7 +37,7 @@ (struct job-result (command test status time timeline warnings backend)) (struct improve-result (preprocess pctxs start target end bogosity)) -(struct alt-analysis (alt train-errors test-errors)) +(struct alt-analysis (alt train-errors test-errors) #:prefab) ;; true if Racket CS <= 8.2 (define cs-places-workaround? @@ -336,6 +338,113 @@ link '())) +(define (dummy-table-row-from-hash result-hash status link) + (define test (hash-ref result-hash 'test)) + (define repr (test-output-repr test)) + (define preprocess + (if (eq? (hash-ref result-hash 'status) 'success) + (hash-ref (hash-ref result-hash 'backend) 'preprocessing) + (test-preprocess test))) + (table-row (test-name test) + (test-identifier test) + status + (prog->fpcore (test-pre test)) + preprocess + (representation-name repr) + '() ; TODO: eliminate field + (test-vars test) + (map car (hash-ref result-hash 'warnings)) + (prog->fpcore (test-input test)) + #f + (prog->fpcore (test-spec test)) + (test-output test) + #f + #f + #f + #f + #f + (hash-ref result-hash 'time) + link + '())) + +(define (get-table-data-from-hash result-hash link) + (define test (hash-ref result-hash 'test)) + (define backend (hash-ref result-hash 'backend)) + (define status (hash-ref result-hash 'status)) + (match status + ['success + (define start (hash-ref backend 'start)) + (define targets (hash-ref backend 'target)) + (define end (hash-ref backend 'end)) + (define expr-cost (platform-cost-proc (*active-platform*))) + (define repr (test-output-repr test)) + + ; starting expr analysis + (match-define (alt-analysis start-alt start-train-errs start-test-errs) start) + (define start-expr (alt-expr start-alt)) + (define start-train-score (errors-score start-train-errs)) + (define start-test-score (errors-score start-test-errs)) + (define start-cost (expr-cost start-expr repr)) + + (define target-cost-score + (for/list ([target targets]) + (define target-expr (alt-expr (alt-analysis-alt target))) + (define tar-cost (expr-cost target-expr repr)) + (define tar-score (errors-score (alt-analysis-test-errors target))) + + (list tar-cost tar-score))) + + ; Important to calculate value of status + (define best-score + (if (null? target-cost-score) target-cost-score (apply min (map second target-cost-score)))) + + (define end-exprs (hash-ref end 'end-alts)) + (define end-train-scores (map errors-score (hash-ref end 'end-train-scores))) + (define end-test-scores (map errors-score (hash-ref end 'end-errors))) + (define end-costs (hash-ref end 'end-costs)) + + ; terribly formatted pareto-optimal frontier + (define cost&accuracy + (list (list start-cost start-test-score) + (list (car end-costs) (car end-test-scores)) + (map list (cdr end-costs) (cdr end-test-scores) (cdr end-exprs)))) + + (define fuzz 0.1) + (define end-est-score (car end-train-scores)) + (define end-score (car end-test-scores)) + (define status + (if (not (null? best-score)) + (begin + (cond + [(< end-score (- best-score fuzz)) "gt-target"] + [(< end-score (+ best-score fuzz)) "eq-target"] + [(> end-score (+ start-test-score fuzz)) "lt-start"] + [(> end-score (- start-test-score fuzz)) "eq-start"] + [(> end-score (+ best-score fuzz)) "lt-target"])) + + (cond + [(and (< start-test-score 1) (< end-score (+ start-test-score 1))) "ex-start"] + [(< end-score (- start-test-score 1)) "imp-start"] + [(< end-score (+ start-test-score fuzz)) "apx-start"] + [else "uni-start"]))) + + (struct-copy table-row + (dummy-table-row-from-hash result-hash status link) + [start-est start-train-score] + [start start-test-score] + [target target-cost-score] + [result-est end-est-score] + [result end-score] + [output + (test-input (parse-test (read-syntax 'web (open-input-string (car end-exprs)))))] + [cost-accuracy cost&accuracy])] + ['failure + (match-define (list 'exn type _ ...) backend) + (define status (if type "error" "crash")) + (dummy-table-row-from-hash result-hash status link)] + ['timeout (dummy-table-row-from-hash result-hash "timeout" link)] + [_ (error 'get-table-data "unknown result type ~a" status)])) + (define (get-table-data result link) (match-define (job-result command test status time _ _ backend) result) (match status @@ -408,8 +517,8 @@ [output (car end-exprs)] [cost-accuracy cost&accuracy])] ['failure - (define exn backend) - (define status (if (exn:fail:user:herbie? exn) "error" "crash")) + (match-define (list 'exn type _ ...) backend) + (define status (if type "error" "crash")) (dummy-table-row result status link)] ['timeout (dummy-table-row result "timeout" link)] [_ (error 'get-table-data "unknown result type ~a" status)])) diff --git a/src/api/server.rkt b/src/api/server.rkt index c11c14e0c..78babd0f8 100644 --- a/src/api/server.rkt +++ b/src/api/server.rkt @@ -1,48 +1,43 @@ #lang racket (require openssl/sha1) +(require (only-in xml write-xexpr) + json) (require "sandbox.rkt" - "../config.rkt" - "../syntax/read.rkt") -(require (submod "../utils/timeline.rkt" debug)) + "../core/preprocess.rkt" + "../core/points.rkt" + "../reports/history.rkt" + "../reports/plot.rkt" + "../reports/common.rkt" + "../syntax/types.rkt" + "../syntax/read.rkt" + "../syntax/sugar.rkt" + "../syntax/load-plugin.rkt" + "../utils/alternative.rkt" + "../utils/common.rkt" + "../utils/errors.rkt" + "../utils/float.rkt") -(provide completed-job? - make-path +(provide make-path + get-improve-table-data + make-improve-result get-results-for - get-improve-job-data job-count is-server-up create-job start-job - is-job-finished wait-for-job start-job-server) -#| Job Server Public API section |# -; computes the path used for server URLs -(define (make-path id) - (format "~a.~a" id *herbie-commit*)) - -; Helers to isolated *completed-jobs* -(define (completed-job? id) - (hash-has-key? *completed-jobs* id)) - -; Returns #f is now job exsist for the given job-id -(define (get-results-for id) - (hash-ref *completed-jobs* id #f)) - -; I don't like how specific this function is but it keeps the API boundary. -(define (get-improve-job-data) - (for/list ([(k v) (in-hash *completed-jobs*)] - #:when (equal? (job-result-command v) 'improve)) - (get-table-data v (make-path k)))) +; verbose logging for debugging +(define verbose #f) ; Maybe change to log-level and use 'verbose? +(define (log msg . args) + (when verbose + (apply eprintf msg args))) -(define (job-count) - (hash-count *job-status*)) - -(define (is-server-up) - (thread-running? *worker-thread*)) +;; Job object, What herbie excepts as input for a new job. +(struct herbie-command (command test seed pcontext profile? timeline-disabled?) #:prefab) ;; Creates a command object to be passed to start-job server. ;; TODO contract? @@ -54,75 +49,69 @@ #:timeline-disabled? [timeline-disabled? #f]) (herbie-command command test seed pcontext profile? timeline-disabled?)) +; computes the path used for server URLs +(define (make-path id) + (format "~a.~a" id *herbie-commit*)) + +; Returns #f is now job exsist for the given job-id +(define (get-results-for job-id) + (define-values (a b) (place-channel)) + (place-channel-put manager (list 'result job-id b)) + (log "Getting result for job: ~a.\n" job-id) + (place-channel-get a)) + +(define (get-improve-table-data) + (define-values (a b) (place-channel)) + (place-channel-put manager (list 'improve b)) + (log "Getting improve results.\n") + (place-channel-get a)) + +(define (job-count) + (define-values (a b) (place-channel)) + (place-channel-put manager (list 'count b)) + (define count (place-channel-get a)) + (log "Current job count: ~a.\n" count) + count) + ;; Starts a job for a given command object| (define (start-job command) (define job-id (compute-job-id command)) - (if (already-computed? job-id) job-id (start-work command))) - -(define (is-job-finished job-id) - (hash-ref *job-status* job-id #f)) + (place-channel-put manager (list 'start manager command job-id)) + (log "Job ~a, Qed up for program: ~a\n" job-id (test-name (herbie-command-test command))) + job-id) (define (wait-for-job job-id) - (if (already-computed? job-id) (hash-ref *completed-jobs* job-id) (internal-wait-for-job job-id))) + (define-values (a b) (place-channel)) + (place-channel-put manager (list 'wait manager job-id b)) + (define finished-result (place-channel-get a)) + (log "Done waiting for: ~a\n" job-id) + finished-result) -(define (start-job-server config global-demo global-output) - ;; Pass along local global values - ;; TODO can I pull these out of config or not need ot pass them along. - (set! *demo?* global-demo) - (set! *demo-output* global-output) - (thread-send *worker-thread* config)) +; TODO refactor using this helper. +(define (manager-ask msg . args) + (define-values (a b) (place-channel)) + (place-channel-put manager (cons msg b args)) + (log "Asking manager: ~a, ~a.\n" msg args) + (place-channel-get a)) -#| End Job Server Public API section |# +(define (is-server-up) + (not (sync/timeout 0 manager-dead-event))) -;; Job object, What herbie excepts as input for a new job. -(struct herbie-command (command test seed pcontext profile? timeline-disabled?) #:transparent) - -; Private globals -; TODO I'm sure these can encapslated some how. -(define *demo?* (make-parameter false)) -(define *demo-output* (make-parameter false)) -(define *completed-jobs* (make-hash)) -(define *job-status* (make-hash)) -(define *job-sema* (make-hash)) - -(define (already-computed? job-id) - (or (hash-has-key? *completed-jobs* job-id) - (and (*demo-output*) (directory-exists? (build-path (*demo-output*) (make-path job-id)))))) - -(define (internal-wait-for-job job-id) - (eprintf "Waiting for job\n") - (define sema (hash-ref *job-sema* job-id)) - (semaphore-wait sema) - (hash-remove! *job-sema* job-id) - (hash-ref *completed-jobs* job-id)) +(define (start-job-server job-cap) + (define r (make-manager job-cap)) + (set! manager-dead-event (place-dead-evt r)) + (set! manager r)) -(define (compute-job-id job-info) - (sha1 (open-input-string (~s job-info)))) +(define manager #f) +(define manager-dead-event #f) -; Encapsulates semaphores and async part of jobs. -(define (start-work job) - (define job-id (compute-job-id job)) - (hash-set! *job-status* job-id (*timeline*)) - (define sema (make-semaphore)) - (hash-set! *job-sema* job-id sema) - (thread-send *worker-thread* (work job-id job sema)) - job-id) +(define (get-command herbie-result) + ; force symbol type to string. + ; This is a HACK to fix JSON parsing errors that may or may not still happen. + (~s (job-result-command herbie-result))) -; Handles semaphore and async part of a job -(struct work (id job sema)) - -(define (run-job job-info) - (match-define (work job-id info sema) job-info) - (define path (make-path job-id)) - (cond ;; Check caches if job as already been completed - [(hash-has-key? *completed-jobs* job-id) (semaphore-post sema)] - [(and (*demo-output*) (directory-exists? (build-path (*demo-output*) path))) - (semaphore-post sema)] - [else - (wrapper-run-herbie info job-id) - (hash-remove! *job-status* job-id) - (semaphore-post sema)]) - (hash-remove! *job-sema* job-id)) +(define (compute-job-id job-info) + (sha1 (open-input-string (~s job-info)))) (define (wrapper-run-herbie cmd job-id) (print-job-message (herbie-command-command cmd) job-id (test-name (herbie-command-test cmd))) @@ -133,8 +122,8 @@ #:pcontext (herbie-command-pcontext cmd) #:profile? (herbie-command-profile? cmd) #:timeline-disabled? (herbie-command-timeline-disabled? cmd))) - (hash-set! *completed-jobs* job-id result) - (eprintf "Job ~a complete\n" job-id)) + (eprintf "Herbie completed job: ~a\n" job-id) + result) (define (print-job-message command job-id job-str) (define job-label @@ -151,33 +140,357 @@ [_ (error 'compute-result "unknown command ~a" command)])) (eprintf "~a Job ~a started:\n ~a ~a...\n" job-label (symbol->string command) job-id job-str)) -(define *worker-thread* - (thread (λ () - (let loop ([seed #f]) - (match (thread-receive) - [`(init rand - ,vec - flags - ,flag-table - num-iters - ,iterations - points - ,points - timeout - ,timeout - output-dir - ,output - reeval - ,reeval - demo? - ,demo?) - (set! seed vec) - (*flags* flag-table) - (*num-iterations* iterations) - (*num-points* points) - (*timeout* timeout) - (*demo-output* output) - (*reeval-pts* reeval) - (*demo?* demo?)] - [job-info (run-job job-info)]) - (loop seed))))) +(define-syntax (place/context* stx) + (syntax-case stx () + [(_ name #:parameters (params ...) body ...) + (with-syntax ([(fresh ...) (generate-temporaries #'(params ...))]) + #'(let ([fresh (params)] ...) + (place/context name + (parameterize ([params fresh] ...) + body ...))))])) + +(struct work-item (command id)) + +(define (make-manager worker-count) + (place/context* + ch + #:parameters (*flags* *num-iterations* + *num-points* + *timeout* + *reeval-pts* + *node-limit* + *max-find-range-depth* + *pareto-mode* + *platform-name* + *loose-plugins*) + (parameterize ([current-error-port (open-output-nowhere)]) ; hide output + (load-herbie-plugins)) + ; not sure if the above code is actaully needed. + (define completed-work (make-hash)) + (define busy-workers (make-hash)) + (define waiting-workers (make-hash)) + (for ([i (in-range worker-count)]) + (hash-set! waiting-workers i (make-worker i))) + (log "~a workers ready.\n" (hash-count waiting-workers)) + (define waiting (make-hash)) + (define job-queue (list)) + (log "Manager waiting to assign work.\n") + (for ([i (in-naturals)]) + ; (eprintf "manager msg ~a handled\n" i) + (match (place-channel-get ch) + [(list 'start self command job-id) + ; Check if the work has been completed already if not assign the work. + (if (hash-has-key? completed-work job-id) + (place-channel-put self (list 'send job-id (hash-ref completed-work job-id))) + (place-channel-put self (list 'queue self job-id command)))] + [(list 'queue self job-id command) + (set! job-queue (append job-queue (list (work-item command job-id)))) + (place-channel-put self (list 'assign self))] + [(list 'assign self) + (define reassigned (make-hash)) + (for ([(wid worker) (in-hash waiting-workers)] + [job (in-list job-queue)]) + (log "Starting worker [~a] on [~a].\n" + (work-item-id job) + (test-name (herbie-command-test (work-item-command job)))) + (place-channel-put worker (list 'apply self (work-item-command job) (work-item-id job))) + (hash-set! reassigned wid worker) + (hash-set! busy-workers wid worker)) + ; remove X many jobs from the Q and update waiting-workers + (for ([(wid worker) (in-hash reassigned)]) + (hash-remove! waiting-workers wid) + (set! job-queue (cdr job-queue)))] + ; Job is finished save work and free worker. Move work to 'send state. + [(list 'finished self wid job-id result) + (log "Job ~a finished, saving result.\n" job-id) + (hash-set! completed-work job-id result) + + ; move worker to waiting list + (hash-set! waiting-workers wid (hash-ref busy-workers wid)) + (hash-remove! busy-workers wid) + + (log "waiting job ~a completed\n" job-id) + (place-channel-put self (list 'send job-id result)) + (place-channel-put self (list 'assign self))] + [(list 'wait self job-id handler) + (log "Waiting for job: ~a\n" job-id) + ; first we add the handler to the wait list. + (hash-update! waiting job-id (curry append (list handler)) '()) + (define result (hash-ref completed-work job-id #f)) + ; check if the job is completed or not. + (unless (false? result) + (log "Done waiting for job: ~a\n" job-id) + ; we have a result to send. + (place-channel-put self (list 'send job-id result)))] + [(list 'send job-id result) + (log "Sending result for ~a.\n" job-id) + (for ([handle (hash-ref waiting job-id '())]) + (place-channel-put handle result)) + (hash-remove! waiting job-id)] + ; Get the result for the given id, return false if no work found. + [(list 'result job-id handler) (place-channel-put handler (hash-ref completed-work job-id #f))] + ; Returns the current count of working workers. + [(list 'count handler) (place-channel-put handler (hash-count busy-workers))] + ; Retreive the improve results for results.json + [(list 'improve handler) + (define improved-list + (for/list ([(job-id result) (in-hash completed-work)] + #:when (equal? (hash-ref result 'command) "improve")) + (get-table-data-from-hash result (make-path job-id)))) + (place-channel-put handler improved-list)])))) + +(define (make-worker worker-id) + (place/context* + ch + #:parameters (*flags* *num-iterations* + *num-points* + *timeout* + *reeval-pts* + *node-limit* + *max-find-range-depth* + *pareto-mode* + *platform-name* + *loose-plugins*) + (parameterize ([current-error-port (open-output-nowhere)]) ; hide output + (load-herbie-plugins)) + (for ([_ (in-naturals)]) + (match (place-channel-get ch) + [(list 'apply manager command job-id) + (log "[~a] working on [~a].\n" job-id (test-name (herbie-command-test command))) + (define herbie-result (wrapper-run-herbie command job-id)) + (match-define (job-result kind test status time _ _ backend) herbie-result) + (define out-result + (match kind + ['alternatives (make-alternatives-result herbie-result test job-id)] + ['evaluate (make-calculate-result herbie-result job-id)] + ['cost (make-cost-result herbie-result job-id)] + ['errors (make-error-result herbie-result job-id)] + ['exacts (make-exacts-result herbie-result job-id)] + ['improve (make-improve-result herbie-result test job-id)] + ['local-error (make-local-error-result herbie-result test job-id)] + ['explanations (make-explanation-result herbie-result job-id)] + ['sample (make-sample-result herbie-result test job-id)] + [_ (error 'compute-result "unknown command ~a" kind)])) + (log "Job: ~a finished, returning work to manager\n" job-id) + (place-channel-put manager (list 'finished manager worker-id job-id out-result))])))) + +(define (make-explanation-result herbie-result job-id) + (define explanations (job-result-backend herbie-result)) + (hasheq 'command + (get-command herbie-result) + 'explanation + explanations + 'job + job-id + 'path + (make-path job-id))) + +(define (make-local-error-result herbie-result test job-id) + (define expr (prog->fpcore (test-input test))) + (define local-error (job-result-backend herbie-result)) + ;; TODO: potentially unsafe if resugaring changes the AST + (define tree + (let loop ([expr expr] + [err local-error]) + (match expr + [(list op args ...) + ;; err => (List (listof Integer) List ...) + (hasheq 'e + (~a op) + 'avg-error + (format-bits (errors-score (first err))) + 'children + (map loop args (rest err)))] + ;; err => (List (listof Integer)) + [_ (hasheq 'e (~a expr) 'avg-error (format-bits (errors-score (first err))) 'children '())]))) + (hasheq 'command (get-command herbie-result) 'tree tree 'job job-id 'path (make-path job-id))) + +(define (make-sample-result herbie-result test job-id) + (define pctx (job-result-backend herbie-result)) + (define repr (context-repr (test-context test))) + (hasheq 'command + (get-command herbie-result) + 'points + (pcontext->json pctx repr) + 'job + job-id + 'path + (make-path job-id))) + +(define (make-calculate-result herbie-result job-id) + (hasheq 'command + (get-command herbie-result) + 'points + (job-result-backend herbie-result) + 'job + job-id + 'path + (make-path job-id))) + +(define (make-cost-result herbie-result job-id) + (hasheq 'command + (get-command herbie-result) + 'cost + (job-result-backend herbie-result) + 'job + job-id + 'path + (make-path job-id))) + +(define (make-error-result herbie-result job-id) + (define errs + (for/list ([pt&err (job-result-backend herbie-result)]) + (define pt (first pt&err)) + (define err (second pt&err)) + (list pt (format-bits (ulps->bits err))))) + (hasheq 'command (get-command herbie-result) 'points errs 'job job-id 'path (make-path job-id))) + +(define (make-exacts-result herbie-result job-id) + (hasheq 'command + (get-command herbie-result) + 'points + (job-result-backend herbie-result) + 'job + job-id + 'path + (make-path job-id))) + +(define (make-improve-result herbie-result test job-id) + (define ctx (context->json (test-context test))) + (define backend (job-result-backend herbie-result)) + (define job-time (job-result-time herbie-result)) + (define warnings (job-result-warnings herbie-result)) + (define timeline (job-result-timeline herbie-result)) + + (define repr (test-output-repr test)) + (define backend-hash + (match (job-result-status herbie-result) + ['success (backend-improve-result-hash-table backend repr test)] + ['timeout #f] + ['failure (exception->datum backend)])) + + (hasheq 'command + (get-command herbie-result) + 'status + (job-result-status herbie-result) + 'test + test + 'ctx + ctx + 'time + job-time + 'warnings + warnings + 'timeline + timeline + 'backend + backend-hash + 'job + job-id + 'path + (make-path job-id))) + +(define (backend-improve-result-hash-table backend repr test) + (define pcontext (improve-result-pctxs backend)) + + (define preprocessing (improve-result-preprocess backend)) + (define end-hash-table (end-hash (improve-result-end backend) repr pcontext test)) + + (hasheq 'preprocessing + preprocessing + 'pctxs + pcontext + 'start + (improve-result-start backend) + 'target + (improve-result-target backend) + 'end + end-hash-table + 'bogosity + (improve-result-bogosity backend))) + +(define (end-hash end repr pcontexts test) + (define-values (end-alts train-errors end-errors end-costs) + (for/lists (l1 l2 l3 l4) + ([analysis end]) + (match-define (alt-analysis alt train-errors test-errs) analysis) + (values alt train-errors test-errs (alt-cost alt repr)))) + (define fpcores + (for/list ([altn end-alts]) + (~a (program->fpcore (alt-expr altn) (test-context test))))) + (define alts-histories + (for/list ([alt end-alts]) + (render-history alt (first pcontexts) (second pcontexts) (test-context test)))) + (define vars (test-vars test)) + (define end-alt (alt-analysis-alt (car end))) + (define splitpoints + (for/list ([var vars]) + (define split-var? (equal? var (regime-var end-alt))) + (if split-var? + (for/list ([val (regime-splitpoints end-alt)]) + (real->ordinal (repr->real val repr) repr)) + '()))) + + (hasheq 'end-alts + fpcores + 'end-histories + alts-histories + 'end-train-scores + train-errors + 'end-errors + end-errors + 'end-costs + end-costs + 'splitpoints + splitpoints)) + +(define (context->json ctx) + (hasheq 'vars (context-vars ctx) 'repr (repr->json (context-repr ctx)))) + +(define (repr->json repr) + (hasheq 'name (representation-name repr) 'type (representation-type repr))) + +(define (make-alternatives-result herbie-result test job-id) + + (define vars (test-vars test)) + (define repr (test-output-repr test)) + + (match-define (list altns test-pcontext processed-pcontext) (job-result-backend herbie-result)) + (define splitpoints + (for/list ([alt altns]) + (for/list ([var vars]) + (define split-var? (equal? var (regime-var alt))) + (if split-var? + (for/list ([val (regime-splitpoints alt)]) + (real->ordinal (repr->real val repr) repr)) + '())))) + + (define fpcores + (for/list ([altn altns]) + (~a (program->fpcore (alt-expr altn) (test-context test))))) + + (define histories + (for/list ([altn altns]) + (let ([os (open-output-string)]) + (parameterize ([current-output-port os]) + (write-xexpr + `(div ([id "history"]) + (ol ,@(render-history altn processed-pcontext test-pcontext (test-context test))))) + (get-output-string os))))) + (define derivations + (for/list ([altn altns]) + (render-json altn processed-pcontext test-pcontext (test-context test)))) + (hasheq 'command + (get-command herbie-result) + 'alternatives + fpcores + 'histories + histories + 'derivations + derivations + 'splitpoints + splitpoints + 'job + job-id + 'path + (make-path job-id))) diff --git a/src/api/thread-pool.rkt b/src/api/thread-pool.rkt index 04ad67c8b..f44b3b94d 100644 --- a/src/api/thread-pool.rkt +++ b/src/api/thread-pool.rkt @@ -3,6 +3,7 @@ (require racket/place) (require "../utils/common.rkt" "sandbox.rkt" + "server.rkt" "../syntax/load-plugin.rkt" "../reports/pages.rkt" "../syntax/read.rkt" @@ -17,6 +18,7 @@ (format "~a-~a" index (substring replaced 0 (min (string-length replaced) 50)))) (define (run-test index test #:seed seed #:profile profile? #:dir dir) + (define bad-id -1) ; TODO move this code to using server.rkt (cond [dir (define dirname (graph-folder-path (test-name test) index)) @@ -31,23 +33,24 @@ #:exists 'replace (λ (pp) (run-herbie 'improve test #:seed seed #:profile? pp)))] [else (run-herbie test 'improve #:seed seed)])) - + (define improve-hash (make-improve-result result test bad-id)) (set-seed! seed) (define error? #f) - (for ([page (all-pages result)]) - (call-with-output-file (build-path rdir page) - #:exists 'replace - (λ (out) - (with-handlers ([exn:fail? (λ (e) - ((page-error-handler result page out) e) - (set! error? #t))]) - (make-page page out result #t profile?))))) - - (define out (get-table-data result dirname)) + (for ([page (all-pages improve-hash)]) + (call-with-output-file + (build-path rdir page) + #:exists 'replace + (λ (out) + (with-handlers ([exn:fail? (λ (e) + ((page-error-handler improve-hash page out) e) + (set! error? #t))]) + (make-page page out improve-hash #t profile?))))) + + (define out (get-table-data-from-hash improve-hash dirname)) (if error? (struct-copy table-row out [status "crash"]) out)] [else (define result (run-herbie 'improve test #:seed seed)) - (get-table-data result "")])) + (get-table-data-from-hash (make-improve-result result test bad-id) "")])) (define-syntax (place/context* stx) (syntax-case stx () diff --git a/src/config.rkt b/src/config.rkt index d2eb8bc6a..046fb0d4a 100644 --- a/src/config.rkt +++ b/src/config.rkt @@ -180,8 +180,13 @@ (for ([fn (in-list resetters)]) (fn))) -(define-syntax-rule (define/reset name value) - (define name - (let ([param (make-parameter value)]) - (register-resetter! (λ () (name value))) - param))) +(define-syntax define/reset + (syntax-rules () + ; default resetter sets parameter to `value` + [(_ name value) (define/reset name value (λ () (name value)))] + ; initial value and resetter + [(_ name value reset-fn) + (define name + (let ([param (make-parameter value)]) + (register-resetter! reset-fn) + param))])) diff --git a/src/core/batch.rkt b/src/core/batch.rkt index cbf37bf29..ee99c6963 100644 --- a/src/core/batch.rkt +++ b/src/core/batch.rkt @@ -6,7 +6,7 @@ (provide progs->batch batch->progs (struct-out batch) - get-expr + batch-ref expand-taylor) (struct batch ([nodes #:mutable] [roots #:mutable] vars [nodes-length #:mutable])) @@ -82,41 +82,25 @@ (unmunge root))) exprs) -(define (expand-taylor batch) - (define vars (batch-vars batch)) +(define (expand-taylor input-batch) + (define vars (batch-vars input-batch)) + (define nodes (batch-nodes input-batch)) + (define roots (batch-roots input-batch)) + + ; Hash to avoid duplications (define icache (reverse vars)) (define exprhash (make-hash (for/list ([var vars] [i (in-naturals)]) (cons var i)))) - ; Counts (define exprc 0) (define varc (length vars)) - ; Translates programs into an instruction sequence of operations - (define (munge prog) - (define node ; This compiles to the register machine - (match prog - [(list '- arg1 arg2) `(+ ,(munge arg1) ,(munge `(neg ,arg2)))] - [(list 'pow base 1/2) `(sqrt ,(munge base))] - [(list 'pow base 1/3) `(cbrt ,(munge base))] - [(list 'pow base 2/3) `(cbrt ,(munge `(* ,base ,base)))] - [(list 'pow base power) - #:when (exact-integer? power) - `(pow ,(munge base) ,(munge power))] - [(list 'pow base power) `(exp ,(munge `(* ,power (log ,base))))] - [(list 'tan args) `(/ ,(munge `(sin ,args)) ,(munge `(cos ,args)))] - [(list 'cosh args) `(* ,(munge 1/2) ,(munge `(+ (exp ,args) (/ 1 (exp ,args)))))] - [(list 'sinh args) `(* ,(munge 1/2) ,(munge `(+ (exp ,args) (neg (/ 1 (exp ,args))))))] - [(list 'tanh args) - `(/ ,(munge `(+ (exp ,args) (neg (/ 1 (exp ,args))))) - ,(munge `(+ (exp ,args) (/ 1 (exp ,args)))))] - [(list 'asinh args) `(log ,(munge `(+ ,args (sqrt (+ (* ,args ,args) 1)))))] - [(list 'acosh args) `(log ,(munge `(+ ,args (sqrt (+ (* ,args ,args) -1)))))] - [(list 'atanh args) `(* ,(munge 1/2) ,(munge `(log (/ (+ 1 ,args) (+ 1 (neg ,args))))))] - [(list op args ...) (cons op (map munge args))] - [(approx spec impl) (approx spec (munge impl))] - [_ prog])) + ; Mapping from nodes to nodes* + (define mappings (build-vector (batch-nodes-length input-batch) values)) + + ; Adding a node to hash + (define (append-node node) (hash-ref! exprhash node (lambda () @@ -124,29 +108,148 @@ (set! exprc (+ 1 exprc)) (set! icache (cons node icache)))))) - (set-batch-roots! batch (list->vector (map munge (batch->progs batch)))) - (set-batch-nodes! batch (list->vector (reverse icache))) - (set-batch-nodes-length! batch (vector-length (batch-nodes batch)))) + ; Sequential rewriting + (for ([node (in-vector nodes)] + [n (in-naturals)]) + (match node + [(list '- arg1 arg2) + (define neg-index (append-node `(neg ,(vector-ref mappings arg2)))) + (vector-set! mappings n (append-node `(+ ,(vector-ref mappings arg1) ,neg-index)))] + [(list 'pow base power) + #:when (equal? (vector-ref nodes power) 1/2) ; 1/2 is to be removed from exprhash + (vector-set! mappings n (append-node `(sqrt ,(vector-ref mappings base))))] + [(list 'pow base power) + #:when (equal? (vector-ref nodes power) 1/3) ; 1/3 is to be removed from exprhash + (vector-set! mappings n (append-node `(cbrt ,(vector-ref mappings base))))] + [(list 'pow base power) + #:when (equal? (vector-ref nodes power) 2/3) ; 2/3 is to be removed from exprhash + (define mult-index (append-node `(* ,(vector-ref mappings base) ,(vector-ref mappings base)))) + (vector-set! mappings n (append-node `(cbrt ,mult-index)))] + [(list 'pow base power) + #:when (exact-integer? (vector-ref nodes power)) + (vector-set! mappings + n + (append-node `(pow ,(vector-ref mappings base) ,(vector-ref mappings power))))] + [(list 'pow base power) + (define log-idx (append-node `(log ,(vector-ref mappings base)))) + (define mult-idx (append-node `(* ,(vector-ref mappings power) ,log-idx))) + (vector-set! mappings n (append-node `(exp ,mult-idx)))] + [(list 'tan args) + (define sin-idx (append-node `(sin ,(vector-ref mappings args)))) + (define cos-idx (append-node `(cos ,(vector-ref mappings args)))) + (vector-set! mappings n (append-node `(/ ,sin-idx ,cos-idx)))] + [(list 'cosh args) + (define exp-idx (append-node `(exp ,(vector-ref mappings args)))) + (define one-idx (append-node 1)) ; should it be 1 or literal 1 or smth? + (define inv-exp-idx (append-node `(/ ,one-idx ,exp-idx))) + (define add-idx (append-node `(+ ,exp-idx ,inv-exp-idx))) + (define half-idx (append-node 1/2)) + (vector-set! mappings n (append-node `(* ,half-idx ,add-idx)))] + [(list 'sinh args) + (define exp-idx (append-node `(exp ,(vector-ref mappings args)))) + (define one-idx (append-node 1)) + (define inv-exp-idx (append-node `(/ ,one-idx ,exp-idx))) + (define neg-idx (append-node `(neg ,inv-exp-idx))) + (define add-idx (append-node `(+ ,exp-idx ,neg-idx))) + (define half-idx (append-node 1/2)) + (vector-set! mappings n (append-node `(* ,half-idx ,add-idx)))] + [(list 'tanh args) + (define exp-idx (append-node `(exp ,(vector-ref mappings args)))) + (define one-idx (append-node 1)) + (define inv-exp-idx (append-node `(/ ,one-idx ,exp-idx))) + (define neg-idx (append-node `(neg ,inv-exp-idx))) + (define add-idx (append-node `(+ ,exp-idx ,inv-exp-idx))) + (define sub-idx (append-node `(+ ,exp-idx ,neg-idx))) + (vector-set! mappings n (append-node `(/ ,sub-idx ,add-idx)))] + [(list 'asinh args) + (define mult-idx (append-node `(* ,(vector-ref mappings args) ,(vector-ref mappings args)))) + (define one-idx (append-node 1)) + (define add-idx (append-node `(+ ,mult-idx ,one-idx))) + (define sqrt-idx (append-node `(sqrt ,add-idx))) + (define add2-idx (append-node `(+ ,(vector-ref mappings args) ,sqrt-idx))) + (vector-set! mappings n (append-node `(log ,add2-idx)))] + [(list 'acosh args) + (define mult-idx (append-node `(* ,(vector-ref mappings args) ,(vector-ref mappings args)))) + (define -one-idx (append-node -1)) + (define add-idx (append-node `(+ ,mult-idx ,-one-idx))) + (define sqrt-idx (append-node `(sqrt ,add-idx))) + (define add2-idx (append-node `(+ ,(vector-ref mappings args) ,sqrt-idx))) + (vector-set! mappings n (append-node `(log ,add2-idx)))] + [(list 'atanh args) + (define neg-idx (append-node `(neg ,(vector-ref mappings args)))) + (define one-idx (append-node 1)) + (define add-idx (append-node `(+ ,one-idx ,(vector-ref mappings args)))) + (define sub-idx (append-node `(+ ,one-idx ,neg-idx))) + (define div-idx (append-node `(/ ,add-idx ,sub-idx))) + (define log-idx (append-node `(log ,div-idx))) + (define half-idx (append-node 1/2)) + (vector-set! mappings n (append-node `(* ,half-idx ,log-idx)))] + [(list op args ...) + (vector-set! mappings n (append-node (cons op (map (curry vector-ref mappings) args))))] + [(approx spec impl) + (vector-set! mappings n (append-node (approx spec (vector-ref mappings impl))))] + [_ (vector-set! mappings n (append-node node))])) + + (define roots* (vector-map (curry vector-ref mappings) roots)) + (define nodes* (list->vector (reverse icache))) + + ; This may be too expensive to handle simple 1/2, 1/3 and 2/3 zombie nodes.. + #;(remove-zombie-nodes (batch nodes* roots* vars (vector-length nodes*))) + + (batch nodes* roots* vars (vector-length nodes*))) + +; The function removes any zombie nodes from batch +(define (remove-zombie-nodes input-batch) + (define nodes (batch-nodes input-batch)) + (define roots (batch-roots input-batch)) + (define nodes-length (batch-nodes-length input-batch)) + + (define zombie-mask (make-vector nodes-length #t)) + (for ([root (in-vector roots)]) + (vector-set! zombie-mask root #f)) + (for ([node (in-vector nodes (- nodes-length 1) -1 -1)] + [zmb (in-vector zombie-mask (- nodes-length 1) -1 -1)] + #:when (not zmb)) + (match node + [(list op args ...) (map (λ (n) (vector-set! zombie-mask n #f)) args)] + [(approx spec impl) (vector-set! zombie-mask impl #f)] + [_ void])) + + (define mappings (build-vector nodes-length values)) + + (define nodes* '()) + (for ([node (in-vector nodes)] + [zmb (in-vector zombie-mask)] + [n (in-naturals)]) + (if zmb + (for ([i (in-range n nodes-length)]) + (vector-set! mappings i (sub1 (vector-ref mappings i)))) + (set! nodes* + (cons (match node + [(list op args ...) (cons op (map (curry vector-ref mappings) args))] + [(approx spec impl) (approx spec (vector-ref mappings impl))] + [_ node]) + nodes*)))) + (set! nodes* (list->vector (reverse nodes*))) + (define roots* (vector-map (curry vector-ref mappings) roots)) + (batch nodes* roots* (batch-vars input-batch) (vector-length nodes*))) -(define (get-expr nodes reg) +(define (batch-ref batch reg) (define (unmunge reg) - (define node (vector-ref nodes reg)) + (define node (vector-ref (batch-nodes batch) reg)) (match node [(approx spec impl) (approx spec (unmunge impl))] [(list op regs ...) (cons op (map unmunge regs))] [_ node])) (unmunge reg)) +; Tests for expand-taylor (module+ test (require rackunit) (define (test-expand-taylor expr) - (define batch (progs->batch (list expr))) - (expand-taylor batch) - (car (batch->progs batch))) - - (define (test-munge-unmunge expr [ignore-approx #t]) - (define batch (progs->batch (list expr) #:ignore-approx ignore-approx)) - (check-equal? (list expr) (batch->progs batch))) + (define batch (progs->batch (list expr) #:ignore-approx #f)) + (define batch* (expand-taylor batch)) + (car (batch->progs batch*))) (check-equal? '(* 1/2 (log (/ (+ 1 x) (+ 1 (neg x))))) (test-expand-taylor '(atanh x))) (check-equal? '(log (+ x (sqrt (+ (* x x) -1)))) (test-expand-taylor '(acosh x))) @@ -160,8 +263,21 @@ (check-equal? '(+ 1 (neg (* 1/2 (+ (exp (/ (sin 3) (cos 3))) (/ 1 (exp (/ (sin 3) (cos 3)))))))) (test-expand-taylor '(- 1 (cosh (tan 3))))) (check-equal? '(exp (* a (log x))) (test-expand-taylor '(pow x a))) + (check-equal? '(+ x (sin a)) (test-expand-taylor '(+ x (sin a)))) (check-equal? '(cbrt x) (test-expand-taylor '(pow x 1/3))) + (check-equal? '(cbrt (* x x)) (test-expand-taylor '(pow x 2/3))) (check-equal? '(+ 100 (cbrt x)) (test-expand-taylor '(+ 100 (pow x 1/3)))) + (check-equal? `(+ 100 (cbrt (* x ,(approx 2 3)))) + (test-expand-taylor `(+ 100 (pow (* x ,(approx 2 3)) 1/3)))) + (check-equal? `(+ ,(approx 2 3) (cbrt x)) (test-expand-taylor `(+ ,(approx 2 3) (pow x 1/3)))) + (check-equal? `(+ (cbrt x) ,(approx 2 1/3)) (test-expand-taylor `(+ (pow x 1/3) ,(approx 2 1/3))))) + +; Tests for progs->batch and batch->progs +(module+ test + (require rackunit) + (define (test-munge-unmunge expr [ignore-approx #t]) + (define batch (progs->batch (list expr) #:ignore-approx ignore-approx)) + (check-equal? (list expr) (batch->progs batch))) (test-munge-unmunge '(* 1/2 (+ (exp x) (neg (/ 1 (exp x)))))) (test-munge-unmunge @@ -171,3 +287,28 @@ (test-munge-unmunge `(+ (sin ,(approx '(* 1/2 (+ (exp x) (neg (/ 1 (exp x))))) '(+ 3 (* 25 (sin 6))))) 4) #f)) + +; Tests for remove-zombie-nodes +(module+ test + (require rackunit) + (define (zombie-test #:nodes nodes #:roots roots) + (define in-batch (batch nodes roots '() (vector-length nodes))) + (define out-batch (remove-zombie-nodes in-batch)) + (batch-nodes out-batch)) + + (check-equal? (vector 0 '(sqrt 0) 2 '(pow 2 1)) + (zombie-test #:nodes (vector 0 1 '(sqrt 0) 2 '(pow 3 2)) #:roots (vector 4))) + (check-equal? (vector 0 '(sqrt 0) '(exp 1)) + (zombie-test #:nodes (vector 0 6 '(pow 0 1) '(* 2 0) '(sqrt 0) '(exp 4)) + #:roots (vector 5))) + (check-equal? (vector 0 1/2 '(+ 0 1)) + (zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0)) #:roots (vector 2))) + (check-equal? (vector 0 (approx '(exp 2) 0)) + (zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0) (approx '(exp 2) 0)) + #:roots (vector 4))) + (check-equal? (vector 2 1/2 (approx '(* x x) 0) '(pow 1 2)) + (zombie-test #:nodes (vector 2 1/2 '(sqrt 0) '(cbrt 0) (approx '(* x x) 0) '(pow 1 4)) + #:roots (vector 5))) + (check-equal? (vector 2 1/2 '(sqrt 0) (approx '(* x x) 0) '(pow 1 3)) + (zombie-test #:nodes (vector 2 1/2 '(sqrt 0) '(cbrt 0) (approx '(* x x) 0) '(pow 1 4)) + #:roots (vector 5 2)))) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 8c21e00cc..12b376706 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -1,20 +1,16 @@ #lang racket (require egg-herbie - (only-in ffi/unsafe - malloc - memcpy - free - cast - ptr-set! - ptr-add - _byte - _pointer - _string/utf-8 - register-finalizer)) - -(require "rules.rkt" - "programs.rkt" + (only-in ffi/vector + make-u32vector + u32vector-length + u32vector-set! + u32vector-ref + list->u32vector + u32vector->list)) + +(require "programs.rkt" + "rules.rkt" "../syntax/platform.rkt" "../syntax/syntax.rkt" "../syntax/types.rkt" @@ -23,9 +19,7 @@ "../utils/timeline.rkt") (provide (struct-out egg-runner) - untyped-egg-extractor typed-egg-extractor - default-untyped-egg-cost-proc platform-egg-cost-proc default-egg-cost-proc make-egg-runner @@ -38,6 +32,18 @@ (require "../syntax/load-plugin.rkt") (load-herbie-builtins)) +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; FFI utils + +(define (u32vector-empty? x) + (zero? (u32vector-length x))) + +(define (in-u32vector vec) + (make-do-sequence + (lambda () + (define len (u32vector-length vec)) + (values (lambda (i) (u32vector-ref vec i)) add1 0 (lambda (i) (< i len)) #f #f)))) + ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; egg FFI shim ;; @@ -45,28 +51,6 @@ ;; - FFIRule: struct defined in egg-herbie ;; - EgraphIter: struct defined in egg-herbie -(define (make-raw-string s) - (define b (string->bytes/utf-8 s)) - (define n (bytes-length b)) - (define ptr (malloc 'raw (+ n 1))) - (memcpy ptr b n) - (ptr-set! ptr _byte n 0) - ptr) - -(define (make-ffi-rule rule) - (define name (make-raw-string (~a (rule-name rule)))) - (define lhs (make-raw-string (~a (rule-input rule)))) - (define rhs (make-raw-string (~a (rule-output rule)))) - (define p (make-FFIRule name lhs rhs)) - (register-finalizer p free-ffi-rule) - p) - -(define (free-ffi-rule rule) - (free (FFIRule-name rule)) - (free (FFIRule-left rule)) - (free (FFIRule-right rule)) - (free rule)) - ;; Wrapper around Rust-allocated egg runner (struct egraph-data (egraph-pointer ; FFI pointer to runner @@ -85,24 +69,92 @@ eg-data [egraph-pointer (egraph_copy (egraph-data-egraph-pointer eg-data))])) -;; result function is a function that takes the ids of the nodes -(define (egraph-add-expr eg-data expr ctx) - (match-define (egraph-data ptr _ _ id->spec) eg-data) - ; add the expression to the e-graph and save the root e-class id - (define egg-expr (expr->egg-expr expr eg-data ctx)) - (define root-id (egraph_add_expr ptr (~a egg-expr))) - ; record all approx specs - (let loop ([egg-expr egg-expr]) - (match egg-expr - [(? number?) (void)] - [(? symbol?) (void)] - [(list '$approx spec impl) - (define id (egraph_add_expr ptr (~a spec))) - (hash-ref! id->spec id (lambda () spec)) - (loop impl)] - [(list _ args ...) (for-each loop args)])) - ; return the id - root-id) +; Adds expressions returning the root ids +; TODO: take a batch rather than list of expressions +(define (egraph-add-exprs egg-data exprs ctx) + (match-define (egraph-data ptr herbie->egg-dict egg->herbie-dict id->spec) egg-data) + + ; lookups the egg name of a variable + (define (normalize-var x) + (hash-ref! herbie->egg-dict + x + (lambda () + (define id (hash-count herbie->egg-dict)) + (define replacement (string->symbol (format "$h~a" id))) + (hash-set! egg->herbie-dict replacement (cons x (context-lookup ctx x))) + replacement))) + + ; normalizes an approx spec + (define (normalize-spec expr) + (match expr + [(? number?) expr] + [(? symbol?) (normalize-var expr)] + [(list op args ...) (cons op (map normalize-spec args))])) + + ; pre-allocated id vectors for all the common cases + (define 0-vec (make-u32vector 0)) + (define 1-vec (make-u32vector 1)) + (define 2-vec (make-u32vector 2)) + (define 3-vec (make-u32vector 3)) + + (define (list->u32vec xs) + (match xs + [(list) 0-vec] + [(list x) + (u32vector-set! 1-vec 0 x) + 1-vec] + [(list x y) + (u32vector-set! 2-vec 0 x) + (u32vector-set! 2-vec 1 y) + 2-vec] + [(list x y z) + (u32vector-set! 3-vec 0 x) + (u32vector-set! 3-vec 1 y) + (u32vector-set! 3-vec 2 z) + 3-vec] + [_ (list->u32vector xs)])) + + ; node -> natural + ; inserts an expression into the e-graph, returning its e-class id. + (define (insert-node! node root?) + (match node + [(list op ids ...) (egraph_add_node ptr (symbol->string op) (list->u32vec ids) root?)] + [(? symbol? x) (egraph_add_node ptr (symbol->string x) 0-vec root?)] + [(? number? n) (egraph_add_node ptr (number->string n) 0-vec root?)])) + + ; expr -> id + ; expression cache + (define expr->id (make-hash)) + + ; expr -> natural + ; inserts an expresison into the e-graph, returning its e-class id. + (define (insert! expr [root? #f]) + ; transform the expression into a node pointing + ; to its child e-classes + (define node + (match expr + [(? number?) expr] + [(? symbol?) (normalize-var expr)] + [(literal v _) v] + [(approx spec impl) + (define spec* (insert! spec)) + (define impl* (insert! impl)) + (hash-ref! id->spec + spec* + (lambda () + (define spec* (normalize-spec spec)) ; preserved spec for extraction + (define type (representation-type (repr-of impl ctx))) ; track type of spec + (cons spec* type))) + (list '$approx spec* impl*)] + [(list op args ...) (cons op (map insert! args))])) + ; always insert the node if it is a root since + ; the e-graph tracks which nodes are roots + (cond + [root? (insert-node! node #t)] + [else (hash-ref! expr->id node (lambda () (insert-node! node #f)))])) + + (for/list ([expr (in-list exprs)]) + (insert! expr #t))) ;; runs rules on an egraph (optional iteration limit) (define (egraph-run egraph-data ffi-rules node-limit iter-limit scheduler const-folding?) @@ -114,29 +166,22 @@ ['backoff #f] ['simple #t] [_ (error 'egraph-run "unknown scheduler: `~a`" scheduler)])) - (define-values (iterations length ptr) - (egraph_run (egraph-data-egraph-pointer egraph-data) - ffi-rules - iter_limit - node_limit - simple_scheduler? - const-folding?)) - (define iteration-data (convert-iteration-data iterations length)) - (destroy_egraphiters ptr) - iteration-data) + (egraph_run (egraph-data-egraph-pointer egraph-data) + ffi-rules + iter_limit + node_limit + simple_scheduler? + const-folding?)) (define (egraph-get-simplest egraph-data node-id iteration ctx) - (define ptr (egraph_get_simplest (egraph-data-egraph-pointer egraph-data) node-id iteration)) - (define str (cast ptr _pointer _string/utf-8)) - (destroy_string ptr) - (egg-expr->expr str egraph-data (context-repr ctx))) + (define expr (egraph_get_simplest (egraph-data-egraph-pointer egraph-data) node-id iteration)) + (egg-expr->expr expr egraph-data (context-repr ctx))) (define (egraph-get-variants egraph-data node-id orig-expr ctx) - (define expr-str (~a (expr->egg-expr orig-expr egraph-data ctx))) - (define ptr (egraph_get_variants (egraph-data-egraph-pointer egraph-data) node-id expr-str)) - (define str (cast ptr _pointer _string/utf-8)) - (destroy_string ptr) - (egg-exprs->exprs str egraph-data (context-repr ctx))) + (define egg-expr (expr->egg-expr orig-expr egraph-data ctx)) + (define exprs (egraph_get_variants (egraph-data-egraph-pointer egraph-data) node-id egg-expr)) + (for/list ([expr (in-list exprs)]) + (egg-expr->expr expr egraph-data (context-repr ctx)))) (define (egraph-is-unsound-detected egraph-data) (egraph_is_unsound_detected (egraph-data-egraph-pointer egraph-data))) @@ -155,45 +200,44 @@ [3 "unsound"] [sr (error 'egraph-stop-reason "unexpected stop reason ~a" sr)])) -;; An egraph is just a S-expr of the form -;; -;; egraph ::= ( ...) -;; eclass ::= ( ..+) -;; enode ::= ( ...) -;; -(define (egraph-serialize egraph-data) - (egraph_serialize (egraph-data-egraph-pointer egraph-data))) +;; Extracts the eclasses of an e-graph as a u32vector +(define (egraph-eclasses egraph-data) + (egraph_get_eclasses (egraph-data-egraph-pointer egraph-data))) + +;; Extracts the nodes of an e-class as a vector +;; where each enode is either a symbol, number, or list +(define (egraph-get-eclass egraph-data id) + (define ptr (egraph-data-egraph-pointer egraph-data)) + (define egg->herbie (egraph-data-egg->herbie-dict egraph-data)) + (define eclass (egraph_get_eclass ptr id)) + ; need to fix up any constant operators + (for ([enode (in-vector eclass)] + [i (in-naturals)]) + (when (and (symbol? enode) (not (hash-has-key? egg->herbie enode))) + (vector-set! eclass i (cons enode (make-u32vector 0))))) + eclass) (define (egraph-find egraph-data id) (egraph_find (egraph-data-egraph-pointer egraph-data) id)) (define (egraph-expr-equal? egraph-data expr goal ctx) - (define id1 (egraph-add-expr egraph-data expr ctx)) - (define id2 (egraph-add-expr egraph-data goal ctx)) - (= (egraph-find egraph-data id1) (egraph-find egraph-data id2))) + (match-define (list id1 id2) (egraph-add-exprs egraph-data (list expr goal) ctx)) + (= id1 id2)) ;; returns a flattened list of terms or #f if it failed to expand the proof due to budget (define (egraph-get-proof egraph-data expr goal ctx) - (define egg-expr (~a (expr->egg-expr expr egraph-data ctx))) - (define egg-goal (~a (expr->egg-expr goal egraph-data ctx))) - (define pointer (egraph_get_proof (egraph-data-egraph-pointer egraph-data) egg-expr egg-goal)) - (define res (cast pointer _pointer _string/utf-8)) - (destroy_string pointer) + (define egg-expr (expr->egg-expr expr egraph-data ctx)) + (define egg-goal (expr->egg-expr goal egraph-data ctx)) + (define str (egraph_get_proof (egraph-data-egraph-pointer egraph-data) egg-expr egg-goal)) (cond - [(< (string-length res) 10000) - (define converted (egg-exprs->exprs res egraph-data (context-repr ctx))) + [(<= (string-length str) (*proof-max-string-length*)) + (define converted + (for/list ([expr (in-port read (open-input-string str))]) + (egg-expr->expr expr egraph-data (context-repr ctx)))) (define expanded (expand-proof converted (box (*proof-max-length*)))) (if (member #f expanded) #f expanded)] [else #f])) -;; Racket representation of per-iteration runner data -(struct iteration-data (num-nodes num-eclasses time)) - -(define (convert-iteration-data egraphiters size) - (for/list ([i (in-range size)]) - (define ptr (ptr-add egraphiters i _EGraphIter)) - (iteration-data (EGraphIter-numnodes ptr) (EGraphIter-numeclasses ptr) (EGraphIter-time ptr)))) - ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; eggIR ;; @@ -273,15 +317,10 @@ [(list (? impl-exists? impl) args ...) (cons impl (map loop args (impl-info impl 'itype)))] [(list op args ...) (cons op (map loop args (operator-info op 'itype)))]))) -;; Parses a string from egg into a list of S-exprs. -(define (egg-exprs->exprs s egraph-data type) - (define egg->herbie (egraph-data-egg->herbie-dict egraph-data)) - (for/list ([egg-expr (in-port read (open-input-string s))]) - (egg-parsed->expr (flatten-let egg-expr) egg->herbie type))) - ;; Parses a string from egg into a single S-expr. -(define (egg-expr->expr s egraph-data type) - (first (egg-exprs->exprs s egraph-data type))) +(define (egg-expr->expr egg-expr egraph-data type) + (define egg->herbie (egraph-data-egg->herbie-dict egraph-data)) + (egg-parsed->expr (flatten-let egg-expr) egg->herbie type)) (module+ test (define repr (get-representation 'binary64)) @@ -291,19 +330,19 @@ (*context* (context-extend (*context*) 'z repr)) (define test-exprs - (list (cons '(+.f64 y x) (~a '(+.f64 $h0 $h1))) - (cons '(+.f64 x y) (~a '(+.f64 $h1 $h0))) - (cons '(-.f64 #s(literal 2 binary64) (+.f64 x y)) (~a '(-.f64 2 (+.f64 $h1 $h0)))) + (list (cons '(+.f64 y x) '(+.f64 $h0 $h1)) + (cons '(+.f64 x y) '(+.f64 $h1 $h0)) + (cons '(-.f64 #s(literal 2 binary64) (+.f64 x y)) '(-.f64 2 (+.f64 $h1 $h0))) (cons '(-.f64 z (+.f64 (+.f64 y #s(literal 2 binary64)) x)) - (~a '(-.f64 $h2 (+.f64 (+.f64 $h0 2) $h1)))) - (cons '(*.f64 x y) (~a '(*.f64 $h1 $h0))) - (cons '(+.f64 (*.f64 x y) #s(literal 2 binary64)) (~a '(+.f64 (*.f64 $h1 $h0) 2))) - (cons '(cos.f32 (PI.f32)) (~a '(cos.f32 (PI.f32)))) - (cons '(if (TRUE) x y) (~a '(if (TRUE) $h1 $h0))))) + '(-.f64 $h2 (+.f64 (+.f64 $h0 2) $h1))) + (cons '(*.f64 x y) '(*.f64 $h1 $h0)) + (cons '(+.f64 (*.f64 x y) #s(literal 2 binary64)) '(+.f64 (*.f64 $h1 $h0) 2)) + (cons '(cos.f32 (PI.f32)) '(cos.f32 (PI.f32))) + (cons '(if (TRUE) x y) '(if (TRUE) $h1 $h0)))) (let ([egg-graph (make-egraph)]) (for ([(in expected-out) (in-dict test-exprs)]) - (define out (~a (expr->egg-expr in egg-graph (*context*)))) + (define out (expr->egg-expr in egg-graph (*context*))) (define computed-in (egg-expr->expr out egg-graph (context-repr (*context*)))) (check-equal? out expected-out) (check-equal? computed-in in))) @@ -332,7 +371,7 @@ (let ([egg-graph (make-egraph)]) (for ([expr extended-expr-list]) (define egg-expr (expr->egg-expr expr egg-graph (*context*))) - (check-equal? (egg-expr->expr (~a egg-expr) egg-graph (context-repr (*context*))) expr)))) + (check-equal? (egg-expr->expr egg-expr egg-graph (context-repr (*context*))) expr)))) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Proofs @@ -443,12 +482,11 @@ ; non-expansive rule [else (list (rule->egg-rule ru))])) -;; egg rule cache -(define/reset *egg-rule-cache* (make-hash)) +;; egg rule cache: rule -> (cons/c rule FFI-rule) +(define/reset *egg-rule-cache* (make-hasheq)) -;; Cache mapping name to its canonical rule name -;; See `*egg-rules*` for details -(define/reset *canon-names* (make-hash)) +;; Cache mapping (expanded) rule name to its canonical rule name +(define/reset *canon-names* (make-hasheq)) ;; Tries to look up the canonical name of a rule using the cache. ;; Obviously dangerous if the cache is invalid. @@ -462,93 +500,332 @@ (for ([rule (in-list rules)]) (define egg&ffi-rules (hash-ref! (*egg-rule-cache*) - (cons (*active-platform*) rule) + rule (lambda () (for/list ([egg-rule (in-list (rule->egg-rules rule))]) (define name (rule-name egg-rule)) + (define ffi-rule + (make-ffi-rule name (rule-input egg-rule) (rule-output egg-rule))) (hash-set! (*canon-names*) name (rule-name rule)) - (cons egg-rule (make-ffi-rule egg-rule)))))) + (cons egg-rule ffi-rule))))) (for-each sow egg&ffi-rules)))) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Racket egraph ;; -;; Racket representation of an egraph; just a hashcons data structure -;; We can think of this as a read-only copy of an egraph for things like -;; platform-aware extraction and possibly ground-truth evaluation. -;; This is regraph reborn! +;; Racket representation of a typed egraph. +;; Given an e-graph from egg-herbie, we can split every e-class +;; by type ensuring that every term in an e-class has the same output type. +;; This trick makes extraction easier. ;; - eclasses: vector of enodes -;; - canon: map from egraph id to vector index -;; - has-leaf?: map from vector index to if the eclass contains a leaf -;; - constants: map from vector index to if the eclass contains a constant -;; - parents: parent e-classes of each e-class +;; - types: vector-map from e-class to type/representation +;; - leaf?: vector-map from e-class to boolean indicating if it contains a leaf node +;; - constants: vector-map from e-class to a number or #f +;; - specs: vector-map from e-class to an approx spec or #f +;; - parents: vector-map from e-class to its parent e-classes (as a vector) +;; - canon: map from (Rust) e-class, type to (Racket) e-class ;; - egg->herbie: data to translate egg IR to herbie IR -;; - id->spec: map from id to an approx spec or #f -(struct regraph (eclasses canon has-leaf? constants parents egg->herbie id->spec)) - -;; Constructs a Racket egraph from an S-expr representation. -(define (sexpr->regraph egraph egg->herbie id->spec) - ; total number of e-classes - (define n (length egraph)) - ; canonicalize all e-class ids to [0, n) - (define canon (make-hash)) - (define (canon-id id) - (hash-ref! canon id (lambda () (hash-count canon)))) - ; iterate through eclasses and fill data - (define eclasses (make-vector n #f)) - (define has-leaf? (make-vector n #f)) - (define constants (make-vector n #f)) - (define parents (make-vector n '())) - (for ([eclass (in-list egraph)]) - (match-define (cons egg-id egg-nodes) eclass) - (define id (canon-id egg-id)) - (define nodes - (for/vector #:length (length egg-nodes) - ([egg-node (in-list egg-nodes)]) - (match egg-node - [(? number?) ; number - (vector-set! has-leaf? id #t) - (vector-set! constants id egg-node) - egg-node] - [(? symbol?) ; variable or constant - (vector-set! has-leaf? id #t) - (if (hash-has-key? egg->herbie egg-node) - egg-node ; variable - (list egg-node))] ; constant - [(list op child-ids ...) ; $approx / application - (cond - [(null? child-ids) ; application is a constant function - (vector-set! has-leaf? id #t) - (list op)] - [else - (define child-ids* (map canon-id child-ids)) - (for ([child-id (in-list child-ids*)]) ; update parent-child relation - (vector-set! parents child-id (cons id (vector-ref parents child-id)))) - (cons op child-ids*)])] - [_ (error 'sexpr->regraph "malformed enode: ~a" egg-node)]))) - (vector-set! eclasses id nodes)) - - ; dedup parent-child relation and convert to vector +(struct regraph (eclasses types leaf? constants specs parents canon egg->herbie)) + +;; Returns all representatations (and their types) in the current platform. +(define (all-reprs/types [pform (*active-platform*)]) + (remove-duplicates (append-map (lambda (repr) (list repr (representation-type repr))) + (platform-reprs pform)))) + +;; Returns the type(s) of an enode so it can be placed in the proper e-class. +;; Typing rules: +;; - numbers: every real representation (or real type) +;; - variables: lookup in the `egg->herbie` renaming dictionary +;; - `if`: type is every representation (or type) [can prune incorrect ones] +;; - `approx`: every real representation [can prune incorrect ones] +;; - ops/impls: its output type/representation +;; NOTE: we can constrain "every" type by using the platform. +(define (enode-type enode egg->herbie) + (match enode + [(? number?) (cons 'real (platform-reprs (*active-platform*)))] ; number + [(? symbol?) ; variable + (match-define (cons _ repr) (hash-ref egg->herbie enode)) + (list repr (representation-type repr))] + [(cons f _) ; application + (cond + [(eq? f '$approx) (platform-reprs (*active-platform*))] + [(eq? f 'if) (all-reprs/types)] + [(impl-exists? f) (list (impl-info f 'otype))] + [else (list (operator-info f 'otype))])])) + +;; Rebuilds an e-node using typed e-classes +(define (rebuild-enode enode type lookup) + (match enode + [(? number?) enode] ; number + [(? symbol?) enode] ; variable + [(cons f ids) ; application + (cond + [(eq? f '$approx) ; approx node + (define spec (u32vector-ref ids 0)) + (define impl (u32vector-ref ids 1)) + (list '$approx (lookup spec (representation-type type)) (lookup impl type))] + [(eq? f 'if) ; if expression + (define cond (u32vector-ref ids 0)) + (define ift (u32vector-ref ids 1)) + (define iff (u32vector-ref ids 2)) + (define cond-type (if (representation? type) (get-representation 'bool) 'bool)) + (list 'if (lookup cond cond-type) (lookup ift type) (lookup iff type))] + [else + (define itypes (if (impl-exists? f) (impl-info f 'itype) (operator-info f 'itype))) + ; unsafe since we don't check that |itypes| = |ids| + ; optimize for common cases to avoid extra allocations + (cons + f + (match itypes + [(list) '()] + [(list t1) (list (lookup (u32vector-ref ids 0) t1))] + [(list t1 t2) (list (lookup (u32vector-ref ids 0) t1) (lookup (u32vector-ref ids 1) t2))] + [(list t1 t2 t3) + (list (lookup (u32vector-ref ids 0) t1) + (lookup (u32vector-ref ids 1) t2) + (lookup (u32vector-ref ids 2) t3))] + [_ (map lookup (u32vector->list ids) itypes)]))])])) + +;; Splits untyped eclasses into typed eclasses. +;; Nodes are duplicated across their possible types. +(define (split-untyped-eclasses egraph-data egg->herbie) + (define eclass-ids (egraph-eclasses egraph-data)) + (define max-id + (for/fold ([current-max 0]) ([egg-id (in-u32vector eclass-ids)]) + (max current-max egg-id))) + (define egg-id->idx (make-u32vector (+ max-id 1))) + (for ([egg-id (in-u32vector eclass-ids)] + [idx (in-naturals)]) + (u32vector-set! egg-id->idx egg-id idx)) + + (define types (all-reprs/types)) + (define type->idx (make-hasheq)) + (for ([type (in-list types)] + [idx (in-naturals)]) + (hash-set! type->idx type idx)) + (define num-types (hash-count type->idx)) + + ; maps (idx, type) to type eclass id + (define (idx+type->id idx type) + (+ (* idx num-types) (hash-ref type->idx type))) + + ; maps (untyped eclass id, type) to typed eclass id + (define (lookup-id eid type) + (idx+type->id (u32vector-ref egg-id->idx eid) type)) + + ; allocate enough eclasses for every (egg-id, type) combination + (define n (* (u32vector-length eclass-ids) num-types)) + (define id->eclass (make-vector n '())) + (define id->parents (make-vector n '())) + (define id->leaf? (make-vector n #f)) + + ; for each eclass, extract the enodes + ; ::= + ; | + ; | ( . ) + ; NOTE: nodes in typed eclasses are reversed relative + ; to their position in untyped eclasses + (for ([eid (in-u32vector eclass-ids)] + [idx (in-naturals)]) + (define enodes (egraph-get-eclass egraph-data eid)) + (for ([enode (in-vector enodes)]) + ; get all possible types for the enode + ; lookup its correct eclass and add the rebuilt node + (define types (enode-type enode egg->herbie)) + (for ([type (in-list types)]) + (define id (idx+type->id idx type)) + (define enode* (rebuild-enode enode type lookup-id)) + (vector-set! id->eclass id (cons enode* (vector-ref id->eclass id))) + (match enode* + [(list _ ids ...) + (if (null? ids) + (vector-set! id->leaf? id #t) + (for ([child-id (in-list ids)]) + (vector-set! id->parents child-id (cons id (vector-ref id->parents child-id)))))] + [(? symbol?) (vector-set! id->leaf? id #t)] + [(? number?) (vector-set! id->leaf? id #t)])))) + + ; dedup `id->parents` values (for ([id (in-range n)]) - (vector-set! parents id (list->vector (remove-duplicates (vector-ref parents id))))) + (vector-set! id->parents id (list->vector (remove-duplicates (vector-ref id->parents id))))) + (values id->eclass id->parents id->leaf? eclass-ids egg-id->idx type->idx)) + +;; TODO: reachable from roots? +;; Prunes e-nodes that are not well-typed. +;; An e-class is well-typed if it has one well-typed node +;; A node is well-typed if all of its child e-classes are well-typed. +(define (prune-ill-typed! id->eclass id->parents id->leaf?) + (define n (vector-length id->eclass)) + + ;; is the e-class well-typed? + (define typed?-vec (make-vector n #f)) + (define (eclass-well-typed? id) + (vector-ref typed?-vec id)) + + ;; is the e-node well-typed? + (define (enode-typed? enode) + (or (number? enode) (symbol? enode) (and (list? enode) (andmap eclass-well-typed? (cdr enode))))) + + (define (check-typed! dirty?-vec) + (define dirty? #f) + (define dirty?-vec* (make-vector n #f)) + (for ([id (in-range n)] + #:when (vector-ref dirty?-vec id)) + (unless (vector-ref typed?-vec id) + (when (ormap enode-typed? (vector-ref id->eclass id)) + (vector-set! typed?-vec id #t) + (define parent-ids (vector-ref id->parents id)) + (unless (vector-empty? parent-ids) + (set! dirty? #t) + (for ([parent-id (in-vector parent-ids)]) + (vector-set! dirty?-vec* parent-id #t)))))) + (when dirty? + (check-typed! dirty?-vec*))) - ; convert id->spec to a vector-map - (define id->spec* (make-vector n #f)) - (for ([(id spec) (in-hash id->spec)]) - (vector-set! id->spec* (hash-ref canon id) spec)) + ; mark all well-typed e-classes and prune nodes that are not well-typed + (check-typed! (vector-copy id->leaf?)) + (for ([id (in-range n)]) + (define eclass (vector-ref id->eclass id)) + (vector-set! id->eclass id (filter enode-typed? eclass))) - ; collect with wrapper - (regraph eclasses canon has-leaf? constants parents egg->herbie id->spec*)) + ; sanity check: every child id points to a non-empty e-class + (for ([id (in-range n)]) + (define eclass (vector-ref id->eclass id)) + (for ([enode (in-list eclass)]) + (match enode + [(list _ ids ...) + (for ([id (in-list ids)]) + (when (null? (vector-ref id->eclass id)) + (error 'prune-ill-typed! + "eclass ~a is empty, eclasses ~a" + id + (for/vector #:length n + ([id (in-range n)]) + (list id (vector-ref id->eclass id))))))] + [_ (void)])))) + +;; Rebuilds eclasses and associated data after pruning. +(define (rebuild-eclasses id->eclass eclass-ids egg-id->idx type->idx) + (define n (vector-length id->eclass)) + (define remap (make-vector n #f)) + + ; build the id map + (define n* 0) + (for ([id (in-range n)]) + (define eclass (vector-ref id->eclass id)) + (unless (null? eclass) + (vector-set! remap id n*) + (set! n* (add1 n*)))) + + ; invert `type->idx` map + (define idx->type (make-hash)) + (define num-types (hash-count type->idx)) + (for ([(type idx) (in-hash type->idx)]) + (hash-set! idx->type idx type)) + + ; rebuild eclass and type vectors + ; transform each eclass from a list to a vector + (define eclasses (make-vector n* #f)) + (define types (make-vector n* #f)) + (for ([id (in-range n)]) + (define id* (vector-ref remap id)) + (when id* + (define eclass (vector-ref id->eclass id)) + (vector-set! eclasses + id* + (for/vector #:length (length eclass) + ([enode (in-list eclass)]) + (match enode + [(? number?) enode] + [(? symbol?) enode] + [(list op ids ...) + (define ids* (map (lambda (id) (vector-ref remap id)) ids)) + (cons op ids*)]))) + (vector-set! types id* (hash-ref idx->type (modulo id num-types))))) + + ; build the canonical id map + (define egg-id->id (make-hash)) + (for ([eid (in-u32vector eclass-ids)]) + (define idx (u32vector-ref egg-id->idx eid)) + (define id0 (* idx num-types)) + (for ([id (in-range id0 (+ id0 num-types))]) + (define id* (vector-ref remap id)) + (when id* + (define type (vector-ref types id*)) + (hash-set! egg-id->id (cons eid type) id*)))) + + (values eclasses types egg-id->id)) + +;; Splits untyped eclasses into typed eclasses, +;; keeping only the subset of enodes that are well-typed. +(define (make-typed-eclasses egraph-data egg->herbie) + ;; Step 1: split Rust-eclasses by type + (define-values (id->eclass id->parents id->leaf? eclass-ids egg-id->idx type->idx) + (split-untyped-eclasses egraph-data egg->herbie)) + + ;; Step 2: keep well-typed e-nodes + ;; An e-class is well-typed if it has one well-typed node + ;; A node is well-typed if all of its child e-classes are well-typed. + (prune-ill-typed! id->eclass id->parents id->leaf?) + + ;; Step 3: remap e-classes + ;; Any empty e-classes must be removed, so we re-map every id + (rebuild-eclasses id->eclass eclass-ids egg-id->idx type->idx)) + +;; Analyzes eclasses for their properties. +;; The result are vector-maps from e-class ids to data. +;; - parents: parent e-classes (as a vector) +;; - leaf?: does the e-class contain a leaf node +;; - constants: the e-class constant (if one exists) +(define (analyze-eclasses eclasses) + (define n (vector-length eclasses)) + (define parents (make-vector n '())) + (define leaf? (make-vector n '#f)) + (define constants (make-vector n #f)) + (for ([id (in-range n)]) + (define eclass (vector-ref eclasses id)) + (for ([enode eclass]) ; might be a list or vector + (match enode + [(? number? n) + (vector-set! leaf? id #t) + (vector-set! constants id n)] + [(? symbol?) (vector-set! leaf? id #t)] + [(list _ ids ...) + (when (null? ids) + (vector-set! leaf? id #t)) + (for ([child-id (in-list ids)]) + (vector-set! parents child-id (cons id (vector-ref parents child-id))))]))) + + ; parent map: remove duplicates, convert lists to vectors + (for ([id (in-range n)]) + (define ids (remove-duplicates (vector-ref parents id))) + (vector-set! parents id (list->vector ids))) + + (values parents leaf? constants)) ;; Constructs a Racket egraph from an S-expr representation of ;; an egraph and data to translate egg IR to herbie IR. (define (make-regraph egraph-data) - (define egraph-str (egraph-serialize egraph-data)) - (sexpr->regraph (read (open-input-string egraph-str)) - (egraph-data-egg->herbie-dict egraph-data) - (for/hash ([(id spec) (in-hash (egraph-data-id->spec egraph-data))]) - (values (egraph-find egraph-data id) spec)))) + (define egg->herbie (egraph-data-egg->herbie-dict egraph-data)) + (define id->spec (egraph-data-id->spec egraph-data)) + + ;; split the e-classes by type + (define-values (eclasses types canon) (make-typed-eclasses egraph-data egg->herbie)) + (define n (vector-length eclasses)) + + ;; analyze each eclass + (define-values (parents leaf? constants) (analyze-eclasses eclasses)) + + ;; rebuild id->spec map for typed e-classes + (define specs (make-vector n #f)) + (for ([(id spec&repr) (in-hash id->spec)]) + (match-define (cons spec repr) spec&repr) + (define id* (hash-ref canon (cons (egraph-find egraph-data id) repr))) + (vector-set! specs id* spec)) + + ; construct the `regraph` instance + (regraph eclasses types leaf? constants specs parents canon egg->herbie)) ;; Egraph node has children. ;; Nullary operators have no children! @@ -562,14 +839,14 @@ ;; the eclass's analysis. (define (regraph-analyze regraph eclass-proc #:analysis [analysis #f]) (define eclasses (regraph-eclasses regraph)) - (define has-leaf? (regraph-has-leaf? regraph)) + (define leaf? (regraph-leaf? regraph)) (define parents (regraph-parents regraph)) (define n (vector-length eclasses)) ; set analysis if not provided (unless analysis (set! analysis (make-vector n #f))) - (define dirty?-vec (vector-copy has-leaf?)) ; visit eclass on next pass? + (define dirty?-vec (vector-copy leaf?)) ; visit eclass on next pass? (define changed?-vec (make-vector n #f)) ; eclass was changed last iteration ; run the analysis @@ -596,133 +873,19 @@ ; Invariant: all eclasses have an analysis (for ([id (in-range n)]) (unless (vector-ref analysis id) + (define types (regraph-types regraph)) (error 'regraph-analyze "analysis not run on all eclasses: ~a ~a" eclass-proc (for/vector #:length n ([id (in-range n)]) + (define type (vector-ref types id)) (define eclass (vector-ref eclasses id)) (define eclass-analysis (vector-ref analysis id)) - (list id eclass eclass-analysis))))) + (list id type eclass eclass-analysis))))) analysis) -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -;; Regraph untyped extraction -;; -;; Untyped extraction is ideal for extracting real expressions from an egraph. -;; This style of extractor associates to each eclass a best (cost, node) pair. -;; The extractor procedure simply takes an eclass id. -;; -;; Untyped cost functions take: -;; - the regraph we are extracting from -;; - a mutable cache (to possibly stash per-node data) -;; - the node we are computing cost for -;; - unary procedure to get the eclass of an id -;; - -;; The untyped extraction algorithm. -(define ((untyped-egg-extractor cost-proc) regraph) - (define eclasses (regraph-eclasses regraph)) - (define n (vector-length eclasses)) - - ; costs: mapping id to (cost, best node) - (define costs (make-vector n #f)) - - ; Checks if eclass has a cost - (define (eclass-has-cost? id) - (vector-ref costs id)) - - ; Unsafe lookup of eclass cost - (define (unsafe-eclass-cost id) - (car (vector-ref costs id))) - - ; Computes the current cost of a node if its children have a cost - ; Cost function has access to a mutable value through `cache` - (define cache (box #f)) - (define (node-cost node changed?-vec) - (if (node-has-children? node) - (let ([child-ids (cdr node)]) ; (op child ...) - ; compute the cost if at least one child eclass has a new analysis - ; ... and an analysis exists for all the child eclasses - (and (ormap (lambda (id) (vector-ref changed?-vec id)) child-ids) - (andmap eclass-has-cost? child-ids) - (cost-proc regraph cache node unsafe-eclass-cost))) - (cost-proc regraph cache node unsafe-eclass-cost))) - - ; Updates the cost of the current eclass. - ; Returns #t if the cost of the current eclass has improved. - (define (eclass-set-cost! _ changed?-vec iter eclass id) - ; Optimization: we only need to update node cost as needed. - ; (i) terminals, nullary operators: only compute once - ; (ii) non-nullary operators: compute when any of its child eclasses - ; have their analysis updated - (define (node-requires-update? node) - (if (node-has-children? node) - (ormap (lambda (id) (vector-ref changed?-vec id)) (cdr node)) - (= iter 0))) - - (define new-cost - (for/fold ([best #f]) ([node (in-vector eclass)]) - (cond - [(node-requires-update? node) - (define cost (node-cost node changed?-vec)) - (match* (best cost) - [(_ #f) best] - [(#f _) (cons cost node)] - [(_ _) - #:when (< cost (car best)) - (cons cost node)] - [(_ _) best])] - [else best]))) - - (cond - [new-cost - (define prev-cost (vector-ref costs id)) - (cond - [(or (not prev-cost) ; first time - (< (car new-cost) (car prev-cost))) - (vector-set! costs id new-cost) - #t] - [else #f])] - [else #f])) - - ; run the analysis - (set! costs (regraph-analyze regraph eclass-set-cost! #:analysis costs)) - - ; reconstructs the best expression at a node - (define id->spec (regraph-id->spec regraph)) - (define (build-expr id) - (match (cdr (vector-ref costs id)) - [(? number? n) n] ; number - [(? symbol? s) s] ; variable - [(list '$approx spec impl) ; approx - (match (vector-ref id->spec spec) - [#f (error 'build-expr "no initial approx node in eclass ~a" id)] - [spec-e (list '$approx spec-e (build-expr impl))])] - [(list op ids ...) (cons op (map build-expr ids))] ; application - [e (error 'untyped-extraction-proc "unexpected node" e)])) - - ; the actual extraction procedure - (lambda (id _) - (define cost (car (vector-ref costs id))) - (cons cost (build-expr id)))) - -;; Is fractional with odd denominator. -(define (fraction-with-odd-denominator? frac) - (and (rational? frac) (let ([denom (denominator frac)]) (and (> denom 1) (odd? denom))))) - -;; The default per-node cost function -(define (default-untyped-egg-cost-proc regraph cache node rec) - (define constants (regraph-constants regraph)) - (match node - [(? number?) 1] - [(? symbol?) 1] - [(list 'pow _ b e) ; special case for fractional pow - (define n (vector-ref constants e)) - (if (and n (fraction-with-odd-denominator? n)) +inf.0 (+ 1 (rec b) (rec e)))] - [(list _ args ...) (apply + 1 (map rec args))])) - ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Regraph typed extraction ;; @@ -741,144 +904,6 @@ ;; - an output type ;; - a default failure value ;; -;; Types are represented as one of the following: -;; - (or ..+) union of types -;; - : representation type -;; - : type-name type -;; - #f: inconclusive (possible result of `if` type during analysis) - -;; Types are equal? -(define (type/equal? ty1 ty2) - (match* (ty1 ty2) - [((list 'or tys1 ...) (list 'or tys2 ...)) - ; set equality: { tys1 ... } = { tys2 ... } - (and (andmap (lambda (ty) (member ty tys1)) tys2) (andmap (lambda (ty) (member ty tys2)) tys1))] - [(_ _) (equal? ty1 ty2)])) - -;; Subtyping relation, e.g., `t1 :< t2` -(define (subtype? ty1 ty2) - (match* (ty1 ty2) - [(_ #f) #f] - [((list 'or tys1 ...) (list 'or tys2 ...)) (andmap (lambda (ty) (member ty tys2)) tys1)] - [(_ (list 'or tys2 ...)) (member ty1 tys2)] - [(_ _) (equal? ty1 ty2)])) - -;; Applying the union operation over types -(define (type/union ty1 . tys) - (for/fold ([ty1 ty1]) ([ty2 (in-list tys)]) - (match* (ty1 ty2) - [(#f _) ty2] - [(_ #f) ty1] - [((list 'or tys1 ...) (list 'or tys2 ...)) (cons 'or (remove-duplicates (append tys1 tys2)))] - [((list 'or tys1 ...) _) (if (member ty2 tys1) ty1 (list* 'or ty2 tys1))] - [(_ (list 'or tys2 ...)) (if (member ty1 tys2) ty2 (list* 'or ty1 tys2))] - [(_ _) (if (equal? ty1 ty2) ty1 (list 'or ty1 ty2))]))) - -;; Applying the intersection operation over types -(define (type/intersect ty1 ty2) - (match* (ty1 ty2) - [((list 'or tys1 ...) (list 'or tys2 ...)) - (match (for/fold ([tys '()]) - ([ty (in-list tys1)] - #:when (member ty tys2)) - (cons ty tys)) - ['() #f] - [(list ty) ty] - [(list tys ...) (cons 'or tys)])] - [((list 'or tys1 ...) _) (and (member ty2 tys1) ty2)] - [(_ (list 'or tys2 ...)) (and (member ty1 tys2) ty1)] - [(_ _) (and (equal? ty1 ty2) ty1)])) - -;; Computes the set of extractable types for each eclass. -(define (regraph-eclass-types egraph) - (define egg->herbie (regraph-egg->herbie egraph)) - (define reprs (platform-reprs (*active-platform*))) - - (define (node->type analysis node) - (match node - [(? number?) - ; NOTE: a number by itself is untyped, but we can constrain - ; the type of the number by the platform - (for/fold ([ty #f]) - ([repr (in-list reprs)] - #:when (eq? (representation-type repr) 'real)) - (type/union ty repr (representation-type repr)))] - [(? symbol?) - (define repr (cdr (hash-ref egg->herbie node))) - (type/union repr (representation-type repr))] - [(list '$approx _ impl) (vector-ref analysis impl)] - [(list 'if _ ift iff) - (define ift-types (vector-ref analysis ift)) - (define iff-types (vector-ref analysis iff)) - (and ift-types iff-types (type/intersect ift-types iff-types))] - [(list (? impl-exists? impl) ids ...) - (and (andmap subtype? (impl-info impl 'itype) (map (curry vector-ref analysis) ids)) - (impl-info impl 'otype))] - [(list op ids ...) - (and (andmap subtype? (operator-info op 'itype) (map (curry vector-ref analysis) ids)) - (operator-info op 'otype))])) - - ;; Type analysis - (define (eclass-set-type! analysis changed?-vec iter eclass id) - (define ty (vector-ref analysis id)) - (define ty* - (if (= iter 0) - ; first iteration: only run analysis on leaves - (for/fold ([ty ty]) - ([node (in-vector eclass)] - #:unless (node-has-children? node)) - (type/union ty (node->type analysis node))) - ; other iterations: run only on non-leaves with updated children - (for/fold ([ty ty]) - ([node (in-vector eclass)] - #:when (and (node-has-children? node) - (ormap (lambda (id) (vector-ref changed?-vec id)) (cdr node)) - (andmap (lambda (id) (vector-ref analysis id)) (cdr node)))) - (type/union ty (node->type analysis node))))) - (vector-set! analysis id ty*) - (not (type/equal? ty ty*))) - - (regraph-analyze egraph eclass-set-type!)) - -;; Computes the return type (or `#f`) of each node. -;; If a node is not well-typed, the type is `#f`. -(define (regraph-node-types regraph) - (define eclasses (regraph-eclasses regraph)) - (define egg->herbie (regraph-egg->herbie regraph)) - (define n (vector-length eclasses)) - - ; Compute the extractable types - (define eclass-types (regraph-eclass-types regraph)) - (define reprs (platform-reprs (*active-platform*))) - (define (node->type node) - (match node - [(? number?) - ; NOTE: a number by itself is untyped, but we can constrain - ; the type of the number by the platform - (for/fold ([ty #f]) - ([repr (in-list reprs)] - #:when (eq? (representation-type repr) 'real)) - (type/union ty repr (representation-type repr)))] - [(? symbol?) - (define repr (cdr (hash-ref egg->herbie node))) - (type/union repr (representation-type repr))] - [(list '$approx _ impl) (vector-ref eclass-types impl)] - [(list 'if _ ift iff) - (define ift-types (vector-ref eclass-types ift)) - (define iff-types (vector-ref eclass-types iff)) - (and ift-types iff-types (type/intersect ift-types iff-types))] - [(list (? impl-exists? op) _ ...) (impl-info op 'otype)] - [(list op _ ...) (operator-info op 'otype)])) - - ; Construct the result - (for/vector #:length n - ([id (in-range n)]) - (define eclass (vector-ref eclasses id)) - (define ty (vector-ref eclass-types id)) - (for/vector #:length (vector-length eclass) - ([node (in-vector eclass)]) - (define node-ty (node->type node)) - (and (subtype? node-ty ty) node-ty)))) ;; The typed extraction algorithm. ;; Extraction is partial, that is, the result of the extraction @@ -886,107 +911,46 @@ ;; at a particular id with a particular output type. (define ((typed-egg-extractor cost-proc) regraph) (define eclasses (regraph-eclasses regraph)) + (define types (regraph-types regraph)) (define n (vector-length eclasses)) - ; compute the extractable types of each well-typed node - (define eclass-types (regraph-node-types regraph)) - - ; costs: mapping eclass id to a table from type to (cost, node) pair - (define costs - (for/vector #:length n - ([id (in-range n)]) - (define table (make-hash)) - (define node-types (vector-ref eclass-types id)) - (for ([ty (in-vector node-types)]) - (match ty - [(list 'or tys ...) - (for ([ty (in-list tys)] - #:when (representation? ty)) - (hash-set! table ty #f))] - [(? representation?) (hash-set! table ty #f)] - [(? type-name?) (void)] - [#f (void)])) - table)) - - ; Checks if eclass has a cost - (define (eclass-has-cost? id type) - (define eclass-costs (vector-ref costs id)) - (hash-ref eclass-costs type #f)) - - ; Unsafe lookup of eclass cost - (define (unsafe-eclass-cost id type failure) - (define eclass-costs (vector-ref costs id)) - (cond - [(hash-ref eclass-costs type #f) - => ; (cost . node) or #f - (lambda (cost) (and cost (car cost)))] - [else failure])) - - ; Unsafe lookup of best eclass node. - ; Returns `#f` if no best eclass exists. - (define (unsafe-best-node id type) - (define eclass-costs (vector-ref costs id)) - (cond - [(hash-ref eclass-costs type #f) - => ; (cost . node) or #f - (lambda (cost) (and cost (cdr cost)))] - [else #f])) - - ; We cache whether it is safe to apply the cost function on a given node - ; for a particular type; once `#t` we need not check the `cost` vector - ; to know if it is safe. - (define ready?-vec - (for/vector #:length n - ([id (in-range n)]) - (define eclass (vector-ref eclasses id)) - (define node-types (vector-ref eclass-types id)) - (for/vector #:length (vector-length eclass) - ([ty (in-vector node-types)]) - (match ty - [(list 'or tys ...) (map (lambda (ty) (and (representation? ty) (box #f))) tys)] - [(? representation?) (box #f)] - [(? type-name?) #f] - [#f #f])))) - - (define (slow-node-ready? node type) + ; e-class costs + (define costs (make-vector n #f)) + + ; looks up the cost + (define (unsafe-eclass-cost id) + (car (vector-ref costs id))) + + ; do its children e-classes have a cost + (define (node-ready? node) (match node - [(list '$approx _ impl) (eclass-has-cost? impl type)] - [(list 'if cond ift iff) - (and (eclass-has-cost? cond (get-representation 'bool)) - (eclass-has-cost? ift type) - (eclass-has-cost? iff type))] - [(list (? impl-exists? op) args ...) (andmap eclass-has-cost? args (impl-info op 'itype))] - [(list op args ...) (andmap eclass-has-cost? args (operator-info op 'itype))])) - - ; Computes the current cost of a node if its children have a cost - ; Cost function has access to a mutable value through `cache` + [(? number?) #t] + [(? symbol?) #t] + [(list '$approx _ impl) (vector-ref costs impl)] + [(list _ ids ...) (andmap (lambda (id) (vector-ref costs id)) ids)])) + + ; computes cost of a node (as long as each of its children have costs) + ; cost function has access to a mutable value through `cache` (define cache (box #f)) - (define (node-cost node type ready?) - (and (or (not (node-has-children? node)) - (unbox ready?) - (let ([v (slow-node-ready? node type)]) - (set-box! ready? v) - v)) - (cost-proc regraph cache node type unsafe-eclass-cost))) - - ; Updates the cost of the current eclass. - ; Returns #t if the cost of the current eclass has improved. + (define (node-cost node type) + (and (node-ready? node) (cost-proc regraph cache node type unsafe-eclass-cost))) + + ; updates the cost of the current eclass. + ; returns whether the cost of the current eclass has improved. (define (eclass-set-cost! _ changed?-vec iter eclass id) - (define node-types (vector-ref eclass-types id)) - (define eclass-costs (vector-ref costs id)) - (define ready?/node (vector-ref ready?-vec id)) + (define type (vector-ref types id)) (define updated? #f) - ; Update cost information - (define (update-cost! type new-cost node) + ; update cost information + (define (update-cost! new-cost node) (when new-cost - (define prev-cost/node (hash-ref eclass-costs type #f)) - (when (or (not prev-cost/node) ; first cost - (< new-cost (car prev-cost/node))) ; better cost - (hash-set! eclass-costs type (cons new-cost node)) + (define prev-cost&node (vector-ref costs id)) + (when (or (not prev-cost&node) ; first cost + (< new-cost (car prev-cost&node))) ; better cost + (vector-set! costs id (cons new-cost node)) (set! updated? #t)))) - ; Optimization: we only need to update node cost as needed. + ; optimization: we only need to update node cost as needed. ; (i) terminals, nullary operators: only compute once ; (ii) non-nullary operators: compute when any of its child eclasses ; have their analysis updated @@ -995,82 +959,42 @@ (ormap (lambda (id) (vector-ref changed?-vec id)) (cdr node)) (= iter 0))) - ; Iterate over the nodes - (for ([node (in-vector eclass)] - [ty (in-vector node-types)] - [ready? (in-vector ready?/node)]) - (match ty - [(list 'or tys ...) ; node is a union type (only for some `if` nodes) - (for ([ty (in-list tys)] - [ready? (in-list ready?)]) - (when (and (representation? ty) (node-requires-update? node)) - (define new-cost (node-cost node ty ready?)) - (update-cost! ty new-cost node)))] - [(? representation?) ; node has a specific reprsentation - (when (node-requires-update? node) - (define new-cost (node-cost node ty ready?)) - (update-cost! ty new-cost node))] - [(? type-name?) (void)] ; type - [#f (void)])) ; no type + ; iterate over each node + (for ([node (in-vector eclass)]) + (when (node-requires-update? node) + (define new-cost (node-cost node type)) + (update-cost! new-cost node))) updated?) ; run the analysis (regraph-analyze regraph eclass-set-cost! #:analysis costs) - ; invariant: all eclasses have a cost for all types - (for ([cost (in-vector costs)]) - (unless (andmap identity (hash-values cost)) - (error 'typed-egg-extractor "costs not computed for all eclasses ~a" costs))) - ; rebuilds the extracted procedure - (define id->spec (regraph-id->spec regraph)) - (define (build-expr id type) - (let/ec - return - (let loop ([id id] - [type type]) - (match (unsafe-best-node id type) - [(? number? n) n] ; number - [(? symbol? s) s] ; variable - [(list '$approx spec impl) ; approx - (match (vector-ref id->spec spec) - [#f (error 'build-expr "no initial approx node in eclass ~a" id)] - [spec-e (list '$approx spec-e (build-expr impl type))])] - [(list 'if cond ift iff) ; if expression - (list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type))] - ; expression of impls - [(list (? impl-exists? impl) ids ...) (cons impl (map loop ids (impl-info impl 'itype)))] - ; expression of operators - [(list (? operator-exists? op) ids ...) (cons op (map loop ids (operator-info op 'itype)))] - [_ (return #f)])))) + (define id->spec (regraph-specs regraph)) + (define (build-expr id) + (let loop ([id id]) + (match (cdr (vector-ref costs id)) + [(? number? n) n] ; number + [(? symbol? s) s] ; variable + [(list '$approx spec impl) ; approx + (match (vector-ref id->spec spec) + [#f (error 'build-expr "no initial approx node in eclass ~a" id)] + [spec-e (list '$approx spec-e (build-expr impl))])] + ; if expression + [(list 'if cond ift iff) (list 'if (loop cond) (loop ift) (loop iff))] + ; expression of impls + [(list (? impl-exists? impl) ids ...) (cons impl (map loop ids))] + ; expression of operators + [(list (? operator-exists? op) ids ...) (cons op (map loop ids))]))) ; the actual extraction procedure - (lambda (id type) (cons (unsafe-eclass-cost id type +inf.0) (build-expr id type)))) + ; as long as the `id` is valid, extraction will work + (lambda (id) (cons (unsafe-eclass-cost id) (build-expr id)))) -;; Per-node cost function according to the platform -;; `rec` takes an id, type, and failure value -(define (platform-egg-cost-proc regraph cache node type rec) - (define egg->herbie (regraph-egg->herbie regraph)) - (define node-cost-proc (platform-node-cost-proc (*active-platform*))) - (match node - ; numbers (repr is unused) - [(? number? n) ((node-cost-proc (literal n type) type))] - [(? symbol?) ; variables (`egg->herbie` has the repr) - (define repr (cdr (hash-ref egg->herbie node))) - ((node-cost-proc node repr))] - ; approx node - [(list '$approx _ impl) (rec impl type +inf.0)] - [(list 'if cond ift iff) ; if expression - (define cost-proc (node-cost-proc node type)) - (cost-proc (rec cond (get-representation 'bool) +inf.0) - (rec ift type +inf.0) - (rec iff type +inf.0))] - [(list (? impl-exists? impl) args ...) ; impls - (define cost-proc (node-cost-proc node type)) - (define itypes (impl-info impl 'itype)) - (apply cost-proc (map (lambda (arg itype) (rec arg itype +inf.0)) args itypes))] - [(list _ ...) +inf.0])) ; specs +;; Is fractional with odd denominator. +(define (fraction-with-odd-denominator? frac) + (and (rational? frac) (let ([denom (denominator frac)]) (and (> denom 1) (odd? denom))))) ;; Old cost model version (define (default-egg-cost-proc regraph cache node type rec) @@ -1078,85 +1002,106 @@ [(? number?) 1] [(? symbol?) 1] ; approx node - [(list '$approx _ impl) (rec impl type +inf.0)] - [(list 'if cond ift iff) - (+ 1 (rec cond (get-representation 'bool) +inf.0) (rec ift type +inf.0) (rec iff type +inf.0))] + [(list '$approx _ impl) (rec impl)] + [(list 'if cond ift iff) (+ 1 (rec cond) (rec ift) (rec iff))] [(list (? impl-exists? impl) args ...) - (define itypes (impl-info impl 'itype)) - (if (equal? (impl->operator impl) 'pow) - (match args - [(list b e) - (define n (vector-ref (regraph-constants regraph) e)) - (if (fraction-with-odd-denominator? n) - +inf.0 - (apply + 1 (map (lambda (arg itype) (rec arg itype +inf.0)) args itypes)))]) - (apply + 1 (map (lambda (arg itype) (rec arg itype +inf.0)) args itypes)))] - [(list _ ...) +inf.0])) + (cond + [(equal? (impl->operator impl) 'pow) + (match-define (list b e) args) + (define n (vector-ref (regraph-constants regraph) e)) + (if (fraction-with-odd-denominator? n) +inf.0 (+ 1 (rec b) (rec e)))] + [else (apply + 1 (map rec args))])] + [(list 'pow b e) + (define n (vector-ref (regraph-constants regraph) e)) + (if (fraction-with-odd-denominator? n) +inf.0 (+ 1 (rec b) (rec e)))] + [(list _ args ...) (apply + 1 (map rec args))])) + +;; Per-node cost function according to the platform +;; `rec` takes an id, type, and failure value +(define (platform-egg-cost-proc regraph cache node type rec) + (cond + [(representation? type) + (define egg->herbie (regraph-egg->herbie regraph)) + (define node-cost-proc (platform-node-cost-proc (*active-platform*))) + (match node + ; numbers (repr is unused) + [(? number? n) ((node-cost-proc (literal n type) type))] + [(? symbol?) ; variables (`egg->herbie` has the repr) + (define repr (cdr (hash-ref egg->herbie node))) + ((node-cost-proc node repr))] + ; approx node + [(list '$approx _ impl) (rec impl)] + [(list 'if cond ift iff) ; if expression + (define cost-proc (node-cost-proc node type)) + (cost-proc (rec cond) (rec ift) (rec iff))] + [(list (? impl-exists?) args ...) ; impls + (define cost-proc (node-cost-proc node type)) + (apply cost-proc (map rec args))])] + [else (default-egg-cost-proc regraph cache node type rec)])) ;; Extracts the best expression according to the extractor. ;; Result is a single element list. (define (regraph-extract-best regraph extract id type) - ; extract expr (extraction is partial) - (define id* (hash-ref (regraph-canon regraph) id)) - (match-define (cons _ egg-expr) (extract id* type)) - ; translate egg IR to Herbie IR + (define egg->herbie (regraph-egg->herbie regraph)) + (define canon (regraph-canon regraph)) + ; extract expr + (define key (cons id type)) (cond - [egg-expr - (define egg->herbie (regraph-egg->herbie regraph)) + ; at least one extractable expression + [(hash-has-key? canon key) + (define id* (hash-ref canon key)) + (match-define (cons _ egg-expr) (extract id*)) (list (egg-parsed->expr (flatten-let egg-expr) egg->herbie type))] + ; no extractable expressions [else (list)])) ;; Extracts multiple expressions according to the extractor (define (regraph-extract-variants regraph extract id type) ; regraph fields (define eclasses (regraph-eclasses regraph)) - (define id->spec (regraph-id->spec regraph)) + (define id->spec (regraph-specs regraph)) (define canon (regraph-canon regraph)) ; extract expressions - (define id* (hash-ref canon id)) - (define egg-exprs - (reap [sow] - (for ([enode (vector-ref eclasses id*)]) - (match enode - [(? number?) (sow enode)] - [(? symbol?) (sow enode)] - [(list '$approx spec impl) - (match (vector-ref id->spec spec) - [#f (error 'regraph-extract-variants "no initial approx node in eclass ~a" id*)] - [spec-e - (match-define (cons _ impl*) (extract impl type)) - (when impl* - (sow (list '$approx spec-e impl*)))])] - [(list 'if cond ift iff) - (match-define (cons _ cond*) - (extract cond (if (representation? type) (get-representation 'bool) 'bool))) - (match-define (cons _ ift*) (extract ift type)) - (match-define (cons _ iff*) (extract iff type)) - (when (and cond* ift* iff*) ; guard against failed extraction - (sow (list 'if cond* ift* iff*)))] - [(list (? impl-exists? impl) ids ...) - (when (equal? (impl-info impl 'otype) type) - (define args - (for/list ([id (in-list ids)] - [itype (in-list (impl-info impl 'itype))]) - (match-define (cons _ expr) (extract id itype)) - expr)) - (when (andmap identity args) ; guard against failed extraction - (sow (cons impl args))))] - [(list (? operator-exists? op) ids ...) - (when (equal? (operator-info op 'otype) type) - (define args - (for/list ([id (in-list ids)] - [itype (in-list (operator-info op 'itype))]) - (match-define (cons _ expr) (extract id itype)) - expr)) - (when (andmap identity args) ; guard against failed extraction - (sow (cons op args))))])))) - ; translate egg IR to Herbie IR - (define egg->herbie (regraph-egg->herbie regraph)) - (for/list ([egg-expr (in-list egg-exprs)]) - (egg-parsed->expr (flatten-let egg-expr) egg->herbie type))) + (define key (cons id type)) + (cond + ; at least one extractable expression + [(hash-has-key? canon key) + (define id* (hash-ref canon key)) + (define egg-exprs + (for/list ([enode (vector-ref eclasses id*)]) + (match enode + [(? number?) enode] + [(? symbol?) enode] + [(list '$approx spec impl) + (define spec* (vector-ref id->spec spec)) + (unless spec* + (error 'regraph-extract-variants "no initial approx node in eclass ~a" id*)) + (match-define (cons _ impl*) (extract impl)) + (list '$approx spec* impl*)] + [(list 'if cond ift iff) + (match-define (cons _ cond*) (extract cond)) + (match-define (cons _ ift*) (extract ift)) + (match-define (cons _ iff*) (extract iff)) + (list 'if cond* ift* iff*)] + [(list (? impl-exists? impl) ids ...) + (define args + (for/list ([id (in-list ids)]) + (match-define (cons _ expr) (extract id)) + expr)) + (cons impl args)] + [(list (? operator-exists? op) ids ...) + (define args + (for/list ([id (in-list ids)]) + (match-define (cons _ expr) (extract id)) + expr)) + (cons op args)]))) + ; translate egg IR to Herbie IR + (define egg->herbie (regraph-egg->herbie regraph)) + (for/list ([egg-expr (in-list egg-exprs)]) + (egg-parsed->expr (flatten-let egg-expr) egg->herbie type))] + ; no extractable expressions + [else (list)])) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Scheduler @@ -1194,9 +1139,7 @@ (define egg-graph (make-egraph)) ; insert expressions into the e-graph - (define root-ids - (for/list ([expr (in-list exprs)]) - (egraph-add-expr egg-graph expr ctx))) + (define root-ids (egraph-add-exprs egg-graph exprs ctx)) ; run the schedule (define rule-apps (make-hash)) diff --git a/src/core/localize.rkt b/src/core/localize.rkt index 856a1760b..46907ec4b 100644 --- a/src/core/localize.rkt +++ b/src/core/localize.rkt @@ -127,13 +127,18 @@ (define subexprs-fn (eval-progs-real (map prog->spec exprs-list) ctx-list)) ; Mutable error hack, this is bad - (define errs (make-hash (map list exprs-list))) + (define errs + (for/vector #:length (vector-length roots) + ([node (in-vector roots)]) + (make-vector (pcontext-length (*pcontext*))))) - (for ([(pt ex) (in-pcontext (*pcontext*))]) + (for ([(pt ex) (in-pcontext (*pcontext*))] + [pt-idx (in-naturals)]) (define exacts (list->vector (apply subexprs-fn pt))) (for ([expr (in-list exprs-list)] [root (in-vector roots)] - [exact (in-vector exacts)]) + [exact (in-vector exacts)] + [expr-idx (in-naturals)]) (define err (match (vector-ref nodes root) [(? literal?) 1] @@ -149,11 +154,13 @@ (vector-ref exacts (vector-member idx roots)))) ; arg's index mapping to exact (define approx (apply (impl-info f 'fl) argapprox)) (ulp-difference exact approx repr)])) - (hash-update! errs expr (curry cons err)))) + (vector-set! (vector-ref errs expr-idx) pt-idx err))) + (define n 0) (for/list ([subexprs (in-list subexprss)]) (for*/hash ([subexpr (in-list subexprs)]) - (values subexpr (reverse (hash-ref errs subexpr)))))) + (begin0 (values subexpr (vector->list (vector-ref errs n))) + (set! n (add1 n)))))) ;; Compute the local error of every subexpression of `prog` ;; and returns the error information as an S-expr in the diff --git a/src/core/patch.rkt b/src/core/patch.rkt index a1a84fe20..67705dd66 100644 --- a/src/core/patch.rkt +++ b/src/core/patch.rkt @@ -17,12 +17,12 @@ ;;;;;;;;;;;;;;;;;;;;;;;;;;;; Simplify ;;;;;;;;;;;;;;;;;;;;;;;;;;;; -(define (lower-approximations approxs approx->prev) +(define (lower-approximations approxs) (timeline-event! 'simplify) (define reprs (for/list ([approx (in-list approxs)]) - (define prev (hash-ref approx->prev approx)) + (define prev (car (alt-prevs approx))) (repr-of (alt-expr prev) (*context*)))) ; generate real rules @@ -51,7 +51,7 @@ (for ([altn (in-list approxs)] [outputs (in-list simplification-options)]) (match-define (cons _ simplified) outputs) - (define prev (hash-ref approx->prev altn)) + (define prev (car (alt-prevs altn))) (for ([expr (in-list simplified)]) (define spec (prog->spec (alt-expr prev))) (sow (alt (approx spec expr) `(simplify ,runner #f #f) (list altn) '())))))) @@ -71,16 +71,27 @@ #;(exp ,exp-x ,log-x) #;(log ,log-x ,exp-x)))) -(define (taylor-alt altn) - (define expr (prog->spec (alt-expr altn))) +(define (taylor-alts altns) + (define exprs + (for/list ([altn (in-list altns)]) + (prog->spec (alt-expr altn)))) + (define free-vars (map free-variables exprs)) + (define vars (list->set (append* free-vars))) + (reap [sow] - (for* ([var (free-variables expr)] + (for* ([var (in-set vars)] [transform-type transforms-to-try]) (match-define (list name f finv) transform-type) - (define timeline-stop! (timeline-start! 'series (~a expr) (~a var) (~a name))) - (define genexpr (approximate expr var #:transform (cons f finv))) - (for ([_ (in-range (*taylor-order-limit*))]) - (sow (alt (genexpr) `(taylor ,name ,var) (list altn) '()))) + (define timeline-stop! (timeline-start! 'series (~a exprs) (~a var) (~a name))) + (define genexprs (approximate exprs var #:transform (cons f finv))) + (for ([genexpr (in-list genexprs)] + [altn (in-list altns)] + [fv (in-list free-vars)] + #:when (member var fv)) ; check whether var exists in expr at all + (for ([_ (in-range (*taylor-order-limit*))]) + (define gen (genexpr)) + (unless (spec-has-nan? gen) + (sow (alt gen `(taylor ,name ,var) (list altn) '()))))) (timeline-stop!)))) (define (spec-has-nan? expr) @@ -90,18 +101,11 @@ (timeline-event! 'series) (timeline-push! 'inputs (map ~a altns)) - (define approx->prev (make-hasheq)) - (define approxs - (reap [sow] - (for ([altn (in-list altns)]) - (for ([approximation (taylor-alt altn)]) - (unless (spec-has-nan? (alt-expr approximation)) - (hash-set! approx->prev approximation altn) - (sow approximation)))))) + (define approxs (taylor-alts altns)) (timeline-push! 'outputs (map ~a approxs)) (timeline-push! 'count (length altns) (length approxs)) - (lower-approximations approxs approx->prev)) + (lower-approximations approxs)) ;;;;;;;;;;;;;;;;;;;;;;;;;;;; Recursive Rewrite ;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/src/core/simplify.rkt b/src/core/simplify.rkt index f194fb98a..17c5dbbd2 100644 --- a/src/core/simplify.rkt +++ b/src/core/simplify.rkt @@ -52,7 +52,7 @@ (make-egg-runner args (map (lambda (_) 'real) args) `((,(*simplify-rules*) . ((node . ,(*node-limit*))))))) - (define extractor (untyped-egg-extractor default-untyped-egg-cost-proc)) + (define extractor (typed-egg-extractor default-egg-cost-proc)) (map last (simplify-batch runner extractor))) (define test-exprs diff --git a/src/core/taylor.rkt b/src/core/taylor.rkt index 4c060eeab..ffa5963c9 100644 --- a/src/core/taylor.rkt +++ b/src/core/taylor.rkt @@ -4,29 +4,35 @@ (require "../utils/common.rkt" "programs.rkt" "reduce.rkt" - "../syntax/syntax.rkt") -(provide approximate) - -(define (approximate expr var #:transform [tform (cons identity identity)] #:iters [iters 5]) - (define expr* (simplify (replace-expression expr var ((car tform) var)))) - (match-define (cons offset coeffs) (taylor var expr*)) + "../syntax/syntax.rkt" + "batch.rkt") - (define i 0) - (define terms '()) - - (define (next [iter 0]) - (define coeff (simplify (replace-expression (coeffs i) var ((cdr tform) var)))) - (set! i (+ i 1)) - (match coeff - [0 - (if (< iter iters) - (next (+ iter 1)) - (simplify (make-horner ((cdr tform) var) (reverse terms))))] - [_ - (set! terms (cons (cons coeff (- i offset 1)) terms)) - (simplify (make-horner ((cdr tform) var) (reverse terms)))])) +(provide approximate) - next) +(define (approximate exprs var #:transform [tform (cons identity identity)] #:iters [iters 5]) + (define exprs* + (for/list ([expr (in-list exprs)]) + (simplify (replace-expression expr var ((car tform) var))))) + (define batch (expand-taylor (progs->batch exprs*))) + + (define taylor-approxs (taylor var batch)) + (for/list ([root (in-vector (batch-roots batch))]) + (match-define (cons offset coeffs) (vector-ref taylor-approxs root)) + (define i 0) + (define terms '()) + + (define (next [iter 0]) + (define coeff (simplify (replace-expression (coeffs i) var ((cdr tform) var)))) + (set! i (+ i 1)) + (match coeff + [0 + (if (< iter iters) + (next (+ iter 1)) + (simplify (make-horner ((cdr tform) var) (reverse terms))))] + [_ + (set! terms (cons (cons coeff (- i offset 1)) terms)) + (simplify (make-horner ((cdr tform) var) (reverse terms)))])) + next)) (define (make-horner var terms [start 0]) (match terms @@ -68,90 +74,67 @@ (for/list ([i (in-range 0 (+ k 1))]) (map (curry cons i) (n-sum-to (- n 1) (- k i)))))])))) -(define (taylor var expr) - (define var-cache (hash-ref! (series-cache) var (λ () (make-hash)))) - (hash-ref! var-cache expr (λ () (taylor* var expr)))) - -(define (taylor* var expr) +(define (taylor var expr-batch) "Return a pair (e, n), such that expr ~= e var^n" - (match expr - [(? (curry equal? var)) (taylor-exact 0 1)] - [(? number?) (taylor-exact expr)] - [(? variable?) (taylor-exact expr)] - [`(,const) (taylor-exact expr)] - [`(+ ,args ...) (apply taylor-add (map (curry taylor var) args))] - [`(neg ,arg) (taylor-negate ((curry taylor var) arg))] - [`(- ,arg1 ,arg2) (taylor var `(+ ,arg1 (neg ,arg2)))] - [`(* ,left ,right) (taylor-mult (taylor var left) (taylor var right))] - [`(/ 1 ,arg) (taylor-invert (taylor var arg))] - [`(/ ,num ,den) (taylor-quotient (taylor var num) (taylor var den))] - [`(sqrt ,arg) (taylor-sqrt var (taylor var arg))] - [`(cbrt ,arg) (taylor-cbrt var (taylor var arg))] - [`(exp ,arg) - (let ([arg* (normalize-series (taylor var arg))]) - (if (positive? (car arg*)) (taylor-exact expr) (taylor-exp (zero-series arg*))))] - [`(sin ,arg) - (let ([arg* (normalize-series (taylor var arg))]) - (cond - [(positive? (car arg*)) (taylor-exact expr)] - [(= (car arg*) 0) - ; Our taylor-sin function assumes that a0 is 0, - ; because that way it is especially simple. We correct for this here - ; We use the identity sin (x + y) = sin x cos y + cos x sin y - (taylor-add - (taylor-mult (taylor-exact `(sin ,((cdr arg*) 0))) (taylor-cos (zero-series arg*))) - (taylor-mult (taylor-exact `(cos ,((cdr arg*) 0))) (taylor-sin (zero-series arg*))))] - [else (taylor-sin (zero-series arg*))]))] - [`(cos ,arg) - (let ([arg* (normalize-series (taylor var arg))]) - (cond - [(positive? (car arg*)) (taylor-exact expr)] - [(= (car arg*) 0) - ; Our taylor-cos function assumes that a0 is 0, - ; because that way it is especially simple. We correct for this here - ; We use the identity cos (x + y) = cos x cos y - sin x sin y - (taylor-add (taylor-mult (taylor-exact `(cos ,((cdr arg*) 0))) - (taylor-cos (zero-series arg*))) - (taylor-negate (taylor-mult (taylor-exact `(sin ,((cdr arg*) 0))) - (taylor-sin (zero-series arg*)))))] - [else (taylor-cos (zero-series arg*))]))] - [`(tan ,arg) (taylor var `(/ (sin ,arg) (cos ,arg)))] - [`(log ,arg) (taylor-log var (taylor var arg))] - [`(pow ,base ,(? exact-integer? power)) (taylor-pow (normalize-series (taylor var base)) power)] - [`(pow ,base 1/2) (taylor-sqrt var (taylor var base))] - [`(pow ,base 1/3) (taylor-cbrt var (taylor var base))] - [`(pow ,base 2/3) - (define tx (taylor var base)) - (taylor-cbrt var (taylor-mult tx tx))] - [`(pow ,base ,power) (taylor var `(exp (* ,power (log ,base))))] - [`(sinh ,arg) - (define exparg (taylor var `(exp ,arg))) - (taylor-mult (taylor-exact 1/2) (taylor-add exparg (taylor-negate (taylor-invert exparg))))] - [`(cosh ,arg) - (define exparg (taylor var `(exp ,arg))) - (taylor-mult (taylor-exact 1/2) (taylor-add exparg (taylor-invert exparg)))] - [`(tanh ,arg) - (define exparg (taylor var `(exp ,arg))) - (define expinv (taylor-invert exparg)) - (define x+ (taylor-add exparg expinv)) - (define x- (taylor-add exparg (taylor-negate expinv))) - (taylor-quotient x- x+)] - [`(asinh ,x) - (define tx (taylor var x)) - (taylor-log var - (taylor-add tx (taylor-sqrt var (taylor-add (taylor-mult tx tx) (taylor-exact 1)))))] - [`(acosh ,x) - (define tx (taylor var x)) - (taylor-log var - (taylor-add tx - (taylor-sqrt var (taylor-add (taylor-mult tx tx) (taylor-exact -1)))))] - [`(atanh ,x) - (define tx (taylor var x)) - (taylor-mult (taylor-exact 1/2) - (taylor-log var - (taylor-quotient (taylor-add (taylor-exact 1) tx) - (taylor-add (taylor-exact 1) (taylor-negate tx)))))] - [_ (taylor-exact expr)])) + (define nodes (batch-nodes expr-batch)) + (define taylor-approxs (make-vector (batch-nodes-length expr-batch))) ; vector of approximations + + (for ([node (in-vector nodes)] + [n (in-naturals)]) + (define approx + (match node + [(? (curry equal? var)) (taylor-exact 0 1)] + [(? number?) (taylor-exact node)] + [(? variable?) (taylor-exact node)] + [`(,const) (taylor-exact node)] + [`(+ ,args ...) (apply taylor-add (map (curry vector-ref taylor-approxs) args))] + [`(neg ,arg) (taylor-negate ((curry vector-ref taylor-approxs) arg))] + [`(* ,left ,right) + (taylor-mult (vector-ref taylor-approxs left) (vector-ref taylor-approxs right))] + [`(/ ,num ,den) + #:when (equal? (vector-ref nodes num) 1) + (taylor-invert (vector-ref taylor-approxs den))] + [`(/ ,num ,den) + (taylor-quotient (vector-ref taylor-approxs num) (vector-ref taylor-approxs den))] + [`(sqrt ,arg) (taylor-sqrt var (vector-ref taylor-approxs arg))] + [`(cbrt ,arg) (taylor-cbrt var (vector-ref taylor-approxs arg))] + [`(exp ,arg) + (let ([arg* (normalize-series (vector-ref taylor-approxs arg))]) + (if (positive? (car arg*)) + (taylor-exact (batch-ref expr-batch n)) + (taylor-exp (zero-series arg*))))] + [`(sin ,arg) + (let ([arg* (normalize-series (vector-ref taylor-approxs arg))]) + (cond + [(positive? (car arg*)) (taylor-exact (batch-ref expr-batch n))] + [(= (car arg*) 0) + ; Our taylor-sin function assumes that a0 is 0, + ; because that way it is especially simple. We correct for this here + ; We use the identity sin (x + y) = sin x cos y + cos x sin y + (taylor-add + (taylor-mult (taylor-exact `(sin ,((cdr arg*) 0))) (taylor-cos (zero-series arg*))) + (taylor-mult (taylor-exact `(cos ,((cdr arg*) 0))) (taylor-sin (zero-series arg*))))] + [else (taylor-sin (zero-series arg*))]))] + [`(cos ,arg) + (let ([arg* (normalize-series (vector-ref taylor-approxs arg))]) + (cond + [(positive? (car arg*)) (taylor-exact (batch-ref expr-batch n))] + [(= (car arg*) 0) + ; Our taylor-cos function assumes that a0 is 0, + ; because that way it is especially simple. We correct for this here + ; We use the identity cos (x + y) = cos x cos y - sin x sin y + (taylor-add (taylor-mult (taylor-exact `(cos ,((cdr arg*) 0))) + (taylor-cos (zero-series arg*))) + (taylor-negate (taylor-mult (taylor-exact `(sin ,((cdr arg*) 0))) + (taylor-sin (zero-series arg*)))))] + [else (taylor-cos (zero-series arg*))]))] + [`(log ,arg) (taylor-log var (vector-ref taylor-approxs arg))] + [`(pow ,base ,power) + #:when (exact-integer? (vector-ref nodes power)) + (taylor-pow (normalize-series (vector-ref taylor-approxs base)) (vector-ref nodes power))] + [_ (taylor-exact (batch-ref expr-batch n))])) + (vector-set! taylor-approxs n approx)) + taylor-approxs) ; A taylor series is represented by a function f : nat -> expr, ; representing the coefficients (the 1 / n! terms not included), @@ -462,11 +445,18 @@ (require rackunit "../syntax/types.rkt" "../syntax/load-plugin.rkt") - (check-pred exact-integer? (car (taylor 'x '(pow x 1.0))))) + (define batch (progs->batch (list '(pow x 1.0)))) + (set! batch (expand-taylor batch)) + (define root (vector-ref (batch-roots batch) 0)) + + (check-pred exact-integer? (car (vector-ref (taylor 'x batch) root)))) (module+ test (define (coeffs expr #:n [n 7]) - (match-define fn (zero-series (taylor 'x expr))) + (define batch (progs->batch (list expr))) + (set! batch (expand-taylor batch)) + (define root (vector-ref (batch-roots batch) 0)) + (match-define fn (zero-series (vector-ref (taylor 'x batch) root))) (build-list n fn)) (check-equal? (coeffs '(sin x)) '(0 1 0 -1/6 0 1/120 0)) diff --git a/src/main.rkt b/src/main.rkt index 362b6d728..8e272913a 100644 --- a/src/main.rkt +++ b/src/main.rkt @@ -49,6 +49,7 @@ (module+ main (define quiet? #f) + (define browser? #t) (define demo-output #f) (define demo-log #f) (define demo-prefix "/") @@ -153,8 +154,15 @@ [("--prefix") prefix "Prefix for proxying demo" (set! demo-prefix prefix)] [("--demo") "Run in Herbie web demo mode. Changes some text" (set! demo? true)] [("--quiet") "Print a smaller banner and don't start a browser." (set! quiet? true)] + [("--threads") + num + "How many jobs to run in parallel: Processor count is the default." + (set! threads (string->thread-count num))] + [("--no-browser") "Run the web demo but don't start a browser." (set! browser? #f)] #:args () (run-demo #:quiet quiet? + #:threads threads + #:browser browser? #:output demo-output #:log demo-log #:prefix demo-prefix diff --git a/src/platforms/binary32.rkt b/src/platforms/binary32.rkt index d7d20be2c..a4f215108 100644 --- a/src/platforms/binary32.rkt +++ b/src/platforms/binary32.rkt @@ -46,11 +46,31 @@ (begin (define-libm-impls/binary32* (itype ... otype) name ...) ...)) -(define-operator-impl (neg neg.f32 binary32) binary32 [fl fl32-]) -(define-operator-impl (+ +.f32 binary32 binary32) binary32 [fl fl32+]) -(define-operator-impl (- -.f32 binary32 binary32) binary32 [fl fl32-]) -(define-operator-impl (* *.f32 binary32 binary32) binary32 [fl fl32*]) -(define-operator-impl (/ /.f32 binary32 binary32) binary32 [fl fl32/]) +(define-operator-impl (neg.f32 [x : binary32]) + binary32 + #:spec (neg x) + #:fpcore (! :precision binary32 (- x)) + #:fl fl32-) +(define-operator-impl (+.f32 [x : binary32] [y : binary32]) + binary32 + #:spec (+ x y) + #:fpcore (! :precision binary32 (+ x y)) + #:fl fl32+) +(define-operator-impl (-.f32 [x : binary32] [y : binary32]) + binary32 + #:spec (- x y) + #:fpcore (! :precision binary32 (- x y)) + #:fl fl32-) +(define-operator-impl (*.f32 [x : binary32] [y : binary32]) + binary32 + #:spec (* x y) + #:fpcore (! :precision binary32 (* x y)) + #:fl fl32*) +(define-operator-impl (/.f32 [x : binary32] [y : binary32]) + binary32 + #:spec (/ x y) + #:fpcore (! :precision binary32 (/ x y)) + #:fl fl32/) (define-libm-impls/binary32 [(binary32 binary32) (acos acosh @@ -96,6 +116,16 @@ [<= <=.f32 <=] [>= >=.f32 >=]) -(define-operator-impl (cast binary64->binary32 binary64) binary32 [fl (curryr ->float32)]) - -(define-operator-impl (cast binary32->binary64 binary32) binary64 [fl identity]) +(define-operator-impl (binary64->binary32 [x : binary64]) + binary32 + #:spec x + #:fpcore (! :precision binary32 (cast x)) + #:fl (curryr ->float32) + #:op cast) + +(define-operator-impl (binary32->binary64 [x : binary32]) + binary64 + #:spec x + #:fpcore (! :precision binary64 (cast x)) + #:fl identity + #:op cast) diff --git a/src/platforms/binary64.rkt b/src/platforms/binary64.rkt index a3ef3cccd..620a0aa1c 100644 --- a/src/platforms/binary64.rkt +++ b/src/platforms/binary64.rkt @@ -46,11 +46,31 @@ (begin (define-libm-impls/binary64* (itype ... otype) name ...) ...)) -(define-operator-impl (neg neg.f64 binary64) binary64 [fl -]) -(define-operator-impl (+ +.f64 binary64 binary64) binary64 [fl +]) -(define-operator-impl (- -.f64 binary64 binary64) binary64 [fl -]) -(define-operator-impl (* *.f64 binary64 binary64) binary64 [fl *]) -(define-operator-impl (/ /.f64 binary64 binary64) binary64 [fl /]) +(define-operator-impl (neg.f64 [x : binary64]) + binary64 + #:spec (neg x) + #:fpcore (! :precision binary64 (- x)) + #:fl -) +(define-operator-impl (+.f64 [x : binary64] [y : binary64]) + binary64 + #:spec (+ x y) + #:fpcore (! :precision binary64 (+ x y)) + #:fl +) +(define-operator-impl (-.f64 [x : binary64] [y : binary64]) + binary64 + #:spec (- x y) + #:fpcore (! :precision binary64 (- x y)) + #:fl -) +(define-operator-impl (*.f64 [x : binary64] [y : binary64]) + binary64 + #:spec (* x y) + #:fpcore (! :precision binary64 (* x y)) + #:fl *) +(define-operator-impl (/.f64 [x : binary64] [y : binary64]) + binary64 + #:spec (/ x y) + #:fpcore (! :precision binary64 (/ x y)) + #:fl /) (define-libm-impls/binary64 [(binary64 binary64) (acos acosh diff --git a/src/platforms/bool.rkt b/src/platforms/bool.rkt index 527a61498..ce92e500b 100644 --- a/src/platforms/bool.rkt +++ b/src/platforms/bool.rkt @@ -29,11 +29,11 @@ (define (or-fn . as) (ormap identity as)) -(define-operator-impl (not not bool) bool [fl not]) +(define-operator-impl (not [x : bool]) bool #:spec (not x) #:fpcore (! (not x)) #:fl not) -(define-operator-impl (and and bool bool) bool [fl and-fn]) +(define-operator-impl (and [x : bool] [y : bool]) bool #:spec (and x y) #:fl and-fn) -(define-operator-impl (or or bool bool) bool [fl or-fn]) +(define-operator-impl (or [x : bool] [y : bool]) bool #:spec (or x y) #:fl or-fn) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;; rules ;;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/src/platforms/fallback.rkt b/src/platforms/fallback.rkt index 64a52cee3..ffe054382 100644 --- a/src/platforms/fallback.rkt +++ b/src/platforms/fallback.rkt @@ -12,19 +12,9 @@ (module test racket/base ) -;;;;;;;;;;;;;;;;;;;;;;;;;;;;; representation ;;;;;;;;;;;;;;;;;;;;;;;;;;;;; - -(define-representation (racket real flonum?) - bigfloat->flonum - bf - (shift 63 ordinal->flonum) - (unshift 63 flonum->ordinal) - 64 - (disjoin nan? infinite?)) - ;;;;;;;;;;;;;;;;;;;;;;;;;;;;; constants ;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -(define-constants racket +(define-constants binary64 [PI PI.rkt pi] [E E.rkt (exp 1.0)] [INFINITY INFINITY.rkt +inf.0] @@ -34,21 +24,22 @@ (define-syntax (define-fallback-operator stx) (syntax-case stx (real fl) - [(_ (op real ...) [fl fn] [key value] ...) - (let* ([num-args (length (cdr (syntax-e (cadr (syntax-e stx)))))] - [sym2-append (λ (x y) - (string->symbol (string-append (symbol->string x) (symbol->string y))))] - [name (sym2-append (syntax-e (car (syntax-e (cadr (syntax-e stx))))) '.rkt)]) - #`(define-operator-impl (op #,name #,@(build-list num-args (λ (_) #'racket))) - racket - [fl fn] - [key value] ...))])) + [(_ (name tsig ...) fields ...) + (let ([name #'name]) + (with-syntax ([name (string->symbol (format "~a.rkt" name))]) + #'(define-operator-impl (name tsig ...) binary64 fields ...)))])) (define-syntax-rule (define-1ary-fallback-operator op fn) - (define-fallback-operator (op real) [fl fn])) + (define-fallback-operator (op [x : binary64]) + #:spec (op x) + #:fpcore (! :precision binary64 :math-library racket (op x)) + #:fl fn)) (define-syntax-rule (define-2ary-fallback-operator op fn) - (define-fallback-operator (op real real) [fl fn])) + (define-fallback-operator (op [x : binary64] [y : binary64]) + #:spec (op x) + #:fpcore (! :precision binary64 :math-library racket (op x)) + #:fl fn)) (define-syntax-rule (define-1ary-fallback-operators [op fn] ...) (begin @@ -130,9 +121,14 @@ [pow (no-complex expt)] [remainder remainder]) -(define-fallback-operator (fma real real real) [fl (from-bigfloat bffma)]) +(define-operator-impl (fma.rkt [x : binary64] [y : binary64] [z : binary64]) + binary64 + #:spec (+ (* x y) z) + #:fpcore (! :precision binary64 :math-library racket (fma x y z)) + #:fl (from-bigfloat bffma) + #:op fma) -(define-comparator-impls racket +(define-comparator-impls binary64 [== ==.rkt =] [!= !=.rkt (negate =)] [< <.rkt <] diff --git a/src/platforms/runtime/libm.rkt b/src/platforms/runtime/libm.rkt index fb02bb73c..0659e71f3 100644 --- a/src/platforms/runtime/libm.rkt +++ b/src/platforms/runtime/libm.rkt @@ -60,8 +60,9 @@ (unless (identifier? #'name) (oops! "expected identifier" #'name)) (with-syntax ([(citype ...) (map repr->type (syntax->list #'(itype ...)))] - [cotype (repr->type #'otype)]) + [cotype (repr->type #'otype)] + [(var ...) (generate-temporaries #'(itype ...))]) #'(begin (define-libm proc (cname citype ... cotype)) (when proc - (define-operator-impl (id name itype ...) otype [fl proc] attrib ...)))))])) + (define-operator-impl (name [var : itype] ...) otype #:fl proc #:op id)))))])) diff --git a/src/platforms/runtime/utils.rkt b/src/platforms/runtime/utils.rkt index 60a4fdf61..d58f1bbe3 100644 --- a/src/platforms/runtime/utils.rkt +++ b/src/platforms/runtime/utils.rkt @@ -20,8 +20,8 @@ (define-syntax-rule (define-constants repr [name impl-name value] ...) (begin - (define-operator-impl (name impl-name) repr [fl (const value)]) ...)) + (define-operator-impl (impl-name) repr #:spec (name) #:fl (const value)) ...)) (define-syntax-rule (define-comparator-impls repr [name impl-name impl-fn] ...) (begin - (define-operator-impl (name impl-name repr repr) bool [fl impl-fn]) ...)) + (define-operator-impl (impl-name [x : repr] [y : repr]) bool #:spec (name x y) #:fl impl-fn) ...)) diff --git a/src/reports/make-graph.rkt b/src/reports/make-graph.rkt index ed0f203a3..1756472d8 100644 --- a/src/reports/make-graph.rkt +++ b/src/reports/make-graph.rkt @@ -49,14 +49,22 @@ (body (h2 "Result page for the " ,(~a command) " command is not available right now."))) out)) -(define (make-graph result out output? profile?) - (match-define (job-result _ test _ time _ warnings backend) result) - (define vars (test-vars test)) +(define (make-graph result-hash out output? profile?) + (define backend (hash-ref result-hash 'backend)) + (define test (hash-ref result-hash 'test)) + (define time (hash-ref result-hash 'time)) + (define warnings (hash-ref result-hash 'warnings)) (define repr (test-output-repr test)) (define repr-bits (representation-total-bits repr)) (define ctx (test-context test)) (define identifier (test-identifier test)) - (match-define (improve-result preprocessing pctxs start targets end bogosity) backend) + + (define preprocessing (hash-ref backend 'preprocessing)) + (define pctxs (hash-ref backend 'pctxs)) + (define start (hash-ref backend 'start)) + (define targets (hash-ref backend 'target)) + (define end (hash-ref backend 'end)) + (define bogosity (hash-ref backend 'bogosity)) (match-define (alt-analysis start-alt _ start-error) start) (define start-cost (alt-cost start-alt repr)) @@ -69,11 +77,9 @@ (for/list ([target targets]) (alt-cost (alt-analysis-alt target) repr))) - (define-values (end-alts end-errors end-costs) - (for/lists (l1 l2 l3) - ([analysis end]) - (match-define (alt-analysis alt _ test-errs) analysis) - (values alt test-errs (alt-cost alt repr)))) + (define end-alts (hash-ref end 'end-alts)) + (define end-errors (hash-ref end 'end-errors)) + (define end-costs (hash-ref end 'end-costs)) (define speedup (let ([better (for/list ([alt end-alts] @@ -83,9 +89,6 @@ (/ start-cost cost))]) (and (not (null? better)) (apply max better)))) - (match-define (list train-pctx test-pctx) pctxs) - - (define end-alt (car end-alts)) (define end-error (car end-errors)) (write-html @@ -167,11 +170,14 @@ ,(render-help "report.html#alternatives")) ,body)) ,@(for/list ([i (in-naturals 1)] - [alt end-alts] + [alt-fpcore end-alts] [errs end-errors] - [cost end-costs]) + [cost end-costs] + [history (hash-ref end 'end-histories)]) + (define formula (read-syntax 'web (open-input-string alt-fpcore))) + (define expr (parse-test formula)) (define-values (dropdown body) - (render-program (alt-expr alt) ctx #:ident identifier #:instructions preprocessing)) + (render-program (test-input expr) ctx #:ident identifier #:instructions preprocessing)) `(section ([id ,(format "alternative~a" i)] (class "programs")) (h2 "Alternative " ,(~a i) @@ -184,9 +190,7 @@ ,dropdown ,(render-help "report.html#alternatives")) ,body - (details (summary "Derivation") - (ol ((class "history")) - ,@(render-history alt train-pctx test-pctx ctx))))) + (details (summary "Derivation") (ol ((class "history")) ,@history)))) ,@(for/list ([i (in-naturals 1)] [target (in-list targets)] [target-error (in-list list-target-error)] diff --git a/src/reports/pages.rkt b/src/reports/pages.rkt index 96f0338be..212a9f50a 100644 --- a/src/reports/pages.rkt +++ b/src/reports/pages.rkt @@ -2,30 +2,23 @@ (require json) (require "../syntax/read.rkt" - "../syntax/sugar.rkt" - "../syntax/syntax.rkt" - "../syntax/types.rkt" - "../utils/alternative.rkt" - "../utils/float.rkt" - "../core/points.rkt" - "../api/sandbox.rkt" - "common.rkt" "timeline.rkt" "plot.rkt" "make-graph.rkt" "traceback.rkt") + (provide all-pages make-page page-error-handler) -(define (all-pages result) - (define good? (eq? (job-result-status result) 'success)) +(define (all-pages result-hash) + (define good? (eq? (hash-ref result-hash 'status) 'success)) (define default-pages '("graph.html" "timeline.html" "timeline.json")) (define success-pages '("points.json")) (append default-pages (if good? success-pages empty))) -(define ((page-error-handler result page out) e) - (define test (job-result-test result)) +(define ((page-error-handler result-hash page out) e) + (define test (hash-ref result-hash 'test)) (eprintf "Error generating `~a` for \"~a\":\n ~a\n" page (test-name test) (exn-message e)) (eprintf "context:\n") (for ([(fn loc) (in-dict (continuation-mark-set->context (exn-continuation-marks e)))]) @@ -37,21 +30,21 @@ ((error-display-handler) (exn-message e) e) (display "" out))) -(define (make-page page out result output? profile?) - (define test (job-result-test result)) - (define status (job-result-status result)) - (define ctx (test-context test)) +(define (make-page page out result-hash output? profile?) + (define test (hash-ref result-hash 'test)) + (define status (hash-ref result-hash 'status)) (match page ["graph.html" (match status ['success - (define command (job-result-command result)) + (define command (hash-ref result-hash 'command)) (match command - ['improve (make-graph result out output? profile?)] + ["improve" (make-graph result-hash out output? profile?)] [else (dummy-graph command out)])] - ['timeout (make-traceback result out profile?)] - ['failure (make-traceback result out profile?)] + ['timeout (make-traceback result-hash out)] + ['failure (make-traceback result-hash out)] [_ (error 'make-page "unknown result type ~a" status)])] - ["timeline.html" (make-timeline (test-name test) (job-result-timeline result) out #:path "..")] - ["timeline.json" (write-json (job-result-timeline result) out)] - ["points.json" (write-json (make-points-json result ctx) out)])) + ["timeline.html" + (make-timeline (test-name test) (hash-ref result-hash 'timeline) out #:path "..")] + ["timeline.json" (write-json (hash-ref result-hash 'timeline) out)] + ["points.json" (write-json (make-points-json result-hash) out)])) diff --git a/src/reports/plot.rkt b/src/reports/plot.rkt index cbfa47d46..928b809e9 100644 --- a/src/reports/plot.rkt +++ b/src/reports/plot.rkt @@ -14,10 +14,15 @@ "../api/sandbox.rkt") (provide make-points-json + regime-var + regime-splitpoints + real->ordinal splitpoints->json) (define (all-same? pts idx) - (= 1 (set-count (for/set ([pt pts]) (list-ref pt idx))))) + (= 1 + (set-count (for/set ([pt pts]) + (list-ref pt idx))))) (define (ulps->bits-tenths x) (string->number (real->decimal-string (ulps->bits x) 1))) @@ -30,14 +35,21 @@ (real->ordinal (repr->real val repr) repr)) '()))) -(define (make-points-json result repr) - (match-define (job-result _ test _ _ _ _ (improve-result _ pctxs start targets end _)) result) +(define (make-points-json result-hash) + (define test (hash-ref result-hash 'test)) + (define backend (hash-ref result-hash 'backend)) + (define pctxs (hash-ref backend 'pctxs)) + (define start (hash-ref backend 'start)) + (define targets (hash-ref backend 'target)) + (define end (hash-ref backend 'end)) + (define repr (test-output-repr test)) (define start-errors (alt-analysis-test-errors start)) (define target-errors (map alt-analysis-test-errors targets)) - (define end-errors (map alt-analysis-test-errors end)) + (define end-errors (hash-ref end 'end-errors)) + (define newpoints (pcontext-points (second pctxs))) ; Immediately convert points to reals to handle posits @@ -79,8 +91,7 @@ (string-replace (~r val #:notation 'exponential #:precision 0) "1e" "e"))) (list tick-str (real->ordinal val repr))))) - (define end-alt (alt-analysis-alt (car end))) - (define splitpoints (splitpoints->json vars end-alt repr)) + (define splitpoints (hash-ref end 'splitpoints)) ; NOTE ordinals *should* be passed as strings so we can detect truncation if ; necessary, but this isn't implemented yet. diff --git a/src/reports/traceback.rkt b/src/reports/traceback.rkt index d91cd4308..17ccc2df5 100644 --- a/src/reports/traceback.rkt +++ b/src/reports/traceback.rkt @@ -2,16 +2,71 @@ (require (only-in xml write-xexpr xexpr?)) (require "../utils/common.rkt" - "../utils/errors.rkt" "../syntax/read.rkt" - "../api/sandbox.rkt" "common.rkt") + (provide make-traceback) -(define (make-traceback result out profile?) - ;; Called with timeout or failure results - (match-define (job-result command test status time timeline warnings backend) result) - (define exn (if (eq? status 'failure) backend 'timeout)) +(define (make-traceback result-hash out) + (match (hash-ref result-hash 'status) + ['timeout (render-timeout result-hash out)] + ['failure (render-failure result-hash out)] + [status (error 'make-traceback "unexpected status ~a" status)])) + +(define (render-failure result-hash out) + (define test (hash-ref result-hash 'test)) + (define warnings (hash-ref result-hash 'warnings)) + (define backend (hash-ref result-hash 'backend)) + + ; unpack the exception + (match-define (list 'exn type msg url extra traceback) backend) + + (write-html + `(html + (head (meta ((charset "utf-8"))) + (title "Exception for " ,(~a (test-name test))) + (link ((rel "stylesheet") (type "text/css") (href "../report.css"))) + ,@js-tex-include + (script ([src "../report.js"]))) + (body ,(render-menu (~a (test-name test)) + (list '("Report" . "../index.html") '("Metrics" . "timeline.html"))) + ,(render-warnings warnings) + ,(render-specification test) + ,(if type + `(section ([id "user-error"] (class "error")) + (h2 ,(~a msg) " " (a ([href ,url]) "(more)")) + ,(if (eq? type 'syntax) (render-syntax-errors msg extra) "")) + "") + ,(if type + "" + `(,@(render-reproduction test #:bug? #t) + (section ([id "backtrace"]) (h2 "Backtrace") ,(render-traceback msg traceback)))))) + out)) + +(define (render-syntax-errors msg locations) + `(table (thead (th ([colspan "2"]) ,msg) (th "L") (th "C")) + (tbody ,@(for/list ([location (in-list locations)]) + (match-define (list msg src line col pos) location) + `(tr (td ((class "procedure")) ,(~a msg)) + (td ,(~a src)) + (td ,(or (~a line ""))) + (td ,(or (~a col) (~a pos)))))))) + +(define (render-traceback msg traceback) + `(table + (thead (th ([colspan "2"]) ,msg) (th "L") (th "C")) + (tbody + ,@ + (for/list ([(name loc) (in-dict traceback)]) + (match loc + [(list file line col) + `(tr (td ((class "procedure")) ,(~a name)) (td ,(~a file)) (td ,(~a line)) (td ,(~a col)))] + [#f `(tr (td ((class "procedure")) ,(~a name)) (td ([colspan "3"]) "unknown"))]))))) + +(define (render-timeout result-hash out) + (define test (hash-ref result-hash 'test)) + (define time (hash-ref result-hash 'time)) + (define warnings (hash-ref result-hash 'warnings)) (write-html `(html (head (meta ((charset "utf-8"))) @@ -23,42 +78,7 @@ (list '("Report" . "../index.html") '("Metrics" . "timeline.html"))) ,(render-warnings warnings) ,(render-specification test) - ,(match exn - [(? exn:fail:user:herbie?) - `(section - ([id "user-error"] (class "error")) - (h2 ,(~a (exn-message exn)) " " (a ([href ,(herbie-error-url exn)]) "(more)")) - ,(if (exn:fail:user:herbie:syntax? exn) (render-syntax-errors exn) ""))] - ['timeout - `(section ([id "user-error"] (class "error")) - (h2 "Timeout after " ,(format-time time)) - (p "Use the " (code "--timeout") " flag to change the timeout."))] - [_ ""]) - ,(match exn - [(? exn:fail:user:herbie?) ""] - [(? exn?) - `(,@(render-reproduction test #:bug? #t) - (section ([id "backtrace"]) (h2 "Backtrace") ,(render-traceback exn)))] - [_ ""]))) + (section ([id "user-error"] (class "error")) + (h2 "Timeout after " ,(format-time time)) + (p "Use the " (code "--timeout") " flag to change the timeout.")))) out)) - -(define (render-syntax-errors exn) - `(table (thead (th ([colspan "2"]) ,(exn-message exn)) (th "L") (th "C")) - (tbody ,@(for/list ([(stx msg) (in-dict (exn:fail:user:herbie:syntax-locations exn))]) - `(tr (td ((class "procedure")) ,(~a msg)) - (td ,(~a (syntax-source stx))) - (td ,(or (~a (syntax-line stx) ""))) - (td ,(or (~a (syntax-column stx)) (~a (syntax-position stx))))))))) - -(define (render-traceback exn) - `(table (thead (th ([colspan "2"]) ,(exn-message exn)) (th "L") (th "C")) - (tbody ,@(for/list ([tb (continuation-mark-set->context (exn-continuation-marks exn))]) - (match (cdr tb) - [(srcloc file line col _ _) - `(tr (td ((class "procedure")) ,(~a (or (car tb) "(unnamed)"))) - (td ,(~a file)) - (td ,(~a line)) - (td ,(~a col)))] - [#f - `(tr (td ((class "procedure")) ,(~a (or (car tb) "(unnamed)"))) - (td ([colspan "3"]) "unknown"))]))))) diff --git a/src/syntax/syntax.rkt b/src/syntax/syntax.rkt index b4ab3bf6f..8022e5792 100644 --- a/src/syntax/syntax.rkt +++ b/src/syntax/syntax.rkt @@ -120,7 +120,7 @@ (hash-ref operators-to-impls op)) ;; Checks an "accelerator" specification -(define (check-accelerator-spec! name itypes otype spec) +(define (check-spec! name itypes otype spec) (define (bad! fmt . args) (error name "~a in `~a`" (apply format fmt args) spec)) @@ -212,7 +212,7 @@ (define deprecated? (dict-ref attrib-dict 'deprecated #f)) ; check the spec if it is provided (when spec - (check-accelerator-spec! name itypes otype spec) + (check-spec! name itypes otype spec) (set! spec (expand-accelerators spec))) ; update tables (define info (operator name itypes* otype* spec deprecated?)) @@ -362,7 +362,7 @@ ;; An "operator implementation" implements a mathematical operator for ;; a particular set of representations satisfying the types described ;; by the `itype` and `otype` properties of the operator. -(struct operator-impl (name op itype otype fl)) +(struct operator-impl (name op ctx spec fpcore fl)) ;; Operator implementation table ;; Tracks implementations that are loaded into Racket's runtime @@ -381,8 +381,10 @@ (error 'impl-info "Unknown operator implementation ~a" impl)) (define info (hash-ref operator-impls impl)) (case field - [(itype) (operator-impl-itype info)] - [(otype) (operator-impl-otype info)] + [(itype) (context-var-reprs (operator-impl-ctx info))] + [(otype) (context-repr (operator-impl-ctx info))] + [(spec) (operator-impl-spec info)] + [(fpcore) (operator-impl-fpcore info)] [(fl) (operator-impl-fl info)])) ;; Like `operator-all-impls`, but filters for only active implementations. @@ -408,14 +410,60 @@ (define (clear-active-operator-impls!) (set-clear! active-operator-impls)) -;; Registers an operator implementation `name` or real operator `op`. -;; The input and output representations must satisfy the types -;; specified by the `itype` and `otype` fields for `op`. -(define (register-operator-impl! op name ireprs orepr attrib-dict) +;; Registers an operator implementation `name` +;; fl, spec,, and fpcore can be synthesize from an operator +(define (register-operator-impl! op + name + args + orepr + #:fl [fl #f] + #:spec [spec #f] + #:fpcore [fpcore #f]) + ;; Check if spec is given (if not, infer it from the operator which is required) + (define vars (map car args)) + (unless (= (length vars) (length (remove-duplicates vars))) + (raise-herbie-syntax-error "Duplicate variable names in ~a" vars)) + (define ireprs (map cdr args)) + (unless spec + (unless op + (raise-herbie-syntax-error "Missing required operator")) + (set! spec `(,op ,@vars))) + (check-spec! name + (map representation-type ireprs) + (representation-type orepr) + (list 'lambda vars spec)) + + ;; Infer operator from spec + (define new-op op) + (if op + op + (let loop ([expr spec] + [operator #f]) + (match expr + [`(,(? symbol? op) ,args ...) + (if (null? args) + (set! new-op op) + (for ([a (in-list args)]) + (if operator + (raise-herbie-syntax-error "Could not infer operator from ~a" spec) + (loop a op))))] + [_ (set! new-op operator)]))) + + (define bool-repr (get-representation 'bool)) + (if fpcore + ;; Verify fpcore is well formed + (match fpcore + [`(! ,props ... (,operator ,args ...)) (void)] + [`(,operator ,args ...) (void)] + [_ (raise-herbie-syntax-error "Invalid fpcore given" fpcore)]) + (if (equal? orepr bool-repr) + (set! fpcore `(,new-op ,@vars)) + (set! fpcore `(! :precision ,(representation-name orepr) (,new-op ,@vars))))) + (define op-info (hash-ref operators - op + new-op (lambda () (raise-herbie-missing-error "Cannot register `~a`, operator `~a` does not exist" name op)))) @@ -428,7 +476,7 @@ (raise-herbie-missing-error "Cannot register `~a` as an implementation of `~a`: expected ~a arguments, got ~a" name - op + new-op expect-arity actual-arity)) (for ([repr (in-list (cons orepr ireprs))] @@ -436,7 +484,7 @@ (unless (equal? (representation-type repr) type) "Cannot register `~a` as implementation of `~a`: ~a is not a representation of ~a" name - op + new-op repr type)) @@ -449,31 +497,83 @@ (define-values (_ exs) (real-apply compiler pt)) (if exs (first exs) fail)) (sym-append 'synth: name))) - ;; Get floating-point implementation (define fl-proc (cond - [(assoc 'fl attrib-dict) - => - cdr] ; user-provided implementation - [(operator-accelerator? op) ; Rival-synthesized accelerator implementation + [fl fl] ; user-provided implementation + [(operator-accelerator? new-op) ; Rival-synthesized accelerator implementation (match-define `(,(or 'lambda 'λ) (,vars ...) ,body) (operator-spec op-info)) (synth-fl-impl name vars body)] [else ; Rival-synthesized operator implementation (define vars (build-list (length ireprs) (lambda (i) (string->symbol (format "x~a" i))))) - (synth-fl-impl name vars `(,op ,@vars))])) + (synth-fl-impl name vars `(,new-op ,@vars))])) ; update tables - (define impl (operator-impl name op-info ireprs orepr fl-proc)) + (define impl (operator-impl name op-info (context vars orepr ireprs) spec fpcore fl-proc)) (hash-set! operator-impls name impl) - (hash-update! operators-to-impls op (curry cons name))) - -(define-syntax-rule (define-operator-impl (operator name atypes ...) rtype [key value] ...) - (register-operator-impl! 'operator - 'name - (list (get-representation 'atypes) ...) - (get-representation 'rtype) - (list (cons 'key value) ...))) + (hash-update! operators-to-impls new-op (curry cons name))) + +(define-syntax (define-operator-impl stx) + (define (oops! why [sub-stx #f]) + (raise-syntax-error 'define-operator-impl why stx sub-stx)) + (syntax-case stx (:) + [(_ (id [var : repr] ...) rtype fields ...) + (let ([impl-id #'id] + [fields #'(fields ...)]) + (unless (identifier? impl-id) + (oops! "impl id is not a valid identifier" impl-id)) + (for ([var (in-list (syntax->list #'(var ...)))]) + (unless (identifier? var) + (oops! "given id is not a valid identifier" var))) + (define operator #f) + (define spec #f) + (define core #f) + (define fl-expr #f) + (let loop ([fields fields]) + (syntax-case fields () + [() + (with-syntax ([impl-id impl-id] + [operator operator] + [spec spec] + [core core] + [fl-expr fl-expr]) + #'(register-operator-impl! 'operator + 'impl-id + (list (cons 'var (get-representation 'repr)) ...) + (get-representation 'rtype) + #:fl fl-expr + #:spec 'spec + #:fpcore 'core))] + [(#:spec expr rest ...) + (cond + [spec (oops! "multiple #:spec clauses" stx)] + [else + (set! spec #'expr) + (loop #'(rest ...))])] + [(#:spec) (oops! "expected value after keyword `#:spec`" stx)] + [(#:fpcore expr rest ...) + (cond + [core (oops! "multiple #:fpcore clauses" stx)] + [else + (set! core #'expr) + (loop #'(rest ...))])] + [(#:fpcore) (oops! "expected value after keyword `#:fpcore`" stx)] + [(#:fl expr rest ...) + (cond + [fl-expr (oops! "multiple #:fl clauses" stx)] + [else + (set! fl-expr #'expr) + (loop #'(rest ...))])] + [(#:fl) (oops! "expected value after keyword `#:fl`" stx)] + [(#:op name rest ...) + (cond + [operator (oops! "multiple #:op clauses" stx)] + [else + (set! operator #'name) + (loop #'(rest ...))])] + [(#:op) (oops! "expected value after keyword `#:op`" stx)] + [_ (oops! "bad syntax" fields)])))] + [_ (oops! "bad syntax")])) ;; Among active implementations, looks up an implementation with ;; the operator name `name` and argument representations `ireprs`. @@ -515,6 +615,15 @@ (define shift-val (expt 2 bits)) (λ (x) (+ (fn x) shift-val))) + ; for testing: also in /reprs/bool.rkt + (define-representation (bool bool boolean?) + identity + identity + (λ (x) (= x 0)) + (λ (x) (if x 1 0)) + 1 + (const #f)) + ; for testing: also in /reprs/binary64.rkt (define-representation (binary64 real flonum?) bigfloat->flonum @@ -525,9 +634,17 @@ (conjoin number? nan?)) ; correctly-rounded log1pmd(x) for binary64 - (define-operator-impl (log1pmd log1pmd.f64 binary64) binary64) + (define-operator-impl (log1pmd.f64 [x : binary64]) + binary64 + #:spec (- (log1p x) (log1p (neg x))) + #:fpcore (! :precision binary64 (log1pmd x)) + #:op log1pmd) ; correctly-rounded sin(x) for binary64 - (define-operator-impl (sin sin.acc.f64 binary64) binary64) + (define-operator-impl (sin.acc.f64 [x : binary64]) + binary64 + #:spec (sin x) + #:fpcore (! :precision binary64 (sin x)) + #:fl sin) (define log1pmd-proc (impl-info 'log1pmd.f64 'fl)) (define log1pmd-vals '((0.0 . 0.0) (0.5 . 1.0986122886681098) (-0.5 . -1.0986122886681098))) @@ -583,14 +700,14 @@ (and (symbol? op) (or (and (hash-has-key? operators op) (null? (operator-itype (hash-ref operators op)))) (and (hash-has-key? operator-impls op) - (null? (operator-impl-itype (hash-ref operator-impls op))))))) + (null? (context-vars (operator-impl-ctx (hash-ref operator-impls op)))))))) (define (variable? var) (and (symbol? var) (or (not (hash-has-key? operators var)) (not (null? (operator-itype (hash-ref operators var))))) (or (not (hash-has-key? operator-impls var)) - (not (null? (operator-impl-itype (hash-ref operator-impls var))))))) + (not (null? (context-vars (operator-impl-ctx (hash-ref operator-impls var)))))))) ;; Floating-point expressions require that number ;; be rounded to a particular precision. diff --git a/src/syntax/types.rkt b/src/syntax/types.rkt index 52618fb86..0565ecc98 100644 --- a/src/syntax/types.rkt +++ b/src/syntax/types.rkt @@ -26,7 +26,8 @@ (define (type-name? x) (hash-has-key? type-dict x)) -(define-syntax-rule (define-type name _ ...) (hash-set! type-dict 'name #t)) +(define-syntax-rule (define-type name _ ...) + (hash-set! type-dict 'name #t)) (define-type real) (define-type bool) diff --git a/src/utils/alternative.rkt b/src/utils/alternative.rkt index cea8ff37c..9db5415dc 100644 --- a/src/utils/alternative.rkt +++ b/src/utils/alternative.rkt @@ -20,10 +20,7 @@ ;; from one program to another. ;; They are a labeled linked list of changes. -(struct alt (expr event prevs preprocessing) - #:methods gen:custom-write - [(define (write-proc alt port mode) - (fprintf port "#" (alt-expr alt)))]) +(struct alt (expr event prevs preprocessing) #:prefab) (define (make-alt expr) (alt expr 'start '() '())) diff --git a/src/utils/errors.rkt b/src/utils/errors.rkt index f0160799a..417859c8f 100644 --- a/src/utils/errors.rkt +++ b/src/utils/errors.rkt @@ -5,6 +5,7 @@ raise-herbie-sampling-error raise-herbie-missing-error syntax->error-format-string + exception->datum herbie-error->string herbie-error-url (struct-out exn:fail:user:herbie) @@ -60,6 +61,35 @@ (or (syntax-line stx) "") (or (syntax-column stx) (syntax-position stx)))) +(define (traceback->datum exn) + (define ctx (continuation-mark-set->context (exn-continuation-marks exn))) + (for/list ([(name loc) (in-dict ctx)]) + (define name* (or name "(unnamed)")) + (match loc + [(srcloc file line col _ _) (cons name* (list file line col))] + [#f (cons name* #f)]))) + +(define (syntax-locations->datum exn) + (for/list ([(stx msg) (in-dict (exn:fail:user:herbie:syntax-locations exn))]) + (list msg (syntax-source stx) (syntax-line stx) (syntax-column stx) (syntax-position stx)))) + +(define (exception->datum exn) + (match exn + [(? exn:fail:user:herbie:missing?) + (list 'exn 'missing (exn-message exn) (herbie-error-url exn) #f (traceback->datum exn))] + [(? exn:fail:user:herbie:sampling?) + (list 'exn 'sampling (exn-message exn) (herbie-error-url exn) #f (traceback->datum exn))] + [(? exn:fail:user:herbie:syntax?) + (list 'exn + 'syntax + (exn-message exn) + (herbie-error-url exn) + (syntax-locations->datum exn) + (traceback->datum exn))] + [(? exn:fail:user:herbie?) + (list 'exn 'herbie (exn-message exn) (herbie-error-url exn) '() (traceback->datum exn))] + [(? exn?) (list 'exn #f (exn-message exn) #f '() (traceback->datum exn))])) + (define (herbie-error->string err) (call-with-output-string (λ (p) diff --git a/src/utils/timeline.rkt b/src/utils/timeline.rkt index 9253d6b3e..9841e1481 100644 --- a/src/utils/timeline.rkt +++ b/src/utils/timeline.rkt @@ -20,7 +20,7 @@ ;; This is a box so we can get a reference outside the engine, and so ;; access its value even in a timeout. ;; Important: Use 'eq?' based hash tables, process may freeze otherwise -(define/reset *timeline* (box '())) +(define/reset *timeline* (box '()) (lambda () (set-box! (*timeline*) '()))) (define *timeline-active-key* #f) (define *timeline-active-value* #f) diff --git a/www/doc/2.1/options.html b/www/doc/2.1/options.html index 4278b7e2e..1df823156 100644 --- a/www/doc/2.1/options.html +++ b/www/doc/2.1/options.html @@ -151,6 +151,9 @@

Web shell options

automatically started to show the Herbie page. This option also shrinks the text printed on start up. +
--no-browser
+
This flag disables the default behavior of opening the Herbie page in your default browser.
+
--public
When set, users on other computers can connect to the demo and use it. (In other words, the server listens