{-# LANGUAGE GADTs, MultiParamTypeClasses, FlexibleInstances, FlexibleContexts #-}
module Data.Random.Distribution.Multinomial where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Binomial
multinomial :: Distribution (Multinomial p) [a] => [p] -> a -> RVar [a]
multinomial :: [p] -> a -> RVar [a]
multinomial ps :: [p]
ps n :: a
n = Multinomial p [a] -> RVar [a]
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar ([p] -> a -> Multinomial p [a]
forall p a. [p] -> a -> Multinomial p [a]
Multinomial [p]
ps a
n)
multinomialT :: Distribution (Multinomial p) [a] => [p] -> a -> RVarT m [a]
multinomialT :: [p] -> a -> RVarT m [a]
multinomialT ps :: [p]
ps n :: a
n = Multinomial p [a] -> RVarT m [a]
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT ([p] -> a -> Multinomial p [a]
forall p a. [p] -> a -> Multinomial p [a]
Multinomial [p]
ps a
n)
data Multinomial p a where
Multinomial :: [p] -> a -> Multinomial p [a]
instance (Num a, Eq a, Fractional p, Distribution (Binomial p) a) => Distribution (Multinomial p) [a] where
rvarT :: Multinomial p [a] -> RVarT n [a]
rvarT (Multinomial ps0 :: [p]
ps0 t :: a
t) = a -> [p] -> [p] -> ([a] -> [a]) -> RVarT n [a]
forall t b c (m :: * -> *).
(Eq t, Distribution (Binomial b) t, Fractional b, Num t) =>
t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go a
t [p]
ps0 ([p] -> [p]
forall a. Num a => [a] -> [a]
tailSums [p]
ps0) [a] -> [a]
forall a. a -> a
id
where
go :: t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go _ [] _ f :: [t] -> c
f = c -> RVarT m c
forall (m :: * -> *) a. Monad m => a -> m a
return ([t] -> c
f [])
go n :: t
n [_] _ f :: [t] -> c
f = c -> RVarT m c
forall (m :: * -> *) a. Monad m => a -> m a
return ([t] -> c
f [t
n])
go 0 (_:ps :: [b]
ps) (_ :psums :: [b]
psums) f :: [t] -> c
f = t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go 0 [b]
ps [b]
psums ([t] -> c
f ([t] -> c) -> ([t] -> [t]) -> [t] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (0t -> [t] -> [t]
forall a. a -> [a] -> [a]
:))
go n :: t
n (p :: b
p:ps :: [b]
ps) (psum :: b
psum:psums :: [b]
psums) f :: [t] -> c
f = do
t
x <- t -> b -> RVarT m t
forall b a (m :: * -> *).
Distribution (Binomial b) a =>
a -> b -> RVarT m a
binomialT t
n (b
p b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
psum)
t -> [b] -> [b] -> ([t] -> c) -> RVarT m c
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
x) [b]
ps [b]
psums ([t] -> c
f ([t] -> c) -> ([t] -> [t]) -> [t] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (t
xt -> [t] -> [t]
forall a. a -> [a] -> [a]
:))
go _ _ _ _ = [Char] -> RVarT m c
forall a. HasCallStack => [Char] -> a
error "rvar/Multinomial: programming error! this case should be impossible!"
tailSums :: [a] -> [a]
tailSums [] = [0]
tailSums (x :: a
x:xs :: [a]
xs) = case [a] -> [a]
tailSums [a]
xs of
(s :: a
s:rest :: [a]
rest) -> (a
xa -> a -> a
forall a. Num a => a -> a -> a
+a
s)a -> [a] -> [a]
forall a. a -> [a] -> [a]
:a
sa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
rest
_ -> [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error "rvar/Multinomial/tailSums: programming error! this case should be impossible!"