{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts,
    UndecidableInstances, TemplateHaskell,
    BangPatterns
  #-}

{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}

module Data.Random.Distribution.Binomial where

import Data.Random.Internal.TH

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Beta
import Data.Random.Distribution.Uniform

import Numeric.SpecFunctions ( stirlingError )
import Numeric.SpecFunctions.Extra ( bd0 )
import Numeric ( log1p )

    -- algorithm from Knuth's TAOCP, 3rd ed., p 136
    -- specific choice of cutoff size taken from gsl source
    -- note that although it's fast enough for large (eg, 2^10000)
    -- @Integer@s, it's not accurate enough when using @Double@ as
    -- the @b@ parameter.
integralBinomial :: (Integral a, Floating b, Ord b, Distribution Beta b, Distribution StdUniform b) => a -> b -> RVarT m a
integralBinomial :: a -> b -> RVarT m a
integralBinomial = a -> a -> b -> RVarT m a
forall a b (m :: * -> *).
(Integral a, Floating b, Ord b, Distribution Beta b,
 Distribution StdUniform b) =>
a -> a -> b -> RVarT m a
bin 0
    where
        bin :: (Integral a, Floating b, Ord b, Distribution Beta b, Distribution StdUniform b) => a -> a -> b -> RVarT m a
        bin :: a -> a -> b -> RVarT m a
bin !a
k !a
t !b
p
            | a
t a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> 10    = do
                let a :: a
a = 1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t a -> a -> a
forall a. Integral a => a -> a -> a
`div` 2
                    b :: a
b = 1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t a -> a -> a
forall a. Num a => a -> a -> a
- a
a

                b
x <- b -> b -> RVarT m b
forall a (m :: * -> *). Distribution Beta a => a -> a -> RVarT m a
betaT (a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a) (a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
b)
                if b
x b -> b -> Bool
forall a. Ord a => a -> a -> Bool
>= b
p
                    then a -> a -> b -> RVarT m a
forall a b (m :: * -> *).
(Integral a, Floating b, Ord b, Distribution Beta b,
 Distribution StdUniform b) =>
a -> a -> b -> RVarT m a
bin  a
k      (a
a a -> a -> a
forall a. Num a => a -> a -> a
- 1) (b
p b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
x)
                    else a -> a -> b -> RVarT m a
forall a b (m :: * -> *).
(Integral a, Floating b, Ord b, Distribution Beta b,
 Distribution StdUniform b) =>
a -> a -> b -> RVarT m a
bin (a
k a -> a -> a
forall a. Num a => a -> a -> a
+ a
a) (a
b a -> a -> a
forall a. Num a => a -> a -> a
- 1) ((b
p b -> b -> b
forall a. Num a => a -> a -> a
- b
x) b -> b -> b
forall a. Fractional a => a -> a -> a
/ (1 b -> b -> b
forall a. Num a => a -> a -> a
- b
x))

            | Bool
otherwise = a -> a -> RVarT m a
forall t t (m :: * -> *).
(Ord t, Num t, Num t) =>
t -> t -> RVarT m t
count a
k a
t
                where
                    count :: t -> t -> RVarT m t
count !t
k' 0         = t -> RVarT m t
forall (m :: * -> *) a. Monad m => a -> m a
return t
k'
                    count !t
k' n :: t
n | t
n t -> t -> Bool
forall a. Ord a => a -> a -> Bool
> 0 = do
                        b
x <- RVarT m b
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
                        t -> t -> RVarT m t
count (if b
x b -> b -> Bool
forall a. Ord a => a -> a -> Bool
< b
p then t
k' t -> t -> t
forall a. Num a => a -> a -> a
+ 1 else t
k') (t
nt -> t -> t
forall a. Num a => a -> a -> a
-1)
                    count _ _ = [Char] -> RVarT m t
forall a. HasCallStack => [Char] -> a
error "integralBinomial: negative number of trials specified"

integralBinomialCDF :: (Integral a, Real b) => a -> b -> a -> Double
integralBinomialCDF :: a -> b -> a -> Double
integralBinomialCDF t :: a
t p :: b
p x :: a
x = [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Double] -> Double) -> [Double] -> Double
forall a b. (a -> b) -> a -> b
$ (a -> Double) -> [a] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (a -> b -> a -> Double
forall a b. (Integral a, Real b) => a -> b -> a -> Double
integralBinomialPDF a
t b
p) ([a] -> [Double]) -> [a] -> [Double]
forall a b. (a -> b) -> a -> b
$ [0 .. a
x]

-- | The probability of getting exactly k successes in n trials is
-- given by the probability mass function:
--
-- \[
-- f(k;n,p) = \Pr(X = k) = \binom n k  p^k(1-p)^{n-k}
-- \]
--
-- Note that in `integralBinomialPDF` the parameters of the mass
-- function are given first and the range of the random variable
-- distributed according to the binomial distribution is given
-- last. That is, \(f(2;4,0.5)\) is calculated by @integralBinomialPDF 4 0.5 2@.

integralBinomialPDF :: (Integral a, Real b) => a -> b -> a -> Double
integralBinomialPDF :: a -> b -> a -> Double
integralBinomialPDF t :: a
t p :: b
p x :: a
x =
  Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ a -> b -> a -> Double
forall a b. (Integral a, Real b) => a -> b -> a -> Double
integralBinomialLogPdf a
t b
p a
x

-- | We use the method given in \"Fast and accurate computation of
-- binomial probabilities, Loader, C\",
-- <http://octave.1599824.n4.nabble.com/attachment/3829107/0/loader2000Fast.pdf>
integralBinomialLogPdf :: (Integral a, Real b) => a -> b -> a -> Double
integralBinomialLogPdf :: a -> b -> a -> Double
integralBinomialLogPdf nI :: a
nI pR :: b
pR xI :: a
xI
  | Double
p Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== 0.0 Bool -> Bool -> Bool
&& a
xI a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== 0   = 1.0
  | Double
p Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== 0.0              = 0.0
  | Double
p Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== 1.0 Bool -> Bool -> Bool
&& a
xI a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
nI  = 1.0
  | Double
p Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== 1.0              = 0.0
  |             a
xI a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== 0   = Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (1Double -> Double -> Double
forall a. Num a => a -> a -> a
-Double
p)
  |             a
xI a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
nI  = Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
p
  | Bool
otherwise = Double
lc Double -> Double -> Double
forall a. Num a => a -> a -> a
- 0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
lf
  where
    n :: Double
n = a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
nI
    x :: Double
x = a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
xI
    p :: Double
p = b -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac b
pR
    lc :: Double
lc = Double -> Double
stirlingError Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
-
         Double -> Double
stirlingError Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
-
         Double -> Double
stirlingError (Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
x) Double -> Double -> Double
forall a. Num a => a -> a -> a
-
         Double -> Double -> Double
bd0 Double
x (Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
p) Double -> Double -> Double
forall a. Num a => a -> a -> a
-
         Double -> Double -> Double
bd0 (Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
x) (Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
* (1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
p))
    lf :: Double
lf = Double -> Double
forall a. Floating a => a -> a
log (2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log1p (- Double
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
n)

-- would it be valid to repeat the above computation using fractional @t@?
-- obviously something different would have to be done with @count@ as well...
{-# SPECIALIZE floatingBinomial :: Float  -> Float  -> RVar Float  #-}
{-# SPECIALIZE floatingBinomial :: Float  -> Double -> RVar Float  #-}
{-# SPECIALIZE floatingBinomial :: Double -> Float  -> RVar Double #-}
{-# SPECIALIZE floatingBinomial :: Double -> Double -> RVar Double #-}
floatingBinomial :: (RealFrac a, Distribution (Binomial b) Integer) => a -> b -> RVar a
floatingBinomial :: a -> b -> RVar a
floatingBinomial t :: a
t p :: b
p = (Integer -> a) -> RVarT Identity Integer -> RVar a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> a
forall a. Num a => Integer -> a
fromInteger (Binomial b Integer -> RVarT Identity Integer
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (Integer -> b -> Binomial b Integer
forall b a. a -> b -> Binomial b a
Binomial (a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate a
t) b
p))

floatingBinomialCDF :: (CDF (Binomial b) Integer, RealFrac a) => a -> b -> a -> Double
floatingBinomialCDF :: a -> b -> a -> Double
floatingBinomialCDF t :: a
t p :: b
p x :: a
x = Binomial b Integer -> Integer -> Double
forall (d :: * -> *) t. CDF d t => d t -> t -> Double
cdf (Integer -> b -> Binomial b Integer
forall b a. a -> b -> Binomial b a
Binomial (a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate a
t :: Integer) b
p) (a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor a
x)

floatingBinomialPDF :: (PDF (Binomial b) Integer, RealFrac a) => a -> b -> a -> Double
floatingBinomialPDF :: a -> b -> a -> Double
floatingBinomialPDF t :: a
t p :: b
p x :: a
x = Binomial b Integer -> Integer -> Double
forall (d :: * -> *) t. PDF d t => d t -> t -> Double
pdf (Integer -> b -> Binomial b Integer
forall b a. a -> b -> Binomial b a
Binomial (a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate a
t :: Integer) b
p) (a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor a
x)

floatingBinomialLogPDF :: (PDF (Binomial b) Integer, RealFrac a) => a -> b -> a -> Double
floatingBinomialLogPDF :: a -> b -> a -> Double
floatingBinomialLogPDF t :: a
t p :: b
p x :: a
x = Binomial b Integer -> Integer -> Double
forall (d :: * -> *) t. PDF d t => d t -> t -> Double
logPdf (Integer -> b -> Binomial b Integer
forall b a. a -> b -> Binomial b a
Binomial (a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate a
t :: Integer) b
p) (a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor a
x)

{-# SPECIALIZE binomial :: Int     -> Float  -> RVar Int #-}
{-# SPECIALIZE binomial :: Int     -> Double -> RVar Int #-}
{-# SPECIALIZE binomial :: Integer -> Float  -> RVar Integer #-}
{-# SPECIALIZE binomial :: Integer -> Double -> RVar Integer #-}
{-# SPECIALIZE binomial :: Float   -> Float  -> RVar Float  #-}
{-# SPECIALIZE binomial :: Float   -> Double -> RVar Float  #-}
{-# SPECIALIZE binomial :: Double  -> Float  -> RVar Double #-}
{-# SPECIALIZE binomial :: Double  -> Double -> RVar Double #-}
binomial :: Distribution (Binomial b) a => a -> b -> RVar a
binomial :: a -> b -> RVar a
binomial t :: a
t p :: b
p = Binomial b a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (a -> b -> Binomial b a
forall b a. a -> b -> Binomial b a
Binomial a
t b
p)

{-# SPECIALIZE binomialT :: Int     -> Float  -> RVarT m Int #-}
{-# SPECIALIZE binomialT :: Int     -> Double -> RVarT m Int #-}
{-# SPECIALIZE binomialT :: Integer -> Float  -> RVarT m Integer #-}
{-# SPECIALIZE binomialT :: Integer -> Double -> RVarT m Integer #-}
{-# SPECIALIZE binomialT :: Float   -> Float  -> RVarT m Float  #-}
{-# SPECIALIZE binomialT :: Float   -> Double -> RVarT m Float  #-}
{-# SPECIALIZE binomialT :: Double  -> Float  -> RVarT m Double #-}
{-# SPECIALIZE binomialT :: Double  -> Double -> RVarT m Double #-}
binomialT :: Distribution (Binomial b) a => a -> b -> RVarT m a
binomialT :: a -> b -> RVarT m a
binomialT t :: a
t p :: b
p = Binomial b a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (a -> b -> Binomial b a
forall b a. a -> b -> Binomial b a
Binomial a
t b
p)

data Binomial b a = Binomial a b

$( replicateInstances ''Int integralTypes [d|
        instance ( Floating b, Ord b
                 , Distribution Beta b
                 , Distribution StdUniform b
                 ) => Distribution (Binomial b) Int
            where
                rvarT (Binomial t p) = integralBinomial t p
        instance ( Real b , Distribution (Binomial b) Int
                 ) => CDF (Binomial b) Int
            where cdf  (Binomial t p) = integralBinomialCDF t p
        instance ( Real b , Distribution (Binomial b) Int
                 ) => PDF (Binomial b) Int
            where pdf (Binomial t p) = integralBinomialPDF t p
                  logPdf (Binomial t p) = integralBinomialLogPdf t p
    |])

$( replicateInstances ''Float realFloatTypes [d|
        instance Distribution (Binomial b) Integer
              => Distribution (Binomial b) Float
              where rvar (Binomial t p) = floatingBinomial t p
        instance CDF (Binomial b) Integer
              => CDF (Binomial b) Float
              where cdf  (Binomial t p) = floatingBinomialCDF t p
        instance PDF (Binomial b) Integer
              => PDF (Binomial b) Float
              where pdf (Binomial t p) = floatingBinomialPDF t p
                    logPdf (Binomial t p) = floatingBinomialLogPDF t p
    |])