Commit 63b7ec73 authored by Valentin Reis's avatar Valentin Reis
Browse files

refactor exp4r

parent 004cb9ab
......@@ -15,12 +15,20 @@
-- Adversarial Contextual Bandit. Proceedings of the 34th International
-- Conference on Machine Learning, in PMLR 70:3280-3288
module HBandit.Exp4R
( -- * State
Exp4R (..),
Exp4RCfg (..),
( -- * Interface
Feedback (..),
-- * State
Exp4R (..),
LastAction (..),
OCER (..),
-- * Configuration
Exp4RCfg (..),
-- * Experts
ObliviousRep (..),
-- * internal
mkMu,
mkDelta,
lambdaInitial,
......@@ -79,12 +87,6 @@ data Exp4RCfg s a er
}
deriving (Generic)
-- | Oblivious Categorical Expert Representation
data OCER a = OCER (NonEmpty (ZeroOne Double, a)) deriving (Generic)
instance ExpertRepresentation (OCER a) () a where
represent (OCER l) () = l
instance
(Eq a, ExpertRepresentation er s a) =>
ContextualBandit (Exp4R s a er) (Exp4RCfg s a er) s a (Maybe Feedback) er
......@@ -102,60 +104,49 @@ instance
experts = (R.unsafeRefine (1 / fromIntegral (NE.length expertsCfg)),) <$> expertsCfg
}
stepCtx g feedback s =
do
expertRepresentations <- use (field @"experts")
weightedExperts <- use (field @"experts") <&> fmap (fmap represent)
lam <- R.unrefine <$> use (field @"lambda")
beta <- use (field @"constraint")
mu <- get <&> mkMu
delta <- get <&> mkDelta
use (field @"lastAction") >>= \case
Nothing -> return ()
Just (LastAction _ p_a pPolicy_a) -> feedback & \case
Nothing -> panic "exp4R usage error: don't give feedback on first action."
Just f -> do
let cHat :: Double
cHat = R.unrefine (cost f) / R.unrefine p_a
rHat :: Double
rHat = R.unrefine (risk f) / R.unrefine p_a
yHats :: NonEmpty Double
yHats = (\p -> cHat * R.unrefine p) <$> pPolicy_a
zHats :: NonEmpty Double
zHats = (\p -> rHat * R.unrefine p) <$> pPolicy_a
wOld :: NonEmpty (ZeroOne Double)
wOld = fst <$> weightedExperts
expTerms :: NonEmpty Double
expTerms = NE.zipWith (\y z -> y + lam * z) yHats zHats
wUpdate = NE.zipWith (\(R.unrefine -> w) x -> w * exp (- mu * x)) wOld expTerms
wDenom = getSum $ sconcat $ Sum <$> wUpdate
field @"experts" .= NE.zipWith (\(_, e) w' -> (unsafeNormalizePanic w' wDenom, e)) expertRepresentations wUpdate
field @"lambda" .= R.unsafeRefine (max 0 (lam + mu * (((R.unrefine <$> wOld) `neDot` zHats) - R.unrefine beta - delta * mu * lam)))
let weightedAdviceMatrix :: NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a))
weightedAdviceMatrix = weightedExperts <&> \(wi, pi_i) -> (wi, pi_i s)
dirtyArmDistribution :: NonEmpty (Double, a)
dirtyArmDistribution = sconcat $ weightedAdviceMatrix <&> \(wi, advices) -> advices <&> \(p, ai) -> (R.unrefine p * R.unrefine wi, ai)
dirtyarmDistribution' :: NonEmpty (Double, a)
dirtyarmDistribution' = groupAllWith1 snd dirtyArmDistribution <&> \gs -> (getSum $ sconcat (gs <&> Sum . fst), snd $ NE.head gs)
armDistribution :: NonEmpty (ZeroOne Double, a)
armDistribution = normalizeDistribution dirtyarmDistribution' & \case
Nothing -> panic "internal Exp4R algorithm failure: distribution normalization failed."
Just d -> d
(a, g') = sampleWL armDistribution g
p_a = find (\x -> snd x == a) armDistribution & \case
Nothing -> panic "internal Exp4R algorithm failure: arm pull issue."
Just p -> fst p
probabilityOf_a :: NonEmpty (ZeroOne Double)
probabilityOf_a = snd <$> weightedAdviceMatrix
<&> \e ->
case find (\x -> snd x == a) e of
Nothing -> panic "internal Exp4R algorithm failure: weight computation"
Just (p, _) -> p
field @"lastAction" ?= LastAction a p_a probabilityOf_a
return (a, g')
neDot :: (Num a) => NonEmpty a -> NonEmpty a -> a
neDot x y = getSum $ sconcat (Sum <$> NE.zipWith (*) x y)
stepCtx g feedback s = do
weightedExperts <- use (field @"experts") <&> fmap (fmap represent)
lam <- R.unrefine <$> use (field @"lambda")
beta <- use (field @"constraint")
mu <- get <&> mkMu
delta <- get <&> mkDelta
use (field @"lastAction")
>>= traverse_
( \(LastAction _ (R.unrefine -> p_a) (fmap R.unrefine -> pPolicy_a)) -> feedback & \case
Nothing -> panic "exp4R usage error: don't give feedback on first action."
Just (Feedback (R.unrefine -> c) (R.unrefine -> r)) -> do
let expTerms :: NonEmpty Double
expTerms = pPolicy_a <&> (* ((r + c) / p_a))
wUpdate = NE.zipWith (\(R.unrefine -> w, _) x -> w * exp (- mu * x)) weightedExperts expTerms
wDenom = getSum $ sconcat $ Sum <$> wUpdate
field @"experts" %= NE.zipWith (\w' (_, e) -> (unsafeNormalizePanic w' wDenom, e)) wUpdate
let fDot (R.unrefine -> wi, _) p = wi * r * p / p_a
let dotted = getSum $ sconcat (Sum <$> NE.zipWith fDot weightedExperts pPolicy_a)
field @"lambda" .= R.unsafeRefine (max 0 (lam + mu * (dotted - R.unrefine beta - delta * mu * lam)))
)
let weightedAdviceMatrix :: NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a))
weightedAdviceMatrix = weightedExperts <&> fmap ($ s)
armDistribution :: NonEmpty (ZeroOne Double, a)
armDistribution = normalizeDistribution dirtyarmDistribution' & \case
Nothing -> panic "internal Exp4R algorithm failure: distribution normalization failed."
Just d -> d
where
dirtyarmDistribution' :: NonEmpty (Double, a)
dirtyarmDistribution' = groupAllWith1 snd dirtyArmDistribution <&> \gs -> (getSum $ sconcat (gs <&> Sum . fst), snd $ NE.head gs)
dirtyArmDistribution :: NonEmpty (Double, a)
dirtyArmDistribution = sconcat $ weightedAdviceMatrix <&> \(wi, advices) -> advices <&> \(p, ai) -> (R.unrefine p * R.unrefine wi, ai)
(a, g') = sampleWL armDistribution g
p_a = find (\x -> snd x == a) armDistribution & \case
Nothing -> panic "internal Exp4R algorithm failure: arm pull issue."
Just p -> fst p
probabilityOf_a :: NonEmpty (ZeroOne Double)
probabilityOf_a = snd <$> weightedAdviceMatrix
<&> \e ->
case find (\x -> snd x == a) e of
Nothing -> panic "internal Exp4R algorithm failure: weight computation"
Just (p, _) -> p
field @"lastAction" ?= LastAction a p_a probabilityOf_a
return (a, g')
-- | \( \mu = \sqrt{\frac{\ln N }{ (T(K+4))}} \)
mkMu :: Exp4R s a er -> Double
......@@ -169,3 +160,9 @@ mkDelta Exp4R {..} = fromIntegral $ 3 * k
-- | \( \lambda_1 = 0 \)
lambdaInitial :: R.Refined R.NonNegative Double
lambdaInitial = R.unsafeRefine 0
-- | Oblivious Expert Representation
newtype ObliviousRep a = ObliviousRep (NonEmpty (ZeroOne Double, a)) deriving (Generic)
instance ExpertRepresentation (ObliviousRep a) () a where
represent (ObliviousRep l) () = l
......@@ -10,7 +10,7 @@ import Data.Sequence
import H.Prelude as R
import HBandit.Class
import HBandit.Exp4R
import HBandit.Types
import HBandit.Types as HBT
import Protolude
import Refined hiding (NonEmpty)
import Refined.Unsafe
......@@ -32,7 +32,7 @@ data GameState
{ historyActions :: Seq Int,
historyCosts :: Seq Double,
historyConstraints :: Seq Double,
bandit :: Exp4R () Int
bandit :: Exp4R () Int (ObliviousRep Int)
}
deriving (Generic)
......@@ -40,7 +40,7 @@ onePass ::
( MonadState GameState m,
MonadIO m,
(Functor (Zoomed m' (Int, StdGen))),
(Zoom m' m (Exp4R () Int) GameState)
(Zoom m' m (Exp4R () Int (ObliviousRep Int)) GameState)
) =>
[(ZO, ZO, ZO, ZO, ZO, ZO)] ->
m ()
......@@ -127,14 +127,15 @@ plot1pass one_cost two_cost three_cost one_risk two_risk three_risk = do
ggsave("risk.pdf", riskPlot)
|]
where
expertsC :: NonEmpty (() -> NonEmpty (ZeroOne Double, Int))
expertsC = [expert1, expert2, expert3, expert4, expert5, expert6]
expert1 () = [(HBandit.Types.one, 1 :: Int), (HBandit.Types.zero, 2 :: Int), (HBandit.Types.zero, 3 :: Int)]
expert2 () = [(HBandit.Types.zero, 1 :: Int), (HBandit.Types.one, 2 :: Int), (HBandit.Types.zero, 3 :: Int)]
expert3 () = [(HBandit.Types.zero, 1 :: Int), (HBandit.Types.zero, 2 :: Int), (HBandit.Types.one, 3 :: Int)]
expert4 () = [(HBandit.Types.zero, 1 :: Int), (unsafeRefine 0.5, 2 :: Int), (unsafeRefine 0.5, 3 :: Int)]
expert5 () = [(unsafeRefine 0.5, 1 :: Int), (unsafeRefine 0.5, 2 :: Int), (HBandit.Types.zero, 3 :: Int)]
expert6 () = [(unsafeRefine 0.5, 1 :: Int), (HBandit.Types.zero, 2 :: Int), (unsafeRefine 0.5, 3 :: Int)]
expertsC :: NonEmpty (ObliviousRep Int)
expertsC =
[ ObliviousRep [(HBT.one, 1 :: Int), (HBT.zero, 2 :: Int), (HBT.zero, 3 :: Int)],
ObliviousRep [(HBT.zero, 1 :: Int), (HBT.one, 2 :: Int), (HBT.zero, 3 :: Int)],
ObliviousRep [(HBT.zero, 1 :: Int), (HBT.zero, 2 :: Int), (HBT.one, 3 :: Int)],
ObliviousRep [(HBT.zero, 1 :: Int), (unsafeRefine 0.5, 2 :: Int), (unsafeRefine 0.5, 3 :: Int)],
ObliviousRep [(unsafeRefine 0.5, 1 :: Int), (HBT.zero, 2 :: Int), (unsafeRefine 0.5, 3 :: Int)],
ObliviousRep [(unsafeRefine 0.5, 1 :: Int), (unsafeRefine 0.5, 2 :: Int), (HBT.zero, 3 :: Int)]
]
p = ZipList . fmap unsafeRefine
(ZipList dataset) =
(\a b c d e f -> (a, b, c, d, e, f) :: (ZO, ZO, ZO, ZO, ZO, ZO))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment