ok
This commit is contained in:
commit
ef5261b577
54
algo/ucb.go
Normal file
54
algo/ucb.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package ucb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"tuxpa.in/a/gambit/helper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UCB struct {
|
||||||
|
cr helper.CountReward
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ucb *UCB) Select(r float64) int {
|
||||||
|
a := len(ucb.cr.Counts)
|
||||||
|
for _, v := range ucb.cr.Counts {
|
||||||
|
if v == 0 {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sz := len(ucb.cr.Counts)
|
||||||
|
var res float64
|
||||||
|
for _, v := range ucb.cr.Counts {
|
||||||
|
ans := math.Sqrt((2.0 * math.Log(float64(sz))) / float64(v))
|
||||||
|
if ans > res {
|
||||||
|
res = ans
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return int(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UCB) Update(a int, r float64) error {
|
||||||
|
if a < 0 || a >= len(u.cr.Rewards) || r < 0 {
|
||||||
|
return errors.New("TODO")
|
||||||
|
}
|
||||||
|
u.cr.Counts[a]++
|
||||||
|
dec := float64(u.cr.Counts[a])
|
||||||
|
u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UCB) Reset(n int) error {
|
||||||
|
u.cr.ResetTo(n)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UCB) Count(res *[]int) {
|
||||||
|
u.cr.Count(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UCB) Reward(res *[]float64) {
|
||||||
|
u.cr.Reward(res)
|
||||||
|
}
|
10
gambit.go
Normal file
10
gambit.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package gambit
|
||||||
|
|
||||||
|
type Bandit interface {
|
||||||
|
Select(r float64) int
|
||||||
|
Update(a int, r float64) error
|
||||||
|
Reset(n int) error
|
||||||
|
|
||||||
|
Count(res *[]int)
|
||||||
|
Reward(res *[]float64)
|
||||||
|
}
|
42
helper/mutex.go
Normal file
42
helper/mutex.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"tuxpa.in/a/gambit"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sync struct {
|
||||||
|
gambit.Bandit
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sync) Select(r float64) int {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.Select(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sync) Update(a int, r float64) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.Update(a, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sync) Reset(n int) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.Reset(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sync) Count(res *[]int) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.Count(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sync) Reward(res *[]float64) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.Reward(res)
|
||||||
|
}
|
43
helper/storage.go
Normal file
43
helper/storage.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
type CountReward struct {
|
||||||
|
Counts []int
|
||||||
|
Rewards []float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CountReward) ResetTo(size int) {
|
||||||
|
if len(c.Counts) > size {
|
||||||
|
c.Counts = make([]int, size)
|
||||||
|
}
|
||||||
|
c.Counts = c.Counts[:size]
|
||||||
|
|
||||||
|
if len(c.Rewards) > size {
|
||||||
|
c.Rewards = make([]float64, size)
|
||||||
|
}
|
||||||
|
c.Rewards = c.Rewards[:size]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CountReward) Count(res *[]int) {
|
||||||
|
if res == nil {
|
||||||
|
r := make([]int, len(c.Counts))
|
||||||
|
res = &r
|
||||||
|
}
|
||||||
|
if len(c.Counts) < len(*res) {
|
||||||
|
*res = append(*res, len(c.Counts)-len(*res))
|
||||||
|
}
|
||||||
|
|
||||||
|
(*res) = (*res)[:len(c.Counts)]
|
||||||
|
copy(*res, c.Counts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CountReward) Reward(res *[]float64) {
|
||||||
|
if res == nil {
|
||||||
|
r := make([]float64, len(c.Rewards))
|
||||||
|
res = &r
|
||||||
|
}
|
||||||
|
if len(c.Rewards) < len(*res) {
|
||||||
|
*res = append(*res, make([]float64, len(c.Rewards)-len(*res))...)
|
||||||
|
}
|
||||||
|
(*res) = (*res)[:len(c.Rewards)]
|
||||||
|
copy(*res, c.Rewards)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user