From ef5261b5777fa4f8da48fd7aee1b746b7bd09870 Mon Sep 17 00:00:00 2001 From: a Date: Wed, 5 Jul 2023 23:22:48 +0100 Subject: [PATCH] ok --- algo/ucb.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++ gambit.go | 10 +++++++++ go.mod | 3 +++ helper/mutex.go | 42 ++++++++++++++++++++++++++++++++++++ helper/storage.go | 43 +++++++++++++++++++++++++++++++++++++ 5 files changed, 152 insertions(+) create mode 100644 algo/ucb.go create mode 100644 gambit.go create mode 100644 go.mod create mode 100644 helper/mutex.go create mode 100644 helper/storage.go diff --git a/algo/ucb.go b/algo/ucb.go new file mode 100644 index 0000000..d6a59bc --- /dev/null +++ b/algo/ucb.go @@ -0,0 +1,54 @@ +package ucb + +import ( + "errors" + "math" + + "tuxpa.in/a/gambit/helper" +) + +type UCB struct { + cr helper.CountReward +} + +func (ucb *UCB) Select(r float64) int { + a := len(ucb.cr.Counts) + for _, v := range ucb.cr.Counts { + if v == 0 { + return a + } + } + sz := len(ucb.cr.Counts) + var res float64 + for _, v := range ucb.cr.Counts { + ans := math.Sqrt((2.0 * math.Log(float64(sz))) / float64(v)) + if ans > res { + res = ans + } + } + return int(res) +} + +func (u *UCB) Update(a int, r float64) error { + if a < 0 || a >= len(u.cr.Rewards) || r < 0 { + return errors.New("TODO") + } + u.cr.Counts[a]++ + dec := float64(u.cr.Counts[a]) + u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec + return nil +} + +func (u *UCB) Reset(n int) error { + u.cr.ResetTo(n) + return nil + +} + +func (u *UCB) Count(res *[]int) { + u.cr.Count(res) +} + +func (u *UCB) Reward(res *[]float64) { + u.cr.Reward(res) +} diff --git a/gambit.go b/gambit.go new file mode 100644 index 0000000..a6bfad3 --- /dev/null +++ b/gambit.go @@ -0,0 +1,10 @@ +package gambit + +type Bandit interface { + Select(r float64) int + Update(a int, r float64) error + Reset(n int) error + + Count(res *[]int) + Reward(res *[]float64) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..710d834 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module tuxpa.in/a/gambit + +go 1.19 diff --git a/helper/mutex.go b/helper/mutex.go new file mode 100644 index 0000000..0d17fe3 --- /dev/null +++ b/helper/mutex.go @@ -0,0 +1,42 @@ +package helper + +import ( + "sync" + + "tuxpa.in/a/gambit" +) + +type Sync struct { + gambit.Bandit + mu sync.RWMutex +} + +func (s *Sync) Select(r float64) int { + s.mu.Lock() + defer s.mu.Unlock() + return s.Select(r) +} + +func (s *Sync) Update(a int, r float64) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.Update(a, r) +} + +func (s *Sync) Reset(n int) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.Reset(n) +} + +func (s *Sync) Count(res *[]int) { + s.mu.Lock() + defer s.mu.Unlock() + s.Count(res) +} + +func (s *Sync) Reward(res *[]float64) { + s.mu.Lock() + defer s.mu.Unlock() + s.Reward(res) +} diff --git a/helper/storage.go b/helper/storage.go new file mode 100644 index 0000000..37e3234 --- /dev/null +++ b/helper/storage.go @@ -0,0 +1,43 @@ +package helper + +type CountReward struct { + Counts []int + Rewards []float64 +} + +func (c *CountReward) ResetTo(size int) { + if len(c.Counts) > size { + c.Counts = make([]int, size) + } + c.Counts = c.Counts[:size] + + if len(c.Rewards) > size { + c.Rewards = make([]float64, size) + } + c.Rewards = c.Rewards[:size] +} + +func (c *CountReward) Count(res *[]int) { + if res == nil { + r := make([]int, len(c.Counts)) + res = &r + } + if len(c.Counts) < len(*res) { + *res = append(*res, len(c.Counts)-len(*res)) + } + + (*res) = (*res)[:len(c.Counts)] + copy(*res, c.Counts) +} + +func (c *CountReward) Reward(res *[]float64) { + if res == nil { + r := make([]float64, len(c.Rewards)) + res = &r + } + if len(c.Rewards) < len(*res) { + *res = append(*res, make([]float64, len(c.Rewards)-len(*res))...) + } + (*res) = (*res)[:len(c.Rewards)] + copy(*res, c.Rewards) +}