diff --git a/infra/merge.rkt b/infra/merge.rkt index 9498e58cb..c68d93a8c 100644 --- a/infra/merge.rkt +++ b/infra/merge.rkt @@ -5,7 +5,8 @@ "../src/utils/profile.rkt" "../src/api/datafile.rkt" "../src/reports/timeline.rkt" - "../src/syntax/load-plugin.rkt") + "../src/syntax/load-plugin.rkt" + "../src/reports/common.rkt") (define (merge-timelines outdir . dirs) (define tls @@ -21,7 +22,8 @@ (curry write-json joint-tl)) (call-with-output-file (build-path outdir "timeline.html") #:exists 'replace - (λ (out) (make-timeline "Herbie run" joint-tl out #:info info)))) + (λ (out) + (write-html (make-timeline "Herbie run" joint-tl #:info info) out)))) (define (merge-profiles outdir . dirs) (define pfs diff --git a/src/api/demo.rkt b/src/api/demo.rkt index d7f967b07..4b39079e3 100644 --- a/src/api/demo.rkt +++ b/src/api/demo.rkt @@ -18,25 +18,18 @@ "../config.rkt" "../syntax/read.rkt" "../utils/errors.rkt") -(require "../syntax/types.rkt" - "../syntax/sugar.rkt" - "../utils/alternative.rkt" +(require "../syntax/sugar.rkt" "../core/points.rkt" - "../api/sandbox.rkt" - "../utils/float.rkt") + "../api/sandbox.rkt") (require "datafile.rkt" "../reports/pages.rkt" "../reports/common.rkt" "../reports/core2mathjs.rkt" - "../reports/history.rkt" - "../reports/plot.rkt" "server.rkt") (provide run-demo) -(define *demo?* (make-parameter false)) (define *demo-prefix* (make-parameter "/")) -(define *demo-output* (make-parameter false)) (define *demo-log* (make-parameter false)) (define (add-prefix url) @@ -48,9 +41,9 @@ (and (not (and (*demo-output*) ; If we've already saved to disk, skip this job (directory-exists? (build-path (*demo-output*) x)))) (let ([m (regexp-match #rx"^([0-9a-f]+)\\.[0-9a-f.]+" x)]) - (and m (get-results-for (second m)))))) + (and m (server-check-on (second m)))))) (λ (x) - (let ([m (regexp-match #rx"^([0-9a-f]+)\\.[0-9a-f.]+" x)]) (get-results-for (if m (second m) x))))) + (let ([m (regexp-match #rx"^([0-9a-f]+)\\.[0-9a-f.]+" x)]) (server-check-on (if m (second m) x))))) (define-bidi-match-expander hash-arg hash-arg/m hash-arg/m) @@ -74,32 +67,10 @@ [((hash-arg) (string-arg)) generate-page] [("results.json") generate-report])) -(define (generate-page req result-hash page) +(define (generate-page req job-id page) (define path (first (string-split (url->string (request-uri req)) "/"))) (cond - [(set-member? (all-pages result-hash) page) - ;; Write page contents to disk - (when (*demo-output*) - (make-directory (build-path (*demo-output*) path)) - (for ([page (all-pages result-hash)]) - (call-with-output-file - (build-path (*demo-output*) path page) - (λ (out) - (with-handlers ([exn:fail? (page-error-handler result-hash page out)]) - (make-page page out result-hash (*demo-output*) #f))))) - (update-report result-hash - path - (get-seed) - (build-path (*demo-output*) "results.json") - (build-path (*demo-output*) "index.html"))) - (response 200 - #"OK" - (current-seconds) - #"text" - (list (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count))))) - (λ (out) - (with-handlers ([exn:fail? (page-error-handler result-hash page out)]) - (make-page page out result-hash (*demo-output*) #f))))] + [(check-and-send path job-id page)] [else (next-dispatcher)])) (define (generate-report req) @@ -234,19 +205,6 @@ (a ([href "./index.html"]) " See what formulas other users submitted."))] [else `("all formulas submitted here are " (a ([href "./index.html"]) "logged") ".")]))))) -(define (update-report result-hash dir seed data-file html-file) - (define link (path-element->string (last (explode-path dir)))) - (define data (get-table-data-from-hash result-hash link)) - (define info - (if (file-exists? data-file) - (let ([info (read-datafile data-file)]) - (struct-copy report-info info [tests (cons data (report-info-tests info))])) - (make-report-info (list data) #:seed seed #:note (if (*demo?*) "Web demo results" "")))) - (define tmp-file (build-path (*demo-output*) "results.tmp")) - (write-datafile tmp-file info) - (rename-file-or-directory tmp-file data-file #t) - (copy-file (web-resource "report.html") html-file #t)) - (define (post-with-json-response fn) (lambda (req) (define post-body (request-post-data/raw req)) @@ -339,27 +297,7 @@ (url main))) (define (check-status req job-id) - (define r (get-results-for job-id)) - ;; TODO return the current status from the jobs timeline - (match r - [#f - (response 202 - #"Job in progress" - (current-seconds) - #"text/plain" - (list (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count))))) - (λ (out) (display "Not done!" out)))] - [(? box? timeline) - (response 202 - #"Job in progress" - (current-seconds) - #"text/plain" - (list (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count))))) - (λ (out) - (display (apply string-append - (for/list ([entry (reverse (unbox timeline))]) - (format "Doing ~a\n" (hash-ref entry 'type)))) - out)))] + (match (get-timeline-for job-id) [(? hash? result-hash) (response/full 201 #"Job complete" @@ -370,7 +308,18 @@ (add-prefix (format "~a.~a/graph.html" job-id *herbie-commit*)))) (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count)))) (header #"X-Herbie-Job-ID" (string->bytes/utf-8 job-id))) - '())])) + '())] + [timeline + (response 202 + #"Job in progress" + (current-seconds) + #"text/plain" + (list (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count))))) + (λ (out) + (display (apply string-append + (for/list ([entry timeline]) + (format "Doing ~a\n" (hash-ref entry 'type)))) + out)))])) (define (check-up req) (response/full (if (is-server-up) 200 500) @@ -505,7 +454,7 @@ (define sample (hash-ref post-data 'sample)) (define seed (hash-ref post-data 'seed #f)) (define test (parse-test formula)) - (define expr (prog->fpcore (test-input test))) + (define expr (prog->fpcore (test-input test) (test-context test))) (define pcontext (json->pcontext sample (test-context test))) (define command (create-job 'local-error diff --git a/src/api/run.rkt b/src/api/run.rkt index 7d9b6f4a0..25a0f8190 100644 --- a/src/api/run.rkt +++ b/src/api/run.rkt @@ -11,7 +11,8 @@ "../core/sampling.rkt" "../reports/pages.rkt" "thread-pool.rkt" - "../reports/timeline.rkt") + "../reports/timeline.rkt" + "../reports/common.rkt") (provide make-report rerun-report @@ -84,9 +85,10 @@ (define profile (merge-profile-jsons (read-json-files info dir "profile.json"))) (call-with-output-file (build-path dir "profile.json") (curry write-json profile) #:exists 'replace) - (call-with-output-file (build-path dir "timeline.html") - #:exists 'replace - (λ (out) (make-timeline "Herbie run" timeline out #:info info #:path "."))) + (call-with-output-file + (build-path dir "timeline.html") + #:exists 'replace + (λ (out) (write-html (make-timeline "Herbie run" timeline #:info info #:path ".") out))) ; Delete old files (let* ([expected-dirs (map string->path diff --git a/src/api/sandbox.rkt b/src/api/sandbox.rkt index 6027fff74..9eab291f6 100644 --- a/src/api/sandbox.rkt +++ b/src/api/sandbox.rkt @@ -319,15 +319,15 @@ (table-row (test-name test) (test-identifier test) status - (prog->fpcore (test-pre test)) + (prog->fpcore (test-pre test) (test-context test)) preprocess (representation-name repr) '() ; TODO: eliminate field (test-vars test) (map car (job-result-warnings result)) - (prog->fpcore (test-input test)) + (prog->fpcore (test-input test) (test-context test)) #f - (prog->fpcore (test-spec test)) + (prog->fpcore (test-spec test) (test-context test)) (test-output test) #f #f @@ -348,15 +348,15 @@ (table-row (test-name test) (test-identifier test) status - (prog->fpcore (test-pre test)) + (prog->fpcore (test-pre test) (test-context test)) preprocess (representation-name repr) '() ; TODO: eliminate field (test-vars test) (map car (hash-ref result-hash 'warnings)) - (prog->fpcore (test-input test)) + (prog->fpcore (test-input test) (test-context test)) #f - (prog->fpcore (test-spec test)) + (prog->fpcore (test-spec test) (test-context test)) (test-output test) #f #f @@ -398,7 +398,7 @@ (define best-score (if (null? target-cost-score) target-cost-score (apply min (map second target-cost-score)))) - (define end-exprs (hash-ref end 'end-alts)) + (define end-exprs (hash-ref end 'end-exprs)) (define end-train-scores (map errors-score (hash-ref end 'end-train-scores))) (define end-test-scores (map errors-score (hash-ref end 'end-errors))) (define end-costs (hash-ref end 'end-costs)) @@ -435,8 +435,7 @@ [target target-cost-score] [result-est end-est-score] [result end-score] - [output - (test-input (parse-test (read-syntax 'web (open-input-string (car end-exprs)))))] + [output (car end-exprs)] [cost-accuracy cost&accuracy])] ['failure (match-define (list 'exn type _ ...) backend) @@ -524,12 +523,11 @@ [_ (error 'get-table-data "unknown result type ~a" status)])) (define (unparse-result row #:expr [expr #f] #:description [descr #f]) + (define vars (table-row-vars row)) (define repr (get-representation (table-row-precision row))) + (define ctx (context vars repr (map (const repr) vars))) ; TODO: this seems wrong (define expr* (or expr (table-row-output row) (table-row-input row))) - (define top - (if (table-row-identifier row) - (list (table-row-identifier row) (table-row-vars row)) - (list (table-row-vars row)))) + (define top (if (table-row-identifier row) (list (table-row-identifier row) vars) (list vars))) `(FPCore ,@top :herbie-status ,(string->symbol (table-row-status row)) @@ -555,4 +553,4 @@ ,@(append (for/list ([(target enabled?) (in-dict (table-row-target-prog row))] #:when enabled?) `(:alt ,target))) - ,(prog->fpcore expr*))) + ,(prog->fpcore expr* ctx))) diff --git a/src/api/server.rkt b/src/api/server.rkt index 78babd0f8..ead92475b 100644 --- a/src/api/server.rkt +++ b/src/api/server.rkt @@ -3,9 +3,10 @@ (require openssl/sha1) (require (only-in xml write-xexpr) json) +(require net/url) +(require web-server/http) (require "sandbox.rkt" - "../core/preprocess.rkt" "../core/points.rkt" "../reports/history.rkt" "../reports/plot.rkt" @@ -17,18 +18,29 @@ "../utils/alternative.rkt" "../utils/common.rkt" "../utils/errors.rkt" - "../utils/float.rkt") + "../utils/float.rkt" + "../reports/pages.rkt" + "datafile.rkt" + (submod "../utils/timeline.rkt" debug)) (provide make-path get-improve-table-data make-improve-result + server-check-on get-results-for + get-timeline-for job-count is-server-up create-job start-job wait-for-job - start-job-server) + start-job-server + check-and-send + *demo?* + *demo-output*) + +(define *demo?* (make-parameter false)) +(define *demo-output* (make-parameter false)) ; verbose logging for debugging (define verbose #f) ; Maybe change to log-level and use 'verbose? @@ -49,6 +61,45 @@ #:timeline-disabled? [timeline-disabled? #f]) (herbie-command command test seed pcontext profile? timeline-disabled?)) +;; TODO move these side worker/manager +(define (check-and-send path job-id page) + (define result-hash (get-results-for job-id)) + (cond + [(set-member? (all-pages result-hash) page) + ;; Write page contents to disk + (when (*demo-output*) + (write-results-to-disk result-hash path)) + (response 200 + #"OK" + (current-seconds) + #"text" + (list (header #"X-Job-Count" (string->bytes/utf-8 (~a (job-count))))) + (λ (out) + (with-handlers ([exn:fail? (page-error-handler result-hash page out)]) + (make-page page out result-hash (*demo-output*) #f))))] + [else #f])) + +(define (write-results-to-disk result-hash path) + (make-directory (build-path (*demo-output*) path)) + (for ([page (all-pages result-hash)]) + (call-with-output-file (build-path (*demo-output*) path page) + (λ (out) + (with-handlers ([exn:fail? (page-error-handler result-hash page out)]) + (make-page page out result-hash (*demo-output*) #f))))) + (define link (path-element->string (last (explode-path path)))) + (define data (get-table-data-from-hash result-hash link)) + (define data-file (build-path (*demo-output*) "results.json")) + (define html-file (build-path (*demo-output*) "index.html")) + (define info + (if (file-exists? data-file) + (let ([info (read-datafile data-file)]) + (struct-copy report-info info [tests (cons data (report-info-tests info))])) + (make-report-info (list data) #:seed (get-seed) #:note (if (*demo?*) "Web demo results" "")))) + (define tmp-file (build-path (*demo-output*) "results.tmp")) + (write-datafile tmp-file info) + (rename-file-or-directory tmp-file data-file #t) + (copy-file (web-resource "report.html") html-file #t)) + ; computes the path used for server URLs (define (make-path id) (format "~a.~a" id *herbie-commit*)) @@ -60,6 +111,19 @@ (log "Getting result for job: ~a.\n" job-id) (place-channel-get a)) +(define (get-timeline-for job-id) + (define-values (a b) (place-channel)) + (place-channel-put manager (list 'timeline job-id b)) + (log "Getting timeline for job: ~a.\n" job-id) + (place-channel-get a)) + +; Returns #f if there is no job returns the job-id if there is a completed job. +(define (server-check-on job-id) + (define-values (a b) (place-channel)) + (place-channel-put manager (list 'check job-id b)) + (log "Checking on: ~a.\n" job-id) + (place-channel-get a)) + (define (get-improve-table-data) (define-values (a b) (place-channel)) (place-channel-put manager (list 'improve b)) @@ -169,6 +233,7 @@ (define completed-work (make-hash)) (define busy-workers (make-hash)) (define waiting-workers (make-hash)) + (define current-jobs (make-hash)) (for ([i (in-range worker-count)]) (hash-set! waiting-workers i (make-worker i))) (log "~a workers ready.\n" (hash-count waiting-workers)) @@ -193,6 +258,7 @@ (log "Starting worker [~a] on [~a].\n" (work-item-id job) (test-name (herbie-command-test (work-item-command job)))) + (hash-set! current-jobs (work-item-id job) wid) (place-channel-put worker (list 'apply self (work-item-command job) (work-item-id job))) (hash-set! reassigned wid worker) (hash-set! busy-workers wid worker)) @@ -206,6 +272,7 @@ (hash-set! completed-work job-id result) ; move worker to waiting list + (hash-remove! current-jobs job-id) (hash-set! waiting-workers wid (hash-ref busy-workers wid)) (hash-remove! busy-workers wid) @@ -229,6 +296,20 @@ (hash-remove! waiting job-id)] ; Get the result for the given id, return false if no work found. [(list 'result job-id handler) (place-channel-put handler (hash-ref completed-work job-id #f))] + [(list 'timeline job-id handler) + (define wid (hash-ref current-jobs job-id #f)) + (cond + [wid + (log "Worker[~a] working on ~a.\n" wid job-id) + (define-values (a b) (place-channel)) + (place-channel-put (hash-ref busy-workers wid) (list 'timeline b)) + (define requested-timeline (place-channel-get a)) + (place-channel-put handler requested-timeline)] + [else + (log "Job complete, no timeline, send result.\n") + (place-channel-put handler (hash-ref completed-work job-id #f))])] + [(list 'check job-id handler) + (place-channel-put handler (if (hash-has-key? completed-work job-id) job-id #f))] ; Returns the current count of working workers. [(list 'count handler) (place-channel-put handler (hash-count busy-workers))] ; Retreive the improve results for results.json @@ -253,26 +334,46 @@ *loose-plugins*) (parameterize ([current-error-port (open-output-nowhere)]) ; hide output (load-herbie-plugins)) + (define worker-thread + (thread (λ () + (let loop ([seed #f]) + (match (thread-receive) + [job-info (run-job job-info)]) + (loop seed))))) + (define timeline #f) + (define current-job-id #f) (for ([_ (in-naturals)]) (match (place-channel-get ch) [(list 'apply manager command job-id) + (set! timeline (*timeline*)) + (set! current-job-id job-id) (log "[~a] working on [~a].\n" job-id (test-name (herbie-command-test command))) - (define herbie-result (wrapper-run-herbie command job-id)) - (match-define (job-result kind test status time _ _ backend) herbie-result) - (define out-result - (match kind - ['alternatives (make-alternatives-result herbie-result test job-id)] - ['evaluate (make-calculate-result herbie-result job-id)] - ['cost (make-cost-result herbie-result job-id)] - ['errors (make-error-result herbie-result job-id)] - ['exacts (make-exacts-result herbie-result job-id)] - ['improve (make-improve-result herbie-result test job-id)] - ['local-error (make-local-error-result herbie-result test job-id)] - ['explanations (make-explanation-result herbie-result job-id)] - ['sample (make-sample-result herbie-result test job-id)] - [_ (error 'compute-result "unknown command ~a" kind)])) - (log "Job: ~a finished, returning work to manager\n" job-id) - (place-channel-put manager (list 'finished manager worker-id job-id out-result))])))) + (thread-send worker-thread (work manager worker-id job-id command))] + [(list 'timeline handler) + (log "Timeline requested from worker[~a] for job ~a\n" worker-id current-job-id) + (place-channel-put handler (reverse (unbox timeline)))])))) + +(struct work (manager worker-id job-id job)) + +(define (run-job job-info) + (match-define (work manager worker-id job-id command) job-info) + (log "run-job: ~a, ~a\n" worker-id job-id) + (define herbie-result (wrapper-run-herbie command job-id)) + (match-define (job-result kind test status time _ _ backend) herbie-result) + (define out-result + (match kind + ['alternatives (make-alternatives-result herbie-result test job-id)] + ['evaluate (make-calculate-result herbie-result job-id)] + ['cost (make-cost-result herbie-result job-id)] + ['errors (make-error-result herbie-result job-id)] + ['exacts (make-exacts-result herbie-result job-id)] + ['improve (make-improve-result herbie-result test job-id)] + ['local-error (make-local-error-result herbie-result test job-id)] + ['explanations (make-explanation-result herbie-result job-id)] + ['sample (make-sample-result herbie-result test job-id)] + [_ (error 'compute-result "unknown command ~a" kind)])) + (log "Job: ~a finished, returning work to manager\n" job-id) + (place-channel-put manager (list 'finished manager worker-id job-id out-result))) (define (make-explanation-result herbie-result job-id) (define explanations (job-result-backend herbie-result)) @@ -286,7 +387,7 @@ (make-path job-id))) (define (make-local-error-result herbie-result test job-id) - (define expr (prog->fpcore (test-input test))) + (define expr (prog->fpcore (test-input test) (test-context test))) (define local-error (job-result-backend herbie-result)) ;; TODO: potentially unsafe if resugaring changes the AST (define tree @@ -410,14 +511,13 @@ (improve-result-bogosity backend))) (define (end-hash end repr pcontexts test) + (define-values (end-alts train-errors end-errors end-costs) (for/lists (l1 l2 l3 l4) ([analysis end]) (match-define (alt-analysis alt train-errors test-errs) analysis) (values alt train-errors test-errs (alt-cost alt repr)))) - (define fpcores - (for/list ([altn end-alts]) - (~a (program->fpcore (alt-expr altn) (test-context test))))) + (define alts-histories (for/list ([alt end-alts]) (render-history alt (first pcontexts) (second pcontexts) (test-context test)))) @@ -431,8 +531,8 @@ (real->ordinal (repr->real val repr) repr)) '()))) - (hasheq 'end-alts - fpcores + (hasheq 'end-exprs + (map alt-expr end-alts) 'end-histories alts-histories 'end-train-scores diff --git a/src/config.rkt b/src/config.rkt index 046fb0d4a..5e007ebb6 100644 --- a/src/config.rkt +++ b/src/config.rkt @@ -138,7 +138,7 @@ (define *platform-name* (make-parameter 'default)) ;; True iff using the old cost function -(define *egraph-platform-cost* (make-parameter #t)) +(define *egraph-platform-cost* (make-parameter #f)) ;; Plugins loaded locally rather than through Racket. (define *loose-plugins* (make-parameter '())) diff --git a/src/core/alt-table.rkt b/src/core/alt-table.rkt index 49f96b471..1384a5d86 100644 --- a/src/core/alt-table.rkt +++ b/src/core/alt-table.rkt @@ -40,11 +40,11 @@ (alt-table (make-immutable-hash (for/list ([(pt ex) (in-pcontext pcontext)] [err (errors (alt-expr initial-alt) pcontext ctx)]) (cons pt (list (pareto-point cost err (list initial-alt)))))) - (hash initial-alt - (for/list ([(pt ex) (in-pcontext pcontext)]) - pt)) - (hash initial-alt #f) - (hash initial-alt cost) + (hasheq initial-alt + (for/list ([(pt ex) (in-pcontext pcontext)]) + pt)) + (hasheq initial-alt #f) + (hasheq initial-alt cost) pcontext (list initial-alt))) @@ -192,7 +192,7 @@ [ppt (in-list curve)] [alt (in-list (pareto-point-data ppt))]) (hash-set! alt->points* alt (cons pt (hash-ref alt->points* alt '())))) - (make-immutable-hash (hash->list alt->points*))) + (make-immutable-hasheq (hash->list alt->points*))) (define (atab-add-altn atab altn errs cost) (match-define (alt-table point->alts alt->points alt->done? alt->cost pcontext all-alts) atab) diff --git a/src/core/batch.rkt b/src/core/batch.rkt index ae3beb7a1..ac22d48e1 100644 --- a/src/core/batch.rkt +++ b/src/core/batch.rkt @@ -1,521 +1,170 @@ #lang racket (require "../utils/timeline.rkt" - "../syntax/syntax.rkt" - (only-in "../utils/alternative.rkt" alt)) + "../syntax/syntax.rkt") -(provide (struct-out batch) - exprs->batch ; (List-of Expr) -> Batch - batch->exprs ; Batch -> (List-of Expr) - batch-expr-roots ; Batch -> (Vector-of Node-ptr) - ; batch-get-expr - batch-ref ; Batch -> Alt-idx -> (Vector Node-ptr Event Alt-idx Preprocessing) - expand-taylor ; Batch -> Batch - empty-batch ; Batch - ;nodes->batch - ;batch-add-expr! - #;batch->alts) +(provide progs->batch + batch->progs + (struct-out batch) + batch-length + batch-ref + deref + batch-replace) -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -;; Structure batch +;; This function defines the recursive structure of expressions -;; - roots: a mapping to the current alternatives in the alts -;; - alts: a vector that store alternatives, highly linked to avoid duplications -;; alts[i]: (vector alt-expr - a pointer to node in nodes vector that stores expressions -;; event - an event of the alternative -;; prevs - a pointer to another alt in alts -;; (listof preprocessing)) - a lits of preprocessing for the current alt -;; - nodes: a main dataset of operations the expressions are build on top of -;; - nodes-length: length of nodes vector -;; - vars: list of free variables inside batch -;; - exprhash: a hash that maps a node to its index in nodes vector -;; - altshash: do we need it? +(define (expr-recurse expr f) + (match expr + [(approx spec impl) (approx spec (f impl))] + [(list op args ...) (cons op (map f args))] + [_ expr])) -(struct batch - ([roots #:mutable] [alts #:mutable] - [nodes #:mutable] - [nodes-length #:mutable] - [vars #:mutable] - [exprhash #:mutable] - [altshash #:mutable])) -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Batches store these recursive structures, flattened -;; This function is needed only for debugging -#;(define (batch->alts in-batch) - (for/list ([root (in-vector (batch-roots batch))] - [event (in-vector (batch-events in-batch))] - [prev (in-vector (batch-prevs in-batch))] - [preprocessing (in-vector (batch-preprocessings in-batch))]) - (alt (batch-get-expr in-batch alt-expr) - event - (map (curry batch-get-expr in-batch) prev) - preprocessing))) +(struct batch ([nodes #:mutable] [roots #:mutable] vars)) -(define (exprs->batch exprs - #:timeline-push [timeline-push #f] - #:vars [vars '()] - #:ignore-approx [ignore-approx #t]) - (define icache (reverse vars)) - (define exprhash - (make-hash (for/list ([var vars] - [i (in-naturals)]) - (cons var i)))) - ; Counts - (define size 0) - (define exprc 0) - (define varc (length vars)) - - ; Translates programs into an instruction sequence of operations - (define (munge-ignore-approx prog) - (set! size (+ 1 size)) - (match prog ; approx nodes are ignored - [(approx _ impl) (munge-ignore-approx impl)] - [_ - (define node ; This compiles to the register machine - (match prog - [(list op args ...) (cons op (map munge-ignore-approx args))] - [_ prog])) - (hash-ref! exprhash - node - (lambda () - (begin0 (+ exprc varc) ; store in cache, update exprs, exprc - (set! exprc (+ 1 exprc)) - (set! icache (cons node icache)))))])) +(define (batch-length b) + (cond + [(batch? b) (vector-length (batch-nodes b))] + [(mutable-batch? b) (hash-count (mutable-batch-index b))] + [else (error 'batch-length "Invalid batch" b)])) - ; Translates programs into an instruction sequence of operations - (define (munge-include-approx prog) - (set! size (+ 1 size)) - (define node ; This compiles to the register machine - (match prog - [(approx spec impl) (approx spec (munge-include-approx impl))] - [(list op args ...) (cons op (map munge-include-approx args))] - [_ prog])) - (hash-ref! exprhash - node - (lambda () - (begin0 (+ exprc varc) ; store in cache, update exprs, exprc - (set! exprc (+ 1 exprc)) - (set! icache (cons node icache)))))) - - (define exprs-roots (map (if ignore-approx munge-ignore-approx munge-include-approx) exprs)) - - ; Maybe we do not need althash - (define althash (make-hash)) - (define altc 0) - ; This creates an alternative vector with no history - (define alts - (for/vector ([expr-root (in-list exprs-roots)]) - (vector expr-root '() '() '()))) - (for ([alt (in-vector alts)] - [n (in-naturals)]) - (hash-set! althash alt n)) - (define roots (build-vector (vector-length alts) values)) - - (define nodes (list->vector (reverse icache))) - (define nodes-length (vector-length nodes)) - - (when timeline-push - (timeline-push! 'compiler (+ varc size) (+ exprc varc))) - (batch roots alts nodes nodes-length vars exprhash althash)) - -; Simply recovering alt-expressions from a batch without considering history of previous alts etc.. -(define (batch->exprs batch) - (define roots (batch-roots batch)) - (define nodes (batch-nodes batch)) - (define alts (batch-alts batch)) - - (define (unmunge reg) - (define node (vector-ref nodes reg)) - (match node - [(approx spec impl) (approx spec (unmunge impl))] - [(list '$approx spec impl) - (list '$approx - spec - (unmunge impl))] ; this row is to be deleted and needed only in egg-herbie.rkt - [(list op regs ...) (cons op (map unmunge regs))] - [_ node])) +(struct mutable-batch ([nodes #:mutable] [index #:mutable] [vars #:mutable])) - (define exprs - (for/list ([root (in-vector roots)]) - (define alt-expr (vector-ref (vector-ref alts root) 0)) - (unmunge alt-expr))) - exprs) +(define (make-mutable-batch) + (mutable-batch '() (make-hash) '())) -; Function extracts enodes to in-batch -; nodes: (listof '(cost op arg1-index arg2-index)) -(define (nodes->batch nodes id->spec in-batch) - ; Mapping from nodes to nodes* - (define icache '()) - (define exprhash (hash-copy (batch-exprhash in-batch))) - (define exprc 0) - (define roots '()) +(define (batch-push! b term) + (define hashcons (mutable-batch-index b)) + (hash-ref! hashcons + term + (lambda () + (let ([new-idx (hash-count hashcons)]) + (hash-set! hashcons term new-idx) + (set-mutable-batch-nodes! b (cons term (mutable-batch-nodes b))) + (when (symbol? term) + (set-mutable-batch-vars! b (cons term (mutable-batch-vars b)))) + new-idx)))) - ; Adding a node to hash - (define (append-node node) - (hash-ref! exprhash - node - (lambda () - (begin0 exprc - (set! exprc (+ 1 exprc)) - (set! icache (cons node icache)))))) +(define (mutable-batch->immutable b roots) + (batch (list->vector (reverse (mutable-batch-nodes b))) roots (reverse (mutable-batch-vars b)))) - ; adds nodes to a batch - (define (add-node id) - (match (cdr (vector-ref nodes id)) - [(? number? n) (append-node n)] ; number - [(? symbol? s) (append-node s)] ; variable - [(list '$approx spec impl) ; approx - (match (vector-ref id->spec spec) - [#f (error nodes->batch "no initial approx node in eclass ~a" id)] - [spec-e (append-node (list '$approx spec-e (add-node impl)))])] - ; if expression - [(list 'if cond ift iff) (append-node (list 'if (add-node cond) (add-node ift) (add-node iff)))] - ; expression of impls - [(list (? impl-exists? impl) ids ...) (append-node (cons impl (map add-node ids)))] - ; expression of operators - [(list (? operator-exists? op) ids ...) (append-node (cons op (map add-node ids)))])) +(struct batchref (batch idx)) - ; This function is better to be eliminated because it duplicates add-node at some point - (define (add-root enode) - (define idx - (match enode - [(? number?) (append-node enode)] - [(? symbol?) (append-node enode)] - [(list '$approx spec impl) - (define spec* (vector-ref id->spec spec)) - (unless spec* - (error 'regraph-extract-variants "no initial approx node in eclass")) - (define impl-idx (add-node impl)) - (append-node (list '$approx spec* impl-idx))] - [(list 'if cond ift iff) - (define cond-idx (add-node cond)) - (define ift-idx (add-node ift)) - (define iff-idx (add-node iff)) - (append-node (list 'if cond-idx ift-idx iff-idx))] - [(list (? impl-exists? impl) ids ...) - (define args - (for/list ([id (in-list ids)]) - (add-node id))) - (append-node (cons impl args))] - [(list (? operator-exists? op) ids ...) - (define args - (for/list ([id (in-list ids)]) - (add-node id))) - (append-node (cons op args))])) - (set! roots (cons idx roots))) +(define (deref x) + (match-define (batchref b idx) x) + (expr-recurse (vector-ref (batch-nodes b) idx) (lambda (ref) (batchref b ref)))) - (define (finalize-batch) - (define exprs-roots (remove-duplicates (reverse roots))) - ; Maybe we do not need althash - (define althash (make-hash)) - (define altc 0) - ; This creates an alternative vector with no history - (define alts - (for/vector ([expr-root (in-list exprs-roots)]) - (vector expr-root '() '() '()))) - (for ([alt (in-vector alts)] - [n (in-naturals)]) - (hash-set! althash alt n)) - (define roots* (build-vector (vector-length alts) values)) +(define (progs->batch exprs #:timeline-push [timeline-push #f] #:vars [vars '()]) - (define nodes* (list->vector (reverse icache))) - (define nodes-length (vector-length nodes*)) - (batch roots* alts nodes* nodes-length '() exprhash althash)) + (define out (make-mutable-batch)) + (for ([var (in-list vars)]) + (batch-push! out var)) - (define (clean-batch) - (set! exprc 0) - (set! icache '()) - (set! roots '()) - (set! exprhash (make-hash))) + (define size 0) + (define (munge prog) + (set! size (+ 1 size)) + (batch-push! out (expr-recurse prog munge))) - (values add-root clean-batch finalize-batch)) + (define roots (list->vector (map munge exprs))) + (define final (mutable-batch->immutable out roots)) + (when timeline-push + (timeline-push! 'compiler size (batch-length final))) + final) + +(define (batch->progs b) + (define exprs (make-vector (batch-length b))) + (for ([node (in-vector (batch-nodes b))] + [idx (in-naturals)]) + (vector-set! exprs idx (expr-recurse node (lambda (x) (vector-ref exprs x))))) + (for/list ([root (batch-roots b)]) + (vector-ref exprs root))) + +(define (batch-replace b f) + (define out (make-mutable-batch)) + (define mapping (make-vector (batch-length b) -1)) + (for ([node (in-vector (batch-nodes b))] + [idx (in-naturals)]) + (define replacement (f (expr-recurse node (lambda (x) (batchref b x))))) + (define final-idx + (let loop ([expr replacement]) + (match expr + [(batchref b* idx) + (unless (eq? b* b) + (error 'batch-replace "Replacement ~a references the wrong batch ~a" replacement b*)) + (when (= -1 (vector-ref mapping idx)) + (error 'batch-replace "Replacement ~a references unknown index ~a" replacement idx)) + (vector-ref mapping idx)] + [_ (batch-push! out (expr-recurse expr loop))]))) + (vector-set! mapping idx final-idx)) + (define roots (vector-map (curry vector-ref mapping) (batch-roots b))) + (mutable-batch->immutable out roots)) -(define (expand-taylor input-batch) - (define vars (batch-vars input-batch)) +; The function removes any zombie nodes from batch +(define (remove-zombie-nodes input-batch) (define nodes (batch-nodes input-batch)) + (define roots (batch-roots input-batch)) + (define nodes-length (batch-length input-batch)) + + (define zombie-mask (make-vector nodes-length #t)) + (for ([root (in-vector roots)]) + (vector-set! zombie-mask root #f)) + (for ([node (in-vector nodes (- nodes-length 1) -1 -1)] + [zmb (in-vector zombie-mask (- nodes-length 1) -1 -1)] + #:when (not zmb)) + (match node + [(list op args ...) (map (λ (n) (vector-set! zombie-mask n #f)) args)] + [(approx spec impl) (vector-set! zombie-mask impl #f)] + [_ void])) - ; Hash to avoid duplications - (define icache (reverse vars)) - (define exprhash - (make-hash (for/list ([var vars] - [i (in-naturals)]) - (cons var i)))) - (define exprc 0) - (define varc (length vars)) - - ; Mapping from nodes to nodes* - (define mappings (build-vector (batch-nodes-length input-batch) values)) - - ; Adding a node to hash - (define (append-node node) - (hash-ref! exprhash - node - (lambda () - (begin0 (+ exprc varc) ; store in cache, update exprs, exprc - (set! exprc (+ 1 exprc)) - (set! icache (cons node icache)))))) + (define mappings (build-vector nodes-length values)) - ; Sequential rewriting + (define nodes* '()) (for ([node (in-vector nodes)] + [zmb (in-vector zombie-mask)] [n (in-naturals)]) + (if zmb + (for ([i (in-range n nodes-length)]) + (vector-set! mappings i (sub1 (vector-ref mappings i)))) + (set! nodes* + (cons (match node + [(list op args ...) (cons op (map (curry vector-ref mappings) args))] + [(approx spec impl) (approx spec (vector-ref mappings impl))] + [_ node]) + nodes*)))) + (set! nodes* (list->vector (reverse nodes*))) + (define roots* (vector-map (curry vector-ref mappings) roots)) + (batch nodes* roots* (batch-vars input-batch))) + +(define (batch-ref batch reg) + (define (unmunge reg) + (define node (vector-ref (batch-nodes batch) reg)) (match node - [(list '- arg1 arg2) - (define neg-index (append-node `(neg ,(vector-ref mappings arg2)))) - (vector-set! mappings n (append-node `(+ ,(vector-ref mappings arg1) ,neg-index)))] - [(list 'pow base power) - #:when (equal? (vector-ref nodes power) 1/2) ; 1/2 is to be removed from exprhash, it is zombie - (vector-set! mappings n (append-node `(sqrt ,(vector-ref mappings base))))] - [(list 'pow base power) - #:when (equal? (vector-ref nodes power) 1/3) ; 1/3 is to be removed from exprhash, it is zombie - (vector-set! mappings n (append-node `(cbrt ,(vector-ref mappings base))))] - [(list 'pow base power) - #:when (equal? (vector-ref nodes power) 2/3) ; 2/3 is to be removed from exprhash, it is zombie - (define mult-index (append-node `(* ,(vector-ref mappings base) ,(vector-ref mappings base)))) - (vector-set! mappings n (append-node `(cbrt ,mult-index)))] - [(list 'pow base power) - #:when (exact-integer? (vector-ref nodes power)) - (vector-set! mappings - n - (append-node `(pow ,(vector-ref mappings base) ,(vector-ref mappings power))))] - [(list 'pow base power) - (define log-idx (append-node `(log ,(vector-ref mappings base)))) - (define mult-idx (append-node `(* ,(vector-ref mappings power) ,log-idx))) - (vector-set! mappings n (append-node `(exp ,mult-idx)))] - [(list 'tan args) - (define sin-idx (append-node `(sin ,(vector-ref mappings args)))) - (define cos-idx (append-node `(cos ,(vector-ref mappings args)))) - (vector-set! mappings n (append-node `(/ ,sin-idx ,cos-idx)))] - [(list 'cosh args) - (define exp-idx (append-node `(exp ,(vector-ref mappings args)))) - (define one-idx (append-node 1)) - (define inv-exp-idx (append-node `(/ ,one-idx ,exp-idx))) - (define add-idx (append-node `(+ ,exp-idx ,inv-exp-idx))) - (define half-idx (append-node 1/2)) - (vector-set! mappings n (append-node `(* ,half-idx ,add-idx)))] - [(list 'sinh args) - (define exp-idx (append-node `(exp ,(vector-ref mappings args)))) - (define one-idx (append-node 1)) - (define inv-exp-idx (append-node `(/ ,one-idx ,exp-idx))) - (define neg-idx (append-node `(neg ,inv-exp-idx))) - (define add-idx (append-node `(+ ,exp-idx ,neg-idx))) - (define half-idx (append-node 1/2)) - (vector-set! mappings n (append-node `(* ,half-idx ,add-idx)))] - [(list 'tanh args) - (define exp-idx (append-node `(exp ,(vector-ref mappings args)))) - (define one-idx (append-node 1)) - (define inv-exp-idx (append-node `(/ ,one-idx ,exp-idx))) - (define neg-idx (append-node `(neg ,inv-exp-idx))) - (define add-idx (append-node `(+ ,exp-idx ,inv-exp-idx))) - (define sub-idx (append-node `(+ ,exp-idx ,neg-idx))) - (vector-set! mappings n (append-node `(/ ,sub-idx ,add-idx)))] - [(list 'asinh args) - (define mult-idx (append-node `(* ,(vector-ref mappings args) ,(vector-ref mappings args)))) - (define one-idx (append-node 1)) - (define add-idx (append-node `(+ ,mult-idx ,one-idx))) - (define sqrt-idx (append-node `(sqrt ,add-idx))) - (define add2-idx (append-node `(+ ,(vector-ref mappings args) ,sqrt-idx))) - (vector-set! mappings n (append-node `(log ,add2-idx)))] - [(list 'acosh args) - (define mult-idx (append-node `(* ,(vector-ref mappings args) ,(vector-ref mappings args)))) - (define -one-idx (append-node -1)) - (define add-idx (append-node `(+ ,mult-idx ,-one-idx))) - (define sqrt-idx (append-node `(sqrt ,add-idx))) - (define add2-idx (append-node `(+ ,(vector-ref mappings args) ,sqrt-idx))) - (vector-set! mappings n (append-node `(log ,add2-idx)))] - [(list 'atanh args) - (define neg-idx (append-node `(neg ,(vector-ref mappings args)))) - (define one-idx (append-node 1)) - (define add-idx (append-node `(+ ,one-idx ,(vector-ref mappings args)))) - (define sub-idx (append-node `(+ ,one-idx ,neg-idx))) - (define div-idx (append-node `(/ ,add-idx ,sub-idx))) - (define log-idx (append-node `(log ,div-idx))) - (define half-idx (append-node 1/2)) - (vector-set! mappings n (append-node `(* ,half-idx ,log-idx)))] - [(list op args ...) - (vector-set! mappings n (append-node (cons op (map (curry vector-ref mappings) args))))] - [(approx spec impl) - (vector-set! mappings n (append-node (approx spec (vector-ref mappings impl))))] - [_ (vector-set! mappings n (append-node node))])) - - (define nodes* (list->vector (reverse icache))) - (define roots* - (vector-copy (batch-roots input-batch))) ; roots indexes to alternatives stay the same - (define nodes-length* (vector-length nodes*)) - - ; Remap references to expressions inside alts vector - (define alts* - (for/vector ([alt (in-vector (batch-alts input-batch))]) - (match-define (vector alt-expr event prev preprocessing) alt) - (vector (vector-ref mappings alt-expr) event prev preprocessing))) - - ; This may be too expensive to handle simple 1/2, 1/3 and 2/3 zombie nodes.. - #;(remove-zombie-nodes (batch nodes* roots* vars (vector-length nodes*))) - - (batch roots* alts* nodes* nodes-length* vars exprhash (make-hash))) - -; Updates in-batch by adding new expressions -; Returns list of new roots -#;(define (batch-add-expr! in-batch expr #:ignore-approx [ignore-approx #f]) - (define exprhash (batch-exprhash in-batch)) - (define icache '()) - (define exprc (hash-count exprhash)) - - (define (munge-ignore-approx prog) - (match prog ; approx nodes are ignored - [(approx _ impl) (munge-ignore-approx impl)] - [_ - (define node ; This compiles to the register machine - (match prog - [(list op args ...) (cons op (map munge-ignore-approx args))] - [_ prog])) - (hash-ref! exprhash - node - (lambda () - (begin0 exprc ; store in cache, update exprs, exprc - (set! exprc (+ 1 exprc)) - (set! icache (cons node icache)))))])) - - ; Translates programs into an instruction sequence of operations - (define (munge-include-approx prog) - (define node ; This compiles to the register machine - (match prog - [(approx spec impl) (approx spec (munge-include-approx impl))] - [(list op args ...) (cons op (map munge-include-approx args))] - [_ prog])) - (hash-ref! exprhash - node - (lambda () - (begin0 exprc ; store in cache, update exprs, exprc - (set! exprc (+ 1 exprc)) - (set! icache (cons node icache)))))) - - (define root (if ignore-approx (munge-ignore-approx expr) (munge-include-approx expr))) - (set-batch-alt-exprs! in-batch (vector-append (batch-alt-exprs in-batch) (vector root))) - (set-batch-nodes! in-batch (vector-append (batch-nodes in-batch) (list->vector (reverse icache)))) - (set-batch-nodes-length! in-batch (vector-length (batch-nodes in-batch))) - (set-batch-exprhash! in-batch exprhash) - root) - -; The function removes any zombie nodes from batch -; TODO: reconstruct exprhash -#;(define (remove-zombie-nodes input-batch) - (define nodes (batch-nodes input-batch)) - (define roots (batch-roots input-batch)) - (define nodes-length (batch-nodes-length input-batch)) - - (define zombie-mask (make-vector nodes-length #t)) - (for ([root (in-vector roots)]) - (vector-set! zombie-mask root #f)) - (for ([node (in-vector nodes (- nodes-length 1) -1 -1)] - [zmb (in-vector zombie-mask (- nodes-length 1) -1 -1)] - #:when (not zmb)) - (match node - [(list op args ...) (map (λ (n) (vector-set! zombie-mask n #f)) args)] - [(approx spec impl) (vector-set! zombie-mask impl #f)] - [_ void])) - - (define mappings (build-vector nodes-length values)) - - (define nodes* '()) - (for ([node (in-vector nodes)] - [zmb (in-vector zombie-mask)] - [n (in-naturals)]) - (if zmb - (for ([i (in-range n nodes-length)]) - (vector-set! mappings i (sub1 (vector-ref mappings i)))) - (set! nodes* - (cons (match node - [(list op args ...) (cons op (map (curry vector-ref mappings) args))] - [(approx spec impl) (approx spec (vector-ref mappings impl))] - [_ node]) - nodes*)))) - (set! nodes* (list->vector (reverse nodes*))) - (define roots* (vector-map (curry vector-ref mappings) roots)) - (batch nodes* roots* (batch-vars input-batch) (vector-length nodes*))) - -(define (empty-batch) - (batch (make-vector 0) (make-vector 0) (make-vector 0) 0 '() (make-hash) (make-hash))) - -#;(define (in-batch batch) - (for/stream ([alt-expr (in-vector (batch-alt-exprs batch))] - [event (in-vector (batch-events batch))] - [prev (in-vector (batch-prevs batch))] - [preprocessing (in-vector (batch-preprocessings batch))] - [n (in-naturals)]) - (values n alt-expr event prev preprocessing))) - -(define (batch-ref batch idx) - (define alt (vector-ref (batch-alts batch) idx)) - alt) - -(define (batch-expr-roots batch) - (for/vector ([root (in-vector (batch-roots batch))]) - (define alt (vector-ref (batch-alts batch) root)) - (vector-ref alt 0))) - -; Function returns a recovered expression from nodes with index idx -#;(define (batch-get-expr batch idx) - (define (unmunge idx) - (define node (vector-ref (batch-nodes batch) idx)) - (match node - [(approx spec impl) (approx spec (unmunge impl))] - [(list op regs ...) (cons op (map unmunge regs))] - [_ node])) - (unmunge idx)) - -; Tests for expand-taylor -(module+ test - (require rackunit) - (define (test-expand-taylor expr) - (define batch (exprs->batch (list expr) #:ignore-approx #f)) - (define batch* (expand-taylor batch)) - (car (batch->exprs batch*))) - - (check-equal? '(* 1/2 (log (/ (+ 1 x) (+ 1 (neg x))))) (test-expand-taylor '(atanh x))) - (check-equal? '(log (+ x (sqrt (+ (* x x) -1)))) (test-expand-taylor '(acosh x))) - (check-equal? '(log (+ x (sqrt (+ (* x x) 1)))) (test-expand-taylor '(asinh x))) - (check-equal? '(/ (+ (exp x) (neg (/ 1 (exp x)))) (+ (exp x) (/ 1 (exp x)))) - (test-expand-taylor '(tanh x))) - (check-equal? '(* 1/2 (+ (exp x) (neg (/ 1 (exp x))))) (test-expand-taylor '(sinh x))) - (check-equal? '(+ 1 (neg (+ 2 (neg 3)))) (test-expand-taylor '(- 1 (- 2 3)))) - (check-equal? '(* 1/2 (+ (exp x) (/ 1 (exp x)))) (test-expand-taylor '(cosh x))) - (check-equal? '(/ (sin x) (cos x)) (test-expand-taylor '(tan x))) - (check-equal? '(+ 1 (neg (* 1/2 (+ (exp (/ (sin 3) (cos 3))) (/ 1 (exp (/ (sin 3) (cos 3)))))))) - (test-expand-taylor '(- 1 (cosh (tan 3))))) - (check-equal? '(exp (* a (log x))) (test-expand-taylor '(pow x a))) - (check-equal? '(+ x (sin a)) (test-expand-taylor '(+ x (sin a)))) - (check-equal? '(cbrt x) (test-expand-taylor '(pow x 1/3))) - (check-equal? '(cbrt (* x x)) (test-expand-taylor '(pow x 2/3))) - (check-equal? '(+ 100 (cbrt x)) (test-expand-taylor '(+ 100 (pow x 1/3)))) - (check-equal? `(+ 100 (cbrt (* x ,(approx 2 3)))) - (test-expand-taylor `(+ 100 (pow (* x ,(approx 2 3)) 1/3)))) - (check-equal? `(+ ,(approx 2 3) (cbrt x)) (test-expand-taylor `(+ ,(approx 2 3) (pow x 1/3)))) - (check-equal? `(+ (cbrt x) ,(approx 2 1/3)) (test-expand-taylor `(+ (pow x 1/3) ,(approx 2 1/3))))) + [(approx spec impl) (approx spec (unmunge impl))] + [(list op regs ...) (cons op (map unmunge regs))] + [_ node])) + (unmunge reg)) ; Tests for progs->batch and batch->progs (module+ test (require rackunit) - (define (test-munge-unmunge expr [ignore-approx #t]) - (define batch (exprs->batch expr #:ignore-approx ignore-approx)) - (check-equal? expr (batch->exprs batch))) + (define (test-munge-unmunge expr) + (define batch (progs->batch (list expr))) + (check-equal? (list expr) (batch->progs batch))) - (test-munge-unmunge (list '(* 1/2 (+ (exp x) (neg (/ 1 (exp x))))))) + (test-munge-unmunge '(* 1/2 (+ (exp x) (neg (/ 1 (exp x)))))) (test-munge-unmunge - (list '(+ 1 (neg (* 1/2 (+ (exp (/ (sin 3) (cos 3))) (/ 1 (exp (/ (sin 3) (cos 3)))))))))) - (test-munge-unmunge (list '(cbrt x))) - (test-munge-unmunge (list '(x))) + '(+ 1 (neg (* 1/2 (+ (exp (/ (sin 3) (cos 3))) (/ 1 (exp (/ (sin 3) (cos 3))))))))) + (test-munge-unmunge '(cbrt x)) + (test-munge-unmunge '(x)) (test-munge-unmunge - (list `(+ (sin ,(approx '(* 1/2 (+ (exp x) (neg (/ 1 (exp x))))) '(+ 3 (* 25 (sin 6))))) 4)) - #f) - (test-munge-unmunge (list '())) - (test-munge-unmunge (list '(x) '(cbrt x)))) + `(+ (sin ,(approx '(* 1/2 (+ (exp x) (neg (/ 1 (exp x))))) '(+ 3 (* 25 (sin 6))))) 4))) ; Tests for remove-zombie-nodes -#; (module+ test (require rackunit) (define (zombie-test #:nodes nodes #:roots roots) - (define in-batch (batch nodes roots '() (vector-length nodes) (make-hash))) + (define in-batch (batch nodes roots '())) (define out-batch (remove-zombie-nodes in-batch)) (batch-nodes out-batch)) @@ -535,16 +184,3 @@ (check-equal? (vector 2 1/2 '(sqrt 0) (approx '(* x x) 0) '(pow 1 3)) (zombie-test #:nodes (vector 2 1/2 '(sqrt 0) '(cbrt 0) (approx '(* x x) 0) '(pow 1 4)) #:roots (vector 5 2)))) - -#;(module+ test - (require rackunit) - (define (batch-add!-test exprs) - (define batch (empty-batch)) - (for ([expr (in-list exprs)]) - (batch-add-expr! batch expr)) - (check-equal? exprs (batch->progs batch))) - - (batch-add!-test '((* 3 (pow 5 (tan x))) (* (pow 3) (tan x)))) - (batch-add!-test '((* 3 (pow 5 (tan x))) (* 3 (pow 5 (tan (* (pow 3) (tan x))))) - (* (pow 3) (tan x)))) - (batch-add!-test '((* 2 3) (pow 2 3) (pow 2 (exp 3)) (+ (* 2 3) 1)))) diff --git a/src/core/bsearch.rkt b/src/core/bsearch.rkt index 95cdd67e9..792595874 100644 --- a/src/core/bsearch.rkt +++ b/src/core/bsearch.rkt @@ -44,10 +44,9 @@ (for/fold ([expr (alt-expr (list-ref alts (sp-cidx (last splitpoints))))]) ([splitpoint (cdr (reverse splitpoints))]) (define repr (repr-of (sp-bexpr splitpoint) ctx)) - (define <=-operator (get-parametric-operator '<= repr repr)) - `(if (,<=-operator ,(sp-bexpr splitpoint) - ,(literal (repr->real (sp-point splitpoint) repr) - (representation-name repr))) + (define <=-impl (get-fpcore-impl '<= '() (list repr repr))) + `(if (,<=-impl ,(sp-bexpr splitpoint) + ,(literal (repr->real (sp-point splitpoint) repr) (representation-name repr))) ,(alt-expr (list-ref alts (sp-cidx splitpoint))) ,expr))) @@ -130,8 +129,7 @@ ; Not totally clear if this should actually use the precondition (define start-real-compiler - (and start-prog - (make-real-compiler (list (expand-accelerators (prog->spec start-prog))) (list ctx*)))) + (and start-prog (make-real-compiler (list (prog->spec start-prog)) (list ctx*)))) (define (prepend-macro v) (prepend-argument start-real-compiler v (*pcontext*))) diff --git a/src/core/compiler.rkt b/src/core/compiler.rkt index 7d624455f..787e8c7ad 100644 --- a/src/core/compiler.rkt +++ b/src/core/compiler.rkt @@ -50,16 +50,21 @@ (define (if-proc c a b) (if c a b)) +(define (batch-remove-approx batch) + (batch-replace batch + (lambda (node) + (match node + [(approx spec impl) impl] + [node node])))) + ;; Translates a Herbie IR into an interpretable IR. ;; Requires some hooks to complete the translation. (define (make-compiler exprs vars) (define num-vars (length vars)) - - ; only here we use weird arguments for exprs->batch. Can it be a separate function? - (define batch (exprs->batch exprs #:timeline-push #t #:vars vars #:ignore-approx #t)) + (define batch (batch-remove-approx (progs->batch exprs #:timeline-push #t #:vars vars))) (define instructions - (for/vector #:length (- (batch-nodes-length batch) num-vars) + (for/vector #:length (- (batch-length batch) num-vars) ([node (in-vector (batch-nodes batch) num-vars)]) (match node [(literal value (app get-representation repr)) (list (const (real->repr value repr)))] diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 2567e93c7..8ae3bf839 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -11,6 +11,7 @@ (require "programs.rkt" "rules.rkt" + "../syntax/matcher.rkt" "../syntax/platform.rkt" "../syntax/syntax.rkt" "../syntax/types.rkt" @@ -639,10 +640,13 @@ ;; Nodes are duplicated across their possible types. (define (split-untyped-eclasses egraph-data egg->herbie) (define eclass-ids (egraph-eclasses egraph-data)) - (define egg-id->idx (make-hash)) + (define max-id + (for/fold ([current-max 0]) ([egg-id (in-u32vector eclass-ids)]) + (max current-max egg-id))) + (define egg-id->idx (make-u32vector (+ max-id 1))) (for ([egg-id (in-u32vector eclass-ids)] [idx (in-naturals)]) - (hash-set! egg-id->idx egg-id idx)) + (u32vector-set! egg-id->idx egg-id idx)) (define types (all-reprs/types)) (define type->idx (make-hasheq)) @@ -657,7 +661,7 @@ ; maps (untyped eclass id, type) to typed eclass id (define (lookup-id eid type) - (idx+type->id (hash-ref egg-id->idx eid) type)) + (idx+type->id (u32vector-ref egg-id->idx eid) type)) ; allocate enough eclasses for every (egg-id, type) combination (define n (* (u32vector-length eclass-ids) num-types)) @@ -694,7 +698,7 @@ ; dedup `id->parents` values (for ([id (in-range n)]) (vector-set! id->parents id (list->vector (remove-duplicates (vector-ref id->parents id))))) - (values id->eclass id->parents id->leaf? egg-id->idx type->idx)) + (values id->eclass id->parents id->leaf? eclass-ids egg-id->idx type->idx)) ;; TODO: reachable from roots? ;; Prunes e-nodes that are not well-typed. @@ -751,7 +755,7 @@ [_ (void)])))) ;; Rebuilds eclasses and associated data after pruning. -(define (rebuild-eclasses id->eclass egg-id->idx type->idx) +(define (rebuild-eclasses id->eclass eclass-ids egg-id->idx type->idx) (define n (vector-length id->eclass)) (define remap (make-vector n #f)) @@ -791,7 +795,8 @@ ; build the canonical id map (define egg-id->id (make-hash)) - (for ([(eid idx) (in-hash egg-id->idx)]) + (for ([eid (in-u32vector eclass-ids)]) + (define idx (u32vector-ref egg-id->idx eid)) (define id0 (* idx num-types)) (for ([id (in-range id0 (+ id0 num-types))]) (define id* (vector-ref remap id)) @@ -805,7 +810,7 @@ ;; keeping only the subset of enodes that are well-typed. (define (make-typed-eclasses egraph-data egg->herbie) ;; Step 1: split Rust-eclasses by type - (define-values (id->eclass id->parents id->leaf? egg-id->idx type->idx) + (define-values (id->eclass id->parents id->leaf? eclass-ids egg-id->idx type->idx) (split-untyped-eclasses egraph-data egg->herbie)) ;; Step 2: keep well-typed e-nodes @@ -815,7 +820,7 @@ ;; Step 3: remap e-classes ;; Any empty e-classes must be removed, so we re-map every id - (rebuild-eclasses id->eclass egg-id->idx type->idx)) + (rebuild-eclasses id->eclass eclass-ids egg-id->idx type->idx)) ;; Analyzes eclasses for their properties. ;; The result are vector-maps from e-class ids to data. @@ -1039,6 +1044,19 @@ (define (fraction-with-odd-denominator? frac) (and (rational? frac) (let ([denom (denominator frac)]) (and (> denom 1) (odd? denom))))) +;; Decompose an e-node representing an impl of `(pow b e)`. +;; Returns either `#f` or the `(cons b e)` +(define (pow-impl-args impl args) + (define vars (impl-info impl 'vars)) + (match (impl-info impl 'spec) + [(list 'pow b e) + #:when (set-member? vars e) + (define env (map cons vars args)) + (define b* (dict-ref env b b)) + (define e* (dict-ref env e e)) + (cons b* e*)] + [_ #f])) + ;; Old cost model version (define (default-egg-cost-proc regraph cache node type rec) (match node @@ -1048,12 +1066,12 @@ [(list '$approx _ impl) (rec impl)] [(list 'if cond ift iff) (+ 1 (rec cond) (rec ift) (rec iff))] [(list (? impl-exists? impl) args ...) - (cond - [(equal? (impl->operator impl) 'pow) - (match-define (list b e) args) - (define n (vector-ref (regraph-constants regraph) e)) - (if (fraction-with-odd-denominator? n) +inf.0 (+ 1 (rec b) (rec e)))] - [else (apply + 1 (map rec args))])] + (match (pow-impl-args impl args) + [(cons _ e) + #:when (let ([n (vector-ref (regraph-constants regraph) e)]) + (fraction-with-odd-denominator? n)) + +inf.0] + [_ (apply + 1 (map rec args))])] [(list 'pow b e) (define n (vector-ref (regraph-constants regraph) e)) (if (fraction-with-odd-denominator? n) +inf.0 (+ 1 (rec b) (rec e)))] diff --git a/src/core/localize.rkt b/src/core/localize.rkt index 6da4cba8a..21ecdb540 100644 --- a/src/core/localize.rkt +++ b/src/core/localize.rkt @@ -120,7 +120,7 @@ (for/list ([subexpr (in-list exprs-list)]) (struct-copy context ctx [repr (repr-of subexpr ctx)]))) - (define expr-batch (progs->batch exprs-list #:ignore-approx #f)) + (define expr-batch (progs->batch exprs-list)) (define nodes (batch-nodes expr-batch)) (define roots (batch-roots expr-batch)) (define expr-roots diff --git a/src/core/mainloop.rkt b/src/core/mainloop.rkt index c8850f2ec..dda971923 100644 --- a/src/core/mainloop.rkt +++ b/src/core/mainloop.rkt @@ -10,6 +10,7 @@ "regimes.rkt" "simplify.rkt" "../utils/alternative.rkt" + "../utils/errors.rkt" "../utils/common.rkt" "explain.rkt" "patch.rkt" @@ -353,7 +354,9 @@ [(and (flag-set? 'reduce 'regimes) (> (length alts) 1) (equal? (representation-type repr) 'real) - (not (null? (context-vars ctx)))) + (not (null? (context-vars ctx))) + (with-handlers ([exn:fail:user:herbie:missing? (const #f)]) + (get-fpcore-impl '<= '() (list repr repr)))) (define opts (pareto-regimes (sort alts < #:key (curryr alt-cost repr)) ctx)) (for/list ([opt (in-list opts)]) (combine-alts opt ctx))] diff --git a/src/core/preprocess.rkt b/src/core/preprocess.rkt index 690565c13..7cffac5c9 100644 --- a/src/core/preprocess.rkt +++ b/src/core/preprocess.rkt @@ -21,13 +21,13 @@ (define (has-fabs-neg-impls? repr) (with-handlers ([exn:fail:user:herbie? (const #f)]) - (get-parametric-operator 'neg repr) - (get-parametric-operator 'fabs repr) + (get-fpcore-impl '- (repr->prop repr) (list repr)) + (get-fpcore-impl 'fabs (repr->prop repr) (list repr)) #t)) (define (has-copysign-impl? repr) (with-handlers ([exn:fail:user:herbie? (const #f)]) - (get-parametric-operator 'copysign repr repr) + (get-fpcore-impl 'copysign (repr->prop repr) (list repr repr)) #t)) ;; The even identities: f(x) = f(-x) @@ -168,18 +168,21 @@ (values (list-set* x indices sorted) y)))] [(list 'abs variable) (define index (index-of variables variable)) - (define abs - (impl-info (get-parametric-operator 'fabs (list-ref (context-var-reprs context) index)) 'fl)) - (lambda (x y) (values (list-update x index abs) y))] + (define var-repr (context-lookup context variable)) + (define abs-proc (impl-info (get-fpcore-impl 'fabs (repr->prop var-repr) (list var-repr)) 'fl)) + (lambda (x y) (values (list-update x index abs-proc) y))] [(list 'negabs variable) (define index (index-of variables variable)) - (define negate-variable - (impl-info (get-parametric-operator 'neg (list-ref (context-var-reprs context) index)) 'fl)) - (define negate-expression (impl-info (get-parametric-operator 'neg (context-repr context)) 'fl)) + (define var-repr (context-lookup context variable)) + (define neg-var (impl-info (get-fpcore-impl '- (repr->prop var-repr) (list var-repr)) 'fl)) + + (define repr (context-repr context)) + (define neg-expr (impl-info (get-fpcore-impl '- (repr->prop repr) (list repr)) 'fl)) + (lambda (x y) ;; Negation is involutive, i.e. it is its own inverse, so t^1(y') = -y' (if (negative? (repr->real (list-ref x index) (context-repr context))) - (values (list-update x index negate-variable) (negate-expression y)) + (values (list-update x index neg-var) (neg-expr y)) (values x y)))])) ; until fixed point, iterate through preprocessing attempting to drop preprocessing with no effect on error diff --git a/src/core/rules.rkt b/src/core/rules.rkt index 6ac5b4f7d..7fbcf1e81 100644 --- a/src/core/rules.rkt +++ b/src/core/rules.rkt @@ -93,19 +93,8 @@ (for ([rule (in-list rules)]) (sow rule)))))) -;; Spec contains no accelerators -(define (spec-has-accelerator? spec) - (match spec - [(list (? operator-accelerator?) _ ...) #t] - [(list _ args ...) (ormap spec-has-accelerator? args)] - [_ #f])) - (define (real-rules rules) - (filter-not (lambda (rule) - (or (representation? (rule-otype rule)) - (spec-has-accelerator? (rule-input rule)) - (spec-has-accelerator? (rule-output rule)))) - rules)) + (filter-not (lambda (rule) (representation? (rule-otype rule))) rules)) ;; ;; Rule loading @@ -732,32 +721,34 @@ [asinh-2 (acosh (+ (* 2 (* x x)) 1)) (* 2 (asinh x))] [acosh-2 (acosh (- (* 2 (* x x)) 1)) (* 2 (acosh x))]) -; Specialized numerical functions -(define-ruleset* special-numerical-reduce - (numerics simplify) - #:type ([x real] [y real] [z real]) - [log1p-expm1 (log1p (expm1 x)) x] - [hypot-1-def (sqrt (+ 1 (* y y))) (hypot 1 y)] - [fmm-def (- (* x y) z) (fma x y (neg z))] - [fmm-undef (fma x y (neg z)) (- (* x y) z)]) - -(define-ruleset* special-numerical-expand - (numerics) - #:type ([x real] [y real]) - [log1p-expm1-u x (log1p (expm1 x))] - [expm1-log1p-u x (expm1 (log1p x))]) - (define-ruleset* erf-rules (special simplify) #:type ([x real]) [erf-odd (erf (neg x)) (neg (erf x))]) -(define-ruleset* numerics-papers - (numerics) - #:type ([a real] [b real] [c real] [d real]) - ; "Further Analysis of Kahan's Algorithm for - ; the Accurate Computation of 2x2 Determinants" - ; Jeannerod et al., Mathematics of Computation, 2013 - ; - ; a * b - c * d ===> fma(a, b, -(d * c)) + fma(-d, c, d * c) - [prod-diff (- (* a b) (* c d)) (+ (fma a b (neg (* d c))) (fma (neg d) c (* d c)))]) +; Specialized numerical functions +; TODO: These are technically rules over impls +; +; (define-ruleset* special-numerical-reduce +; (numerics simplify) +; #:type ([x real] [y real] [z real]) +; [log1p-expm1 (log1p (expm1 x)) x] +; [hypot-1-def (sqrt (+ 1 (* y y))) (hypot 1 y)] +; [fmm-def (- (* x y) z) (fma x y (neg z))] +; [fmm-undef (fma x y (neg z)) (- (* x y) z)]) + +; (define-ruleset* special-numerical-expand +; (numerics) +; #:type ([x real] [y real]) +; [log1p-expm1-u x (log1p (expm1 x))] +; [expm1-log1p-u x (expm1 (log1p x))]) + +; (define-ruleset* numerics-papers +; (numerics) +; #:type ([a real] [b real] [c real] [d real]) +; ; "Further Analysis of Kahan's Algorithm for +; ; the Accurate Computation of 2x2 Determinants" +; ; Jeannerod et al., Mathematics of Computation, 2013 +; ; +; ; a * b - c * d ===> fma(a, b, -(d * c)) + fma(-d, c, d * c) +; [prod-diff (- (* a b) (* c d)) (+ (fma a b (neg (* d c))) (fma (neg d) c (* d c)))]) ;; Sound because it's about soundness over real numbers (define-ruleset* compare-reduce diff --git a/src/core/taylor.rkt b/src/core/taylor.rkt index 377866490..dc1a5fe9d 100644 --- a/src/core/taylor.rkt +++ b/src/core/taylor.rkt @@ -34,6 +34,59 @@ (simplify (make-horner ((cdr tform) var) (reverse terms)))])) next)) +;; Our Taylor expander prefers sin, cos, exp, log, neg over trig, htrig, pow, and subtraction +(define (expand-taylor input-batch) + (batch-replace + input-batch + (lambda (node) + (match node + [(list '- ref1 ref2) `(+ ,ref1 (neg ,ref2))] + [(list 'pow base (app deref 1/2)) `(sqrt ,base)] + [(list 'pow base (app deref 1/3)) `(cbrt ,base)] + [(list 'pow base (app deref 2/3)) `(cbrt (* ,base ,base))] + [(list 'pow base power) + #:when (exact-integer? (deref power)) + `(pow ,base ,power)] + [(list 'pow base power) `(exp (* ,power (log ,base)))] + [(list 'tan arg) `(/ (sin ,arg) (cos ,arg))] + [(list 'cosh arg) `(* 1/2 (+ (exp ,arg) (/ 1 (exp ,arg))))] + [(list 'sinh arg) `(* 1/2 (+ (exp ,arg) (/ -1 (exp ,arg))))] + [(list 'tanh arg) `(/ (+ (exp ,arg) (neg (/ 1 (exp ,arg)))) (+ (exp ,arg) (/ 1 (exp ,arg))))] + [(list 'asinh arg) `(log (+ ,arg (sqrt (+ (* ,arg ,arg) 1))))] + [(list 'acosh arg) `(log (+ ,arg (sqrt (+ (* ,arg ,arg) -1))))] + [(list 'atanh arg) `(* 1/2 (log (/ (+ 1 ,arg) (+ 1 (neg ,arg)))))] + [_ node])))) + +; Tests for expand-taylor +(module+ test + (require rackunit) + + (define (test-expand-taylor expr) + (define batch (progs->batch (list expr))) + (define batch* (expand-taylor batch)) + (car (batch->progs batch*))) + + (check-equal? '(* 1/2 (log (/ (+ 1 x) (+ 1 (neg x))))) (test-expand-taylor '(atanh x))) + (check-equal? '(log (+ x (sqrt (+ (* x x) -1)))) (test-expand-taylor '(acosh x))) + (check-equal? '(log (+ x (sqrt (+ (* x x) 1)))) (test-expand-taylor '(asinh x))) + (check-equal? '(/ (+ (exp x) (neg (/ 1 (exp x)))) (+ (exp x) (/ 1 (exp x)))) + (test-expand-taylor '(tanh x))) + (check-equal? '(* 1/2 (+ (exp x) (/ -1 (exp x)))) (test-expand-taylor '(sinh x))) + (check-equal? '(+ 1 (neg (+ 2 (neg 3)))) (test-expand-taylor '(- 1 (- 2 3)))) + (check-equal? '(* 1/2 (+ (exp x) (/ 1 (exp x)))) (test-expand-taylor '(cosh x))) + (check-equal? '(/ (sin x) (cos x)) (test-expand-taylor '(tan x))) + (check-equal? '(+ 1 (neg (* 1/2 (+ (exp (/ (sin 3) (cos 3))) (/ 1 (exp (/ (sin 3) (cos 3)))))))) + (test-expand-taylor '(- 1 (cosh (tan 3))))) + (check-equal? '(exp (* a (log x))) (test-expand-taylor '(pow x a))) + (check-equal? '(+ x (sin a)) (test-expand-taylor '(+ x (sin a)))) + (check-equal? '(cbrt x) (test-expand-taylor '(pow x 1/3))) + (check-equal? '(cbrt (* x x)) (test-expand-taylor '(pow x 2/3))) + (check-equal? '(+ 100 (cbrt x)) (test-expand-taylor '(+ 100 (pow x 1/3)))) + (check-equal? `(+ 100 (cbrt (* x ,(approx 2 3)))) + (test-expand-taylor `(+ 100 (pow (* x ,(approx 2 3)) 1/3)))) + (check-equal? `(+ ,(approx 2 3) (cbrt x)) (test-expand-taylor `(+ ,(approx 2 3) (pow x 1/3)))) + (check-equal? `(+ (cbrt x) ,(approx 2 1/3)) (test-expand-taylor `(+ (pow x 1/3) ,(approx 2 1/3))))) + (define (make-horner var terms [start 0]) (match terms ['() 0] @@ -77,7 +130,7 @@ (define (taylor var expr-batch) "Return a pair (e, n), such that expr ~= e var^n" (define nodes (batch-nodes expr-batch)) - (define taylor-approxs (make-vector (batch-nodes-length expr-batch))) ; vector of approximations + (define taylor-approxs (make-vector (batch-length expr-batch))) ; vector of approximations (for ([node (in-vector nodes)] [n (in-naturals)]) diff --git a/src/platforms/binary32.rkt b/src/platforms/binary32.rkt index a4f215108..e7f5142db 100644 --- a/src/platforms/binary32.rkt +++ b/src/platforms/binary32.rkt @@ -51,27 +51,39 @@ #:spec (neg x) #:fpcore (! :precision binary32 (- x)) #:fl fl32-) + (define-operator-impl (+.f32 [x : binary32] [y : binary32]) binary32 #:spec (+ x y) #:fpcore (! :precision binary32 (+ x y)) #:fl fl32+) + (define-operator-impl (-.f32 [x : binary32] [y : binary32]) binary32 #:spec (- x y) #:fpcore (! :precision binary32 (- x y)) #:fl fl32-) + (define-operator-impl (*.f32 [x : binary32] [y : binary32]) binary32 #:spec (* x y) #:fpcore (! :precision binary32 (* x y)) #:fl fl32*) + (define-operator-impl (/.f32 [x : binary32] [y : binary32]) binary32 #:spec (/ x y) #:fpcore (! :precision binary32 (/ x y)) #:fl fl32/) +(define-comparator-impls binary32 + [== ==.f32 =] + [!= !=.f32 (negate =)] + [< <.f32 <] + [> >.f32 >] + [<= <=.f32 <=] + [>= >=.f32 >=]) + (define-libm-impls/binary32 [(binary32 binary32) (acos acosh asin @@ -83,16 +95,13 @@ cos cosh erf - erfc exp exp2 - expm1 fabs floor lgamma log log10 - log1p log2 logb rint @@ -105,27 +114,57 @@ tgamma trunc)] [(binary32 binary32 binary32) - (atan2 copysign fdim fmax fmin fmod hypot pow remainder)] - [(binary32 binary32 binary32 binary32) (fma)]) - -(define-comparator-impls binary32 - [== ==.f32 =] - [!= !=.f32 (negate =)] - [< <.f32 <] - [> >.f32 >] - [<= <=.f32 <=] - [>= >=.f32 >=]) + (atan2 copysign fdim fmax fmin fmod pow remainder)]) + +(define-libm c_erfcf (erfc float float)) +(define-libm c_expm1f (expm1 float float)) +(define-libm c_log1pf (log1p float float)) +(define-libm c_hypotf (hypot float float float)) +(define-libm c_fmaf (fma float float float float)) + +(when c_erfcf + (define-operator-impl (erfc.f32 [x : binary32]) + binary32 + #:spec (- 1 (erf x)) + #:fpcore (! :precision binary32 (erfc x)) + #:fl c_erfcf)) + +(when c_expm1f + (define-operator-impl (expm1.f32 [x : binary32]) + binary32 + #:spec (- (exp x) 1) + #:fpcore (! :precision binary32 (expm1 x)) + #:fl c_expm1f)) + +(when c_log1pf + (define-operator-impl (log1p.f32 [x : binary32]) + binary32 + #:spec (log (+ 1 x)) + #:fpcore (! :precision binary32 (log1p x)) + #:fl c_log1pf)) + +(when c_hypotf + (define-operator-impl (hypot.f32 [x : binary32] [y : binary32]) + binary32 + #:spec (sqrt (+ (* x x) (* y y))) + #:fpcore (! :precision binary32 (hypot x y)) + #:fl c_hypotf)) + +(when c_fmaf + (define-operator-impl (fma.f32 [x : binary32] [y : binary32] [z : binary32]) + binary32 + #:spec (+ (* x y) z) + #:fpcore (! :precision binary32 (fma x y z)) + #:fl c_fmaf)) (define-operator-impl (binary64->binary32 [x : binary64]) binary32 #:spec x #:fpcore (! :precision binary32 (cast x)) - #:fl (curryr ->float32) - #:op cast) + #:fl (curryr ->float32)) (define-operator-impl (binary32->binary64 [x : binary32]) binary64 #:spec x #:fpcore (! :precision binary64 (cast x)) - #:fl identity - #:op cast) + #:fl identity) diff --git a/src/platforms/binary64.rkt b/src/platforms/binary64.rkt index 620a0aa1c..e42705515 100644 --- a/src/platforms/binary64.rkt +++ b/src/platforms/binary64.rkt @@ -51,21 +51,25 @@ #:spec (neg x) #:fpcore (! :precision binary64 (- x)) #:fl -) + (define-operator-impl (+.f64 [x : binary64] [y : binary64]) binary64 #:spec (+ x y) #:fpcore (! :precision binary64 (+ x y)) #:fl +) + (define-operator-impl (-.f64 [x : binary64] [y : binary64]) binary64 #:spec (- x y) #:fpcore (! :precision binary64 (- x y)) #:fl -) + (define-operator-impl (*.f64 [x : binary64] [y : binary64]) binary64 #:spec (* x y) #:fpcore (! :precision binary64 (* x y)) #:fl *) + (define-operator-impl (/.f64 [x : binary64] [y : binary64]) binary64 #:spec (/ x y) @@ -83,16 +87,13 @@ cos cosh erf - erfc exp exp2 - expm1 fabs floor lgamma log log10 - log1p log2 logb rint @@ -105,8 +106,48 @@ tgamma trunc)] [(binary64 binary64 binary64) - (atan2 copysign fdim fmax fmin fmod hypot pow remainder)] - [(binary64 binary64 binary64 binary64) (fma)]) + (atan2 copysign fdim fmax fmin fmod pow remainder)]) + +(define-libm c_erfc (erfc double double)) +(define-libm c_expm1 (expm1 double double)) +(define-libm c_log1p (log1p double double)) +(define-libm c_hypot (hypot double double double)) +(define-libm c_fma (fma double double double double)) + +(when c_erfc + (define-operator-impl (erfc.f64 [x : binary64]) + binary64 + #:spec (- 1 (erf x)) + #:fpcore (! :precision binary64 (erfc x)) + #:fl c_erfc)) + +(when c_expm1 + (define-operator-impl (expm1.f64 [x : binary64]) + binary64 + #:spec (- (exp x) 1) + #:fpcore (! :precision binary64 (expm1 x)) + #:fl c_expm1)) + +(when c_log1p + (define-operator-impl (log1p.f64 [x : binary64]) + binary64 + #:spec (log (+ 1 x)) + #:fpcore (! :precision binary64 (log1p x)) + #:fl c_log1p)) + +(when c_hypot + (define-operator-impl (hypot.f64 [x : binary64] [y : binary64]) + binary64 + #:spec (sqrt (+ (* x x) (* y y))) + #:fpcore (! :precision binary64 (hypot x y)) + #:fl c_hypot)) + +(when c_fma + (define-operator-impl (fma.f64 [x : binary64] [y : binary64] [z : binary64]) + binary64 + #:spec (+ (* x y) z) + #:fpcore (! :precision binary64 (fma x y z)) + #:fl c_fma)) (define-comparator-impls binary64 [== ==.f64 =] diff --git a/src/platforms/default.rkt b/src/platforms/default.rkt index 7e69dd125..4d088478d 100644 --- a/src/platforms/default.rkt +++ b/src/platforms/default.rkt @@ -80,7 +80,6 @@ lgamma.f64 log.f64 log10.f64 - log1p.f64 log2.f64 logb.f64 pow.f64 @@ -122,7 +121,6 @@ lgamma.f32 log.f32 log10.f32 - log1p.f32 log2.f32 logb.f32 pow.f32 diff --git a/src/platforms/fallback.rkt b/src/platforms/fallback.rkt index ffe054382..10a768e29 100644 --- a/src/platforms/fallback.rkt +++ b/src/platforms/fallback.rkt @@ -4,7 +4,6 @@ (require math/base math/bigfloat - math/flonum math/special-functions) (require "runtime/utils.rkt") @@ -37,8 +36,8 @@ (define-syntax-rule (define-2ary-fallback-operator op fn) (define-fallback-operator (op [x : binary64] [y : binary64]) - #:spec (op x) - #:fpcore (! :precision binary64 :math-library racket (op x)) + #:spec (op x y) + #:fpcore (! :precision binary64 :math-library racket (op x y)) #:fl fn)) (define-syntax-rule (define-1ary-fallback-operators [op fn] ...) @@ -75,16 +74,13 @@ [cos cos] [cosh cosh] [erf (no-complex erf)] - [erfc erfc] [exp exp] [exp2 (no-complex (λ (x) (expt 2 x)))] - [expm1 (from-bigfloat bfexpm1)] [fabs abs] [floor floor] [lgamma log-gamma] [log (no-complex log)] [log10 (no-complex (λ (x) (log x 10)))] - [log1p (from-bigfloat bflog1p)] [log2 (from-bigfloat bflog2)] [logb (λ (x) (floor (bigfloat->flonum (bflog2 (bf (abs x))))))] [rint round] @@ -117,16 +113,38 @@ [(nan? y) x] [else (min x y)]))] [fmod (from-bigfloat bffmod)] - [hypot (from-bigfloat bfhypot)] [pow (no-complex expt)] [remainder remainder]) +(define-operator-impl (erfc.rkt [x : binary64]) + binary64 + #:spec (- 1 (erf x)) + #:fpcore (! :precision binary64 :math-library racket (erfc x)) + #:fl erfc) + +(define-operator-impl (expm1.rkt [x : binary64]) + binary64 + #:spec (- (exp x) 1) + #:fpcore (! :precision binary64 :math-library racket (expm1 x)) + #:fl (from-bigfloat bfexpm1)) + +(define-operator-impl (log1p.rkt [x : binary64]) + binary64 + #:spec (log (+ 1 x)) + #:fpcore (! :precision binary64 :math-library racket (log1p x)) + #:fl (from-bigfloat bflog1p)) + +(define-operator-impl (hypot.rkt [x : binary64] [y : binary64]) + binary64 + #:spec (sqrt (+ (* x x) (* y y))) + #:fpcore (! :precision binary64 :math-library racket (hypot x y)) + #:fl (from-bigfloat bfhypot)) + (define-operator-impl (fma.rkt [x : binary64] [y : binary64] [z : binary64]) binary64 #:spec (+ (* x y) z) #:fpcore (! :precision binary64 :math-library racket (fma x y z)) - #:fl (from-bigfloat bffma) - #:op fma) + #:fl (from-bigfloat bffma)) (define-comparator-impls binary64 [== ==.rkt =] diff --git a/src/platforms/runtime/libm.rkt b/src/platforms/runtime/libm.rkt index 0659e71f3..50e5d79f3 100644 --- a/src/platforms/runtime/libm.rkt +++ b/src/platforms/runtime/libm.rkt @@ -51,18 +51,30 @@ [integer #'integer] [_ (oops! "unknown type" repr)])) (syntax-case stx () - [(_ cname (id name itype ...) otype attrib ...) - (begin - (unless (identifier? #'cname) - (oops! "expected identifier" #'cname)) - (unless (identifier? #'id) - (oops! "expected identifier" #'id)) - (unless (identifier? #'name) - (oops! "expected identifier" #'name)) - (with-syntax ([(citype ...) (map repr->type (syntax->list #'(itype ...)))] - [cotype (repr->type #'otype)] - [(var ...) (generate-temporaries #'(itype ...))]) + [(_ cname (op name itype ...) otype fields ...) + (let ([op #'op] + [name #'name] + [cname #'cname] + [itypes (syntax->list #'(itype ...))]) + (unless (identifier? op) + (oops! "expected identifier" op)) + (unless (identifier? name) + (oops! "expected identifier" name)) + (unless (identifier? cname) + (oops! "expected identifier" cname)) + (with-syntax ([op op] + [name name] + [cname cname] + [(var ...) (build-list (length itypes) + (lambda (i) (string->symbol (format "x~a" i))))] + [(itype ...) itypes] + [(citype ...) (map repr->type itypes)] + [cotype (repr->type #'otype)]) #'(begin (define-libm proc (cname citype ... cotype)) (when proc - (define-operator-impl (name [var : itype] ...) otype #:fl proc #:op id)))))])) + (define-operator-impl (name [var : itype] ...) + otype + #:spec (op var ...) + #:fl proc + fields ...)))))])) diff --git a/src/reports/common.rkt b/src/reports/common.rkt index 5f78d5c42..149ad372a 100644 --- a/src/reports/common.rkt +++ b/src/reports/common.rkt @@ -56,7 +56,7 @@ (write-xexpr xexpr out)) (define (program->fpcore expr ctx #:ident [ident #f]) - (define body (prog->fpcore expr)) + (define body (prog->fpcore expr ctx)) (if ident (list 'FPCore ident (context-vars ctx) body) (list 'FPCore (context-vars ctx) body))) (define (fpcore-add-props core props) @@ -164,8 +164,8 @@ (define r (list-ref (context-var-reprs c) p)) (define c* (struct-copy context c [vars (list-set (context-vars c) p x*)])) (define c** (context-extend c* x-sign r)) - (define e* - (list (get-parametric-operator '* r (context-repr c)) x-sign (replace-expression e x x*))) + (define *-impl (get-fpcore-impl '* (repr->prop (context-repr c)) (list r (context-repr c)))) + (define e* (list *-impl x-sign (replace-expression e x x*))) (cons e* c**)] [_ (cons e c)])) @@ -185,16 +185,19 @@ [(list 'abs x) (define x* (string->symbol (string-append (symbol->string x) "_m"))) (define r (list-ref (context-var-reprs ctx) (index-of (context-vars ctx) x))) - (define e (list (get-parametric-operator 'fabs r) x)) + (define fabs-impl (get-fpcore-impl 'fabs (repr->prop r) (list r))) + (define e (list fabs-impl x)) (define c (context (list x) r r)) (format "~a = ~a" x* (converter* e c))] [(list 'negabs x) + ; TODO: why are x* and x-sign unused? (define x* (string->symbol (format "~a_m" x))) (define r (context-lookup ctx x)) - (define p (representation-name r)) - (define e* (list (get-parametric-operator 'fabs r) x)) + (define fabs-impl (get-fpcore-impl 'fabs (repr->prop r) (list r))) + (define copysign-impl (get-fpcore-impl 'copysign (repr->prop r) (list r r))) + (define e* (list fabs-impl x)) (define x-sign (string->symbol (format "~a_s" x))) - (define e-sign (list (get-parametric-operator 'copysign r r) (literal 1 p) x)) + (define e-sign (list copysign-impl (literal 1 (representation-name r)) x)) (define c (context (list x) r r)) (list (format "~a = ~a" (format "~a\\_m" x) (converter* e* c)) (format "~a = ~a" (format "~a\\_s" x) (converter* e-sign c)))] @@ -294,8 +297,9 @@ `(div ,(if (equal? precondition '(TRUE)) "" - `(div ([id "precondition"]) - (div ((class "program math")) "\\[" ,(expr->tex (prog->fpcore precondition)) "\\]"))) + `(div + ([id "precondition"]) + (div ((class "program math")) "\\[" ,(expr->tex (prog->fpcore precondition ctx)) "\\]"))) (div ((class "implementation") [data-language "Math"]) (div ((class "program math")) "\\[" ,math-out "\\]")) ,@(for/list ([(lang out) (in-dict versions)]) @@ -317,23 +321,25 @@ (-> test? string?) (define output-repr (test-output-repr test)) (string-join - (filter - identity - (list - (if (test-identifier test) - (format "(FPCore ~a ~a" (test-identifier test) (test-vars test)) - (format "(FPCore ~a" (test-vars test))) - (format " :name ~s" (test-name test)) - (format " :precision ~s" (representation-name (test-output-repr test))) - (if (equal? (test-pre test) '(TRUE)) #f (format " :pre ~a" (prog->fpcore (test-pre test)))) - (if (equal? (test-expected test) #t) #f (format " :herbie-expected ~a" (test-expected test))) - (and (test-output test) - (not (null? (test-output test))) - (format "\n~a" - (string-join (map (lambda (exp) (format " :alt\n ~a\n" (car exp))) - (test-output test)) - "\n"))) - (format " ~a)" (prog->fpcore (test-input test))))) + (filter identity + (list (if (test-identifier test) + (format "(FPCore ~a ~a" (test-identifier test) (test-vars test)) + (format "(FPCore ~a" (test-vars test))) + (format " :name ~s" (test-name test)) + (format " :precision ~s" (representation-name (test-output-repr test))) + (if (equal? (test-pre test) '(TRUE)) + #f + (format " :pre ~a" (prog->fpcore (test-pre test) (test-context test)))) + (if (equal? (test-expected test) #t) + #f + (format " :herbie-expected ~a" (test-expected test))) + (and (test-output test) + (not (null? (test-output test))) + (format "\n~a" + (string-join (map (lambda (exp) (format " :alt\n ~a\n" (car exp))) + (test-output test)) + "\n"))) + (format " ~a)" (prog->fpcore (test-input test) (test-context test))))) "\n")) (define (format-percent num den) diff --git a/src/reports/history.rkt b/src/reports/history.rkt index d7ddf078a..6ef558b5c 100644 --- a/src/reports/history.rkt +++ b/src/reports/history.rkt @@ -86,8 +86,16 @@ [(? number?) expr] [(? literal?) (literal-value expr)] [(approx _ impl) (loop impl)] - [`(if ,cond ,ift ,iff) `(if ,(loop cond) ,(loop ift) ,(loop ift))] - [`(,(? impl-exists? impl) ,args ...) `(,(impl->operator impl) ,@(map loop args))] + [`(if ,cond ,ift ,iff) `(if ,(loop cond) ,(loop ift) ,(loop iff))] + [`(,(? impl-exists? impl) ,args ...) + ; use the FPCore operator without rounding properties + (define args* (map loop args)) + (define vars (impl-info impl 'vars)) + (define pattern + (match (impl-info impl 'fpcore) + [(list '! _ ... body) body] + [body body])) + (replace-vars (map cons vars args*) pattern)] [`(,op ,args ...) `(,op ,@(map loop args))]))) `(FPCore ,(context-vars ctx) ,expr*)) diff --git a/src/reports/make-graph.rkt b/src/reports/make-graph.rkt index 1756472d8..66ccafda7 100644 --- a/src/reports/make-graph.rkt +++ b/src/reports/make-graph.rkt @@ -6,16 +6,12 @@ (require "../utils/common.rkt" "../core/points.rkt" "../utils/float.rkt" - "../core/programs.rkt" "../utils/alternative.rkt" "../syntax/types.rkt" "../syntax/read.rkt" "../core/bsearch.rkt" "../api/sandbox.rkt" - "common.rkt" - "history.rkt" - "../syntax/sugar.rkt" - "timeline.rkt") + "common.rkt") (provide make-graph dummy-graph) @@ -36,20 +32,18 @@ [(list op args ...) (ormap list? args)] [_ #f])) -(define (dummy-graph command out) - (write-html - `(html (head (meta ([charset "utf-8"])) - (title "Result page for the " ,(~a command) " command is not available right now.") - ,@js-tex-include - (script ([src "https://unpkg.com/mathjs@4.4.2/dist/math.min.js"])) - (script ([src "https://unpkg.com/d3@6.7.0/dist/d3.min.js"])) - (script ([src "https://unpkg.com/@observablehq/plot@0.4.3/dist/plot.umd.min.js"])) - (link ([rel "stylesheet"] [type "text/css"] [href "../report.css"])) - (script ([src "../report.js"]))) - (body (h2 "Result page for the " ,(~a command) " command is not available right now."))) - out)) - -(define (make-graph result-hash out output? profile?) +(define (dummy-graph command) + `(html (head (meta ([charset "utf-8"])) + (title "Result page for the " ,(~a command) " command is not available right now.") + ,@js-tex-include + (script ([src "https://unpkg.com/mathjs@4.4.2/dist/math.min.js"])) + (script ([src "https://unpkg.com/d3@6.7.0/dist/d3.min.js"])) + (script ([src "https://unpkg.com/@observablehq/plot@0.4.3/dist/plot.umd.min.js"])) + (link ([rel "stylesheet"] [type "text/css"] [href "../report.css"])) + (script ([src "../report.js"]))) + (body (h2 "Result page for the " ,(~a command) " command is not available right now.")))) + +(define (make-graph result-hash output? profile?) (define backend (hash-ref result-hash 'backend)) (define test (hash-ref result-hash 'test)) (define time (hash-ref result-hash 'time)) @@ -77,13 +71,12 @@ (for/list ([target targets]) (alt-cost (alt-analysis-alt target) repr))) - (define end-alts (hash-ref end 'end-alts)) + (define end-exprs (hash-ref end 'end-exprs)) (define end-errors (hash-ref end 'end-errors)) (define end-costs (hash-ref end 'end-costs)) (define speedup - (let ([better (for/list ([alt end-alts] - [err end-errors] + (let ([better (for/list ([err end-errors] [cost end-costs] #:when (<= (errors-score err) (errors-score start-error))) (/ start-cost cost))]) @@ -91,124 +84,120 @@ (define end-error (car end-errors)) - (write-html - `(html - (head (meta ([charset "utf-8"])) - (title "Result for " ,(~a (test-name test))) - ,@js-tex-include - (script ([src "https://unpkg.com/mathjs@4.4.2/dist/math.min.js"])) - (script ([src "https://unpkg.com/d3@6.7.0/dist/d3.min.js"])) - (script ([src "https://unpkg.com/@observablehq/plot@0.4.3/dist/plot.umd.min.js"])) - (link ([rel "stylesheet"] [type "text/css"] [href "../report.css"])) - (script ([src "../report.js"]))) - (body - ,(render-menu #:path ".." - (~a (test-name test)) - (if output? - (list '("Report" . "../index.html") '("Metrics" . "timeline.html")) - (list '("Metrics" . "timeline.html")))) - (div ([id "large"]) - ,(render-comparison - "Percentage Accurate" - (format-accuracy (errors-score start-error) repr-bits #:unit "%") - (format-accuracy (errors-score end-error) repr-bits #:unit "%") - #:title - (format "Minimum Accuracy: ~a → ~a" - (format-accuracy (apply max (map ulps->bits start-error)) repr-bits #:unit "%") - (format-accuracy (apply max (map ulps->bits end-error)) repr-bits #:unit "%"))) - ,(render-large "Time" (format-time time)) - ,(render-large "Alternatives" (~a (length end-alts))) - ,(if (*pareto-mode*) - (render-large "Speedup" - (if speedup (~r speedup #:precision '(= 1)) "N/A") - "×" - #:title "Relative speed of fastest alternative that improves accuracy.") - "")) - ,(render-warnings warnings) - ,(render-specification test #:bogosity bogosity) - (figure ([id "graphs"]) - (h2 "Local Percentage Accuracy vs " - (span ([id "variables"])) - (a ((class "help-button float") [href ,(doc-url "report.html#graph")] + `(html + (head (meta ([charset "utf-8"])) + (title "Result for " ,(~a (test-name test))) + ,@js-tex-include + (script ([src "https://unpkg.com/mathjs@4.4.2/dist/math.min.js"])) + (script ([src "https://unpkg.com/d3@6.7.0/dist/d3.min.js"])) + (script ([src "https://unpkg.com/@observablehq/plot@0.4.3/dist/plot.umd.min.js"])) + (link ([rel "stylesheet"] [type "text/css"] [href "../report.css"])) + (script ([src "../report.js"]))) + (body + ,(render-menu #:path ".." + (~a (test-name test)) + (if output? + (list '("Report" . "../index.html") '("Metrics" . "timeline.html")) + (list '("Metrics" . "timeline.html")))) + (div ([id "large"]) + ,(render-comparison + "Percentage Accurate" + (format-accuracy (errors-score start-error) repr-bits #:unit "%") + (format-accuracy (errors-score end-error) repr-bits #:unit "%") + #:title + (format "Minimum Accuracy: ~a → ~a" + (format-accuracy (apply max (map ulps->bits start-error)) repr-bits #:unit "%") + (format-accuracy (apply max (map ulps->bits end-error)) repr-bits #:unit "%"))) + ,(render-large "Time" (format-time time)) + ,(render-large "Alternatives" (~a (length end-exprs))) + ,(if (*pareto-mode*) + (render-large "Speedup" + (if speedup (~r speedup #:precision '(= 1)) "N/A") + "×" + #:title "Relative speed of fastest alternative that improves accuracy.") + "")) + ,(render-warnings warnings) + ,(render-specification test #:bogosity bogosity) + (figure ([id "graphs"]) + (h2 "Local Percentage Accuracy vs " + (span ([id "variables"])) + (a ((class "help-button float") [href ,(doc-url "report.html#graph")] + [target "_blank"]) + "?")) + (svg) + (div ([id "functions"])) + (figcaption "The average percentage accuracy by input value. Horizontal axis shows " + "value of an input variable; the variable is choosen in the title. " + "Vertical axis is accuracy; higher is better. Red represent the original " + "program, while blue represents Herbie's suggestion. " + "These can be toggled with buttons below the plot. " + "The line is an average while dots represent individual samples.")) + (section ([id "cost-accuracy"] (class "section") [data-benchmark-name ,(~a (test-name test))]) + ; TODO : Show all Developer Target Accuracy + (h2 "Accuracy vs Speed" + (a ((class "help-button float") [href ,(doc-url "report.html#cost-accuracy")] [target "_blank"]) "?")) - (svg) - (div ([id "functions"])) - (figcaption "The average percentage accuracy by input value. Horizontal axis shows " - "value of an input variable; the variable is choosen in the title. " - "Vertical axis is accuracy; higher is better. Red represent the original " - "program, while blue represents Herbie's suggestion. " - "These can be toggled with buttons below the plot. " - "The line is an average while dots represent individual samples.")) - (section ([id "cost-accuracy"] (class "section") [data-benchmark-name ,(~a (test-name test))]) - ; TODO : Show all Developer Target Accuracy - (h2 "Accuracy vs Speed" - (a ((class "help-button float") [href ,(doc-url "report.html#cost-accuracy")] - [target "_blank"]) - "?")) - (div ((class "figure-row")) - (svg) - (div (p "Herbie found " ,(~a (length end-alts)) " alternatives:") - (table (thead (tr (th "Alternative") - (th ((class "numeric")) "Accuracy") - (th ((class "numeric")) "Speedup"))) - (tbody)))) - (figcaption "The accuracy (vertical axis) and speed (horizontal axis) of each " - "alternatives. Up and to the right is better. The red square shows " - "the initial program, and each blue circle shows an alternative." - "The line shows the best available speed-accuracy tradeoffs.")) - ,(let-values ([(dropdown body) (render-program (alt-expr start-alt) ctx #:ident identifier)]) - `(section ([id "initial"] (class "programs")) - (h2 "Initial Program" + (div ((class "figure-row")) + (svg) + (div (p "Herbie found " ,(~a (length end-exprs)) " alternatives:") + (table (thead (tr (th "Alternative") + (th ((class "numeric")) "Accuracy") + (th ((class "numeric")) "Speedup"))) + (tbody)))) + (figcaption "The accuracy (vertical axis) and speed (horizontal axis) of each " + "alternatives. Up and to the right is better. The red square shows " + "the initial program, and each blue circle shows an alternative." + "The line shows the best available speed-accuracy tradeoffs.")) + ,(let-values ([(dropdown body) (render-program (alt-expr start-alt) ctx #:ident identifier)]) + `(section ([id "initial"] (class "programs")) + (h2 "Initial Program" + ": " + (span ((class "subhead")) + (data ,(format-accuracy (errors-score start-error) repr-bits #:unit "%")) + " accurate, " + (data "1.0×") + " speedup") + ,dropdown + ,(render-help "report.html#alternatives")) + ,body)) + ,@(for/list ([i (in-naturals 1)] + [expr end-exprs] + [errs end-errors] + [cost end-costs] + [history (hash-ref end 'end-histories)]) + (define-values (dropdown body) + (render-program expr ctx #:ident identifier #:instructions preprocessing)) + `(section ([id ,(format "alternative~a" i)] (class "programs")) + (h2 "Alternative " + ,(~a i) ": " (span ((class "subhead")) - (data ,(format-accuracy (errors-score start-error) repr-bits #:unit "%")) + (data ,(format-accuracy (errors-score errs) repr-bits #:unit "%")) " accurate, " - (data "1.0×") + (data ,(~r (/ (alt-cost start-alt repr) cost) #:precision '(= 1)) "×") " speedup") ,dropdown ,(render-help "report.html#alternatives")) - ,body)) - ,@(for/list ([i (in-naturals 1)] - [alt-fpcore end-alts] - [errs end-errors] - [cost end-costs] - [history (hash-ref end 'end-histories)]) - (define formula (read-syntax 'web (open-input-string alt-fpcore))) - (define expr (parse-test formula)) - (define-values (dropdown body) - (render-program (test-input expr) ctx #:ident identifier #:instructions preprocessing)) - `(section ([id ,(format "alternative~a" i)] (class "programs")) - (h2 "Alternative " - ,(~a i) - ": " - (span ((class "subhead")) - (data ,(format-accuracy (errors-score errs) repr-bits #:unit "%")) - " accurate, " - (data ,(~r (/ (alt-cost start-alt repr) cost) #:precision '(= 1)) "×") - " speedup") - ,dropdown - ,(render-help "report.html#alternatives")) - ,body - (details (summary "Derivation") (ol ((class "history")) ,@history)))) - ,@(for/list ([i (in-naturals 1)] - [target (in-list targets)] - [target-error (in-list list-target-error)] - [target-cost (in-list list-target-cost)]) - (let-values ([(dropdown body) - (render-program (alt-expr (alt-analysis-alt target)) ctx #:ident identifier)]) - `(section - ([id ,(format "target~a" i)] (class "programs")) - (h2 "Developer Target " - ,(~a i) - ": " - (span ((class "subhead")) - (data ,(format-accuracy (errors-score target-error) repr-bits #:unit "%")) - " accurate, " - (data ,(~r (/ (alt-cost start-alt repr) target-cost) #:precision '(= 1)) "×") - " speedup") - ,dropdown - ,(render-help "report.html#target")) - ,body))) - ,(render-reproduction test))) - out)) + ,body + (details (summary "Derivation") (ol ((class "history")) ,@history)))) + ,@(for/list ([i (in-naturals 1)] + [target (in-list targets)] + [target-error (in-list list-target-error)] + [target-cost (in-list list-target-cost)]) + (let-values ([(dropdown body) + (render-program (alt-expr (alt-analysis-alt target)) ctx #:ident identifier)]) + `(section + ([id ,(format "target~a" i)] (class "programs")) + (h2 "Developer Target " + ,(~a i) + ": " + (span ((class "subhead")) + (data ,(format-accuracy (errors-score target-error) repr-bits #:unit "%")) + " accurate, " + (data ,(~r (/ (alt-cost start-alt repr) target-cost) #:precision '(= 1)) "×") + " speedup") + ,dropdown + ,(render-help "report.html#target")) + ,body))) + ,(render-reproduction test)))) diff --git a/src/reports/pages.rkt b/src/reports/pages.rkt index 212a9f50a..5939b5bbe 100644 --- a/src/reports/pages.rkt +++ b/src/reports/pages.rkt @@ -5,7 +5,8 @@ "timeline.rkt" "plot.rkt" "make-graph.rkt" - "traceback.rkt") + "traceback.rkt" + "common.rkt") (provide all-pages make-page @@ -31,20 +32,24 @@ (display "" out))) (define (make-page page out result-hash output? profile?) - (define test (hash-ref result-hash 'test)) - (define status (hash-ref result-hash 'status)) (match page - ["graph.html" - (match status - ['success - (define command (hash-ref result-hash 'command)) - (match command - ["improve" (make-graph result-hash out output? profile?)] - [else (dummy-graph command out)])] - ['timeout (make-traceback result-hash out)] - ['failure (make-traceback result-hash out)] - [_ (error 'make-page "unknown result type ~a" status)])] + ["graph.html" (write-html (make-graph-html result-hash output? profile?) out)] ["timeline.html" - (make-timeline (test-name test) (hash-ref result-hash 'timeline) out #:path "..")] + (write-html (make-timeline (test-name (hash-ref result-hash 'test)) + (hash-ref result-hash 'timeline) + #:path "..") + out)] ["timeline.json" (write-json (hash-ref result-hash 'timeline) out)] ["points.json" (write-json (make-points-json result-hash) out)])) + +(define (make-graph-html result-hash output? profile?) + (define status (hash-ref result-hash 'status)) + (match status + ['success + (define command (hash-ref result-hash 'command)) + (match command + ["improve" (make-graph result-hash output? profile?)] + [else (dummy-graph command)])] + ['timeout (make-traceback result-hash)] + ['failure (make-traceback result-hash)] + [_ (error 'make-graph-html "unknown result type ~a" status)])) diff --git a/src/reports/plot.rkt b/src/reports/plot.rkt index 928b809e9..5488e88f5 100644 --- a/src/reports/plot.rkt +++ b/src/reports/plot.rkt @@ -2,12 +2,10 @@ (require math/bigfloat math/flonum) -(require "../utils/common.rkt" - "../core/points.rkt" +(require "../core/points.rkt" "../utils/float.rkt" "../core/programs.rkt" "../syntax/types.rkt" - "../syntax/syntax.rkt" "../syntax/read.rkt" "../utils/alternative.rkt" "../core/bsearch.rkt" @@ -105,12 +103,18 @@ ; bits of error for the output on each point ; ticks: array of size n where each entry is 13 or so tick values as [ordinal, string] pairs ; splitpoints: array with the ordinal splitpoints - `#hasheq((bits . ,bits) - (vars . ,(map symbol->string vars)) - (points . ,json-points) - (error . ,error-entries) - (ticks_by_varidx . ,ticks) - (splitpoints_by_varidx . ,splitpoints))) + (hasheq 'bits + bits + 'vars + (map symbol->string vars) + 'points + json-points + 'error + error-entries + 'ticks_by_varidx + ticks + 'splitpoints_by_varidx + splitpoints)) ;; Repr conversions diff --git a/src/reports/timeline.rkt b/src/reports/timeline.rkt index 63bc37b15..bf3cc3e27 100644 --- a/src/reports/timeline.rkt +++ b/src/reports/timeline.rkt @@ -15,20 +15,18 @@ ;; This first part handles timelines for a single Herbie run -(define (make-timeline name timeline out #:info [info #f] #:path [path "."]) - (write-html - `(html (head (meta ([charset "utf-8"])) - (title "Metrics for " ,(~a name)) - (link ([rel "stylesheet"] [type "text/css"] - [href ,(if info "report.css" "../report.css")])) - (script ([src ,(if info "report.js" "../report.js")]))) - (body ,(render-menu (~a name) - #:path path - (if info `(("Report" . "index.html")) `(("Details" . "graph.html")))) - ,(if info (render-about info) "") - ,(render-timeline timeline) - ,(render-profile))) - out)) +(define (make-timeline name timeline #:info [info #f] #:path [path "."]) + `(html (head (meta ([charset "utf-8"])) + (title "Metrics for " ,(~a name)) + (link ([rel "stylesheet"] [type "text/css"] + [href ,(if info "report.css" "../report.css")])) + (script ([src ,(if info "report.js" "../report.js")]))) + (body ,(render-menu (~a name) + #:path path + (if info `(("Report" . "index.html")) `(("Details" . "graph.html")))) + ,(if info (render-about info) "") + ,(render-timeline timeline) + ,(render-profile)))) (define/contract (render-timeline timeline) (-> timeline? xexpr?) diff --git a/src/reports/traceback.rkt b/src/reports/traceback.rkt index 17ccc2df5..90c174b06 100644 --- a/src/reports/traceback.rkt +++ b/src/reports/traceback.rkt @@ -7,13 +7,13 @@ (provide make-traceback) -(define (make-traceback result-hash out) +(define (make-traceback result-hash) (match (hash-ref result-hash 'status) - ['timeout (render-timeout result-hash out)] - ['failure (render-failure result-hash out)] + ['timeout (render-timeout result-hash)] + ['failure (render-failure result-hash)] [status (error 'make-traceback "unexpected status ~a" status)])) -(define (render-failure result-hash out) +(define (render-failure result-hash) (define test (hash-ref result-hash 'test)) (define warnings (hash-ref result-hash 'warnings)) (define backend (hash-ref result-hash 'backend)) @@ -21,27 +21,25 @@ ; unpack the exception (match-define (list 'exn type msg url extra traceback) backend) - (write-html - `(html - (head (meta ((charset "utf-8"))) - (title "Exception for " ,(~a (test-name test))) - (link ((rel "stylesheet") (type "text/css") (href "../report.css"))) - ,@js-tex-include - (script ([src "../report.js"]))) - (body ,(render-menu (~a (test-name test)) - (list '("Report" . "../index.html") '("Metrics" . "timeline.html"))) - ,(render-warnings warnings) - ,(render-specification test) - ,(if type - `(section ([id "user-error"] (class "error")) - (h2 ,(~a msg) " " (a ([href ,url]) "(more)")) - ,(if (eq? type 'syntax) (render-syntax-errors msg extra) "")) - "") - ,(if type - "" - `(,@(render-reproduction test #:bug? #t) - (section ([id "backtrace"]) (h2 "Backtrace") ,(render-traceback msg traceback)))))) - out)) + `(html + (head (meta ((charset "utf-8"))) + (title "Exception for " ,(~a (test-name test))) + (link ((rel "stylesheet") (type "text/css") (href "../report.css"))) + ,@js-tex-include + (script ([src "../report.js"]))) + (body ,(render-menu (~a (test-name test)) + (list '("Report" . "../index.html") '("Metrics" . "timeline.html"))) + ,(render-warnings warnings) + ,(render-specification test) + ,(if type + `(section ([id "user-error"] (class "error")) + (h2 ,(~a msg) " " (a ([href ,url]) "(more)")) + ,(if (eq? type 'syntax) (render-syntax-errors msg extra) "")) + "") + ,(if type + "" + `(,@(render-reproduction test #:bug? #t) + (section ([id "backtrace"]) (h2 "Backtrace") ,(render-traceback msg traceback))))))) (define (render-syntax-errors msg locations) `(table (thead (th ([colspan "2"]) ,msg) (th "L") (th "C")) @@ -63,22 +61,20 @@ `(tr (td ((class "procedure")) ,(~a name)) (td ,(~a file)) (td ,(~a line)) (td ,(~a col)))] [#f `(tr (td ((class "procedure")) ,(~a name)) (td ([colspan "3"]) "unknown"))]))))) -(define (render-timeout result-hash out) +(define (render-timeout result-hash) (define test (hash-ref result-hash 'test)) (define time (hash-ref result-hash 'time)) (define warnings (hash-ref result-hash 'warnings)) - (write-html - `(html (head (meta ((charset "utf-8"))) - (title "Exception for " ,(~a (test-name test))) - (link ((rel "stylesheet") (type "text/css") (href "../report.css"))) - ,@js-tex-include - (script ([src "../report.js"]))) - (body ,(render-menu (~a (test-name test)) - (list '("Report" . "../index.html") '("Metrics" . "timeline.html"))) - ,(render-warnings warnings) - ,(render-specification test) - (section ([id "user-error"] (class "error")) - (h2 "Timeout after " ,(format-time time)) - (p "Use the " (code "--timeout") " flag to change the timeout.")))) - out)) + `(html (head (meta ((charset "utf-8"))) + (title "Exception for " ,(~a (test-name test))) + (link ((rel "stylesheet") (type "text/css") (href "../report.css"))) + ,@js-tex-include + (script ([src "../report.js"]))) + (body ,(render-menu (~a (test-name test)) + (list '("Report" . "../index.html") '("Metrics" . "timeline.html"))) + ,(render-warnings warnings) + ,(render-specification test) + (section ([id "user-error"] (class "error")) + (h2 "Timeout after " ,(format-time time)) + (p "Use the " (code "--timeout") " flag to change the timeout."))))) diff --git a/src/syntax/matcher.rkt b/src/syntax/matcher.rkt new file mode 100644 index 000000000..12d0c2699 --- /dev/null +++ b/src/syntax/matcher.rkt @@ -0,0 +1,36 @@ +;; Minimal pattern matcher/substituter for S-expressions + +#lang racket + +(provide pattern-match + pattern-substitute) + +;; Unions two bindings. Returns #f if they disagree. +(define (merge-bindings binding1 binding2) + (and binding1 + binding2 + (let/ec quit + (for/fold ([binding binding1]) ([(k v) (in-dict binding2)]) + (dict-update binding k (λ (x) (if (equal? x v) v (quit #f))) v))))) + +;; Pattern matcher that returns a substitution or #f. +;; A substitution is an association list of symbols and expressions. +(define (pattern-match pattern expr) + (match* (pattern expr) + [((? number?) _) (and (equal? pattern expr) '())] + [((? symbol?) _) (list (cons pattern expr))] + [((list phead prest ...) (list head rest ...)) + (and (equal? phead head) + (= (length prest) (length rest)) + (for/fold ([bindings '()]) + ([pat (in-list prest)] + [term (in-list rest)]) + (merge-bindings bindings (pattern-match pat term))))] + [(_ _) #f])) + +(define (pattern-substitute pattern bindings) + ; pattern binding -> expr + (match pattern + [(? number?) pattern] + [(? symbol?) (dict-ref bindings pattern)] + [(list phead pargs ...) (cons phead (map (curryr pattern-substitute bindings) pargs))])) diff --git a/src/syntax/platform.rkt b/src/syntax/platform.rkt index 6b7beb9c0..c8751f640 100644 --- a/src/syntax/platform.rkt +++ b/src/syntax/platform.rkt @@ -1,11 +1,10 @@ #lang racket -(require (for-syntax racket/match)) - (require "../utils/common.rkt" "../utils/errors.rkt" "../core/programs.rkt" "../core/rules.rkt" + "matcher.rkt" "syntax.rkt" "types.rkt") @@ -242,10 +241,7 @@ ;; Casts between representations in a platform. (define (platform-casts pform) - (reap [sow] - (for ([impl (in-list (platform-impls pform))]) - (when (eq? (impl->operator impl) 'cast) - (sow impl))))) + (filter cast-impl? (platform-impls pform))) ;; Merger for costs. (define (merge-cost pform-costs key #:optional? [optional? #f]) @@ -311,7 +307,8 @@ (define reprs* (filter repr-supported? (platform-reprs pform))) (define impls* (filter (λ (impl) - (and (op-supported? (impl->operator impl)) + (define spec (impl-info impl 'spec)) + (and (andmap op-supported? (ops-in-expr spec)) (repr-supported? (impl-info impl 'otype)) (andmap repr-supported? (impl-info impl 'itype)))) (platform-impls pform))) @@ -429,16 +426,9 @@ ;; Synthesizes the LHS and RHS of lifting/lowering rules. (define (impl->rule-parts impl) - (define op (impl->operator impl)) - (cond - [(operator-accelerator? op) - (define spec (operator-info op 'spec)) - (match-define `(,(or 'lambda 'λ) (,vars ...) ,body) spec) - (values vars body (cons impl vars))] - [else - (define itypes (operator-info op 'itype)) - (define vars (map (lambda (_) (gensym)) itypes)) - (values vars (cons op vars) (cons impl vars))])) + (define vars (impl-info impl 'vars)) + (define spec (impl-info impl 'spec)) + (values vars spec (cons impl vars))) ;; Synthesizes lifting rules for a given platform. (define (platform-lifting-rules [pform (*active-platform*)]) @@ -466,11 +456,10 @@ (hash-ref! (*lowering-rules*) (cons impl pform) (lambda () - (define op (impl->operator impl)) (define name (sym-append 'lower- impl)) - (define itypes (operator-info op 'itype)) - (define otype (operator-info op 'otype)) (define-values (vars spec-expr impl-expr) (impl->rule-parts impl)) + (define itypes (map representation-type (impl-info impl 'itype))) + (define otype (representation-type (impl-info impl 'otype))) (rule name spec-expr impl-expr (map cons vars itypes) otype))))) ;; All possible assignments of implementations. @@ -483,8 +472,9 @@ [(list 'if rest ...) (loop rest assigns)] [(list (? (curryr assq assigns)) rest ...) (loop rest assigns)] [(list op rest ...) - (for ([impl (operator-all-impls op)]) - (when (set-member? impls impl) + (for ([impl (in-set impls)]) + (define pattern (cons op (map (lambda _ (gensym)) (operator-info op 'itype)))) + (when (pattern-match (impl-info impl 'spec) pattern) (loop rest (cons (cons op impl) assigns))))])))) ;; Attempts to lower a specification to an expression using @@ -550,5 +540,10 @@ (when (and input* output*) (define itypes* (merge-envs ienv oenv)) (when itypes* - (define name* (sym-append name '_ (repr->symbol repr))) + (define name* + (string->symbol + (format "~a-~a-~a" + name + (representation-name repr) + (string-join (map (lambda (subst) (~a (cdr subst))) isubst) "-")))) (sow (rule name* input* output* itypes* repr)))))])))) diff --git a/src/syntax/read.rkt b/src/syntax/read.rkt index 3d9f76221..90961ec69 100644 --- a/src/syntax/read.rkt +++ b/src/syntax/read.rkt @@ -144,28 +144,22 @@ [(list 'FPCore name (list args ...) props ... body) (values name args props body)] [(list 'FPCore (list args ...) props ... body) (values #f args props body)])) - ;; TODO(interface): Currently, this code doesn't fire because annotations aren't - ;; allowed for variables because of the syntax checker yet. This should run correctly - ;; once the syntax checker is updated to the FPBench 1.1 standard. - (define arg-names - (for/list ([arg args]) - (if (list? arg) (last arg) arg))) + (define prop-dict (props->dict props)) + (define default-prec (dict-ref prop-dict ':precision (*default-precision*))) - (define prop-dict - (let loop ([props props]) - (match props - ['() '()] - [(list prop val rest ...) (cons (cons prop val) (loop rest))]))) + (define-values (var-names var-precs) + (for/lists (var-names var-precs) + ([var (in-list args)]) + (match var + [(list '! props ... name) + (define prop-dict (props->dict props)) + (define arg-prec (dict-ref prop-dict ':precision default-prec)) + (values name arg-prec)] + [(? symbol? name) (values name default-prec)]))) - (define default-prec (dict-ref prop-dict ':precision (*default-precision*))) (define default-repr (get-representation default-prec)) - (define var-reprs - (for/list ([arg args] - [arg-name arg-names]) - (if (and (list? arg) (set-member? args ':precision)) - (get-representation (cadr (member ':precision args))) - default-repr))) - (define ctx (context arg-names default-repr var-reprs)) + (define var-reprs (map get-representation var-precs)) + (define ctx (context var-names default-repr var-reprs)) ;; Named fpcores need to be added to function table (when func-name @@ -192,12 +186,12 @@ (cons val #t))]))) (define spec (fpcore->prog (dict-ref prop-dict ':spec body) ctx)) - (check-unused-variables arg-names body* pre*) - (check-weird-variables arg-names) + (check-unused-variables var-names body* pre*) + (check-weird-variables var-names) (test (~a name) func-name - arg-names + var-names body* targets (dict-ref prop-dict ':herbie-expected #t) @@ -205,8 +199,8 @@ pre* (dict-ref prop-dict ':herbie-preprocess empty) (representation-name default-repr) - (for/list ([var arg-names] - [repr var-reprs]) + (for/list ([var (in-list var-names)] + [repr (in-list var-reprs)]) (cons var (representation-name repr))))) (define (check-unused-variables vars precondition expr) diff --git a/src/syntax/sugar.rkt b/src/syntax/sugar.rkt index 4e93764cf..b002c48ef 100644 --- a/src/syntax/sugar.rkt +++ b/src/syntax/sugar.rkt @@ -1,12 +1,5 @@ -#lang racket - -(require "types.rkt" - "syntax.rkt") - -(provide fpcore->prog - prog->fpcore - prog->spec) - +;; Expression conversions +;; ;; Herbie uses three expression languages. ;; All formats are S-expressions with variables, numbers, and applications. ;; @@ -64,6 +57,18 @@ ;; ::= ;; +#lang racket + +(require "../core/programs.rkt" + "../utils/common.rkt" + "matcher.rkt" + "syntax.rkt" + "types.rkt") + +(provide fpcore->prog + prog->fpcore + prog->spec) + ;; Expression pre-processing for normalizing expressions. ;; Used for conversion from FPCore to other IRs. (define (expand-expr expr) @@ -135,97 +140,224 @@ ; other [_ expr]))) -;; Prop list to dict -(define (props->dict props) - (let loop ([props props] - [dict '()]) - (match props - [(list key val rest ...) (loop rest (dict-set dict key val))] - [(list key) (error 'props->dict "unmatched key" key)] - [(list) dict]))) +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; FPCore -> LImpl -;; Translates from FPCore to an LImpl +;; Translates an FPCore operator application into +;; an LImpl operator application. +(define (fpcore->impl-app op prop-dict args ctx) + (define ireprs (map (lambda (arg) (repr-of arg ctx)) args)) + (define impl (get-fpcore-impl op prop-dict ireprs)) + (define vars (impl-info impl 'vars)) + (define pattern + (match (impl-info impl 'fpcore) + [(list '! _ ... body) body] + [body body])) + (define subst (pattern-match pattern (cons op args))) + (pattern-substitute (cons impl vars) subst)) + +;; Translates from FPCore to an LImpl. (define (fpcore->prog prog ctx) - (define-values (expr* _) - (let loop ([expr (expand-expr prog)] - [ctx ctx]) - (match expr - [`(FPCore ,name (,vars ...) ,props ... ,body) - (define-values (body* repr*) (loop body ctx)) - (values `(FPCore ,name ,vars ,@props ,body*) repr*)] - [`(FPCore (,vars ...) ,props ... ,body) - (define-values (body* repr*) (loop body ctx)) - (values `(FPCore ,vars ,@props ,body*) repr*)] - [`(if ,cond ,ift ,iff) - (define-values (cond* cond-repr) (loop cond ctx)) - (define-values (ift* ift-repr) (loop ift ctx)) - (define-values (iff* iff-repr) (loop iff ctx)) - (values `(if ,cond* ,ift* ,iff*) ift-repr)] - [`(! ,props ... ,body) - (define props* (props->dict props)) - (loop body - (match (dict-ref props* ':precision #f) - [#f ctx] - [prec (struct-copy context ctx [repr (get-representation prec)])]))] - [`(cast ,body) - (define repr (context-repr ctx)) - (define-values (body* repr*) (loop body ctx)) - (if (equal? repr* repr) ; check if cast is redundant - (values body* repr) - (values (list (get-cast-impl repr* repr) body*) repr))] - [`(,(? constant-operator? x)) - (define cnst (get-parametric-constant x (context-repr ctx))) - (values (list cnst) (impl-info cnst 'otype))] - [(list 'neg arg) ; non-standard but useful - (define-values (arg* atype) (loop arg ctx)) - (define op* (get-parametric-operator 'neg atype)) - (values (list op* arg*) (impl-info op* 'otype))] - [`(,op ,args ...) - (define-values (args* atypes) (for/lists (args* atypes) ([arg args]) (loop arg ctx))) - ;; Match guaranteed to succeed because we ran type-check first - (define op* (apply get-parametric-operator op atypes)) - (values (cons op* args*) (impl-info op* 'otype))] - [(? variable?) (values expr (context-lookup ctx expr))] - [(? number?) - (define prec (representation-name (context-repr ctx))) - (define num - (match expr - [(or +inf.0 -inf.0 +nan.0) expr] - [(? exact?) expr] - [_ (inexact->exact expr)])) - (values (literal num prec) (context-repr ctx))]))) - expr*) + (let loop ([expr (expand-expr prog)] + [prop-dict (repr->prop (context-repr ctx))]) + (match expr + [(? number? n) + (literal (match n + [(or +inf.0 -inf.0 +nan.0) expr] + [(? exact?) expr] + [_ (inexact->exact expr)]) + (dict-ref prop-dict ':precision))] + [(? variable?) expr] + [(list 'if cond ift iff) + (define cond* (loop cond prop-dict)) + (define ift* (loop ift prop-dict)) + (define iff* (loop iff prop-dict)) + (list 'if cond* ift* iff*)] + [(list '! props ... body) (loop body (apply dict-set prop-dict props))] + [(list 'neg arg) ; non-standard but useful [TODO: remove] + (define arg* (loop arg prop-dict)) + (fpcore->impl-app '- prop-dict (list arg*) ctx)] + [(list 'cast arg) ; special case: unnecessary casts + (define arg* (loop arg prop-dict)) + (define repr (get-representation (dict-ref prop-dict ':precision))) + (if (equal? (repr-of arg* ctx) repr) arg* (fpcore->impl-app 'cast prop-dict (list arg*) ctx))] + [(list op args ...) + (define args* (map (lambda (arg) (loop arg prop-dict)) args)) + (fpcore->impl-app op prop-dict args* ctx)]))) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; LImpl -> FPCore +;; Translates from LImpl to an FPCore + +;; TODO: this process uses a batch-like data structure +;; but _without_ deduplication since different use sites +;; of a particular subexpression may have different +;; parent rounding contexts. Would be nice to explore +;; if the batch data structure can be used. + +;; Instruction vector index +(struct index (v) #:prefab) + +;; Translates a literal (LImpl) to an FPCore expr +(define (literal->fpcore x) + (match x + [(literal -inf.0 _) '(- INFINITY)] + [(literal +inf.0 _) 'INFINITY] + [(literal v (or 'binary64 'binary32)) (exact->inexact v)] + [(literal v _) v])) + +;; Step 1. +;; Translates from LImpl to a series of let bindings such that each +;; local variable is bound once and used at most once. The result is an +;; instruction vector, representing the let bindings; the operator +;; implementation for each instruction, and the final "root" operation/literal. +;; Except for let-bound variables, the subexpressions are in FPCore. + +(define (prog->let-exprs expr) + (define instrs '()) + (define (push! impl node) + (define id (length instrs)) + (set! instrs (cons (cons node impl) instrs)) + (index id)) + + (define (munge expr #:root? [root? #f]) + (match expr + [(? literal?) (literal->fpcore expr)] + [(? symbol?) expr] + [(approx _ impl) (munge impl)] + [(list 'if cond ift iff) (list 'if (munge cond) (munge ift) (munge iff))] + [(list (? impl-exists? impl) args ...) + (define args* (map munge args)) + (define vars (impl-info impl 'vars)) + (define node (replace-vars (map cons vars args*) (impl-info impl 'fpcore))) + (if root? node (push! impl node))])) + + (define root (munge expr #:root? #t)) + (cons (list->vector (reverse instrs)) root)) + +;; Step 2. +;; Inlines let bindings; let-inlining is generally unsound with +;; rounding properties (the parent context may change), +;; so we only inline those that result in the same operator +;; implementation when converting back from FPCore to LImpl. + +(define (inline! root ivec ctx) + (define global-prop-dict (repr->prop (context-repr ctx))) + (let loop ([node root] + [prop-dict global-prop-dict]) + (match node + [(? number?) node] ; number + [(? symbol?) node] ; variable + [(index idx) ; let-bound variable + ; we check what happens if we inline + (match-define (cons expr impl) (vector-ref ivec idx)) + (define impl* + (match expr + [(list '! props ... (list op _ ...)) + ; rounding context updated parent context + (define prop-dict* (apply dict-set prop-dict props)) + (get-fpcore-impl op prop-dict* (impl-info impl 'itype))] + ; rounding context inherited from parent context + [(list op _ ...) (get-fpcore-impl op prop-dict (impl-info impl 'itype))])) + (cond + [(equal? impl impl*) ; inlining is safe + (define expr* (loop expr prop-dict)) + (vector-set! ivec idx #f) + expr*] + [else ; inlining is not safe + (define expr* (loop expr global-prop-dict)) + (vector-set! ivec idx expr*) + node])] + [(list '! props ... body) ; explicit rounding context + (define prop-dict* (props->dict props)) + (define body* (loop body prop-dict*)) + (define new-prop-dict + (for/list ([(k v) (in-dict prop-dict*)] + #:unless (and (dict-has-key? prop-dict k) (equal? (dict-ref prop-dict k) v))) + (cons k v))) + (if (null? new-prop-dict) body* `(! ,@(dict->props new-prop-dict) ,body*))] + [(list op args ...) ; operator application + (define args* (map (lambda (e) (loop e prop-dict)) args)) + `(,op ,@args*)]))) + +;; Step 3. +;; Construct the final FPCore expression using remaining let-bindings +;; and the let-free body from the previous step. + +(define (reachable-indices ivec expr) + (define reachable (mutable-set)) + (let loop ([expr expr]) + (match expr + [(? number?) (void)] + [(? symbol?) (void)] + [(index idx) + (set-add! reachable idx) + (loop (vector-ref ivec idx))] + [(list _ args ...) (for-each loop args)])) + reachable) + +(define (remove-indices id->name expr) + (let loop ([expr expr]) + (match expr + [(? number?) expr] + [(? symbol?) expr] + [(index idx) (hash-ref id->name idx)] + [(list '! props ... body) `(! ,@props ,(loop body))] + [(list op args ...) `(,op ,@(map loop args))]))) + +(define (build-expr expr ivec ctx) + ; variable generation + (define vars (list->mutable-seteq (context-vars ctx))) + (define counter 0) + (define (gensym) + (set! counter (add1 counter)) + (match (string->symbol (format "t~a" counter)) + [(? (curry set-member? vars)) (gensym)] + [x + (set-add! vars x) + x])) + + ; need fresh variables for reachable, non-inlined subexpressions + (define reachable (reachable-indices ivec expr)) + (define id->name (make-hash)) + (for ([expr (in-vector ivec)] + [idx (in-naturals)]) + (when (and expr (set-member? reachable idx)) + (hash-set! id->name idx (gensym)))) + + (for/fold ([body (remove-indices id->name expr)]) ([idx (in-list (sort (hash-keys id->name) >))]) + (define var (hash-ref id->name idx)) + (define val (remove-indices id->name (vector-ref ivec idx))) + `(let ([,var ,val]) ,body))) ;; Translates from LImpl to an FPCore. -(define (prog->fpcore expr) - (match expr - [`(if ,cond ,ift ,iff) `(if ,(prog->fpcore cond) ,(prog->fpcore ift) ,(prog->fpcore iff))] - [`(,(? cast-impl? impl) ,body) - (define prec (representation-name (impl-info impl 'otype))) - `(! :precision ,prec (cast ,(prog->fpcore body)))] - [`(,impl) (impl->operator impl)] - [`(,impl ,args ...) - (define op (impl->operator impl)) - (define args* (map prog->fpcore args)) - (match (cons op args*) - [`(neg ,arg) `(- ,arg)] - [expr expr])] - [(approx _ impl) (prog->fpcore impl)] - [(? variable?) expr] - [(? literal?) - (match (literal-value expr) - [-inf.0 '(- INFINITY)] - [+inf.0 'INFINITY] - [+nan.0 'NAN] - [v (if (set-member? '(binary64 binary32) (literal-precision expr)) (exact->inexact v) v)])])) +;; The implementation of this procedure is complicated since +;; (1) every operator implementation requires certain (FPCore) rounding properties +;; (2) rounding contexts have lexical scoping +(define (prog->fpcore prog ctx) + ; step 1: convert to an instruction vector where + ; each expression is evaluated under explicit rounding contexts + (match-define (cons ivec root) (prog->let-exprs prog)) + + ; step 2: inline nodes + (define body (inline! root ivec ctx)) + + ; step 3: construct the actual FPCore expression from + ; the remaining let-bindings and body + (build-expr body ivec ctx)) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; LImpl -> LSpec ;; Translates an LImpl to a LSpec. (define (prog->spec expr) - (expand-accelerators - (match expr - [`(if ,cond ,ift ,iff) `(if ,(prog->spec cond) ,(prog->spec ift) ,(prog->spec iff))] - [`(,(? cast-impl? impl) ,body) `(,impl ,(prog->spec body))] - [`(,impl ,args ...) `(,(impl->operator impl) ,@(map prog->spec args))] - [(approx spec _) spec] - [(? variable?) expr] - [(? literal?) (literal-value expr)]))) + (match expr + [(? literal?) (literal-value expr)] + [(? variable?) expr] + [(approx spec _) spec] + [`(if ,cond ,ift ,iff) `(if ,(prog->spec cond) ,(prog->spec ift) ,(prog->spec iff))] + [`(,impl ,args ...) + (define vars (impl-info impl 'vars)) + (define spec (impl-info impl 'spec)) + (define env (map cons vars (map prog->spec args))) + (replace-vars env spec)])) diff --git a/src/syntax/syntax-check.rkt b/src/syntax/syntax-check.rkt index 6398a89b9..4d53745b2 100644 --- a/src/syntax/syntax-check.rkt +++ b/src/syntax/syntax-check.rkt @@ -67,6 +67,17 @@ ;; These expand by associativity so we don't check the number of arguments (for ([arg args]) (loop arg vars))] + [#`(,(? (curry set-member? '(erfc expm1 log1p hypot fma)) op) #,args ...) + ; FPCore operators that are composite in Herbie + (define arity + (case op + [(erfc expm1 log1p) 1] + [(hypot) 2] + [(fma) 3])) + (unless (= arity (length args)) + (error! stx "Operator ~a given ~a arguments (expects ~a)" op (length args) arity)) + (for ([arg (in-list args)]) + (loop arg vars))] [#`(#,f-syntax #,args ...) (define f (syntax->datum f-syntax)) (cond diff --git a/src/syntax/syntax.rkt b/src/syntax/syntax.rkt index 8022e5792..2640d3939 100644 --- a/src/syntax/syntax.rkt +++ b/src/syntax/syntax.rkt @@ -5,6 +5,7 @@ (require "../utils/common.rkt" "../utils/errors.rkt" "../core/rival.rkt" + "matcher.rkt" "types.rkt") (provide (rename-out [operator-or-impl? operator?]) @@ -14,23 +15,18 @@ constant-operator? operator-exists? operator-deprecated? - operator-accelerator? operator-info all-operators all-constants - all-accelerators - expand-accelerators impl-exists? impl-info - impl->operator - operator-all-impls - operator-active-impls + all-operator-impls + (rename-out [all-active-operator-impls active-operator-impls]) activate-operator-impl! clear-active-operator-impls! *functions* register-function! - get-parametric-operator - get-parametric-constant + get-fpcore-impl get-cast-impl generate-cast-impl cast-impl?) @@ -58,11 +54,8 @@ ;; A real operator requires ;; - a (unique) name ;; - input and output types -;; - optionally a specification [#f by default] ;; - optionally a deprecated? flag [#f by default] -;; Operator implementations _implement_ a real operator -;; for a particular set of input and output representations. -(struct operator (name itype otype spec deprecated)) +(struct operator (name itype otype deprecated)) ;; All real operators (define operators (make-hasheq)) @@ -75,10 +68,6 @@ (define (operator-deprecated? op) (operator-deprecated (hash-ref operators op))) -;; Checks if an operator is an "accelerator". -(define (operator-accelerator? op) - (and (hash-has-key? operators op) (operator-spec (hash-ref operators op)))) - ;; Returns all operators. (define (all-operators) (sort (hash-keys operators) symbol symbol? (or/c 'itype 'otype 'spec) any/c) + (-> symbol? (or/c 'itype 'otype) any/c) (unless (hash-has-key? operators op) (error 'operator-info "Unknown operator ~a" op)) (define info (hash-ref operators op)) (case field [(itype) (operator-itype info)] - [(otype) (operator-otype info)] - [(spec) (operator-spec info)])) - -;; Map from operator to its implementations -(define operators-to-impls (make-hasheq)) - -;; All implementations of an operator `op`. -;; Panics if the operator is not found. -(define (operator-all-impls op) - (unless (hash-has-key? operators op) - (error 'operator-info "Unknown operator ~a" op)) - (hash-ref operators-to-impls op)) - -;; Checks an "accelerator" specification -(define (check-spec! name itypes otype spec) - (define (bad! fmt . args) - (error name "~a in `~a`" (apply format fmt args) spec)) - - (define (type-error! expr actual-ty expect-ty) - (bad! "expression `~a` has type `~a`, expected `~a`" expr actual-ty expect-ty)) - - (define-values (vars body) - (match spec - [`(,(or 'lambda 'λ) (,vars ...) ,spec) - (for ([var (in-list vars)]) - (unless (symbol? var) - (bad! "expected symbol `~a` in `~a`" var spec))) - (values vars spec)] - [_ (bad! "malformed specification, expected `(lambda )`")])) - - (unless (= (length itypes) (length vars)) - (bad! "arity mismatch; expected ~a, got ~a" (length itypes) (length vars))) - - (define env (map cons vars itypes)) - (define actual-ty - (let type-of ([expr body]) - (match expr - [(? number?) 'real] - [(? symbol?) - (cond - [(assq expr env) - => - cdr] - [else (bad! "unbound variable `~a`" expr)])] - [`(if ,cond ,ift ,iff) - (define cond-ty (type-of cond)) - (unless (equal? cond-ty 'bool) - (type-error! cond cond-ty 'bool)) - (define ift-ty (type-of ift)) - (define iff-ty (type-of iff)) - (unless (equal? ift-ty iff-ty) - (type-error! iff iff-ty ift-ty)) - ift-ty] - [`(,op ,args ...) - (unless (operator-exists? op) - (bad! "expected operator at `~a`, got `~a` in `~a`" expr op)) - (define itypes (operator-info op 'itype)) - (for ([arg (in-list args)] - [itype (in-list itypes)]) - (define arg-ty (type-of arg)) - (unless (equal? itype arg-ty) - (type-error! arg arg-ty itype))) - (operator-info op 'otype)] - [_ (bad! "expected an expression, got `~a`" expr)]))) - - (unless (equal? actual-ty otype) - (type-error! body actual-ty otype))) - -;; Applies a substitution. -;; Slightly different than `replace-vars` in `programs.rkt`. -(define (replace-vars expr env) - (let loop ([expr expr]) - (match expr - [(? number?) expr] - [(? symbol?) (cdr (assq expr env))] - [`(,op ,args ...) `(,op ,@(map loop args))]))) - -;; Expands an "accelerator" specification. -;; Any nested accelerator is unfolded into its definition. -(define (expand-accelerators spec) - (let loop ([expr spec]) - (match expr - [(? number?) expr] - [(? symbol?) expr] - [`(,(? operator-accelerator? op) ,args ...) - (define spec (operator-info op 'spec)) - (match-define `(,(or 'lambda 'λ) (,vars ...) ,body) spec) - (define env (map cons vars (map loop args))) - (replace-vars body env)] - [`(,op ,args ...) `(,op ,@(map loop args))]))) + [(otype) (operator-otype info)])) ;; Registers an operator with an attribute mapping. ;; Panics if an operator with name `name` has already been registered. @@ -205,42 +98,23 @@ (define (register-operator! name itypes otype attrib-dict) (when (hash-has-key? operators name) (error 'register-operator! "operator already registered: ~a" name)) - ; extract relevant fields + ; extract relevant fields and update tables (define itypes* (dict-ref attrib-dict 'itype itypes)) (define otype* (dict-ref attrib-dict 'otype otype)) - (define spec (dict-ref attrib-dict 'spec #f)) (define deprecated? (dict-ref attrib-dict 'deprecated #f)) - ; check the spec if it is provided - (when spec - (check-spec! name itypes otype spec) - (set! spec (expand-accelerators spec))) - ; update tables - (define info (operator name itypes* otype* spec deprecated?)) - (hash-set! operators name info) - (hash-set! operators-to-impls name '())) + (define info (operator name itypes* otype* deprecated?)) + (hash-set! operators name info)) -;; Syntactic form for `register-operator!`. -;; Special translations for +;; Syntactic form for `register-operator!` (define-syntax (define-operator stx) (define (bad! why [what #f]) (raise-syntax-error 'define-operator why stx what)) - - (define (attribute-val key val) - (syntax-case key (spec) - [spec - (with-syntax ([val val]) - (syntax 'val))] - [_ val])) - (syntax-case stx () [(_ (id itype ...) otype [key val] ...) - (let ([id #'id] - [keys (syntax->list #'(key ...))] - [vals (syntax->list #'(val ...))]) + (let ([id #'id]) (unless (identifier? id) (bad! "expected identifier" id)) - (with-syntax ([id id] - [(val ...) (map attribute-val keys vals)]) + (with-syntax ([id id]) #'(register-operator! 'id '(itype ...) 'otype (list (cons 'key val) ...))))])) (define-syntax define-operators @@ -324,45 +198,32 @@ [pow : real real -> real] [remainder : real real -> real]) -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -;; Accelerator operators - -(define-operator (cast real) real [spec (lambda (x) x)]) - -(define-operator (erfc real) real [spec (lambda (x) (- 1 (erf x)))]) - -(define-operator (expm1 real) real [spec (lambda (x) (- (exp x) 1))]) - -(define-operator (log1p real) real [spec (lambda (x) (log (+ 1 x)))]) - -(define-operator (hypot real real) real [spec (lambda (x y) (sqrt (+ (* x x) (* y y))))]) - -(define-operator (fma real real real) real [spec (lambda (x y z) (+ (* x y) z))]) - (module+ test ; check expected number of operators - (check-equal? (length (all-operators)) 63) + (check-equal? (length (all-operators)) 57) ; check that Rival supports all non-accelerator operators - (for ([op (in-list (all-operators))] - #:unless (operator-accelerator? op)) + (for ([op (in-list (all-operators))]) (define vars (map (lambda (_) (gensym)) (operator-info op 'itype))) (define disc (discretization 64 #f #f)) ; fake arguments - (rival-compile (list `(,op ,@vars)) vars (list disc))) - - ; test accelerator operator - ; log1pmd(x) = log1p(x) - log1p(-x) - (define-operator (log1pmd real) real [spec (lambda (x) (- (log1p x) (log1p (neg x))))])) + (rival-compile (list `(,op ,@vars)) vars (list disc)))) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Operator implementations ;; Floating-point operations that approximate mathematical operations -;; Operator implementations -;; An "operator implementation" implements a mathematical operator for -;; a particular set of representations satisfying the types described -;; by the `itype` and `otype` properties of the operator. -(struct operator-impl (name op ctx spec fpcore fl)) +;; Operator implementations _approximate_ a program of +;; mathematical operators with fixed input and output representations. +;; +;; An operator implementation requires +;; - a (unique) name +;; - input variables/representations +;; - output representation +;; - a specification it approximates +;; - its FPCore representation +;; - a floating-point implementation +;; +(struct operator-impl (name ctx spec fpcore fl)) ;; Operator implementation table ;; Tracks implementations that are loaded into Racket's runtime @@ -376,28 +237,25 @@ ;; Looks up a property `field` of an real operator `op`. ;; Panics if the operator is not found. (define/contract (impl-info impl field) - (-> symbol? (or/c 'itype 'otype 'fl) any/c) + (-> symbol? (or/c 'vars 'itype 'otype 'spec 'fpcore 'fl) any/c) (unless (hash-has-key? operator-impls impl) (error 'impl-info "Unknown operator implementation ~a" impl)) (define info (hash-ref operator-impls impl)) (case field + [(vars) (context-vars (operator-impl-ctx info))] [(itype) (context-var-reprs (operator-impl-ctx info))] [(otype) (context-repr (operator-impl-ctx info))] [(spec) (operator-impl-spec info)] [(fpcore) (operator-impl-fpcore info)] [(fl) (operator-impl-fl info)])) -;; Like `operator-all-impls`, but filters for only active implementations. -(define (operator-active-impls name) - (filter (curry set-member? active-operator-impls) (operator-all-impls name))) +;; Returns all operator implementations. +(define (all-operator-impls) + (sort (hash-keys operator-impls) symboloperator name) - (unless (hash-has-key? operator-impls name) - (raise-herbie-missing-error "Unknown operator implementation ~a" name)) - (define impl (hash-ref operator-impls name)) - (operator-name (operator-impl-op impl))) +;; Returns all active operator implementations. +(define (all-active-operator-impls) + (sort (set->list active-operator-impls) symbolrepr orepr) +nan.bf)) - (procedure-rename (lambda pt - (define-values (_ exs) (real-apply compiler pt)) - (if exs (first exs) fail)) - (sym-append 'synth: name))) - ;; Get floating-point implementation - (define fl-proc - (cond - [fl fl] ; user-provided implementation - [(operator-accelerator? new-op) ; Rival-synthesized accelerator implementation - (match-define `(,(or 'lambda 'λ) (,vars ...) ,body) (operator-spec op-info)) - (synth-fl-impl name vars body)] - [else ; Rival-synthesized operator implementation - (define vars (build-list (length ireprs) (lambda (i) (string->symbol (format "x~a" i))))) - (synth-fl-impl name vars `(,new-op ,@vars))])) +;; Collects all operators + +;; Checks a specification. +(define (check-spec! name ctx spec) + (define (bad! fmt . args) + (error name "~a in `~a`" (apply format fmt args) spec)) + + (define (type-error! expr actual-ty expect-ty) + (bad! "expression `~a` has type `~a`, expected `~a`" expr actual-ty expect-ty)) + + (match-define (context vars repr var-reprs) ctx) + (define itypes (map representation-type var-reprs)) + (define otype (representation-type repr)) + + (unless (= (length itypes) (length vars)) + (bad! "arity mismatch; expected ~a, got ~a" (length itypes) (length vars))) + + (define env (map cons vars itypes)) + (define actual-ty + (let type-of ([expr spec]) + (match expr + [(? number?) 'real] + [(? symbol?) + (cond + [(assq expr env) + => + cdr] + [else (bad! "unbound variable `~a`" expr)])] + [`(if ,cond ,ift ,iff) + (define cond-ty (type-of cond)) + (unless (equal? cond-ty 'bool) + (type-error! cond cond-ty 'bool)) + (define ift-ty (type-of ift)) + (define iff-ty (type-of iff)) + (unless (equal? ift-ty iff-ty) + (type-error! iff iff-ty ift-ty)) + ift-ty] + [`(,op ,args ...) + (unless (operator-exists? op) + (bad! "at `~a`, `~a` not an operator" expr op)) + (define itypes (operator-info op 'itype)) + (unless (= (length itypes) (length args)) + (bad! "arity mismatch at `~a`: expected `~a`, got `~a`" + expr + (length itypes) + (length args))) + (for ([arg (in-list args)] + [itype (in-list itypes)]) + (define arg-ty (type-of arg)) + (unless (equal? itype arg-ty) + (type-error! arg arg-ty itype))) + (operator-info op 'otype)] + [_ (bad! "expected an expression, got `~a`" expr)]))) + (unless (equal? actual-ty otype) + (type-error! spec actual-ty otype))) + +; Registers an operator implementation `name` with context `ctx` and spec `spec. +; Can optionally specify a floating-point implementation and fpcore translation. +(define/contract (register-operator-impl! name ctx spec #:fl [fl-proc #f] #:fpcore [fpcore #f]) + (->* (symbol? context? any/c) (#:fl (or/c procedure? #f) #:fpcore any/c) void?) + ; check specification + (check-spec! name ctx spec) + (define vars (context-vars ctx)) + ; synthesize operator (if the spec contains exactly one operator) + (define op + (match spec + [(list op (or (? number?) (? symbol?)) ...) op] + [_ #f])) + ; check or synthesize FPCore translatin + (define fpcore* + (cond + [fpcore ; provided -> TODO: check free variables, props + (match fpcore + [`(! ,props ... (,op ,args ...)) + (unless (even? (length props)) + (error 'register-operator-impl! "~a: umatched property in ~a" name fpcore)) + (unless (symbol? op) + (error 'register-operator-impl! "~a: expected symbol `~a`" name op)) + (for ([arg (in-list args)]) + (unless (or (symbol? arg) (number? arg)) + (error 'register-operator-impl! "~a: expected terminal `~a`" name arg)))] + [`(,op ,args ...) + (unless (symbol? op) + (error 'register-operator-impl! "~a: expected symbol `~a`" name op)) + (for ([arg (in-list args)]) + (unless (or (symbol? arg) (number? arg)) + (error 'register-operator-impl! "~a: expected terminal `~a`" name arg)))] + [_ (error 'register-operator-impl! "Invalid fpcore for ~a: ~a" name fpcore)]) + fpcore] + [else ; not provided => need to generate it + (define repr (context-repr ctx)) + (define bool-repr (get-representation 'bool)) + (if (equal? repr bool-repr) + `(,op ,@vars) ; special case: boolean-valued operations do not need a precision annotation + `(! :precision ,(representation-name repr) (,op ,@vars)))])) + ; check or synthesize floating-point operation + (define fl-proc* + (cond + [fl-proc ; provided => check arity + (unless (procedure-arity-includes? fl-proc (length vars) #t) + (error 'register-operator-impl! + "~a: procedure does not accept ~a arguments" + name + (length vars))) + fl-proc] + [else ; need to generate + (define compiler (make-real-compiler (list spec) (list ctx))) + (define fail ((representation-bf->repr (context-repr ctx)) +nan.bf)) + (procedure-rename (lambda pt + (define-values (_ exs) (real-apply compiler pt)) + (if exs (first exs) fail)) + name)])) ; update tables - (define impl (operator-impl name op-info (context vars orepr ireprs) spec fpcore fl-proc)) - (hash-set! operator-impls name impl) - (hash-update! operators-to-impls new-op (curry cons name))) + (define impl (operator-impl name ctx spec fpcore* fl-proc*)) + (hash-set! operator-impls name impl)) (define-syntax (define-operator-impl stx) (define (oops! why [sub-stx #f]) (raise-syntax-error 'define-operator-impl why stx sub-stx)) (syntax-case stx (:) [(_ (id [var : repr] ...) rtype fields ...) - (let ([impl-id #'id] + (let ([id #'id] + [vars (syntax->list #'(var ...))] [fields #'(fields ...)]) - (unless (identifier? impl-id) - (oops! "impl id is not a valid identifier" impl-id)) - (for ([var (in-list (syntax->list #'(var ...)))]) + (unless (identifier? id) + (oops! "expected identifier" id)) + (for ([var (in-list vars)]) (unless (identifier? var) - (oops! "given id is not a valid identifier" var))) - (define operator #f) + (oops! "expected identifier" var))) (define spec #f) (define core #f) (define fl-expr #f) (let loop ([fields fields]) (syntax-case fields () [() - (with-syntax ([impl-id impl-id] - [operator operator] + (unless spec + (oops! "missing `#:spec` keyword")) + (with-syntax ([id id] [spec spec] [core core] [fl-expr fl-expr]) - #'(register-operator-impl! 'operator - 'impl-id - (list (cons 'var (get-representation 'repr)) ...) - (get-representation 'rtype) + #'(register-operator-impl! 'id + (context '(var ...) + (get-representation 'rtype) + (list (get-representation 'repr) ...)) + 'spec #:fl fl-expr - #:spec 'spec #:fpcore 'core))] [(#:spec expr rest ...) (cond @@ -565,108 +438,71 @@ (set! fl-expr #'expr) (loop #'(rest ...))])] [(#:fl) (oops! "expected value after keyword `#:fl`" stx)] - [(#:op name rest ...) - (cond - [operator (oops! "multiple #:op clauses" stx)] - [else - (set! operator #'name) - (loop #'(rest ...))])] - [(#:op) (oops! "expected value after keyword `#:op`" stx)] + ; bad [_ (oops! "bad syntax" fields)])))] [_ (oops! "bad syntax")])) -;; Among active implementations, looks up an implementation with -;; the operator name `name` and argument representations `ireprs`. -(define (get-parametric-operator #:all? [all? #f] name . ireprs) - (define get-impls (if all? operator-all-impls operator-active-impls)) - (let/ec k - (for/first ([impl (get-impls name)] - #:when (equal? (impl-info impl 'itype) ireprs)) - (k impl)) - (raise-herbie-missing-error - "Could not find operator implementation for ~a with ~a" - name - (string-join (map (λ (r) (format "<~a>" (representation-name r))) ireprs) " ")))) - -;; Among active implementations, looks up an implementation of -;; a constant (nullary operator) with the operator name `name` -;; and representation `repr`. -(define (get-parametric-constant name repr #:all? [all? #f]) - (define get-impls (if all? operator-all-impls operator-active-impls)) - (let/ec k - (for ([impl (get-impls name)]) - (define rtype (impl-info impl 'otype)) - (when (or (equal? rtype repr) (equal? (representation-type rtype) 'bool)) - (k impl))) - (raise-herbie-missing-error "Could not find constant implementation for ~a with ~a" - name - (format "<~a>" (representation-name repr))))) - -(module+ test - (require math/flonum - math/bigfloat - (submod "types.rkt" internals)) - - (define (shift bits fn) - (define shift-val (expt 2 bits)) - (λ (x) (fn (- x shift-val)))) - - (define (unshift bits fn) - (define shift-val (expt 2 bits)) - (λ (x) (+ (fn x) shift-val))) - - ; for testing: also in /reprs/bool.rkt - (define-representation (bool bool boolean?) - identity - identity - (λ (x) (= x 0)) - (λ (x) (if x 1 0)) - 1 - (const #f)) - - ; for testing: also in /reprs/binary64.rkt - (define-representation (binary64 real flonum?) - bigfloat->flonum - bf - (shift 63 ordinal->flonum) - (unshift 63 flonum->ordinal) - 64 - (conjoin number? nan?)) - - ; correctly-rounded log1pmd(x) for binary64 - (define-operator-impl (log1pmd.f64 [x : binary64]) - binary64 - #:spec (- (log1p x) (log1p (neg x))) - #:fpcore (! :precision binary64 (log1pmd x)) - #:op log1pmd) - ; correctly-rounded sin(x) for binary64 - (define-operator-impl (sin.acc.f64 [x : binary64]) - binary64 - #:spec (sin x) - #:fpcore (! :precision binary64 (sin x)) - #:fl sin) - - (define log1pmd-proc (impl-info 'log1pmd.f64 'fl)) - (define log1pmd-vals '((0.0 . 0.0) (0.5 . 1.0986122886681098) (-0.5 . -1.0986122886681098))) - (for ([(pt out) (in-dict log1pmd-vals)]) - (check-equal? (log1pmd-proc pt) out (format "log1pmd(~a) = ~a" pt out))) - - (define sin-proc (impl-info 'sin.acc.f64 'fl)) - (define sin-vals '((0.0 . 0.0) (1.0 . 0.8414709848078965) (-1.0 . -0.8414709848078965))) - (for ([(pt out) (in-dict sin-vals)]) - (check-equal? (sin-proc pt) out (format "sin(~a) = ~a" pt out))) - - (void)) +;; Extracts the `fpcore` field of an operator implementation +;; as a property dictionary and expression. +(define (impl->fpcore impl) + (match (impl-info impl 'fpcore) + [(list '! props ... body) (values (props->dict props) body)] + [body (values '() body)])) + +;; For a given FPCore operator, rounding context, and input representations, +;; finds the best operator implementation. Panics if none can be found. +(define/contract (get-fpcore-impl op prop-dict ireprs #:impls [all-impls (all-active-operator-impls)]) + (->* (symbol? prop-dict/c (listof representation?)) (#:impls (listof symbol?)) symbol?) + ; gather all implementations that have the same spec, input representations, + ; and its FPCore translation has properties that are found in `prop-dict` + (define impls + (reap [sow] + (for ([impl (in-list all-impls)]) + (when (equal? ireprs (impl-info impl 'itype)) + (define-values (prop-dict* expr) (impl->fpcore impl)) + (define pattern (cons op (map (lambda (_) (gensym)) ireprs))) + (when (and (andmap (lambda (prop) (member prop prop-dict)) prop-dict*) + (pattern-match pattern expr)) + (sow impl)))))) + ; check that we have any matching impls + (when (null? impls) + (raise-herbie-missing-error + "No implementation for `~a` under rounding context `~a` with types `~a`" + op + prop-dict + (string-join (map (λ (r) (format "<~a>" (representation-name r))) ireprs) " "))) + ; ; we rank implementations and select the highest scoring one + (define scores + (for/list ([impl (in-list impls)]) + (define-values (prop-dict* _) (impl->fpcore impl)) + (define num-matching (count (lambda (prop) (member prop prop-dict*)) prop-dict)) + (cons num-matching (- (length prop-dict) num-matching)))) + ; select the best implementation + ; sort first by the number of matched properties, + ; then tie break on the number of extraneous properties + (match-define (list (cons _ best) _ ...) + (sort (map cons scores impls) + (lambda (x y) + (cond + [(> (car x) (car y)) #t] + [(< (car x) (car y)) #f] + [else (> (cdr x) (cdr y))])) + #:key car)) + best) ;; Casts and precision changes (define (cast-impl? x) - (and (symbol? x) (set-member? (operator-all-impls 'cast) x))) + (and (symbol? x) + (impl-exists? x) + (match (impl-info x 'vars) + [(list v) + #:when (eq? (impl-info x 'spec) v) + #t] + [_ #f]))) -(define (get-cast-impl irepr orepr #:all? [all? #f]) - (define get-impls (if all? operator-all-impls operator-active-impls)) - (for/or ([name (get-impls 'cast)]) - (and (equal? (impl-info name 'otype) orepr) (equal? (first (impl-info name 'itype)) irepr) name))) +(define (get-cast-impl irepr orepr #:impls [impls (all-active-operator-impls)]) + (get-fpcore-impl 'cast (repr->prop orepr) (list irepr) #:impls impls)) ; Similar to representation generators, conversion generators ; allow Herbie to query plugins for optimized implementations @@ -699,17 +535,15 @@ (define (constant-operator? op) (and (symbol? op) (or (and (hash-has-key? operators op) (null? (operator-itype (hash-ref operators op)))) - (and (hash-has-key? operator-impls op) - (null? (context-vars (operator-impl-ctx (hash-ref operator-impls op)))))))) + (and (hash-has-key? operator-impls op) (null? (impl-info op 'vars)))))) (define (variable? var) (and (symbol? var) (or (not (hash-has-key? operators var)) (not (null? (operator-itype (hash-ref operators var))))) - (or (not (hash-has-key? operator-impls var)) - (not (null? (context-vars (operator-impl-ctx (hash-ref operator-impls var)))))))) + (or (not (hash-has-key? operator-impls var)) (not (null? (impl-info var 'vars)))))) -;; Floating-point expressions require that number +;; Floating-point expressions require that numbers ;; be rounded to a particular precision. (struct literal (value precision) #:prefab) diff --git a/src/syntax/test-syntax.rkt b/src/syntax/test-syntax.rkt new file mode 100644 index 000000000..6004c713c --- /dev/null +++ b/src/syntax/test-syntax.rkt @@ -0,0 +1,57 @@ +#lang racket + +(require "load-plugin.rkt" + "syntax.rkt" + "types.rkt" + (submod "syntax.rkt" internals)) + +(module+ test + (require rackunit + math/bigfloat) + + (load-herbie-builtins) + + ; log1pmd(x) = log1p(x) - log1p(-x) + + (define-operator-impl (log1pmd.f64 [x : binary64]) + binary64 + #:spec (- (log (+ 1 x)) (log (+ 1 (neg x)))) + #:fpcore (! :precision binary64 (log1pmd x))) + + (define log1pmd-proc (impl-info 'log1pmd.f64 'fl)) + (define log1pmd-vals '((0.0 . 0.0) (0.5 . 1.0986122886681098) (-0.5 . -1.0986122886681098))) + (for ([(pt out) (in-dict log1pmd-vals)]) + (check-equal? (log1pmd-proc pt) out (format "log1pmd(~a) = ~a" pt out))) + + ; fast sine + + (define-operator-impl (fast-sin.f64 [x : binary64]) + binary64 + #:spec (sin x) + #:fpcore (! :precision binary64 :math-library fast (sin x)) + #:fl (lambda (x) + (parameterize ([bf-precision 12]) + (bigfloat->flonum (bfsin (bf x)))))) + + (define sin-proc (impl-info 'fast-sin.f64 'fl)) + (define sin-vals '((0.0 . 0.0) (1.0 . 0.841552734375) (-1.0 . -0.841552734375))) + (for ([(pt out) (in-dict sin-vals)]) + (check-equal? (sin-proc pt) out (format "sin(~a) = ~a" pt out))) + + ; get-fpcore-impl + + (define f64 (get-representation 'binary64)) + (define (get-impl op props itypes) + (get-fpcore-impl op props itypes #:impls (all-operator-impls))) + + (check-equal? (get-impl '+ '((:precision . binary64)) (list f64 f64)) '+.f64) + (check-equal? (get-impl '+ '((:precision . binary64)) (list f64 f64)) '+.f64) + (check-equal? (get-impl '+ '((:precision . binary64) (:description . "test")) (list f64 f64)) + '+.f64) + + (check-equal? (get-impl 'log1pmd '((:precision . binary64)) (list f64)) 'log1pmd.f64) + (check-equal? (get-impl 'sin '((:precision . binary64)) (list f64)) 'sin.f64) + (check-equal? (get-impl 'sin '((:precision . binary64) (:math-library . fast)) (list f64)) + 'fast-sin.f64) + + (void)) diff --git a/src/syntax/type-check.rkt b/src/syntax/type-check.rkt index 66f1f4c3d..bf22dbbb0 100644 --- a/src/syntax/type-check.rkt +++ b/src/syntax/type-check.rkt @@ -9,40 +9,45 @@ (define (assert-program-typed! stx) (define-values (vars props body) (match (syntax-e stx) - [(list (app syntax-e 'FPCore) (app syntax-e name) (app syntax-e (list vars ...)) props ... body) + [(list (app syntax-e 'FPCore) _ (app syntax-e (list vars ...)) props ... body) (values vars props body)] [(list (app syntax-e 'FPCore) (app syntax-e (list vars ...)) props ... body) (values vars props body)])) - (define props* - (let loop ([props props]) - (match props - [(list) (list)] - [(list (app syntax-e prop) value rest ...) - (cons (cons prop (syntax->datum value)) (loop rest))]))) - (define type (get-representation (dict-ref props* ':precision 'binary64))) - (assert-expression-type! body - type - #:env (for/hash ([var vars]) - (values (syntax-e var) type)))) - -(define (assert-expression-type! stx expected-rtype #:env [env #hash()]) - (define errs - (reap [sow] - (define (error! stx fmt . args) - (sow (cons stx - (apply format - fmt - (for/list ([arg args]) - (if (representation? arg) (representation-name arg) arg)))))) - (define actual-rtype (expression->type stx env expected-rtype error!)) - (unless (equal? expected-rtype actual-rtype) - (error! stx "Expected program of type ~a, got type ~a" expected-rtype actual-rtype)))) + + (define default-dict `((:precision . ,(*default-precision*)))) + (define prop-dict (apply dict-set* default-dict (map syntax->datum props))) + (define prec (dict-ref prop-dict ':precision)) + + (define-values (var-names var-precs) + (for/lists (var-names var-precs) + ([var (in-list vars)]) + (match (syntax->datum var) + [(list '! props ... name) + (define prop-dict (props->dict props)) + (define arg-prec (dict-ref prop-dict ':precision prec)) + (values name arg-prec)] + [(? symbol? name) (values name prec)]))) + + (define ctx (context var-names (get-representation prec) (map get-representation var-precs))) + (assert-expression-type! body prop-dict ctx)) + +(define (assert-expression-type! stx props ctx) + (define errs '()) + (define (error! stx fmt . args) + (define args* + (for/list ([arg (in-list args)]) + (match arg + [(? representation?) (representation-name arg)] + [_ arg]))) + (set! errs (cons (cons stx (apply format fmt args*)) errs))) + + (define repr (expression->type stx props ctx error!)) + (unless (equal? repr (context-repr ctx)) + (error! stx "Expected program of type ~a, got type ~a" (context-repr ctx) repr)) + (unless (null? errs) (raise-herbie-syntax-error "Program has type errors" #:locations errs))) -(define (repr-has-type? repr type) - (and repr (equal? (representation-type repr) type))) - (define (application->string op types) (format "(~a ~a)" op @@ -50,170 +55,110 @@ (if t (format "<~a>" (representation-name t)) "")) " "))) -(define (resolve-missing-op! stx op actual-types error!) - (define active-impls (operator-active-impls op)) - (cond - [(null? active-impls) - ; no active implementations - (define all-impls (operator-all-impls op)) - (cond - ; no implementations at all - [(null? all-impls) (error! stx "No implementations of ~a found; check plugins" op)] - [else - ; found in-active implementations - (error! stx - "No implementations of `~a` in platform, but found inactive implementations ~a" - op - (string-join (for/list ([impl all-impls]) - (application->string op (impl-info impl 'itype))) - " or "))])] - [else - ; active implementations were found - (error! stx - "Invalid arguments to ~a; found ~a, but got ~a" - op - (string-join (for/list ([impl active-impls]) - (application->string op (impl-info impl 'itype))) - " or ") - (application->string op actual-types))])) - -(define (expression->type stx env type error!) - (match stx - [#`,(? number?) type] - [#`,(? constant-operator? x) - (define cnst* - (with-handlers ([exn:fail:user:herbie:missing? (const #f)]) - (get-parametric-constant x type))) - (cond - [cnst* (impl-info cnst* 'otype)] - [else - ; implementation not supported so try to report a useful error - (define active-impls (operator-active-impls x)) - (cond - [(null? active-impls) - ; no active implementations - (define all-impls (operator-all-impls x)) - (cond - ; no implementations at all - [(null? all-impls) (error! stx "No implementations of ~a found; check plugins" x)] - [else - ; found in-active implementations - (error! stx - (string-append "No implementations of `~a` in platform, " - "but found inactive implementations for ~a") - x - (string-join (for/list ([impl all-impls]) - (format "<~a>" (representation-name (impl-info impl 'otype)))) - " or "))])] - [else - ; active implementations were found - (error! stx - "No implementation for ~a with ~a; found implementations for ~a" - x - (format "<~a>" (representation-name type)) - (string-join (for/list ([impl active-impls]) - (format "<~a>" (representation-name (impl-info impl 'otype)))) - " or "))]) - type])] - [#`,(? variable? x) - (define vtype (dict-ref env x)) - (unless (or (equal? type vtype) (repr-has-type? vtype 'bool)) - (error! stx "Expected a variable of type ~a, but got ~a" type vtype)) - vtype] - [#`(let ([,id #,expr] ...) #,body) - (define env2 - (for/fold ([env2 env]) - ([var id] - [val expr]) - (dict-set env2 var (expression->type val env type error!)))) - (expression->type body env2 type error!)] - [#`(let* ([,id #,expr] ...) #,body) - (define env2 - (for/fold ([env2 env]) - ([var id] - [val expr]) - (dict-set env2 var (expression->type val env2 type error!)))) - (expression->type body env2 type error!)] - [#`(if #,branch #,ifstmt #,elsestmt) - (define branch-type (expression->type branch env type error!)) - (unless (repr-has-type? branch-type 'bool) - (error! stx "If statement has non-boolean type ~a for branch" branch-type)) - (define ifstmt-type (expression->type ifstmt env type error!)) - (define elsestmt-type (expression->type elsestmt env type error!)) - (unless (equal? ifstmt-type elsestmt-type) - (error! stx - "If statement has different types for if (~a) and else (~a)" - ifstmt-type - elsestmt-type)) - ifstmt-type] - [#`(! #,props ... #,body) - (define props* (apply hash-set* (hash) (map syntax-e props))) - (cond - [(hash-has-key? props* ':precision) - (define itype (get-representation (hash-ref props* ':precision))) - (define rtype (expression->type body env itype error!)) - (unless (equal? rtype itype) - (error! stx "Annotation promised precision ~a, but got ~a" itype rtype)) - type] - [else (expression->type body env type error!)])] - [#`(- #,arg) - ; special case: unary negation - (define actual-type (expression->type arg env type error!)) - (define op* - (with-handlers ([exn:fail:user:herbie:missing? (const #f)]) - (get-parametric-operator 'neg actual-type))) - (cond - [op* (impl-info op* 'otype)] - [else - (resolve-missing-op! stx '- (list actual-type) error!) - actual-type])] - [#`(,(? operator-exists? op) #,exprs ...) - (define actual-types - (for/list ([arg exprs]) - (expression->type arg env type error!))) - (define op* - (with-handlers ([exn:fail:user:herbie:missing? (const #f)]) - (apply get-parametric-operator op actual-types))) - (cond - [op* (impl-info op* 'otype)] - [else - ; implementation not supported so try to report a useful error - (resolve-missing-op! stx op actual-types error!) - type])] - [#`(,(? (curry hash-has-key? (*functions*)) fname) #,exprs ...) - (match-define (list vars prec _) (hash-ref (*functions*) fname)) - (define repr (get-representation prec)) - (define actual-types - (for/list ([arg exprs]) - (expression->type arg env type error!))) - (define expected (map (const repr) vars)) - (if (andmap equal? actual-types expected) - repr - (begin - (error! stx - "Invalid arguments to ~a; expects ~a but got ~a" - fname - fname - (application->string fname expected) - (application->string fname actual-types)) - type))])) +(define (expression->type stx prop-dict ctx error!) + (let loop ([stx stx] + [prop-dict prop-dict] + [ctx ctx]) + (match stx + [#`,(? number?) (get-representation (dict-ref prop-dict ':precision))] + [#`,(? variable? x) (context-lookup ctx x)] + [#`,(? constant-operator? op) + (define impl + (with-handlers ([exn:fail:user:herbie:missing? (const #f)]) + (get-fpcore-impl op prop-dict '()))) + (match impl + [#f ; no implementation found + (error! stx "No implementation of `~a` in platform for context `~a`" op prop-dict) + (get-representation (dict-ref prop-dict ':precision))] + [_ (impl-info impl 'otype)])] + [#`(let ([,ids #,exprs] ...) #,body) + (define ctx* + (for/fold ([ctx* ctx]) + ([id (in-list ids)] + [expr (in-list exprs)]) + (context-extend ctx* id (loop expr prop-dict ctx)))) + (loop body prop-dict ctx*)] + [#`(let* ([,ids #,exprs] ...) #,body) + (define ctx* + (for/fold ([ctx* ctx]) + ([id (in-list ids)] + [expr (in-list exprs)]) + (context-extend ctx* id (loop expr prop-dict ctx*)))) + (loop body prop-dict ctx*)] + [#`(if #,branch #,ifstmt #,elsestmt) + (define cond-ctx (struct-copy context ctx [repr (get-representation 'bool)])) + (define cond-repr (loop branch prop-dict cond-ctx)) + (unless (equal? (representation-type cond-repr) 'bool) + (error! stx "If statement has non-boolean type ~a for branch" cond-repr)) + (define ift-repr (loop ifstmt prop-dict ctx)) + (define iff-repr (loop elsestmt prop-dict ctx)) + (unless (equal? ift-repr iff-repr) + (error! stx "If statement has different types for if (~a) and else (~a)" ift-repr iff-repr)) + ift-repr] + [#`(! #,props ... #,body) (loop body (apply dict-set prop-dict props) ctx)] + [#`(,(? (curry hash-has-key? (*functions*)) fname) #,args ...) + ; TODO: inline functions expect uniform types, this is clearly wrong + (match-define (list vars prec _) (hash-ref (*functions*) fname)) + (define repr (get-representation prec)) + (define ireprs (map (lambda (arg) (loop arg prop-dict ctx)) args)) + (define expected (map (const repr) vars)) + (unless (andmap equal? ireprs expected) + (error! stx + "Invalid arguments to ~a; expects ~a but got ~a" + fname + fname + (application->string fname expected) + (application->string fname ireprs))) + repr] + [#`(cast #,arg) + (define irepr (loop arg prop-dict ctx)) + (define repr (get-representation (dict-ref prop-dict ':precision))) + (cond + [(equal? irepr repr) repr] + [else + (define impl + (with-handlers ([exn:fail:user:herbie:missing? (const #f)]) + (get-fpcore-impl 'cast prop-dict (list irepr)))) + (match impl + [#f ; no implementation found + (error! stx + "No implementation of `~a` in platform for context `~a`" + (application->string 'cast (list irepr)) + prop-dict) + (get-representation (dict-ref prop-dict ':precision))] + [_ (impl-info impl 'otype)])])] + [#`(,(? symbol? op) #,args ...) + (define ireprs (map (lambda (arg) (loop arg prop-dict ctx)) args)) + (define impl + (with-handlers ([exn:fail:user:herbie:missing? (const #f)]) + (get-fpcore-impl op prop-dict ireprs))) + (match impl + [#f ; no implementation found + (error! stx + "No implementation of `~a` in platform for context `~a`" + (application->string op ireprs) + prop-dict) + (get-representation (dict-ref prop-dict ':precision))] + [_ (impl-info impl 'otype)])]))) (module+ test (require rackunit) (require "load-plugin.rkt") (load-herbie-builtins) - (define (fail stx msg . args) + (define (fail! stx msg . args) (error (apply format msg args) stx)) - (define (check-types env-type rtype expr #:env [env #hash()]) - (check-equal? (expression->type expr env env-type fail) rtype)) + (define (check-types env-type rtype expr #:env [env '()]) + (define ctx (context (map car env) env-type (map cdr env))) + (define repr (expression->type expr (repr->prop env-type) ctx fail!)) + (check-equal? repr rtype)) - (define (check-fails type expr #:env [env #hash()]) - (check-equal? (let ([v #f]) - (expression->type expr env type (lambda _ (set! v #t))) - v) - #t)) + (define (check-fails type expr #:env [env '()]) + (define fail? #f) + (define ctx (context (map car env) type (map cdr env))) + (expression->type expr (repr->prop type) ctx (lambda _ (set! fail? #t))) + (check-true fail?)) (define (get-representation 'bool)) (define (get-representation 'binary64)) diff --git a/src/syntax/types.rkt b/src/syntax/types.rkt index 0565ecc98..e0c8d98e3 100644 --- a/src/syntax/types.rkt +++ b/src/syntax/types.rkt @@ -8,6 +8,7 @@ get-representation repr-exists? repr->symbol + repr->prop (struct-out context) *context* context-extend @@ -49,6 +50,12 @@ (define repr-name (representation-name repr)) (string->symbol (string-replace* (~a repr-name) replace-table))) +;; Converts a representation into a rounding property +(define (repr->prop repr) + (match (representation-type repr) + ['bool '()] + ['real (list (cons ':precision (representation-name repr)))])) + ;; Repr / operator generation ;; Some plugins might define 'parameterized' reprs (e.g. fixed point with ;; m integer and n fractional bits). Since defining an infinite number of reprs diff --git a/src/utils/common.rkt b/src/utils/common.rkt index dde147d9f..a51507868 100644 --- a/src/utils/common.rkt +++ b/src/utils/common.rkt @@ -22,12 +22,16 @@ quasisyntax dict sym-append + gen-vars string-replace* format-time format-bits format-accuracy format-cost web-resource + prop-dict/c + props->dict + dict->props (all-from-out "../config.rkt")) (module+ test @@ -268,11 +272,38 @@ (define (web-resource [name #f]) (if name (build-path web-resource-path name) web-resource-path)) -(define (sym-append . args) - (string->symbol (apply string-append (map ~a args)))) - (define/contract (string-replace* str changes) (-> string? (listof (cons/c string? string?)) string?) (for/fold ([str str]) ([change changes]) (match-define (cons from to) change) (string-replace str from to))) + +;; Symbol generation + +(define (sym-append . args) + (string->symbol (apply string-append (map ~a args)))) + +;; Generates a list of variables names. +(define/contract (gen-vars n) + (-> natural? (listof symbol?)) + (build-list n (lambda (i) (string->symbol (format "x~a" i))))) + +;; FPCore properties + +(define prop-dict/c (listof (cons/c symbol? any/c))) + +;; Prop list to dict +(define/contract (props->dict props) + (-> list? (listof (cons/c symbol? any/c))) + (let loop ([props props] + [dict '()]) + (match props + [(list key val rest ...) (loop rest (dict-set dict key val))] + [(list key) (error 'props->dict "unmatched key" key)] + [(list) dict]))) + +(define/contract (dict->props prop-dict) + (-> (listof (cons/c symbol? any/c)) list?) + (apply append + (for/list ([(k v) (in-dict prop-dict)]) + (list k v))))