Skip to content

Commit

Permalink
feat: add map-shaped categorical to commons
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Dec 12, 2023
1 parent eefa9af commit ad29921
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/gen/distribution/commons_math.clj
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,34 @@
(defn uniform-discrete-distribution [low high]
(UniformIntegerDistribution. (rng) low high))

(defn categorical-distribution [probabilities]
(defn- v->categorical [probabilities]
(let [n (count probabilities)
ks (int-array (range n))
vs (double-array probabilities)]
(EnumeratedIntegerDistribution. (rng) ks vs)))

(defn- m->categorical [probabilities]
(let [ks (keys probabilities)
vs (vals probabilities)
k->i (zipmap ks (range))
i->k (zipmap (range) ks)]
(-> (v->categorical vs)
(d/->Encoded k->i i->k))))

(defn categorical-distribution
"Given either
- a sequence of `probabilities` that sum to 1.0
- a map of object => probability (whose values sum to 1.0)
returns a distribution that produces samples of an integer in the range $[0,
n)$ (where `n == (count probabilities)`), or of a map key (for map-shaped
`probabilities`)."
[probabilities]
(if (map? probabilities)
(m->categorical probabilities)
(v->categorical probabilities)))

;; ## Primitive generative functions

(def bernoulli
Expand Down
3 changes: 3 additions & 0 deletions test/gen/distribution/commons_math_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
(deftest beta-tests
(dt/beta-tests commons/beta-distribution))

(deftest categorical-tests
(dt/categorical-tests commons/categorical-distribution))

(deftest uniform-tests
(dt/uniform-tests commons/uniform-distribution))

Expand Down
19 changes: 19 additions & 0 deletions test/gen/distribution_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@
(Math/exp (dist/logpdf (->bernoulli p) (not v)))))
"All options sum to 1")))

(defn categorical-tests [->cat]
(checking "map => categorical properties"
[p (gen-double 0 1)]
(let [dist (->cat {:true p :false (- 1 p)})]
(is (ish? (Math/log p) (dist/logpdf dist :true))
"prob of `:true` matches `p`")

(is (ish? (Math/log (- 1 p)) (dist/logpdf dist :false))
"prob of `:false` matches `1-p`")))

(checking "vector => categorical properties"
[p (gen-double 0 1)]
(let [dist (->cat [p (- 1 p)])]
(is (ish? (Math/log p) (dist/logpdf dist 0))
"prob of `1` matches `p`")

(is (ish? (Math/log (- 1 p)) (dist/logpdf dist 1))
"prob of `0` matches `1-p`"))))

(defn bernoulli-gfi-tests [bernoulli-dist]
(testing "bernoulli-call-no-args"
(is (boolean? (bernoulli-dist))))
Expand Down

0 comments on commit ad29921

Please sign in to comment.