Commit da4a3f2a authored by Valentin Reis's avatar Valentin Reis
Browse files

[fix] bug fixes, cleaning the source

This fixes a long standing bug in the weight update code and
adds the UCB family of algorithms.
parent 3ed8a35e
Pipeline #10618 failed with stages
in 9 seconds
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
......@@ -27,6 +27,7 @@ library
Bandit.Class
Bandit.EpsGreedy
Bandit.Exp3
Bandit.UCB
Bandit.Exp4R
Bandit.Types
Bandit.Util
......@@ -59,5 +60,6 @@ library
refined -any,
intervals -any,
MonadRandom -any,
list-extras -any,
lens -any,
generic-lens -any
......@@ -102,6 +102,9 @@ let deps =
nobound "primitive"
, containers =
nobound "containers"
, list-extras =
nobound "list-extras"
, bytestring =
nobound "bytestring"
, storable-endian =
......@@ -121,6 +124,7 @@ let allmodules =
, "Bandit.Class"
, "Bandit.EpsGreedy"
, "Bandit.Exp3"
, "Bandit.UCB"
, "Bandit.Exp4R"
, "Bandit.Types"
, "Bandit.Util"
......@@ -133,6 +137,7 @@ let libdep =
, deps.refined
, deps.intervals
, deps.monadRandom
, deps.list-extras
, deps.lens
, deps.generic-lens
]
......
{ mkDerivation, base, generic-lens, intervals, lens, MonadRandom
, protolude, random, refined, stdenv
{ mkDerivation, base, generic-lens, intervals, lens, list-extras
, MonadRandom, protolude, random, refined, stdenv
}:
mkDerivation {
pname = "hbandit";
version = "1.0.0";
src = ./.;
libraryHaskellDepends = [
base generic-lens intervals lens MonadRandom protolude random
refined
base generic-lens intervals lens list-extras MonadRandom protolude
random refined
];
description = "hbandit";
license = stdenv.lib.licenses.bsd3;
......
......@@ -4,9 +4,6 @@ haskellPackages.shellFor {
packages = p: [ haskellPackages.hbandit ];
withHoogle = true;
buildInputs = [
(rWrapper.override {
packages = with rPackages; [ ggplot2 svglite dplyr msgpackR knitr ];
})
ghcid
dhall
pythonPackages.nbconvert
......@@ -20,7 +17,6 @@ haskellPackages.shellFor {
cabal-install
];
shellHook = ''
export R_LIBS_SITE=${builtins.readFile r-libs-site}
export LOCALE_ARCHIVE=${glibcLocales}/lib/locale/locale-archive
export LANG=en_US.UTF-8
export NIX_GHC="${haskellPackages.hbandit.env.NIX_GHC}"
......
......@@ -19,6 +19,7 @@ module Bandit.Class
-- hyperparameters.
ExpertRepresentation (..),
Rate (..),
InvLFPhi (..),
)
where
......@@ -50,7 +51,6 @@ import System.Random
-- * @l@ is a superset of admissible losses \(\mathbb{L}\) (statically
-- known).
class Bandit b hyper a l | b -> l, b -> hyper, b -> a where
-- | Init hyper returns the initial state of the algorithm and the
-- first action.
init :: (RandomGen g) => g -> hyper -> (b, a, g)
......@@ -63,7 +63,6 @@ class Bandit b hyper a l | b -> l, b -> hyper, b -> a where
--
-- * @er@ is an expert representation (see 'ExpertRepresentation')
class (ExpertRepresentation er s a) => ContextualBandit b hyper s a l er | b -> l, b -> hyper, b -> s, b -> a, b -> er where
-- | Init hyper returns the initial state of the algorithm
initCtx :: hyper -> b
......@@ -92,3 +91,13 @@ instance Rate FixedRate where
instance Rate InverseSqrtRate where
toRate x t = coerce x / sqrt (fromIntegral t)
-- | InvLFPhi r is the inverse of the legendre-fenchel transform
-- of the convex function \(\Phi\) that parametrizes an UCB learner.
--
-- @toRate r@ returns the rate schedule.
class InvLFPhi a where
toInvLFPhi :: a -> Double -> Double
instance InvLFPhi AlphaUCBInvLFPhi where
toInvLFPhi _ x = sqrt (coerce x / 2)
{-# OPTIONS_GHC -fno-warn-partial-fields #-}
-- |
-- Module : Bandit.EpsGreedy
-- Copyright : (c) 2019, UChicago Argonne, LLC.
......@@ -14,9 +12,13 @@
module Bandit.EpsGreedy
( EpsGreedy (..),
Weight (..),
Screening (..),
EpsGreedyHyper (..),
Params (..),
ExploreExploit (..),
pickRandom,
updateAvgLoss,
updateWeight,
)
where
......@@ -50,10 +52,9 @@ data Screening a
deriving (Show, Generic)
-- | The sampling procedure has started.
data ExploreExploit a
newtype ExploreExploit a
= ExploreExploit
{ k :: Int,
weights :: NonEmpty (Weight a)
{ weights :: NonEmpty (Weight a)
}
deriving (Show, Generic)
......@@ -81,16 +82,17 @@ data EpsGreedyHyper a r
-- | The variable rate \(\epsilon\)-Greedy MAB algorithm.
-- Offers no interesting guarantees, works well in practice.
instance (Rate r, Eq a) => Bandit (EpsGreedy a r) (EpsGreedyHyper a r) a Double where
init g (EpsGreedyHyper r (Arms (a :| as))) =
( EpsGreedy
{ t = 1,
rate = r,
lastAction = a,
params = InitialScreening $ Screening
{ screened = [],
screenQueue = as
}
params =
InitialScreening $
Screening
{ screened = [],
screenQueue = as
}
},
a,
g
......@@ -100,33 +102,31 @@ instance (Rate r, Eq a) => Bandit (EpsGreedy a r) (EpsGreedyHyper a r) a Double
oldAction <- use #lastAction
schedule <- use #rate <&> toRate
e <- use #t <&> schedule
#t += 1
(a, newGen) <- use #params >>= \case
InitialScreening sg ->
case screenQueue sg of
(a : as) -> do
#params . #_InitialScreening .= Screening
{ screened = (l, oldAction) : screened sg,
screenQueue = as
}
#params
.= InitialScreening
( Screening
{ screened = (l, oldAction) : screened sg,
screenQueue = as
}
)
return (a, g)
[] -> do
let ee = ExploreExploit
{ k = length (screened sg) + 1,
weights = toW <$> ((l, oldAction) :| screened sg)
}
#params . #_Started .= ee
let ee =
ExploreExploit
{ weights = toW <$> ((l, oldAction) :| screened sg)
}
#params .= Started ee
pickreturn e g ee
Started s -> do
let eeg =
s
{ weights = weights s <&> \w ->
if action w == oldAction
then updateAvgLoss l w
else w
}
pickreturn e g eeg
Started ee -> do
let ee' = ee & #weights %~ updateWeight oldAction l
#params . #_Started .= ee'
pickreturn e g ee
#lastAction .= a
#t += 1
return (a, newGen)
-- | Action selection and return
......@@ -153,12 +153,21 @@ pickRandom ExploreExploit {..} =
w2tuple :: Weight b -> (Double, b)
w2tuple (Weight _avgloss _hits action) = (1, action)
-- | rudimentary online mean accumulator.
-- | online mean accumulator.
updateAvgLoss :: Double -> Weight a -> Weight a
updateAvgLoss l (Weight avgloss hits action) =
Weight
( (avgloss * fromIntegral hits + l)
/ (fromIntegral hits + 1)
)
(hits + 1)
action
updateAvgLoss x w = w &~ do
#hits += 1
n <- use #hits <&> fromIntegral
avg <- use #averageLoss
#averageLoss += (x - avg) / (n + 1)
-- | updating the weights
updateWeight ::
(Eq a) =>
a ->
Double ->
NonEmpty (Weight a) ->
NonEmpty (Weight a)
updateWeight a l = fmap updateIf
where
updateIf w@Weight {..} = if action == a then updateAvgLoss l w else w
......@@ -39,15 +39,15 @@ data Exp3 a
k :: Int,
weights :: NonEmpty (Weight a)
}
deriving (Generic)
deriving (Show, Generic)
-- | Probability of picking an action
newtype Probability = Probability {getProbability :: Double}
deriving (Generic)
deriving (Show, Generic)
-- | Cumulative loss counter for an action
newtype CumulativeLoss = CumulativeLoss {getCumulativeLoss :: Double}
deriving (Generic)
deriving (Show, Generic)
-- | Exp3 weight for one action
data Weight a
......@@ -56,14 +56,13 @@ data Weight a
cumulativeLoss :: CumulativeLoss,
action :: a
}
deriving (Generic)
deriving (Show, Generic)
-- | The Exponential-weight algorithm for Exploration and Exploitation (EXP3).
instance
(Eq a) =>
Bandit (Exp3 a) (Arms a) a (ZeroOne Double)
where
init g (Arms as) =
( Exp3
{ t = 1,
......
......@@ -10,9 +10,11 @@ module Bandit.Types
ObliviousRep (..),
FixedRate (..),
InverseSqrtRate (..),
AlphaUCBInvLFPhi (..),
ZeroOne,
Bandit.Types.zero,
Bandit.Types.one,
rewardCostBijection,
)
where
......@@ -31,6 +33,9 @@ zero = unsafeRefine 0
one :: (Ord a, Num a) => ZeroOne a
one = unsafeRefine 1
rewardCostBijection :: (Ord a, Num a) => ZeroOne a -> ZeroOne a
rewardCostBijection x = unsafeRefine $ (1 - unrefine x)
-- | Arms a represents a set of possible actions.
newtype Arms a = Arms (Protolude.NonEmpty a)
deriving (Show, Generic)
......@@ -38,10 +43,13 @@ newtype Arms a = Arms (Protolude.NonEmpty a)
-- | Oblivious Categorical Expert Representation
newtype ObliviousRep a
= ObliviousRep (Protolude.NonEmpty (ZeroOne Double, a))
deriving (Generic)
deriving (Show, Generic)
newtype FixedRate = FixedRate Double
deriving (Generic)
deriving (Show, Generic)
newtype InverseSqrtRate = InverseSqrtRate Double
deriving (Generic)
deriving (Show, Generic)
data AlphaUCBInvLFPhi = AlphaUCBInvLFPhi
deriving (Show, Generic)
{-# OPTIONS_GHC -fno-warn-partial-fields #-}
-- |
-- Module : Bandit.UCB
-- Copyright : (c) 2019, UChicago Argonne, LLC.
-- License : MIT
-- Maintainer : fre@freux.fr
--
-- This module implements the UCB family of algorithms.
module Bandit.UCB
( UCB (..),
UCBHyper (..),
hyperAlphaUCB,
hyperUCB1,
)
where
import Bandit.Class
import Bandit.EpsGreedy
import Bandit.Types
import Control.Lens
import Data.Generics.Labels ()
import Data.List.Extras.Argmax
import Protolude
import Refined
data UCBHyper a r
= UCBHyper
{ invLFPhiUCB :: r,
alphaUCB :: Double,
armsUCB :: Arms a
}
deriving (Show, Generic)
-- | Hyperparameter for \(\alpha\)-UCB
hyperAlphaUCB :: Double -> Arms a -> UCBHyper a AlphaUCBInvLFPhi
hyperAlphaUCB = UCBHyper AlphaUCBInvLFPhi
-- | Hyperparameter for parameter-free UCB1
hyperUCB1 :: Arms a -> UCBHyper a AlphaUCBInvLFPhi
hyperUCB1 = hyperAlphaUCB 4
-- | State for the UCB algorithm.
data UCB a p
= UCB
{ t :: Int,
invLFPhi :: p,
alpha :: Double,
lastAction :: a,
params :: Params a
}
deriving (Show, Generic)
toW :: (Double, a) -> Weight a
toW (loss, action) = Weight loss 1 action
-- | The variable rate \(\epsilon\)-Greedy MAB algorithm.
-- Offers no interesting guarantees, works well in practice.
instance (InvLFPhi p, Eq a) => Bandit (UCB a p) (UCBHyper a p) a (ZeroOne Double) where
init g (UCBHyper invLFPhiUCB alphaUCB (Arms (a :| as))) =
( UCB
{ t = 1,
alpha = alphaUCB,
invLFPhi = invLFPhiUCB,
lastAction = a,
params =
InitialScreening $
Screening
{ screened = [],
screenQueue = as
}
},
a,
g
)
step g (unrefine . rewardCostBijection -> l) = do
oldAction <- use #lastAction
invLFPhiFunc <- use #invLFPhi <&> toInvLFPhi
alphaValue <- use #alpha
#t += 1
t <- use #t
a <- use #params >>= \case
InitialScreening sg ->
case screenQueue sg of
(a : as) -> do
#params . #_InitialScreening
.= Screening
{ screened = (l, oldAction) : screened sg,
screenQueue = as
}
return a
[] -> do
let ee =
ExploreExploit
{ weights = toW <$> ((l, oldAction) :| screened sg)
}
#params .= Started ee
pickreturn t invLFPhiFunc alphaValue ee
Started ee -> do
let ee' = ee & #weights %~ updateWeight oldAction l
#params .= Started ee'
pickreturn t invLFPhiFunc alphaValue ee'
#lastAction .= a
return (a, g)
--- | Action selection and return
pickreturn ::
(MonadState (UCB a r) m) =>
Int ->
(Double -> Double) ->
Double ->
ExploreExploit a ->
m a
pickreturn t phiInv alpha (ExploreExploit weights) =
return . action . argmax f $ toList weights
where
f Weight {..} =
averageLoss + phiInv (alpha * log (fromIntegral t) / fromIntegral hits)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment