This commit is contained in:
parent
b486de5d8c
commit
2f1636a33c
|
@ -1,4 +1,4 @@
|
|||
package ucb
|
||||
package algo
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
@ -37,10 +37,10 @@ func (u *UCB) Reset(n int) error {
|
|||
|
||||
}
|
||||
|
||||
func (u *UCB) Count(res *[]int) {
|
||||
func (u *UCB) Count(res []int) {
|
||||
u.cr.Count(res)
|
||||
}
|
||||
|
||||
func (u *UCB) Reward(res *[]float64) {
|
||||
func (u *UCB) Reward(res []float64) {
|
||||
u.cr.Reward(res)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"lukechampine.com/frand"
|
||||
"tuxpa.in/a/gambit"
|
||||
"tuxpa.in/a/gambit/algo"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
g := &gambit.Gang{}
|
||||
b := &algo.EpsilonGreedy{Epsilon: 0.1}
|
||||
b.Reset(4)
|
||||
g.WithBandit(b)
|
||||
|
||||
n := 100
|
||||
for i := 0; i < n; i++ {
|
||||
b.Update(
|
||||
// select a random arm
|
||||
b.Select(frand.Float64()),
|
||||
// and supply a random score
|
||||
float64(frand.Intn(4)),
|
||||
)
|
||||
}
|
||||
|
||||
log.Println(g.AllocateSolution())
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package gambit
|
||||
|
||||
type Gang struct {
|
||||
b Bandit
|
||||
}
|
||||
|
||||
func (g *Gang) WithBandit(b Bandit) {
|
||||
g.b = b
|
||||
}
|
||||
|
||||
func (g *Gang) AllocateSolution() ([]int, []float64) {
|
||||
a := make([]int, g.b.Size())
|
||||
b := make([]float64, g.b.Size())
|
||||
g.b.Count(a)
|
||||
g.b.Reward(b)
|
||||
return a, b
|
||||
}
|
6
go.mod
6
go.mod
|
@ -2,10 +2,12 @@ module tuxpa.in/a/gambit
|
|||
|
||||
go 1.19
|
||||
|
||||
require golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
|
||||
require (
|
||||
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
|
||||
lukechampine.com/frand v1.4.2
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
|
||||
golang.org/x/sys v0.1.0 // indirect
|
||||
lukechampine.com/frand v1.4.2 // indirect
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@ type CountReward struct {
|
|||
}
|
||||
|
||||
func (c *CountReward) ResetTo(size int) {
|
||||
if len(c.Counts) > size {
|
||||
if len(c.Counts) < size {
|
||||
c.Counts = make([]int, size)
|
||||
}
|
||||
c.Counts = c.Counts[:size]
|
||||
|
@ -21,7 +21,7 @@ func (c *CountReward) ResetTo(size int) {
|
|||
c.Counts[idx] = 0
|
||||
}
|
||||
|
||||
if len(c.Rewards) > size {
|
||||
if len(c.Rewards) < size {
|
||||
c.Rewards = make([]float64, size)
|
||||
}
|
||||
c.Rewards = c.Rewards[:size]
|
||||
|
|
Loading…
Reference in New Issue