Skip to content

Commit

Permalink
Rewrite the eclass reading
Browse files Browse the repository at this point in the history
  • Loading branch information
pavpanchekha committed Aug 30, 2024
1 parent fc30cd4 commit 78a7765
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 65 deletions.
94 changes: 79 additions & 15 deletions egg-herbie/main.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
egraph_find
egraph_serialize
egraph_get_eclasses
egraph_get_eclass
in-egraph-enodes
egraph_get_simplest
egraph_get_variants
egraph_get_cost
Expand Down Expand Up @@ -241,21 +241,85 @@
->
[f : _rust/string]
->
(if (zero? (u32vector-length v))
(values
(if (zero? (u32vector-length v))
(or (string->number f) (string->symbol f))
(cons (string->symbol f) v))))
; u32vector
(define empty-u32vec (make-u32vector 0))

; egraph -> id -> (vectorof (or symbol? number? (cons symbol u32vector)))
(define (egraph_get_eclass egg-ptr id)
(define n (egraph_eclass_size egg-ptr id))
(for/vector #:length n
([i (in-range n)])
(define node-size (egraph_enode_size egg-ptr id i))
(if (zero? node-size)
(egraph_get_node egg-ptr id i empty-u32vec)
(egraph_get_node egg-ptr id i (make-u32vector node-size)))))
(string->symbol f))
v)))

;; This fairly long and ugly method is the core of our egg FFI.
;; It iterates over all enodes in the egraph, returning for each:
;;
;; - The eclass id. This is normalize to range over 1..n
;; - The operator name, which is a symbol or a number
;; - A u32vector of child eclass ids, which are also normalized.
;;
;; If an enode is an atom, like a symbol or a number, there aren't any
;; children. The same holds if it's a zero-argument operator; these
;; cases need to be disambiguated by the caller.
;;
;; This method is also extremely performance sensitive, so it is written
;; to maximize performance. This leads to a key constraint:
;;
;; - The u32vector of child eclass ids cannot escape the caller; it is
;; invalidated with each iteration.
;;
;; The caller must be careful!

(define (in-egraph-enodes egg-ptr)
(define eclass-ids (egraph_get_eclasses egg-ptr))
(define max-id
(for/fold ([current-max 0]) ([idx (in-range (u32vector-length eclass-ids))])
(define egg-id (u32vector-ref eclass-ids idx))
(max current-max egg-id)))
(define egg-id->idx (make-u32vector (+ max-id 1)))
(for ([idx (in-range (u32vector-length eclass-ids))])
(u32vector-set! egg-id->idx (u32vector-ref eclass-ids idx) idx))
(define num-eclasses (u32vector-length eclass-ids))

(define 0-vec (make-u32vector 0))
(define 1-vec (make-u32vector 1))
(define 2-vec (make-u32vector 2))
(define 3-vec (make-u32vector 3))

(define first-eclass (u32vector-ref eclass-ids 0))
(define first-eclass-size (egraph_eclass_size egg-ptr first-eclass))

(make-do-sequence
(lambda ()
(values
(lambda (i)
(match-define (vector eclass-idx eclass-id eclass-size enode-idx) i)
(define node-size (egraph_enode_size egg-ptr eclass-id enode-idx))
(define vec
(match node-size [0 0-vec] [1 1-vec] [2 2-vec] [3 3-vec] [n (make-u32vector n)]))
(define-values (op args) (egraph_get_node egg-ptr eclass-id enode-idx vec))
(for ([i (in-range (u32vector-length args))])
(define sub-eclass-id (u32vector-ref args i))
(u32vector-set! args i (u32vector-ref egg-id->idx sub-eclass-id)))
(values eclass-idx op args))
(lambda (i)
(match-define (vector eclass-idx eclass-id eclass-size enode-idx) i)
(define enode-idx* (add1 enode-idx))
(cond
[(< enode-idx* eclass-size)
(vector-set! i 3 enode-idx*)]
[else
(define eclass-idx* (add1 eclass-idx))
(vector-set! i 0 eclass-idx*)
(when (< eclass-idx* num-eclasses)
(define eclass-id* (u32vector-ref eclass-ids eclass-idx*))
(vector-set! i 1 eclass-id*)
(define eclass-size* (egraph_eclass_size egg-ptr eclass-id*))
(vector-set! i 2 eclass-size*)
(vector-set! i 3 0))])
i)
(vector 0 first-eclass first-eclass-size 0)
(lambda (i)
(match-define (vector eclass-idx eclass-id eclass-size enode-idx) i)
(< eclass-idx num-eclasses))
#f
#f))))

;; egraph -> id -> id
(define-eggmath egraph_find (_fun _egraph-pointer _uint -> _uint))
Expand Down
84 changes: 34 additions & 50 deletions src/core/egg-herbie.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,6 @@
(define (egraph-eclasses egraph-data)
(egraph_get_eclasses (egraph-data-egraph-pointer egraph-data)))

;; Extracts the nodes of an e-class as a vector
;; where each enode is either a symbol, number, or list
(define (egraph-get-eclass egraph-data id)
(define ptr (egraph-data-egraph-pointer egraph-data))
(define egg->herbie (egraph-data-egg->herbie-dict egraph-data))
(define eclass (egraph_get_eclass ptr id))
; need to fix up any constant operators
(for ([enode (in-vector eclass)]
[i (in-naturals)])
(when (and (symbol? enode) (not (hash-has-key? egg->herbie enode)))
(vector-set! eclass i (cons enode (make-u32vector 0)))))
eclass)

(define (egraph-find egraph-data id)
(egraph_find (egraph-data-egraph-pointer egraph-data) id))

Expand Down Expand Up @@ -592,31 +579,26 @@
;; Nodes are duplicated across their possible types.
(define (split-untyped-eclasses egraph-data egg->herbie)
(define eclass-ids (egraph-eclasses egraph-data))
(define max-id
(for/fold ([current-max 0]) ([egg-id (in-u32vector eclass-ids)])
(max current-max egg-id)))
(define egg-id->idx (make-u32vector (+ max-id 1)))
(for ([egg-id (in-u32vector eclass-ids)]
[idx (in-naturals)])
(u32vector-set! egg-id->idx egg-id idx))
(define egg-ptr (egraph-data-egraph-pointer egraph-data))

(define egg-id->idx
(for/hasheq ([idx (in-naturals)]
[egg-id (in-u32vector eclass-ids)])
(values egg-id idx)))

(define types (all-reprs/types))
(define type->idx (make-hasheq))
(for ([type (in-list types)]
[idx (in-naturals)])
(hash-set! type->idx type idx))
(define type->idx
(for/hasheq ([type (in-list types)] [idx (in-naturals)])
(values type idx)))
(define num-types (hash-count type->idx))

; allocate enough eclasses for every (egg-id, type) combination
(define n (* (u32vector-length eclass-ids) num-types))

; maps (idx, type) to type eclass id
(define (idx+type->id idx type)
(+ (* idx num-types) (hash-ref type->idx type)))

; maps (untyped eclass id, type) to typed eclass id
(define (lookup-id eid type)
(idx+type->id (u32vector-ref egg-id->idx eid) type))

; allocate enough eclasses for every (egg-id, type) combination
(define n (* (u32vector-length eclass-ids) num-types))
(define id->eclass (make-vector n '()))
(define id->parents (make-vector n '()))
(define id->leaf? (make-vector n #f))
Expand All @@ -627,25 +609,27 @@
; | (<symbol> . <u32vector>)
; NOTE: nodes in typed eclasses are reversed relative
; to their position in untyped eclasses
(for ([eid (in-u32vector eclass-ids)]
[idx (in-naturals)])
(define enodes (egraph-get-eclass egraph-data eid))
(for ([enode (in-vector enodes)])
; get all possible types for the enode
; lookup its correct eclass and add the rebuilt node
(define types (enode-type enode egg->herbie))
(for ([type (in-list types)])
(define id (idx+type->id idx type))
(define enode* (rebuild-enode enode type lookup-id))
(vector-set! id->eclass id (cons enode* (vector-ref id->eclass id)))
(match enode*
[(list _ ids ...)
(if (null? ids)
(vector-set! id->leaf? id #t)
(for ([child-id (in-list ids)])
(vector-set! id->parents child-id (cons id (vector-ref id->parents child-id)))))]
[(? symbol?) (vector-set! id->leaf? id #t)]
[(? number?) (vector-set! id->leaf? id #t)]))))
(for ([(eclass op args) (in-egraph-enodes egg-ptr)])
(define enode
(if (or (number? op)
(and (hash-has-key? egg->herbie op) (zero? (u32vector-length args))))
op
(cons op args)))
; get all possible types for the enode
; lookup its correct eclass and add the rebuilt node
(define types (enode-type enode egg->herbie))
(for ([type (in-list types)])
(define id (idx+type->id eclass type))
(define enode* (rebuild-enode enode type idx+type->id))
(vector-set! id->eclass id (cons enode* (vector-ref id->eclass id)))
(match enode*
[(list _)
(vector-set! id->leaf? id #t)]
[(list _ ids ...)
(for ([child-id (in-list ids)])
(vector-set! id->parents child-id (cons id (vector-ref id->parents child-id))))]
[(? symbol?) (vector-set! id->leaf? id #t)]
[(? number?) (vector-set! id->leaf? id #t)])))

; dedup `id->parents` values
(for ([id (in-range n)])
Expand Down Expand Up @@ -748,7 +732,7 @@
; build the canonical id map
(define egg-id->id (make-hash))
(for ([eid (in-u32vector eclass-ids)])
(define idx (u32vector-ref egg-id->idx eid))
(define idx (hash-ref egg-id->idx eid))
(define id0 (* idx num-types))
(for ([id (in-range id0 (+ id0 num-types))])
(define id* (vector-ref remap id))
Expand Down

0 comments on commit 78a7765

Please sign in to comment.