Skip to content

Commit

Permalink
aditya-egglog-single No key found for neg.f64
Browse files Browse the repository at this point in the history
  • Loading branch information
adityaakhileshwaran committed Nov 27, 2024
1 parent d8e8509 commit 73227b6
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 57 deletions.
145 changes: 97 additions & 48 deletions src/core/egglog-herbie.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
(oops! "expected list of rules: `~a`" rules))
(for ([param (in-list params)])
(match param
[(cons 'node (? nonnegative-integer?)) (void)] ; exists
[(cons 'node (? nonnegative-integer?)) (void)] ; exists (node 5)
[(cons 'iteration (? nonnegative-integer?)) (void)] ; exists (run 5)
[(cons 'const-fold? (? boolean?)) (void)] ; AVOID
[(cons 'scheduler mode) ; AVOID
Expand Down Expand Up @@ -184,45 +184,110 @@

; (ruleset math)
; (ruleset fp-safe)
(printf "before tag\n")
;; 2. User Rules which comes from schedule (need to be translated)
(define schedule
(define tag-schedule
(for/list ([i (in-naturals 1)] ; Start index `i` from 1
[element (in-list (egg-runner-schedule runner))])
(printf "element : ~a\n" element)
(printf "car element : ~a\n\n" (car element))
(match (car element)
['lift (void)] ; lift and lower in prelude
['lower (void)]
[_
((egglog-rewrite-rules (car element) (string-append "?tag" (number->string i)))
. (cdr element))])))
; (printf "element : ~a\n" element)
; (printf "car element : ~a\n" (car element))

(define rule-type (car element))
(define schedule-params (cdr element))

(define tag
(match rule-type
['lift 'lifting]
['lower 'lowering]
[_
(define curr-tag (string->symbol (string-append "?tag" (number->string i))))
(printf "rules ~a\n" rule-type)
(set! program (append program (egglog-rewrite-rules rule-type curr-tag)))
curr-tag]))

; (printf "tag : ~a\n\n\n" tag)

(cons tag schedule-params)))
(printf "after tag\n")

(printf "finished schedule \n")
; (printf "finished schedule \n")
; (printf "schedule ~a\n\n" tag-schedule)

(set! program (append program (map cons schedule)))
; (set! program (append program (map cons schedule)))

;; 3. Inserting expressions -> (egglog-add-exprs curr-batch (egglog-runner-ctx))
(set! program (append program (egglog-add-exprs curr-batch (egglog-runner-ctx))))
(set! program (append program (egglog-add-exprs curr-batch (egg-runner-ctx runner))))

;; 4. Running the schedule
;; TODO:

; `((run-schedule (saturate lifting) (saturate math) (saturate lowering) ))
(set! program
(append program
'((cons 'run-schedule
(cons (saturate 'lifting)
(cons (saturate 'math)
(cons (saturate 'lowering)
(for/list ([i (in-range 1 (length schedule))])
(saturate (string-append "?tag"
(number->string i)))))))))))
(define run-schedule '())

(for ([(tag schedule-params) (in-dict tag-schedule)])
(match tag
[(or 'lifting 'lowering) (set! run-schedule (append run-schedule (list (list 'saturate tag))))]
[_
; Set Tag
(set! run-schedule (append run-schedule (list (list 'saturate tag))))

; Set params
(define is-node-present (dict-ref schedule-params 'node #f))
(define is-iteration-present (dict-ref schedule-params 'iteration #f))


(match* (is-node-present is-iteration-present)
[((? nonnegative-integer? node-amt) (? nonnegative-integer? iter-amt))
(set! run-schedule (append run-schedule (list (list 'iter iter-amt) (list 'run node-amt))))]

[(#f (? nonnegative-integer? iter-amt))
(set! run-schedule (append run-schedule (list (list 'iter iter-amt))))]
[((? nonnegative-integer? node-amt) #f)
(set! run-schedule (append run-schedule (list (list 'run node-amt))))]
[(#f #f) (error "lmao")])]))

; (define run-schedule
; (for/list ([(tag schedule-params) (in-dict tag-schedule)])
; (match tag
; [(or 'lifting 'lowering) (list (list 'saturate tag))]
; [_
; (define is-node-present (dict-ref schedule-params 'node #f))
; (define is-iteration-present (dict-ref schedule-params 'iteration #f))

; (match* (is-node-present is-iteration-present)
; [((? nonnegative-integer? node-amt) (? nonnegative-integer? iter-amt))
; (list (list 'iter iter-amt) (list 'run node-amt))]

; [(#f (? nonnegative-integer? iter-amt)) (list (list 'iter iter-amt))]
; [((? nonnegative-integer? node-amt) #f) (list (list 'run node-amt))]
; [(#f #f)
; (error "lmao")])])))

; (printf "run-schedule ~a \n\n" run-schedule)

(set! program (append program `(run-schedule ,run-schedule)))
; [(#f iter-amt) (list 'iter iter-amt)]
; [(node-amt #f) (list 'run node-amt)]
; [(#f #f) (error "lmao")])]))))))

; (define egglog-scheduling-params
; (for/list ([param (in-list schedule-params)])
; (match param
; [(cons 'node num) (list 'run num)] ; exists (node 5)
; [(cons 'iteration num) (list 'run num)] ; exists (run 5)
; [_ (error "wrong")])))

; (list (list 'saturate tag) egglog-scheduling-params)]))))))
;; dict-ref
(printf "finished run-schedule \n")

;; 5. Extraction -> should just need root ids
(for ([root (egg-runner-roots runner)])
(set! program
(append program
'((extract (lower (lift (string-append "?r" (number->string root)) "binary64")))))))
; not only binary64
(printf "reached end\n")

;; 6. Call process-egglog

Expand Down Expand Up @@ -486,32 +551,16 @@
[(list op args ...) `(,(hash-ref id->e1 op) ,@(map loop args))])))

(define (egglog-rewrite-rules rules tag)
; (printf "before-rules ~a\n" rules)

(define actual-rules
(match rules
[`(quote lift) (platform-lifting-rules)]
[`(quote lower) (platform-lowering-rules)]
[_ rules]))

; (printf "after-rules ~a\n\n" actual-rules)

(define return-list
(for/list ([rule (in-list actual-rules)])
; (printf "other-rule ~a\n" rule)
(if (not (representation? (rule-output rule)))
`(rewrite ,(expr->e1-pattern (rule-input rule))
,(expr->e1-pattern (rule-output rule))
:ruleset
,tag)
`(rewrite ,(expr->e2-pattern (rule-input rule) (rule-otype rule))
,(expr->e2-pattern (rule-output rule) (rule-otype rule))
:ruleset
,tag))))

(printf "finished for loop \n")

return-list)
(for/list ([rule (in-list rules)])
(if (not (representation? (rule-output rule)))
`(rewrite ,(expr->e1-pattern (rule-input rule))
,(expr->e1-pattern (rule-output rule))
:ruleset
,tag)
`(rewrite ,(expr->e2-pattern (rule-input rule) (rule-otype rule))
,(expr->e2-pattern (rule-output rule) (rule-otype rule))
:ruleset
,tag))))

(define (egglog-add-exprs batch ctx)
(define egglog-exprs '())
Expand Down
5 changes: 3 additions & 2 deletions src/core/localize.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@
(make-egg-runner batch
(batch-roots batch)
reprs
`((lift . ((iteration . 1))) (,rules . ((node . ,(*node-limit*))))
(lower . ((iteration . 1))))))
`((lift . ((iteration . 1)))
(,rules . ((node . ,(*node-limit*))))
(lower . ((iteration . 1))))))

; run egg
(define simplified
Expand Down
2 changes: 1 addition & 1 deletion src/core/mainloop.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@

; egg schedule (only mathematical rewrites)
(define rules (append (*fp-safe-simplify-rules*) (real-rules (*simplify-rules*))))
(define schedule `((,rules . ((node . ,(*node-limit*)) (const-fold? . #f)))))
(define schedule `((,rules . ((node . ,(*node-limit*))))))

; egg runner
(define exprs (map alt-expr alts))
Expand Down
8 changes: 4 additions & 4 deletions src/core/patch.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
(define schedule
(if (flag-set? 'generate 'simplify)
; if simplify enabled, 2-phases for real rewrites and implementation selection
`((,rules . ((node . ,(*node-limit*)))) ('lower . ((iteration . 1) (scheduler . simple))))
`((,rules . ((node . ,(*node-limit*)))) (lower . ((iteration . 1))))
; if disabled, only implementation selection
`(('lower . ((iteration . 1) (scheduler . simple))))))
`((lower . ((iteration . 1))))))

(define roots
(for/vector ([approx (in-list approxs)])
Expand Down Expand Up @@ -142,9 +142,9 @@

; egg schedule (3-phases for mathematical rewrites and implementation selection)
(define schedule
`(('lift . ((iteration . 1) (scheduler . simple)))
`((lift . ((iteration . 1)))
(,rules . ((node . ,(*node-limit*))))
('lower . ((iteration . 1) (scheduler . simple)))))
(lower . ((iteration . 1)))))

; run egg
(define exprs (map (compose debatchref alt-expr) altns))
Expand Down
4 changes: 2 additions & 2 deletions src/core/preprocess.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@

; egg schedule (3-phases for mathematical rewrites and implementation selection)
(define schedule
`((lift . ((iteration . 1)))
(,rules . ((node . ,(*node-limit*))))
`((lift . ((iteration . 1)))
(,rules . ((node . ,(*node-limit*))))
(lower . ((iteration . 1)))))

; egg query
Expand Down

0 comments on commit 73227b6

Please sign in to comment.