Skip to content

Commit

Permalink
inference
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Dec 1, 2023
1 parent c6f2530 commit 2b0211c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 84 deletions.
2 changes: 1 addition & 1 deletion deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
org.mentat/emmy {:mvn/version "0.31.0"}
org.mentat/emmy-viewers {:mvn/version "0.3.1"}
io.github.inferenceql/gen.clj
{:git/sha "743e9d25da926d1d3f43ada66ccbdbd60e4af156"}
{:git/sha "8d427e95dc3a10c94b40305dac366f270b921c7c"}

io.github.nextjournal/clerk {:git/sha "d80187013d7b7b96db3d8b114b8d99f687170668"}
io.github.nextjournal/clerk.render
Expand Down
189 changes: 106 additions & 83 deletions notebooks/finance/model.clj
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,6 @@
0.0
simulation))

;; TODO add separate generative function - `create-business`,

(declare simulate-business)

(def create-business
(gen []
(let [retention-mean (dynamic/trace! :retention kixi/uniform 0 100)
cost-of-service (dynamic/trace! :cost-of-service kixi/uniform 0 80)
simulation (dynamic/trace! :business
(simulate-business
{:retention-mean retention-mean
}))]
;; TODO add more elements.
;;
;; TODO ths is where `:total-value` should actually be traced.
)))

;; TODO first version - teach them what the simulator is, and what they're
;; learning as they see multiple simulations at once. (this is sliders attached
;; to simulate-business, basically what we ahve now.)
Expand All @@ -94,35 +77,53 @@
;;
;; TODO is there a version of this business that has value more than something?
;;
;; - constrain on total value > sometime.
;; - constrain on total value > 10000.
;;
;; 3 pages in the thing.

(defn sim->choicemaps
"TODO does this need to be more efficient? does it matter?"
[sim]
(into {}
(map (fn [{:keys [period] :as record}]
{period record})
sim)))

(def record-row!
(gen [row wiggle]
(doseq [[k v] row]
(dynamic/trace! k kixi/normal v wiggle))))

(def record-business!
(gen [simulation wiggle]
(doseq [{:keys [period] :as row} simulation]
(dynamic/trace! period record-row! row wiggle))
simulation))

(def simulate-business
(gen [{:keys [users cost] :or {users 0 cost 0}}
{:keys [periods
retention-mean
retention-rate
cost-of-service
revenue-per-paying
free->pay
ad-spend-increase
viral-growth-kicker]
:or {retention-mean 0.8
viral-growth-kicker
wiggle]
:or {retention-rate 0.8
periods 20
cost-of-service 0
revenue-per-paying 0
free->pay 0
ad-spend-increase 0
viral-growth-kicker 0}}]
(let [retention-rate (dynamic/trace! :retention kixi/normal retention-mean 0.05)
cost-per-user (dynamic/trace! :cost-of kixi/normal cost-of-service 0.01)

pay-rate (/ free->pay 100)
viral-growth-kicker 0
wiggle 0.1}}]
(let [pay-rate (/ free->pay 100)
paying (round (* pay-rate users))
revenue (* revenue-per-paying paying)
init-cpa (cpa cost users)
spend-rate (inc (/ ad-spend-increase 100.0))
cost-of-service (* cost-per-user users)
cost-of-service (* cost-of-service users)
initial {:period 0
:ad-spend cost
:total-new-users users
Expand Down Expand Up @@ -154,18 +155,9 @@
(+ (* retention-rate
(:cumulative-paying-users prev))
new-paying-users))
cost-of-service (* cost-per-user users-sum)
cost-of-service (* cost-of-service users-sum)
total-cost (+ new-spend cost-of-service)
revenue (* revenue-per-paying paying-sum)]
;; TODO turn each of these into a trace call like [:period i]....
;;
;; TODO turn the noise into a business-level parameter,
;; so that we can loosen things up in case inference
;; goes nuts.
;;
;;
;; TODO add a "simulation config" parameter that we can
;; set with a slider.
{:period (inc (:period prev))
:ad-spend new-spend
:total-new-users total-new-users
Expand All @@ -180,20 +172,38 @@
:cumulative-users users-sum
:cumulative-paying-users paying-sum
:cpa (cpa new-spend total-new-users)}))
initial))
value (total-value simulation)]
initial))]
(dynamic/splice! record-business! simulation wiggle))))

(def create-business
(gen [initial-data {:keys [value-target periods]}]
(let [free->pay (dynamic/trace! :free->pay kixi/uniform 0 100)
ad-spend-increase (dynamic/trace! :ad-spend-increase kixi/uniform 0 100)
viral-growth-kicker (dynamic/trace! :viral-growth-kicker kixi/uniform 0 2)
retention-rate (dynamic/trace! :retention-rate kixi/uniform 0 1)
cost-of-service (dynamic/trace! :cost-of-service kixi/uniform 0 80)
revenue-per-paying (dynamic/trace! :revenue-per-paying kixi/uniform 0 80)
simulation (dynamic/trace!
:simulation
simulate-business
initial-data
{:periods periods
:free->pay free->pay
:ad-spend-increase ad-spend-increase
:viral-growth-kicker viral-growth-kicker
:retention-rate retention-rate
:cost-of-service cost-of-service
:revenue-per-paying revenue-per-paying})
value (total-value simulation)]
(dynamic/trace! :total-value kixi/normal value 0.001)
(when value-target
(dynamic/trace! :profitable?
kixi/bernoulli
(if (> value value-target)
1.0
0.0)))
simulation)))

(defn simulation-choices
[{:keys [initial-data config]}]
(repeatedly
(:trials config 10)
(fn []
(choicemap/->map
(trace/get-choices
(gf/simulate simulate-business [initial-data config]))))))

(defn choices->scatterplot [data]
{:schema "https://vega.github.io/schema/vega-lite/v5.json"
:embed/opts {:actions false}
Expand Down Expand Up @@ -234,14 +244,19 @@
;;
;; Notice that each entry maps to a slider in the hovering control panel.
;;
;; Simulation parameters:
;;
;; - `:trials`: total number of simulations to generate
;; - `:periods`: number of time periods to simulate the business
;; - `:value-target`: the goal business value, used for importance sampling.
;; - `:n-samples`: number of samples used in importance sampling

;; Business configuration:
;;
;; - `:free->pay`: percentage of users that convert from free => paid each time period
;; - `:ad-spend-increase`: percentage increase in ad spending per period
;; - `:viral-growth-kicker`: ratio of new viral users to last time period's total new users
;; - `:retention-mean`: average user retention rate, constant per simulation
;; - `:retention-rate`: average user retention rate, constant per simulation
;; - `:cost-of-service`: total cost to serve a user
;; - `:revenue-per-paying`: revenue per paying customer per time period

Expand All @@ -253,9 +268,10 @@
:free->pay 10
:ad-spend-increase 5
:viral-growth-kicker 0.75
:retention-mean 0.8
:retention-rate 0.8
:cost-of-service 0.1
:revenue-per-paying 6.446})
:revenue-per-paying 6.446
:prefix 0})

;; `simulate-business` lets us simulate the history of a business with all of
;; the assumptions from `config`:
Expand All @@ -278,39 +294,37 @@

(total-value simulation)

;; ## Simulating Many Trials
;; ## Interactive Simulation
;;
;; The following charts are tied to the slider hovering over the page. Play with
;; the sliders and watch the charts update.
;;
;; The first chart shows the revenue+cost chart from above, but tied to the
;; sliders.
;;
;; The next two charts are generated by simulating many businesses and plotting aggregates.

;; The scatterplots show pairs of (random choice, total value) for each type of
;; random choice made by the model.
;;
;; The final chart shows a histogram of total business value generated by all
;; trials.
;; Then, the table from above.

(def schema
{"Simulation Params"
(leva/folder
{:trials {:min 0 :max 1000 :step 5}
:periods {:min 1 :max 50 :step 1}
:value-target {:min 0 :max 10000000 :step 1000}
:n-samples {:min 0 :max 100 :step 5}}
{:periods {:min 1 :max 50 :step 1}}
{:order -1})

"Business Config"
(leva/folder
{:retention-mean {:min 0 :max 1 :step 0.01}
{:retention-rate {:min 0 :max 1 :step 0.01}
:free->pay {:min 0 :max 100 :step 0.01}
:ad-spend-increase {:min 0 :max 100 :step 0.01}
:viral-growth-kicker {:min 0 :max 1 :step 0.01}
:cost-of-service {:min 0 :max 10 :step 0.01}
:revenue-per-paying {:min 0 :max 70 :step 0.01}})})
:revenue-per-paying {:min 0 :max 70 :step 0.01}})

"Inference"
(leva/folder
{:value-target {:min 0 :max 10000000 :step 1000}
:n-samples {:min 0 :max 100 :step 5}
:trials {:min 0 :max 1000 :step 5}
:prefix {:min 0 :max 100 :step 10}})})

^{::clerk/visibility {:code :hide}}
(ev/with-let [!state config]
Expand All @@ -320,33 +334,33 @@
:schema schema})
(list 'let ['config {:initial-data `initial-data
:config (list 'deref !state)}
'data (list `simulation-choices 'config)
'trial (list `simulate-business
'sim (list `simulate-business
`initial-data
(list 'deref !state))]
[:<>
['nextjournal.clerk.render/render-vega-lite
(list `revenue+cost-schema 'trial)]
['nextjournal.clerk.render/render-vega-lite
(list `choices->scatterplot 'data)]
['nextjournal.clerk.render/render-vega-lite
(list `value-histogram 'data)]])])
(list `revenue+cost-schema 'sim)]
['nextjournal.clerk.render/inspect
(list `->table 'sim)]])])

;; ## Inference
;;
;; Not a great visualization, but here's a start:

(defn do-inference
[{:keys [initial-data config]}]
(let [{:keys [trials value-target n-samples]} config
[{:keys [initial-data config sim]}]
(let [{:keys [prefix periods trials n-samples]} config
prefix (take (Math/floor (* periods (/ prefix 100)))
sim)
infer (fn []
(-> (importance/resampling
simulate-business
create-business
[initial-data config]
{:total-value value-target}
{:simulation (reduce into (sim->choicemaps prefix))}
n-samples)
(:trace)
(trace/get-choices)
(choicemap/get-values-shallow)
(choicemap/->map)))]
(repeatedly trials infer)))

Expand All @@ -357,16 +371,25 @@
:data {:values data}
:layer
[{:mark :point
:encoding {:x {:field :retention :type "quantitative"}
:y {:field :cost-of :type "quantitative"}}}]})
:encoding {:x {:field :retention-rate :type "quantitative"}
:y {:field :cost-of-service :type "quantitative"}}}]})

^{::clerk/visibility {:code :hide}}
(ev/with-let [!state config]
[:<>
(leva/controls
{:atom !state :schema schema})
['nextjournal.clerk.render/render-vega-lite
(list `infer-scatterplot
(list `do-inference
{:initial-data `initial-data
:config (list 'deref !state)}))]])
(list 'let ['sim (list `simulate-business
`initial-data
(list 'deref !state))
'inf (list `do-inference
{:initial-data `initial-data
:config (list 'deref !state)
:sim 'sim})]
[:<>
['nextjournal.clerk.render/inspect
(list `->table 'sim)]
['nextjournal.clerk.render/inspect
(list `first 'inf)]
['nextjournal.clerk.render/render-vega-lite
(list `infer-scatterplot 'inf)]])])

0 comments on commit 2b0211c

Please sign in to comment.