Commit 62c40b9f authored by Valentin Reis's avatar Valentin Reis
Browse files

Merge branch 'develop' of xgitlab.cels.anl.gov:argo/hbandit into develop

parents 63b7ec73 256c0f95
Pipeline #9898 failed with stages
in 11 seconds
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -63,6 +63,7 @@ data Exp4R s a er
}
deriving (Generic)
-- | Encapsulator for 'last action taken'
data LastAction a
= LastAction
{ action :: a,
......@@ -71,6 +72,7 @@ data LastAction a
}
deriving (Generic)
-- | Constructor for feedback from the environment.
data Feedback
= Feedback
{ cost :: ZeroOne Double,
......@@ -78,6 +80,7 @@ data Feedback
}
deriving (Generic)
-- | Hyperparameters.
data Exp4RCfg s a er
= Exp4RCfg
{ expertsCfg :: NonEmpty er,
......@@ -111,43 +114,57 @@ instance
mu <- get <&> mkMu
delta <- get <&> mkDelta
use (field @"lastAction")
>>= traverse_
( \(LastAction _ (R.unrefine -> p_a) (fmap R.unrefine -> pPolicy_a)) -> feedback & \case
Nothing -> panic "exp4R usage error: don't give feedback on first action."
Just (Feedback (R.unrefine -> c) (R.unrefine -> r)) -> do
let expTerms :: NonEmpty Double
expTerms = pPolicy_a <&> (* ((r + c) / p_a))
wUpdate = NE.zipWith (\(R.unrefine -> w, _) x -> w * exp (- mu * x)) weightedExperts expTerms
wDenom = getSum $ sconcat $ Sum <$> wUpdate
field @"experts" %= NE.zipWith (\w' (_, e) -> (unsafeNormalizePanic w' wDenom, e)) wUpdate
let fDot (R.unrefine -> wi, _) p = wi * r * p / p_a
let dotted = getSum $ sconcat (Sum <$> NE.zipWith fDot weightedExperts pPolicy_a)
field @"lambda" .= R.unsafeRefine (max 0 (lam + mu * (dotted - R.unrefine beta - delta * mu * lam)))
)
>>= traverse_ \(LastAction _ (R.unrefine -> p_a) (fmap R.unrefine -> pPolicy_a)) ->
fromMaybe
(panic "exp4R usage error: do not give feedback on first action.")
( feedback
<&> \(Feedback (R.unrefine -> c) (R.unrefine -> r)) -> do
let numeratorTerm (R.unrefine -> w, _) p = w * exp (- mu * (p * (lam * r + c) / p_a))
let wUpdate = NE.zipWith numeratorTerm weightedExperts pPolicy_a
wDenom = getSum $ sconcat $ Sum <$> wUpdate
field @"experts" %= NE.zipWith (\w' (_, e) -> (unsafeNormalizePanic w' wDenom, e)) wUpdate
let fDot (R.unrefine -> wi, _) p = wi * r * p / p_a
let dotted = getSum $ sconcat (Sum <$> NE.zipWith fDot weightedExperts pPolicy_a)
field @"lambda" .= R.unsafeRefine (max 0 (lam + mu * (dotted - R.unrefine beta - delta * mu * lam)))
)
let weightedAdviceMatrix :: NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a))
weightedAdviceMatrix = weightedExperts <&> fmap ($ s)
armDistribution :: NonEmpty (ZeroOne Double, a)
armDistribution = normalizeDistribution dirtyarmDistribution' & \case
Nothing -> panic "internal Exp4R algorithm failure: distribution normalization failed."
Just d -> d
where
dirtyarmDistribution' :: NonEmpty (Double, a)
dirtyarmDistribution' = groupAllWith1 snd dirtyArmDistribution <&> \gs -> (getSum $ sconcat (gs <&> Sum . fst), snd $ NE.head gs)
dirtyArmDistribution :: NonEmpty (Double, a)
dirtyArmDistribution = sconcat $ weightedAdviceMatrix <&> \(wi, advices) -> advices <&> \(p, ai) -> (R.unrefine p * R.unrefine wi, ai)
armDistribution =
fromMaybe
(panic "internal Exp4R algorithm failure: distribution normalization failed.")
(combineAdvice weightedAdviceMatrix)
(a, g') = sampleWL armDistribution g
p_a = find (\x -> snd x == a) armDistribution & \case
Nothing -> panic "internal Exp4R algorithm failure: arm pull issue."
Just p -> fst p
p_a =
fst $
fromMaybe
(panic "internal Exp4R algorithm failure: arm pull issue.")
(find (\x -> snd x == a) armDistribution)
probabilityOf_a :: NonEmpty (ZeroOne Double)
probabilityOf_a = snd <$> weightedAdviceMatrix
<&> \e ->
case find (\x -> snd x == a) e of
Nothing -> panic "internal Exp4R algorithm failure: weight computation"
Just (p, _) -> p
( fst $
fromMaybe
(panic "internal Exp4R algorithm failure: weight computation")
(find (\x -> snd x == a) e)
)
field @"lastAction" ?= LastAction a p_a probabilityOf_a
return (a, g')
-- | combineAdvice turns weighted expert advice into a probability distribution to
-- sample from.
combineAdvice ::
(Ord a) =>
NonEmpty (ZeroOne Double, NonEmpty (ZeroOne Double, a)) ->
Maybe (NonEmpty (ZeroOne Double, a))
combineAdvice weightedAdviceMatrix = normalizeDistribution $
groupAllWith1 snd dirtyArmDistribution
<&> \gs -> (getSum $ sconcat (gs <&> Sum . fst), snd $ NE.head gs)
where
dirtyArmDistribution = sconcat $
weightedAdviceMatrix
<&> \(wi, advices) -> advices <&> \(p, ai) -> (R.unrefine p * R.unrefine wi, ai)
-- | \( \mu = \sqrt{\frac{\ln N }{ (T(K+4))}} \)
mkMu :: Exp4R s a er -> Double
mkMu Exp4R {..} =
......@@ -161,8 +178,8 @@ mkDelta Exp4R {..} = fromIntegral $ 3 * k
lambdaInitial :: R.Refined R.NonNegative Double
lambdaInitial = R.unsafeRefine 0
-- | Oblivious Expert Representation
newtype ObliviousRep a = ObliviousRep (NonEmpty (ZeroOne Double, a)) deriving (Generic)
-- | Oblivious Categorical Expert Representation
data ObliviousRep a = ObliviousRep (NonEmpty (ZeroOne Double, a)) deriving (Generic)
instance ExpertRepresentation (ObliviousRep a) () a where
represent (ObliviousRep l) () = l
......@@ -76,7 +76,7 @@ plot1pass one_cost two_cost three_cost one_risk two_risk three_risk = do
let b = initCtx $ Exp4RCfg
{ expertsCfg = expertsC,
constraintCfg = unsafeRefine 0.5,
horizonCfg = unsafeRefine 500000,
horizonCfg = unsafeRefine 5000,
as = [1, 2, 3]
}
( GameState
......@@ -170,4 +170,4 @@ main =
$ do
for_ rpackages rrequire
[r| theme_set(theme_bw()) |]
void $ experiment 500000
void $ experiment 5000
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment