From 2f1636a33ceef86c445d2e40fb6dac2f288c165a Mon Sep 17 00:00:00 2001 From: a Date: Thu, 6 Jul 2023 10:10:29 +0100 Subject: [PATCH] ok --- algo/ucb.go | 6 +++--- example/simple/main.go | 29 +++++++++++++++++++++++++++++ gang.go | 17 +++++++++++++++++ go.mod | 6 ++++-- helper/storage.go | 4 ++-- 5 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 example/simple/main.go create mode 100644 gang.go diff --git a/algo/ucb.go b/algo/ucb.go index 2136976..7ca58fc 100644 --- a/algo/ucb.go +++ b/algo/ucb.go @@ -1,4 +1,4 @@ -package ucb +package algo import ( "math" @@ -37,10 +37,10 @@ func (u *UCB) Reset(n int) error { } -func (u *UCB) Count(res *[]int) { +func (u *UCB) Count(res []int) { u.cr.Count(res) } -func (u *UCB) Reward(res *[]float64) { +func (u *UCB) Reward(res []float64) { u.cr.Reward(res) } diff --git a/example/simple/main.go b/example/simple/main.go new file mode 100644 index 0000000..86873c1 --- /dev/null +++ b/example/simple/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "log" + + "lukechampine.com/frand" + "tuxpa.in/a/gambit" + "tuxpa.in/a/gambit/algo" +) + +func main() { + + g := &gambit.Gang{} + b := &algo.EpsilonGreedy{Epsilon: 0.1} + b.Reset(4) + g.WithBandit(b) + + n := 100 + for i := 0; i < n; i++ { + b.Update( + // select a random arm + b.Select(frand.Float64()), + // and supply a random score + float64(frand.Intn(4)), + ) + } + + log.Println(g.AllocateSolution()) +} diff --git a/gang.go b/gang.go new file mode 100644 index 0000000..6003ede --- /dev/null +++ b/gang.go @@ -0,0 +1,17 @@ +package gambit + +type Gang struct { + b Bandit +} + +func (g *Gang) WithBandit(b Bandit) { + g.b = b +} + +func (g *Gang) AllocateSolution() ([]int, []float64) { + a := make([]int, g.b.Size()) + b := make([]float64, g.b.Size()) + g.b.Count(a) + g.b.Reward(b) + return a, b +} diff --git a/go.mod b/go.mod index a187029..e9f0dd0 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,12 @@ module tuxpa.in/a/gambit go 1.19 -require golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df +require ( + golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df + lukechampine.com/frand v1.4.2 +) require ( github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect golang.org/x/sys v0.1.0 // indirect - lukechampine.com/frand v1.4.2 // indirect ) diff --git a/helper/storage.go b/helper/storage.go index c0666b3..86d1fbb 100644 --- a/helper/storage.go +++ b/helper/storage.go @@ -13,7 +13,7 @@ type CountReward struct { } func (c *CountReward) ResetTo(size int) { - if len(c.Counts) > size { + if len(c.Counts) < size { c.Counts = make([]int, size) } c.Counts = c.Counts[:size] @@ -21,7 +21,7 @@ func (c *CountReward) ResetTo(size int) { c.Counts[idx] = 0 } - if len(c.Rewards) > size { + if len(c.Rewards) < size { c.Rewards = make([]float64, size) } c.Rewards = c.Rewards[:size]