40 lines
647 B
Go
40 lines
647 B
Go
|
package algo
|
||
|
|
||
|
import (
|
||
|
"lukechampine.com/frand"
|
||
|
"tuxpa.in/a/gambit/helper"
|
||
|
)
|
||
|
|
||
|
type EpsilonGreedy struct {
|
||
|
Epsilon float64
|
||
|
cr helper.CountReward
|
||
|
}
|
||
|
|
||
|
func (u *EpsilonGreedy) Select(r float64) int {
|
||
|
if r > u.Epsilon {
|
||
|
return int(u.cr.RewardMax())
|
||
|
}
|
||
|
return frand.Intn(u.cr.Size())
|
||
|
}
|
||
|
|
||
|
func (u *EpsilonGreedy) Update(a int, r float64) error {
|
||
|
return u.cr.Update(a, r)
|
||
|
}
|
||
|
|
||
|
func (u *EpsilonGreedy) Reset(n int) error {
|
||
|
u.cr.ResetTo(n)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (u *EpsilonGreedy) Size() int {
|
||
|
return u.cr.Size()
|
||
|
}
|
||
|
|
||
|
func (u *EpsilonGreedy) Count(res []int) {
|
||
|
u.cr.Count(res)
|
||
|
}
|
||
|
|
||
|
func (u *EpsilonGreedy) Reward(res []float64) {
|
||
|
u.cr.Reward(res)
|
||
|
}
|