Commit 85fa4da5 authored by Valentin Reis's avatar Valentin Reis
Browse files

Merge branch 'fix-negativeRisk' into 'master'

[fix] Fixes the refinement type on Risk for Exp4R to be negative

See merge request !8
parents 5533471e 0d775a3f
Pipeline #11858 passed with stages
in 33 seconds
......@@ -32,7 +32,7 @@ nix:package:
tags:
- kvm
- nix
script: nix-build -A hbandit --no-build-output
script: nix-build -A haskellPackages.hbandit --no-build-output
make:readme:
......
......@@ -12,7 +12,7 @@ SHELL := $(shell which bash)
NIX_PATH := nixpkgs=./.
.PHONY: all
all: hbandit.nix ghcid pre-commit
all: ghcid pre-commit
#generating the vendored cabal file.
......@@ -35,7 +35,7 @@ ci-%:
.PHONY: ghcid
ghcid: ghcid-hbandit
ghcid-hbandit: hbandit.cabal .hlint.yaml hbandit.nix
ghcid-hbandit: hbandit.cabal .hlint.yaml
@nix-shell -E '
with import <nixpkgs> {};
with haskellPackages;
......@@ -50,7 +50,7 @@ ghcid-hbandit: hbandit.cabal .hlint.yaml hbandit.nix
-l
'
ghcid-test: hbandit.cabal .hlint.yaml hbandit.nix
ghcid-test: hbandit.cabal .hlint.yaml
@nix-shell --pure --run bash <<< '
ghcid --command "cabal v2-repl test " \
--restart=hbandit.cabal \
......@@ -99,7 +99,7 @@ ormolu:
'
.PHONY: doc
doc: hbandit.cabal hbandit.nix
doc: hbandit.cabal
@nix-shell -E '
with import <nixpkgs> {};
with haskellPackages;
......@@ -116,5 +116,5 @@ clean:
rm -rf .build
rm -rf dist*
rm -f extras/main.hs
rm -f hbandit.nix
rm -f
rm -f hbandit.cabal
......@@ -51,7 +51,7 @@ library
refined -any,
intervals -any,
MonadRandom -any,
list-extras -any,
-- list-extras -any,
lens -any,
generic-lens -any
......
......@@ -53,7 +53,7 @@ data Exp4R s a er
k :: Int,
n :: Int,
lambda :: R.Refined R.NonNegative Double,
constraint :: ZeroOne Double,
constraint :: Double,
experts ::
NonEmpty
( ZeroOne Double,
......@@ -75,7 +75,7 @@ data LastAction a
data Feedback
= Feedback
{ cost :: ZeroOne Double,
risk :: ZeroOne Double
risk :: R.Refined R.NonPositive Double
}
deriving (Generic)
......@@ -83,7 +83,7 @@ data Feedback
data Exp4RCfg s a er
= Exp4RCfg
{ expertsCfg :: NonEmpty er,
constraintCfg :: ZeroOne Double,
constraintCfg :: Double,
horizonCfg :: R.Refined R.Positive Int,
as :: NonEmpty a
}
......@@ -145,7 +145,7 @@ update
lam <- R.unrefine <$> use #lambda
delta <- get <&> mkDelta
mu <- get <&> mkMu
beta <- use #constraint <&> R.unrefine
beta <- use #constraint
let numeratorTerm (R.unrefine -> w, _) p =
w * exp (- mu * (p * (lam * r + c) / p_a))
wUpdate = NE.zipWith numeratorTerm weightedAdvice pPolicy_a
......
......@@ -18,9 +18,9 @@ where
import Bandit.Class
import Bandit.EpsGreedy
import Bandit.Types
import Bandit.Util (argmax')
import Control.Lens
import Data.Generics.Labels ()
import Data.List.Extras.Argmax
import Protolude
import Refined
......@@ -114,7 +114,7 @@ pickreturn ::
ExploreExploit a ->
m a
pickreturn t phiInv alpha (ExploreExploit weights) =
return . action . argmax f $ toList weights
return . action . argmax' f $ toList weights
where
f Weight {..} =
averageLoss + phiInv (alpha * log (fromIntegral t) / fromIntegral hits)
......@@ -11,6 +11,8 @@ module Bandit.Util
unsafeNormalizePanic,
normalizeDistribution,
normalizedSum,
argmax,
argmax',
)
where
......@@ -80,3 +82,41 @@ unsafeNormalizePanic v m =
fromMaybe
(panic "normalizePanic error.")
(normalize v m)
emptyListError :: Text -> a
emptyListError fun = panic $ "Utils.argmax" <> fun <> ": empty list"
argmax' :: (Ord b) => (a -> b) -> [a] -> a
argmax' _ [] = emptyListError "argmax'"
argmax' f (x : xs) = _argmaxBy (>) f xs (x, f x)
argmax :: (Ord b) => (a -> b) -> [a] -> Maybe a
argmax _ [] = Nothing
argmax f xs@(_ : _) = Just (argmax' f xs)
-- | Direct version of 'argmaxBy' which doesn't catch the empty
-- list error.
argmaxBy' :: (b -> b -> Ordering) -> (a -> b) -> [a] -> a
argmaxBy' _ _ [] = emptyListError "argmaxBy'"
argmaxBy' ord f (x : xs) = _argmaxBy boolOrd f xs (x, f x)
where
boolOrd a b = GT == ord a b
-- | Returns the element of the list which maximizes a function
-- according to a user-defined ordering, or @Nothing@ if the list
-- was empty.
argmaxBy :: (b -> b -> Ordering) -> (a -> b) -> [a] -> Maybe a
argmaxBy _ _ [] = Nothing
argmaxBy ord f xs@(_ : _) = Just (argmaxBy' ord f xs)
-- | Tail-recursive driver
_argmaxBy :: (b -> b -> Bool) -> (a -> b) -> [a] -> (a, b) -> a
_argmaxBy isBetterThan f = go
where
go [] (b, _) = b
go (x : xs) (b, fb) = go xs $! cmp x (b, fb)
cmp a (b, fb) =
let fa = f a
in if fa `isBetterThan` fb
then (a, fa)
else (b, fb)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment