Tutorial.hs 13.5 KB
Newer Older
1

2
{-| This module serves as an introduction to the `hbandit` Multi-Armed Bandit library.
3
4
-}

5
module Bandit.Tutorial (
6
7
8
9
10
11
12
13
14
15
16
17
18
-- *** Setup

-- | The code snippets displayed in this tutorial require the following list of extensions and modules.

-- |
-- >  {-# LANGUAGE LambdaCase #-}
-- >  {-# LANGUAGE FlexibleContexts #-}
-- >  {-# LANGUAGE DeriveGeneric #-}
-- >  {-# LANGUAGE OverloadedStrings #-}
-- >  {-# LANGUAGE ViewPatterns #-}
-- >  {-# LANGUAGE OverloadedLists #-}
-- >  {-# LANGUAGE OverloadedLabels #-}
-- >  {-# LANGUAGE DataKinds #-}
19
-- >  {-# LANGUAGE FlexibleInstances  #-}
20
-- >  {-# LANGUAGE ScopedTypeVariables #-}
21
-- >  {-# LANGUAGE StandaloneDeriving #-}
22
23
-- >  {-# LANGUAGE NoImplicitPrelude #-}
-- >  {-# LANGUAGE TemplateHaskell #-}
24
25
-- >  {-# LANGUAGE DeriveAnyClass #-}
-- >  {-# LANGUAGE RecordWildCards #-}
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
-- >  {-# LANGUAGE QuasiQuotes #-}
-- >  import Protolude
-- >  import Text.Pretty.Simple
-- >  import System.Random
-- >  import Data.List.NonEmpty as NonEmpty hiding (init)
-- >  import Refined hiding (NonEmpty)
-- >  import Refined.Unsafe
-- >  import Data.Sequence as Sequence
-- >  import Data.Generics.Product
-- >  import Prelude ((!!))
-- >  import Data.Coerce
-- >  import Data.Generics.Labels
-- >  import Data.Functor.Compose
-- >  import H.Prelude.Interactive
-- >  import System.IO hiding (print)
-- >  import Control.Monad.Primitive
-- >  import qualified Language.R.Instance as R
-- >  import Control.Lens
44
45
46
-- >  import Bandit
-- >  import Bandit.EpsGreedy
-- >  import Bandit.Exp3
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
-- >  import qualified Data.Text.Lazy.Encoding as T
-- >  import qualified Data.Text.Lazy as T
-- >  import Data.Aeson hiding ((.=))

-- 
-- print' :: (Show a) => a -> IO ()
-- print' a = Protolude.putText $ "@" <> toS (pShowNoColor a) <> "@"
-- 
-- putText' :: Text -> IO ()
-- putText' t = Protolude.putText $ "@" <> t <> "@"
-- 
-- rrequire :: (MonadR m, Literal lib a) => lib -> m ()
-- rrequire p = void [r| suppressMessages(require(p_hs,character.only=TRUE)) |]
-- 
-- rpackages :: [Text]
-- rpackages = [ "svglite", "dplyr", "tidyr", "purrr", "ggplot2" , "jsonlite"]
-- 
-- main :: IO ()
-- main = do
--   R.initialize R.defaultConfig
--   for_ rpackages rrequire
--   [r| theme_set(
--       theme_bw() +
--       theme(
--         panel.background = element_rect(fill = "transparent"), 
--         plot.background = element_rect(fill = "transparent", color = NA), 
--         legend.background = element_rect(fill = "transparent"), 
--         legend.box.background = element_rect(fill = "transparent") 
--       )) |]

-- * Non-contextual
-- | We'll first cover the case of simple MABs that do not use context information.

-- ** Classes
--
-- | The main algorithm class for non-contextual bandits is 'Bandit'. This class gives
-- types for a basic bandit game between a learner and an environment, where the
-- learner has access to a random generator and is defined via a stateful 'step'
-- function. All non-contetual bandit algorithms in this library are instances of this.
86
Bandit.Class.Bandit(..)
87
88
89
90
91
92

-- *** Example instance: Epsilon-Greedy
--
-- | Let's take a look at the instance for the classic fixed-rate \(\epsilon\)-Greedy
-- algorithm. The necessary hyperparameters are the number of arms and the rate value,
-- as the 'EpsGreedyHyper' datatype shows.
93
,Bandit.EpsGreedy.EpsGreedyHyper(..)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

-- | Let's use that instance on some toy data with a few rounds.
--
-- First, we define the @onePass@ function that takes a deterministic oblivious adversary
-- (represented as a list of  of @ a->l @), an initial random generator for the bandit,
-- a hyperparameter, and runs the bandit game on all the iterations of the adversary
-- to produces a history of lossses and actions:

-- | 
-- >  data GameState b a l
-- >    = GameState
-- >        { historyActions :: NonEmpty a,
-- >          historyLosses :: Seq l,
-- >          bandit :: b,
-- >          stdGen :: StdGen
-- >        }
-- >    deriving (Generic, Show)
-- >  
-- >  onePass :: (Bandit b hyper a l) =>
-- >    hyper ->       -- ^ hyperparameter
-- >    StdGen ->      -- ^ random generator initial value
-- >    [(a -> l)] ->  -- ^ oblivious deterministic adversary
-- >    GameState b a l
-- >  onePass hyper g adversary = runGame initialGame
-- >   where
119
-- >    (initialBanditState, initialAction, g') = Bandit.init g hyper
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
-- >    initialGame =  GameState
-- >      { historyActions = [initialAction],
-- >        historyLosses = [],
-- >        bandit = initialBanditState,
-- >        stdGen = g'
-- >      }
-- >    runGame = execState game
-- >    game = for_ adversary iteration
-- >    iteration actionToLoss = do
-- >      (actionToLoss . NonEmpty.head -> loss) <- use #historyActions
-- >      oldGen <- use #stdGen
-- >      (action, newGen) <- zoom #bandit $ step oldGen loss
-- >      #stdGen .= newGen
-- >      #historyActions %= (action NonEmpty.<|)
-- >      #historyLosses %= (loss Sequence.<|)

136
-- |  Specializing this to the 'EpsGreedy' datatype on a small toy dataset, using a fixed rate:
137
138

-- | 
139
-- >  runOnePassEG :: StdGen -> GameState (EpsGreedy Bool FixedRate) Bool Double
140
141
142
-- >  runOnePassEG g = onePass hyper g (getZipList $ f <$> ZipList [40, 2, 10] <*> ZipList [4, 44 ,3] )
-- >   where
-- >    f a b = \case True -> a; False -> b
143
-- >    hyper = EpsGreedyHyper {rateRep = (FixedRate 0.5), arms = Bandit.Arms [True, False]}
144
145
146
147
148
149
150
151
152
153
154
155
156
157
-- >  
-- >  printOnePassEG :: IO ()
-- >  printOnePassEG = putText $
-- >    "Action series:" <> 
-- >    show  (historyActions gs ^.. traversed) <>
-- >    "\nLoss series:" <> 
-- >    show ( historyLosses gs ^.. traversed)
-- >   where gs = runOnePassEG (mkStdGen 1)

-- |
-- >>>  printOnePassEG
-- $eg

-- *** Other classes
158
159
-- | Some other, more restrictive classes are available in [Bandit.Class](Bandit-Class.html) for convenience. See for
-- example 'Bandit.Class.ParameterFreeMAB', which exposes a hyperparameter-free interface for
160
-- algorithms that don't need any information besides the arm count. Those instances are not necessary
161
162
163
-- per se, and the 'Bandit' class is always sufficient. Some instances make agressive use
-- of type refinement through the 'Refined' package. The \(\left[0,1\right]\) interval is particularly useful:

164
,Bandit.Types.ZeroOne
165
166

-- ** Algorithm comparison
Valentin Reis's avatar
Valentin Reis committed
167
-- | This subsection runs bandit experiments on an example dataset with some of the @Bandit@ instances.
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
-- The data for this tutorial is generated in R using the [inline-r](https://hackagehaskell.org/package/inline-r) package.
-- Let's define a simple problem with three gaussian arms. We will threshold all cost values to \(\left[0,1\right]\).

-- | 
-- >  generateGaussianData ::
-- >    Int ->                -- ^ number of rounds
-- >    [ZeroOne Double] ->   -- ^ arm averages
-- >    IO [[Double]]         -- ^ dataset
-- >  generateGaussianData (fromInteger . toInteger -> n :: Double) avgs = 
-- >    (mapM generate (unrefine <$> avgs ))
-- >    where
-- >      generate :: (MonadR m, Functor m) => Double -> m [Double]
-- >      generate mu = gen01TS mu <&> fromSomeSEXP
-- >      gen01TS :: (MonadR m) => Double -> m (SomeSEXP (PrimState m))
-- >      gen01TS mu = [r| pmax(0,pmin(1,rnorm(n_hs, mean=mu_hs, sd=0.1))) |]
-- >  
-- >  refineDataset ::  [[Double]] -> [[ZeroOne Double]]
-- >  refineDataset = (fmap.fmap) unsafeRefine

-- Let's generate data for a 3 arm problem and observe the distribution of costs.

-- |
-- >>>  dataset <- generateGaussianData 400 (unsafeRefine <$> [0.1, 0.5, 0.6])
-- >>>  let d :: Text
-- >>>      d = show $ Protolude.transpose dataset
-- >>>  [r| print(summary(jsonlite::fromJSON(d_hs))) |]
-- $summaryProblem

-- |
-- >>>  [r| 
-- >>>    data <- as.data.frame(jsonlite::fromJSON(d_hs))
-- >>>    data_mutated = data %>% gather("arm", "cost", 1:ncol(data))
-- >>>    ggplot(data_mutated, aes(arm, cost, group=factor(arm)))+ geom_boxplot()
-- >>>  |]
-- $summaryPlot

-- | Here is helper that convert to the @[action->loss]@ adversary format:

-- | 
-- >  toAdversary :: [[a]] -> [Int -> a]
-- >  toAdversary xss = Protolude.transpose xss <&> listToAdversary
-- >   where
-- >    listToAdversary :: [a] -> Int -> a
-- >    listToAdversary l i = l Prelude.!! i

-- | Let's define some experiments:

-- | 
-- >  exp3 :: [[Double]] -> StdGen -> GameState (Exp3 Int) Int (ZeroOne Double)
-- >  exp3 dataset g = 
-- >    onePass
219
-- >      (Bandit.Arms [0..2])
220
221
222
-- >      g
-- >      (toAdversary $ refineDataset dataset)
-- >                   
223
224
-- >  greedy :: (Rate r) => [[Double]] -> StdGen -> r -> GameState (EpsGreedy Int r) Int (Double)
-- >  greedy dataset g r =  
225
-- >    onePass
226
-- >      (EpsGreedyHyper {rateRep = r, arms = Bandit.Arms [0..2]})
227
228
229
-- >      g
-- >      (toAdversary dataset)
-- >  
230
231
232
233
234
235
236
237
238
239
240
-- >  data SimResult t = SimResult {
-- >    t :: t Int,
-- >    seed :: t Int,
-- >    greedy05 :: t Double,
-- >    greedy03 :: t Double,
-- >    greedysqrt05 :: t Double,
-- >    exp3pf :: t Double
-- >  } deriving (Generic)
-- >  
-- >  simulation :: Int -> Int -> IO (SimResult [])
-- >  simulation tmax seed@(mkStdGen -> g) = do
241
-- >    dataset <- generateGaussianData tmax (unsafeRefine <$> [0.1, 0.5, 0.6])
242
243
244
245
246
247
248
249
250
251
-- >    return $ SimResult {
-- >               t = [1 .. tmax],
-- >               seed = Protolude.replicate tmax seed,
-- >               greedy05 = extract $ greedy dataset g (FixedRate 0.5),
-- >               greedy03 = extract $ greedy dataset g (FixedRate 0.3),
-- >               greedysqrt05 = extract $ greedy dataset g (InverseSqrtRate 0.5),
-- >               exp3pf = fmap unrefine . extract $ exp3 dataset g
-- >             }
-- >   where 
-- >     extract = Protolude.toList . Sequence.reverse . historyLosses
252
-- >  
253
254
255
256
257
258
259
260
261
262
-- >  instance Semigroup (SimResult []) where
-- >    x <> y = SimResult {
-- >               t = f Main.t,
-- >               seed = f seed, 
-- >               greedy05 = f greedy05, 
-- >               greedy03 = f greedy03, 
-- >               greedysqrt05 = f greedysqrt05,
-- >               exp3pf = f exp3pf
-- >             }
-- >             where f accessor = accessor x <> accessor y
263
-- >  
264
265
-- >  instance Monoid (SimResult []) where
-- >    mempty = SimResult mempty mempty mempty mempty mempty mempty
266
-- >  
267
268
269
-- >  instance ToJSON (SimResult []) where
-- >    toJSON SimResult{..} = 
-- >      toJSON (t, seed, greedy05, greedy03, greedysqrt05, exp3pf)
270
271

-- |
272
273
-- >>>  results <- forM ([2..10] ::[Int]) (simulation 400)
-- >>>  let exported = T.unpack $ T.decodeUtf8 $ encode $ mconcat results
274
275
-- >>>  [r|
-- >>>    data.frame(t(jsonlite::fromJSON(exported_hs))) %>%
276
-- >>>      rename(t = X1, iteration = X2, greedy05= X3, greedy03=X4, greedysqrt05=X5,exp3=X6 ) %>%
277
278
279
280
281
282
283
-- >>>      summary %>%
-- >>>      print
-- >>>  |]
-- $expe

-- |
-- >>>  [r| data.frame(t(jsonlite::fromJSON(exported_hs))) %>%
284
-- >>>        rename(t = X1, iteration = X2, greedy05= X3, greedy03=X4, greedysqrt05=X5,exp3=X6 ) %>%
285
286
287
288
289
290
-- >>>        gather("strategy", "loss", -t, -iteration) %>%
-- >>>        mutate(strategy=factor(strategy)) %>% 
-- >>>        group_by(strategy,iteration) %>%
-- >>>        mutate(regret = cumsum(loss-0.1)) %>% 
-- >>>        ungroup() %>%
-- >>>        ggplot(., aes(t, regret, color=strategy, group=interaction(strategy, iteration))) +
291
-- >>>          geom_line(alpha=0.5) + ylab("External Regret")
292
293
294
295
-- >>>  |]
-- $regretPlot

  ) where
296
297
298
import Bandit.Class
import Bandit.Types
import Bandit.EpsGreedy
299
300
301
302
303
304

--   pass

-- > Resolving dependencies...
-- > Build profile: -w ghc-8.6.5 -O1
-- > In order, the following will be built (use -v for more details):
Valentin Reis's avatar
Valentin Reis committed
305
-- >  - hbandit-1.0.0 (lib) (configuration changed)
306
-- >  - fake-package-0 (exe:script) (configuration changed)
Valentin Reis's avatar
Valentin Reis committed
307
-- > Configuring library for hbandit-1.0.0..
308
309
-- > Preprocessing library for hbandit-1.0.0..
-- > Building library for hbandit-1.0.0..
310
311
312
313
314
315
316
317
318
319
320
-- > [8 of 8] Compiling Bandit.Tutorial  ( src/Bandit/Tutorial.hs, /home/fre/workspace/hbandit/dist-newstyle/build/x86_64-linux/ghc-8.6.5/hbandit-1.0.0/build/Bandit/Tutorial.o )
-- > Configuring executable 'script' for fake-package-0..
-- > Preprocessing executable 'script' for fake-package-0..
-- > Building executable 'script' for fake-package-0..
-- > [1 of 1] Compiling Main             ( Main.hs, /home/fre/workspace/hbandit/dist-newstyle/build/x86_64-linux/ghc-8.6.5/fake-package-0/x/script/build/script/script-tmp/Main.o )
-- > Linking /home/fre/workspace/hbandit/dist-newstyle/build/x86_64-linux/ghc-8.6.5/fake-package-0/x/script/build/script/script ...
-- $eg
-- > Action series:[True,True,False,True]
-- > Loss series:[10.0,44.0,40.0]
-- $summaryProblem
-- >        V1                V2               V3        
Valentin Reis's avatar
Valentin Reis committed
321
322
323
324
325
326
-- >  Min.   :0.00000   Min.   :0.2234   Min.   :0.3124  
-- >  1st Qu.:0.02779   1st Qu.:0.4327   1st Qu.:0.5324  
-- >  Median :0.09837   Median :0.4924   Median :0.6035  
-- >  Mean   :0.10480   Mean   :0.4951   Mean   :0.5999  
-- >  3rd Qu.:0.15828   3rd Qu.:0.5543   3rd Qu.:0.6707  
-- >  Max.   :0.37662   Max.   :0.7331   Max.   :0.9077  
327
328
329
330
331
-- $summaryPlot
-- <<literate/summaryPlot.png>>
-- $expe
-- >        t           iteration     greedy05          greedy03      
-- >  Min.   :  1.0   Min.   : 2   Min.   :0.00000   Min.   :0.00000  
Valentin Reis's avatar
Valentin Reis committed
332
333
334
335
336
337
338
339
340
341
342
343
-- >  1st Qu.:100.8   1st Qu.: 4   1st Qu.:0.05422   1st Qu.:0.04456  
-- >  Median :200.5   Median : 6   Median :0.13669   Median :0.11669  
-- >  Mean   :200.5   Mean   : 6   Mean   :0.19376   Mean   :0.15407  
-- >  3rd Qu.:300.2   3rd Qu.: 8   3rd Qu.:0.26790   3rd Qu.:0.20092  
-- >  Max.   :400.0   Max.   :10   Max.   :0.80432   Max.   :0.80432  
-- >   greedysqrt05          exp3       
-- >  Min.   :0.00000   Min.   :0.0000  
-- >  1st Qu.:0.03851   1st Qu.:0.0435  
-- >  Median :0.10307   Median :0.1122  
-- >  Mean   :0.11850   Mean   :0.1475  
-- >  3rd Qu.:0.17220   3rd Qu.:0.1912  
-- >  Max.   :0.80432   Max.   :0.8889  
344
345
-- $regretPlot
-- <<literate/regretPlot.png>>
346