diff --git a/src/core/mainloop.rkt b/src/core/mainloop.rkt index a59617aa9..4809225d7 100644 --- a/src/core/mainloop.rkt +++ b/src/core/mainloop.rkt @@ -368,7 +368,7 @@ (timeline-event! 'simplify) ; egg schedule (only mathematical rewrites) - (define rules (*fp-safe-simplify-rules*)) + (define rules (append (*fp-safe-simplify-rules*) (real-rules (*simplify-rules*)))) (define schedule `((,rules . ((node . ,(*node-limit*)) (const-fold? . #f))))) ; egg runner diff --git a/src/platforms/bool.rkt b/src/platforms/bool.rkt index 54908d803..61d1a77e4 100644 --- a/src/platforms/bool.rkt +++ b/src/platforms/bool.rkt @@ -34,31 +34,16 @@ #:spec (not x) #:fpcore (! (not x)) #:fl not - #:identities - ([not-true (not (TRUE)) (FALSE)] [not-false (not (FALSE)) (TRUE)] - [not-not (not (not a)) a] - [not-and (not (and a b)) (or (not a) (not b))] - [not-or (not (or a b)) (and (not a) (not b))] - [not-lt (not (< x y)) (>= x y)] - [not-gt (not (> x y)) (<= x y)] - [not-lte (not (<= x y)) (> x y)] - [not-gte (not (>= x y)) (< x y)])) + #:identities (#:exact (not a))) (define-operator-impl (and [x : bool] [y : bool]) bool #:spec (and x y) #:fl and-fn - #:identities - ([and-true-l (and (TRUE) a) a] [and-true-r (and a (TRUE)) a] - [and-false-l (and (FALSE) a) (FALSE)] - [and-false-r (and a (FALSE)) (FALSE)] - [and-same (and a a) a])) + #:identities (#:exact (and a b))) (define-operator-impl (or [x : bool] [y : bool]) bool #:spec (or x y) #:fl or-fn - #:identities ([or-true-l (or (TRUE) a) (TRUE)] [or-true-r (or a (TRUE)) (TRUE)] - [or-false-l (or (FALSE) a) a] - [or-false-r (or a (FALSE)) a] - [or-same (or a a) a])) + #:identities (#: exact (or a b))) diff --git a/src/syntax/matcher.rkt b/src/syntax/matcher.rkt index 12d0c2699..93e75e860 100644 --- a/src/syntax/matcher.rkt +++ b/src/syntax/matcher.rkt @@ -2,7 +2,8 @@ #lang racket -(provide pattern-match +(provide merge-bindings + pattern-match pattern-substitute) ;; Unions two bindings. Returns #f if they disagree. diff --git a/src/syntax/platform.rkt b/src/syntax/platform.rkt index 56f182752..20d324b2e 100644 --- a/src/syntax/platform.rkt +++ b/src/syntax/platform.rkt @@ -5,6 +5,7 @@ "../core/programs.rkt" "../core/rules.rkt" "matcher.rkt" + "sugar.rkt" "syntax.rkt" "types.rkt") @@ -549,18 +550,86 @@ (string-join (map (lambda (subst) (~a (cdr subst))) isubst) "-")))) (sow (rule name* input* output* itypes* repr)))))])))) +(define (expr-otype expr) + (match expr + [(? literal?) #f] + [(? variable?) #f] + [(list 'if cond ift iff) (expr-otype ift)] + [(list op args ...) (impl-info op 'otype)])) + +(define (type-verify expr otype) + (match expr + [(? literal?) '()] + [(? variable?) '((cons expr otype))] + [(list 'if cond ift iff) + (define bool-repr (get-representation 'bool)) + (define combined + (merge-bindings (type-verify cond bool-repr) + (merge-bindings (type-verify ift otype) (type-verify iff otype)))) + (unless combined + (error 'type-verify "Variable types do not match in ~a" expr)) + combined] + [(list op args ...) + (define op-otype (impl-info op 'otype)) + (when (not (equal? op-otype otype)) + (error 'type-verify "Operator ~a has type ~a, expected ~a" op op-otype otype)) + (define bindings '()) + (for ([arg (in-list args)] + [itype (in-list (impl-info op 'itype))]) + (define combined (merge-bindings bindings (type-verify arg itype))) + (unless combined + (error 'type-verify "Variable types do not match in ~a" expr)) + (set! bindings combined)) + bindings])) + +(define (expr->prog expr repr) + (match expr + [(? literal?) (literal (get-representation repr) expr)] + [(? variable?) expr] + [`(if ,cond ,ift ,iff) + `(if ,(expr->prog cond repr) ,(expr->prog ift repr) ,(expr->prog iff repr))] + [`(,impl ,args ...) `(impl ,@(map (λ (arg) (expr->prog arg (impl-info impl 'itype))) args))])) + (define (*fp-safe-simplify-rules*) (reap [sow] (for ([impl (in-list (platform-impls (*active-platform*)))]) (define rules (impl-info impl 'identities)) - (for ([name (in-hash-keys rules)]) - (match-define (list input output vars) (hash-ref rules name)) - (define itype (car (impl-info impl 'itype))) - (define r - (rule name - input - output - (for/hash ([v (in-list vars)]) - (values v itype)) - (impl-info impl 'otype))) - (sow r))))) + (for ([identity (in-list rules)]) + (match identity + [(list 'exact name expr) + (when (not (expr-otype expr)) + (error "Exact identity expr cannot infer type")) + (define otype (expr-otype expr)) + (define var-types (type-verify expr otype)) + (define prog (expr->prog expr otype)) + (define r + (rule name + prog + (prog->spec prog) + (for/hash ([binding (in-list var-types)]) + (values (car binding) (cdr binding))) + (impl-info impl 'otype))) + (sow r)] + [(list 'commutes name expr rev-expr) + (define vars (impl-info impl 'vars)) + (define itype (car (impl-info impl 'itype))) + (define r + (rule name + (expr->prog expr) + (expr->prog rev-expr) + (for/hash ([v (in-list vars)]) + (values v itype)) + (impl-info impl 'otype))) ; Commutes by definition the types are matching + (sow r)] + [(list 'directed name lhs rhs) + (define lotype (expr-otype lhs)) + (define rotype (expr-otype rhs)) + (define var-types (merge-bindings (type-verify lhs lotype) (type-verify rhs rotype))) + (define r + (rule name + (expr->prog lhs) + (expr->prog rhs) + (for/hash ([binding (in-list var-types)]) + (values (car binding) (cdr binding))) + (impl-info impl 'otype))) + (sow r)]))))) diff --git a/src/syntax/syntax.rkt b/src/syntax/syntax.rkt index 8f96badd7..0164774b0 100644 --- a/src/syntax/syntax.rkt +++ b/src/syntax/syntax.rkt @@ -389,52 +389,57 @@ name)])) ; make hash table - (define rules (make-hasheq)) - (define count 0) + (define rules '()) + (define rule-names (make-hasheq)) (define commutes? #f) (when identities - (for ([ident (in-list identities)]) - (match ident - [(list ident-name lhs-expr rhs-expr) - (cond - [(hash-has-key? rules ident-name) - (raise-herbie-syntax-error "Duplicate identity ~a" ident-name)] - [else - (hash-set! rules - (string->symbol (format "~a-~a" (symbol->string ident-name) name)) - (list lhs-expr - rhs-expr - (remove-duplicates (append (free-variables lhs-expr) - (free-variables rhs-expr)))))])] - [(list 'exact expr) - (hash-set! rules - (gensym (string->symbol (format "~a-exact-~a" name count))) - (list expr expr (free-variables expr))) - (set! count (+ count 1))] - [(list 'commutes) - (cond - [commutes? (raise-herbie-syntax-error "Commutes identity already defined")] - [(hash-has-key? rules (string->symbol (format "~a-commutes" name))) - (raise-herbie-syntax-error "Commutes identity already manually defined")] - [(not (equal? (length vars) 2)) - (raise-herbie-syntax-error "Cannot commute a non 2-ary operator")] - [else - (set! commutes? #t) - (hash-set! rules - (string->symbol (format "~a-commutes" name)) - (list `(,name ,@vars) `(,name ,@(reverse vars)) vars))])]))) + (set! rules + (for/list ([ident (in-list identities)] + [i (in-naturals)]) + (match ident + [(list ident-name lhs-expr rhs-expr) + (cond + [(hash-has-key? rule-names ident-name) + (raise-herbie-syntax-error "Duplicate identity ~a" ident-name)] + [(not (well-formed? lhs-expr)) + (raise-herbie-syntax-error "Ill-formed identity expression ~a" lhs-expr)] + [(not (well-formed? rhs-expr)) + (raise-herbie-syntax-error "Ill-formed identity expression ~a" rhs-expr)] + [else + (define rule-name (string->symbol (format "~a-~a" ident-name name))) + (hash-set! rule-names rule-name #f) + (list 'directed rule-name lhs-expr rhs-expr)])] + [(list 'exact expr) + (cond + [(not (well-formed? expr)) + (raise-herbie-syntax-error "Ill-formed identity expression ~a" expr)] + [else + (define rule-name (gensym (string->symbol (format "~a-exact-~a" name i)))) + (hash-set! rule-names rule-name #f) + (list 'exact rule-name expr)])] + [(list 'commutes) + (cond + [commutes? (error "Commutes identity already defined")] + [(hash-has-key? rule-names (string->symbol (format "~a-commutes" name))) + (error "Commutes identity already manually defined")] + [(not (equal? (length vars) 2)) + (raise-herbie-syntax-error "Cannot commute a non 2-ary operator")] + [else + (set! commutes? #t) + (define rule-name (string->symbol (format "~a-commutes" name))) + (hash-set! rule-names rule-name #f) + (list 'commutes rule-name `(,name ,@vars) `(,name ,@(reverse vars)))])])))) ; update tables (define impl (operator-impl name ctx spec fpcore* fl-proc* rules)) (hash-set! operator-impls name impl)) -(define (free-variables prog) - (match prog - [(? literal?) '()] - [(? number?) '()] - [(? variable?) (list prog)] - [(approx _ impl) (free-variables impl)] - [(list _ args ...) (remove-duplicates (append-map free-variables args))])) +(define (well-formed? expr) + (match expr + [(? number?) #t] + [(? variable?) #t] + [`(,impl ,args ...) (andmap well-formed? args)] + [_ #f])) (define-syntax (define-operator-impl stx) (define (oops! why [sub-stx #f])