Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sampling for multivariate normal distribution #293

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/Distribution.idr
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ limitations under the License.
||| This module contains definitions for probability distributions.
module Distribution

import Control.Monad.State
import Data.Nat

import Constants
import Literal
import Tensor
import Constants

||| A joint, or multivariate distribution over a tensor of floating point values, where the first
||| two central moments (mean and covariance) are known. Every sub-event is assumed to have the
Expand Down Expand Up @@ -51,6 +53,10 @@ public export
interface Distribution dist =>
ClosedFormDistribution (0 event : Shape)
(0 dist : (0 event : Shape) -> (0 dim : Nat) -> Type) where
||| Produce `n` IID samples from this distribution.
-- which interface does this belong to?
sample : dist event dim -> {n : _} -> Rand $ Tensor (n :: dim :: event) F64

||| The probability density function of the distribution at the specified point.
pdf : dist event (S d) -> Tensor (S d :: event) F64 -> Tensor [] F64

Expand Down Expand Up @@ -78,6 +84,9 @@ Distribution Gaussian where
||| **NOTE** `cdf` is implemented only for univariate `Gaussian`.
export
ClosedFormDistribution [1] Gaussian where
sample (MkGaussian mean cov) key =
pure $ expand 2 $ (broadcast mean + cholesky (squeeze cov) @@ !(normal key)).T

pdf (MkGaussian {d} mean cov) x =
let cholCov = cholesky (squeeze {to=[S d, S d]} cov)
tri = cholCov |\ squeeze (x - mean)
Expand Down
12 changes: 4 additions & 8 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ trace x with (x)
||| The state is updated each time a new value is generated.
public export
Rand : Type -> Type
Rand = State (Tensor [1] U64)
Rand a = Tensor [] U64 -> State (Tensor [1] U64) a

inf : Tensor [] F64
inf = fromDouble (1.0 / 0.0)
Expand Down Expand Up @@ -1279,12 +1279,8 @@ inf = fromDouble (1.0 / 0.0)
||| @bound A bound of the samples. See full docstring for details.
||| @bound' A bound of the samples. See full docstring for details.
export
uniform :
{shape : _} ->
(key : Tensor [] U64) ->
(bound, bound' : Tensor shape F64) ->
Rand (Tensor shape F64)
uniform (MkTensor keyGraph key) bound bound' =
uniform : {shape : _} -> (bound, bound' : Tensor shape F64) -> Rand (Tensor shape F64)
uniform bound bound' (MkTensor keyGraph key) =
let MkTensor minvalGraph minval = min bound bound'
MkTensor maxvalGraph maxval = max bound bound'
in ST $ \(MkTensor initialStateGraph initialState) =>
Expand Down Expand Up @@ -1320,7 +1316,7 @@ uniform (MkTensor keyGraph key) bound bound' =
|||
||| @key Determines the stream of generated samples.
export
normal : {shape : _} -> (key : Tensor [] U64) -> Rand (Tensor shape F64)
normal : {shape : _} -> Rand (Tensor shape F64)
normal (MkTensor keyGraph key) =
ST $ \(MkTensor initialStateGraph initialState) =>
let valueGraph = NormalFloatingPointDistributionValue keyGraph initialStateGraph ThreeFry shape
Expand Down
8 changes: 4 additions & 4 deletions test/Unit/TestTensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ uniform = withTests 20 . property $ do
seed = fromLiteral seed

samples : Tensor [1000, 10] F64 :=
evalState seed (uniform key (broadcast bound) (broadcast bound'))
evalState seed (uniform (broadcast bound) (broadcast bound') key)

uniformCdf : Tensor [1000, 10] F64 -> Tensor [1000, 10] F64
uniformCdf x = (x - broadcast bound) / broadcast (bound' - bound)
Expand All @@ -1101,7 +1101,7 @@ uniformForEqualBounds = withTests 20 . property $ do
key = fromLiteral key
seed = fromLiteral seed

samples : Tensor [6] F64 = evalState seed (uniform key bound bound)
samples : Tensor [6] F64 = evalState seed (uniform bound bound key)

samples ===# fromLiteral [nan, nan, nan, -1.0, 0.0, 1.0]

Expand All @@ -1118,7 +1118,7 @@ uniformSeedIsUpdated = withTests 20 . property $ do
key = fromLiteral key
seed = fromLiteral seed

rng = uniform key {shape=[10]} (broadcast bound) (broadcast bound')
rng = uniform {shape=[10]} (broadcast bound) (broadcast bound') key
(seed', sample) = runState seed rng
(seed'', sample') = runState seed' rng

Expand All @@ -1139,7 +1139,7 @@ uniformIsReproducible = withTests 20 . property $ do
key = fromLiteral key
seed = fromLiteral seed

rng = uniform {shape=[10]} key (broadcast bound) (broadcast bound')
rng = uniform {shape=[10]} (broadcast bound) (broadcast bound') key
sample = evalState seed rng
sample' = evalState seed rng

Expand Down