From a12f01793ff97e0ea53bc6f751bee758d1df6bb2 Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Tue, 21 May 2019 14:53:13 +0300
Subject: [PATCH] + clients: find DNS client's hostname by IP using rDNS

---
 dns.go                   | 130 +++++++++++++++++++++++++++++++++++++++
 dns_test.go              |   9 +++
 dnsforward/dnsforward.go |   5 ++
 3 files changed, 144 insertions(+)
 create mode 100644 dns_test.go

diff --git a/dns.go b/dns.go
index e637fa4a..4eebc110 100644
--- a/dns.go
+++ b/dns.go
@@ -4,16 +4,36 @@ import (
 	"fmt"
 	"net"
 	"os"
+	"strings"
+	"sync"
+	"time"
 
 	"github.com/AdguardTeam/AdGuardHome/dnsfilter"
 	"github.com/AdguardTeam/AdGuardHome/dnsforward"
 	"github.com/AdguardTeam/dnsproxy/proxy"
+	"github.com/AdguardTeam/dnsproxy/upstream"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/joomcode/errorx"
+	"github.com/miekg/dns"
 )
 
 var dnsServer *dnsforward.Server
 
+const (
+	rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
+)
+
+type dnsContext struct {
+	rdnsChannel chan string // pass data from DNS request handling thread to rDNS thread
+	// contains IP addresses of clients to be resolved by rDNS
+	// if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP
+	rdnsIP   map[string]bool
+	rdnsLock sync.Mutex        // synchronize access to rdnsIP
+	upstream upstream.Upstream // Upstream object for our own DNS server
+}
+
+var dnsctx dnsContext
+
 // initDNSServer creates an instance of the dnsforward.Server
 // Please note that we must do it even if we don't start it
 // so that we had access to the query log and the stats
@@ -24,12 +44,121 @@ func initDNSServer(baseDir string) {
 	}
 
 	dnsServer = dnsforward.NewServer(baseDir)
+
+	bindhost := config.DNS.BindHost
+	if config.DNS.BindHost == "0.0.0.0" {
+		bindhost = "127.0.0.1"
+	}
+	resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
+	opts := upstream.Options{
+		Timeout: rdnsTimeout,
+	}
+	dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
+	if err != nil {
+		log.Error("upstream.AddressToUpstream: %s", err)
+		return
+	}
+	dnsctx.rdnsChannel = make(chan string, 256)
+	go asyncRDNSLoop()
 }
 
 func isRunning() bool {
 	return dnsServer != nil && dnsServer.IsRunning()
 }
 
+func beginAsyncRDNS(ip string) {
+	log.Tracef("Adding %s for rDNS resolve", ip)
+	select {
+	case dnsctx.rdnsChannel <- ip:
+		//
+	default:
+		log.Tracef("rDNS queue is full")
+	}
+}
+
+// Use rDNS to get hostname by IP address
+func resolveRDNS(ip string) string {
+	log.Tracef("Resolving host for %s", ip)
+
+	req := dns.Msg{}
+	req.Id = dns.Id()
+	req.RecursionDesired = true
+	req.Question = []dns.Question{
+		{
+			Qtype:  dns.TypePTR,
+			Qclass: dns.ClassINET,
+		},
+	}
+	var err error
+	req.Question[0].Name, err = dns.ReverseAddr(ip)
+	if err != nil {
+		log.Error("dns.ReverseAddr: %s", err)
+		return ""
+	}
+
+	resp, err := dnsctx.upstream.Exchange(&req)
+	if err != nil {
+		log.Error("upstream.Exchange: %s", err)
+		return ""
+	}
+	if len(resp.Answer) != 1 {
+		log.Error("len(resp.Answer) != 1")
+		return ""
+	}
+	ptr, ok := resp.Answer[0].(*dns.PTR)
+	if !ok {
+		log.Error("not a dns.PTR response")
+		return ""
+	}
+
+	log.Tracef("PTR response: %s", ptr.String())
+	if strings.HasSuffix(ptr.Ptr, ".") {
+		ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1]
+	}
+
+	return ptr.Ptr
+}
+
+// Wait for a signal and then synchronously resolve hostname by IP address
+// Add the hostname:IP pair to "Clients" array
+func asyncRDNSLoop() {
+	for {
+		var ip string
+		ip = <-dnsctx.rdnsChannel
+
+		host := resolveRDNS(ip)
+		if len(host) == 0 {
+			continue
+		}
+
+		dnsctx.rdnsLock.Lock()
+		delete(dnsctx.rdnsIP, ip)
+		dnsctx.rdnsLock.Unlock()
+
+		clientAddHost(ip, host, ClientSourceRDNS)
+	}
+}
+
+func onDNSRequest(d *proxy.DNSContext) {
+	if d.Req.Question[0].Qtype == dns.TypeA {
+		ip, _, _ := net.SplitHostPort(d.Addr.String())
+		if clientExists(ip) {
+			return
+		}
+
+		// add IP to rdnsIP, if not exists
+		dnsctx.rdnsLock.Lock()
+		defer dnsctx.rdnsLock.Unlock()
+		_, ok := dnsctx.rdnsIP[ip]
+		if ok {
+			return
+		}
+		dnsctx.rdnsIP[ip] = true
+
+		beginAsyncRDNS(ip)
+	}
+}
+
 func generateServerConfig() dnsforward.ServerConfig {
 	filters := []dnsfilter.Filter{}
 	userFilter := userFilter()
@@ -71,6 +200,7 @@ func generateServerConfig() dnsforward.ServerConfig {
 	newconfig.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams
 	newconfig.AllServers = config.DNS.AllServers
 	newconfig.FilterHandler = applyClientSettings
+	newconfig.OnDNSRequest = onDNSRequest
 	return newconfig
 }
 
diff --git a/dns_test.go b/dns_test.go
new file mode 100644
index 00000000..6623d1be
--- /dev/null
+++ b/dns_test.go
@@ -0,0 +1,9 @@
+package main
+
+import "testing"
+
+func TestResolveRDNS(t *testing.T) {
+	if r := resolveRDNS("1.1.1.1", "1.1.1.1"); r != "one.one.one.one" {
+		t.Errorf("resolveRDNS(): %s", r)
+	}
+}
diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go
index 34e6285c..bc8ed460 100644
--- a/dnsforward/dnsforward.go
+++ b/dnsforward/dnsforward.go
@@ -88,6 +88,7 @@ type ServerConfig struct {
 	Upstreams                []upstream.Upstream            // Configured upstreams
 	DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams
 	Filters                  []dnsfilter.Filter             // A list of filters to use
+	OnDNSRequest             func(d *proxy.DNSContext)
 
 	FilteringConfig
 	TLSConfig
@@ -324,6 +325,10 @@ func (s *Server) GetStatsHistory(timeUnit time.Duration, startTime time.Time, en
 func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
 	start := time.Now()
 
+	if s.conf.OnDNSRequest != nil {
+		s.conf.OnDNSRequest(d)
+	}
+
 	// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
 	res, err := s.filterDNSRequest(d)
 	if err != nil {