From f64868472aae50ed8203ef4d3ec9de7a7cf96fd9 Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Mon, 11 Nov 2019 16:18:20 +0300
Subject: [PATCH] - stats: fix read-write race

* periodicFlush() operation doesn't result in an inconsistent state at any time
* stats readers use the last unit ID properly, excluding the possibility
 when unit ID could be changed, but this unit isn't yet created
---
 stats/stats_unit.go | 41 ++++++++++++++++++++++-------------------
 1 file changed, 22 insertions(+), 19 deletions(-)

diff --git a/stats/stats_unit.go b/stats/stats_unit.go
index 3db14d5b..72ccfc36 100644
--- a/stats/stats_unit.go
+++ b/stats/stats_unit.go
@@ -207,6 +207,13 @@ func btoi(b []byte) uint64 {
 }
 
 // Flush the current unit to DB and delete an old unit when a new hour is started
+// If a unit must be flushed:
+// . lock DB
+// . atomically set a new empty unit as the current one and get the old unit
+//   This is important to do it inside DB lock, so the reader won't get inconsistent results.
+// . write the unit to DB
+// . remove the stale unit from DB
+// . unlock DB
 func (s *statsCtx) periodicFlush() {
 	for {
 		s.unitLock.Lock()
@@ -222,12 +229,13 @@ func (s *statsCtx) periodicFlush() {
 			continue
 		}
 
+		tx := s.beginTxn(true)
+
 		nu := unit{}
 		s.initUnit(&nu, id)
 		u := s.swapUnit(&nu)
 		udb := serialize(u)
 
-		tx := s.beginTxn(true)
 		if tx == nil {
 			continue
 		}
@@ -455,15 +463,20 @@ func (s *statsCtx) Update(e Entry) {
 	s.unitLock.Unlock()
 }
 
-func (s *statsCtx) loadUnits(lastID uint32) []*unitDB {
+func (s *statsCtx) loadUnits() ([]*unitDB, uint32) {
 	tx := s.beginTxn(false)
 	if tx == nil {
-		return nil
+		return nil, 0
 	}
 
+	s.unitLock.Lock()
+	curUnit := serialize(s.unit)
+	curID := s.unit.id
+	s.unitLock.Unlock()
+
 	units := []*unitDB{} //per-hour units
-	firstID := lastID - s.limit + 1
-	for i := firstID; i != lastID; i++ {
+	firstID := curID - s.limit + 1
+	for i := firstID; i != curID; i++ {
 		u := s.loadUnitFromDB(tx, i)
 		if u == nil {
 			u = &unitDB{}
@@ -474,20 +487,13 @@ func (s *statsCtx) loadUnits(lastID uint32) []*unitDB {
 
 	_ = tx.Rollback()
 
-	s.unitLock.Lock()
-	cu := serialize(s.unit)
-	cuID := s.unit.id
-	s.unitLock.Unlock()
-	if cuID != lastID {
-		units = units[1:]
-	}
-	units = append(units, cu)
+	units = append(units, curUnit)
 
 	if len(units) != int(s.limit) {
 		log.Fatalf("len(units) != s.limit: %d %d", len(units), s.limit)
 	}
 
-	return units
+	return units, firstID
 }
 
 /* Algorithm:
@@ -521,9 +527,7 @@ func (s *statsCtx) loadUnits(lastID uint32) []*unitDB {
 func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} {
 	d := map[string]interface{}{}
 
-	lastID := s.conf.UnitID()
-	firstID := lastID - s.limit + 1
-	units := s.loadUnits(lastID)
+	units, firstID := s.loadUnits()
 	if units == nil {
 		return nil
 	}
@@ -699,8 +703,7 @@ func (s *statsCtx) getData(timeUnit TimeUnit) map[string]interface{} {
 }
 
 func (s *statsCtx) GetTopClientsIP(limit uint) []string {
-	lastID := s.conf.UnitID()
-	units := s.loadUnits(lastID)
+	units, _ := s.loadUnits()
 	if units == nil {
 		return nil
 	}