91 lines
1.5 KiB
Go
91 lines
1.5 KiB
Go
package helper
|
|
|
|
import (
|
|
"errors"
|
|
"math"
|
|
|
|
"golang.org/x/exp/constraints"
|
|
)
|
|
|
|
type CountReward struct {
|
|
Counts []int
|
|
Rewards []float64
|
|
}
|
|
|
|
func (c *CountReward) ResetTo(size int) {
|
|
if len(c.Counts) < size {
|
|
c.Counts = make([]int, size)
|
|
}
|
|
c.Counts = c.Counts[:size]
|
|
for idx := range c.Counts {
|
|
c.Counts[idx] = 0
|
|
}
|
|
|
|
if len(c.Rewards) < size {
|
|
c.Rewards = make([]float64, size)
|
|
}
|
|
c.Rewards = c.Rewards[:size]
|
|
for idx := range c.Rewards {
|
|
c.Rewards[idx] = 0
|
|
}
|
|
}
|
|
func (c *CountReward) CountMax() (i int) {
|
|
xs := c.Counts
|
|
i = math.MinInt
|
|
for _, v := range xs {
|
|
if v > i {
|
|
i = v
|
|
}
|
|
}
|
|
return
|
|
}
|
|
func (c *CountReward) RewardMax() (i float64) {
|
|
xs := c.Rewards
|
|
i = math.Inf(-1)
|
|
for _, v := range xs {
|
|
if v > i {
|
|
i = v
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *CountReward) Update(a int, r float64) error {
|
|
if a < 0 || a >= c.Size() || r < 0 {
|
|
return errors.New("TODO")
|
|
}
|
|
c.Counts[a]++
|
|
dec := float64(c.Counts[a])
|
|
c.Rewards[a] = (c.Rewards[a]*(dec-1) + r) / dec
|
|
return nil
|
|
}
|
|
func (c *CountReward) CountSum() (i int) {
|
|
return sum(c.Counts)
|
|
}
|
|
func (c *CountReward) RewardSum() (i float64) {
|
|
return sum(c.Rewards)
|
|
}
|
|
|
|
func (c *CountReward) Size() int {
|
|
return len(c.Counts)
|
|
}
|
|
|
|
func (c *CountReward) Count(res []int) {
|
|
copy(res, c.Counts)
|
|
}
|
|
|
|
func (c *CountReward) Reward(res []float64) {
|
|
copy(res, c.Rewards)
|
|
}
|
|
|
|
func sum[T Numeric](n []T) (i T) {
|
|
for _, v := range n {
|
|
i = i + v
|
|
}
|
|
return
|
|
}
|
|
|
|
type Numeric interface {
|
|
constraints.Integer | constraints.Float
|
|
}
|