From 15d07a40eb83c021032bfca5bbd532f3b032f2c5 Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Thu, 25 Jul 2019 16:37:06 +0300
Subject: [PATCH] * refactor

---
 dnsfilter/dnsfilter.go      | 11 +-------
 dnsfilter/dnsfilter_test.go | 52 +++++++++++++++++++++----------------
 dnsforward/dnsforward.go    | 12 ++++++++-
 3 files changed, 42 insertions(+), 33 deletions(-)

diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go
index cc724ff7..47ce1f0e 100644
--- a/dnsfilter/dnsfilter.go
+++ b/dnsfilter/dnsfilter.go
@@ -206,7 +206,7 @@ func (r Reason) Matched() bool {
 }
 
 // CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled
-func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Result, error) {
+func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFilteringSettings) (Result, error) {
 	// sometimes DNS clients will try to resolve ".", which is a request to get root servers
 	if host == "" {
 		return Result{Reason: NotFilteredNotFound}, nil
@@ -217,15 +217,6 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Res
 		return Result{}, nil
 	}
 
-	var setts RequestFilteringSettings
-	setts.FilteringEnabled = true
-	setts.SafeSearchEnabled = d.SafeSearchEnabled
-	setts.SafeBrowsingEnabled = d.SafeBrowsingEnabled
-	setts.ParentalEnabled = d.ParentalEnabled
-	if d.FilterHandler != nil {
-		d.FilterHandler(clientAddr, &setts)
-	}
-
 	var result Result
 	var err error
 
diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go
index 7df5fe09..d5eabafe 100644
--- a/dnsfilter/dnsfilter_test.go
+++ b/dnsfilter/dnsfilter_test.go
@@ -16,6 +16,8 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
+var setts RequestFilteringSettings
+
 // HELPERS
 // SAFE BROWSING
 // SAFE SEARCH
@@ -46,10 +48,16 @@ func _Func() string {
 }
 
 func NewForTest(c *Config, filters map[int]string) *Dnsfilter {
+	setts = RequestFilteringSettings{}
+	setts.FilteringEnabled = true
 	if c != nil {
 		c.SafeBrowsingCacheSize = 1024
 		c.SafeSearchCacheSize = 1024
 		c.ParentalCacheSize = 1024
+
+		setts.SafeSearchEnabled = c.SafeSearchEnabled
+		setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
+		setts.ParentalEnabled = c.ParentalEnabled
 	}
 	d := New(c, filters)
 	purgeCaches()
@@ -58,7 +66,7 @@ func NewForTest(c *Config, filters map[int]string) *Dnsfilter {
 
 func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
 	t.Helper()
-	ret, err := d.CheckHost(hostname, dns.TypeA, "")
+	ret, err := d.CheckHost(hostname, dns.TypeA, &setts)
 	if err != nil {
 		t.Errorf("Error while matching host %s: %s", hostname, err)
 	}
@@ -69,7 +77,7 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) {
 
 func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype uint16) {
 	t.Helper()
-	ret, err := d.CheckHost(hostname, qtype, "")
+	ret, err := d.CheckHost(hostname, qtype, &setts)
 	if err != nil {
 		t.Errorf("Error while matching host %s: %s", hostname, err)
 	}
@@ -83,7 +91,7 @@ func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype
 
 func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) {
 	t.Helper()
-	ret, err := d.CheckHost(hostname, dns.TypeA, "")
+	ret, err := d.CheckHost(hostname, dns.TypeA, &setts)
 	if err != nil {
 		t.Errorf("Error while matching host %s: %s", hostname, err)
 	}
@@ -214,7 +222,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
 
 	// Check host for each domain
 	for _, host := range yandex {
-		result, err := d.CheckHost(host, dns.TypeA, "")
+		result, err := d.CheckHost(host, dns.TypeA, &setts)
 		if err != nil {
 			t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err)
 		}
@@ -234,7 +242,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
 
 	// Check host for each domain
 	for _, host := range googleDomains {
-		result, err := d.CheckHost(host, dns.TypeA, "")
+		result, err := d.CheckHost(host, dns.TypeA, &setts)
 		if err != nil {
 			t.Errorf("SafeSearch doesn't work for %s cause %s", host, err)
 		}
@@ -254,7 +262,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
 	var err error
 
 	// Check host with disabled safesearch
-	result, err = d.CheckHost(domain, dns.TypeA, "")
+	result, err = d.CheckHost(domain, dns.TypeA, &setts)
 	if err != nil {
 		t.Fatalf("Cannot check host due to %s", err)
 	}
@@ -265,7 +273,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
 	d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
 	defer d.Destroy()
 
-	result, err = d.CheckHost(domain, dns.TypeA, "")
+	result, err = d.CheckHost(domain, dns.TypeA, &setts)
 	if err != nil {
 		t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
 	}
@@ -291,7 +299,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
 	d := NewForTest(nil, nil)
 	defer d.Destroy()
 	domain := "www.google.ru"
-	result, err := d.CheckHost(domain, dns.TypeA, "")
+	result, err := d.CheckHost(domain, dns.TypeA, &setts)
 	if err != nil {
 		t.Fatalf("Cannot check host due to %s", err)
 	}
@@ -322,7 +330,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
 		}
 	}
 
-	result, err = d.CheckHost(domain, dns.TypeA, "")
+	result, err = d.CheckHost(domain, dns.TypeA, &setts)
 	if err != nil {
 		t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err)
 	}
@@ -435,7 +443,7 @@ func TestMatching(t *testing.T) {
 			d := NewForTest(nil, filters)
 			defer d.Destroy()
 
-			ret, err := d.CheckHost(test.hostname, dns.TypeA, "")
+			ret, err := d.CheckHost(test.hostname, dns.TypeA, &setts)
 			if err != nil {
 				t.Errorf("Error while matching host %s: %s", test.hostname, err)
 			}
@@ -451,7 +459,7 @@ func TestMatching(t *testing.T) {
 
 // CLIENT SETTINGS
 
-func applyClientSettings(clientAddr string, setts *RequestFilteringSettings) {
+func applyClientSettings(setts *RequestFilteringSettings) {
 	setts.FilteringEnabled = false
 	setts.ParentalEnabled = false
 	setts.SafeBrowsingEnabled = true
@@ -476,50 +484,50 @@ func TestClientSettings(t *testing.T) {
 	// no client settings:
 
 	// blocked by filters
-	r, _ = d.CheckHost("example.org", dns.TypeA, "1.1.1.1")
+	r, _ = d.CheckHost("example.org", dns.TypeA, &setts)
 	if !r.IsFiltered || r.Reason != FilteredBlackList {
 		t.Fatalf("CheckHost FilteredBlackList")
 	}
 
 	// blocked by parental
-	r, _ = d.CheckHost("pornhub.com", dns.TypeA, "1.1.1.1")
+	r, _ = d.CheckHost("pornhub.com", dns.TypeA, &setts)
 	if !r.IsFiltered || r.Reason != FilteredParental {
 		t.Fatalf("CheckHost FilteredParental")
 	}
 
 	// safesearch is disabled
-	r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, "1.1.1.1")
+	r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, &setts)
 	if r.IsFiltered {
 		t.Fatalf("CheckHost safesearch")
 	}
 
 	// not blocked
-	r, _ = d.CheckHost("facebook.com", dns.TypeA, "1.1.1.1")
+	r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts)
 	assert.True(t, !r.IsFiltered)
 
 	// override client settings:
-	d.FilterHandler = applyClientSettings
+	applyClientSettings(&setts)
 
 	// override filtering settings
-	r, _ = d.CheckHost("example.org", dns.TypeA, "1.1.1.1")
+	r, _ = d.CheckHost("example.org", dns.TypeA, &setts)
 	if r.IsFiltered {
 		t.Fatalf("CheckHost")
 	}
 
 	// override parental settings (force disable parental)
-	r, _ = d.CheckHost("pornhub.com", dns.TypeA, "1.1.1.1")
+	r, _ = d.CheckHost("pornhub.com", dns.TypeA, &setts)
 	if r.IsFiltered {
 		t.Fatalf("CheckHost")
 	}
 
 	// override safesearch settings (force enable safesearch)
-	r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, "1.1.1.1")
+	r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, &setts)
 	if !r.IsFiltered || r.Reason != FilteredSafeBrowsing {
 		t.Fatalf("CheckHost FilteredSafeBrowsing")
 	}
 
 	// blocked by additional rules
-	r, _ = d.CheckHost("facebook.com", dns.TypeA, "1.1.1.1")
+	r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts)
 	assert.True(t, r.IsFiltered && r.Reason == FilteredBlockedService)
 }
 
@@ -530,7 +538,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
 	defer d.Destroy()
 	for n := 0; n < b.N; n++ {
 		hostname := "wmconvirus.narod.ru"
-		ret, err := d.CheckHost(hostname, dns.TypeA, "")
+		ret, err := d.CheckHost(hostname, dns.TypeA, &setts)
 		if err != nil {
 			b.Errorf("Error while matching host %s: %s", hostname, err)
 		}
@@ -546,7 +554,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
 	b.RunParallel(func(pb *testing.PB) {
 		for pb.Next() {
 			hostname := "wmconvirus.narod.ru"
-			ret, err := d.CheckHost(hostname, dns.TypeA, "")
+			ret, err := d.CheckHost(hostname, dns.TypeA, &setts)
 			if err != nil {
 				b.Errorf("Error while matching host %s: %s", hostname, err)
 			}
diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go
index d09bf7be..a5aa81fd 100644
--- a/dnsforward/dnsforward.go
+++ b/dnsforward/dnsforward.go
@@ -533,7 +533,17 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
 	if d.Addr != nil {
 		clientAddr, _, _ = net.SplitHostPort(d.Addr.String())
 	}
-	res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, clientAddr)
+
+	var setts dnsfilter.RequestFilteringSettings
+	setts.FilteringEnabled = true
+	setts.SafeSearchEnabled = s.conf.SafeSearchEnabled
+	setts.SafeBrowsingEnabled = s.conf.SafeBrowsingEnabled
+	setts.ParentalEnabled = s.conf.ParentalEnabled
+	if s.conf.FilterHandler != nil {
+		s.conf.FilterHandler(clientAddr, &setts)
+	}
+
+	res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts)
 	if err != nil {
 		// Return immediately if there's an error
 		return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)