From e2a69cf5d557bcf92780c3610fa454e33c8b86db Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Wed, 7 Feb 2024 12:13:17 -0700 Subject: [PATCH] feat!: port to genjl style interfaces (#60) --- examples/editable.clj | 13 +- examples/intro_to_modeling.clj | 81 ++--- examples/introduction.clj | 48 +-- src/data_readers.cljc | 4 +- src/gen/choice_map.cljc | 30 -- src/gen/choicemap.cljc | 2 +- src/gen/distribution.cljc | 143 ++++++-- src/gen/distribution/commons_math.clj | 16 +- src/gen/distribution/java_util.clj | 6 +- src/gen/distribution/kixi.cljc | 16 +- src/gen/dynamic.cljc | 488 ++++++++++++++++++++++---- src/gen/dynamic/choice_map.cljc | 230 ------------ src/gen/dynamic/trace.cljc | 276 --------------- src/gen/generative_function.cljc | 3 +- src/gen/sci.cljc | 22 +- src/gen/trace.cljc | 8 +- test/gen/distribution_test.cljc | 64 ++-- test/gen/dynamic/choice_map_test.cljc | 57 --- test/gen/dynamic/trace_test.cljc | 87 ----- test/gen/dynamic_test.cljc | 232 +++++++----- test/gen/sci_test.cljc | 9 +- 21 files changed, 807 insertions(+), 1028 deletions(-) delete mode 100644 src/gen/choice_map.cljc delete mode 100644 src/gen/dynamic/choice_map.cljc delete mode 100644 src/gen/dynamic/trace.cljc delete mode 100644 test/gen/dynamic/choice_map_test.cljc delete mode 100644 test/gen/dynamic/trace_test.cljc diff --git a/examples/editable.clj b/examples/editable.clj index be07ae0..0a741f1 100644 --- a/examples/editable.clj +++ b/examples/editable.clj @@ -5,12 +5,11 @@ `intro-to-modeling`, and also changed the `pmap` call in `prepeatedly` into `map`, since we don't have `pmap` available in the browser." {:nextjournal.clerk/toc true} - (:require [gen.choice-map] - [gen.dynamic :as dynamic :refer [gen]] - [gen.dynamic.choice-map :refer [choice-map]] + (:require [gen.choicemap :refer [choicemap]] [gen.clerk.callout :as callout] [gen.clerk.viewer :as viewer] [gen.distribution.kixi :as dist] + [gen.dynamic :as dynamic :refer [gen]] [gen.generative-function :as gf] [gen.inference.importance :as importance] [gen.trace :as trace] @@ -586,7 +585,7 @@ math/PI ;; them to be inferred. (let [observations (reduce (fn [observations [i y]] (assoc observations [:y i] y)) - (choice-map {}) + (choicemap {}) (map-indexed vector ys))] (:trace (importance/resampling model [xs] observations amount-of-computation)))) @@ -652,7 +651,7 @@ math/PI ;; For example: -(def predicting-constraints (choice-map {:slope 0 :intercept 0})) +(def predicting-constraints (choicemap {:slope 0 :intercept 0})) (def predicting-trace (:trace (gf/generate line-model [xs] predicting-constraints))) (def predict-opts @@ -684,7 +683,7 @@ math/PI ;; constraints. (let [constraints (reduce (fn [cm param-addr] (assoc cm param-addr (get trace param-addr))) - (choice-map {}) + (choicemap {}) param-addrs) ;; Run the model with new x coordinates, and with parameters @@ -944,7 +943,7 @@ math/PI ;; collisions for complex models. ;; Hierarchical traces are represented using nested choice maps -;; (`gen.dynamic.choice-map/ChoiceMap`). Hierarchical addresses can be accessed +;; (`gen.dynamic.choicemap/ChoiceMap`). Hierarchical addresses can be accessed ;; using `clojure.core` functions like `clojure.core/get-in`. (get-in bar-with-key-trace [:z :y]) diff --git a/examples/intro_to_modeling.clj b/examples/intro_to_modeling.clj index 0ff88d4..5ea5c44 100644 --- a/examples/intro_to_modeling.clj +++ b/examples/intro_to_modeling.clj @@ -1,12 +1,11 @@ ^{:nextjournal.clerk/visibility {:code :hide :result :hide}} (ns intro-to-modeling {:nextjournal.clerk/toc true} - (:require [gen.choice-map] - [gen.dynamic :as dynamic :refer [gen]] - [gen.dynamic.choice-map :refer [choice-map]] + (:require [gen.choicemap :as choicemap :refer [choicemap]] [gen.clerk.callout :as callout] [gen.clerk.viewer :as viewer] [gen.distribution.kixi :as dist] + [gen.dynamic :as dynamic :refer [gen]] [gen.generative-function :as gf] [gen.inference.importance :as importance] [gen.trace :as trace] @@ -51,7 +50,7 @@ ;; - **Modeling**: You first need to frame the problem — and any assumptions you ;; bring to the table — as a probabilistic model. A huge variety of problems ;; can be viewed from a modeling & inference lens, if you set them up -;; properly. **This notebook is about how to think of problems in this light, +;; properly. **This notebook is about how to think of problems in this light, ;; and how to use Gen** **to formally specify your assumptions and the tasks ;; you wish to solve.** @@ -61,7 +60,7 @@ ;; proposal distributions. With enough computation, the algorithm can in ;; theory solve any modeling and inference problem, but in practice, for most ;; problems of interest, it is too slow to achieve accurate results in a -;; reasonable amount of time. **Future tutorials introduce some of Gen's** +;; reasonable amount of time. **Future tutorials introduce some of Gen's** ;; **programmable inference features**, which let you tailor the inference ;; algorithm for use with more complex models (Gen will still automate the ;; math!). @@ -222,32 +221,27 @@ ^{::clerk/visibility {:result :hide}} (def line-model (gen [xs] - - ;; We begin by sampling a slope and intercept for the line. Before we have + ;; We begin by sampling a slope and intercept for the line. Before we have ;; seen the data, we don't know the values of these parameters, so we treat - ;; them as random choices. The distributions they are drawn from represent our - ;; prior beliefs about the parameters: in this case, that neither the slope - ;; nor the intercept will be more than a couple points away from 0. - + ;; them as random choices. The distributions they are drawn from represent + ;; our prior beliefs about the parameters: in this case, that neither the + ;; slope nor the intercept will be more than a couple points away from 0. (let [slope (dynamic/trace! :slope dist/normal 0 1) intercept (dynamic/trace! :intercept dist/normal 0 2) ;; We define a function to compute y for a given x. - y (fn [x] (+ (* slope x) intercept))] ;; Given the slope and intercept, we can sample y coordinates for each of ;; the x coordinates in our input vector. - (doseq [[i x] (map vector (range) xs)] (dynamic/trace! [:y i] dist/normal (y x) 0.1)) - ;; Most of the time, we don't care about the return - ;; value of a model, only the random choices it makes. - ;; It can sometimems be useful to return something - ;; meaningful, however; here, we return the function `y`. + ;; Most of the time, we don't care about the return value of a model, only + ;; the random choices it makes. It can sometimems be useful to return + ;; something meaningful, however; here, we return the function `y`. y))) ;; The generative function takes as an argument a vector of x-coordinates. We @@ -304,7 +298,6 @@ ;; (require '[gen.trace :as trace]) ;; ``` - (trace/get-args trace) ;; The trace also contains the value of the random choices, stored in a map from @@ -326,15 +319,6 @@ (:slope (trace/get-choices trace)) -;; We can also read the value of a random choice directly from the trace, -;; without having to use `gen.trace/get-choices` first: - -(get trace :slope) - -(trace :slope) - -(:slope trace) - ;; The return value is also recorded in the trace, and is accessible with the ;; `trace/get-retval` API method: @@ -352,8 +336,9 @@ [trace & {:keys [clip x-domain y-domain] :or {clip false}}] (let [[xs] (trace/get-args trace) ; Pull out the xs from the trace. y (trace/get-retval trace) ; Pull out the return value, useful for plotting. + choices (trace/get-choices trace) ys (for [i (range (count xs))] - (trace [:y i])) + (choicemap/get-value choices [:y i])) data (mapv (fn [x y] {:x x :y y}) xs @@ -472,8 +457,9 @@ math/PI min-x (apply min xs) max-x (apply max xs) y (trace/get-retval trace) ; Pull out the return value, useful for plotting. + choices (trace/get-choices trace) ys (for [i (range (count xs))] - (get trace [:y i])) + (choicemap/get-value choices [:y i])) data (mapv (fn [x y] {:x x :y y}) xs @@ -585,7 +571,7 @@ math/PI ;; ```clojure ;; (require '[gen.inference.importance :as importance] -;; '[gen.dynamic.choice-map :refer [choice-map]]) +;; '[gen.choicemap :as choicemap :refer [choicemap]]) ;; ``` (defn do-inference @@ -595,7 +581,7 @@ math/PI ;; them to be inferred. (let [observations (reduce (fn [observations [i y]] (assoc observations [:y i] y)) - (choice-map {}) + (choicemap {}) (map-indexed vector ys))] (:trace (importance/resampling model [xs] observations amount-of-computation)))) @@ -661,7 +647,7 @@ math/PI ;; For example: -(def predicting-constraints (choice-map {:slope 0 :intercept 0})) +(def predicting-constraints {:slope 0 :intercept 0}) (def predicting-trace (:trace (gf/generate line-model [xs] predicting-constraints))) (def predict-opts @@ -691,17 +677,19 @@ math/PI [model trace new-xs param-addrs] ;; Copy parameter values from the inferred trace (`trace`) into a fresh set of ;; constraints. - (let [constraints (reduce (fn [cm param-addr] - (assoc cm param-addr (get trace param-addr))) - (choice-map {}) + (let [prev-choices (trace/get-choices trace) + constraints (reduce (fn [cm param-addr] + (assoc cm param-addr + (choicemap/get-value prev-choices param-addr))) + (choicemap) param-addrs) - ;; Run the model with new x coordinates, and with parameters ;; fixed to be the inferred values. - {new-trace :trace} (gf/generate model [new-xs] constraints)] + {new-trace :trace} (gf/generate model [new-xs] constraints) + choices (trace/get-choices new-trace)] ;; Pull out the y-values and return them. - (mapv #(get new-trace [:y %]) + (mapv #(choicemap/get-value choices [:y %]) (range (count new-xs))))) ;; To illustrate, we call the function above given the previous trace (which @@ -953,8 +941,8 @@ math/PI ;; collisions for complex models. ;; Hierarchical traces are represented using nested choice maps -;; (`gen.dynamic.choice-map/ChoiceMap`). Hierarchical addresses can be accessed -;; using `clojure.core` functions like `clojure.core/get-in`. +;; (`gen.choicemap/IChoiceMap`). Hierarchical addresses can be accessed using +;; `clojure.core` functions like `clojure.core/get-in`. (get-in bar-with-key-trace [:z :y]) @@ -1209,15 +1197,16 @@ math/PI (defn render-changepoint-model-trace [trace] - (let [[xs] (trace/get-args trace) - ys (for [i (range (count xs))] - (trace [:y i])) - node (trace/get-retval trace) + (let [[xs] (trace/get-args trace) + choices (trace/get-choices trace) + ys (for [i (range (count xs))] + (choicemap/get-value choices [:y i])) + node (trace/get-retval trace) node-layer (node-vl-spec node) data-layer (scatter-spec xs ys :color "grey" :fillOpacity 0.3 :strokeOpacity 1.0)] - (clerk/vl {:schema "https://vega.github.io/schema/vega-lite/v5.json" + (clerk/vl {:schema "https://vega.github.io/schema/vega-lite/v5.json" :embed/opts {:actions false} - :layer [node-layer data-layer]}))) + :layer [node-layer data-layer]}))) {::clerk/visibility {:result :show}} diff --git a/examples/introduction.clj b/examples/introduction.clj index 9ac713e..2190e4e 100644 --- a/examples/introduction.clj +++ b/examples/introduction.clj @@ -4,6 +4,7 @@ [clojure.repl :as repl] [gen.distribution.commons-math :as dist] [gen.dynamic :as dynamic :refer [gen]] + [gen.choicemap :as choicemap] [gen.generative-function :as gf] [gen.trace :as trace] [nextjournal.clerk :as clerk])) @@ -298,12 +299,14 @@ (let [data (apply concat (for [p ps] (->> (repeatedly #(trace/get-choices (gf/simulate gen-f [p]))) - (filter (fn [trace] - (= observed-fp (get trace :fp)))) - (take 1000) - (mapv (fn [trace] + (filter (fn [choices] + (= observed-fp + (choicemap/get-value choices :fp)))) + (take 100) + (mapv (fn [choices] {:p p - :if-test (get trace :if-test)})))))] + :if-test + (choicemap/get-value choices :if-test)})))))] (clerk/vl {:schema "https://vega.github.io/schema/vega-lite/v5.json" :embed/opts {:actions false} :data {:values data} @@ -437,30 +440,9 @@ ;; where `:a` is always true and `:c` is always false. We first construct a ;; choice map containing these constraints: -(require '[gen.dynamic.choice-map :as dynamic.choice-map] - '[gen.choice-map :as choice-map]) - (def constraints - (dynamic.choice-map/choice-map - :a true - :c false)) - -#_ -(choice-map/submaps - (dynamic.choice-map/choice-map - :a true - :c false)) - -;; The `gen.dynamic.choice-map/choice-map` constructor above took two elements -;; of the form (address, value). This is equivalent to constructing an empty -;; choice map and then populating it: - -(def choices - (assoc (dynamic.choice-map/choice-map) - :a true - :c false)) - -(choice-map/submaps choices) + {:a true + :c false}) ;; Then, we pass the constraints as the third argument to ;; `gen.generative-function/generate`, after the function itself and the @@ -613,7 +595,7 @@ (def samples (let [n-particles [1 10 100]] (zipmap n-particles - (mapv #(draw-samples % #gen/choice-map {:c false}) + (mapv #(draw-samples % {:c false}) n-particles)))) (clerk/vl {:schema "https://vega.github.io/schema/vega-lite/v5.json" @@ -696,13 +678,13 @@ ;; Consider the function `foo` from above. Let's obtain an initial trace: -(def update-trace (:trace (gf/generate foo [0.3] #gen/choice-map {:a true :b true :c true}))) +(def update-trace (:trace (gf/generate foo [0.3] {:a true :b true :c true}))) (trace/get-choices update-trace) ;; Now, we use the `update` function, to change the value of `:c` from `true` to ;; `false`: -(def updated (trace/update update-trace #gen/choice-map {:c false})) +(def updated (trace/update update-trace {:c false})) (trace/get-choices (:trace updated)) ;; The `update` function returns the new trace, as well as a weight, which the @@ -724,9 +706,9 @@ ;; Doing an update can also cause some addresses to leave the choice map ;; altogether. For example, if we set `:a` to `false`, then choice at address -;; `:b` is no longer include in the choice map. +;; `:b` is no longer included in the choice map. -(def update-a-true (trace/update update-trace #gen/choice-map {:a false})) +(def update-a-true (trace/update update-trace {:a false})) (trace/get-choices (:trace update-a-true)) ;; The *discard* choice map that is returned by `update` contains the valus for diff --git a/src/data_readers.cljc b/src/data_readers.cljc index 92459e9..319dbe8 100644 --- a/src/data_readers.cljc +++ b/src/data_readers.cljc @@ -1,2 +1,2 @@ -{gen/choice gen.dynamic.choice-map/parse-choice - gen/choice-map gen.dynamic.choice-map/parse-choice-map} +{gen/choice gen.choicemap/parse-choice + gen/choicemap gen.choicemap/parse-choicemap} diff --git a/src/gen/choice_map.cljc b/src/gen/choice_map.cljc deleted file mode 100644 index cc9d30d..0000000 --- a/src/gen/choice_map.cljc +++ /dev/null @@ -1,30 +0,0 @@ -(ns gen.choice-map - "Protocols that constitute the choice map interface.") - -;; https://www.gen.dev/docs/stable/ref/choice_maps/#Choice-Maps-1 - -;; [x] Gen.get_value — Function. ; (get (values cm) k) -;; [x] Gen.has_value — Function. ; (contains? (values cm) k) -;; [x] Gen.get_submap — Function. ; (get (submaps cm) k) -;; [x] Gen.get_values_shallow — Function. ; (values cm) -;; [x] Gen.get_value — Function. (get (values cm) k) -;; [x] Gen.get_submaps_shallow — Function. ; (submaps cm) -;; [x] Gen.to_array — Function. ; (into [] (values cm)) -;; [ ] Gen.from_array — Function. -;; [ ] Gen.get_selected — Function. - -(defprotocol Value - :extend-via-metadata true - (value [cm] "Returns the value for ")) - -(defprotocol Values - :extend-via-metadata true - (values [cm] "Returns an associative data structure mapping keys to values.")) - -(defprotocol Submaps - :extend-via-metadata true - (submaps [cm] "Returns an associative data structure mapping keys to submaps.")) - -(defprotocol Leaf - :extend-via-metadata true - (leaf-value [cm] "Returns the value for")) diff --git a/src/gen/choicemap.cljc b/src/gen/choicemap.cljc index afacc24..95a672c 100644 --- a/src/gen/choicemap.cljc +++ b/src/gen/choicemap.cljc @@ -670,7 +670,7 @@ ;; ### ChoiceMap interactions -(defn- equiv +(defn ^:no-doc equiv "Returns true if `r` is a choicemap with equivalent submaps to `l`, false otherwise. diff --git a/src/gen/distribution.cljc b/src/gen/distribution.cljc index b4d0763..dcdeab8 100644 --- a/src/gen/distribution.cljc +++ b/src/gen/distribution.cljc @@ -1,9 +1,11 @@ (ns gen.distribution "Collection of protocols and functions for working with primitive distributions." - (:require [gen.dynamic.choice-map :as cm] + (:require [clojure.pprint :as pprint] + [gen.choicemap :as choicemap] + [gen.diff :as diff] [gen.generative-function :as gf] - [gen.dynamic.trace :as trace]) + [gen.trace :as trace]) #?(:clj (:import (clojure.lang IFn)))) @@ -27,6 +29,74 @@ (and (satisfies? LogPDF t) (satisfies? Sample t))) +;; ## Combinators +;; +;; The [[Encoded]] type creates a new distribution from a base distribution +;; `dist`. This new distribution transforms values on the way in to `logpdf` +;; using an `encode` function, and decodes sampled values via `decode`. +;; +;; This is useful for building distributions like categorical distributions that +;; might produce and score arbitrary Clojure values, but lean on some existing +;; numeric base implementation. + +(defrecord Encoded [dist encode decode] + LogPDF + (logpdf [_ v] + (logpdf dist (encode v))) + + Sample + (sample [_] + (decode (sample dist)))) + +(defn encoded + "Given a distribution-producing function `ctor`, returns a constructor for a new + distribution that + + - encodes each value `v` into `(encode v)` before passage to [[logpdf]] + - decodes each value `v` sampled from the base distribution into `(decode + v)`" + [ctor encode decode] + (comp #(->Encoded % encode decode) ctor)) + +;; ## Primitive Trace +;; +;; [[Trace]] above tracks map-like associations of address to traced value. At +;; the bottom of the tree represented by these associations is a primitive +;; trace, usually generated by a primitive probability distribution. +;; +;; [[Trace]] is a simplified version of [[Trace]] (and an +;; implementer of the [[gen.trace/ITrace]] interface) designed for a single +;; value. + +(defrecord Trace [gen-fn args val score] + trace/ITrace + (get-args [_] args) + (get-retval [_] val) + (get-choices [_] (choicemap/->Choice val)) + (get-gen-fn [_] gen-fn) + (get-score [_] score) + + trace/IUpdate + (-update [_ _ _ constraint] + (let [current (choicemap/->Choice val)] + (if (choicemap/has-value? constraint) + (-> (gf/generate gen-fn args constraint) + (update :weight - score) + (assoc :change diff/unknown-change + :discard current)) + (-> (gf/generate gen-fn args current) + (update :weight - score) + (assoc :change diff/no-change + :discard choicemap/EMPTY)))))) + +#?(:clj + (defmethod print-method Trace + [^Trace t ^java.io.Writer w] + (print-method (trace/trace->map t) w))) + +(defmethod pprint/simple-dispatch Trace [^Trace t] + (pprint/simple-dispatch (trace/trace->map t))) + ;; ## Primitive Generative Functions ;; The [[gen.distribution/GenerativeFn]] type wraps a constructor `ctor` (a @@ -39,22 +109,48 @@ ;; ;; This type provides support for all primitive distributions. -(defrecord GenerativeFn [ctor] +(defrecord GenerativeFn [ctor arity] gf/IGenerativeFunction + (has-argument-grads [_] (repeat arity false)) + + (accepts-output-grad? [_] false) + + (get-params [_] ()) + (simulate [this args] (let [dist (apply ctor args) val (sample dist) score (logpdf dist val)] - (trace/->PrimitiveTrace this args val score))) + (->Trace this args val score))) gf/IGenerate - (-generate [gf args constraint] - (assert (cm/choice? constraint)) - (let [dist (apply ctor args) - val (cm/unwrap constraint) + (-generate [this args constraint] + (if (= constraint choicemap/EMPTY) + (gf/generate this args) + (do (assert (choicemap/has-value? constraint)) + (let [val (choicemap/get-value constraint) + dist (apply ctor args) + weight (logpdf dist val)] + {:weight weight + :trace (->Trace this args val weight)})))) + + gf/IAssess + (-assess [_ args choice] + (assert (choicemap/has-value? choice)) + (let [val (choicemap/get-value choice) + dist (apply ctor args) weight (logpdf dist val)] {:weight weight - :trace (trace/->PrimitiveTrace gf args val weight)})) + :retval val})) + + gf/IPropose + (propose [_ args] + (let [dist (apply ctor args) + val (sample dist) + weight (logpdf dist val)] + {:choices (choicemap/->Choice val) + :weight weight + :retval val})) #?@(:clj [IFn @@ -151,32 +247,3 @@ (sample (ctor a b c d e f g h i j k l m n o p q r s t))) (-invoke [_ a b c d e f g h i j k l m n o p q r s t rest] (sample (apply ctor a b c d e f g h i j k l m n o p q r s t rest)))])) - -;; ## Combinators -;; -;; The [[Encoded]] type creates a new distribution from a base distribution -;; `dist`. This new distribution transforms values on the way in to `logpdf` -;; using an `encode` function, and decodes sampled values via `decode`. -;; -;; This is useful for building distributions like categorical distributions that -;; might produce and score arbitrary Clojure values, but lean on some existing -;; numeric base implementation. - -(defrecord Encoded [dist encode decode] - LogPDF - (logpdf [_ v] - (logpdf dist (encode v))) - - Sample - (sample [_] - (decode (sample dist)))) - -(defn encoded - "Given a distribution-producing function `ctor`, returns a constructor for a new - distribution that - - - encodes each value `v` into `(encode v)` before passage to [[logpdf]] - - decodes each value `v` sampled from the base distribution into `(decode - v)`" - [ctor encode decode] - (comp #(->Encoded % encode decode) ctor)) diff --git a/src/gen/distribution/commons_math.clj b/src/gen/distribution/commons_math.clj index 2808fef..4fc3b5b 100644 --- a/src/gen/distribution/commons_math.clj +++ b/src/gen/distribution/commons_math.clj @@ -127,30 +127,30 @@ ;; ## Primitive generative functions (def bernoulli - (d/->GenerativeFn bernoulli-distribution)) + (d/->GenerativeFn bernoulli-distribution 1)) (def beta - (d/->GenerativeFn beta-distribution)) + (d/->GenerativeFn beta-distribution 2)) (def gamma - (d/->GenerativeFn gamma-distribution)) + (d/->GenerativeFn gamma-distribution 2)) (def student-t - (d/->GenerativeFn student-t-distribution)) + (d/->GenerativeFn student-t-distribution 3)) (def normal - (d/->GenerativeFn normal-distribution)) + (d/->GenerativeFn normal-distribution 2)) (def uniform - (d/->GenerativeFn uniform-distribution)) + (d/->GenerativeFn uniform-distribution 2)) (def uniform-discrete "Sample an integer from the uniform distribution on the set `{low low+1 ... high-1 high}`." - (d/->GenerativeFn uniform-discrete-distribution)) + (d/->GenerativeFn uniform-discrete-distribution 2)) (def categorical "Given a sequence of probabilities probs where `(reduce + probs)` is 1, sample an integer `i` from the set #{1 2 ... (count probs)} with probability `(nth probs i)`." - (d/->GenerativeFn categorical-distribution)) + (d/->GenerativeFn categorical-distribution 1)) diff --git a/src/gen/distribution/java_util.clj b/src/gen/distribution/java_util.clj index 8291778..ed7e9cd 100644 --- a/src/gen/distribution/java_util.clj +++ b/src/gen/distribution/java_util.clj @@ -55,10 +55,10 @@ ;; ## Primitive generative functions (def bernoulli - (d/->GenerativeFn bernoulli-distribution)) + (d/->GenerativeFn bernoulli-distribution 1)) (def uniform - (d/->GenerativeFn uniform-distribution)) + (d/->GenerativeFn uniform-distribution 2)) (def normal - (d/->GenerativeFn normal-distribution)) + (d/->GenerativeFn normal-distribution 2)) diff --git a/src/gen/distribution/kixi.cljc b/src/gen/distribution/kixi.cljc index 7b46113..d31a79f 100644 --- a/src/gen/distribution/kixi.cljc +++ b/src/gen/distribution/kixi.cljc @@ -141,25 +141,25 @@ ;; ## Primitive generative functions (def bernoulli - (d/->GenerativeFn bernoulli-distribution)) + (d/->GenerativeFn bernoulli-distribution 1)) (def beta - (d/->GenerativeFn beta-distribution)) + (d/->GenerativeFn beta-distribution 2)) (def cauchy - (d/->GenerativeFn cauchy-distribution)) + (d/->GenerativeFn cauchy-distribution 2)) (def exponential - (d/->GenerativeFn exponential-distribution)) + (d/->GenerativeFn exponential-distribution 1)) (def uniform - (d/->GenerativeFn uniform-distribution)) + (d/->GenerativeFn uniform-distribution 2)) (def normal - (d/->GenerativeFn normal-distribution)) + (d/->GenerativeFn normal-distribution 2)) (def gamma - (d/->GenerativeFn gamma-distribution)) + (d/->GenerativeFn gamma-distribution 2)) (def student-t - (d/->GenerativeFn student-t-distribution)) + (d/->GenerativeFn student-t-distribution 3)) diff --git a/src/gen/dynamic.cljc b/src/gen/dynamic.cljc index a3caa2d..7658ab4 100644 --- a/src/gen/dynamic.cljc +++ b/src/gen/dynamic.cljc @@ -1,9 +1,13 @@ (ns gen.dynamic - (:require [clojure.walk :as walk] - [gen.choice-map :as choice-map] - [gen.dynamic.trace :as dynamic.trace] + (:require [clojure.pprint :as pprint] + [clojure.walk :as walk] + [gen.choicemap :as choicemap] + [gen.diff :as diff] [gen.generative-function :as gf] [gen.trace :as trace]) + #?(:clj + (:import (clojure.lang Associative IFn IObj IPersistentMap + MapEntry))) #?(:cljs (:require-macros [gen.dynamic :refer [untraced]]))) @@ -17,57 +21,343 @@ (throw (ex-info "Illegal usage of `splice!` out of `gen`." {}))) +;; ## trace impl + +(defn no-op + ([gf args] + (apply gf args)) + ([_k gf args] + (apply gf args))) + +(def ^:dynamic *trace* + "Applies the generative function gf to args. Dynamically rebound by functions + like `gf/simulate`, `gf/generate`, `trace/update`, etc." + no-op) + +(defn active-trace + "Returns the currently-active tracing function, bound to [[*trace*]]. + + NOTE: Prefer `([[active-trace]])` to `[[*trace*]]`, as direct access to + `[[*trace*]]` won't reflect new bindings when accessed inside of an SCI + environment." + [] *trace*) + +;; TODO move `trace!` to `gen`. + (defmacro untraced [& body] - `(binding [dynamic.trace/*trace* dynamic.trace/no-op] + `(binding [*trace* no-op] ~@body)) -(defrecord DynamicDSLFunction [clojure-fn] +;; ## Choice Map for address-like trace + +(defrecord Call [subtrace score noise]) + +(deftype ChoiceMap [m] + choicemap/IChoiceMap + (-has-value? [_] false) + (-get-value [_] nil) + (has-submap? [_ k] (contains? m k)) + (get-submap [this k] (.invoke ^IFn this k choicemap/EMPTY)) + + (get-values-shallow [_] + (persistent! + (reduce-kv + (fn [acc k v] + (let [m (trace/get-choices (:subtrace v))] + (if (choicemap/-has-value? m) + (assoc! acc k (choicemap/-get-value m)) + acc))) + (transient {}) + m))) + + (get-submaps-shallow [_] + (persistent! + (reduce-kv + (fn [acc k v] + (assoc! acc k (trace/get-choices (:subtrace v)))) + (transient {}) + m))) + + #?@(:clj + [Object + (equals [this that] (choicemap/equiv this that)) + (toString [this] (pr-str this)) + + IFn + (invoke [this k] (.invoke ^IFn this k nil)) + (invoke [_ k not-found] + (if-let [v (get m k)] + (trace/get-choices (:subtrace v)) + not-found)) + + IObj + (meta [_] (meta m)) + (withMeta [_ meta-m] + (ChoiceMap. + (with-meta m meta-m))) + + IPersistentMap + (assocEx [_ _ _] (throw (Exception.))) + (assoc [_ _ _] + (throw + (ex-info "ChoiceMap instances are read-only." {}))) + (without [m k] + (ChoiceMap. (dissoc m k))) + + Associative + (containsKey [_ k] (contains? m k)) + (entryAt [this k] + (when (contains? m k) + (MapEntry/create k (.invoke ^IFn this k nil)))) + (cons [_ _] + (throw + (ex-info "ChoiceMap instances are read-only." {}))) + + (count [_] (count m)) + (seq [_] + (when-let [kvs (seq m)] + (map (fn [[k v]] + (MapEntry/create k (trace/get-choices (:subtrace v)))) + kvs))) + + (empty [_] choicemap/EMPTY) + (valAt [this k] (.invoke ^IFn this k nil)) + (valAt [this k not-found] (.invoke ^IFn this k not-found)) + (equiv [this that] (choicemap/equiv this that)) + + Iterable + (iterator [this] + (.iterator ^Iterable (choicemap/get-submaps-shallow this)))] + + :cljs + [Object + (toString [_] (pr-str m)) + (equiv [this that] (choicemap/equiv this that)) + + IPrintWithWriter + (-pr-writer [_ writer opts] + (-pr-writer m writer opts)) + + IFn + (-invoke [this k] (-invoke this k nil)) + (-invoke [_ k not-found] + (if-let [v (get m k)] + (trace/get-choices (:subtrace v)) + not-found)) + + IMeta + (-meta [_] (-meta m)) + + IWithMeta + (-with-meta [_ meta-m] + (ChoiceMap. + (-with-meta m meta-m))) + + IEmptyableCollection + (-empty [_] choicemap/EMPTY) + + IEquiv + (-equiv [this that] (choicemap/equiv this that)) + + ISeqable + (-seq [_] (-seq m)) + + ICounted + (-count [_] (-count m)) + + ILookup + (-lookup [_ k] (-invoke m k nil)) + (-lookup [_ k not-found] (-invoke m k not-found)) + + IAssociative + (-assoc [_ _ _] + (throw + (ex-info "ChoiceMap instances are read-only." {}))) + (-contains-key? [_ k] (-contains-key? m k)) + + IMap + (-dissoc [_ k] + (ChoiceMap. + (dissoc m k)))])) + +#?(:clj + (defmethod print-method ChoiceMap + [^ChoiceMap cm ^java.io.Writer w] + (-> (choicemap/get-submaps-shallow cm) + (print-method w)))) + +(defmethod pprint/simple-dispatch ChoiceMap [cm] + (pprint/simple-dispatch + (choicemap/get-submaps-shallow cm))) + +(deftype Trace [gen-fn trie score noise args retval] + trace/ITrace + (get-args [_] args) + (get-retval [_] retval) + (get-gen-fn [_] gen-fn) + (get-choices [_] (->ChoiceMap trie)) + (get-score [_] score)) + +#?(:clj + (defmethod print-method Trace + [^Trace t ^java.io.Writer w] + (print-method (trace/trace->map t) w))) + +(defmethod pprint/simple-dispatch Trace [^Trace t] + (pprint/simple-dispatch (trace/trace->map t))) + +(defn trace + "Returns a new bare trace. + + TODO pad args with defaults if available." + [gen-fn args] + (Trace. gen-fn {} 0.0 0.0 args nil)) + +(defn validate-empty! + [^Trace trace addr] + (when (contains? (.-trie trace) addr) + (throw + (ex-info + "Subtrace already present at address. The same address cannot be reused + for multiple random choices." + {:addr addr})))) + +(defn with-retval [^Trace trace retval] + (Trace. (.-gen-fn trace) + (.-trie trace) + (.-score trace) + (.-noise trace) + (.-args trace) + retval)) + +(defn add-call + "TODO handle noise." + [^Trace trace k subtrace] + (validate-empty! trace k) + (let [trie (.-trie trace) + score (trace/get-score subtrace) + noise 0.0 #_ (trace/project subtrace nil) + call (->Call subtrace score noise)] + (Trace. (.-gen-fn trace) + (assoc trie k call) + (+ (.-score trace) score) + (+ (.-noise trace) noise) + (.-args trace) + (.-retval trace)))) + +(defn ^:no-doc trace:= [^Trace this that] + (and (instance? Trace that) + (let [^Trace that that] + (and (= (.-gen-fn this) (.-gen-fn that)) + (= (.-trie this) (.-trie that)) + (= (.-score this) (.-score that)) + (= (.-noise this) (.-noise that)) + (= (.-args this) (.-args that)) + (= (.-retval this) (.-retval that)))))) + +;; ## Update State +(defn ^:no-doc combine + "Combine trace update states. careful not to add " + [v k {:keys [trace weight discard]}] + {:trace (add-call (:trace v) k trace) + :weight (+ (:weight v) weight) + :discard (if (choicemap/empty? discard) + (:discard v) + (assoc (:discard v) k discard))}) + +;; ## Update impl +;; +;; TODO figure out what these notes mean!! + +;; TODO this feels weird that we need something like this... +;; +;; TODO can we add exec to the protocol? NO but we can do `exec` if we move all +;; this nonsense into `gen.dynamic`... that would work! + +(defn ^:no-doc extract-unvisited [^Trace prev-trace new-trace] + (let [visited-m (choicemap/get-submaps-shallow + (trace/get-choices new-trace)) + unvisited-trie (apply dissoc + (.-trie prev-trace) + (keys visited-m)) + to-subtract (reduce-kv (fn [acc _ v] (+ acc (:score v))) + 0.0 + unvisited-trie)] + + [to-subtract (->ChoiceMap unvisited-trie)])) + +(defn assert-all-visited! [^Trace trace constraints] + (when-let [unvisited (keys + (apply dissoc + (choicemap/get-submaps-shallow constraints) + (keys (.-trie trace))))] + (throw (ex-info "Some constraints weren't visited: " + {:unvisited unvisited})))) + +(declare apply-inner) + +(extend-type Trace + trace/IUpdate + (-update [this args _ constraints] + (let [gen-fn (trace/get-gen-fn this) + state (atom {:trace (trace gen-fn args) + :weight 0.0 + :discard (choicemap/choicemap)})] + (binding + [*trace* + (fn + ([gf args] + (apply-inner gf args)) + ([k gen-fn args] + (validate-empty! (:trace @state) k) + + (let [k-constraints (choicemap/get-submap constraints k) + new-state + ;; TODO this is a spot where we'll need to check the + ;; previous value. + (if-let [prev-subtrace (:subtrace + (get (.-trie this) k))] + (do + (assert + (= gen-fn (trace/get-gen-fn prev-subtrace)) + (str "Generative function changed at address " k ".")) + (trace/update prev-subtrace k-constraints)) + (gf/generate gen-fn args k-constraints))] + (swap! state combine k new-state) + (trace/get-retval + (:trace new-state)))))] + (let [retval (apply-inner gen-fn args) + {:keys [trace weight discard]} @state + [to-subtract unvisited] (extract-unvisited this trace)] + (assert-all-visited! trace constraints) + {:trace (with-retval trace retval) + :change diff/unknown-change + :weight (- weight to-subtract) + :discard (choicemap/merge discard unvisited)}))))) + +;; so we are going to remove the score of the unvisited stuff as we go up. Does +;; that work? + +(defrecord DynamicDSLFunction [clojure-fn has-argument-grads accepts-output-grad?] gf/IGenerativeFunction + (has-argument-grads [_] has-argument-grads) + (accepts-output-grad? [_] accepts-output-grad?) + (get-params [_] ()) (simulate [gf args] - (let [trace (atom (dynamic.trace/trace gf args))] - (binding [dynamic.trace/*splice* - (fn [gf args] - (let [subtrace (gf/simulate gf args)] - (swap! trace dynamic.trace/merge-subtraces subtrace) - (trace/get-retval subtrace))) - - dynamic.trace/*trace* - (fn [k gf args] - (dynamic.trace/validate-empty! @trace k) - (let [subtrace (gf/simulate gf args)] - (swap! trace dynamic.trace/assoc-subtrace k subtrace) - (trace/get-retval subtrace)))] - (let [retval (apply clojure-fn args)] - (swap! trace dynamic.trace/with-retval retval) - @trace)))) - - gf/IGenerate - (-generate [gf args constraints] - (let [state (atom {:trace (dynamic.trace/trace gf args) - :weight 0})] - (binding [dynamic.trace/*splice* - (fn [gf args] - (let [{subtrace :trace - weight :weight} - (gf/generate gf args constraints)] - (swap! state update :trace dynamic.trace/merge-subtraces subtrace) - (swap! state update :weight + weight) - (trace/get-retval subtrace))) - - dynamic.trace/*trace* - (fn [k gf args] - (dynamic.trace/validate-empty! (:trace @state) k) - (let [{subtrace :trace :as ret} - (if-let [k-constraints (get (choice-map/submaps constraints) k)] - (gf/generate gf args k-constraints) - (gf/generate gf args))] - (swap! state dynamic.trace/combine k ret) - (trace/get-retval subtrace)))] + (let [!trace (atom (trace gf args))] + (binding [*trace* + (fn + ([gf args] + (apply-inner gf args)) + ([k gf args] + (validate-empty! @!trace k) + (let [subtrace (gf/simulate gf args)] + (swap! !trace add-call k subtrace) + (trace/get-retval subtrace))))] (let [retval (apply clojure-fn args) - trace (:trace @state)] - {:trace (dynamic.trace/with-retval trace retval) - :weight (:weight @state)})))) + trace @!trace] + (with-retval trace retval))))) #?@(:clj [clojure.lang.IFn @@ -166,6 +456,9 @@ (-invoke [_ x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 x11 x12 x13 x14 x15 x16 x17 x18 x19 x20 xs] (untraced (apply clojure-fn x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 x11 x12 x13 x14 x15 x16 x17 x18 x19 x20 xs)))])) +(defn ^:no-doc apply-inner [^DynamicDSLFunction gf args] + (apply (.-clojure-fn gf) args)) + ;; The following two functions use a brittle form of macro-rewriting; we should ;; really look at the namespace and local macro environments to try and see if a ;; particular symbol is bound to `#'gen.dynamic/{trace!,splice!}`. See @@ -186,25 +479,94 @@ (defn ^:no-doc gen-body [& xs] (let [name (when (simple-symbol? (first xs)) (first xs)) - [params & body] (if name (rest xs) xs)] - `(->DynamicDSLFunction - (fn ~@(when name [name]) - ~params - ~@(walk/postwalk - (fn [form] - (cond (trace-form? form) - (let [[addr gf & xs] (rest form)] - `((dynamic.trace/active-trace) ~addr ~gf [~@xs])) - - (splice-form? form) - (let [[gf & xs] (rest form)] - `((dynamic.trace/active-splice) ~gf [~@xs])) - - :else form)) - body))))) + [params & body] (if name (rest xs) xs) + has-arg-grads (mapv (constantly false) params) + accepts-output-grad? false] + `(-> (fn ~@(when name [name]) + ~params + ~@(walk/postwalk + (fn [form] + (cond (trace-form? form) + (let [[addr gf & xs] (rest form)] + `((active-trace) ~addr ~gf [~@xs])) + + (splice-form? form) + (let [[gf & xs] (rest form)] + `((active-trace) ~gf [~@xs])) + + :else form)) + body)) + (->DynamicDSLFunction ~has-arg-grads ~accepts-output-grad?)))) (defmacro gen "Defines a generative function." [& args] {:clj-kondo/lint-as 'clojure.core/fn} (apply gen-body args)) + +;; ## Generate impl + +(defn ^:no-doc assoc-state + "combine by adding weights?" + [state k {:keys [trace weight]}] + (-> state + (update :trace add-call k trace) + (update :weight + weight))) + +;; TODO figure out visited / unvisited?? + +(extend-type DynamicDSLFunction + gf/IGenerate + (-generate [gf args constraints] + (let [trace (trace gf args) + !state (atom {:trace trace :weight 0.0})] + (binding [*trace* + (fn + ([gf args] + (apply-inner gf args)) + ([k gf args] + (validate-empty! (:trace @!state) k) + (let [{subtrace :trace :as ret} + (let [k-constraints (choicemap/get-submap constraints k)] + (gf/generate gf args k-constraints))] + (swap! !state assoc-state k ret) + (trace/get-retval subtrace))))] + (let [retval (apply-inner gf args) + state @!state] + (update state :trace with-retval retval))))) + + gf/IAssess + (-assess [gf args choices] + (let [!weight (atom 0.0)] + (binding [*trace* + (fn + ([gf args] + (apply-inner gf args)) + ([k gf args] + (let [{:keys [weight retval]} + (let [k-choices (choicemap/get-submap choices k)] + (gf/assess gf args k-choices))] + (swap! !weight + weight) + retval)))] + (let [retval (apply-inner gf args)] + {:weight @!weight + :retval retval})))) + + gf/IPropose + (propose [gf args] + (let [!state (atom {:choices (choicemap/choicemap) + :weight 0.0})] + (binding [*trace* + (fn + ([gf args] + (apply-inner gf args)) + ([k gf args] + (let [{:keys [submap weight retval]} (gf/propose gf args)] + (swap! !state + (fn [m] + (-> m + (update :choices assoc k submap) + (update :weight + weight)))) + retval)))] + (let [retval (apply-inner gf args)] + (assoc @!state :retval retval)))))) diff --git a/src/gen/dynamic/choice_map.cljc b/src/gen/dynamic/choice_map.cljc deleted file mode 100644 index 9e197ce..0000000 --- a/src/gen/dynamic/choice_map.cljc +++ /dev/null @@ -1,230 +0,0 @@ -(ns gen.dynamic.choice-map - (:require [gen.choice-map :as choice-map]) - #?(:clj - (:import (clojure.lang Associative IFn IObj IPersistentMap - IMapIterable MapEntry)))) - -;; https://blog.wsscode.com/guide-to-custom-map-types/ -;; https://github.com/originrose/lazy-map/blob/119dda207fef90c1e26e6c01aa63e6cfb45c1fa8/src/lazy_map/core.clj#L197-L278 - -(defrecord Choice [choice] - choice-map/Value - (value [_] choice)) - -#?(:clj - (defmethod print-method Choice [choice ^java.io.Writer w] - (.write w "#gen/choice ") - (.write w (pr-str (choice-map/value choice))))) - -(defn choice? - "Returns `true` if `x` is an instance of `Choice`." - [x] - (instance? Choice x)) - -(defn choice - "Creates a new leaf chioce map with `x` as its value." - [x] - (if (instance? Choice x) - x - (->Choice x))) - -(declare choice-map choice-map? unwrap) - -(defn auto-get-choice - [x] - (if (instance? Choice x) - (choice-map/value x) - x)) - -(deftype ChoiceMap [m] - choice-map/Submaps - (submaps [_] m) - - #?@(:cljs - [Object - (toString [this] (pr-str this)) - (equiv [this other] (-equiv this other)) - - IPrintWithWriter - (-pr-writer [cm writer _] - (write-all - writer - "#gen/choice-map " - (str (unwrap cm)))) - - IFn - (-invoke [_ k] (auto-get-choice (get m k))) - - IMeta - (-meta [_] (-meta m)) - - IWithMeta - (-with-meta [_ meta-m] (ChoiceMap. (-with-meta m meta-m))) - - - ICloneable - (-clone [_] (ChoiceMap. (-clone m))) - - IIterable - (-iterator [_] (-iterator m)) - - ICollection - (-conj [_ entry] - (if (vector? entry) - (ChoiceMap. - (-assoc m (-nth entry 0) (choice (-nth entry 1)))) - (ChoiceMap. - (reduce-kv (fn [acc k v] - (assoc acc k (choice v))) - m - entry)))) - - IEmptyableCollection - (-empty [_] (ChoiceMap. (-empty m))) - - IEquiv - (-equiv [_ o] (and (instance? ChoiceMap o) (= m (.-m ^ChoiceMap o)))) - - IHash - (-hash [_] (-hash m)) - - ISeqable - (-seq [_] - (when-let [kvs (seq m)] - (map (fn [[k v]] - (MapEntry. k (auto-get-choice v) nil)) - kvs))) - - ICounted - (-count [_] (-count m)) - - ILookup - (-lookup [_ k] (auto-get-choice (-lookup m k))) - (-lookup [_ k not-found] - (let [v (-lookup m k ::not-found)] - (if (= v ::not-found) - not-found - (auto-get-choice v)))) - - IAssociative - (-assoc [_ k v] (ChoiceMap. (-assoc m k (choice v)))) - (-contains-key? [_ k] (-contains-key? m k)) - - IFind - (-find [_ k] - (when-let [v (-find m k)] - (MapEntry. (-key v) (auto-get-choice (-val v)) nil))) - - IMap - (-dissoc [_ k] (ChoiceMap. (-dissoc m k))) - - IKVReduce - (-kv-reduce [_ f init] - (-kv-reduce m - (fn [acc k v] - (f acc k (auto-get-choice v))) - init))] - - :clj - [Object - (equals [_ o] (and (instance? ChoiceMap o) (= m (.-m ^ChoiceMap o)))) - (toString [this] (pr-str this)) - - IFn - (invoke [this k] (.valAt this k)) - (invoke [this k not-found] (.valAt this k not-found)) - - IObj - (meta [_] (meta m)) - (withMeta [_ meta-m] (ChoiceMap. (with-meta m meta-m))) - - IPersistentMap - (assocEx [_ _ _] (throw (Exception.))) - (assoc [_ k v] - (ChoiceMap. (.assoc ^IPersistentMap m k (choice v)))) - (without [_ k] - (ChoiceMap. (.without ^IPersistentMap m k))) - - Associative - (containsKey [_ k] (contains? m k)) - (entryAt [_ k] - (when (contains? m k) - (MapEntry/create k (auto-get-choice (get m k))))) - (cons [this o] - (if (map? o) - (reduce-kv assoc this o) - (let [[k v] o] - (ChoiceMap. (assoc m k (choice v)))))) - (count [_] (count m)) - (seq [_] - (when-let [kvs (seq m)] - (map (fn [[k v]] - (MapEntry/create k (auto-get-choice v))) - kvs))) - (empty [_] (ChoiceMap. (empty m))) - (valAt [_ k] - (auto-get-choice (get m k))) - (valAt [_ k not-found] - (auto-get-choice (get m k not-found))) - (equiv [_ o] - (and (instance? ChoiceMap o) (= m (.-m ^ChoiceMap o)))) - - IMapIterable - (keyIterator [_] - (.iterator ^Iterable (keys m))) - (valIterator [_] - (.iterator ^Iterable (map auto-get-choice m))) - - Iterable - (iterator [this] - (if-let [xs (.seq this)] - (.iterator ^Iterable xs) - (.iterator {})))])) - -(defn unwrap - "If `m` is a [[Choice]] or [[ChoiceMap]], returns `m` stripped of its wrappers. - Else, returns `m`" - [m] - (cond (choice? m) (:choice m) - (map? m) (update-vals m unwrap) - :else m)) - -#?(:clj - (defmethod print-method ChoiceMap [^ChoiceMap cm ^java.io.Writer w] - (.write w "#gen/choice-map ") - (print-method (unwrap cm) w))) - -(defn choice-map - [& {:as m}] - (->ChoiceMap - (update-vals m (fn [x] - (cond (or (instance? Choice x) - (instance? ChoiceMap x)) - x - - (map? x) - (choice-map x) - - :else - (Choice. x)))))) - -(defn choice-map? [x] - (instance? ChoiceMap x)) - -;; ## Reader literals - -(defn ^:no-doc parse-choice - "Implementation of a reader literal that turns literal forms into calls - to [[choice]]. - - Installed by default under `#gen/choice`." - [form] - `(choice ~form)) - -(defn ^:no-doc parse-choice-map - "Implementation of a reader literal that turns literal map forms into calls - to [[choice-map]]. - - Installed by default under `#gen/choice-map`." - [form] - `(choice-map ~form)) diff --git a/src/gen/dynamic/trace.cljc b/src/gen/dynamic/trace.cljc deleted file mode 100644 index f590e8f..0000000 --- a/src/gen/dynamic/trace.cljc +++ /dev/null @@ -1,276 +0,0 @@ -(ns gen.dynamic.trace - (:refer-clojure :exclude [=]) - (:require [clojure.core :as core] - [gen.choice-map :as choice-map] - [gen.diff :as diff] - [gen.dynamic.choice-map :as cm] - [gen.generative-function :as gf] - [gen.trace :as trace]) - #?(:cljs - (:require-macros [gen.dynamic.trace])) - #?(:clj - (:import - (clojure.lang Associative IFn IObj IMapIterable)))) - -(defn no-op - ([gf args] - (apply gf args)) - ([_k gf args] - (apply gf args))) - -(def ^:dynamic *trace* - "Applies the generative function gf to args. Dynamically rebound by functions - like `gf/simulate`, `gf/generate`, `trace/update`, etc." - no-op) - -(def ^:dynamic *splice* - "Applies the generative function gf to args. Dynamically rebound by functions - like `gf/simulate`, `gf/generate`, `trace/update`, etc." - no-op) - -(defn active-trace - "Returns the currently-active tracing function, bound to [[*trace*]]. - - NOTE: Prefer `([[active-trace]])` to `[[*trace*]]`, as direct access to - `[[*trace*]]` won't reflect new bindings when accessed inside of an SCI - environment." - [] *trace*) - -(defn active-splice - "Returns the currently-active tracing function, bound to [[*splice*]]. - - NOTE: Prefer `([[active-splice]])` to `[[*splice*]]`, as direct access to - `[[*splice*]]` won't reflect new bindings when accessed inside of an SCI - environment." - [] - *splice*) - -(declare assoc-subtrace update-trace trace =) - -(deftype Trace [gf args subtraces retval] - trace/ITrace - (get-args [_] args) - (get-choices [_] - (cm/->ChoiceMap (update-vals subtraces trace/get-choices))) - (get-gen-fn [_] gf) - (get-retval [_] retval) - (get-score [_] - ;; TODO Handle untraced randomness. - (let [v (vals subtraces)] - (transduce (map trace/get-score) + 0.0 v))) - - trace/IUpdate - (-update [this _ _ constraints] - (update-trace this constraints)) - - #?@(:cljs - [Object - (equiv [this other] (-equiv this other)) - - IFn - (-invoke [this k] (-lookup this k)) - (-invoke [this k not-found] (-lookup this k not-found)) - - IMeta - (-meta [_] (meta subtraces)) - - IWithMeta - (-with-meta [_ m] (Trace. gf args (with-meta subtraces m) retval)) - - - ;; ICloneable - ;; (-clone [_] (Trace. (-clone m))) - - IIterable - (-iterator [this] (-iterator (trace/get-choices this))) - - ;; ICollection - ;; (-conj [_ entry]) - - ;; IEmptyableCollection - ;; (-empty [_]) - - IEquiv - (-equiv [this that] (= this that)) - - ;; IHash - ;; (-hash [_] (-hash m)) - - ISeqable - (-seq [this] (-seq (trace/get-choices this))) - - ICounted - (-count [_] (-count subtraces)) - - ILookup - (-lookup [this k] - (-lookup (trace/get-choices this) k)) - (-lookup [this k not-found] - (-lookup (trace/get-choices this) k not-found)) - - IAssociative - ;; (-assoc [_ k v] (Trace. (-assoc m k (choice v)))) - (-contains-key? [_ k] (-contains-key? subtraces k)) - - IFind - (-find [this k] - (-find (trace/get-choices this) k))] - - :clj - [Object - (equals [this that] (= this that)) - - IFn - (invoke [this k] (.valAt this k)) - (invoke [this k not-found] (.valAt this k not-found)) - - IObj - (meta [_] (meta subtraces)) - (withMeta [_ m] (Trace. gf args (with-meta subtraces m) retval)) - - Associative - (containsKey [_ k] (contains? subtraces k)) - (entryAt [_ k] (.entryAt ^Associative subtraces k)) - (count [_] (count subtraces)) - (seq [this] (seq (trace/get-choices this))) - (valAt [this k] - (get (trace/get-choices this) k)) - (valAt [this k not-found] - (get (trace/get-choices this) k not-found)) - (equiv [this that] (= this that)) - ;; TODO missing `cons`, `empty`? - - IMapIterable - (keyIterator [this] - (.iterator ^Iterable (keys (trace/get-choices this)))) - (valIterator [this] - (.iterator ^Iterable (vals (trace/get-choices this)))) - - Iterable - (iterator [this] - (.iterator ^Iterable (trace/get-choices this)))])) - -(defn ^:no-doc = [^Trace this that] - (and (instance? Trace that) - (let [^Trace that that] - (and (core/= (.-gf this) (.-gf that)) - (core/= (.-args this) (.-args that)) - (core/= (.-subtraces this) (.-subtraces that)) - (core/= (.-retval this) (.-retval that)))))) - -(defn trace - [gf args] - (->Trace gf args {} nil)) - -(defn with-retval [^Trace t v] - (->Trace (.-gf t) (.-args t) (.-subtraces t) v)) - -(defn validate-empty! [t addr] - (when (contains? t addr) - (throw (ex-info "Value or subtrace already present at address. The same - address cannot be reused for multiple random choices." - {:addr addr})))) - -(defn assoc-subtrace - [^Trace t addr subt] - (validate-empty! t addr) - (->Trace (.-gf t) - (.-args t) - (assoc (.-subtraces t) addr subt) - (.-retval t))) - -(defn merge-subtraces - [^Trace t1 ^Trace t2] - (reduce-kv assoc-subtrace - t1 - (.-subtraces t2))) - -(defn ^:no-doc combine - "combine by adding weights?" - [v k {:keys [trace weight discard]}] - (-> v - (update :trace assoc-subtrace k trace) - (update :weight + weight) - (cond-> discard (update :discard assoc k discard)))) - -(defn update-trace [^Trace this constraints] - (let [gf (trace/get-gen-fn this) - state (atom {:trace (trace gf (trace/get-args this)) - :weight 0 - :discard (cm/choice-map)})] - (binding [*splice* - (fn [& _] - (throw (ex-info "Not yet implemented." {}))) - - *trace* - (fn [k gf args] - (validate-empty! (:trace @state) k) - (let [k-constraints (get (choice-map/submaps constraints) k) - {subtrace :trace :as ret} - (if-let [prev-subtrace (get (.-subtraces this) k)] - (trace/update prev-subtrace k-constraints) - (gf/generate gf args k-constraints))] - (swap! state combine k ret) - (trace/get-retval subtrace)))] - (let [retval (apply (:clojure-fn gf) (trace/get-args this)) - {:keys [trace weight discard]} @state - unvisited (apply dissoc - (trace/get-choices this) - (keys (trace/get-choices trace)))] - - {:trace (with-retval trace retval) - :weight weight - :discard (merge discard unvisited)})))) - -;; ## Primitive Trace -;; -;; [[Trace]] above tracks map-like associations of address to traced value. At -;; the bottom of the tree represented by these associations is a primitive -;; trace, usually generated by a primitive probability distribution. -;; -;; [[PrimitiveTrace]] is a simplified version of [[Trace]] (and an implementer -;; of the [[gen.trace]] interface) designed for a single value. - -(declare update-primitive) - -(defrecord PrimitiveTrace [gf args val score] - trace/ITrace - (get-args [_] args) - (get-choices [_] (cm/choice val)) - (get-retval [_] val) - (get-gen-fn [_] gf) - (get-score [_] score) - - trace/IUpdate - (-update [trace _ _ constraint] - (update-primitive trace constraint))) - -(defn ^:no-doc update-primitive - "Accepts a [[PrimitiveTrace]] instance `t` and a - single [[gen.dynamic.choice-map/Choice]] and returns a new object with keys - `:trace`, `:weight` and `:change`." - [t constraint] - {:pre [(instance? PrimitiveTrace t)]} - (cond (cm/choice? constraint) - (-> (trace/get-gen-fn t) - (gf/generate (trace/get-args t) constraint) - (update :weight - (trace/get-score t)) - (core/assoc :change diff/unknown-change - :discard (trace/get-choices t))) - - (nil? constraint) - {:trace t - :weight 0.0 - :change diff/unknown-change} - - (map? constraint) - (throw - (ex-info - "Expected a value at address but found a sub-assignment." - {:sub-assignment constraint})) - - :else - (throw - (ex-info - "non-nil, non-Choice constraint not allowed." - {:sub-assignment constraint})))) diff --git a/src/gen/generative_function.cljc b/src/gen/generative_function.cljc index f3487ba..e025e97 100644 --- a/src/gen/generative_function.cljc +++ b/src/gen/generative_function.cljc @@ -109,8 +109,7 @@ {:trace trace :weight 0.0})) ([gf args constraints] - ;; TODO re-enable after full dynamic conversion. - (let [constraints constraints #_(choicemap/choicemap constraints)] + (let [constraints (choicemap/choicemap constraints)] (-generate gf args constraints)))) (defn ^:no-doc default-propose diff --git a/src/gen/sci.cljc b/src/gen/sci.cljc index 47db93f..c1b808b 100644 --- a/src/gen/sci.cljc +++ b/src/gen/sci.cljc @@ -1,6 +1,7 @@ (ns gen.sci "Functions for installation of all namespaces into an SCI context." - (:require [gen.choice-map] + (:require [gen.array] + [gen.choicemap] [gen.clerk.callout] [gen.clerk.viewer] [gen.diff] @@ -8,30 +9,33 @@ [gen.distribution.kixi] [gen.distribution.math.log-likelihood] [gen.dynamic] - [gen.dynamic.choice-map] - [gen.dynamic.trace] [gen.generative-function] [gen.inference.importance] [gen.trace] [sci.core :as sci] [sci.ctx-store])) -(def gen-macro ^:sci/macro +(def ^:no-doc gen-macro ^:sci/macro (fn [_&form _&env & args] (apply gen.dynamic/gen-body args))) +(def ^:no-doc untraced-macro ^:sci/macro + (fn [_&form _&env & body] + `(binding [gen.dynamic/*trace* no-op] + ~@body))) + (def namespaces - {'gen.clerk.callout (sci/copy-ns gen.clerk.callout (sci/create-ns 'gen.clerk.callout)) + {'gen.array (sci/copy-ns gen.array (sci/create-ns 'gen.array)) + 'gen.choicemap (sci/copy-ns gen.choicemap (sci/create-ns 'gen.choicemap)) + 'gen.clerk.callout (sci/copy-ns gen.clerk.callout (sci/create-ns 'gen.clerk.callout)) 'gen.clerk.viewer (sci/copy-ns gen.clerk.viewer (sci/create-ns 'gen.clerk.viewer)) - 'gen.choice-map (sci/copy-ns gen.choice-map (sci/create-ns 'gen.choice-map)) 'gen.diff (sci/copy-ns gen.diff (sci/create-ns 'gen.diff)) 'gen.distribution (sci/copy-ns gen.distribution (sci/create-ns 'gen.distribution)) 'gen.distribution.kixi (sci/copy-ns gen.distribution.kixi (sci/create-ns 'gen.distribution.kixi)) 'gen.distribution.math.log-likelihood (sci/copy-ns gen.distribution.math.log-likelihood (sci/create-ns 'gen.distribution.math.log-likelihood)) 'gen.dynamic (-> (sci/copy-ns gen.dynamic (sci/create-ns 'gen.dynamic)) - (assoc 'gen gen-macro)) - 'gen.dynamic.choice-map (sci/copy-ns gen.dynamic.choice-map (sci/create-ns 'gen.dynamic.choice-map)) - 'gen.dynamic.trace (sci/copy-ns gen.dynamic.trace (sci/create-ns 'gen.dynamic.trace)) + (assoc 'gen gen-macro + 'untraced untraced-macro)) 'gen.generative-function (sci/copy-ns gen.generative-function (sci/create-ns 'gen.generative-function)) 'gen.inference.importance (sci/copy-ns gen.inference.importance (sci/create-ns 'gen.inference.importance)) 'gen.trace (sci/copy-ns gen.trace (sci/create-ns 'gen.trace))}) diff --git a/src/gen/trace.cljc b/src/gen/trace.cljc index ba07297..46e899e 100644 --- a/src/gen/trace.cljc +++ b/src/gen/trace.cljc @@ -1,7 +1,7 @@ (ns gen.trace "Defines the [[ITrace]] abstraction and its API." (:refer-clojure :exclude [update]) - (:require #_[gen.choicemap :as choicemap] + (:require [gen.choicemap :as choicemap] [gen.diff :as diff])) ;; ## ITrace @@ -223,12 +223,10 @@ ([trace constraints] (let [args (get-args trace) diffs (repeat (count args) diff/no-change) - ;; TODO re-enable after full dynamic conversion. - constraints constraints #_(choicemap/choicemap constraints)] + constraints (choicemap/choicemap constraints)] (-update trace args diffs constraints))) ([trace args argdiffs constraints] - ;; TODO re-enable after full dynamic conversion. - (let [constraints constraints #_(choicemap/choicemap constraints)] + (let [constraints (choicemap/choicemap constraints)] (-update trace args argdiffs constraints)))) (defn trace->map diff --git a/test/gen/distribution_test.cljc b/test/gen/distribution_test.cljc index f90a3e7..280ffd6 100644 --- a/test/gen/distribution_test.cljc +++ b/test/gen/distribution_test.cljc @@ -2,12 +2,12 @@ (:require [com.gfredericks.test.chuck.clojure-test :refer [checking]] [clojure.test :refer [is testing]] [clojure.test.check.generators :as gen] + [gen.choicemap :as choicemap] [gen.diff :as diff] [gen.distribution :as dist] - [gen.dynamic.choice-map :as choice-map] [gen.generative-function :as gf] - [gen.trace :as trace] [gen.generators :refer [gen-double within]] + [gen.trace :as trace] [same.core :refer [ish? zeroish? with-comparator]])) (defn gamma-tests [->gamma] @@ -33,6 +33,19 @@ (is (= -5.992380837839856 (dist/logpdf (->beta 0.001 1) 0.4))) (is (= -6.397440480839912 (dist/logpdf (->beta 1 0.001) 0.4))))) +(defn primitive-gfi-tests [gf args] + (let [trace (gf/simulate gf args)] + (is (= gf (trace/get-gen-fn trace)) + "distribution round trips through the trace ") + + (is (= args (trace/get-args trace)) + "distribution round trips through the trace ") + + (let [choice (trace/get-choices trace)] + (is (= (trace/get-retval trace) + (choicemap/get-value choice)) + "primitive distributions return a single choice.")))) + (defn bernoulli-tests [->bernoulli] (checking "Bernoulli properties" [p (gen-double 0 1) @@ -66,25 +79,30 @@ "prob of `0` matches `1-p`")))) (defn bernoulli-gfi-tests [bernoulli-dist] + (primitive-gfi-tests bernoulli-dist [0.5]) + + (checking "bernoulli dist has proper logpdf" [p (gen-double 0 1)] + (let [trace (gf/simulate bernoulli-dist [p])] + (is (ish? (if (trace/get-retval trace) + p + (- 1 p)) + (Math/exp + (trace/get-score trace)))))) + (testing "bernoulli-call-no-args" (is (boolean? (bernoulli-dist)))) (testing "bernoulli-call-args" (is (boolean? (bernoulli-dist 0.5)))) - (testing "bernoulli-gf" - (is (= bernoulli-dist (trace/get-gen-fn (gf/simulate bernoulli-dist []))))) - - (testing "bernoulli-args" - (is (= [0.5] (trace/get-args (gf/simulate bernoulli-dist [0.5]))))) - (testing "bernoulli-retval" (is (boolean? (trace/get-retval (gf/simulate bernoulli-dist [0.5]))))) (testing "bernoulli-choices-noargs" (is (boolean? - (choice-map/unwrap - (trace/get-choices (gf/simulate bernoulli-dist [])))))) + (choicemap/get-value + (trace/get-choices + (gf/simulate bernoulli-dist [])))))) (testing "bernoulli-update-weight" (is (= 1.0 @@ -102,29 +120,29 @@ (Math/exp))))) (testing "bernoulli-update-discard" - (is (nil? - (-> (gf/generate bernoulli-dist [0.3] #gen/choice true) + (is (choicemap/empty? + (-> (gf/generate bernoulli-dist [0.3] true) (:trace) - (trace/update nil) + (trace/update choicemap/EMPTY) (:discard)))) (is (= #gen/choice true - (-> (gf/generate bernoulli-dist [0.3] #gen/choice true) + (-> (gf/generate bernoulli-dist [0.3] true) (:trace) - (trace/update #gen/choice false) + (trace/update false) (:discard))))) (testing "bernoulli-update-change" - (is (= diff/unknown-change - (-> (gf/generate bernoulli-dist [0.3] #gen/choice true) + (is (= diff/no-change + (-> (gf/generate bernoulli-dist [0.3] true) (:trace) - (trace/update nil) + (trace/update choicemap/EMPTY) (:change)))) (is (= diff/unknown-change - (-> (gf/generate bernoulli-dist [0.3] #gen/choice true) + (-> (gf/generate bernoulli-dist [0.3] true) (:trace) - (trace/update #gen/choice false) + (trace/update false) (:change)))))) (defn cauchy-tests [->cauchy] @@ -187,7 +205,7 @@ (checking "Normal properties" [mu (gen-double -10 10) sigma (gen-double 0.001 10) - v (gen-double -100 100) + v (gen-double -50 50) shift (gen-double -10 10)] (is (ish? (dist/logpdf (->normal 0.0 sigma) v) (dist/logpdf (->normal 0.0 sigma) (- v))) @@ -206,8 +224,8 @@ (defn uniform-tests [->uniform] (checking "(log of the) Beta function is symmetrical" - [min (gen-double -10 0) - max (gen-double 0 10) + [min (gen-double -10 -0.1) + max (gen-double 0.1 10) v (gen-double -10 10)] (let [log-l (dist/logpdf (->uniform min max) v)] (if (<= min v max) diff --git a/test/gen/dynamic/choice_map_test.cljc b/test/gen/dynamic/choice_map_test.cljc deleted file mode 100644 index f8018be..0000000 --- a/test/gen/dynamic/choice_map_test.cljc +++ /dev/null @@ -1,57 +0,0 @@ -(ns gen.dynamic.choice-map-test - (:refer-clojure :exclude [empty empty?]) - (:require [clojure.core :as clojure] - [clojure.test :refer [deftest is]] - [clojure.test.check.generators :as gen] - [com.gfredericks.test.chuck.clojure-test :refer [checking]] - [gen.choice-map :as choice-map] - [gen.dynamic.choice-map :as dynamic.choice-map])) - -(def gen-choice-map - (comp (partial gen/fmap dynamic.choice-map/choice-map) - gen/map)) - -(deftest choice - (is (dynamic.choice-map/choice? (dynamic.choice-map/choice nil))) - (is (dynamic.choice-map/choice? #gen/choice nil)) - (is (dynamic.choice-map/choice? (dynamic.choice-map/choice :x))) - (is (dynamic.choice-map/choice? #gen/choice :x)) - (is (dynamic.choice-map/choice? (dynamic.choice-map/choice [:x]))) - (is (dynamic.choice-map/choice? #gen/choice [:x])) - (is (dynamic.choice-map/choice? (dynamic.choice-map/choice {:x 0}))) - (is (dynamic.choice-map/choice? #gen/choice {:x 0})) - (is (not (dynamic.choice-map/choice? nil))) - (is (not (dynamic.choice-map/choice? :x)))) - -(deftest choice-map? - (is (dynamic.choice-map/choice-map? #gen/choice-map {})) - (is (not (dynamic.choice-map/choice-map? {})))) - -(deftest choice-map-value - (is (= nil (choice-map/value #gen/choice nil))) - (is (= :x (choice-map/value #gen/choice :x)))) - -(deftest empty? - (is (clojure/empty? (dynamic.choice-map/choice-map))) - (is (clojure/empty? #gen/choice-map {})) - #_{:clj-kondo/ignore [:not-empty?]} - (is (not (clojure/empty? #gen/choice-map {:x 0})))) - -(defn iterable-seq [^Iterable iter] - (when (.hasNext iter) - (lazy-seq - (cons (.next iter) - (iterable-seq iter))))) - -(deftest interface-tests - (checking "Interface tests for choice maps" - [m (gen-choice-map gen/keyword gen/any-equatable)] - #?(:clj - (is (= (seq m) - (iterable-seq - (.iterator ^Iterable m))) - "iterator impl matches seq")) - - (is (= m (dynamic.choice-map/choice-map - (zipmap (keys m) (vals m)))) - "keys and vals work correctly"))) diff --git a/test/gen/dynamic/trace_test.cljc b/test/gen/dynamic/trace_test.cljc deleted file mode 100644 index dc4b9ce..0000000 --- a/test/gen/dynamic/trace_test.cljc +++ /dev/null @@ -1,87 +0,0 @@ -(ns gen.dynamic.trace-test - (:refer-clojure :exclude [empty? get keys seq vals]) - (:require [clojure.core :as clojure] - [clojure.test :refer [deftest is]] - [gen.dynamic :refer [gen]] - [gen.dynamic.choice-map :as dynamic.choice-map] - [gen.dynamic.trace :as dynamic.trace] - [gen.trace :as trace])) - -(deftest binding-tests - (letfn [(f [_] "hi!")] - (binding [dynamic.trace/*trace* f - dynamic.trace/*splice* f] - (is (= f (dynamic.trace/active-trace)) - "active-trace reflects dynamic bindings") - - (is (= f (dynamic.trace/active-splice)) - "active-splice reflects dynamic bindings")))) - -(defn choice-trace - [x] - (reify trace/ITrace - (get-choices [_] - (dynamic.choice-map/choice x)))) - -(deftest empty? - (let [trace (dynamic.trace/trace (gen []) [])] - (is (clojure/empty? trace)))) - -(deftest gf - (let [gf (gen [])] - (is (= gf (trace/get-gen-fn (dynamic.trace/trace gf [])))))) - -(deftest args - (is (= [] (trace/get-args (dynamic.trace/trace (gen []) [])))) - (is (= [0] (trace/get-args (dynamic.trace/trace (gen [x] x) [0])))) - (is (= [0 1] (trace/get-args (dynamic.trace/trace (gen [x y] (+ x y)) [0 1]))))) - -(deftest call-position - (let [trace (dynamic.trace/trace (gen []) [])] - (is (nil? (trace :addr)))) - (let [trace (-> (dynamic.trace/trace (gen []) []) - (dynamic.trace/assoc-subtrace :addr (choice-trace :x)))] - (is (= :x (trace :addr))))) - -(deftest keys - (is (= {} (into {} (dynamic.trace/trace (gen []) []))) - "iterator works on an empty trace") - - (is (= #{:addr} - (-> (dynamic.trace/trace (gen []) []) - (dynamic.trace/assoc-subtrace :addr (choice-trace :x)) - (clojure/keys) - (set)))) - (is (= #{:addr1 :addr2} - (-> (dynamic.trace/trace (gen []) []) - (dynamic.trace/assoc-subtrace :addr1 (choice-trace :x)) - (dynamic.trace/assoc-subtrace :addr2 (choice-trace :y)) - (clojure/keys) - (set))))) - -(deftest vals - (is (= #{:x} - (-> (dynamic.trace/trace (gen []) []) - (dynamic.trace/assoc-subtrace :addr (choice-trace :x)) - (clojure/vals) - (set)))) - (is (= #{:x :y} - (-> (dynamic.trace/trace (gen []) []) - (dynamic.trace/assoc-subtrace :addr1 (choice-trace :x)) - (dynamic.trace/assoc-subtrace :addr2 (choice-trace :y)) - (clojure/vals) - (set))))) - -(deftest seq - (let [trace (-> (dynamic.trace/trace (gen []) []) - (dynamic.trace/assoc-subtrace :addr0 (choice-trace :x)) - (dynamic.trace/assoc-subtrace :addr1 (choice-trace :y)) - (dynamic.trace/assoc-subtrace :addr2 (choice-trace :z)))] - (is (every? map-entry? (clojure/seq trace))))) - -(deftest get - (let [trace (dynamic.trace/trace (gen []) [])] - (is (nil? (clojure/get trace :addr)))) - (let [trace (-> (dynamic.trace/trace (gen []) []) - (dynamic.trace/assoc-subtrace :addr (choice-trace :x)))] - (is (= :x (clojure/get trace :addr))))) diff --git a/test/gen/dynamic_test.cljc b/test/gen/dynamic_test.cljc index c339a3d..49d2ad9 100644 --- a/test/gen/dynamic_test.cljc +++ b/test/gen/dynamic_test.cljc @@ -1,117 +1,159 @@ (ns gen.dynamic-test - (:require [clojure.math :as math] - [clojure.test :refer [deftest is]] - [gen.distribution.kixi :as d] + (:require [clojure.test :refer [deftest is testing]] + [clojure.test.check.generators :as gen] + [com.gfredericks.test.chuck.clojure-test :refer [checking]] + [gen.choicemap :as choicemap] + [gen.distribution.kixi :as kixi] [gen.dynamic :as dynamic :refer [gen]] [gen.generative-function :as gf] [gen.trace :as trace])) -(deftest call - (is (nil? ((gen [])))) - (is (= 0 ((gen [] 0)))) - (is (nil? ((gen [_]) 0))) - (is (= 0 ((gen [x] x) 0)))) - -(deftest gf - (let [gf (gen []) - trace (gf/simulate gf [])] - (is (= gf (trace/get-gen-fn trace))))) - -(deftest trace-form?-false - (is (not (dynamic/trace-form? '()))) - (is (not (dynamic/trace-form? '(trace))))) - -(deftest trace-form?-true - (is (dynamic/trace-form? `(dynamic/trace!))) - (is (dynamic/trace-form? `(dynamic/trace! :x))) - (is (dynamic/trace-form? `(dynamic/trace! ~'x))) - (is (dynamic/trace-form? `(dynamic/trace! :x :y)))) - -(deftest trace-args - (is (= [0] (trace/get-args (gf/simulate (gen [& _]) [0])))) - (is (= [0 1] (trace/get-args (gf/simulate (gen [& _]) [0 1]))))) - -(deftest simulate-trace - (let [gf (gen [] (dynamic/trace! :addr d/bernoulli)) - trace (gf/simulate gf []) - choice-map (trace/get-choices trace)] - (is (= #{:addr} (set (keys trace)))) - (is (= #{:addr} (set (keys choice-map)))) - (is (boolean? (:addr trace))) - (is (boolean? (:addr choice-map))))) - -(deftest simulate-splice - (let [gf0 (gen [] (dynamic/trace! :addr d/bernoulli)) - gf1 (gen [] (dynamic/splice! gf0)) - trace (gf/simulate gf1 []) - choice-map (trace/get-choices trace)] - (is (= #{:addr} (set (keys trace)))) - (is (= #{:addr} (set (keys choice-map)))) - (is (boolean? (:addr trace))) - (is (boolean? (:addr choice-map))))) - -(deftest generate-trace-trace - (let [gf (gen [] (dynamic/trace! :addr d/bernoulli)) - trace (:trace (gf/generate gf [])) - choice-map (trace/get-choices trace)] - (is (= #{:addr} (set (keys trace)))) - (is (= #{:addr} (set (keys choice-map)))) - (is (boolean? (:addr trace))) - (is (boolean? (:addr choice-map))))) - -(deftest generate-splice-trace - (let [gf0 (gen [] (dynamic/trace! :addr d/bernoulli)) - gf1 (gen [] (dynamic/splice! gf0)) - trace (:trace (gf/generate gf1 [])) - choice-map (trace/get-choices trace)] - (is (= #{:addr} (set (keys trace)))) - (is (= #{:addr} (set (keys choice-map)))) - (is (boolean? (:addr trace))) - (is (boolean? (:addr choice-map))))) - -(deftest generate-call-trace - (let [gf0 (gen [] (dynamic/trace! :addr d/bernoulli)) - gf1 (gen [] (dynamic/untraced (gf0))) - trace (:trace (gf/generate gf1 [])) - choice-map (trace/get-choices trace)] - (is (empty? trace)) - (is (empty? choice-map)))) - -(deftest generate-call-splice - (let [gf0 (gen [] (d/bernoulli)) - gf1 (gen [] (gf0)) - trace (:trace (gf/generate gf1 [])) - choice-map (trace/get-choices trace)] - (is (empty? trace)) - (is (empty? choice-map)))) +(deftest binding-tests + (letfn [(f [_] "hi!")] + (binding [dynamic/*trace* f] + (is (= f (dynamic/active-trace)) + "active-trace reflects dynamic bindings")))) + +(deftest gen-fn-tests + (is (nil? ((gen []))) + "no-arity, no return function returns nil on call") + + (checking "round-trip through functions" [x gen/any-equatable] + (is (= x ((gen [] x))))) + + (checking "round-trip through functions" + [xs (gen/vector gen/small-integer 5)] + (let [gf (gen [a b c d e] + (+ a b c d e)) + trace (gf/simulate gf xs)] + (is (= gf (trace/get-gen-fn trace)) + "distribution round trips through the trace.") + + (is (= xs (trace/get-args trace)) + "args round-trip through the trace.") + + (is (empty? (trace/get-choices trace)) + "we made no choices!") + + (is (= (apply gf xs) + (trace/get-retval trace)) + "deterministic functions match the retval.")))) + +(deftest trace-form-tests + (testing "incorrect trace forms return false." + (is (not (dynamic/trace-form? '())) + "no trace call") + + (is (not (dynamic/trace-form? '(g/trace!))) + "unknown prefix")) + + (testing "proper trace forms return true." + (is (dynamic/trace-form? `(dynamic/trace!)) + "correct, but trace will fail due to no args.") + (is (dynamic/trace-form? `(dynamic/trace! :x)) + "address only, we are lenient here!") + + (is (dynamic/trace-form? `(gen.dynamic/trace! ~'x)) + "different blessed prefixes work") + + (is (dynamic/trace-form? '(trace! :x :y)) + "for now, this special symbol works."))) + +(deftest gfi-tests + (testing "subtleties of nested tracing" + (let [gf (gen [p] (dynamic/trace! :addr kixi/bernoulli p)) + trace (gf/simulate gf [0.5])] + (is (= (choicemap/choicemap + {:addr (trace/get-retval trace)}) + (trace/get-choices trace)) + "trace choices match retval"))) + + (testing "trace inside splice should bubble up" + (let [gf0 (gen [] (dynamic/trace! :addr kixi/bernoulli)) + gf1 (gen [] (dynamic/splice! gf0)) + trace (gf/simulate gf1 [])] + (is (= (choicemap/choicemap + {:addr (trace/get-retval trace)}) + (trace/get-choices trace)) + "works for simulate") + + (let [trace (:trace (gf/generate gf1 [] choicemap/EMPTY))] + (is (= (choicemap/choicemap + {:addr (trace/get-retval trace)}) + (trace/get-choices trace)) + "with generate")))) + + (testing "trace inside of trace should nest" + (let [inner (gen [] (dynamic/trace! :inner kixi/bernoulli)) + outer (gen [] (dynamic/trace! :outer inner)) + trace (gf/simulate outer [])] + (is (= (choicemap/choicemap + {:outer + {:inner + (trace/get-retval trace)}}) + (trace/get-choices trace)) + "with simulate") + + (let [trace (:trace (gf/generate outer [] choicemap/EMPTY))] + (is (= (choicemap/choicemap + {:outer + {:inner + (trace/get-retval trace)}}) + (trace/get-choices trace)) + "with generate")))) + + (testing "explicit untracing" + (let [inner (gen [] (dynamic/trace! :addr kixi/bernoulli)) + outer (gen [] (dynamic/untraced (inner))) + trace (:trace (gf/generate outer []))] + (is (empty? + (trace/get-choices trace)) + "untraced turns off tracing"))) + + (testing "implicit untraced randomness" + (let [inner (gen [] (dynamic/trace! :addr kixi/bernoulli)) + outer (gen [] (inner)) + trace (:trace (gf/generate outer []))] + (is (empty? + (trace/get-choices trace)) + "function calls induce untraced randomness."))) + + (testing "generate-call-splice" + (let [inner (gen [] (kixi/bernoulli)) + outer (gen [] (inner)) + trace (:trace (gf/generate outer []))] + (is (empty? + (trace/get-choices trace)))))) (deftest score - (is (= 0.5 (math/exp (trace/get-score (gf/simulate d/bernoulli [0.5]))))) (let [trace (gf/simulate (gen [] - (dynamic/trace! :addr d/bernoulli 0.5)) + (dynamic/trace! :addr kixi/bernoulli 0.5)) [])] - (is (= 0.5 (math/exp (trace/get-score trace)))))) + (is (= 0.5 (Math/exp + (trace/get-score trace)))))) (deftest update-discard-yes (let [gf (gen [] - (dynamic/trace! :discarded d/bernoulli 0))] - (is (= #gen/choice-map {:discarded false} + (dynamic/trace! :discarded kixi/bernoulli 0))] + (is (= #gen/choicemap {:discarded false} (-> (gf/simulate gf []) - (trace/update #gen/choice-map {:discarded true}) + (trace/update {:discarded true}) (:discard)))))) (deftest update-discard-no (let [gf (gen [] - (dynamic/trace! :not-discarded d/bernoulli 0))] - (is (empty? (-> (gf/simulate gf []) - (trace/update #gen/choice-map {:discarded true}) - (:discard)))))) + (dynamic/trace! :not-discarded kixi/bernoulli 0))] + (try (-> (gf/simulate gf []) + (trace/update {:discarded true})) + (catch #?(:clj clojure.lang.ExceptionInfo :cljs js/Error) e + (is (= {:unvisited [:discarded]} + (ex-data e))))))) (deftest update-discard-both (let [gf (gen [] - (dynamic/trace! :discarded d/bernoulli 0) - (dynamic/trace! :not-discarded d/bernoulli 1))] - (is (= #gen/choice-map {:discarded false} + (dynamic/trace! :discarded kixi/bernoulli 0) + (dynamic/trace! :not-discarded kixi/bernoulli 1))] + (is (= #gen/choicemap {:discarded false} (-> (gf/simulate gf []) - (trace/update #gen/choice-map {:discarded true}) + (trace/update {:discarded true}) (:discard)))))) diff --git a/test/gen/sci_test.cljc b/test/gen/sci_test.cljc index 8556fc6..7bee184 100644 --- a/test/gen/sci_test.cljc +++ b/test/gen/sci_test.cljc @@ -13,9 +13,8 @@ (deftest sci-tests (testing "Check that we can evaluate a model inside SCI." (eval - '(do (require '[gen.distribution.kixi :as dist] + '(do (require '[gen.distribution.kixi :as kixi] '[gen.dynamic :as dynamic :refer [gen]] - '[gen.dynamic.trace :as dt] '[gen.generative-function :as gf] '[gen.trace :as trace]) (def line-model @@ -26,8 +25,8 @@ ;; prior beliefs about the parameters: in this case, that neither the slope ;; nor the intercept will be more than a couple points away from 0. - (let [slope (dynamic/trace! :slope dist/normal 0 1) - intercept (dynamic/trace! :intercept dist/normal 0 2) + (let [slope (dynamic/trace! :slope kixi/normal 0 1) + intercept (dynamic/trace! :intercept kixi/normal 0 2) ;; We define a function to compute y for a given x. @@ -39,7 +38,7 @@ ;; the x coordinates in our input vector. (doseq [[i x] (map vector (range) xs)] - (dynamic/trace! [:y i] dist/normal (y x) 0.1)) + (dynamic/trace! [:y i] kixi/normal (y x) 0.1)) ;; Most of the time, we don't care about the return ;; value of a model, only the random choices it makes.