diff --git a/algo/ucb.go b/algo/ucb.go index d6a59bc..2bfe906 100644 --- a/algo/ucb.go +++ b/algo/ucb.go @@ -20,10 +20,10 @@ func (ucb *UCB) Select(r float64) int { } sz := len(ucb.cr.Counts) var res float64 - for _, v := range ucb.cr.Counts { + for idx, v := range ucb.cr.Counts { ans := math.Sqrt((2.0 * math.Log(float64(sz))) / float64(v)) if ans > res { - res = ans + res = float64(a) + ucb.cr.Rewards[idx] } } return int(res)