From 5077f1a2b31056e184156ce9602e2428c7676acb Mon Sep 17 00:00:00 2001
From: Andrey Meshkov <am@adguard.com>
Date: Mon, 23 Dec 2019 13:36:59 +0300
Subject: [PATCH] -(dnsforward): fix client settings for CNAME matching
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

✅ Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1274
---
 dnsforward/dnsforward.go      | 26 +++++++++++++++-----------
 dnsforward/dnsforward_test.go | 30 +++++++++++++++++++++++++++++-
 2 files changed, 44 insertions(+), 12 deletions(-)

diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go
index f87ee219..0b4e5081 100644
--- a/dnsforward/dnsforward.go
+++ b/dnsforward/dnsforward.go
@@ -593,22 +593,28 @@ func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dns
 	s.stats.Update(e)
 }
 
-// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
-func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
-	if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
-		return &dnsfilter.Result{}, nil
-	}
-
+// getClientRequestFilteringSettings lookups client filtering settings
+// using the client's IP address from the DNSContext
+func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings {
 	setts := s.dnsFilter.GetConfig()
 	setts.FilteringEnabled = true
 	if s.conf.FilterHandler != nil {
 		clientAddr := ipFromAddr(d.Addr)
 		s.conf.FilterHandler(clientAddr, &setts)
 	}
+	return &setts
+}
 
+// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
+func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
+	if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
+		return &dnsfilter.Result{}, nil
+	}
+
+	setts := s.getClientRequestFilteringSettings(d)
 	req := d.Req
 	host := strings.TrimSuffix(req.Question[0].Name, ".")
-	res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts)
+	res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
 	if err != nil {
 		// Return immediately if there's an error
 		return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
@@ -631,7 +637,6 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
 				a := s.genAAnswer(req, ip)
 				a.Hdr.Name = dns.Fqdn(name)
 				resp.Answer = append(resp.Answer, a)
-
 			} else if req.Question[0].Qtype == dns.TypeAAAA {
 				a := s.genAAAAAnswer(req, ip)
 				a.Hdr.Name = dns.Fqdn(name)
@@ -675,9 +680,8 @@ func (s *Server) filterResponse(d *proxy.DNSContext) (*dnsfilter.Result, error)
 			s.RUnlock()
 			continue
 		}
-		setts := dnsfilter.RequestFilteringSettings{}
-		setts.FilteringEnabled = true
-		res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts)
+		setts := s.getClientRequestFilteringSettings(d)
+		res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
 		s.RUnlock()
 
 		if err != nil {
diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go
index 88f7fb78..76f8f028 100644
--- a/dnsforward/dnsforward_test.go
+++ b/dnsforward/dnsforward_test.go
@@ -384,6 +384,30 @@ func TestBlockCNAME(t *testing.T) {
 	_ = s.Stop()
 }
 
+func TestClientRulesForCNAMEMatching(t *testing.T) {
+	s := createTestServer(t)
+	testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
+	s.conf.FilterHandler = func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) {
+		settings.FilteringEnabled = false
+	}
+	err := s.startWithUpstream(testUpstm)
+	assert.Nil(t, err)
+	addr := s.dnsProxy.Addr(proxy.ProtoUDP)
+
+	// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
+	// response is blocked
+	req := dns.Msg{}
+	req.Id = dns.Id()
+	req.Question = []dns.Question{
+		{Name: "badhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+	}
+	// However, in our case it should not be blocked
+	// as filtering is disabled on the client level
+	reply, err := dns.Exchange(&req, addr.String())
+	assert.Nil(t, err)
+	assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
+}
+
 func TestNullBlockedRequest(t *testing.T) {
 	s := createTestServer(t)
 	s.conf.FilteringConfig.BlockingMode = "null_ip"
@@ -563,7 +587,11 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
 }
 
 func createTestServer(t *testing.T) *Server {
-	rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1	host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n"
+	rules := `||nxdomain.example.org
+||null.example.org^
+127.0.0.1	host.example.org
+@@||whitelist.example.org^
+||127.0.0.255`
 	filters := map[int]string{}
 	filters[0] = rules
 	c := dnsfilter.Config{}