-- | Operations on vector of `Default`.
module MCSP.Algorithms.Vector (
    Default,

    -- * Initialization
    zeros,
    replicate,

    -- * Element-wise Operations
    map,
    choose,
    (.+),
    (.-),
    (.*),
    (.*.),
    sum,

    -- * Sorting
    sort,
    sortOn,
    sortBy,
    argSort,
    sortLike,

    -- * Statistics
    normalized,
    standardized,

    -- * Monadic Operations
    sumM,
    replicateM,

    -- ** Random Operations
    uniformN,
    uniformSN,
    uniformRN,
    weighted,
    weightedN,
    choice,
) where

import Control.Applicative (pure)
import Control.Exception.Extra (errorWithoutStackTrace)
import Control.Monad (Monad, join, sequence)
import Data.Bool (Bool (..), bool, otherwise, (||))
import Data.Eq (Eq (..))
import Data.Foldable1 (foldl1')
import Data.Function (id, on, ($), (.))
import Data.Functor ((<$>))
import Data.Int (Int)
import Data.List.NonEmpty (NonEmpty)
import Data.Ord (Ord (..), Ordering)
import Data.Vector.Algorithms.Merge qualified as Vector (sort, sortBy)
import Data.Vector.Generic qualified as Vector (maximumOn, sum)
import Data.Vector.Unboxed (
    Unbox,
    Vector,
    create,
    length,
    map,
    modify,
    null,
    replicate,
    replicateM,
    unsafeBackpermute,
    unsafeIndex,
    zipWith,
 )
import Data.Vector.Unboxed.Mutable (generate)
import GHC.Float (Double, Float, Floating, sqrt)
import GHC.Num (Num (..))
import GHC.Real (Fractional, fromIntegral, (/))
import Text.Printf (printf)

import MCSP.System.Random (Random, Variate, uniformR, weightedChoice)

-- | Default type used in specialized vector operations.
type Default = Float

-- | Checks that both input vectors have same length before applying an operation.
--
-- >>> withSameLength (zipWith (+)) [1, 2, 3] [4, 5, 6]
-- [5.0,7.0,9.0]
--
-- >>> withSameLength (zipWith (+)) [1, 2, 3] [4, 5]
-- *** Exception: length mismatch: 3 != 2
withSameLength :: (Unbox a, Unbox b) => (Vector a -> Vector b -> c) -> Vector a -> Vector b -> c
withSameLength :: forall a b c.
(Unbox a, Unbox b) =>
(Vector a -> Vector b -> c) -> Vector a -> Vector b -> c
withSameLength Vector a -> Vector b -> c
f Vector a
v1 Vector b
v2 =
    if Vector a -> Int
forall a. Unbox a => Vector a -> Int
length Vector a
v1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector b -> Int
forall a. Unbox a => Vector a -> Int
length Vector b
v2
        then Vector a -> Vector b -> c
f Vector a
v1 Vector b
v2
        else [Char] -> c
forall a. [Char] -> a
errorWithoutStackTrace ([Char] -> c) -> [Char] -> c
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"length mismatch: %d != %d" (Vector a -> Int
forall a. Unbox a => Vector a -> Int
length Vector a
v1) (Vector b -> Int
forall a. Unbox a => Vector a -> Int
length Vector b
v2)
{-# INLINE withSameLength #-}

-- -------------- --
-- Initialization --
-- -------------- --

-- | Create a vector of zeros given the length.
--
-- >>> zeros 3
-- [0.0,0.0,0.0]
--
-- >>> zeros 0
-- []
zeros :: (Unbox a, Num a) => Int -> Vector a
zeros :: forall a. (Unbox a, Num a) => Int -> Vector a
zeros Int
s = Int -> a -> Vector a
forall a. Unbox a => Int -> a -> Vector a
replicate Int
s a
0
{-# SPECIALIZE zeros :: Int -> Vector Default #-}

-- ----------------------- --
-- Element-Wise Operations --
-- ----------------------- --

-- | Case analysis for the Bool type.
--
-- @`choose` x y p@ evaluates to @x@ in every position that is `False` in @p@, and evaluates to @y@
-- everywehere else. Works as a vector version of `bool`.
--
-- The name is taken from
-- [numpy.choose](https://numpy.org/doc/stable/reference/generated/numpy.choose.html).
--
-- >>> choose 10 (-5) [True, False, False, True]
-- [-5.0,10.0,10.0,-5.0]
choose :: Unbox a => a -> a -> Vector Bool -> Vector a
choose :: forall a. Unbox a => a -> a -> Vector Bool -> Vector a
choose a
falsy a
truthy = (Bool -> a) -> Vector Bool -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
map (a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
falsy a
truthy)
{-# SPECIALIZE choose :: Default -> Default -> Vector Bool -> Vector Default #-}

infixl 7 .*, .*.
infixl 6 .+, .-

-- | Element-wise addition.
--
-- >>> [1, 2, 3] .+ [5, 6, 7]
-- [6.0,8.0,10.0]
(.+) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a
.+ :: forall a. (Unbox a, Num a) => Vector a -> Vector a -> Vector a
(.+) = (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(+)
{-# SPECIALIZE (.+) :: Vector Default -> Vector Default -> Vector Default #-}

-- | Element-wise subtraction.
--
-- >>> [1, 2, 3] .- [5, 6, 7]
-- [-4.0,-4.0,-4.0]
(.-) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a
.- :: forall a. (Unbox a, Num a) => Vector a -> Vector a -> Vector a
(.-) = (Vector a -> Vector a -> Vector a)
-> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b) =>
(Vector a -> Vector b -> c) -> Vector a -> Vector b -> c
withSameLength ((Vector a -> Vector a -> Vector a)
 -> Vector a -> Vector a -> Vector a)
-> (Vector a -> Vector a -> Vector a)
-> Vector a
-> Vector a
-> Vector a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith (-)
{-# SPECIALIZE (.-) :: Vector Default -> Vector Default -> Vector Default #-}

-- | Element-wise multiplication.
--
-- >>> [1, 2, 3] .* [5, 6, 7]
-- [5.0,12.0,21.0]
(.*) :: (Unbox a, Num a) => Vector a -> Vector a -> Vector a
.* :: forall a. (Unbox a, Num a) => Vector a -> Vector a -> Vector a
(.*) = (Vector a -> Vector a -> Vector a)
-> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b) =>
(Vector a -> Vector b -> c) -> Vector a -> Vector b -> c
withSameLength ((Vector a -> Vector a -> Vector a)
 -> Vector a -> Vector a -> Vector a)
-> (Vector a -> Vector a -> Vector a)
-> Vector a
-> Vector a
-> Vector a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(*)
{-# SPECIALIZE (.*) :: Vector Default -> Vector Default -> Vector Default #-}

-- | Multiplication by a scalar.
--
-- >>> 3 .*. [5, 6, 7]
-- [15.0,18.0,21.0]
(.*.) :: (Unbox a, Num a) => a -> Vector a -> Vector a
a
factor .*. :: forall a. (Unbox a, Num a) => a -> Vector a -> Vector a
.*. Vector a
vector = (a -> a) -> Vector a -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
map (a
factor *) Vector a
vector
{-# SPECIALIZE (.*.) :: Default -> Vector Default -> Vector Default #-}

-- | Element-wise sum of all vectors.
--
-- >>> sum [[1, 2], [3, 4], [5, 6]]
-- [9.0,12.0]
sum :: (Unbox a, Num a) => NonEmpty (Vector a) -> Vector a
sum :: forall a. (Unbox a, Num a) => NonEmpty (Vector a) -> Vector a
sum = (Vector a -> Vector a -> Vector a)
-> NonEmpty (Vector a) -> Vector a
forall (t :: * -> *) a. Foldable1 t => (a -> a -> a) -> t a -> a
foldl1' Vector a -> Vector a -> Vector a
forall a. (Unbox a, Num a) => Vector a -> Vector a -> Vector a
(.+)
{-# INLINE sum #-}
{-# SPECIALIZE INLINE sum :: NonEmpty (Vector Default) -> Vector Default #-}

-- ------- --
-- Sorting --
-- ------- --

-- | Sorts an array using the default comparison.
--
-- >>> sort [3, 1, 2]
-- [1.0,2.0,3.0]
sort :: (Unbox a, Ord a) => Vector a -> Vector a
sort :: forall a. (Unbox a, Ord a) => Vector a -> Vector a
sort = (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
modify MVector s a -> ST s ()
MVector (PrimState (ST s)) a -> ST s ()
forall s. MVector s a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e, Ord e) =>
v (PrimState m) e -> m ()
Vector.sort
{-# SPECIALIZE sort :: Vector Default -> Vector Default #-}

-- | Sorts a vector using a custom comparison.
--
-- >>> import Data.Ord (Ordering (..))
-- >>> sortBy (\x y -> if x * x < y * y then LT else GT) [-3, -1, 2]
-- [-1.0,2.0,-3.0]
sortBy :: Unbox a => (a -> a -> Ordering) -> Vector a -> Vector a
sortBy :: forall a. Unbox a => (a -> a -> Ordering) -> Vector a -> Vector a
sortBy a -> a -> Ordering
cmp = (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
modify ((a -> a -> Ordering) -> MVector (PrimState (ST s)) a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
Vector.sortBy a -> a -> Ordering
cmp)

-- | Sorts a vector by comparing the results of a key function applied to each element.
--
-- >>> sortOn (\x -> x * x) [-3, -1, 2]
-- [-1.0,2.0,-3.0]
sortOn :: (Unbox a, Ord b) => (a -> b) -> Vector a -> Vector a
sortOn :: forall a b. (Unbox a, Ord b) => (a -> b) -> Vector a -> Vector a
sortOn a -> b
key = (a -> a -> Ordering) -> Vector a -> Vector a
forall a. Unbox a => (a -> a -> Ordering) -> Vector a -> Vector a
sortBy (b -> b -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (b -> b -> Ordering) -> (a -> b) -> a -> a -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` a -> b
key)
{-# SPECIALIZE sortOn :: Unbox a => (a -> Default) -> Vector a -> Vector a #-}

-- | Returns the indices that would sort the vector.
--
-- >>> argSort [30, 10, 20]
-- [1,2,0]
--
-- >>> argSort $ argSort [30, 10, 20]
-- [2,0,1]
argSort :: (Unbox a, Ord a) => Vector a -> Vector Int
argSort :: forall a. (Unbox a, Ord a) => Vector a -> Vector Int
argSort Vector a
vec = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
    MVector s Int
index <- Int -> (Int -> Int) -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> (Int -> a) -> m (MVector (PrimState m) a)
generate (Vector a -> Int
forall a. Unbox a => Vector a -> Int
length Vector a
vec) Int -> Int
forall a. a -> a
id
    -- SAFETY: index was created above, so it must be inbounds
    Comparison Int -> MVector (PrimState (ST s)) Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
Vector.sortBy (a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (a -> a -> Ordering) -> (Int -> a) -> Comparison Int
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
unsafeIndex Vector a
vec) MVector s Int
MVector (PrimState (ST s)) Int
index
    pure MVector s Int
index
{-# SPECIALIZE argSort :: Vector Default -> Vector Int #-}
{-# SPECIALIZE argSort :: Vector Int -> Vector Int #-}

-- | Sort a vector based on the values of another array.
--
-- >>> sortLike [30, 10, 20] [0.2, 0.9, 0.1]
-- [20.0,30.0,10.0]
sortLike :: (Unbox a, Unbox b, Ord b) => Vector a -> Vector b -> Vector a
sortLike :: forall a b.
(Unbox a, Unbox b, Ord b) =>
Vector a -> Vector b -> Vector a
sortLike = (Vector a -> Vector b -> Vector a)
-> Vector a -> Vector b -> Vector a
forall a b c.
(Unbox a, Unbox b) =>
(Vector a -> Vector b -> c) -> Vector a -> Vector b -> c
withSameLength ((Vector a -> Vector b -> Vector a)
 -> Vector a -> Vector b -> Vector a)
-> (Vector a -> Vector b -> Vector a)
-> Vector a
-> Vector b
-> Vector a
forall a b. (a -> b) -> a -> b
$ \Vector a
x Vector b
y ->
    -- SAFETY: beckpermute is safe here because both vector have the same length
    Vector a -> Vector Int -> Vector a
forall a. Unbox a => Vector a -> Vector Int -> Vector a
unsafeBackpermute Vector a
x (Vector b -> Vector Int
forall a. (Unbox a, Ord a) => Vector a -> Vector Int
argSort Vector b
y)
{-# SPECIALIZE sortLike :: Unbox a => Vector a -> Vector Default -> Vector a #-}
{-# SPECIALIZE sortLike :: Vector Default -> Vector Int -> Vector Default #-}

-- ---------- --
-- Statistics --
-- ---------- --

-- | Normalize a vector by its maximum absolute value.
--
-- >>> normalized [1, 2, 5, 10]
-- [0.1,0.2,0.5,1.0]
--
-- >>> normalized [1, 2, 5, -10]
-- [0.1,0.2,0.5,-1.0]
--
-- >>> normalized []
-- []
--
-- >>> normalized [0]
-- [0.0]
normalized :: (Unbox a, Fractional a, Ord a) => Vector a -> Vector a
normalized :: forall a. (Unbox a, Fractional a, Ord a) => Vector a -> Vector a
normalized Vector a
vector
    | Vector a -> Bool
forall a. Unbox a => Vector a -> Bool
null Vector a
vector Bool -> Bool -> Bool
|| a
absMax a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = Vector a
vector
    | Bool
otherwise = (a -> a) -> Vector a -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
map (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
absMax) Vector a
vector
  where
    absMax :: a
absMax = a -> a
forall a. Num a => a -> a
abs ((a -> a) -> Vector a -> a
forall b (v :: * -> *) a.
(Ord b, Vector v a) =>
(a -> b) -> v a -> a
Vector.maximumOn a -> a
forall a. Num a => a -> a
abs Vector a
vector)
{-# SPECIALIZE normalized :: Vector Default -> Vector Default #-}

-- | Average value in a vector.
--
-- >>> mean [1, 2, 5, 10]
-- 4.5
--
-- >>> mean []
-- 0.0
mean :: (Unbox a, Fractional a) => Vector a -> a
mean :: forall a. (Unbox a, Fractional a) => Vector a -> a
mean Vector a
vector
    | Vector a -> Bool
forall a. Unbox a => Vector a -> Bool
null Vector a
vector = a
0
    | Bool
otherwise = Vector a -> a
forall (v :: * -> *) a. (Vector v a, Num a) => v a -> a
Vector.sum Vector a
vector a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector a -> Int
forall a. Unbox a => Vector a -> Int
length Vector a
vector)
{-# SPECIALIZE mean :: Vector Default -> Default #-}

-- | Variance of the values in a vector.
--
-- >>> variance [1, 2, 5, 10]
-- 12.25
--
-- >>> variance []
-- 0.0
variance :: (Unbox a, Fractional a) => Vector a -> a
variance :: forall a. (Unbox a, Fractional a) => Vector a -> a
variance Vector a
vector = Vector a -> a
forall a. (Unbox a, Fractional a) => Vector a -> a
mean (Vector a
dev Vector a -> Vector a -> Vector a
forall a. (Unbox a, Num a) => Vector a -> Vector a -> Vector a
.* Vector a
dev)
  where
    u :: a
u = Vector a -> a
forall a. (Unbox a, Fractional a) => Vector a -> a
mean Vector a
vector
    dev :: Vector a
dev = (a -> a) -> Vector a -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
map (\a
x -> a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
u) Vector a
vector
{-# SPECIALIZE variance :: Vector Default -> Default #-}

-- | Standard Deviation of the values in a vector.
--
-- >>> stdev [1, 2, 5, 10]
-- 3.5
stdev :: (Unbox a, Floating a) => Vector a -> a
stdev :: forall a. (Unbox a, Floating a) => Vector a -> a
stdev = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> (Vector a -> a) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> a
forall a. (Unbox a, Fractional a) => Vector a -> a
variance
{-# SPECIALIZE stdev :: Vector Default -> Default #-}

-- | Adapt values for such that the mean becomes zero and standard deviation, one.
--
-- >>> standardized [1, 2, 5, 10]
-- [-1.0,-0.7142857142857143,0.14285714285714285,1.5714285714285714]
--
-- >>> standardized []
-- []
--
-- >>> standardized [1, 1]
-- [0.0,0.0]
standardized :: (Unbox a, Floating a, Eq a) => Vector a -> Vector a
standardized :: forall a. (Unbox a, Floating a, Eq a) => Vector a -> Vector a
standardized Vector a
vector
    | Vector a -> Bool
forall a. Unbox a => Vector a -> Bool
null Vector a
vector = Vector a
vector
    | a
s a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = (a -> a) -> Vector a -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
map (\a
x -> a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
u) Vector a
vector
    | Bool
otherwise = (a -> a) -> Vector a -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
map (\a
x -> (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
u) a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
s) Vector a
vector
  where
    u :: a
u = Vector a -> a
forall a. (Unbox a, Fractional a) => Vector a -> a
mean Vector a
vector
    s :: a
s = Vector a -> a
forall a. (Unbox a, Floating a) => Vector a -> a
stdev Vector a
vector
{-# SPECIALIZE standardized :: Vector Default -> Vector Default #-}

-- ------------------ --
-- Monadic Operations --
-- ------------------ --

-- | Lifted version of `sum`.
--
-- >>> import MCSP.System.Random (generateWith)
-- >>> generateWith (2,3) $ sumM [uniformN 4, uniformSN 4]
-- [0.17218197108856648,0.21998774703644852,-0.16158831286684616,1.0345635554897776]
sumM :: (Unbox a, Num a, Monad m) => NonEmpty (m (Vector a)) -> m (Vector a)
sumM :: forall a (m :: * -> *).
(Unbox a, Num a, Monad m) =>
NonEmpty (m (Vector a)) -> m (Vector a)
sumM NonEmpty (m (Vector a))
values = NonEmpty (Vector a) -> Vector a
forall a. (Unbox a, Num a) => NonEmpty (Vector a) -> Vector a
sum (NonEmpty (Vector a) -> Vector a)
-> m (NonEmpty (Vector a)) -> m (Vector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (m (Vector a)) -> m (NonEmpty (Vector a))
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => NonEmpty (m a) -> m (NonEmpty a)
sequence NonEmpty (m (Vector a))
values
{-# INLINE sumM #-}
{-# SPECIALIZE INLINE sumM :: Monad m => NonEmpty (m (Vector Default)) -> m (Vector Default) #-}
{-# SPECIALIZE INLINE sumM :: NonEmpty (Random (Vector Default)) -> Random (Vector Default) #-}

-- | Generate multiple uniformly distributed values in the given range.
--
-- Replicated version of `uniformR`.
--
-- >>> import MCSP.System.Random (generateWith)
-- >>> generateWith (2,3) $ uniformRN 5 50 3
-- [5.889456173182222,21.532789204576318,30.885279630109455]
uniformRN :: (Unbox a, Variate a) => a -> a -> Int -> Random (Vector a)
uniformRN :: forall a.
(Unbox a, Variate a) =>
a -> a -> Int -> Random (Vector a)
uniformRN a
lo a
hi Int
count = Int -> Random a -> Random (Vector a)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Int -> m a -> m (Vector a)
replicateM Int
count (a -> a -> Random a
forall a. Variate a => a -> a -> Random a
uniformR a
lo a
hi)
{-# SPECIALIZE uniformRN :: Default -> Default -> Int -> Random (Vector Default) #-}

-- | Generate multiple uniformly distributed values between @[0,1]@.
--
-- Replicated version of `uniform`.
--
-- >>> import MCSP.System.Random (generateWith)
-- >>> generateWith (2,3) $ uniformN 3
-- [1.9765692737382712e-2,0.3673953156572515,0.5752284362246546]
uniformN :: (Unbox a, Variate a, Num a) => Int -> Random (Vector a)
uniformN :: forall a. (Unbox a, Variate a, Num a) => Int -> Random (Vector a)
uniformN = a -> a -> Int -> Random (Vector a)
forall a.
(Unbox a, Variate a) =>
a -> a -> Int -> Random (Vector a)
uniformRN a
0 a
1
{-# SPECIALIZE uniformN :: Int -> Random (Vector Default) #-}

-- | Generate multiple uniformly distributed values between @[-1,1]@.
--
-- Signed version of `uniformN`.
--
-- >>> import MCSP.System.Random (generateWith)
-- >>> generateWith (2,3) $ uniformSN 3
-- [-0.9604686145252346,-0.26520936868549705,0.15045687244930916]
uniformSN :: (Unbox a, Variate a, Num a) => Int -> Random (Vector a)
uniformSN :: forall a. (Unbox a, Variate a, Num a) => Int -> Random (Vector a)
uniformSN = a -> a -> Int -> Random (Vector a)
forall a.
(Unbox a, Variate a) =>
a -> a -> Int -> Random (Vector a)
uniformRN (-a
1) a
1
{-# SPECIALIZE uniformSN :: Int -> Random (Vector Default) #-}

-- | Multiplies the vector by a single random value between @[0,maxWeight]@.
--
-- Randomized version of `.*.`. See also `weightedN`.
--
-- >>> import MCSP.System.Random (generateWith)
-- >>> generateWith (2,3) $ weighted 10 [1, 2, 10]
-- [0.19765692737382712,0.39531385474765424,1.9765692737382712]
weighted :: (Unbox a, Variate a, Num a) => a -> Vector a -> Random (Vector a)
weighted :: forall a.
(Unbox a, Variate a, Num a) =>
a -> Vector a -> Random (Vector a)
weighted a
maxWeight Vector a
vec = do
    a
k <- a -> a -> Random a
forall a. Variate a => a -> a -> Random a
uniformR a
0 a
maxWeight
    pure (a
k a -> Vector a -> Vector a
forall a. (Unbox a, Num a) => a -> Vector a -> Vector a
.*. Vector a
vec)
{-# SPECIALIZE weighted :: Default -> Vector Default -> Random (Vector Default) #-}

-- | Multiplies the vector by multiple random values between @[0,maxWeight]@.
--
-- Randomized version of `.*`. See also `weighted`.
--
-- >>> import MCSP.System.Random (generateWith)
-- >>> generateWith (2,3) $ weightedN 10 [1, 2, 10]
-- [0.19765692737382712,7.34790631314503,57.52284362246546]
weightedN :: (Unbox a, Variate a, Num a) => a -> Vector a -> Random (Vector a)
weightedN :: forall a.
(Unbox a, Variate a, Num a) =>
a -> Vector a -> Random (Vector a)
weightedN a
maxWeight Vector a
vec = do
    Vector a
k <- a -> a -> Int -> Random (Vector a)
forall a.
(Unbox a, Variate a) =>
a -> a -> Int -> Random (Vector a)
uniformRN a
0 a
maxWeight (Vector a -> Int
forall a. Unbox a => Vector a -> Int
length Vector a
vec)
    pure (Vector a
k Vector a -> Vector a -> Vector a
forall a. (Unbox a, Num a) => Vector a -> Vector a -> Vector a
.* Vector a
vec)
{-# SPECIALIZE weightedN :: Default -> Vector Default -> Random (Vector Default) #-}

-- | Choose randomly between multiple `Random` monad, with probablity proportional to its given
-- weight.
--
-- >>> import MCSP.System.Random (generateWith)
-- >>> generateWith (1,2) $ replicateM 10 $ choice [(1, pure 'a'), (2, pure 'b')]
-- "abbbaabbbb"
--
-- >>> generateWith (1,2) $ replicateM 10 $ choice [(1, uniformR (-1) 0), (2, uniformR 0 1)]
-- [-0.11816487538074749,0.5798377716767166,0.12231072251084052,0.754750234725723,0.5163453222019222,0.9673060222002038,-0.28900858364465354,0.609061325679456,-0.15187385001852494,0.4987697781636008]
choice :: NonEmpty (Double, Random a) -> Random a
choice :: forall a. NonEmpty (Double, Random a) -> Random a
choice NonEmpty (Double, Random a)
options = Random (Random a) -> Random a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (NonEmpty (Double, Random a) -> Random (Random a)
forall a. NonEmpty (Double, a) -> Random a
weightedChoice NonEmpty (Double, Random a)
options)
{-# INLINE choice #-}