UCB.hs 3.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
{-# OPTIONS_GHC -fno-warn-partial-fields #-}

-- |
-- Module      : Bandit.UCB
-- Copyright   : (c) 2019, UChicago Argonne, LLC.
-- License     : MIT
-- Maintainer  : fre@freux.fr
--
-- This module implements the UCB family of algorithms.
module Bandit.UCB
  ( UCB (..),
    UCBHyper (..),
    hyperAlphaUCB,
    hyperUCB1,
  )
where

import Bandit.Class
import Bandit.EpsGreedy
import Bandit.Types
21
import Bandit.Util (argmax')
22
23
24
25
26
import Control.Lens
import Data.Generics.Labels ()
import Protolude
import Refined

Valentin Reis's avatar
Valentin Reis committed
27
-- | Hyperparameter for \((\alpha,\phi)\)-UCB
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
data UCBHyper a r
  = UCBHyper
      { invLFPhiUCB :: r,
        alphaUCB :: Double,
        armsUCB :: Arms a
      }
  deriving (Show, Generic)

-- | Hyperparameter for \(\alpha\)-UCB
hyperAlphaUCB :: Double -> Arms a -> UCBHyper a AlphaUCBInvLFPhi
hyperAlphaUCB = UCBHyper AlphaUCBInvLFPhi

-- | Hyperparameter for parameter-free UCB1
hyperUCB1 :: Arms a -> UCBHyper a AlphaUCBInvLFPhi
hyperUCB1 = hyperAlphaUCB 4

-- | State for the UCB algorithm.
data UCB a p
  = UCB
      { t :: Int,
        invLFPhi :: p,
        alpha :: Double,
        lastAction :: a,
        params :: Params a
      }
  deriving (Show, Generic)

toW :: (Double, a) -> Weight a
toW (loss, action) = Weight loss 1 action

-- | The variable rate \(\epsilon\)-Greedy MAB algorithm.
-- Offers no interesting guarantees, works well in practice.
instance (InvLFPhi p, Eq a) => Bandit (UCB a p) (UCBHyper a p) a (ZeroOne Double) where
61

62
63
64
65
66
67
  init g (UCBHyper invLFPhiUCB alphaUCB (Arms (a :| as))) =
    ( UCB
        { t = 1,
          alpha = alphaUCB,
          invLFPhi = invLFPhiUCB,
          lastAction = a,
68
69
70
71
72
          params = InitialScreening $
            Screening
              { screened = [],
                screenQueue = as
              }
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        },
      a,
      g
    )

  step g (unrefine . rewardCostBijection -> l) = do
    oldAction <- use #lastAction
    invLFPhiFunc <- use #invLFPhi <&> toInvLFPhi
    alphaValue <- use #alpha
    #t += 1
    t <- use #t
    a <- use #params >>= \case
      InitialScreening sg ->
        case screenQueue sg of
          (a : as) -> do
            #params . #_InitialScreening
              .= Screening
                { screened = (l, oldAction) : screened sg,
                  screenQueue = as
                }
            return a
          [] -> do
            let ee =
                  ExploreExploit
                    { weights = toW <$> ((l, oldAction) :| screened sg)
                    }
            #params .= Started ee
            pickreturn t invLFPhiFunc alphaValue ee
      Started ee -> do
        let ee' = ee & #weights %~ updateWeight oldAction l
        #params .= Started ee'
        pickreturn t invLFPhiFunc alphaValue ee'
    #lastAction .= a
    return (a, g)

--- | Action selection and  return
pickreturn ::
  (MonadState (UCB a r) m) =>
  Int ->
  (Double -> Double) ->
  Double ->
  ExploreExploit a ->
  m a
pickreturn t phiInv alpha (ExploreExploit weights) =
117
  return . action . argmax' f $ toList weights
118
119
120
  where
    f Weight {..} =
      averageLoss + phiInv (alpha * log (fromIntegral t) / fromIntegral hits)