47 lines
723 B
Go
47 lines
723 B
Go
package algo
|
|
|
|
import (
|
|
"math"
|
|
|
|
"tuxpa.in/a/gambit/helper"
|
|
)
|
|
|
|
type UCB struct {
|
|
cr helper.CountReward
|
|
}
|
|
|
|
func (ucb *UCB) Select(r float64) int {
|
|
a := len(ucb.cr.Counts)
|
|
for _, v := range ucb.cr.Counts {
|
|
if v == 0 {
|
|
return a
|
|
}
|
|
}
|
|
sz := len(ucb.cr.Counts)
|
|
var res float64
|
|
for idx, v := range ucb.cr.Counts {
|
|
ans := math.Sqrt((2.0 * math.Log(float64(sz))) / float64(v))
|
|
if ans > res {
|
|
res = float64(a) + ucb.cr.Rewards[idx]
|
|
}
|
|
}
|
|
return int(res)
|
|
}
|
|
|
|
func (u *UCB) Update(a int, r float64) error {
|
|
return u.cr.Update(a, r)
|
|
}
|
|
func (u *UCB) Reset(n int) error {
|
|
u.cr.ResetTo(n)
|
|
return nil
|
|
|
|
}
|
|
|
|
func (u *UCB) Count(res []int) {
|
|
u.cr.Count(res)
|
|
}
|
|
|
|
func (u *UCB) Reward(res []float64) {
|
|
u.cr.Reward(res)
|
|
}
|