Skip to content

Commit

Permalink
Type-checking logic added
Browse files Browse the repository at this point in the history
  • Loading branch information
varun10p committed Sep 25, 2024
1 parent a4efa40 commit 3a50b18
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 70 deletions.
2 changes: 1 addition & 1 deletion src/core/mainloop.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@
(timeline-event! 'simplify)

; egg schedule (only mathematical rewrites)
(define rules (*fp-safe-simplify-rules*))
(define rules (append (*fp-safe-simplify-rules*) (real-rules (*simplify-rules*))))
(define schedule `((,rules . ((node . ,(*node-limit*)) (const-fold? . #f)))))

; egg runner
Expand Down
21 changes: 3 additions & 18 deletions src/platforms/bool.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,16 @@
#:spec (not x)
#:fpcore (! (not x))
#:fl not
#:identities
([not-true (not (TRUE)) (FALSE)] [not-false (not (FALSE)) (TRUE)]
[not-not (not (not a)) a]
[not-and (not (and a b)) (or (not a) (not b))]
[not-or (not (or a b)) (and (not a) (not b))]
[not-lt (not (< x y)) (>= x y)]
[not-gt (not (> x y)) (<= x y)]
[not-lte (not (<= x y)) (> x y)]
[not-gte (not (>= x y)) (< x y)]))
#:identities (#:exact (not a)))

(define-operator-impl (and [x : bool] [y : bool])
bool
#:spec (and x y)
#:fl and-fn
#:identities
([and-true-l (and (TRUE) a) a] [and-true-r (and a (TRUE)) a]
[and-false-l (and (FALSE) a) (FALSE)]
[and-false-r (and a (FALSE)) (FALSE)]
[and-same (and a a) a]))
#:identities (#:exact (and a b)))

(define-operator-impl (or [x : bool] [y : bool])
bool
#:spec (or x y)
#:fl or-fn
#:identities ([or-true-l (or (TRUE) a) (TRUE)] [or-true-r (or a (TRUE)) (TRUE)]
[or-false-l (or (FALSE) a) a]
[or-false-r (or a (FALSE)) a]
[or-same (or a a) a]))
#:identities (#: exact (or a b)))
3 changes: 2 additions & 1 deletion src/syntax/matcher.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

#lang racket

(provide pattern-match
(provide merge-bindings
pattern-match
pattern-substitute)

;; Unions two bindings. Returns #f if they disagree.
Expand Down
91 changes: 80 additions & 11 deletions src/syntax/platform.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"../core/programs.rkt"
"../core/rules.rkt"
"matcher.rkt"
"sugar.rkt"
"syntax.rkt"
"types.rkt")

Expand Down Expand Up @@ -549,18 +550,86 @@
(string-join (map (lambda (subst) (~a (cdr subst))) isubst) "-"))))
(sow (rule name* input* output* itypes* repr)))))]))))

(define (expr-otype expr)
(match expr
[(? literal?) #f]
[(? variable?) #f]
[(list 'if cond ift iff) (expr-otype ift)]
[(list op args ...) (impl-info op 'otype)]))

(define (type-verify expr otype)
(match expr
[(? literal?) '()]
[(? variable?) '((cons expr otype))]
[(list 'if cond ift iff)
(define bool-repr (get-representation 'bool))
(define combined
(merge-bindings (type-verify cond bool-repr)
(merge-bindings (type-verify ift otype) (type-verify iff otype))))
(unless combined
(error 'type-verify "Variable types do not match in ~a" expr))
combined]
[(list op args ...)
(define op-otype (impl-info op 'otype))
(when (not (equal? op-otype otype))
(error 'type-verify "Operator ~a has type ~a, expected ~a" op op-otype otype))
(define bindings '())
(for ([arg (in-list args)]
[itype (in-list (impl-info op 'itype))])
(define combined (merge-bindings bindings (type-verify arg itype)))
(unless combined
(error 'type-verify "Variable types do not match in ~a" expr))
(set! bindings combined))
bindings]))

(define (expr->prog expr repr)
(match expr
[(? literal?) (literal (get-representation repr) expr)]
[(? variable?) expr]
[`(if ,cond ,ift ,iff)
`(if ,(expr->prog cond repr) ,(expr->prog ift repr) ,(expr->prog iff repr))]
[`(,impl ,args ...) `(impl ,@(map (λ (arg) (expr->prog arg (impl-info impl 'itype))) args))]))

(define (*fp-safe-simplify-rules*)
(reap [sow]
(for ([impl (in-list (platform-impls (*active-platform*)))])
(define rules (impl-info impl 'identities))
(for ([name (in-hash-keys rules)])
(match-define (list input output vars) (hash-ref rules name))
(define itype (car (impl-info impl 'itype)))
(define r
(rule name
input
output
(for/hash ([v (in-list vars)])
(values v itype))
(impl-info impl 'otype)))
(sow r)))))
(for ([identity (in-list rules)])
(match identity
[(list 'exact name expr)
(when (not (expr-otype expr))
(error "Exact identity expr cannot infer type"))
(define otype (expr-otype expr))
(define var-types (type-verify expr otype))
(define prog (expr->prog expr otype))
(define r
(rule name
prog
(prog->spec prog)
(for/hash ([binding (in-list var-types)])
(values (car binding) (cdr binding)))
(impl-info impl 'otype)))
(sow r)]
[(list 'commutes name expr rev-expr)
(define vars (impl-info impl 'vars))
(define itype (car (impl-info impl 'itype)))
(define r
(rule name
(expr->prog expr)
(expr->prog rev-expr)
(for/hash ([v (in-list vars)])
(values v itype))
(impl-info impl 'otype))) ; Commutes by definition the types are matching
(sow r)]
[(list 'directed name lhs rhs)
(define lotype (expr-otype lhs))
(define rotype (expr-otype rhs))
(define var-types (merge-bindings (type-verify lhs lotype) (type-verify rhs rotype)))
(define r
(rule name
(expr->prog lhs)
(expr->prog rhs)
(for/hash ([binding (in-list var-types)])
(values (car binding) (cdr binding)))
(impl-info impl 'otype)))
(sow r)])))))
83 changes: 44 additions & 39 deletions src/syntax/syntax.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -389,52 +389,57 @@
name)]))

; make hash table
(define rules (make-hasheq))
(define count 0)
(define rules '())
(define rule-names (make-hasheq))
(define commutes? #f)
(when identities
(for ([ident (in-list identities)])
(match ident
[(list ident-name lhs-expr rhs-expr)
(cond
[(hash-has-key? rules ident-name)
(raise-herbie-syntax-error "Duplicate identity ~a" ident-name)]
[else
(hash-set! rules
(string->symbol (format "~a-~a" (symbol->string ident-name) name))
(list lhs-expr
rhs-expr
(remove-duplicates (append (free-variables lhs-expr)
(free-variables rhs-expr)))))])]
[(list 'exact expr)
(hash-set! rules
(gensym (string->symbol (format "~a-exact-~a" name count)))
(list expr expr (free-variables expr)))
(set! count (+ count 1))]
[(list 'commutes)
(cond
[commutes? (raise-herbie-syntax-error "Commutes identity already defined")]
[(hash-has-key? rules (string->symbol (format "~a-commutes" name)))
(raise-herbie-syntax-error "Commutes identity already manually defined")]
[(not (equal? (length vars) 2))
(raise-herbie-syntax-error "Cannot commute a non 2-ary operator")]
[else
(set! commutes? #t)
(hash-set! rules
(string->symbol (format "~a-commutes" name))
(list `(,name ,@vars) `(,name ,@(reverse vars)) vars))])])))
(set! rules
(for/list ([ident (in-list identities)]
[i (in-naturals)])
(match ident
[(list ident-name lhs-expr rhs-expr)
(cond
[(hash-has-key? rule-names ident-name)
(raise-herbie-syntax-error "Duplicate identity ~a" ident-name)]
[(not (well-formed? lhs-expr))
(raise-herbie-syntax-error "Ill-formed identity expression ~a" lhs-expr)]
[(not (well-formed? rhs-expr))
(raise-herbie-syntax-error "Ill-formed identity expression ~a" rhs-expr)]
[else
(define rule-name (string->symbol (format "~a-~a" ident-name name)))
(hash-set! rule-names rule-name #f)
(list 'directed rule-name lhs-expr rhs-expr)])]
[(list 'exact expr)
(cond
[(not (well-formed? expr))
(raise-herbie-syntax-error "Ill-formed identity expression ~a" expr)]
[else
(define rule-name (gensym (string->symbol (format "~a-exact-~a" name i))))
(hash-set! rule-names rule-name #f)
(list 'exact rule-name expr)])]
[(list 'commutes)
(cond
[commutes? (error "Commutes identity already defined")]
[(hash-has-key? rule-names (string->symbol (format "~a-commutes" name)))
(error "Commutes identity already manually defined")]
[(not (equal? (length vars) 2))
(raise-herbie-syntax-error "Cannot commute a non 2-ary operator")]
[else
(set! commutes? #t)
(define rule-name (string->symbol (format "~a-commutes" name)))
(hash-set! rule-names rule-name #f)
(list 'commutes rule-name `(,name ,@vars) `(,name ,@(reverse vars)))])]))))

; update tables
(define impl (operator-impl name ctx spec fpcore* fl-proc* rules))
(hash-set! operator-impls name impl))

(define (free-variables prog)
(match prog
[(? literal?) '()]
[(? number?) '()]
[(? variable?) (list prog)]
[(approx _ impl) (free-variables impl)]
[(list _ args ...) (remove-duplicates (append-map free-variables args))]))
(define (well-formed? expr)
(match expr
[(? number?) #t]
[(? variable?) #t]
[`(,impl ,args ...) (andmap well-formed? args)]
[_ #f]))

(define-syntax (define-operator-impl stx)
(define (oops! why [sub-stx #f])
Expand Down

0 comments on commit 3a50b18

Please sign in to comment.