gambit/algo/ucb.go

47 lines
723 B
Go
Raw Normal View History

2023-07-06 09:10:29 +00:00
package algo
2023-07-05 22:22:48 +00:00
import (
"math"
"tuxpa.in/a/gambit/helper"
)
type UCB struct {
cr helper.CountReward
}
func (ucb *UCB) Select(r float64) int {
a := len(ucb.cr.Counts)
for _, v := range ucb.cr.Counts {
if v == 0 {
return a
}
}
sz := len(ucb.cr.Counts)
var res float64
2023-07-06 08:32:50 +00:00
for idx, v := range ucb.cr.Counts {
2023-07-05 22:22:48 +00:00
ans := math.Sqrt((2.0 * math.Log(float64(sz))) / float64(v))
if ans > res {
2023-07-06 08:32:50 +00:00
res = float64(a) + ucb.cr.Rewards[idx]
2023-07-05 22:22:48 +00:00
}
}
return int(res)
}
func (u *UCB) Update(a int, r float64) error {
2023-07-06 08:59:17 +00:00
return u.cr.Update(a, r)
2023-07-05 22:22:48 +00:00
}
func (u *UCB) Reset(n int) error {
u.cr.ResetTo(n)
return nil
}
2023-07-06 09:10:29 +00:00
func (u *UCB) Count(res []int) {
2023-07-05 22:22:48 +00:00
u.cr.Count(res)
}
2023-07-06 09:10:29 +00:00
func (u *UCB) Reward(res []float64) {
2023-07-05 22:22:48 +00:00
u.cr.Reward(res)
}