Commit 0b93da4e authored by Valentin Reis's avatar Valentin Reis
Browse files

Adds expert representations.

parent 9ebd2593
Pipeline #9893 failed with stages
in 11 seconds
......@@ -11,6 +11,7 @@
module HBandit.Class
( -- * Generalized Bandit
Bandit (..),
ExpertRepresentation (..),
ContextualBandit (..),
-- * Discrete Multi-Armed-Bandits
......@@ -22,6 +23,7 @@ where
import Protolude
import System.Random
import HBandit.Types
-- | Bandit b hyper f a l is the class for a bandit algorithm. This is mostly
-- here to help structure the library itself. We have the following bandit
......@@ -55,7 +57,7 @@ class Bandit b hyper a l | b -> l, b -> hyper, b -> a where
-- | @step loss@ iterates the bandit process one step forward.
step :: (RandomGen g, MonadState b m) => g -> l -> m (a, g)
class ContextualBandit b hyper s a l | b -> l, b -> hyper, b -> s, b -> a where
class ContextualBandit b hyper s a l er | b -> l, b -> hyper, b -> s, b -> a, b-> er where
-- | Init hyper returns the initial state of the algorithm
initCtx :: hyper -> b
......@@ -63,6 +65,9 @@ class ContextualBandit b hyper s a l | b -> l, b -> hyper, b -> s, b -> a where
-- | @step loss@ iterates the bandit process one step forward.
stepCtx :: (RandomGen g, MonadState b m, Ord a) => g -> l -> s -> m (a, g)
class ExpertRepresentation er s a | er -> s, er ->a where
represent :: er -> (s -> NonEmpty (ZeroOne Double, a))
newtype Arms a = Arms (NonEmpty a)
deriving (Show, Generic)
......
......@@ -20,6 +20,7 @@ module HBandit.Exp4R
Exp4RCfg (..),
Feedback (..),
LastAction (..),
OCER (..),
mkMu,
mkDelta,
lambdaInitial,
......@@ -37,7 +38,7 @@ import qualified Refined as R
import qualified Refined.Unsafe as R
-- | The EXP4R state
data Exp4R s a
data Exp4R s a er
= Exp4R
{ t :: Int,
horizon :: R.Refined R.Positive Int,
......@@ -49,7 +50,7 @@ data Exp4R s a
experts ::
NonEmpty
( ZeroOne Double,
s -> NonEmpty (ZeroOne Double, a)
er
)
}
deriving (Generic)
......@@ -67,19 +68,26 @@ data Feedback
{ cost :: ZeroOne Double,
risk :: ZeroOne Double
}
deriving (Generic)
data Exp4RCfg s a
data Exp4RCfg s a er
= Exp4RCfg
{ expertsCfg :: NonEmpty (s -> NonEmpty (ZeroOne Double, a)),
{ expertsCfg :: NonEmpty er,
constraintCfg :: ZeroOne Double,
horizonCfg :: R.Refined R.Positive Int,
as :: NonEmpty a
}
deriving (Generic)
-- | Oblivious Categorical Expert Representation
data OCER a = OCER (NonEmpty (ZeroOne Double, a)) deriving (Generic)
instance ExpertRepresentation (OCER a) () a where
represent (OCER l) () = l
instance
(Eq a) =>
ContextualBandit (Exp4R s a) (Exp4RCfg s a) s a (Maybe Feedback)
(Eq a, ExpertRepresentation er s a) =>
ContextualBandit (Exp4R s a er) (Exp4RCfg s a er) s a (Maybe Feedback) er
where
initCtx Exp4RCfg {..} =
......@@ -96,7 +104,8 @@ instance
stepCtx g feedback s =
do
weightedExperts <- use (field @"experts")
expertRepresentations <- use (field @"experts")
weightedExperts <- use (field @"experts") <&> fmap (fmap represent)
lam <- R.unrefine <$> use (field @"lambda")
beta <- use (field @"constraint")
mu <- get <&> mkMu
......@@ -120,7 +129,7 @@ instance
expTerms = NE.zipWith (\y z -> y + lam * z) yHats zHats
wUpdate = NE.zipWith (\(R.unrefine -> w) x -> w * exp (- mu * x)) wOld expTerms
wDenom = getSum $ sconcat $ Sum <$> wUpdate
field @"experts" .= NE.zipWith (\(_, e) w' -> (unsafeNormalizePanic w' wDenom, e)) weightedExperts wUpdate
field @"experts" .= NE.zipWith (\(_, e) w' -> (unsafeNormalizePanic w' wDenom, e)) expertRepresentations wUpdate
field @"lambda" .= R.unsafeRefine (max 0 (lam + mu * (((R.unrefine <$> wOld) `neDot` zHats) - R.unrefine beta - delta * mu * lam)))
let weightedAdviceMatrix :: NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a))
weightedAdviceMatrix = weightedExperts <&> \(wi, pi_i) -> (wi, pi_i s)
......@@ -149,12 +158,12 @@ neDot :: (Num a) => NonEmpty a -> NonEmpty a -> a
neDot x y = getSum $ sconcat (Sum <$> NE.zipWith (*) x y)
-- | \( \mu = \sqrt{\frac{\ln N }{ (T(K+4))}} \)
mkMu :: Exp4R s a -> Double
mkMu :: Exp4R s a er -> Double
mkMu Exp4R {..} =
sqrt $ log (fromIntegral n) / fromIntegral (R.unrefine horizon * (k + 4))
-- | \( \delta = 3K \)
mkDelta :: Exp4R s a -> Double
mkDelta :: Exp4R s a er -> Double
mkDelta Exp4R {..} = fromIntegral $ 3 * k
-- | \( \lambda_1 = 0 \)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment