From 2307f55715c27f2affe5a975339cbea50377db72 Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Mon, 24 Jun 2019 19:00:03 +0300
Subject: [PATCH] * dnsfilter: use a single global context object

---
 dnsfilter/dnsfilter.go | 53 ++++++++++++++++++++++++------------------
 1 file changed, 30 insertions(+), 23 deletions(-)

diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go
index dcecbccc..ea514d00 100644
--- a/dnsfilter/dnsfilter.go
+++ b/dnsfilter/dnsfilter.go
@@ -128,14 +128,15 @@ const (
 	FilteredSafeSearch
 )
 
-// these variables need to survive coredns reload
-var (
+type dnsfContext struct {
 	stats             Stats
 	dialCache         gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
 	safebrowsingCache gcache.Cache
 	parentalCache     gcache.Cache
 	safeSearchCache   gcache.Cache
-)
+}
+
+var gctx dnsfContext // global dnsfilter context
 
 // Result holds state of hostname check
 type Result struct {
@@ -298,14 +299,10 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
 		defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host)
 	}
 
-	if safeSearchCache == nil {
-		safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
-	}
-
 	// Check cache. Return cached result if it was found
-	cachedValue, isFound, err := getCachedReason(safeSearchCache, host)
+	cachedValue, isFound, err := getCachedReason(gctx.safeSearchCache, host)
 	if isFound {
-		atomic.AddUint64(&stats.Safesearch.CacheHits, 1)
+		atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
 		log.Tracef("%s: found in SafeSearch cache", host)
 		return cachedValue, nil
 	}
@@ -322,7 +319,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
 	res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
 	if ip := net.ParseIP(safeHost); ip != nil {
 		res.IP = ip
-		err = safeSearchCache.Set(host, res)
+		err = gctx.safeSearchCache.Set(host, res)
 		if err != nil {
 			return Result{}, nil
 		}
@@ -349,7 +346,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
 	}
 
 	// Cache result
-	err = safeSearchCache.Set(host, res)
+	err = gctx.safeSearchCache.Set(host, res)
 	if err != nil {
 		return Result{}, nil
 	}
@@ -395,10 +392,7 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
 		}
 		return result, nil
 	}
-	if safebrowsingCache == nil {
-		safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
-	}
-	result, err := d.lookupCommon(host, &stats.Safebrowsing, safebrowsingCache, true, format, handleBody)
+	result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, gctx.safebrowsingCache, true, format, handleBody)
 	return result, err
 }
 
@@ -450,10 +444,7 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
 		}
 		return result, nil
 	}
-	if parentalCache == nil {
-		parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
-	}
-	result, err := d.lookupCommon(host, &stats.Parental, parentalCache, false, format, handleBody)
+	result, err := d.lookupCommon(host, &gctx.stats.Parental, gctx.parentalCache, false, format, handleBody)
 	return result, err
 }
 
@@ -620,7 +611,7 @@ func (d *Dnsfilter) shouldBeInDialCache(host string) bool {
 
 // Search for an IP address by host name
 func searchInDialCache(host string) string {
-	rawValue, err := dialCache.Get(host)
+	rawValue, err := gctx.dialCache.Get(host)
 	if err != nil {
 		return ""
 	}
@@ -632,7 +623,7 @@ func searchInDialCache(host string) string {
 
 // Add "hostname" -> "IP address" entry to cache
 func addToDialCache(host, ip string) {
-	err := dialCache.Set(host, ip)
+	err := gctx.dialCache.Set(host, ip)
 	if err != nil {
 		log.Debug("dialCache.Set: %s", err)
 	}
@@ -701,6 +692,23 @@ func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionTyp
 
 // New creates properly initialized DNS Filter that is ready to be used
 func New(c *Config, filters map[int]string) *Dnsfilter {
+
+	if c != nil {
+		// initialize objects only once
+		if c.SafeBrowsingEnabled && gctx.safebrowsingCache == nil {
+			gctx.safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
+		}
+		if c.SafeSearchEnabled && gctx.safeSearchCache == nil {
+			gctx.safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
+		}
+		if c.ParentalEnabled && gctx.parentalCache == nil {
+			gctx.parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
+		}
+		if len(c.ResolverAddress) != 0 && gctx.dialCache == nil {
+			gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
+		}
+	}
+
 	d := new(Dnsfilter)
 
 	// Customize the Transport to have larger connection pool,
@@ -714,7 +722,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
 		ExpectContinueTimeout: 1 * time.Second,
 	}
 	if c != nil && len(c.ResolverAddress) != 0 {
-		dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
 		d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
 	}
 	d.client = http.Client{
@@ -790,5 +797,5 @@ func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
 
 // GetStats return dns filtering stats since startup
 func (d *Dnsfilter) GetStats() Stats {
-	return stats
+	return gctx.stats
 }