From 37f6d38c498740c7c3db59c31da031cd5c305694 Mon Sep 17 00:00:00 2001
From: Eugene Bujak <hmage@hmage.net>
Date: Tue, 9 Oct 2018 04:45:05 +0300
Subject: [PATCH 1/2] Implement online stats calculation in coredns plugin
 instead of scraping prometheus.

---
 control.go                       | 121 +++-------
 coredns_plugin/coredns_plugin.go |  30 +--
 coredns_plugin/coredns_stats.go  | 391 +++++++++++++++++++++++++++++++
 coredns_plugin/querylog.go       |   8 +-
 coredns_plugin/querylog_file.go  |   3 +-
 coredns_plugin/querylog_top.go   |  71 ++++--
 6 files changed, 489 insertions(+), 135 deletions(-)
 create mode 100644 coredns_plugin/coredns_stats.go

diff --git a/control.go b/control.go
index a1b88900..d00d5e07 100644
--- a/control.go
+++ b/control.go
@@ -222,115 +222,68 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
 // stats
 // -----
 func handleStats(w http.ResponseWriter, r *http.Request) {
-	histrical := generateMapFromStats(&statistics.PerHour, 0, 24)
-	// sum them up
-	summed := map[string]interface{}{}
-	for key, values := range histrical {
-		summedValue := 0.0
-		floats, ok := values.([]float64)
-		if !ok {
-			continue
-		}
-		for _, v := range floats {
-			summedValue += v
-		}
-		summed[key] = summedValue
+	resp, err := client.Get("http://127.0.0.1:8618/stats")
+	if err != nil {
+		errortext := fmt.Sprintf("Couldn't get stats_top from coredns: %T %s\n", err, err)
+		log.Println(errortext)
+		http.Error(w, errortext, http.StatusBadGateway)
+		return
+	}
+	if resp != nil && resp.Body != nil {
+		defer resp.Body.Close()
 	}
-	summed["stats_period"] = "24 hours"
 
-	json, err := json.Marshal(summed)
+	// read the body entirely
+	body, err := ioutil.ReadAll(resp.Body)
 	if err != nil {
-		errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
+		errortext := fmt.Sprintf("Couldn't read response body: %s", err)
 		log.Println(errortext)
-		http.Error(w, errortext, 500)
+		http.Error(w, errortext, http.StatusBadGateway)
 		return
 	}
+
+	// forward body entirely with status code
 	w.Header().Set("Content-Type", "application/json")
-	_, err = w.Write(json)
+	w.Header().Set("Content-Length", strconv.Itoa(len(body)))
+	w.WriteHeader(resp.StatusCode)
+	_, err = w.Write(body)
 	if err != nil {
-		errortext := fmt.Sprintf("Unable to write response json: %s", err)
+		errortext := fmt.Sprintf("Couldn't write body: %s", err)
 		log.Println(errortext)
-		http.Error(w, errortext, 500)
-		return
+		http.Error(w, errortext, http.StatusInternalServerError)
 	}
 }
 
 func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
-	// handle time unit and prepare our time window size
-	now := time.Now()
-	timeUnitString := r.URL.Query().Get("time_unit")
-	var stats *periodicStats
-	var timeUnit time.Duration
-	switch timeUnitString {
-	case "seconds":
-		timeUnit = time.Second
-		stats = &statistics.PerSecond
-	case "minutes":
-		timeUnit = time.Minute
-		stats = &statistics.PerMinute
-	case "hours":
-		timeUnit = time.Hour
-		stats = &statistics.PerHour
-	case "days":
-		timeUnit = time.Hour * 24
-		stats = &statistics.PerDay
-	default:
-		http.Error(w, "Must specify valid time_unit parameter", 400)
-		return
-	}
-
-	// parse start and end time
-	startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
+	resp, err := client.Get("http://127.0.0.1:8618/stats_history?" + r.URL.RawQuery)
 	if err != nil {
-		errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
+		errortext := fmt.Sprintf("Couldn't get stats_top from coredns: %T %s\n", err, err)
 		log.Println(errortext)
-		http.Error(w, errortext, 400)
+		http.Error(w, errortext, http.StatusBadGateway)
 		return
 	}
-	endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
+	if resp != nil && resp.Body != nil {
+		defer resp.Body.Close()
+	}
+
+	// read the body entirely
+	body, err := ioutil.ReadAll(resp.Body)
 	if err != nil {
-		errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
+		errortext := fmt.Sprintf("Couldn't read response body: %s", err)
 		log.Println(errortext)
-		http.Error(w, errortext, 400)
+		http.Error(w, errortext, http.StatusBadGateway)
 		return
 	}
 
-	// check if start and time times are within supported time range
-	timeRange := timeUnit * statsHistoryElements
-	if startTime.Add(timeRange).Before(now) {
-		http.Error(w, "start_time parameter is outside of supported range", 501)
-		return
-	}
-	if endTime.Add(timeRange).Before(now) {
-		http.Error(w, "end_time parameter is outside of supported range", 501)
-		return
-	}
-
-	// calculate start and end of our array
-	// basically it's how many hours/minutes/etc have passed since now
-	start := int(now.Sub(endTime) / timeUnit)
-	end := int(now.Sub(startTime) / timeUnit)
-
-	// swap them around if they're inverted
-	if start > end {
-		start, end = end, start
-	}
-
-	data := generateMapFromStats(stats, start, end)
-	json, err := json.Marshal(data)
-	if err != nil {
-		errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, 500)
-		return
-	}
+	// forward body entirely with status code
 	w.Header().Set("Content-Type", "application/json")
-	_, err = w.Write(json)
+	w.Header().Set("Content-Length", strconv.Itoa(len(body)))
+	w.WriteHeader(resp.StatusCode)
+	_, err = w.Write(body)
 	if err != nil {
-		errortext := fmt.Sprintf("Unable to write response json: %s", err)
+		errortext := fmt.Sprintf("Couldn't write body: %s", err)
 		log.Println(errortext)
-		http.Error(w, errortext, 500)
-		return
+		http.Error(w, errortext, http.StatusInternalServerError)
 	}
 }
 
diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go
index 474c3f22..195a9218 100644
--- a/coredns_plugin/coredns_plugin.go
+++ b/coredns_plugin/coredns_plugin.go
@@ -68,27 +68,6 @@ var defaultPluginSettings = plugSettings{
 	BlockedTTL:            3600, // in seconds
 }
 
-func newDNSCounter(name string, help string) prometheus.Counter {
-	return prometheus.NewCounter(prometheus.CounterOpts{
-		Namespace: plugin.Namespace,
-		Subsystem: "dnsfilter",
-		Name:      name,
-		Help:      help,
-	})
-}
-
-var (
-	requests             = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.")
-	filtered             = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.")
-	filteredLists        = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.")
-	filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.")
-	filteredParental     = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.")
-	filteredInvalid      = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.")
-	whitelisted          = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.")
-	safesearch           = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.")
-	errorsTotal          = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.")
-)
-
 //
 // coredns handling functions
 //
@@ -183,10 +162,10 @@ func setupPlugin(c *caddy.Controller) (*plug, error) {
 		}
 	}
 
-	log.Printf("Loading top from querylog")
-	err := loadTopFromFiles()
+	log.Printf("Loading stats from querylog")
+	err := fillStatsFromQueryLog()
 	if err != nil {
-		log.Printf("Failed to load top from querylog: %s", err)
+		log.Printf("Failed to load stats from querylog: %s", err)
 		return nil, err
 	}
 
@@ -229,6 +208,7 @@ func setup(c *caddy.Controller) error {
 			x.MustRegister(whitelisted)
 			x.MustRegister(safesearch)
 			x.MustRegister(errorsTotal)
+			x.MustRegister(elapsedTime)
 			x.MustRegister(p)
 		}
 		return nil
@@ -562,6 +542,8 @@ func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
 	}
 
 	// log
+	elapsed := time.Since(start)
+	elapsedTime.Observe(elapsed.Seconds())
 	if p.settings.QueryLogEnabled {
 		logRequest(r, rrw.Msg, result, time.Since(start), ip)
 	}
diff --git a/coredns_plugin/coredns_stats.go b/coredns_plugin/coredns_stats.go
new file mode 100644
index 00000000..d57ba397
--- /dev/null
+++ b/coredns_plugin/coredns_stats.go
@@ -0,0 +1,391 @@
+package dnsfilter
+
+import (
+	"encoding/json"
+	"fmt"
+	"log"
+	"net/http"
+	"sync"
+	"time"
+
+	"github.com/coredns/coredns/plugin"
+	"github.com/prometheus/client_golang/prometheus"
+)
+
+var (
+	requests             = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.")
+	filtered             = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.")
+	filteredLists        = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.")
+	filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.")
+	filteredParental     = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.")
+	filteredInvalid      = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.")
+	whitelisted          = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.")
+	safesearch           = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.")
+	errorsTotal          = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.")
+	elapsedTime          = newDNSHistogram("request_duration", "Histogram of the time (in seconds) each request took.")
+)
+
+// entries for single time period (for example all per-second entries)
+type statsEntries map[string][statsHistoryElements]float64
+
+// how far back to keep the stats
+const statsHistoryElements = 60 + 1 // +1 for calculating delta
+
+// each periodic stat is a map of arrays
+type periodicStats struct {
+	Entries    statsEntries
+	period     time.Duration // how long one entry lasts
+	LastRotate time.Time     // last time this data was rotated
+
+	sync.RWMutex
+}
+
+type stats struct {
+	PerSecond periodicStats
+	PerMinute periodicStats
+	PerHour   periodicStats
+	PerDay    periodicStats
+}
+
+// per-second/per-minute/per-hour/per-day stats
+var statistics stats
+
+func initPeriodicStats(periodic *periodicStats, period time.Duration) {
+	periodic.Entries = statsEntries{}
+	periodic.LastRotate = time.Now()
+	periodic.period = period
+}
+
+func init() {
+	purgeStats()
+}
+
+func purgeStats() {
+	initPeriodicStats(&statistics.PerSecond, time.Second)
+	initPeriodicStats(&statistics.PerMinute, time.Minute)
+	initPeriodicStats(&statistics.PerHour, time.Hour)
+	initPeriodicStats(&statistics.PerDay, time.Hour*24)
+}
+
+func (p *periodicStats) Inc(name string, when time.Time) {
+	// calculate how many periods ago this happened
+	elapsed := int64(time.Since(when) / p.period)
+	// trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed)
+	if elapsed >= statsHistoryElements {
+		return // outside of our timeframe
+	}
+	p.Lock()
+	currentValues := p.Entries[name]
+	currentValues[elapsed]++
+	p.Entries[name] = currentValues
+	p.Unlock()
+}
+
+func (p *periodicStats) Observe(name string, when time.Time, value float64) {
+	// calculate how many periods ago this happened
+	elapsed := int64(time.Since(when) / p.period)
+	// trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed)
+	if elapsed >= statsHistoryElements {
+		return // outside of our timeframe
+	}
+	p.Lock()
+	{
+		countname := name + "_count"
+		currentValues := p.Entries[countname]
+		value := currentValues[elapsed]
+		// trace("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1)
+		value += 1
+		currentValues[elapsed] = value
+		p.Entries[countname] = currentValues
+	}
+	{
+		totalname := name + "_sum"
+		currentValues := p.Entries[totalname]
+		currentValues[elapsed] += value
+		p.Entries[totalname] = currentValues
+	}
+	p.Unlock()
+}
+
+func (p *periodicStats) statsRotate(now time.Time) {
+	p.Lock()
+	rotations := int64(now.Sub(p.LastRotate) / p.period)
+	if rotations > statsHistoryElements {
+		rotations = statsHistoryElements
+	}
+	// calculate how many times we should rotate
+	for r := int64(0); r < rotations; r++ {
+		for key, values := range p.Entries {
+			newValues := [statsHistoryElements]float64{}
+			for i := 1; i < len(values); i++ {
+				newValues[i] = values[i-1]
+			}
+			p.Entries[key] = newValues
+		}
+	}
+	if rotations > 0 {
+		p.LastRotate = now
+	}
+	p.Unlock()
+}
+
+func statsRotator() {
+	for range time.Tick(time.Second) {
+		now := time.Now()
+		statistics.PerSecond.statsRotate(now)
+		statistics.PerMinute.statsRotate(now)
+		statistics.PerHour.statsRotate(now)
+		statistics.PerDay.statsRotate(now)
+	}
+}
+
+// counter that wraps around prometheus Counter but also adds to periodic stats
+type counter struct {
+	name  string // used as key in periodic stats
+	value int64
+	prom  prometheus.Counter
+}
+
+func newDNSCounter(name string, help string) *counter {
+	// trace("called")
+	c := &counter{}
+	c.prom = prometheus.NewCounter(prometheus.CounterOpts{
+		Namespace: plugin.Namespace,
+		Subsystem: "dnsfilter",
+		Name:      name,
+		Help:      help,
+	})
+	c.name = name
+
+	return c
+}
+
+func (c *counter) IncWithTime(when time.Time) {
+	statistics.PerSecond.Inc(c.name, when)
+	statistics.PerMinute.Inc(c.name, when)
+	statistics.PerHour.Inc(c.name, when)
+	statistics.PerDay.Inc(c.name, when)
+	c.value++
+	c.prom.Inc()
+}
+
+func (c *counter) Inc() {
+	c.IncWithTime(time.Now())
+}
+
+func (c *counter) Describe(ch chan<- *prometheus.Desc) {
+	c.prom.Describe(ch)
+}
+
+func (c *counter) Collect(ch chan<- prometheus.Metric) {
+	c.prom.Collect(ch)
+}
+
+type histogram struct {
+	name  string // used as key in periodic stats
+	count int64
+	total float64
+	prom  prometheus.Histogram
+}
+
+func newDNSHistogram(name string, help string) *histogram {
+	// trace("called")
+	h := &histogram{}
+	h.prom = prometheus.NewHistogram(prometheus.HistogramOpts{
+		Namespace: plugin.Namespace,
+		Subsystem: "dnsfilter",
+		Name:      name,
+		Help:      help,
+	})
+	h.name = name
+
+	return h
+}
+
+func (h *histogram) ObserveWithTime(value float64, when time.Time) {
+	statistics.PerSecond.Observe(h.name, when, value)
+	statistics.PerMinute.Observe(h.name, when, value)
+	statistics.PerHour.Observe(h.name, when, value)
+	statistics.PerDay.Observe(h.name, when, value)
+	h.count++
+	h.total += value
+	h.prom.Observe(value)
+}
+
+func (h *histogram) Observe(value float64) {
+	h.ObserveWithTime(value, time.Now())
+}
+
+func (h *histogram) Describe(ch chan<- *prometheus.Desc) {
+	h.prom.Describe(ch)
+}
+
+func (h *histogram) Collect(ch chan<- prometheus.Metric) {
+	h.prom.Collect(ch)
+}
+
+// -----
+// stats
+// -----
+func handleStats(w http.ResponseWriter, r *http.Request) {
+	histrical := generateMapFromStats(&statistics.PerHour, 0, 24)
+	// sum them up
+	summed := map[string]interface{}{}
+	for key, values := range histrical {
+		summedValue := 0.0
+		floats, ok := values.([]float64)
+		if !ok {
+			continue
+		}
+		for _, v := range floats {
+			summedValue += v
+		}
+		summed[key] = summedValue
+	}
+	summed["stats_period"] = "24 hours"
+
+	json, err := json.Marshal(summed)
+	if err != nil {
+		errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
+		log.Println(errortext)
+		http.Error(w, errortext, 500)
+		return
+	}
+	w.Header().Set("Content-Type", "application/json")
+	_, err = w.Write(json)
+	if err != nil {
+		errortext := fmt.Sprintf("Unable to write response json: %s", err)
+		log.Println(errortext)
+		http.Error(w, errortext, 500)
+		return
+	}
+}
+
+func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} {
+	// clamp
+	start = clamp(start, 0, statsHistoryElements)
+	end = clamp(end, 0, statsHistoryElements)
+
+	avgProcessingTime := make([]float64, 0)
+
+	count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end)
+	sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end)
+	for i := 0; i < len(count); i++ {
+		var avg float64
+		if count[i] != 0 {
+			avg = sum[i] / count[i]
+			avg *= 1000
+		}
+		avgProcessingTime = append(avgProcessingTime, avg)
+	}
+
+	result := map[string]interface{}{
+		"dns_queries":           getReversedSlice(stats.Entries[requests.name], start, end),
+		"blocked_filtering":     getReversedSlice(stats.Entries[filtered.name], start, end),
+		"replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end),
+		"replaced_safesearch":   getReversedSlice(stats.Entries[safesearch.name], start, end),
+		"replaced_parental":     getReversedSlice(stats.Entries[filteredParental.name], start, end),
+		"avg_processing_time":   avgProcessingTime,
+	}
+	return result
+}
+
+func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
+	// handle time unit and prepare our time window size
+	now := time.Now()
+	timeUnitString := r.URL.Query().Get("time_unit")
+	var stats *periodicStats
+	var timeUnit time.Duration
+	switch timeUnitString {
+	case "seconds":
+		timeUnit = time.Second
+		stats = &statistics.PerSecond
+	case "minutes":
+		timeUnit = time.Minute
+		stats = &statistics.PerMinute
+	case "hours":
+		timeUnit = time.Hour
+		stats = &statistics.PerHour
+	case "days":
+		timeUnit = time.Hour * 24
+		stats = &statistics.PerDay
+	default:
+		http.Error(w, "Must specify valid time_unit parameter", 400)
+		return
+	}
+
+	// parse start and end time
+	startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
+	if err != nil {
+		errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
+		log.Println(errortext)
+		http.Error(w, errortext, 400)
+		return
+	}
+	endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
+	if err != nil {
+		errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
+		log.Println(errortext)
+		http.Error(w, errortext, 400)
+		return
+	}
+
+	// check if start and time times are within supported time range
+	timeRange := timeUnit * statsHistoryElements
+	if startTime.Add(timeRange).Before(now) {
+		http.Error(w, "start_time parameter is outside of supported range", 501)
+		return
+	}
+	if endTime.Add(timeRange).Before(now) {
+		http.Error(w, "end_time parameter is outside of supported range", 501)
+		return
+	}
+
+	// calculate start and end of our array
+	// basically it's how many hours/minutes/etc have passed since now
+	start := int(now.Sub(endTime) / timeUnit)
+	end := int(now.Sub(startTime) / timeUnit)
+
+	// swap them around if they're inverted
+	if start > end {
+		start, end = end, start
+	}
+
+	data := generateMapFromStats(stats, start, end)
+	json, err := json.Marshal(data)
+	if err != nil {
+		errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
+		log.Println(errortext)
+		http.Error(w, errortext, 500)
+		return
+	}
+	w.Header().Set("Content-Type", "application/json")
+	_, err = w.Write(json)
+	if err != nil {
+		errortext := fmt.Sprintf("Unable to write response json: %s", err)
+		log.Println(errortext)
+		http.Error(w, errortext, 500)
+		return
+	}
+}
+
+func clamp(value, low, high int) int {
+	if value < low {
+		return low
+	}
+	if value > high {
+		return high
+	}
+	return value
+}
+
+// --------------------------
+// helper functions for stats
+// --------------------------
+func getReversedSlice(input [statsHistoryElements]float64, start int, end int) []float64 {
+	output := make([]float64, 0)
+	for i := start; i <= end; i++ {
+		output = append([]float64{input[i]}, output...)
+	}
+	return output
+}
diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go
index 9e45315f..ecf9185b 100644
--- a/coredns_plugin/querylog.go
+++ b/coredns_plugin/querylog.go
@@ -88,7 +88,7 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
 	logBufferLock.Unlock()
 
 	// add it to running top
-	err = runningTop.addEntry(&entry, now)
+	err = runningTop.addEntry(&entry, question, now)
 	if err != nil {
 		log.Printf("Failed to add entry to running top: %s", err)
 		// don't do failure, just log
@@ -100,7 +100,6 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
 		// do it in separate goroutine -- we are stalling DNS response this whole time
 		go flushToFile(flushBuffer)
 	}
-	return
 }
 
 func handleQueryLog(w http.ResponseWriter, r *http.Request) {
@@ -114,7 +113,7 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) {
 	if needRefresh {
 		// need to get fresh data
 		logBufferLock.RLock()
-		values := logBuffer
+		values = logBuffer
 		logBufferLock.RUnlock()
 
 		if len(values) < queryLogCacheSize {
@@ -238,9 +237,12 @@ func startQueryLogServer() {
 
 	go periodicQueryLogRotate()
 	go periodicHourlyTopRotate()
+	go statsRotator()
 
 	http.HandleFunc("/querylog", handleQueryLog)
+	http.HandleFunc("/stats", handleStats)
 	http.HandleFunc("/stats_top", handleStatsTop)
+	http.HandleFunc("/stats_history", handleStatsHistory)
 	if err := http.ListenAndServe(listenAddr, nil); err != nil {
 		log.Fatalf("error in ListenAndServe: %s", err)
 	}
diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go
index 7025fcd3..17bd23df 100644
--- a/coredns_plugin/querylog_file.go
+++ b/coredns_plugin/querylog_file.go
@@ -204,7 +204,8 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti
 				return err
 			}
 		}
-		log.Printf("file \"%s\": read %d entries", file, i)
+		elapsed := time.Since(now)
+		log.Printf("file \"%s\": read %d entries in %v, %v/entry", file, i, elapsed, elapsed/time.Duration(i))
 	}
 	return nil
 }
diff --git a/coredns_plugin/querylog_top.go b/coredns_plugin/querylog_top.go
index 7e92a7f5..6ff531f2 100644
--- a/coredns_plugin/querylog_top.go
+++ b/coredns_plugin/querylog_top.go
@@ -14,6 +14,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/AdguardTeam/AdguardDNS/dnsfilter"
 	"github.com/bluele/gcache"
 	"github.com/miekg/dns"
 )
@@ -155,16 +156,7 @@ func (top *hourTop) lockedGetClients(key string) (int, error) {
 	return top.lockedGetValue(key, top.clients)
 }
 
-func (r *dayTop) addEntry(entry *logEntry, now time.Time) error {
-	if len(entry.Question) == 0 {
-		log.Printf("entry question is absent, skipping")
-		return nil
-	}
-
-	if entry.Time.After(now) {
-		log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now)
-		return nil
-	}
+func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
 	// figure out which hour bucket it belongs to
 	hour := int(now.Sub(entry.Time).Hours())
 	if hour >= 24 {
@@ -172,17 +164,6 @@ func (r *dayTop) addEntry(entry *logEntry, now time.Time) error {
 		return nil
 	}
 
-	q := new(dns.Msg)
-	if err := q.Unpack(entry.Question); err != nil {
-		log.Printf("failed to unpack dns message question: %s", err)
-		return err
-	}
-
-	if len(q.Question) != 1 {
-		log.Printf("malformed dns message, has no questions, skipping")
-		return nil
-	}
-
 	hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, "."))
 
 	// get value, if not set, crate one
@@ -213,7 +194,7 @@ func (r *dayTop) addEntry(entry *logEntry, now time.Time) error {
 	return nil
 }
 
-func loadTopFromFiles() error {
+func fillStatsFromQueryLog() error {
 	now := time.Now()
 	runningTop.loadedWriteLock()
 	defer runningTop.loadedWriteUnlock()
@@ -221,11 +202,55 @@ func loadTopFromFiles() error {
 		return nil
 	}
 	onEntry := func(entry *logEntry) error {
-		err := runningTop.addEntry(entry, now)
+		if len(entry.Question) == 0 {
+			log.Printf("entry question is absent, skipping")
+			return nil
+		}
+
+		if entry.Time.After(now) {
+			log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now)
+			return nil
+		}
+
+		q := new(dns.Msg)
+		if err := q.Unpack(entry.Question); err != nil {
+			log.Printf("failed to unpack dns message question: %s", err)
+			return err
+		}
+
+		if len(q.Question) != 1 {
+			log.Printf("malformed dns message, has no questions, skipping")
+			return nil
+		}
+
+		err := runningTop.addEntry(entry, q, now)
 		if err != nil {
 			log.Printf("Failed to add entry to running top: %s", err)
 			return err
 		}
+
+		requests.IncWithTime(entry.Time)
+		if entry.Result.IsFiltered {
+			filtered.IncWithTime(entry.Time)
+		}
+		switch entry.Result.Reason {
+		case dnsfilter.NotFilteredWhiteList:
+			whitelisted.IncWithTime(entry.Time)
+		case dnsfilter.NotFilteredError:
+			errorsTotal.IncWithTime(entry.Time)
+		case dnsfilter.FilteredBlackList:
+			filteredLists.IncWithTime(entry.Time)
+		case dnsfilter.FilteredSafeBrowsing:
+			filteredSafebrowsing.IncWithTime(entry.Time)
+		case dnsfilter.FilteredParental:
+			filteredParental.IncWithTime(entry.Time)
+		case dnsfilter.FilteredInvalid:
+			// do nothing
+		case dnsfilter.FilteredSafeSearch:
+			safesearch.IncWithTime(entry.Time)
+		}
+		elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
+
 		return nil
 	}
 

From ca794aed6344cf0f5570bb359c731a96ca90739f Mon Sep 17 00:00:00 2001
From: Eugene Bujak <hmage@hmage.net>
Date: Tue, 9 Oct 2018 05:02:16 +0300
Subject: [PATCH 2/2] querylog file -- disable gzip compression

---
 coredns_plugin/querylog_file.go | 91 +++++++++++++++++++++------------
 1 file changed, 59 insertions(+), 32 deletions(-)

diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go
index 17bd23df..72cd4d32 100644
--- a/coredns_plugin/querylog_file.go
+++ b/coredns_plugin/querylog_file.go
@@ -17,6 +17,8 @@ var (
 	fileWriteLock sync.Mutex
 )
 
+const enableGzip = false
+
 func flushToFile(buffer []logEntry) error {
 	if len(buffer) == 0 {
 		return nil
@@ -42,31 +44,37 @@ func flushToFile(buffer []logEntry) error {
 		return err
 	}
 
-	filenamegz := queryLogFileName + ".gz"
-
 	var zb bytes.Buffer
+	filename := queryLogFileName
 
-	zw := gzip.NewWriter(&zb)
-	zw.Name = queryLogFileName
-	zw.ModTime = time.Now()
+	// gzip enabled?
+	if enableGzip {
+		filename += ".gz"
 
-	_, err = zw.Write(b.Bytes())
-	if err != nil {
-		log.Printf("Couldn't compress to gzip: %s", err)
-		zw.Close()
-		return err
-	}
+		zw := gzip.NewWriter(&zb)
+		zw.Name = queryLogFileName
+		zw.ModTime = time.Now()
 
-	if err = zw.Close(); err != nil {
-		log.Printf("Couldn't close gzip writer: %s", err)
-		return err
+		_, err = zw.Write(b.Bytes())
+		if err != nil {
+			log.Printf("Couldn't compress to gzip: %s", err)
+			zw.Close()
+			return err
+		}
+
+		if err = zw.Close(); err != nil {
+			log.Printf("Couldn't close gzip writer: %s", err)
+			return err
+		}
+	} else {
+		zb = b
 	}
 
 	fileWriteLock.Lock()
 	defer fileWriteLock.Unlock()
-	f, err := os.OpenFile(filenamegz, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
+	f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
 	if err != nil {
-		log.Printf("failed to create file \"%s\": %s", filenamegz, err)
+		log.Printf("failed to create file \"%s\": %s", filename, err)
 		return err
 	}
 	defer f.Close()
@@ -77,7 +85,7 @@ func flushToFile(buffer []logEntry) error {
 		return err
 	}
 
-	log.Printf("ok \"%s\": %v bytes written", filenamegz, n)
+	log.Printf("ok \"%s\": %v bytes written", filename, n)
 
 	return nil
 }
@@ -111,8 +119,13 @@ func checkBuffer(buffer []logEntry, b bytes.Buffer) error {
 }
 
 func rotateQueryLog() error {
-	from := queryLogFileName + ".gz"
-	to := queryLogFileName + ".gz.1"
+	from := queryLogFileName
+	to := queryLogFileName + ".1"
+
+	if enableGzip {
+		from = queryLogFileName + ".gz"
+		to = queryLogFileName + ".gz.1"
+	}
 
 	if _, err := os.Stat(from); os.IsNotExist(err) {
 		// do nothing, file doesn't exist
@@ -143,9 +156,18 @@ func periodicQueryLogRotate() {
 func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error {
 	now := time.Now()
 	// read from querylog files, try newest file first
-	files := []string{
-		queryLogFileName + ".gz",
-		queryLogFileName + ".gz.1",
+	files := []string{}
+
+	if enableGzip {
+		files = []string{
+			queryLogFileName + ".gz",
+			queryLogFileName + ".gz.1",
+		}
+	} else {
+		files = []string{
+			queryLogFileName,
+			queryLogFileName + ".1",
+		}
 	}
 
 	// read from all files
@@ -158,7 +180,6 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti
 			continue
 		}
 
-		trace("Opening file %s", file)
 		f, err := os.Open(file)
 		if err != nil {
 			log.Printf("Failed to open file \"%s\": %s", file, err)
@@ -167,16 +188,22 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti
 		}
 		defer f.Close()
 
-		trace("Creating gzip reader")
-		zr, err := gzip.NewReader(f)
-		if err != nil {
-			log.Printf("Failed to create gzip reader: %s", err)
-			continue
-		}
-		defer zr.Close()
+		var d *json.Decoder
 
-		trace("Creating json decoder")
-		d := json.NewDecoder(zr)
+		if enableGzip {
+			trace("Creating gzip reader")
+			zr, err := gzip.NewReader(f)
+			if err != nil {
+				log.Printf("Failed to create gzip reader: %s", err)
+				continue
+			}
+			defer zr.Close()
+
+			trace("Creating json decoder")
+			d = json.NewDecoder(zr)
+		} else {
+			d = json.NewDecoder(f)
+		}
 
 		i := 0
 		// entries on file are in oldest->newest order