From 9a03190a622068eb9bf892ffabdd9e2f39535fef Mon Sep 17 00:00:00 2001
From: Andrey Meshkov <am@adguard.com>
Date: Sun, 10 Feb 2019 20:47:43 +0300
Subject: [PATCH] Fix #579

1. Added --workdir command-line argument that lets configure the working dir.
2. Made "dnsforward" use this workdir parameter when saving/reading querylog.
3. Reworked "dnsforward" -- moved http handlers out of there to control.go
---
 app.go                        |  21 ++-
 config.go                     |   4 +-
 control.go                    | 234 ++++++++++++++++++++++++---
 dhcp.go                       |   4 +-
 dns.go                        |  17 +-
 dnsforward/dnsforward.go      |  70 +++++---
 dnsforward/dnsforward_test.go |  97 +++++++++--
 dnsforward/querylog.go        |  87 +++++-----
 dnsforward/querylog_file.go   |  46 ++++--
 dnsforward/querylog_top.go    | 165 ++++++-------------
 dnsforward/stats.go           | 293 ++++++++++++++--------------------
 filter.go                     |   2 +-
 go.mod                        |   2 +
 helpers.go                    |   2 +-
 upgrade.go                    |   4 +-
 15 files changed, 630 insertions(+), 418 deletions(-)

diff --git a/app.go b/app.go
index e7ca88c8..1580bcb0 100644
--- a/app.go
+++ b/app.go
@@ -66,7 +66,7 @@ func run(args options) {
 
 	// print the first message after logger is configured
 	log.Printf("AdGuard Home, version %s\n", VersionString)
-	log.Tracef("Current working directory is %s", config.ourBinaryDir)
+	log.Tracef("Current working directory is %s", config.ourWorkingDir)
 	if args.runningAsService {
 		log.Printf("AdGuard Home is running as a service")
 	}
@@ -117,6 +117,10 @@ func run(args options) {
 		log.Fatal(err)
 	}
 
+	// Init the DNS server instance before registering HTTP handlers
+	dnsBaseDir := filepath.Join(config.ourWorkingDir, dataDir)
+	initDNSServer(dnsBaseDir)
+
 	if !config.firstRun {
 		err = startDNSServer()
 		if err != nil {
@@ -172,18 +176,19 @@ func run(args options) {
 	}
 }
 
-// initWorkingDir initializes the ourBinaryDir (basically, we use it as a working dir)
+// initWorkingDir initializes the ourWorkingDir
+// if no command-line arguments specified, we use the directory where our binary file is located
 func initWorkingDir(args options) {
 	exec, err := os.Executable()
 	if err != nil {
 		panic(err)
 	}
 
-	if args.configFilename != "" {
+	if args.workDir != "" {
 		// If there is a custom config file, use it's directory as our working dir
-		config.ourBinaryDir = filepath.Dir(args.configFilename)
+		config.ourWorkingDir = args.workDir
 	} else {
-		config.ourBinaryDir = filepath.Dir(exec)
+		config.ourWorkingDir = filepath.Dir(exec)
 	}
 }
 
@@ -218,7 +223,7 @@ func configureLogger(args options) {
 			log.Fatalf("cannot initialize syslog: %s", err)
 		}
 	} else {
-		logFilePath := filepath.Join(config.ourBinaryDir, ls.LogFile)
+		logFilePath := filepath.Join(config.ourWorkingDir, ls.LogFile)
 		file, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0755)
 		if err != nil {
 			log.Fatalf("cannot create a log file: %s", err)
@@ -244,6 +249,7 @@ func cleanup() {
 type options struct {
 	verbose        bool   // is verbose logging enabled
 	configFilename string // path to the config file
+	workDir        string // path to the working directory where we will store the filters data and the querylog
 	bindHost       string // host address to bind HTTP server on
 	bindPort       int    // port to serve HTTP pages on
 	logFile        string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
@@ -267,7 +273,8 @@ func loadOptions() options {
 		callbackWithValue func(value string)
 		callbackNoValue   func()
 	}{
-		{"config", "c", "path to config file", func(value string) { o.configFilename = value }, nil},
+		{"config", "c", "path to the config file", func(value string) { o.configFilename = value }, nil},
+		{"work-dir", "w", "path to the working directory", func(value string) { o.workDir = value }, nil},
 		{"host", "h", "host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil},
 		{"port", "p", "port to serve HTTP pages on", func(value string) {
 			v, err := strconv.Atoi(value)
diff --git a/config.go b/config.go
index 751e3c41..ba30c027 100644
--- a/config.go
+++ b/config.go
@@ -28,7 +28,7 @@ type logSettings struct {
 // field ordering is important -- yaml fields will mirror ordering from here
 type configuration struct {
 	ourConfigFilename string // Config filename (can be overridden via the command line arguments)
-	ourBinaryDir      string // Location of our directory, used to protect against CWD being somewhere else
+	ourWorkingDir     string // Location of our directory, used to protect against CWD being somewhere else
 	firstRun          bool   // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
 
 	BindHost  string             `yaml:"bind_host"`
@@ -92,7 +92,7 @@ var config = configuration{
 func (c *configuration) getConfigFilename() string {
 	configFile := config.ourConfigFilename
 	if !filepath.IsAbs(configFile) {
-		configFile = filepath.Join(config.ourBinaryDir, config.ourConfigFilename)
+		configFile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename)
 	}
 	return configFile
 }
diff --git a/control.go b/control.go
index 2077549d..fffb2d27 100644
--- a/control.go
+++ b/control.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"bytes"
 	"context"
 	"encoding/json"
 	"fmt"
@@ -8,6 +9,7 @@ import (
 	"net"
 	"net/http"
 	"os"
+	"sort"
 	"strconv"
 	"strings"
 	"time"
@@ -32,9 +34,28 @@ var client = &http.Client{
 	Timeout: time.Second * 30,
 }
 
-// -------------------
+// ----------------
+// helper functions
+// ----------------
+
+func returnOK(w http.ResponseWriter) {
+	_, err := fmt.Fprintf(w, "OK\n")
+	if err != nil {
+		errorText := fmt.Sprintf("Couldn't write body: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusInternalServerError)
+	}
+}
+
+func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
+	text := fmt.Sprintf(format, args...)
+	log.Println(text)
+	http.Error(w, text, code)
+}
+
+// ---------------
 // dns run control
-// -------------------
+// ---------------
 func writeAllConfigsAndReloadDNS() error {
 	err := writeAllConfigs()
 	if err != nil {
@@ -55,15 +76,6 @@ func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
 	returnOK(w)
 }
 
-func returnOK(w http.ResponseWriter) {
-	_, err := fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errorText := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, http.StatusInternalServerError)
-	}
-}
-
 func handleStatus(w http.ResponseWriter, r *http.Request) {
 	data := map[string]interface{}{
 		"dns_address":        config.BindHost,
@@ -117,12 +129,190 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
 	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
-func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
-	text := fmt.Sprintf(format, args...)
-	log.Println(text)
-	http.Error(w, text, code)
+func handleQueryLog(w http.ResponseWriter, r *http.Request) {
+	data := dnsServer.GetQueryLog()
+
+	jsonVal, err := json.Marshal(data)
+	if err != nil {
+		errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusInternalServerError)
+		return
+	}
+
+	w.Header().Set("Content-Type", "application/json")
+	_, err = w.Write(jsonVal)
+	if err != nil {
+		errorText := fmt.Sprintf("Unable to write response json: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusInternalServerError)
+	}
 }
 
+func handleStatsTop(w http.ResponseWriter, r *http.Request) {
+	s := dnsServer.GetStatsTop()
+
+	// use manual json marshalling because we want maps to be sorted by value
+	statsJSON := bytes.Buffer{}
+	statsJSON.WriteString("{\n")
+
+	gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) {
+		json.WriteString("  ")
+		json.WriteString(fmt.Sprintf("%q", name))
+		json.WriteString(": {\n")
+		sorted := sortByValue(top)
+		// no more than 50 entries
+		if len(sorted) > 50 {
+			sorted = sorted[:50]
+		}
+		for i, key := range sorted {
+			json.WriteString("    ")
+			json.WriteString(fmt.Sprintf("%q", key))
+			json.WriteString(": ")
+			json.WriteString(strconv.Itoa(top[key]))
+			if i+1 != len(sorted) {
+				json.WriteByte(',')
+			}
+			json.WriteByte('\n')
+		}
+		json.WriteString("  }")
+		if addComma {
+			json.WriteByte(',')
+		}
+		json.WriteByte('\n')
+	}
+	gen(&statsJSON, "top_queried_domains", s.Domains, true)
+	gen(&statsJSON, "top_blocked_domains", s.Blocked, true)
+	gen(&statsJSON, "top_clients", s.Clients, true)
+	statsJSON.WriteString("  \"stats_period\": \"24 hours\"\n")
+	statsJSON.WriteString("}\n")
+
+	w.Header().Set("Content-Type", "application/json")
+	_, err := w.Write(statsJSON.Bytes())
+	if err != nil {
+		errorText := fmt.Sprintf("Couldn't write body: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusInternalServerError)
+	}
+}
+
+// handleStatsReset resets the stats caches
+func handleStatsReset(w http.ResponseWriter, r *http.Request) {
+	dnsServer.ResetStats()
+	_, err := fmt.Fprintf(w, "OK\n")
+	if err != nil {
+		errorText := fmt.Sprintf("Couldn't write body: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusInternalServerError)
+	}
+}
+
+// handleStats returns aggregated stats data for the 24 hours
+func handleStats(w http.ResponseWriter, r *http.Request) {
+	summed := dnsServer.GetAggregatedStats()
+
+	statsJSON, err := json.Marshal(summed)
+	if err != nil {
+		errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, 500)
+		return
+	}
+	w.Header().Set("Content-Type", "application/json")
+	_, err = w.Write(statsJSON)
+	if err != nil {
+		errorText := fmt.Sprintf("Unable to write response json: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, 500)
+		return
+	}
+}
+
+// HandleStatsHistory returns historical stats data for the 24 hours
+func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
+	// handle time unit and prepare our time window size
+	timeUnitString := r.URL.Query().Get("time_unit")
+	var timeUnit time.Duration
+	switch timeUnitString {
+	case "seconds":
+		timeUnit = time.Second
+	case "minutes":
+		timeUnit = time.Minute
+	case "hours":
+		timeUnit = time.Hour
+	case "days":
+		timeUnit = time.Hour * 24
+	default:
+		http.Error(w, "Must specify valid time_unit parameter", http.StatusBadRequest)
+		return
+	}
+
+	// parse start and end time
+	startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
+	if err != nil {
+		errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusBadRequest)
+		return
+	}
+	endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
+	if err != nil {
+		errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusBadRequest)
+		return
+	}
+
+	data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
+	if err != nil {
+		errorText := fmt.Sprintf("Cannot get stats history: %s", err)
+		http.Error(w, errorText, http.StatusBadRequest)
+		return
+	}
+
+	statsJSON, err := json.Marshal(data)
+	if err != nil {
+		errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusInternalServerError)
+		return
+	}
+
+	w.Header().Set("Content-Type", "application/json")
+	_, err = w.Write(statsJSON)
+	if err != nil {
+		errorText := fmt.Sprintf("Unable to write response json: %s", err)
+		log.Println(errorText)
+		http.Error(w, errorText, http.StatusInternalServerError)
+		return
+	}
+}
+
+// sortByValue is a helper function for querylog API
+func sortByValue(m map[string]int) []string {
+	type kv struct {
+		k string
+		v int
+	}
+	var ss []kv
+	for k, v := range m {
+		ss = append(ss, kv{k, v})
+	}
+	sort.Slice(ss, func(l, r int) bool {
+		return ss[l].v > ss[r].v
+	})
+
+	sorted := []string{}
+	for _, v := range ss {
+		sorted = append(sorted, v.k)
+	}
+	return sorted
+}
+
+// -----------------------
+// upstreams configuration
+// -----------------------
+
 func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
 	body, err := ioutil.ReadAll(r.Body)
 	if err != nil {
@@ -737,8 +927,8 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
 
 	data.Interfaces = make(map[string]interface{})
 	for _, iface := range ifaces {
-		addrs, err := iface.Addrs()
-		if err != nil {
+		addrs, e := iface.Addrs()
+		if e != nil {
 			httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
 			return
 		}
@@ -844,17 +1034,17 @@ func registerControlHandlers() {
 	http.HandleFunc("/control/status", postInstall(optionalAuth(ensureGET(handleStatus))))
 	http.HandleFunc("/control/enable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionEnable))))
 	http.HandleFunc("/control/disable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionDisable))))
-	http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(dnsforward.HandleQueryLog))))
+	http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(handleQueryLog))))
 	http.HandleFunc("/control/querylog_enable", postInstall(optionalAuth(ensurePOST(handleQueryLogEnable))))
 	http.HandleFunc("/control/querylog_disable", postInstall(optionalAuth(ensurePOST(handleQueryLogDisable))))
 	http.HandleFunc("/control/set_upstream_dns", postInstall(optionalAuth(ensurePOST(handleSetUpstreamDNS))))
 	http.HandleFunc("/control/test_upstream_dns", postInstall(optionalAuth(ensurePOST(handleTestUpstreamDNS))))
 	http.HandleFunc("/control/i18n/change_language", postInstall(optionalAuth(ensurePOST(handleI18nChangeLanguage))))
 	http.HandleFunc("/control/i18n/current_language", postInstall(optionalAuth(ensureGET(handleI18nCurrentLanguage))))
-	http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsTop))))
-	http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(dnsforward.HandleStats))))
-	http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsHistory))))
-	http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(dnsforward.HandleStatsReset))))
+	http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(handleStatsTop))))
+	http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(handleStats))))
+	http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(handleStatsHistory))))
+	http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(handleStatsReset))))
 	http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
 	http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable))))
 	http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable))))
diff --git a/dhcp.go b/dhcp.go
index a67b0ef6..f2f8201d 100644
--- a/dhcp.go
+++ b/dhcp.go
@@ -85,8 +85,8 @@ func handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
 			// this interface doesn't support broadcast, skip it
 			continue
 		}
-		addrs, err := iface.Addrs()
-		if err != nil {
+		addrs, e := iface.Addrs()
+		if e != nil {
 			httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
 			return
 		}
diff --git a/dns.go b/dns.go
index 12c71def..3e800892 100644
--- a/dns.go
+++ b/dns.go
@@ -3,6 +3,7 @@ package main
 import (
 	"fmt"
 	"net"
+	"os"
 
 	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
 	"github.com/AdguardTeam/AdGuardHome/dnsforward"
@@ -11,10 +12,22 @@ import (
 	"github.com/joomcode/errorx"
 )
 
-var dnsServer = dnsforward.Server{}
+var dnsServer *dnsforward.Server
+
+// initDNSServer creates an instance of the dnsforward.Server
+// Please note that we must do it even if we don't start it
+// so that we had access to the query log and the stats
+func initDNSServer(baseDir string) {
+	err := os.MkdirAll(baseDir, 0755)
+	if err != nil {
+		log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
+	}
+
+	dnsServer = dnsforward.NewServer(baseDir)
+}
 
 func isRunning() bool {
-	return dnsServer.IsRunning()
+	return dnsServer != nil && dnsServer.IsRunning()
 }
 
 func generateServerConfig() dnsforward.ServerConfig {
diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go
index b2cf0556..e9d87a08 100644
--- a/dnsforward/dnsforward.go
+++ b/dnsforward/dnsforward.go
@@ -35,14 +35,24 @@ const (
 //
 // The zero Server is empty and ready for use.
 type Server struct {
-	dnsProxy *proxy.Proxy // DNS proxy instance
-
+	dnsProxy  *proxy.Proxy         // DNS proxy instance
 	dnsFilter *dnsfilter.Dnsfilter // DNS filter instance
+	queryLog  *queryLog            // Query log instance
+	stats     *stats               // General server statistics
 
 	sync.RWMutex
 	ServerConfig
 }
 
+// NewServer creates a new instance of the dnsforward.Server
+// baseDir is the base directory for query logs
+func NewServer(baseDir string) *Server {
+	return &Server{
+		queryLog: newQueryLog(baseDir),
+		stats:    newStats(),
+	}
+}
+
 // FilteringConfig represents the DNS filtering configuration of AdGuard Home
 type FilteringConfig struct {
 	ProtectionEnabled  bool     `yaml:"protection_enabled"`   // whether or not use any of dnsfilter features
@@ -111,15 +121,16 @@ func (s *Server) startInternal(config *ServerConfig) error {
 	}
 
 	log.Tracef("Loading stats from querylog")
-	err = fillStatsFromQueryLog()
+	err = s.queryLog.fillStatsFromQueryLog(s.stats)
 	if err != nil {
 		return errorx.Decorate(err, "failed to load stats from querylog")
 	}
 
+	// TODO: Start starts rotators, stop stops rotators
 	once.Do(func() {
-		go periodicQueryLogRotate()
-		go periodicHourlyTopRotate()
-		go statsRotator()
+		go s.queryLog.periodicQueryLogRotate()
+		go s.queryLog.runningTop.periodicHourlyTopRotate()
+		go s.stats.statsRotator()
 	})
 
 	proxyConfig := proxy.Config{
@@ -187,17 +198,7 @@ func (s *Server) stopInternal() error {
 	}
 
 	// flush remainder to file
-	logBufferLock.Lock()
-	flushBuffer := logBuffer
-	logBuffer = nil
-	logBufferLock.Unlock()
-	err := flushToFile(flushBuffer)
-	if err != nil {
-		log.Printf("Saving querylog to file failed: %s", err)
-		return err
-	}
-
-	return nil
+	return s.queryLog.clearLogBuffer()
 }
 
 // IsRunning returns true if the DNS server is running
@@ -229,6 +230,36 @@ func (s *Server) Reconfigure(config *ServerConfig) error {
 	return nil
 }
 
+// GetQueryLog returns a map with the current query log ready to be converted to a JSON
+func (s *Server) GetQueryLog() []map[string]interface{} {
+	return s.queryLog.getQueryLog()
+}
+
+// GetStatsTop returns the current stop stats
+func (s *Server) GetStatsTop() *StatsTop {
+	return s.queryLog.runningTop.getStatsTop()
+}
+
+// ResetStats purges current server stats
+func (s *Server) ResetStats() {
+	// TODO: Locks?
+	s.stats.purgeStats()
+}
+
+// GetAggregatedStats returns aggregated stats data for the 24 hours
+func (s *Server) GetAggregatedStats() map[string]interface{} {
+	return s.stats.getAggregatedStats()
+}
+
+// GetStatsHistory gets stats history aggregated by the specified time unit
+// timeUnit is either time.Second, time.Minute, time.Hour, or 24*time.Hour
+// start is start of the time range
+// end is end of the time range
+// returns nil if time unit is not supported
+func (s *Server) GetStatsHistory(timeUnit time.Duration, startTime time.Time, endTime time.Time) (map[string]interface{}, error) {
+	return s.stats.getStatsHistory(timeUnit, startTime, endTime)
+}
+
 // handleDNSRequest filters the incoming DNS requests and writes them to the query log
 func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
 	start := time.Now()
@@ -261,7 +292,10 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
 		if d.Upstream != nil {
 			upstreamAddr = d.Upstream.Address()
 		}
-		logRequest(msg, d.Res, res, elapsed, d.Addr, upstreamAddr)
+		entry := s.queryLog.logRequest(msg, d.Res, res, elapsed, d.Addr, upstreamAddr)
+		if entry != nil {
+			s.stats.incrementCounters(entry)
+		}
 	}
 
 	return nil
diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go
index 0edde88b..9553b9ed 100644
--- a/dnsforward/dnsforward_test.go
+++ b/dnsforward/dnsforward_test.go
@@ -2,18 +2,19 @@ package dnsforward
 
 import (
 	"net"
+	"os"
 	"testing"
 	"time"
 
-	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
+	"github.com/stretchr/testify/assert"
 
+	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
 	"github.com/miekg/dns"
 )
 
 func TestServer(t *testing.T) {
-	s := Server{}
-	s.UDPListenAddr = &net.UDPAddr{Port: 0}
-	s.TCPListenAddr = &net.TCPAddr{Port: 0}
+	s := createTestServer(t)
+	defer removeDataDir(t)
 	err := s.Start(nil)
 	if err != nil {
 		t.Fatalf("Failed to start server: %s", err)
@@ -29,6 +30,14 @@ func TestServer(t *testing.T) {
 	}
 	assertResponse(t, reply)
 
+	// check query log and stats
+	log := s.GetQueryLog()
+	assert.Equal(t, 1, len(log), "Log size")
+	stats := s.GetStatsTop()
+	assert.Equal(t, 1, len(stats.Domains), "Top domains length")
+	assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
+	assert.Equal(t, 1, len(stats.Clients), "Top clients length")
+
 	// message over TCP
 	req = createTestMessage()
 	addr = s.dnsProxy.Addr("tcp")
@@ -39,6 +48,15 @@ func TestServer(t *testing.T) {
 	}
 	assertResponse(t, reply)
 
+	// check query log and stats again
+	log = s.GetQueryLog()
+	assert.Equal(t, 2, len(log), "Log size")
+	stats = s.GetStatsTop()
+	// Length did not change as we queried the same domain
+	assert.Equal(t, 1, len(stats.Domains), "Top domains length")
+	assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
+	assert.Equal(t, 1, len(stats.Clients), "Top clients length")
+
 	err = s.Stop()
 	if err != nil {
 		t.Fatalf("DNS server failed to stop: %s", err)
@@ -46,9 +64,8 @@ func TestServer(t *testing.T) {
 }
 
 func TestInvalidRequest(t *testing.T) {
-	s := Server{}
-	s.UDPListenAddr = &net.UDPAddr{Port: 0}
-	s.TCPListenAddr = &net.TCPAddr{Port: 0}
+	s := createTestServer(t)
+	defer removeDataDir(t)
 	err := s.Start(nil)
 	if err != nil {
 		t.Fatalf("Failed to start server: %s", err)
@@ -67,6 +84,15 @@ func TestInvalidRequest(t *testing.T) {
 		t.Fatalf("got a response to an invalid query")
 	}
 
+	// check query log and stats
+	// invalid requests aren't written to the query log
+	log := s.GetQueryLog()
+	assert.Equal(t, 0, len(log), "Log size")
+	stats := s.GetStatsTop()
+	assert.Equal(t, 0, len(stats.Domains), "Top domains length")
+	assert.Equal(t, 0, len(stats.Blocked), "Top blocked length")
+	assert.Equal(t, 0, len(stats.Clients), "Top clients length")
+
 	err = s.Stop()
 	if err != nil {
 		t.Fatalf("DNS server failed to stop: %s", err)
@@ -74,7 +100,8 @@ func TestInvalidRequest(t *testing.T) {
 }
 
 func TestBlockedRequest(t *testing.T) {
-	s := createTestServer()
+	s := createTestServer(t)
+	defer removeDataDir(t)
 	err := s.Start(nil)
 	if err != nil {
 		t.Fatalf("Failed to start server: %s", err)
@@ -99,6 +126,14 @@ func TestBlockedRequest(t *testing.T) {
 		t.Fatalf("Wrong response: %s", reply.String())
 	}
 
+	// check query log and stats
+	log := s.GetQueryLog()
+	assert.Equal(t, 1, len(log), "Log size")
+	stats := s.GetStatsTop()
+	assert.Equal(t, 1, len(stats.Domains), "Top domains length")
+	assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
+	assert.Equal(t, 1, len(stats.Clients), "Top clients length")
+
 	err = s.Stop()
 	if err != nil {
 		t.Fatalf("DNS server failed to stop: %s", err)
@@ -106,7 +141,8 @@ func TestBlockedRequest(t *testing.T) {
 }
 
 func TestBlockedByHosts(t *testing.T) {
-	s := createTestServer()
+	s := createTestServer(t)
+	defer removeDataDir(t)
 	err := s.Start(nil)
 	if err != nil {
 		t.Fatalf("Failed to start server: %s", err)
@@ -138,6 +174,14 @@ func TestBlockedByHosts(t *testing.T) {
 		t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
 	}
 
+	// check query log and stats
+	log := s.GetQueryLog()
+	assert.Equal(t, 1, len(log), "Log size")
+	stats := s.GetStatsTop()
+	assert.Equal(t, 1, len(stats.Domains), "Top domains length")
+	assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
+	assert.Equal(t, 1, len(stats.Clients), "Top clients length")
+
 	err = s.Stop()
 	if err != nil {
 		t.Fatalf("DNS server failed to stop: %s", err)
@@ -145,7 +189,8 @@ func TestBlockedByHosts(t *testing.T) {
 }
 
 func TestBlockedBySafeBrowsing(t *testing.T) {
-	s := createTestServer()
+	s := createTestServer(t)
+	defer removeDataDir(t)
 	err := s.Start(nil)
 	if err != nil {
 		t.Fatalf("Failed to start server: %s", err)
@@ -188,16 +233,25 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
 		t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
 	}
 
+	// check query log and stats
+	log := s.GetQueryLog()
+	assert.Equal(t, 1, len(log), "Log size")
+	stats := s.GetStatsTop()
+	assert.Equal(t, 1, len(stats.Domains), "Top domains length")
+	assert.Equal(t, 1, len(stats.Blocked), "Top blocked length")
+	assert.Equal(t, 1, len(stats.Clients), "Top clients length")
+
 	err = s.Stop()
 	if err != nil {
 		t.Fatalf("DNS server failed to stop: %s", err)
 	}
 }
 
-func createTestServer() *Server {
-	s := Server{}
+func createTestServer(t *testing.T) *Server {
+	s := NewServer(createDataDir(t))
 	s.UDPListenAddr = &net.UDPAddr{Port: 0}
 	s.TCPListenAddr = &net.TCPAddr{Port: 0}
+	s.QueryLogEnabled = true
 	s.FilteringConfig.FilteringEnabled = true
 	s.FilteringConfig.ProtectionEnabled = true
 	s.FilteringConfig.SafeBrowsingEnabled = true
@@ -209,7 +263,24 @@ func createTestServer() *Server {
 	}
 	filter := dnsfilter.Filter{ID: 1, Rules: rules}
 	s.Filters = append(s.Filters, filter)
-	return &s
+	return s
+}
+
+func createDataDir(t *testing.T) string {
+	dir := "testData"
+	err := os.MkdirAll(dir, 0755)
+	if err != nil {
+		t.Fatalf("Cannot create %s: %s", dir, err)
+	}
+	return dir
+}
+
+func removeDataDir(t *testing.T) {
+	dir := "testData"
+	err := os.RemoveAll(dir)
+	if err != nil {
+		t.Fatalf("Cannot remove %s: %s", dir, err)
+	}
 }
 
 func createTestMessage() *dns.Msg {
diff --git a/dnsforward/querylog.go b/dnsforward/querylog.go
index fc51d165..52fa115c 100644
--- a/dnsforward/querylog.go
+++ b/dnsforward/querylog.go
@@ -1,10 +1,9 @@
 package dnsforward
 
 import (
-	"encoding/json"
 	"fmt"
 	"net"
-	"net/http"
+	"path/filepath"
 	"strconv"
 	"strings"
 	"sync"
@@ -24,13 +23,27 @@ const (
 	queryLogTopSize        = 500             // Keep in memory only top N values
 )
 
-var (
+// queryLog is a structure that writes and reads the DNS query log
+type queryLog struct {
+	logFile    string  // path to the log file
+	runningTop *dayTop // current top charts
+
 	logBufferLock sync.RWMutex
 	logBuffer     []*logEntry
 
 	queryLogCache []*logEntry
 	queryLogLock  sync.RWMutex
-)
+}
+
+// newQueryLog creates a new instance of the query log
+func newQueryLog(baseDir string) *queryLog {
+	l := &queryLog{
+		logFile:    filepath.Join(baseDir, queryLogFileName),
+		runningTop: &dayTop{},
+	}
+	l.runningTop.init()
+	return l
+}
 
 type logEntry struct {
 	Question []byte
@@ -42,7 +55,7 @@ type logEntry struct {
 	Upstream string `json:",omitempty"` // if empty, means it was cached
 }
 
-func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, addr net.Addr, upstream string) {
+func (l *queryLog) logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, addr net.Addr, upstream string) *logEntry {
 	var q []byte
 	var a []byte
 	var err error
@@ -52,7 +65,7 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el
 		q, err = question.Pack()
 		if err != nil {
 			log.Printf("failed to pack question for querylog: %s", err)
-			return
+			return nil
 		}
 	}
 
@@ -60,7 +73,7 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el
 		a, err = answer.Pack()
 		if err != nil {
 			log.Printf("failed to pack answer for querylog: %s", err)
-			return
+			return nil
 		}
 	}
 
@@ -80,49 +93,49 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el
 	}
 	var flushBuffer []*logEntry
 
-	logBufferLock.Lock()
-	logBuffer = append(logBuffer, &entry)
-	if len(logBuffer) >= logBufferCap {
-		flushBuffer = logBuffer
-		logBuffer = nil
+	l.logBufferLock.Lock()
+	l.logBuffer = append(l.logBuffer, &entry)
+	if len(l.logBuffer) >= logBufferCap {
+		flushBuffer = l.logBuffer
+		l.logBuffer = nil
 	}
-	logBufferLock.Unlock()
-	queryLogLock.Lock()
-	queryLogCache = append(queryLogCache, &entry)
-	if len(queryLogCache) > queryLogSize {
-		toremove := len(queryLogCache) - queryLogSize
-		queryLogCache = queryLogCache[toremove:]
+	l.logBufferLock.Unlock()
+	l.queryLogLock.Lock()
+	l.queryLogCache = append(l.queryLogCache, &entry)
+	if len(l.queryLogCache) > queryLogSize {
+		toremove := len(l.queryLogCache) - queryLogSize
+		l.queryLogCache = l.queryLogCache[toremove:]
 	}
-	queryLogLock.Unlock()
+	l.queryLogLock.Unlock()
 
 	// add it to running top
-	err = runningTop.addEntry(&entry, question, now)
+	err = l.runningTop.addEntry(&entry, question, now)
 	if err != nil {
 		log.Printf("Failed to add entry to running top: %s", err)
 		// don't do failure, just log
 	}
 
-	incrementCounters(&entry)
-
 	// if buffer needs to be flushed to disk, do it now
 	if len(flushBuffer) > 0 {
 		// write to file
 		// do it in separate goroutine -- we are stalling DNS response this whole time
 		go func() {
-			err := flushToFile(flushBuffer)
+			err := l.flushToFile(flushBuffer)
 			if err != nil {
 				log.Printf("Failed to flush the query log: %s", err)
 			}
 		}()
 	}
+
+	return &entry
 }
 
-// HandleQueryLog handles query log web request
-func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
-	queryLogLock.RLock()
-	values := make([]*logEntry, len(queryLogCache))
-	copy(values, queryLogCache)
-	queryLogLock.RUnlock()
+// getQueryLogJson returns a map with the current query log ready to be converted to a JSON
+func (l *queryLog) getQueryLog() []map[string]interface{} {
+	l.queryLogLock.RLock()
+	values := make([]*logEntry, len(l.queryLogCache))
+	copy(values, l.queryLogCache)
+	l.queryLogLock.RUnlock()
 
 	// reverse it so that newest is first
 	for left, right := 0, len(values)-1; left < right; left, right = left+1, right-1 {
@@ -182,21 +195,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) {
 		data = append(data, jsonEntry)
 	}
 
-	jsonVal, err := json.Marshal(data)
-	if err != nil {
-		errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, http.StatusInternalServerError)
-		return
-	}
-
-	w.Header().Set("Content-Type", "application/json")
-	_, err = w.Write(jsonVal)
-	if err != nil {
-		errorText := fmt.Sprintf("Unable to write response json: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, http.StatusInternalServerError)
-	}
+	return data
 }
 
 func answerToMap(a *dns.Msg) []map[string]interface{} {
diff --git a/dnsforward/querylog_file.go b/dnsforward/querylog_file.go
index 9ab048db..8aff26ae 100644
--- a/dnsforward/querylog_file.go
+++ b/dnsforward/querylog_file.go
@@ -19,7 +19,23 @@ var (
 
 const enableGzip = false
 
-func flushToFile(buffer []*logEntry) error {
+// clearLogBuffer flushes the current buffer to file and resets the current buffer
+func (l *queryLog) clearLogBuffer() error {
+	// flush remainder to file
+	l.logBufferLock.Lock()
+	flushBuffer := l.logBuffer
+	l.logBuffer = nil
+	l.logBufferLock.Unlock()
+	err := l.flushToFile(flushBuffer)
+	if err != nil {
+		log.Printf("Saving querylog to file failed: %s", err)
+		return err
+	}
+	return nil
+}
+
+// flushToFile saves the specified log entries to the query log file
+func (l *queryLog) flushToFile(buffer []*logEntry) error {
 	if len(buffer) == 0 {
 		return nil
 	}
@@ -45,14 +61,14 @@ func flushToFile(buffer []*logEntry) error {
 	}
 
 	var zb bytes.Buffer
-	filename := queryLogFileName
+	filename := l.logFile
 
 	// gzip enabled?
 	if enableGzip {
 		filename += ".gz"
 
 		zw := gzip.NewWriter(&zb)
-		zw.Name = queryLogFileName
+		zw.Name = l.logFile
 		zw.ModTime = time.Now()
 
 		_, err = zw.Write(b.Bytes())
@@ -118,13 +134,13 @@ func checkBuffer(buffer []*logEntry, b bytes.Buffer) error {
 	return nil
 }
 
-func rotateQueryLog() error {
-	from := queryLogFileName
-	to := queryLogFileName + ".1"
+func (l *queryLog) rotateQueryLog() error {
+	from := l.logFile
+	to := l.logFile + ".1"
 
 	if enableGzip {
-		from = queryLogFileName + ".gz"
-		to = queryLogFileName + ".gz.1"
+		from = l.logFile + ".gz"
+		to = l.logFile + ".gz.1"
 	}
 
 	if _, err := os.Stat(from); os.IsNotExist(err) {
@@ -143,9 +159,9 @@ func rotateQueryLog() error {
 	return nil
 }
 
-func periodicQueryLogRotate() {
+func (l *queryLog) periodicQueryLogRotate() {
 	for range time.Tick(queryLogRotationPeriod) {
-		err := rotateQueryLog()
+		err := l.rotateQueryLog()
 		if err != nil {
 			log.Printf("Failed to rotate querylog: %s", err)
 			// do nothing, continue rotating
@@ -153,20 +169,20 @@ func periodicQueryLogRotate() {
 	}
 }
 
-func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error {
+func (l *queryLog) genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error {
 	now := time.Now()
 	// read from querylog files, try newest file first
 	var files []string
 
 	if enableGzip {
 		files = []string{
-			queryLogFileName + ".gz",
-			queryLogFileName + ".gz.1",
+			l.logFile + ".gz",
+			l.logFile + ".gz.1",
 		}
 	} else {
 		files = []string{
-			queryLogFileName,
-			queryLogFileName + ".1",
+			l.logFile,
+			l.logFile + ".1",
 		}
 	}
 
diff --git a/dnsforward/querylog_top.go b/dnsforward/querylog_top.go
index 5c08a223..25ad9791 100644
--- a/dnsforward/querylog_top.go
+++ b/dnsforward/querylog_top.go
@@ -1,14 +1,10 @@
 package dnsforward
 
 import (
-	"bytes"
 	"fmt"
-	"net/http"
 	"os"
 	"path"
 	"runtime"
-	"sort"
-	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -40,32 +36,30 @@ type dayTop struct {
 	loadedLock sync.Mutex
 }
 
-var runningTop dayTop
-
-func init() {
-	runningTop.hoursWriteLock()
+func (d *dayTop) init() {
+	d.hoursWriteLock()
 	for i := 0; i < 24; i++ {
 		hour := hourTop{}
 		hour.init()
-		runningTop.hours = append(runningTop.hours, &hour)
+		d.hours = append(d.hours, &hour)
 	}
-	runningTop.hoursWriteUnlock()
+	d.hoursWriteUnlock()
 }
 
-func rotateHourlyTop() {
+func (d *dayTop) rotateHourlyTop() {
 	log.Printf("Rotating hourly top")
 	hour := &hourTop{}
 	hour.init()
-	runningTop.hoursWriteLock()
-	runningTop.hours = append([]*hourTop{hour}, runningTop.hours...)
-	runningTop.hours = runningTop.hours[:24]
-	runningTop.hoursWriteUnlock()
+	d.hoursWriteLock()
+	d.hours = append([]*hourTop{hour}, d.hours...)
+	d.hours = d.hours[:24]
+	d.hoursWriteUnlock()
 }
 
-func periodicHourlyTopRotate() {
+func (d *dayTop) periodicHourlyTopRotate() {
 	t := time.Hour
 	for range time.Tick(t) {
-		rotateHourlyTop()
+		d.rotateHourlyTop()
 	}
 }
 
@@ -165,16 +159,16 @@ func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
 	hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, "."))
 
 	// get value, if not set, crate one
-	runningTop.hoursReadLock()
-	defer runningTop.hoursReadUnlock()
-	err := runningTop.hours[hour].incrementDomains(hostname)
+	d.hoursReadLock()
+	defer d.hoursReadUnlock()
+	err := d.hours[hour].incrementDomains(hostname)
 	if err != nil {
 		log.Printf("Failed to increment value: %s", err)
 		return err
 	}
 
 	if entry.Result.IsFiltered {
-		err := runningTop.hours[hour].incrementBlocked(hostname)
+		err := d.hours[hour].incrementBlocked(hostname)
 		if err != nil {
 			log.Printf("Failed to increment value: %s", err)
 			return err
@@ -182,7 +176,7 @@ func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
 	}
 
 	if len(entry.IP) > 0 {
-		err := runningTop.hours[hour].incrementClients(entry.IP)
+		err := d.hours[hour].incrementClients(entry.IP)
 		if err != nil {
 			log.Printf("Failed to increment value: %s", err)
 			return err
@@ -192,11 +186,11 @@ func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error {
 	return nil
 }
 
-func fillStatsFromQueryLog() error {
+func (l *queryLog) fillStatsFromQueryLog(s *stats) error {
 	now := time.Now()
-	runningTop.loadedWriteLock()
-	defer runningTop.loadedWriteUnlock()
-	if runningTop.loaded {
+	l.runningTop.loadedWriteLock()
+	defer l.runningTop.loadedWriteUnlock()
+	if l.runningTop.loaded {
 		return nil
 	}
 	onEntry := func(entry *logEntry) error {
@@ -221,42 +215,49 @@ func fillStatsFromQueryLog() error {
 			return nil
 		}
 
-		err := runningTop.addEntry(entry, q, now)
+		err := l.runningTop.addEntry(entry, q, now)
 		if err != nil {
 			log.Printf("Failed to add entry to running top: %s", err)
 			return err
 		}
 
-		queryLogLock.Lock()
-		queryLogCache = append(queryLogCache, entry)
-		if len(queryLogCache) > queryLogSize {
-			toremove := len(queryLogCache) - queryLogSize
-			queryLogCache = queryLogCache[toremove:]
+		l.queryLogLock.Lock()
+		l.queryLogCache = append(l.queryLogCache, entry)
+		if len(l.queryLogCache) > queryLogSize {
+			toremove := len(l.queryLogCache) - queryLogSize
+			l.queryLogCache = l.queryLogCache[toremove:]
 		}
-		queryLogLock.Unlock()
-
-		incrementCounters(entry)
+		l.queryLogLock.Unlock()
 
+		s.incrementCounters(entry)
 		return nil
 	}
 
 	needMore := func() bool { return true }
-	err := genericLoader(onEntry, needMore, queryLogTimeLimit)
+	err := l.genericLoader(onEntry, needMore, queryLogTimeLimit)
 	if err != nil {
 		log.Printf("Failed to load entries from querylog: %s", err)
 		return err
 	}
 
-	runningTop.loaded = true
-
+	l.runningTop.loaded = true
 	return nil
 }
 
-// HandleStatsTop returns the current top stats
-func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
-	domains := map[string]int{}
-	blocked := map[string]int{}
-	clients := map[string]int{}
+// StatsTop represents top stat charts
+type StatsTop struct {
+	Domains map[string]int // Domains - top requested domains
+	Blocked map[string]int // Blocked - top blocked domains
+	Clients map[string]int // Clients - top DNS clients
+}
+
+// getStatsTop returns the current top stats
+func (d *dayTop) getStatsTop() *StatsTop {
+	s := &StatsTop{
+		Domains: map[string]int{},
+		Blocked: map[string]int{},
+		Clients: map[string]int{},
+	}
 
 	do := func(keys []interface{}, getter func(key string) (int, error), result map[string]int) {
 		for _, ikey := range keys {
@@ -273,79 +274,17 @@ func HandleStatsTop(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
-	runningTop.hoursReadLock()
+	d.hoursReadLock()
 	for hour := 0; hour < 24; hour++ {
-		runningTop.hours[hour].RLock()
-		do(runningTop.hours[hour].domains.Keys(), runningTop.hours[hour].lockedGetDomains, domains)
-		do(runningTop.hours[hour].blocked.Keys(), runningTop.hours[hour].lockedGetBlocked, blocked)
-		do(runningTop.hours[hour].clients.Keys(), runningTop.hours[hour].lockedGetClients, clients)
-		runningTop.hours[hour].RUnlock()
+		d.hours[hour].RLock()
+		do(d.hours[hour].domains.Keys(), d.hours[hour].lockedGetDomains, s.Domains)
+		do(d.hours[hour].blocked.Keys(), d.hours[hour].lockedGetBlocked, s.Blocked)
+		do(d.hours[hour].clients.Keys(), d.hours[hour].lockedGetClients, s.Clients)
+		d.hours[hour].RUnlock()
 	}
-	runningTop.hoursReadUnlock()
+	d.hoursReadUnlock()
 
-	// use manual json marshalling because we want maps to be sorted by value
-	json := bytes.Buffer{}
-	json.WriteString("{\n")
-
-	gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) {
-		json.WriteString("  ")
-		json.WriteString(fmt.Sprintf("%q", name))
-		json.WriteString(": {\n")
-		sorted := sortByValue(top)
-		// no more than 50 entries
-		if len(sorted) > 50 {
-			sorted = sorted[:50]
-		}
-		for i, key := range sorted {
-			json.WriteString("    ")
-			json.WriteString(fmt.Sprintf("%q", key))
-			json.WriteString(": ")
-			json.WriteString(strconv.Itoa(top[key]))
-			if i+1 != len(sorted) {
-				json.WriteByte(',')
-			}
-			json.WriteByte('\n')
-		}
-		json.WriteString("  }")
-		if addComma {
-			json.WriteByte(',')
-		}
-		json.WriteByte('\n')
-	}
-	gen(&json, "top_queried_domains", domains, true)
-	gen(&json, "top_blocked_domains", blocked, true)
-	gen(&json, "top_clients", clients, true)
-	json.WriteString("  \"stats_period\": \"24 hours\"\n")
-	json.WriteString("}\n")
-
-	w.Header().Set("Content-Type", "application/json")
-	_, err := w.Write(json.Bytes())
-	if err != nil {
-		errorText := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, http.StatusInternalServerError)
-	}
-}
-
-// helper function for querylog API
-func sortByValue(m map[string]int) []string {
-	type kv struct {
-		k string
-		v int
-	}
-	var ss []kv
-	for k, v := range m {
-		ss = append(ss, kv{k, v})
-	}
-	sort.Slice(ss, func(l, r int) bool {
-		return ss[l].v > ss[r].v
-	})
-
-	sorted := []string{}
-	for _, v := range ss {
-		sorted = append(sorted, v.k)
-	}
-	return sorted
+	return s
 }
 
 func (d *dayTop) hoursWriteLock()    { tracelock(); d.hoursLock.Lock() }
diff --git a/dnsforward/stats.go b/dnsforward/stats.go
index cbc25af7..705a250f 100644
--- a/dnsforward/stats.go
+++ b/dnsforward/stats.go
@@ -1,68 +1,76 @@
 package dnsforward
 
 import (
-	"encoding/json"
 	"fmt"
-	"net/http"
 	"sync"
 	"time"
 
 	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
-	"github.com/hmage/golibs/log"
 )
 
-var (
-	requests             = newDNSCounter("requests_total")
-	filtered             = newDNSCounter("filtered_total")
-	filteredLists        = newDNSCounter("filtered_lists_total")
-	filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total")
-	filteredParental     = newDNSCounter("filtered_parental_total")
-	whitelisted          = newDNSCounter("whitelisted_total")
-	safesearch           = newDNSCounter("safesearch_total")
-	errorsTotal          = newDNSCounter("errors_total")
-	elapsedTime          = newDNSHistogram("request_duration")
-)
-
-// entries for single time period (for example all per-second entries)
-type statsEntries map[string][statsHistoryElements]float64
-
 // how far back to keep the stats
 const statsHistoryElements = 60 + 1 // +1 for calculating delta
 
+// entries for single time period (for example all per-second entries)
+type statsEntries map[string][statsHistoryElements]float64
+
 // each periodic stat is a map of arrays
 type periodicStats struct {
-	Entries    statsEntries
+	entries    statsEntries
 	period     time.Duration // how long one entry lasts
-	LastRotate time.Time     // last time this data was rotated
+	lastRotate time.Time     // last time this data was rotated
 
 	sync.RWMutex
 }
 
+// stats is the DNS server historical statistics
 type stats struct {
-	PerSecond periodicStats
-	PerMinute periodicStats
-	PerHour   periodicStats
-	PerDay    periodicStats
+	perSecond periodicStats
+	perMinute periodicStats
+	perHour   periodicStats
+	perDay    periodicStats
+
+	requests             *counter   // total number of requests
+	filtered             *counter   // total number of filtered requests
+	filteredLists        *counter   // total number of requests blocked by filter lists
+	filteredSafebrowsing *counter   // total number of requests blocked by safebrowsing
+	filteredParental     *counter   // total number of requests blocked by the parental control
+	whitelisted          *counter   // total number of requests whitelisted by filter lists
+	safesearch           *counter   // total number of requests for which safe search rules were applied
+	errorsTotal          *counter   // total number of errors
+	elapsedTime          *histogram // requests duration histogram
 }
 
-// per-second/per-minute/per-hour/per-day stats
-var statistics stats
+// initializes an empty stats structure
+func newStats() *stats {
+	s := &stats{
+		requests:             newDNSCounter("requests_total"),
+		filtered:             newDNSCounter("filtered_total"),
+		filteredLists:        newDNSCounter("filtered_lists_total"),
+		filteredSafebrowsing: newDNSCounter("filtered_safebrowsing_total"),
+		filteredParental:     newDNSCounter("filtered_parental_total"),
+		whitelisted:          newDNSCounter("whitelisted_total"),
+		safesearch:           newDNSCounter("safesearch_total"),
+		errorsTotal:          newDNSCounter("errors_total"),
+		elapsedTime:          newDNSHistogram("request_duration"),
+	}
+
+	// Initializes empty per-sec/minute/hour/day stats
+	s.purgeStats()
+	return s
+}
 
 func initPeriodicStats(periodic *periodicStats, period time.Duration) {
-	periodic.Entries = statsEntries{}
-	periodic.LastRotate = time.Now()
+	periodic.entries = statsEntries{}
+	periodic.lastRotate = time.Now()
 	periodic.period = period
 }
 
-func init() {
-	purgeStats()
-}
-
-func purgeStats() {
-	initPeriodicStats(&statistics.PerSecond, time.Second)
-	initPeriodicStats(&statistics.PerMinute, time.Minute)
-	initPeriodicStats(&statistics.PerHour, time.Hour)
-	initPeriodicStats(&statistics.PerDay, time.Hour*24)
+func (s *stats) purgeStats() {
+	initPeriodicStats(&s.perSecond, time.Second)
+	initPeriodicStats(&s.perMinute, time.Minute)
+	initPeriodicStats(&s.perHour, time.Hour)
+	initPeriodicStats(&s.perDay, time.Hour*24)
 }
 
 func (p *periodicStats) Inc(name string, when time.Time) {
@@ -73,9 +81,9 @@ func (p *periodicStats) Inc(name string, when time.Time) {
 		return // outside of our timeframe
 	}
 	p.Lock()
-	currentValues := p.Entries[name]
+	currentValues := p.entries[name]
 	currentValues[elapsed]++
-	p.Entries[name] = currentValues
+	p.entries[name] = currentValues
 	p.Unlock()
 }
 
@@ -89,51 +97,51 @@ func (p *periodicStats) Observe(name string, when time.Time, value float64) {
 	p.Lock()
 	{
 		countname := name + "_count"
-		currentValues := p.Entries[countname]
+		currentValues := p.entries[countname]
 		v := currentValues[elapsed]
-		// log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1)
+		// log.Tracef("Will change p.entries[%s][%d] from %v to %v", countname, elapsed, value, value+1)
 		v++
 		currentValues[elapsed] = v
-		p.Entries[countname] = currentValues
+		p.entries[countname] = currentValues
 	}
 	{
 		totalname := name + "_sum"
-		currentValues := p.Entries[totalname]
+		currentValues := p.entries[totalname]
 		currentValues[elapsed] += value
-		p.Entries[totalname] = currentValues
+		p.entries[totalname] = currentValues
 	}
 	p.Unlock()
 }
 
 func (p *periodicStats) statsRotate(now time.Time) {
 	p.Lock()
-	rotations := int64(now.Sub(p.LastRotate) / p.period)
+	rotations := int64(now.Sub(p.lastRotate) / p.period)
 	if rotations > statsHistoryElements {
 		rotations = statsHistoryElements
 	}
 	// calculate how many times we should rotate
 	for r := int64(0); r < rotations; r++ {
-		for key, values := range p.Entries {
+		for key, values := range p.entries {
 			newValues := [statsHistoryElements]float64{}
 			for i := 1; i < len(values); i++ {
 				newValues[i] = values[i-1]
 			}
-			p.Entries[key] = newValues
+			p.entries[key] = newValues
 		}
 	}
 	if rotations > 0 {
-		p.LastRotate = now
+		p.lastRotate = now
 	}
 	p.Unlock()
 }
 
-func statsRotator() {
+func (s *stats) statsRotator() {
 	for range time.Tick(time.Second) {
 		now := time.Now()
-		statistics.PerSecond.statsRotate(now)
-		statistics.PerMinute.statsRotate(now)
-		statistics.PerHour.statsRotate(now)
-		statistics.PerDay.statsRotate(now)
+		s.perSecond.statsRotate(now)
+		s.perMinute.statsRotate(now)
+		s.perHour.statsRotate(now)
+		s.perDay.statsRotate(now)
 	}
 }
 
@@ -152,20 +160,16 @@ func newDNSCounter(name string) *counter {
 	}
 }
 
-func (c *counter) IncWithTime(when time.Time) {
-	statistics.PerSecond.Inc(c.name, when)
-	statistics.PerMinute.Inc(c.name, when)
-	statistics.PerHour.Inc(c.name, when)
-	statistics.PerDay.Inc(c.name, when)
+func (s *stats) incWithTime(c *counter, when time.Time) {
+	s.perSecond.Inc(c.name, when)
+	s.perMinute.Inc(c.name, when)
+	s.perHour.Inc(c.name, when)
+	s.perDay.Inc(c.name, when)
 	c.Lock()
 	c.value++
 	c.Unlock()
 }
 
-func (c *counter) Inc() {
-	c.IncWithTime(time.Now())
-}
-
 type histogram struct {
 	name  string // used as key in periodic stats
 	count int64
@@ -180,56 +184,52 @@ func newDNSHistogram(name string) *histogram {
 	}
 }
 
-func (h *histogram) ObserveWithTime(value float64, when time.Time) {
-	statistics.PerSecond.Observe(h.name, when, value)
-	statistics.PerMinute.Observe(h.name, when, value)
-	statistics.PerHour.Observe(h.name, when, value)
-	statistics.PerDay.Observe(h.name, when, value)
+func (s *stats) observeWithTime(h *histogram, value float64, when time.Time) {
+	s.perSecond.Observe(h.name, when, value)
+	s.perMinute.Observe(h.name, when, value)
+	s.perHour.Observe(h.name, when, value)
+	s.perDay.Observe(h.name, when, value)
 	h.Lock()
 	h.count++
 	h.total += value
 	h.Unlock()
 }
 
-func (h *histogram) Observe(value float64) {
-	h.ObserveWithTime(value, time.Now())
-}
-
 // -----
 // stats
 // -----
-func incrementCounters(entry *logEntry) {
-	requests.IncWithTime(entry.Time)
+func (s *stats) incrementCounters(entry *logEntry) {
+	s.incWithTime(s.requests, entry.Time)
 	if entry.Result.IsFiltered {
-		filtered.IncWithTime(entry.Time)
+		s.incWithTime(s.filtered, entry.Time)
 	}
 
 	switch entry.Result.Reason {
 	case dnsfilter.NotFilteredWhiteList:
-		whitelisted.IncWithTime(entry.Time)
+		s.incWithTime(s.whitelisted, entry.Time)
 	case dnsfilter.NotFilteredError:
-		errorsTotal.IncWithTime(entry.Time)
+		s.incWithTime(s.errorsTotal, entry.Time)
 	case dnsfilter.FilteredBlackList:
-		filteredLists.IncWithTime(entry.Time)
+		s.incWithTime(s.filteredLists, entry.Time)
 	case dnsfilter.FilteredSafeBrowsing:
-		filteredSafebrowsing.IncWithTime(entry.Time)
+		s.incWithTime(s.filteredSafebrowsing, entry.Time)
 	case dnsfilter.FilteredParental:
-		filteredParental.IncWithTime(entry.Time)
+		s.incWithTime(s.filteredParental, entry.Time)
 	case dnsfilter.FilteredInvalid:
 		// do nothing
 	case dnsfilter.FilteredSafeSearch:
-		safesearch.IncWithTime(entry.Time)
+		s.incWithTime(s.safesearch, entry.Time)
 	}
-	elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time)
+	s.observeWithTime(s.elapsedTime, entry.Elapsed.Seconds(), entry.Time)
 }
 
-// HandleStats returns aggregated stats data for the 24 hours
-func HandleStats(w http.ResponseWriter, r *http.Request) {
+// getAggregatedStats returns aggregated stats data for the 24 hours
+func (s *stats) getAggregatedStats() map[string]interface{} {
 	const numHours = 24
-	histrical := generateMapFromStats(&statistics.PerHour, 0, numHours)
+	historical := s.generateMapFromStats(&s.perHour, 0, numHours)
 	// sum them up
 	summed := map[string]interface{}{}
-	for key, values := range histrical {
+	for key, values := range historical {
 		summedValue := 0.0
 		floats, ok := values.([]float64)
 		if !ok {
@@ -249,33 +249,18 @@ func HandleStats(w http.ResponseWriter, r *http.Request) {
 	}
 
 	summed["stats_period"] = "24 hours"
-
-	json, err := json.Marshal(summed)
-	if err != nil {
-		errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, 500)
-		return
-	}
-	w.Header().Set("Content-Type", "application/json")
-	_, err = w.Write(json)
-	if err != nil {
-		errorText := fmt.Sprintf("Unable to write response json: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, 500)
-		return
-	}
+	return summed
 }
 
-func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} {
+func (s *stats) generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} {
 	// clamp
 	start = clamp(start, 0, statsHistoryElements)
 	end = clamp(end, 0, statsHistoryElements)
 
 	avgProcessingTime := make([]float64, 0)
 
-	count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end)
-	sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end)
+	count := getReversedSlice(stats.entries[s.elapsedTime.name+"_count"], start, end)
+	sum := getReversedSlice(stats.entries[s.elapsedTime.name+"_sum"], start, end)
 	for i := 0; i < len(count); i++ {
 		var avg float64
 		if count[i] != 0 {
@@ -286,66 +271,48 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i
 	}
 
 	result := map[string]interface{}{
-		"dns_queries":           getReversedSlice(stats.Entries[requests.name], start, end),
-		"blocked_filtering":     getReversedSlice(stats.Entries[filtered.name], start, end),
-		"replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end),
-		"replaced_safesearch":   getReversedSlice(stats.Entries[safesearch.name], start, end),
-		"replaced_parental":     getReversedSlice(stats.Entries[filteredParental.name], start, end),
+		"dns_queries":           getReversedSlice(stats.entries[s.requests.name], start, end),
+		"blocked_filtering":     getReversedSlice(stats.entries[s.filtered.name], start, end),
+		"replaced_safebrowsing": getReversedSlice(stats.entries[s.filteredSafebrowsing.name], start, end),
+		"replaced_safesearch":   getReversedSlice(stats.entries[s.safesearch.name], start, end),
+		"replaced_parental":     getReversedSlice(stats.entries[s.filteredParental.name], start, end),
 		"avg_processing_time":   avgProcessingTime,
 	}
 	return result
 }
 
-// HandleStatsHistory returns historical stats data for the 24 hours
-func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
-	// handle time unit and prepare our time window size
-	now := time.Now()
-	timeUnitString := r.URL.Query().Get("time_unit")
+// getStatsHistory gets stats history aggregated by the specified time unit
+// timeUnit is either time.Second, time.Minute, time.Hour, or 24*time.Hour
+// start is start of the time range
+// end is end of the time range
+// returns nil if time unit is not supported
+func (s *stats) getStatsHistory(timeUnit time.Duration, startTime time.Time, endTime time.Time) (map[string]interface{}, error) {
 	var stats *periodicStats
-	var timeUnit time.Duration
-	switch timeUnitString {
-	case "seconds":
-		timeUnit = time.Second
-		stats = &statistics.PerSecond
-	case "minutes":
-		timeUnit = time.Minute
-		stats = &statistics.PerMinute
-	case "hours":
-		timeUnit = time.Hour
-		stats = &statistics.PerHour
-	case "days":
-		timeUnit = time.Hour * 24
-		stats = &statistics.PerDay
-	default:
-		http.Error(w, "Must specify valid time_unit parameter", 400)
-		return
+
+	switch timeUnit {
+	case time.Second:
+		stats = &s.perSecond
+	case time.Minute:
+		stats = &s.perMinute
+	case time.Hour:
+		stats = &s.perHour
+	case 24 * time.Hour:
+		stats = &s.perDay
 	}
 
-	// parse start and end time
-	startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
-	if err != nil {
-		errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, 400)
-		return
-	}
-	endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
-	if err != nil {
-		errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, 400)
-		return
+	if stats == nil {
+		return nil, fmt.Errorf("unsupported time unit: %v", timeUnit)
 	}
 
+	now := time.Now()
+
 	// check if start and time times are within supported time range
 	timeRange := timeUnit * statsHistoryElements
 	if startTime.Add(timeRange).Before(now) {
-		http.Error(w, "start_time parameter is outside of supported range", http.StatusBadRequest)
-		return
+		return nil, fmt.Errorf("start_time parameter is outside of supported range: %s", startTime.String())
 	}
 	if endTime.Add(timeRange).Before(now) {
-		http.Error(w, "end_time parameter is outside of supported range", http.StatusBadRequest)
-		return
+		return nil, fmt.Errorf("end_time parameter is outside of supported range: %s", startTime.String())
 	}
 
 	// calculate start and end of our array
@@ -358,33 +325,7 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) {
 		start, end = end, start
 	}
 
-	data := generateMapFromStats(stats, start, end)
-	json, err := json.Marshal(data)
-	if err != nil {
-		errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, 500)
-		return
-	}
-	w.Header().Set("Content-Type", "application/json")
-	_, err = w.Write(json)
-	if err != nil {
-		errorText := fmt.Sprintf("Unable to write response json: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, 500)
-		return
-	}
-}
-
-// HandleStatsReset resets the stats caches
-func HandleStatsReset(w http.ResponseWriter, r *http.Request) {
-	purgeStats()
-	_, err := fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errorText := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errorText)
-		http.Error(w, errorText, http.StatusInternalServerError)
-	}
+	return s.generateMapFromStats(stats, start, end), nil
 }
 
 func clamp(value, low, high int) int {
diff --git a/filter.go b/filter.go
index 49b54f0e..46d050ab 100644
--- a/filter.go
+++ b/filter.go
@@ -251,5 +251,5 @@ func (filter *filter) load() error {
 
 // Path to the filter contents
 func (filter *filter) Path() string {
-	return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
+	return filepath.Join(config.ourWorkingDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
 }
diff --git a/go.mod b/go.mod
index b2b0a049..fd61e3e4 100644
--- a/go.mod
+++ b/go.mod
@@ -13,8 +13,10 @@ require (
 	github.com/kardianos/service v0.0.0-20181115005516-4c239ee84e7b
 	github.com/krolaw/dhcp4 v0.0.0-20180925202202-7cead472c414
 	github.com/miekg/dns v1.1.1
+	github.com/pkg/errors v0.8.0
 	github.com/shirou/gopsutil v2.18.10+incompatible
 	github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect
+	github.com/stretchr/testify v1.2.2
 	go.uber.org/goleak v0.10.0
 	golang.org/x/net v0.0.0-20181220203305-927f97764cc3
 	golang.org/x/sys v0.0.0-20181228144115-9a3f9b0469bb
diff --git a/helpers.go b/helpers.go
index 1bea694e..a0cf1fd7 100644
--- a/helpers.go
+++ b/helpers.go
@@ -100,7 +100,7 @@ func optionalAuthHandler(handler http.Handler) http.Handler {
 func detectFirstRun() bool {
 	configfile := config.ourConfigFilename
 	if !filepath.IsAbs(configfile) {
-		configfile = filepath.Join(config.ourBinaryDir, config.ourConfigFilename)
+		configfile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename)
 	}
 	_, err := os.Stat(configfile)
 	if !os.IsNotExist(err) {
diff --git a/upgrade.go b/upgrade.go
index 02629797..0b3ddc5c 100644
--- a/upgrade.go
+++ b/upgrade.go
@@ -95,7 +95,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err
 func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
 	log.Printf("%s(): called", _Func())
 
-	dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt")
+	dnsFilterPath := filepath.Join(config.ourWorkingDir, "dnsfilter.txt")
 	if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
 		log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath)
 		err = os.Remove(dnsFilterPath)
@@ -116,7 +116,7 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
 func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
 	log.Printf("%s(): called", _Func())
 
-	coreFilePath := filepath.Join(config.ourBinaryDir, "Corefile")
+	coreFilePath := filepath.Join(config.ourWorkingDir, "Corefile")
 	if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
 		log.Printf("Deleting %s as we don't need it anymore", coreFilePath)
 		err = os.Remove(coreFilePath)