From 118b170210962da5d088d196feccc6e14ba5a132 Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Tue, 26 May 2020 11:42:42 +0300
Subject: [PATCH] + rewrites: support exceptions:

*.host -> IP
my.host -> my.host
*.my.host -> *.my.host

Requests for my.host and *.my.host will be passed to upstream servers,
 while all other requests for *.host will be answered with a rewritten IP
---
 dnsfilter/dnsfilter.go     |  7 +++++++
 dnsfilter/rewrites.go      | 10 +++++-----
 dnsfilter/rewrites_test.go | 40 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 52 insertions(+), 5 deletions(-)

diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go
index 6a79b79a..bcd597d8 100644
--- a/dnsfilter/dnsfilter.go
+++ b/dnsfilter/dnsfilter.go
@@ -390,6 +390,7 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
 
 // Process rewrites table
 // . Find CNAME for a domain name (exact match or by wildcard)
+//  . if found and CNAME equals to domain name - this is an exception;  exit
 //  . if found, set domain name to canonical name
 //  . repeat for the new domain name (Note: we return only the last CNAME)
 // . Find A or AAAA record for a domain name (exact match or by wildcard)
@@ -409,6 +410,12 @@ func (d *Dnsfilter) processRewrites(host string) Result {
 	origHost := host
 	for len(rr) != 0 && rr[0].Type == dns.TypeCNAME {
 		log.Debug("Rewrite: CNAME for %s is %s", host, rr[0].Answer)
+
+		if host == rr[0].Answer { // "host == CNAME" is an exception
+			res.Reason = 0
+			return res
+		}
+
 		host = rr[0].Answer
 		_, ok := cnames[host]
 		if ok {
diff --git a/dnsfilter/rewrites.go b/dnsfilter/rewrites.go
index cbc02a16..166ff0bb 100644
--- a/dnsfilter/rewrites.go
+++ b/dnsfilter/rewrites.go
@@ -43,14 +43,14 @@ func (a rewritesArray) Len() int { return len(a) }
 func (a rewritesArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
 
 // Priority:
-//  . CNAME > A/AAAA;
-//  . exact > wildcard;
-//  . higher level wildcard > lower level wildcard
+//  . CNAME < A/AAAA;
+//  . exact < wildcard;
+//  . higher level wildcard < lower level wildcard
 func (a rewritesArray) Less(i, j int) bool {
 	if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
-		return false
-	} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
 		return true
+	} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
+		return false
 	}
 
 	if isWildcard(a[i].Domain) {
diff --git a/dnsfilter/rewrites_test.go b/dnsfilter/rewrites_test.go
index 6da3e0f9..aac9cedd 100644
--- a/dnsfilter/rewrites_test.go
+++ b/dnsfilter/rewrites_test.go
@@ -125,3 +125,43 @@ func TestRewritesLevels(t *testing.T) {
 	assert.Equal(t, 1, len(r.IPList))
 	assert.Equal(t, "3.3.3.3", r.IPList[0].String())
 }
+
+func TestRewritesException(t *testing.T) {
+	d := Dnsfilter{}
+	// wildcard; exception for a sub-domain
+	d.Rewrites = []RewriteEntry{
+		RewriteEntry{"*.host.com", "2.2.2.2", 0, nil},
+		RewriteEntry{"sub.host.com", "sub.host.com", 0, nil},
+	}
+	d.prepareRewrites()
+
+	// match sub-domain
+	r := d.processRewrites("my.host.com")
+	assert.Equal(t, ReasonRewrite, r.Reason)
+	assert.Equal(t, 1, len(r.IPList))
+	assert.Equal(t, "2.2.2.2", r.IPList[0].String())
+
+	// match sub-domain, but handle exception
+	r = d.processRewrites("sub.host.com")
+	assert.Equal(t, NotFilteredNotFound, r.Reason)
+}
+
+func TestRewritesExceptionWC(t *testing.T) {
+	d := Dnsfilter{}
+	// wildcard; exception for a sub-wildcard
+	d.Rewrites = []RewriteEntry{
+		RewriteEntry{"*.host.com", "2.2.2.2", 0, nil},
+		RewriteEntry{"*.sub.host.com", "*.sub.host.com", 0, nil},
+	}
+	d.prepareRewrites()
+
+	// match sub-domain
+	r := d.processRewrites("my.host.com")
+	assert.Equal(t, ReasonRewrite, r.Reason)
+	assert.Equal(t, 1, len(r.IPList))
+	assert.Equal(t, "2.2.2.2", r.IPList[0].String())
+
+	// match sub-domain, but handle exception
+	r = d.processRewrites("my.sub.host.com")
+	assert.Equal(t, NotFilteredNotFound, r.Reason)
+}