Commit 256c0f95 authored by Valentin Reis's avatar Valentin Reis
Browse files

Initial Ep4R code refactor complete.

parent 004cb9ab
Pipeline #9897 failed with stages
in 12 seconds
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
-- |
-- Module : HBandit.Exp4R
-- Copyright : (c) 2019, UChicago Argonne, LLC.
-- License : MIT
-- Maintainer : fre@freux.fr
--
-- The contextual exponential-weight algorithm for Exploration and Exploitation
-- with Experts and Risk Constraints (EXP4R). See [1]
--
-- - [1] Sun, W., Dey, D. & Kapoor, A.. (2017). Safety-Aware Algorithms for
-- Adversarial Contextual Bandit. Proceedings of the 34th International
-- Conference on Machine Learning, in PMLR 70:3280-3288
module HBandit.Exp4R
( -- * Interface
Feedback (..),
-- * State
Exp4R (..),
LastAction (..),
-- * Configuration
Exp4RCfg (..),
-- * Experts
ObliviousRep (..),
-- * internal
mkMu,
mkDelta,
lambdaInitial,
)
where
import Control.Lens
import Data.Generics.Product
import Data.List.NonEmpty as NE
import HBandit.Class
import HBandit.Types
import HBandit.Util
import Protolude
import qualified Refined as R
import qualified Refined.Unsafe as R
-- | The EXP4R state
data Exp4R s a er
= Exp4R
{ t :: Int,
horizon :: R.Refined R.Positive Int,
lastAction :: Maybe (LastAction a),
k :: Int,
n :: Int,
lambda :: R.Refined R.NonNegative Double,
constraint :: ZeroOne Double,
experts ::
NonEmpty
( ZeroOne Double,
er
)
}
deriving (Generic)
data LastAction a
= LastAction
{ action :: a,
globalProbabilityOfSample :: ZeroOne Double,
perExpertProbabilityOfSample :: NonEmpty (ZeroOne Double)
}
deriving (Generic)
data Feedback
= Feedback
{ cost :: ZeroOne Double,
risk :: ZeroOne Double
}
deriving (Generic)
data Exp4RCfg s a er
= Exp4RCfg
{ expertsCfg :: NonEmpty er,
constraintCfg :: ZeroOne Double,
horizonCfg :: R.Refined R.Positive Int,
as :: NonEmpty a
}
deriving (Generic)
instance
(Eq a, ExpertRepresentation er s a) =>
ContextualBandit (Exp4R s a er) (Exp4RCfg s a er) s a (Maybe Feedback) er
where
initCtx Exp4RCfg {..} =
Exp4R
{ t = 1,
lastAction = Nothing,
k = NE.length as,
n = NE.length expertsCfg,
lambda = lambdaInitial,
constraint = constraintCfg,
horizon = horizonCfg,
experts = (R.unsafeRefine (1 / fromIntegral (NE.length expertsCfg)),) <$> expertsCfg
}
stepCtx g feedback s = do
weightedExperts <- use (field @"experts") <&> fmap (fmap represent)
lam <- R.unrefine <$> use (field @"lambda")
beta <- use (field @"constraint")
mu <- get <&> mkMu
delta <- get <&> mkDelta
use (field @"lastAction")
>>= traverse_
( \(LastAction _ (R.unrefine -> p_a) (fmap R.unrefine -> pPolicy_a)) -> feedback & \case
Nothing -> panic "exp4R usage error: don't give feedback on first action."
Just (Feedback (R.unrefine -> c) (R.unrefine -> r)) -> do
let expTerms :: NonEmpty Double
expTerms = pPolicy_a <&> (* ((r + c) / p_a))
wUpdate = NE.zipWith (\(R.unrefine -> w, _) x -> w * exp (- mu * x)) weightedExperts expTerms
wDenom = getSum $ sconcat $ Sum <$> wUpdate
field @"experts" %= NE.zipWith (\w' (_, e) -> (unsafeNormalizePanic w' wDenom, e)) wUpdate
let fDot (R.unrefine -> wi, _) p = wi * r * p / p_a
let dotted = getSum $ sconcat (Sum <$> NE.zipWith fDot weightedExperts pPolicy_a)
field @"lambda" .= R.unsafeRefine (max 0 (lam + mu * (dotted - R.unrefine beta - delta * mu * lam)))
)
let weightedAdviceMatrix :: NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a))
weightedAdviceMatrix = weightedExperts <&> fmap ($ s)
armDistribution :: NonEmpty (ZeroOne Double, a)
armDistribution = normalizeDistribution dirtyarmDistribution' & \case
Nothing -> panic "internal Exp4R algorithm failure: distribution normalization failed."
Just d -> d
where
dirtyarmDistribution' :: NonEmpty (Double, a)
dirtyarmDistribution' = groupAllWith1 snd dirtyArmDistribution <&> \gs -> (getSum $ sconcat (gs <&> Sum . fst), snd $ NE.head gs)
dirtyArmDistribution :: NonEmpty (Double, a)
dirtyArmDistribution = sconcat $ weightedAdviceMatrix <&> \(wi, advices) -> advices <&> \(p, ai) -> (R.unrefine p * R.unrefine wi, ai)
(a, g') = sampleWL armDistribution g
p_a = find (\x -> snd x == a) armDistribution & \case
Nothing -> panic "internal Exp4R algorithm failure: arm pull issue."
Just p -> fst p
probabilityOf_a :: NonEmpty (ZeroOne Double)
probabilityOf_a = snd <$> weightedAdviceMatrix
<&> \e ->
case find (\x -> snd x == a) e of
Nothing -> panic "internal Exp4R algorithm failure: weight computation"
Just (p, _) -> p
field @"lastAction" ?= LastAction a p_a probabilityOf_a
return (a, g')
-- | \( \mu = \sqrt{\frac{\ln N }{ (T(K+4))}} \)
mkMu :: Exp4R s a er -> Double
mkMu Exp4R {..} =
sqrt $ log (fromIntegral n) / fromIntegral (R.unrefine horizon * (k + 4))
-- | \( \delta = 3K \)
mkDelta :: Exp4R s a er -> Double
mkDelta Exp4R {..} = fromIntegral $ 3 * k
-- | \( \lambda_1 = 0 \)
lambdaInitial :: R.Refined R.NonNegative Double
lambdaInitial = R.unsafeRefine 0
-- | Oblivious Expert Representation
newtype ObliviousRep a = ObliviousRep (NonEmpty (ZeroOne Double, a)) deriving (Generic)
instance ExpertRepresentation (ObliviousRep a) () a where
represent (ObliviousRep l) () = l
......@@ -15,12 +15,20 @@
-- Adversarial Contextual Bandit. Proceedings of the 34th International
-- Conference on Machine Learning, in PMLR 70:3280-3288
module HBandit.Exp4R
( -- * State
Exp4R (..),
Exp4RCfg (..),
( -- * Interface
Feedback (..),
-- * State
Exp4R (..),
LastAction (..),
OCER (..),
-- * Configuration
Exp4RCfg (..),
-- * Experts
ObliviousRep (..),
-- * internal
mkMu,
mkDelta,
lambdaInitial,
......@@ -55,6 +63,7 @@ data Exp4R s a er
}
deriving (Generic)
-- | Encapsulator for 'last action taken'
data LastAction a
= LastAction
{ action :: a,
......@@ -63,6 +72,7 @@ data LastAction a
}
deriving (Generic)
-- | Constructor for feedback from the environment.
data Feedback
= Feedback
{ cost :: ZeroOne Double,
......@@ -70,6 +80,7 @@ data Feedback
}
deriving (Generic)
-- | Hyperparameters.
data Exp4RCfg s a er
= Exp4RCfg
{ expertsCfg :: NonEmpty er,
......@@ -79,12 +90,6 @@ data Exp4RCfg s a er
}
deriving (Generic)
-- | Oblivious Categorical Expert Representation
data OCER a = OCER (NonEmpty (ZeroOne Double, a)) deriving (Generic)
instance ExpertRepresentation (OCER a) () a where
represent (OCER l) () = l
instance
(Eq a, ExpertRepresentation er s a) =>
ContextualBandit (Exp4R s a er) (Exp4RCfg s a er) s a (Maybe Feedback) er
......@@ -102,60 +107,63 @@ instance
experts = (R.unsafeRefine (1 / fromIntegral (NE.length expertsCfg)),) <$> expertsCfg
}
stepCtx g feedback s =
do
expertRepresentations <- use (field @"experts")
weightedExperts <- use (field @"experts") <&> fmap (fmap represent)
lam <- R.unrefine <$> use (field @"lambda")
beta <- use (field @"constraint")
mu <- get <&> mkMu
delta <- get <&> mkDelta
use (field @"lastAction") >>= \case
Nothing -> return ()
Just (LastAction _ p_a pPolicy_a) -> feedback & \case
Nothing -> panic "exp4R usage error: don't give feedback on first action."
Just f -> do
let cHat :: Double
cHat = R.unrefine (cost f) / R.unrefine p_a
rHat :: Double
rHat = R.unrefine (risk f) / R.unrefine p_a
yHats :: NonEmpty Double
yHats = (\p -> cHat * R.unrefine p) <$> pPolicy_a
zHats :: NonEmpty Double
zHats = (\p -> rHat * R.unrefine p) <$> pPolicy_a
wOld :: NonEmpty (ZeroOne Double)
wOld = fst <$> weightedExperts
expTerms :: NonEmpty Double
expTerms = NE.zipWith (\y z -> y + lam * z) yHats zHats
wUpdate = NE.zipWith (\(R.unrefine -> w) x -> w * exp (- mu * x)) wOld expTerms
wDenom = getSum $ sconcat $ Sum <$> wUpdate
field @"experts" .= NE.zipWith (\(_, e) w' -> (unsafeNormalizePanic w' wDenom, e)) expertRepresentations wUpdate
field @"lambda" .= R.unsafeRefine (max 0 (lam + mu * (((R.unrefine <$> wOld) `neDot` zHats) - R.unrefine beta - delta * mu * lam)))
let weightedAdviceMatrix :: NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a))
weightedAdviceMatrix = weightedExperts <&> \(wi, pi_i) -> (wi, pi_i s)
dirtyArmDistribution :: NonEmpty (Double, a)
dirtyArmDistribution = sconcat $ weightedAdviceMatrix <&> \(wi, advices) -> advices <&> \(p, ai) -> (R.unrefine p * R.unrefine wi, ai)
dirtyarmDistribution' :: NonEmpty (Double, a)
dirtyarmDistribution' = groupAllWith1 snd dirtyArmDistribution <&> \gs -> (getSum $ sconcat (gs <&> Sum . fst), snd $ NE.head gs)
armDistribution :: NonEmpty (ZeroOne Double, a)
armDistribution = normalizeDistribution dirtyarmDistribution' & \case
Nothing -> panic "internal Exp4R algorithm failure: distribution normalization failed."
Just d -> d
(a, g') = sampleWL armDistribution g
p_a = find (\x -> snd x == a) armDistribution & \case
Nothing -> panic "internal Exp4R algorithm failure: arm pull issue."
Just p -> fst p
probabilityOf_a :: NonEmpty (ZeroOne Double)
probabilityOf_a = snd <$> weightedAdviceMatrix
<&> \e ->
case find (\x -> snd x == a) e of
Nothing -> panic "internal Exp4R algorithm failure: weight computation"
Just (p, _) -> p
field @"lastAction" ?= LastAction a p_a probabilityOf_a
return (a, g')
neDot :: (Num a) => NonEmpty a -> NonEmpty a -> a
neDot x y = getSum $ sconcat (Sum <$> NE.zipWith (*) x y)
stepCtx g feedback s = do
weightedExperts <- use (field @"experts") <&> fmap (fmap represent)
lam <- R.unrefine <$> use (field @"lambda")
beta <- use (field @"constraint")
mu <- get <&> mkMu
delta <- get <&> mkDelta
use (field @"lastAction")
>>= traverse_ \(LastAction _ (R.unrefine -> p_a) (fmap R.unrefine -> pPolicy_a)) ->
fromMaybe
(panic "exp4R usage error: do not give feedback on first action.")
( feedback
<&> \(Feedback (R.unrefine -> c) (R.unrefine -> r)) -> do
let numeratorTerm (R.unrefine -> w, _) p = w * exp (- mu * (p * (lam * r + c) / p_a))
let wUpdate = NE.zipWith numeratorTerm weightedExperts pPolicy_a
wDenom = getSum $ sconcat $ Sum <$> wUpdate
field @"experts" %= NE.zipWith (\w' (_, e) -> (unsafeNormalizePanic w' wDenom, e)) wUpdate
let fDot (R.unrefine -> wi, _) p = wi * r * p / p_a
let dotted = getSum $ sconcat (Sum <$> NE.zipWith fDot weightedExperts pPolicy_a)
field @"lambda" .= R.unsafeRefine (max 0 (lam + mu * (dotted - R.unrefine beta - delta * mu * lam)))
)
let weightedAdviceMatrix :: NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a))
weightedAdviceMatrix = weightedExperts <&> fmap ($ s)
armDistribution :: NonEmpty (ZeroOne Double, a)
armDistribution =
fromMaybe
(panic "internal Exp4R algorithm failure: distribution normalization failed.")
(combineAdvice weightedAdviceMatrix)
(a, g') = sampleWL armDistribution g
p_a =
fst $
fromMaybe
(panic "internal Exp4R algorithm failure: arm pull issue.")
(find (\x -> snd x == a) armDistribution)
probabilityOf_a :: NonEmpty (ZeroOne Double)
probabilityOf_a = snd <$> weightedAdviceMatrix
<&> \e ->
( fst $
fromMaybe
(panic "internal Exp4R algorithm failure: weight computation")
(find (\x -> snd x == a) e)
)
field @"lastAction" ?= LastAction a p_a probabilityOf_a
return (a, g')
-- | combineAdvice turns weighted expert advice into a probability distribution to
-- sample from.
combineAdvice ::
(Ord a) =>
NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a)) ->
Maybe (NonEmpty (ZeroOne Double, a))
combineAdvice weightedAdviceMatrix = normalizeDistribution $
groupAllWith1 snd dirtyArmDistribution
<&> \gs -> (getSum $ sconcat (gs <&> Sum . fst), snd $ NE.head gs)
where
dirtyArmDistribution = sconcat $
weightedAdviceMatrix
<&> \(wi, advices) -> advices <&> \(p, ai) -> (R.unrefine p * R.unrefine wi, ai)
-- | \( \mu = \sqrt{\frac{\ln N }{ (T(K+4))}} \)
mkMu :: Exp4R s a er -> Double
......@@ -169,3 +177,9 @@ mkDelta Exp4R {..} = fromIntegral $ 3 * k
-- | \( \lambda_1 = 0 \)
lambdaInitial :: R.Refined R.NonNegative Double
lambdaInitial = R.unsafeRefine 0
-- | Oblivious Categorical Expert Representation
data ObliviousRep a = ObliviousRep (NonEmpty (ZeroOne Double, a)) deriving (Generic)
instance ExpertRepresentation (ObliviousRep a) () a where
represent (ObliviousRep l) () = l
......@@ -10,7 +10,7 @@ import Data.Sequence
import H.Prelude as R
import HBandit.Class
import HBandit.Exp4R
import HBandit.Types
import HBandit.Types as HBT
import Protolude
import Refined hiding (NonEmpty)
import Refined.Unsafe
......@@ -32,7 +32,7 @@ data GameState
{ historyActions :: Seq Int,
historyCosts :: Seq Double,
historyConstraints :: Seq Double,
bandit :: Exp4R () Int
bandit :: Exp4R () Int (ObliviousRep Int)
}
deriving (Generic)
......@@ -40,7 +40,7 @@ onePass ::
( MonadState GameState m,
MonadIO m,
(Functor (Zoomed m' (Int, StdGen))),
(Zoom m' m (Exp4R () Int) GameState)
(Zoom m' m (Exp4R () Int (ObliviousRep Int)) GameState)
) =>
[(ZO, ZO, ZO, ZO, ZO, ZO)] ->
m ()
......@@ -76,7 +76,7 @@ plot1pass one_cost two_cost three_cost one_risk two_risk three_risk = do
let b = initCtx $ Exp4RCfg
{ expertsCfg = expertsC,
constraintCfg = unsafeRefine 0.5,
horizonCfg = unsafeRefine 500000,
horizonCfg = unsafeRefine 5000,
as = [1, 2, 3]
}
( GameState
......@@ -127,14 +127,15 @@ plot1pass one_cost two_cost three_cost one_risk two_risk three_risk = do
ggsave("risk.pdf", riskPlot)
|]
where
expertsC :: NonEmpty (() -> NonEmpty (ZeroOne Double, Int))
expertsC = [expert1, expert2, expert3, expert4, expert5, expert6]
expert1 () = [(HBandit.Types.one, 1 :: Int), (HBandit.Types.zero, 2 :: Int), (HBandit.Types.zero, 3 :: Int)]
expert2 () = [(HBandit.Types.zero, 1 :: Int), (HBandit.Types.one, 2 :: Int), (HBandit.Types.zero, 3 :: Int)]
expert3 () = [(HBandit.Types.zero, 1 :: Int), (HBandit.Types.zero, 2 :: Int), (HBandit.Types.one, 3 :: Int)]
expert4 () = [(HBandit.Types.zero, 1 :: Int), (unsafeRefine 0.5, 2 :: Int), (unsafeRefine 0.5, 3 :: Int)]
expert5 () = [(unsafeRefine 0.5, 1 :: Int), (unsafeRefine 0.5, 2 :: Int), (HBandit.Types.zero, 3 :: Int)]
expert6 () = [(unsafeRefine 0.5, 1 :: Int), (HBandit.Types.zero, 2 :: Int), (unsafeRefine 0.5, 3 :: Int)]
expertsC :: NonEmpty (ObliviousRep Int)
expertsC =
[ ObliviousRep [(HBT.one, 1 :: Int), (HBT.zero, 2 :: Int), (HBT.zero, 3 :: Int)],
ObliviousRep [(HBT.zero, 1 :: Int), (HBT.one, 2 :: Int), (HBT.zero, 3 :: Int)],
ObliviousRep [(HBT.zero, 1 :: Int), (HBT.zero, 2 :: Int), (HBT.one, 3 :: Int)],
ObliviousRep [(HBT.zero, 1 :: Int), (unsafeRefine 0.5, 2 :: Int), (unsafeRefine 0.5, 3 :: Int)],
ObliviousRep [(unsafeRefine 0.5, 1 :: Int), (HBT.zero, 2 :: Int), (unsafeRefine 0.5, 3 :: Int)],
ObliviousRep [(unsafeRefine 0.5, 1 :: Int), (unsafeRefine 0.5, 2 :: Int), (HBT.zero, 3 :: Int)]
]
p = ZipList . fmap unsafeRefine
(ZipList dataset) =
(\a b c d e f -> (a, b, c, d, e, f) :: (ZO, ZO, ZO, ZO, ZO, ZO))
......@@ -169,4 +170,4 @@ main =
$ do
for_ rpackages rrequire
[r| theme_set(theme_bw()) |]
void $ experiment 500000
void $ experiment 5000
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment