This commit is contained in:
parent
b486de5d8c
commit
2f1636a33c
|
@ -1,4 +1,4 @@
|
||||||
package ucb
|
package algo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
|
@ -37,10 +37,10 @@ func (u *UCB) Reset(n int) error {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UCB) Count(res *[]int) {
|
func (u *UCB) Count(res []int) {
|
||||||
u.cr.Count(res)
|
u.cr.Count(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UCB) Reward(res *[]float64) {
|
func (u *UCB) Reward(res []float64) {
|
||||||
u.cr.Reward(res)
|
u.cr.Reward(res)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"lukechampine.com/frand"
|
||||||
|
"tuxpa.in/a/gambit"
|
||||||
|
"tuxpa.in/a/gambit/algo"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
g := &gambit.Gang{}
|
||||||
|
b := &algo.EpsilonGreedy{Epsilon: 0.1}
|
||||||
|
b.Reset(4)
|
||||||
|
g.WithBandit(b)
|
||||||
|
|
||||||
|
n := 100
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
b.Update(
|
||||||
|
// select a random arm
|
||||||
|
b.Select(frand.Float64()),
|
||||||
|
// and supply a random score
|
||||||
|
float64(frand.Intn(4)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(g.AllocateSolution())
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
package gambit
|
||||||
|
|
||||||
|
type Gang struct {
|
||||||
|
b Bandit
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gang) WithBandit(b Bandit) {
|
||||||
|
g.b = b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Gang) AllocateSolution() ([]int, []float64) {
|
||||||
|
a := make([]int, g.b.Size())
|
||||||
|
b := make([]float64, g.b.Size())
|
||||||
|
g.b.Count(a)
|
||||||
|
g.b.Reward(b)
|
||||||
|
return a, b
|
||||||
|
}
|
6
go.mod
6
go.mod
|
@ -2,10 +2,12 @@ module tuxpa.in/a/gambit
|
||||||
|
|
||||||
go 1.19
|
go 1.19
|
||||||
|
|
||||||
require golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
|
require (
|
||||||
|
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
|
||||||
|
lukechampine.com/frand v1.4.2
|
||||||
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
|
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
|
||||||
golang.org/x/sys v0.1.0 // indirect
|
golang.org/x/sys v0.1.0 // indirect
|
||||||
lukechampine.com/frand v1.4.2 // indirect
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,7 @@ type CountReward struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CountReward) ResetTo(size int) {
|
func (c *CountReward) ResetTo(size int) {
|
||||||
if len(c.Counts) > size {
|
if len(c.Counts) < size {
|
||||||
c.Counts = make([]int, size)
|
c.Counts = make([]int, size)
|
||||||
}
|
}
|
||||||
c.Counts = c.Counts[:size]
|
c.Counts = c.Counts[:size]
|
||||||
|
@ -21,7 +21,7 @@ func (c *CountReward) ResetTo(size int) {
|
||||||
c.Counts[idx] = 0
|
c.Counts[idx] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.Rewards) > size {
|
if len(c.Rewards) < size {
|
||||||
c.Rewards = make([]float64, size)
|
c.Rewards = make([]float64, size)
|
||||||
}
|
}
|
||||||
c.Rewards = c.Rewards[:size]
|
c.Rewards = c.Rewards[:size]
|
||||||
|
|
Loading…
Reference in New Issue