Skip to content

Commit

Permalink
a preparation for rewriting insert! function
Browse files Browse the repository at this point in the history
  • Loading branch information
AYadrov committed Aug 30, 2024
1 parent c3d7996 commit 138b9c0
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 21 deletions.
11 changes: 10 additions & 1 deletion src/core/batch.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
deref ; Batchref -> Expr
batch-replace ; Batch -> Lambda -> Batch
egg-nodes->batch ; Nodes -> Spec-maps -> Batch -> (Listof Root)
batchref->expr) ; Batchref -> Expr
batchref->expr ; Batchref -> Expr
batch-extract-exprs) ; Batch -> (Listof Root) -> (Listof Expr)

;; This function defines the recursive structure of expressions

Expand Down Expand Up @@ -80,6 +81,14 @@
(timeline-push! 'compiler size (batch-length final)))
final)

(define (batch-extract-exprs b roots)
(define exprs (make-vector (batch-length b)))
(for ([node (in-vector (batch-nodes b))]
[idx (in-naturals)])
(vector-set! exprs idx (expr-recurse node (lambda (x) (vector-ref exprs x)))))
(for/list ([root roots])
(vector-ref exprs root)))

(define (batch->progs b)
(define exprs (make-vector (batch-length b)))
(for ([node (in-vector (batch-nodes b))]
Expand Down
36 changes: 28 additions & 8 deletions src/core/egg-herbie.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@

; Adds expressions returning the root ids
; TODO: take a batch rather than list of expressions
(define (egraph-add-exprs egg-data exprs ctx)
(define (egraph-add-exprs egg-data batch roots ctx)
(match-define (egraph-data ptr herbie->egg-dict egg->herbie-dict id->spec) egg-data)

; lookups the egg name of a variable
Expand Down Expand Up @@ -129,6 +129,21 @@
; expression cache
(define expr->id (make-hash))

;--------------------------------- EXPERIMENTAL
#;(define (insert-batchref! ref)
(match ref
[(batchref b idx)
(define nodes (batch-nodes b))
(insert! idx #t)]
[_ (error insert-batchref! "Pointer is not a batchref!")]))

#;(define (insert! ref [root #f])
(match ref
[(batchref b idx) ...]
[(define node (match))]))

;---------------------------------

; expr -> natural
; inserts an expresison into the e-graph, returning its e-class id.
(define (insert! expr [root? #f])
Expand Down Expand Up @@ -156,6 +171,7 @@
[root? (insert-node! node #t)]
[else (hash-ref! expr->id node (lambda () (insert-node! node #f)))]))

(define exprs (batch-extract-exprs batch roots))
(for/list ([expr (in-list exprs)])
(insert! expr #t)))

Expand Down Expand Up @@ -224,7 +240,8 @@
(egraph_find (egraph-data-egraph-pointer egraph-data) id))

(define (egraph-expr-equal? egraph-data expr goal ctx)
(match-define (list id1 id2) (egraph-add-exprs egraph-data (list expr goal) ctx))
(define batch (progs->batch (list expr goal)))
(match-define (list id1 id2) (egraph-add-exprs egraph-data batch (batch-roots batch) ctx))
(= id1 id2))

;; returns a flattened list of terms or #f if it failed to expand the proof due to budget
Expand Down Expand Up @@ -1195,12 +1212,12 @@
(loop (sub1 num-iters)))]
[else (values egg-graph iteration-data)])))

(define (egraph-run-schedule exprs schedule ctx)
(define (egraph-run-schedule batch roots schedule ctx)
; allocate the e-graph
(define egg-graph (make-egraph))

; insert expressions into the e-graph
(define root-ids (egraph-add-exprs egg-graph exprs ctx))
(define root-ids (egraph-add-exprs egg-graph batch roots ctx))

; run the schedule
(define rule-apps (make-hash))
Expand Down Expand Up @@ -1247,7 +1264,7 @@

;; Herbie's version of an egg runner.
;; Defines parameters for running rewrite rules with egg
(struct egg-runner (exprs reprs schedule ctx)
(struct egg-runner (batch roots reprs schedule ctx)
#:transparent ; for equality
#:methods gen:custom-write ; for abbreviated printing
[(define (write-proc alt port mode)
Expand All @@ -1264,7 +1281,7 @@
;; - scheduler: `(scheduler . <name>)` [default: backoff]
;; - `simple`: run all rules without banning
;; - `backoff`: ban rules if the fire too much
(define (make-egg-runner exprs reprs schedule #:context [ctx (*context*)])
(define (make-egg-runner batch roots reprs schedule #:context [ctx (*context*)])
(define (oops! fmt . args)
(apply error 'verify-schedule! fmt args))
; verify the schedule
Expand All @@ -1285,7 +1302,7 @@
[_ (oops! "in instruction `~a`, unknown parameter `~a`" instr param)]))]
[_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)]))
; make the runner
(egg-runner exprs reprs schedule ctx))
(egg-runner batch roots reprs schedule ctx))

;; Runs egg using an egg runner.
;;
Expand All @@ -1297,7 +1314,10 @@
;; Run egg using runner
(define ctx (egg-runner-ctx runner))
(define-values (root-ids egg-graph)
(egraph-run-schedule (egg-runner-exprs runner) (egg-runner-schedule runner) ctx))
(egraph-run-schedule (egg-runner-batch runner)
(egg-runner-roots runner)
(egg-runner-schedule runner)
ctx))
; Perform extraction
(match cmd
[`(single . ,extractor) ; single expression extraction
Expand Down
4 changes: 3 additions & 1 deletion src/core/localize.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
(define lowering-rules (platform-lowering-rules))

; egg runner (2-phases for real rewrites and implementation selection)
(define batch (progs->batch progs))
(define runner
(make-egg-runner progs
(make-egg-runner batch
(batch-roots batch)
reprs
`((,lifting-rules . ((iteration . 1) (scheduler . simple)))
(,rules . ((node . ,(*node-limit*))))
Expand Down
6 changes: 4 additions & 2 deletions src/core/mainloop.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"preprocess.rkt"
"programs.rkt"
"../utils/timeline.rkt"
"soundiness.rkt")
"soundiness.rkt"
"batch.rkt")
(provide run-improve!)

;; The Herbie main loop goes through a simple iterative process:
Expand Down Expand Up @@ -374,7 +375,8 @@
; egg runner
(define exprs (map alt-expr alts))
(define reprs (map (lambda (expr) (repr-of expr (*context*))) exprs))
(define runner (make-egg-runner exprs reprs schedule))
(define batch (progs->batch exprs))
(define runner (make-egg-runner batch (batch-roots batch) reprs schedule))

; run egg
(define simplified
Expand Down
7 changes: 5 additions & 2 deletions src/core/patch.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
`((,lowering-rules . ((iteration . 1) (scheduler . simple))))))

; run egg
(define runner (make-egg-runner (map alt-expr approxs) reprs schedule))
(define batch (progs->batch (map alt-expr approxs)))
(define runner (make-egg-runner batch (batch-roots batch) reprs schedule))
(define simplification-options
(simplify-batch runner
(typed-egg-extractor
Expand Down Expand Up @@ -133,7 +134,9 @@
(define exprs (map alt-expr altns))
(define reprs (map (curryr repr-of (*context*)) exprs))
(timeline-push! 'inputs (map ~a exprs))
(define runner (make-egg-runner exprs reprs schedule #:context (*context*)))

(define batch (progs->batch exprs))
(define runner (make-egg-runner batch (batch-roots batch) reprs schedule #:context (*context*)))
; variantss is a (listof roots))
(define rootss (run-egg runner `(multi . ,extractor)))

Expand Down
11 changes: 8 additions & 3 deletions src/core/preprocess.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
"programs.rkt"
"points.rkt"
"../utils/timeline.rkt"
"../utils/float.rkt")
"../utils/float.rkt"
"batch.rkt")

(provide find-preprocessing
preprocess-pcontext
Expand Down Expand Up @@ -66,7 +67,8 @@
(,lowering-rules . ((iteration . 1) (scheduler . simple)))))

; egg query
(define runner (make-egg-runner (list expr) (list (context-repr ctx)) schedule))
(define batch (progs->batch (list expr)))
(define runner (make-egg-runner batch (batch-roots batch) (list (context-repr ctx)) schedule))

; run egg
(define simplified
Expand Down Expand Up @@ -100,8 +102,11 @@

;; make egg runner
(define rules (real-rules (*simplify-rules*)))

(define batch (progs->batch specs))
(define runner
(make-egg-runner specs
(make-egg-runner batch
(batch-roots batch)
(map (lambda (_) (context-repr ctx)) specs)
`((,rules . ((node . ,(*node-limit*)))))))

Expand Down
12 changes: 8 additions & 4 deletions src/core/simplify.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
"../utils/errors.rkt"
"rules.rkt"
"../utils/alternative.rkt"
"egg-herbie.rkt")
"egg-herbie.rkt"
"batch.rkt")

(provide simplify-batch)

Expand All @@ -20,13 +21,14 @@
;; if the input specifies proofs, it instead returns proofs for these expressions
(define/contract (simplify-batch runner extractor)
(-> egg-runner? procedure? (listof (listof expr?)))
(timeline-push! 'inputs (map ~a (egg-runner-exprs runner)))
(timeline-push! 'inputs
(map ~a (batch-extract-exprs (egg-runner-batch runner) (egg-runner-roots runner))))
(timeline-push! 'method "egg-herbie")

(define simplifieds (run-egg runner (cons 'single extractor)))
(define out
(for/list ([simplified simplifieds]
[expr (egg-runner-exprs runner)])
[expr (batch-extract-exprs (egg-runner-batch runner) (egg-runner-roots runner))])
(remove-duplicates (cons expr simplified))))

(timeline-push! 'outputs (map ~a (apply append out)))
Expand All @@ -48,8 +50,10 @@
(string-append "Rule failed: " (symbol->string (rule-name rule)))))

(define (test-simplify . args)
(define batch (progs->batch args))
(define runner
(make-egg-runner args
(make-egg-runner batch
(batch-roots batch)
(map (lambda (_) 'real) args)
`((,(*simplify-rules*) . ((node . ,(*node-limit*)))))))
(define extractor (typed-egg-extractor default-egg-cost-proc))
Expand Down

0 comments on commit 138b9c0

Please sign in to comment.