Commit 0d775a3f authored by Valentin Reis's avatar Valentin Reis
Browse files

[fix] Fixes the refinement type on Risk for Exp4R to be negative

parent 5533471e
Pipeline #11856 passed with stages
in 38 seconds
...@@ -32,7 +32,7 @@ nix:package: ...@@ -32,7 +32,7 @@ nix:package:
tags: tags:
- kvm - kvm
- nix - nix
script: nix-build -A hbandit --no-build-output script: nix-build -A haskellPackages.hbandit --no-build-output
make:readme: make:readme:
......
...@@ -12,7 +12,7 @@ SHELL := $(shell which bash) ...@@ -12,7 +12,7 @@ SHELL := $(shell which bash)
NIX_PATH := nixpkgs=./. NIX_PATH := nixpkgs=./.
.PHONY: all .PHONY: all
all: hbandit.nix ghcid pre-commit all: ghcid pre-commit
#generating the vendored cabal file. #generating the vendored cabal file.
...@@ -35,7 +35,7 @@ ci-%: ...@@ -35,7 +35,7 @@ ci-%:
.PHONY: ghcid .PHONY: ghcid
ghcid: ghcid-hbandit ghcid: ghcid-hbandit
ghcid-hbandit: hbandit.cabal .hlint.yaml hbandit.nix ghcid-hbandit: hbandit.cabal .hlint.yaml
@nix-shell -E ' @nix-shell -E '
with import <nixpkgs> {}; with import <nixpkgs> {};
with haskellPackages; with haskellPackages;
...@@ -50,7 +50,7 @@ ghcid-hbandit: hbandit.cabal .hlint.yaml hbandit.nix ...@@ -50,7 +50,7 @@ ghcid-hbandit: hbandit.cabal .hlint.yaml hbandit.nix
-l -l
' '
ghcid-test: hbandit.cabal .hlint.yaml hbandit.nix ghcid-test: hbandit.cabal .hlint.yaml
@nix-shell --pure --run bash <<< ' @nix-shell --pure --run bash <<< '
ghcid --command "cabal v2-repl test " \ ghcid --command "cabal v2-repl test " \
--restart=hbandit.cabal \ --restart=hbandit.cabal \
...@@ -99,7 +99,7 @@ ormolu: ...@@ -99,7 +99,7 @@ ormolu:
' '
.PHONY: doc .PHONY: doc
doc: hbandit.cabal hbandit.nix doc: hbandit.cabal
@nix-shell -E ' @nix-shell -E '
with import <nixpkgs> {}; with import <nixpkgs> {};
with haskellPackages; with haskellPackages;
...@@ -116,5 +116,5 @@ clean: ...@@ -116,5 +116,5 @@ clean:
rm -rf .build rm -rf .build
rm -rf dist* rm -rf dist*
rm -f extras/main.hs rm -f extras/main.hs
rm -f hbandit.nix rm -f
rm -f hbandit.cabal rm -f hbandit.cabal
...@@ -51,7 +51,7 @@ library ...@@ -51,7 +51,7 @@ library
refined -any, refined -any,
intervals -any, intervals -any,
MonadRandom -any, MonadRandom -any,
list-extras -any, -- list-extras -any,
lens -any, lens -any,
generic-lens -any generic-lens -any
......
...@@ -53,7 +53,7 @@ data Exp4R s a er ...@@ -53,7 +53,7 @@ data Exp4R s a er
k :: Int, k :: Int,
n :: Int, n :: Int,
lambda :: R.Refined R.NonNegative Double, lambda :: R.Refined R.NonNegative Double,
constraint :: ZeroOne Double, constraint :: Double,
experts :: experts ::
NonEmpty NonEmpty
( ZeroOne Double, ( ZeroOne Double,
...@@ -75,7 +75,7 @@ data LastAction a ...@@ -75,7 +75,7 @@ data LastAction a
data Feedback data Feedback
= Feedback = Feedback
{ cost :: ZeroOne Double, { cost :: ZeroOne Double,
risk :: ZeroOne Double risk :: R.Refined R.NonPositive Double
} }
deriving (Generic) deriving (Generic)
...@@ -83,7 +83,7 @@ data Feedback ...@@ -83,7 +83,7 @@ data Feedback
data Exp4RCfg s a er data Exp4RCfg s a er
= Exp4RCfg = Exp4RCfg
{ expertsCfg :: NonEmpty er, { expertsCfg :: NonEmpty er,
constraintCfg :: ZeroOne Double, constraintCfg :: Double,
horizonCfg :: R.Refined R.Positive Int, horizonCfg :: R.Refined R.Positive Int,
as :: NonEmpty a as :: NonEmpty a
} }
...@@ -145,7 +145,7 @@ update ...@@ -145,7 +145,7 @@ update
lam <- R.unrefine <$> use #lambda lam <- R.unrefine <$> use #lambda
delta <- get <&> mkDelta delta <- get <&> mkDelta
mu <- get <&> mkMu mu <- get <&> mkMu
beta <- use #constraint <&> R.unrefine beta <- use #constraint
let numeratorTerm (R.unrefine -> w, _) p = let numeratorTerm (R.unrefine -> w, _) p =
w * exp (- mu * (p * (lam * r + c) / p_a)) w * exp (- mu * (p * (lam * r + c) / p_a))
wUpdate = NE.zipWith numeratorTerm weightedAdvice pPolicy_a wUpdate = NE.zipWith numeratorTerm weightedAdvice pPolicy_a
......
...@@ -18,9 +18,9 @@ where ...@@ -18,9 +18,9 @@ where
import Bandit.Class import Bandit.Class
import Bandit.EpsGreedy import Bandit.EpsGreedy
import Bandit.Types import Bandit.Types
import Bandit.Util (argmax')
import Control.Lens import Control.Lens
import Data.Generics.Labels () import Data.Generics.Labels ()
import Data.List.Extras.Argmax
import Protolude import Protolude
import Refined import Refined
...@@ -114,7 +114,7 @@ pickreturn :: ...@@ -114,7 +114,7 @@ pickreturn ::
ExploreExploit a -> ExploreExploit a ->
m a m a
pickreturn t phiInv alpha (ExploreExploit weights) = pickreturn t phiInv alpha (ExploreExploit weights) =
return . action . argmax f $ toList weights return . action . argmax' f $ toList weights
where where
f Weight {..} = f Weight {..} =
averageLoss + phiInv (alpha * log (fromIntegral t) / fromIntegral hits) averageLoss + phiInv (alpha * log (fromIntegral t) / fromIntegral hits)
...@@ -11,6 +11,8 @@ module Bandit.Util ...@@ -11,6 +11,8 @@ module Bandit.Util
unsafeNormalizePanic, unsafeNormalizePanic,
normalizeDistribution, normalizeDistribution,
normalizedSum, normalizedSum,
argmax,
argmax',
) )
where where
...@@ -80,3 +82,41 @@ unsafeNormalizePanic v m = ...@@ -80,3 +82,41 @@ unsafeNormalizePanic v m =
fromMaybe fromMaybe
(panic "normalizePanic error.") (panic "normalizePanic error.")
(normalize v m) (normalize v m)
emptyListError :: Text -> a
emptyListError fun = panic $ "Utils.argmax" <> fun <> ": empty list"
argmax' :: (Ord b) => (a -> b) -> [a] -> a
argmax' _ [] = emptyListError "argmax'"
argmax' f (x : xs) = _argmaxBy (>) f xs (x, f x)
argmax :: (Ord b) => (a -> b) -> [a] -> Maybe a
argmax _ [] = Nothing
argmax f xs@(_ : _) = Just (argmax' f xs)
-- | Direct version of 'argmaxBy' which doesn't catch the empty
-- list error.
argmaxBy' :: (b -> b -> Ordering) -> (a -> b) -> [a] -> a
argmaxBy' _ _ [] = emptyListError "argmaxBy'"
argmaxBy' ord f (x : xs) = _argmaxBy boolOrd f xs (x, f x)
where
boolOrd a b = GT == ord a b
-- | Returns the element of the list which maximizes a function
-- according to a user-defined ordering, or @Nothing@ if the list
-- was empty.
argmaxBy :: (b -> b -> Ordering) -> (a -> b) -> [a] -> Maybe a
argmaxBy _ _ [] = Nothing
argmaxBy ord f xs@(_ : _) = Just (argmaxBy' ord f xs)
-- | Tail-recursive driver
_argmaxBy :: (b -> b -> Bool) -> (a -> b) -> [a] -> (a, b) -> a
_argmaxBy isBetterThan f = go
where
go [] (b, _) = b
go (x : xs) (b, fb) = go xs $! cmp x (b, fb)
cmp a (b, fb) =
let fa = f a
in if fa `isBetterThan` fb
then (a, fa)
else (b, fb)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment